feat: implement memory-enhanced learning for agents
- ComplexityEstimator now queries historical context to adjust token estimates - ModelSelector uses actual success rates from get_model_stats() instead of heuristics - TaskExecutor auto-discovers /root/tools/ and injects tool inventory into prompt - Task outcomes are now recorded via record_task_outcome() for future learning - Restore accidentally deleted memory/types.rs and memory/supabase.rs - Update cursor rules to document learning system implementation
This commit is contained in:
@@ -190,7 +190,7 @@ src/
|
||||
- **Postgres (pgvector)**: runs, tasks (hierarchical), events (preview), chunks (embeddings), **task_outcomes**
|
||||
- **Supabase Storage**: full event streams (jsonl), large artifacts
|
||||
|
||||
## Learning System (v3)
|
||||
## Learning System (v3 - Implemented)
|
||||
|
||||
### Purpose
|
||||
Enable data-driven optimization of:
|
||||
@@ -198,6 +198,15 @@ Enable data-driven optimization of:
|
||||
- **Model selection**: learn actual success rates per model/complexity
|
||||
- **Budget allocation**: learn actual costs vs estimated
|
||||
|
||||
### Current Implementation Status (as of Dec 2024)
|
||||
|
||||
| Component | Learning Integration | Status |
|
||||
|-----------|---------------------|--------|
|
||||
| ComplexityEstimator | Queries `get_historical_context()` to adjust token estimates | ✅ Implemented |
|
||||
| ModelSelector | Queries `get_model_stats()` for actual success rates | ✅ Implemented |
|
||||
| TaskExecutor | Auto-discovers `/root/tools/` and injects tool inventory | ✅ Implemented |
|
||||
| routes.rs | Records task outcomes via `record_task_outcome()` | ✅ Implemented |
|
||||
|
||||
### Architecture
|
||||
```
|
||||
┌──────────────────────────────────────────────────────────────────────┐
|
||||
@@ -212,23 +221,35 @@ Enable data-driven optimization of:
|
||||
│ ┌────────────────┐ └───────────────┬─────────────┘ │
|
||||
│ │ Complexity │◀── historical context ─────────┘ │
|
||||
│ │ Estimator │ (avg token ratio, avg cost ratio) │
|
||||
│ │ (enhanced) │ │
|
||||
│ │ (learning) │ │
|
||||
│ └────────┬───────┘ │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ ┌────────────────┐ Query: "models at complexity ~0.6" │
|
||||
│ │ Model Selector │ Returns: actual success rates, cost ratios │
|
||||
│ │ (enhanced) │ │
|
||||
│ │ (learning) │ │
|
||||
│ └────────┬───────┘ │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ ┌────────────────┐ │
|
||||
│ ┌────────────────┐ Discovers /root/tools/ inventory │
|
||||
│ │ TaskExecutor │──▶ record_task_outcome() ──▶ task_outcomes │
|
||||
│ │ (tool reuse) │ │
|
||||
│ └────────────────┘ │
|
||||
│ │
|
||||
└──────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Tool Reuse System
|
||||
|
||||
The TaskExecutor automatically discovers reusable tools:
|
||||
|
||||
1. **At execution start**: Scans `/root/tools/` (or `{working_dir}/tools`)
|
||||
2. **Extracts documentation**: Reads README.md files and lists scripts
|
||||
3. **Injects into prompt**: Adds tool inventory to system prompt
|
||||
4. **Agent guidance**: Prompts agent to check existing tools before creating new ones
|
||||
|
||||
This enables cross-task learning where tools created in one task can be reused in future tasks.
|
||||
|
||||
### Database Schema: `task_outcomes`
|
||||
```sql
|
||||
CREATE TABLE task_outcomes (
|
||||
@@ -472,12 +493,14 @@ This is intentional - the agent is designed to be a powerful system-wide assista
|
||||
|
||||
- [ ] Formal verification in Lean (extract pure logic)
|
||||
- [ ] WebSocket for bidirectional streaming
|
||||
- [ ] Enhanced ComplexityEstimator with historical context injection
|
||||
- [ ] Enhanced ModelSelector with data-driven success rates
|
||||
- [x] Enhanced ComplexityEstimator with historical context injection (Dec 2024)
|
||||
- [x] Enhanced ModelSelector with data-driven success rates (Dec 2024)
|
||||
- [x] Tool reuse system with auto-discovery (Dec 2024)
|
||||
- [x] Semantic code search (embeddings-based)
|
||||
- [x] Multi-model support (U-curve optimization)
|
||||
- [x] Cost tracking (Budget system)
|
||||
- [x] Persistent memory (Supabase + pgvector)
|
||||
- [x] Learning system (task_outcomes table, historical queries)
|
||||
- [x] Smart retry strategy (analyze failure mode → upgrade/downgrade model)
|
||||
- [x] Task outcome recording for learning (Dec 2024)
|
||||
|
||||
|
||||
@@ -4,6 +4,10 @@
|
||||
//! - Complexity score (0-1)
|
||||
//! - Whether to split into subtasks
|
||||
//! - Estimated token count
|
||||
//!
|
||||
//! ## Learning Integration
|
||||
//! When memory is available, the estimator queries similar past tasks
|
||||
//! and adjusts predictions based on historical actual token usage.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
@@ -12,6 +16,7 @@ use crate::agents::{
|
||||
Agent, AgentContext, AgentId, AgentResult, AgentType, Complexity, LeafAgent, LeafCapability,
|
||||
};
|
||||
use crate::llm::{ChatMessage, ChatOptions, Role};
|
||||
use crate::memory::HistoricalContext;
|
||||
use crate::task::Task;
|
||||
|
||||
/// Agent that estimates task complexity.
|
||||
@@ -176,6 +181,61 @@ Rubric for score:
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Query historical context for similar tasks.
|
||||
///
|
||||
/// Returns adjustment multipliers based on past actual vs predicted values.
|
||||
async fn get_historical_adjustments(
|
||||
&self,
|
||||
task_description: &str,
|
||||
ctx: &AgentContext,
|
||||
) -> Option<HistoricalContext> {
|
||||
let memory = ctx.memory.as_ref()?;
|
||||
|
||||
match memory.retriever.get_historical_context(task_description, 5).await {
|
||||
Ok(context) => {
|
||||
if let Some(ref hist) = context {
|
||||
tracing::debug!(
|
||||
"Historical context found: {} similar tasks, avg token ratio: {:.2}, success rate: {:.2}",
|
||||
hist.similar_outcomes.len(),
|
||||
hist.avg_token_multiplier,
|
||||
hist.similar_success_rate
|
||||
);
|
||||
}
|
||||
context
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to fetch historical context: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Adjust token estimate based on historical data.
|
||||
///
|
||||
/// If similar past tasks consistently used more/fewer tokens than predicted,
|
||||
/// we adjust our estimate accordingly.
|
||||
fn apply_historical_adjustment(
|
||||
&self,
|
||||
base_tokens: u64,
|
||||
historical: Option<&HistoricalContext>,
|
||||
) -> u64 {
|
||||
match historical {
|
||||
Some(hist) if hist.similar_outcomes.len() >= 2 => {
|
||||
// Apply the historical token multiplier (clamped to reasonable range)
|
||||
let multiplier = hist.avg_token_multiplier.clamp(0.5, 3.0);
|
||||
let adjusted = (base_tokens as f64 * multiplier).round() as u64;
|
||||
|
||||
tracing::debug!(
|
||||
"Adjusted token estimate: {} -> {} (multiplier: {:.2})",
|
||||
base_tokens, adjusted, multiplier
|
||||
);
|
||||
|
||||
adjusted
|
||||
}
|
||||
_ => base_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ComplexityEstimator {
|
||||
@@ -202,7 +262,14 @@ impl Agent for ComplexityEstimator {
|
||||
///
|
||||
/// # Returns
|
||||
/// AgentResult with Complexity data in the `data` field.
|
||||
///
|
||||
/// # Learning Integration
|
||||
/// When memory is available, queries similar past tasks to adjust predictions
|
||||
/// based on actual historical token usage.
|
||||
async fn execute(&self, task: &mut Task, ctx: &AgentContext) -> AgentResult {
|
||||
// Query historical context for similar tasks (if memory available)
|
||||
let historical = self.get_historical_adjustments(task.description(), ctx).await;
|
||||
|
||||
let prompt = self.build_prompt(task);
|
||||
|
||||
let messages = vec![
|
||||
@@ -240,9 +307,13 @@ impl Agent for ComplexityEstimator {
|
||||
let parsed = self.parse_response(&content);
|
||||
|
||||
// Apply calibrated adjustments (pure post-processing).
|
||||
let adjusted_tokens = ((parsed.estimated_tokens() as f64) * self.token_multiplier)
|
||||
let base_tokens = ((parsed.estimated_tokens() as f64) * self.token_multiplier)
|
||||
.round()
|
||||
.max(1.0) as u64;
|
||||
|
||||
// Apply historical adjustment if we have relevant data
|
||||
let adjusted_tokens = self.apply_historical_adjustment(base_tokens, historical.as_ref());
|
||||
|
||||
let should_split = parsed.score() > self.split_threshold;
|
||||
let complexity = Complexity::new(parsed.score(), parsed.reasoning(), adjusted_tokens)
|
||||
.with_split(should_split);
|
||||
@@ -262,11 +333,20 @@ impl Agent for ComplexityEstimator {
|
||||
_ => 1, // fallback tiny cost
|
||||
};
|
||||
|
||||
// Build historical info for response data
|
||||
let historical_info = historical.as_ref().map(|h| json!({
|
||||
"similar_tasks_found": h.similar_outcomes.len(),
|
||||
"avg_token_multiplier": h.avg_token_multiplier,
|
||||
"avg_cost_multiplier": h.avg_cost_multiplier,
|
||||
"similar_success_rate": h.similar_success_rate,
|
||||
}));
|
||||
|
||||
AgentResult::success(
|
||||
format!(
|
||||
"Complexity: {:.2} - {}",
|
||||
"Complexity: {:.2} - {}{}",
|
||||
complexity.score(),
|
||||
if complexity.should_split() { "Should split" } else { "Execute directly" }
|
||||
if complexity.should_split() { "Should split" } else { "Execute directly" },
|
||||
if historical.is_some() { " (adjusted from history)" } else { "" }
|
||||
),
|
||||
cost_cents,
|
||||
)
|
||||
@@ -276,6 +356,8 @@ impl Agent for ComplexityEstimator {
|
||||
"reasoning": complexity.reasoning(),
|
||||
"should_split": complexity.should_split(),
|
||||
"estimated_tokens": complexity.estimated_tokens(),
|
||||
"base_tokens_before_history": base_tokens,
|
||||
"historical_adjustment": historical_info,
|
||||
"usage": response.usage.as_ref().map(|u| json!({
|
||||
"prompt_tokens": u.prompt_tokens,
|
||||
"completion_tokens": u.completion_tokens,
|
||||
|
||||
@@ -2,9 +2,14 @@
|
||||
//!
|
||||
//! This is a refactored version of the original agent loop,
|
||||
//! now as a leaf agent in the hierarchical tree.
|
||||
//!
|
||||
//! ## Tool Reuse
|
||||
//! The executor automatically discovers and lists reusable tools in `/root/tools/`
|
||||
//! at the start of each execution, injecting their documentation into the system prompt.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::path::Path;
|
||||
|
||||
use crate::agents::{
|
||||
Agent, AgentContext, AgentId, AgentResult, AgentType, LeafAgent, LeafCapability,
|
||||
@@ -53,8 +58,112 @@ impl TaskExecutor {
|
||||
Self { id: AgentId::new() }
|
||||
}
|
||||
|
||||
/// Discover reusable tools in /root/tools/.
|
||||
///
|
||||
/// Scans the directory for README.md files and tool scripts,
|
||||
/// building an inventory of available reusable tools.
|
||||
async fn discover_reusable_tools(&self, working_dir: &str) -> String {
|
||||
let tools_dir = if working_dir.starts_with('/') {
|
||||
// Find the root (e.g., /root or the working directory's ancestor)
|
||||
if working_dir.contains("/root") {
|
||||
"/root/tools".to_string()
|
||||
} else {
|
||||
format!("{}/tools", working_dir)
|
||||
}
|
||||
} else {
|
||||
"/root/tools".to_string()
|
||||
};
|
||||
|
||||
let tools_path = Path::new(&tools_dir);
|
||||
if !tools_path.exists() {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
let mut tool_inventory = Vec::new();
|
||||
|
||||
// Try to read the directory
|
||||
if let Ok(entries) = std::fs::read_dir(&tools_dir) {
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
let name = entry.file_name().to_string_lossy().to_string();
|
||||
|
||||
// Skip hidden files
|
||||
if name.starts_with('.') {
|
||||
continue;
|
||||
}
|
||||
|
||||
if path.is_dir() {
|
||||
// Check for README.md in the tool folder
|
||||
let readme_path = path.join("README.md");
|
||||
if readme_path.exists() {
|
||||
if let Ok(content) = std::fs::read_to_string(&readme_path) {
|
||||
// Extract first paragraph or first 500 chars as description
|
||||
let description = content
|
||||
.lines()
|
||||
.filter(|l| !l.starts_with('#') && !l.trim().is_empty())
|
||||
.take(3)
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ")
|
||||
.chars()
|
||||
.take(300)
|
||||
.collect::<String>();
|
||||
|
||||
tool_inventory.push(format!("- **{}**: {}", name, description));
|
||||
} else {
|
||||
tool_inventory.push(format!("- **{}**: (tool folder, check README.md for details)", name));
|
||||
}
|
||||
} else {
|
||||
// List scripts in the folder
|
||||
let scripts: Vec<_> = std::fs::read_dir(&path)
|
||||
.ok()
|
||||
.map(|entries| {
|
||||
entries
|
||||
.flatten()
|
||||
.filter(|e| {
|
||||
let name = e.file_name().to_string_lossy().to_string();
|
||||
name.ends_with(".sh") || name.ends_with(".py") || name.ends_with(".rs")
|
||||
})
|
||||
.map(|e| e.file_name().to_string_lossy().to_string())
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
if !scripts.is_empty() {
|
||||
tool_inventory.push(format!("- **{}**: scripts: {}", name, scripts.join(", ")));
|
||||
}
|
||||
}
|
||||
} else if name.ends_with(".sh") || name.ends_with(".py") || name.ends_with(".rs") {
|
||||
// Standalone script
|
||||
tool_inventory.push(format!("- **{}**: standalone script", name));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also check for top-level README
|
||||
let main_readme = Path::new(&tools_dir).join("README.md");
|
||||
if main_readme.exists() {
|
||||
if let Ok(content) = std::fs::read_to_string(&main_readme) {
|
||||
// Return the entire tools README as inventory
|
||||
return format!(
|
||||
"\n## Available Reusable Tools (from /root/tools/)\n\n{}\n\n### Tool Inventory\n{}",
|
||||
content.chars().take(1000).collect::<String>(),
|
||||
tool_inventory.join("\n")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if tool_inventory.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!(
|
||||
"\n## Available Reusable Tools (from /root/tools/)\n\nThese tools have been created in previous runs. Check their documentation before recreating them!\n\n{}\n",
|
||||
tool_inventory.join("\n")
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the system prompt for task execution.
|
||||
fn build_system_prompt(&self, working_dir: &str, tools: &ToolRegistry) -> String {
|
||||
fn build_system_prompt(&self, working_dir: &str, tools: &ToolRegistry, reusable_tools: &str) -> String {
|
||||
let tool_descriptions = tools
|
||||
.list_tools()
|
||||
.iter()
|
||||
@@ -91,13 +200,14 @@ Run `ls -la /root/context/` to see what's available before doing anything else.
|
||||
- Don't dump files directly in `/root/` or `/root/work/`
|
||||
- Clean up temp files when done
|
||||
- Document your tools with README files
|
||||
|
||||
## Available Tools
|
||||
{reusable_tools}
|
||||
## Available Tools (API)
|
||||
{tool_descriptions}
|
||||
|
||||
## Philosophy: BE PROACTIVE
|
||||
You are encouraged to **experiment and try things**:
|
||||
- Install ANY software you need without asking (decompilers, debuggers, analyzers, language runtimes)
|
||||
- **IMPORTANT: Before creating new helper scripts, check /root/tools/ for existing reusable tools!**
|
||||
- Create helper scripts and save them in /root/tools/ for reuse
|
||||
- Write documentation for your tools so future runs can use them
|
||||
- If a tool doesn't exist, build it or find an alternative
|
||||
@@ -107,19 +217,20 @@ You are encouraged to **experiment and try things**:
|
||||
## Workflow for Unknown Files
|
||||
When encountering files you need to analyze:
|
||||
1. **IDENTIFY** — Run `file <filename>` to detect the file type
|
||||
2. **INSTALL TOOLS** — Install appropriate tools for that file type:
|
||||
2. **CHECK EXISTING TOOLS** — Look in `/root/tools/` for reusable scripts for this file type
|
||||
3. **INSTALL TOOLS** — Install appropriate tools for that file type:
|
||||
- **Java/JAR/Class**: `apt install -y default-jdk jadx` (jadx is a Java decompiler)
|
||||
- **Android APK**: `apt install -y jadx apktool`
|
||||
- **Native binaries**: `apt install -y ghidra radare2 binutils`
|
||||
- **Python .pyc**: `pip install uncompyle6 decompyle3`
|
||||
- **.NET**: `apt install -y mono-complete; pip install dnfile`
|
||||
- **Archives**: `apt install -y p7zip-full unzip`
|
||||
3. **ANALYZE** — Use the installed tools to examine/decompile the file
|
||||
4. **Handle obfuscation** — If code is obfuscated:
|
||||
4. **ANALYZE** — Use the installed tools to examine/decompile the file
|
||||
5. **Handle obfuscation** — If code is obfuscated:
|
||||
- Java: Try `java-deobfuscator` or `cfr` with string decryption
|
||||
- Look for string encryption patterns, rename variables to understand flow
|
||||
- Run the code dynamically if static analysis fails
|
||||
5. **Document findings** — Save analysis notes to your task folder in `/root/work/`
|
||||
6. **Document findings** — Save analysis notes to your task folder in `/root/work/`
|
||||
|
||||
## Java Reverse Engineering (Common)
|
||||
For .jar or .class files:
|
||||
@@ -146,23 +257,26 @@ java -jar /root/tools/cfr.jar <jar_file> --outputdir /root/work/java-analysis/ou
|
||||
## Rules
|
||||
1. **Act, don't just describe** — Use tools to accomplish tasks, don't just explain what to do
|
||||
2. **Check /root/context/ first** — This is where users put files for you
|
||||
3. **Stay organized** — Create task-specific folders in /root/work/, keep /root/context/ read-only
|
||||
4. **Identify before analyzing** — Always run `file` on unknown files
|
||||
5. **Install what you need** — Don't ask permission, just `apt install` or `pip install`
|
||||
6. **Handle obfuscation** — If decompiled code looks obfuscated, install deobfuscators and try them
|
||||
7. **Create reusable tools** — Save useful scripts to /root/tools/ with README
|
||||
8. **Verify your work** — Test, run, check outputs when possible
|
||||
9. **Iterate** — If first approach fails, try alternatives before giving up
|
||||
3. **Check /root/tools/ for existing tools** — Reuse scripts before creating new ones
|
||||
4. **Stay organized** — Create task-specific folders in /root/work/, keep /root/context/ read-only
|
||||
5. **Identify before analyzing** — Always run `file` on unknown files
|
||||
6. **Install what you need** — Don't ask permission, just `apt install` or `pip install`
|
||||
7. **Handle obfuscation** — If decompiled code looks obfuscated, install deobfuscators and try them
|
||||
8. **Create reusable tools** — Save useful scripts to /root/tools/ with README
|
||||
9. **Verify your work** — Test, run, check outputs when possible
|
||||
10. **Iterate** — If first approach fails, try alternatives before giving up
|
||||
|
||||
## Response
|
||||
When task is complete, provide a clear summary of:
|
||||
- What you did (approach taken)
|
||||
- Files created/modified (with full paths, organized in /root/work/[task]/)
|
||||
- Tools installed (for future reference)
|
||||
- Tools reused from /root/tools/ (if any)
|
||||
- How to verify the result
|
||||
- Any reusable scripts saved to /root/tools/"#,
|
||||
- Any NEW reusable scripts saved to /root/tools/"#,
|
||||
working_dir = working_dir,
|
||||
tool_descriptions = tool_descriptions
|
||||
tool_descriptions = tool_descriptions,
|
||||
reusable_tools = reusable_tools
|
||||
)
|
||||
}
|
||||
|
||||
@@ -203,8 +317,14 @@ When task is complete, provide a clear summary of:
|
||||
// If we can fetch pricing, compute real costs from token usage.
|
||||
let pricing = ctx.pricing.get_pricing(model).await;
|
||||
|
||||
// Build initial messages
|
||||
let system_prompt = self.build_system_prompt(&ctx.working_dir_str(), &ctx.tools);
|
||||
// Discover reusable tools from /root/tools/ (or working_dir/tools)
|
||||
let reusable_tools = self.discover_reusable_tools(&ctx.working_dir_str()).await;
|
||||
if !reusable_tools.is_empty() {
|
||||
tracing::info!("Discovered reusable tools inventory");
|
||||
}
|
||||
|
||||
// Build initial messages with reusable tools info
|
||||
let system_prompt = self.build_system_prompt(&ctx.working_dir_str(), &ctx.tools, &reusable_tools);
|
||||
let mut messages = vec![
|
||||
ChatMessage {
|
||||
role: Role::System,
|
||||
|
||||
@@ -8,14 +8,20 @@
|
||||
//!
|
||||
//! # Cost Model
|
||||
//! Expected Cost = base_cost * (1 + failure_rate * retry_multiplier) * token_efficiency
|
||||
//!
|
||||
//! # Learning Integration
|
||||
//! When memory is available, uses historical model statistics (actual success rates,
|
||||
//! cost ratios) instead of pure heuristics.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::agents::{
|
||||
Agent, AgentContext, AgentId, AgentResult, AgentType, LeafAgent, LeafCapability,
|
||||
};
|
||||
use crate::budget::PricingInfo;
|
||||
use crate::memory::ModelStats;
|
||||
use crate::task::Task;
|
||||
|
||||
/// Agent that selects the optimal model for a task.
|
||||
@@ -49,6 +55,9 @@ pub struct ModelRecommendation {
|
||||
|
||||
/// Alternative models if primary fails
|
||||
pub fallbacks: Vec<String>,
|
||||
|
||||
/// Whether historical data was used for this selection
|
||||
pub used_historical_data: bool,
|
||||
}
|
||||
|
||||
impl ModelSelector {
|
||||
@@ -176,15 +185,22 @@ impl ModelSelector {
|
||||
complexity: f64,
|
||||
estimated_tokens: u64,
|
||||
budget_cents: u64,
|
||||
historical_stats: Option<&HashMap<String, ModelStats>>,
|
||||
) -> Option<ModelRecommendation> {
|
||||
if models.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Calculate expected cost for all models
|
||||
// Calculate expected cost for all models, using historical stats when available
|
||||
let mut costs: Vec<ExpectedCost> = models
|
||||
.iter()
|
||||
.map(|m| self.calculate_expected_cost(m, complexity, estimated_tokens))
|
||||
.map(|m| {
|
||||
if let Some(stats) = historical_stats.and_then(|h| h.get(&m.model_id)) {
|
||||
self.calculate_expected_cost_with_history(m, complexity, estimated_tokens, stats)
|
||||
} else {
|
||||
self.calculate_expected_cost(m, complexity, estimated_tokens)
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Sort by expected cost (ascending)
|
||||
@@ -209,20 +225,102 @@ impl ModelSelector {
|
||||
.map(|c| c.model_id.clone())
|
||||
.collect();
|
||||
|
||||
let used_history = historical_stats.and_then(|h| h.get(&selected.model_id)).is_some();
|
||||
|
||||
Some(ModelRecommendation {
|
||||
model_id: selected.model_id.clone(),
|
||||
expected_cost_cents: selected.expected_cost_cents,
|
||||
confidence: 1.0 - selected.failure_probability,
|
||||
reasoning: format!(
|
||||
"Selected {} with expected cost {} cents (capability: {:.2}, failure prob: {:.2})",
|
||||
"Selected {} with expected cost {} cents (capability: {:.2}, failure prob: {:.2}){}",
|
||||
selected.model_id,
|
||||
selected.expected_cost_cents,
|
||||
selected.capability,
|
||||
selected.failure_probability
|
||||
selected.failure_probability,
|
||||
if used_history { " [from historical data]" } else { "" }
|
||||
),
|
||||
fallbacks,
|
||||
used_historical_data: used_history,
|
||||
})
|
||||
}
|
||||
|
||||
/// Calculate expected cost using actual historical statistics.
|
||||
///
|
||||
/// This uses real success rates and cost ratios from past executions
|
||||
/// instead of heuristic estimates.
|
||||
fn calculate_expected_cost_with_history(
|
||||
&self,
|
||||
pricing: &PricingInfo,
|
||||
_complexity: f64,
|
||||
estimated_tokens: u64,
|
||||
stats: &ModelStats,
|
||||
) -> ExpectedCost {
|
||||
// Use actual failure rate from history (inverted success rate)
|
||||
let failure_prob = (1.0 - stats.success_rate).clamp(0.0, self.max_failure_probability);
|
||||
|
||||
// Use actual token ratio from history for inefficiency
|
||||
let inefficiency = stats.avg_token_ratio.clamp(0.5, 3.0);
|
||||
|
||||
// Base cost for estimated tokens
|
||||
let input_tokens = estimated_tokens / 2;
|
||||
let output_tokens = estimated_tokens / 2;
|
||||
let base_cost = pricing.calculate_cost_cents(input_tokens, output_tokens);
|
||||
|
||||
// Adjust for actual inefficiency
|
||||
let adjusted_tokens = ((estimated_tokens as f64) * inefficiency) as u64;
|
||||
let adjusted_cost = pricing.calculate_cost_cents(adjusted_tokens / 2, adjusted_tokens / 2);
|
||||
|
||||
// Apply actual cost ratio (how much more/less than predicted)
|
||||
let cost_with_ratio = (adjusted_cost as f64) * stats.avg_cost_ratio.clamp(0.5, 3.0);
|
||||
|
||||
// Expected cost including retry probability
|
||||
let expected_cost = cost_with_ratio * (1.0 + failure_prob * self.retry_multiplier);
|
||||
|
||||
// Capability estimated from success rate rather than price
|
||||
let capability = stats.success_rate.clamp(0.3, 0.95);
|
||||
|
||||
ExpectedCost {
|
||||
model_id: pricing.model_id.clone(),
|
||||
base_cost_cents: base_cost,
|
||||
expected_cost_cents: expected_cost.ceil() as u64,
|
||||
failure_probability: failure_prob,
|
||||
capability,
|
||||
inefficiency,
|
||||
}
|
||||
}
|
||||
|
||||
/// Query historical model stats from memory.
|
||||
async fn get_historical_model_stats(
|
||||
&self,
|
||||
complexity: f64,
|
||||
ctx: &AgentContext,
|
||||
) -> Option<HashMap<String, ModelStats>> {
|
||||
let memory = ctx.memory.as_ref()?;
|
||||
|
||||
// Query stats for models at similar complexity levels (+/- 0.2)
|
||||
match memory.retriever.get_model_stats(complexity, 0.2).await {
|
||||
Ok(stats) if !stats.is_empty() => {
|
||||
tracing::debug!(
|
||||
"Found historical stats for {} models at complexity ~{:.2}",
|
||||
stats.len(),
|
||||
complexity
|
||||
);
|
||||
|
||||
// Convert to HashMap for easy lookup
|
||||
Some(stats.into_iter()
|
||||
.map(|s| (s.model_id.clone(), s))
|
||||
.collect())
|
||||
}
|
||||
Ok(_) => {
|
||||
tracing::debug!("No historical stats found for complexity ~{:.2}", complexity);
|
||||
None
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to fetch model stats: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Intermediate calculation result for a model.
|
||||
@@ -265,6 +363,10 @@ impl Agent for ModelSelector {
|
||||
///
|
||||
/// # Returns
|
||||
/// AgentResult with ModelRecommendation in the `data` field.
|
||||
///
|
||||
/// # Learning Integration
|
||||
/// When memory is available, queries historical model statistics and uses
|
||||
/// actual success rates/cost ratios instead of heuristics.
|
||||
async fn execute(&self, task: &mut Task, ctx: &AgentContext) -> AgentResult {
|
||||
// Get complexity + estimated tokens from task analysis (populated by ComplexityEstimator).
|
||||
let complexity = task
|
||||
@@ -277,6 +379,9 @@ impl Agent for ModelSelector {
|
||||
// Get available budget
|
||||
let budget_cents = task.budget().remaining_cents();
|
||||
|
||||
// Query historical model stats (if memory available)
|
||||
let historical_stats = self.get_historical_model_stats(complexity, ctx).await;
|
||||
|
||||
// Fetch pricing for tool-supporting models only
|
||||
let models = ctx.pricing.models_by_cost_filtered(true).await;
|
||||
|
||||
@@ -300,10 +405,11 @@ impl Agent for ModelSelector {
|
||||
"confidence": 0.8,
|
||||
"reasoning": "Fallback to configured default model",
|
||||
"fallbacks": [],
|
||||
"used_historical_data": false,
|
||||
}));
|
||||
}
|
||||
|
||||
match self.select_optimal(&models, complexity, estimated_tokens, budget_cents) {
|
||||
match self.select_optimal(&models, complexity, estimated_tokens, budget_cents, historical_stats.as_ref()) {
|
||||
Some(rec) => {
|
||||
// Record selection in analysis
|
||||
{
|
||||
@@ -322,6 +428,8 @@ impl Agent for ModelSelector {
|
||||
"confidence": rec.confidence,
|
||||
"reasoning": rec.reasoning,
|
||||
"fallbacks": rec.fallbacks,
|
||||
"used_historical_data": rec.used_historical_data,
|
||||
"historical_stats_available": historical_stats.as_ref().map(|h| h.len()),
|
||||
"inputs": {
|
||||
"complexity": complexity,
|
||||
"estimated_tokens": estimated_tokens,
|
||||
|
||||
@@ -362,6 +362,43 @@ async fn run_agent_task(
|
||||
)
|
||||
.await;
|
||||
|
||||
// Record task outcome for learning system
|
||||
// Extract metrics from the task analysis and result
|
||||
let analysis = task.analysis();
|
||||
let actual_usage = analysis.actual_usage.as_ref();
|
||||
let actual_tokens = actual_usage.map(|u| u.total_tokens as i64);
|
||||
|
||||
// Extract tool call count from result data
|
||||
let tool_calls_count = result.data.as_ref()
|
||||
.and_then(|d| d.get("tool_calls"))
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|v| v as i32);
|
||||
|
||||
// Extract iterations from execution signals if available
|
||||
let iterations = result.data.as_ref()
|
||||
.and_then(|d| d.get("execution_signals"))
|
||||
.and_then(|s| s.get("iterations"))
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|v| v as i32);
|
||||
|
||||
let _ = mem
|
||||
.writer
|
||||
.record_task_outcome(
|
||||
run_id,
|
||||
task.id().as_uuid(),
|
||||
&task_description,
|
||||
analysis.complexity_score,
|
||||
analysis.estimated_total_tokens.map(|t| t as i64),
|
||||
analysis.estimated_cost_cents.map(|c| c as i64),
|
||||
analysis.selected_model.clone(),
|
||||
actual_tokens,
|
||||
Some(result.cost_cents as i64),
|
||||
result.success,
|
||||
iterations,
|
||||
tool_calls_count,
|
||||
)
|
||||
.await;
|
||||
|
||||
// Generate and store summary
|
||||
let summary = format!(
|
||||
"Task: {}\nResult: {}\nSuccess: {}",
|
||||
@@ -627,4 +664,4 @@ async fn search_memory(
|
||||
"query": params.q,
|
||||
"results": results
|
||||
})))
|
||||
}
|
||||
}
|
||||
@@ -1 +1,643 @@
|
||||
//! Supabase client for PostgREST and Storage APIs.
|
||||
|
||||
use reqwest::Client;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::types::{DbRun, DbTask, DbEvent, DbChunk, DbTaskOutcome, SearchResult, ModelStats, DbMission, MissionMessage};
|
||||
|
||||
/// Supabase client for database and storage operations.
|
||||
pub struct SupabaseClient {
|
||||
client: Client,
|
||||
url: String,
|
||||
service_role_key: String,
|
||||
}
|
||||
|
||||
impl SupabaseClient {
|
||||
/// Create a new Supabase client.
|
||||
pub fn new(url: &str, service_role_key: &str) -> Self {
|
||||
Self {
|
||||
client: Client::new(),
|
||||
url: url.trim_end_matches('/').to_string(),
|
||||
service_role_key: service_role_key.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the PostgREST URL.
|
||||
fn rest_url(&self) -> String {
|
||||
format!("{}/rest/v1", self.url)
|
||||
}
|
||||
|
||||
/// Get the Storage URL.
|
||||
fn storage_url(&self) -> String {
|
||||
format!("{}/storage/v1", self.url)
|
||||
}
|
||||
|
||||
// ==================== Runs ====================
|
||||
|
||||
/// Create a new run.
|
||||
pub async fn create_run(&self, input_text: &str) -> anyhow::Result<DbRun> {
|
||||
let body = serde_json::json!({
|
||||
"input_text": input_text,
|
||||
"status": "pending"
|
||||
});
|
||||
|
||||
let resp = self.client
|
||||
.post(format!("{}/runs", self.rest_url()))
|
||||
.header("apikey", &self.service_role_key)
|
||||
.header("Authorization", format!("Bearer {}", self.service_role_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Prefer", "return=representation")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = resp.status();
|
||||
let text = resp.text().await?;
|
||||
|
||||
if !status.is_success() {
|
||||
anyhow::bail!("Failed to create run: {} - {}", status, text);
|
||||
}
|
||||
|
||||
let runs: Vec<DbRun> = serde_json::from_str(&text)?;
|
||||
runs.into_iter().next().ok_or_else(|| anyhow::anyhow!("No run returned"))
|
||||
}
|
||||
|
||||
/// Update a run.
|
||||
pub async fn update_run(&self, id: Uuid, updates: serde_json::Value) -> anyhow::Result<()> {
|
||||
let resp = self.client
|
||||
.patch(format!("{}/runs?id=eq.{}", self.rest_url(), id))
|
||||
.header("apikey", &self.service_role_key)
|
||||
.header("Authorization", format!("Bearer {}", self.service_role_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&updates)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let text = resp.text().await?;
|
||||
anyhow::bail!("Failed to update run: {}", text);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get a run by ID.
|
||||
pub async fn get_run(&self, id: Uuid) -> anyhow::Result<Option<DbRun>> {
|
||||
let resp = self.client
|
||||
.get(format!("{}/runs?id=eq.{}", self.rest_url(), id))
|
||||
.header("apikey", &self.service_role_key)
|
||||
.header("Authorization", format!("Bearer {}", self.service_role_key))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let runs: Vec<DbRun> = resp.json().await?;
|
||||
Ok(runs.into_iter().next())
|
||||
}
|
||||
|
||||
/// List runs with pagination.
|
||||
pub async fn list_runs(&self, limit: usize, offset: usize) -> anyhow::Result<Vec<DbRun>> {
|
||||
let resp = self.client
|
||||
.get(format!(
|
||||
"{}/runs?order=created_at.desc&limit={}&offset={}",
|
||||
self.rest_url(), limit, offset
|
||||
))
|
||||
.header("apikey", &self.service_role_key)
|
||||
.header("Authorization", format!("Bearer {}", self.service_role_key))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
// ==================== Tasks ====================
|
||||
|
||||
/// Create a task.
|
||||
pub async fn create_task(&self, task: &DbTask) -> anyhow::Result<DbTask> {
|
||||
let resp = self.client
|
||||
.post(format!("{}/tasks", self.rest_url()))
|
||||
.header("apikey", &self.service_role_key)
|
||||
.header("Authorization", format!("Bearer {}", self.service_role_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Prefer", "return=representation")
|
||||
.json(task)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = resp.status();
|
||||
let text = resp.text().await?;
|
||||
|
||||
if !status.is_success() {
|
||||
anyhow::bail!("Failed to create task: {} - {}", status, text);
|
||||
}
|
||||
|
||||
let tasks: Vec<DbTask> = serde_json::from_str(&text)?;
|
||||
tasks.into_iter().next().ok_or_else(|| anyhow::anyhow!("No task returned"))
|
||||
}
|
||||
|
||||
/// Update a task.
|
||||
pub async fn update_task(&self, id: Uuid, updates: serde_json::Value) -> anyhow::Result<()> {
|
||||
let resp = self.client
|
||||
.patch(format!("{}/tasks?id=eq.{}", self.rest_url(), id))
|
||||
.header("apikey", &self.service_role_key)
|
||||
.header("Authorization", format!("Bearer {}", self.service_role_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&updates)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let text = resp.text().await?;
|
||||
anyhow::bail!("Failed to update task: {}", text);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get tasks for a run.
|
||||
pub async fn get_tasks_for_run(&self, run_id: Uuid) -> anyhow::Result<Vec<DbTask>> {
|
||||
let resp = self.client
|
||||
.get(format!(
|
||||
"{}/tasks?run_id=eq.{}&order=depth,seq",
|
||||
self.rest_url(), run_id
|
||||
))
|
||||
.header("apikey", &self.service_role_key)
|
||||
.header("Authorization", format!("Bearer {}", self.service_role_key))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
// ==================== Events ====================
|
||||
|
||||
/// Insert an event.
|
||||
pub async fn insert_event(&self, event: &DbEvent) -> anyhow::Result<i64> {
|
||||
let resp = self.client
|
||||
.post(format!("{}/events", self.rest_url()))
|
||||
.header("apikey", &self.service_role_key)
|
||||
.header("Authorization", format!("Bearer {}", self.service_role_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Prefer", "return=representation")
|
||||
.json(event)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = resp.status();
|
||||
let text = resp.text().await?;
|
||||
|
||||
if !status.is_success() {
|
||||
anyhow::bail!("Failed to insert event: {} - {}", status, text);
|
||||
}
|
||||
|
||||
let events: Vec<DbEvent> = serde_json::from_str(&text)?;
|
||||
events.into_iter().next()
|
||||
.and_then(|e| e.id)
|
||||
.ok_or_else(|| anyhow::anyhow!("No event ID returned"))
|
||||
}
|
||||
|
||||
/// Get events for a run.
|
||||
pub async fn get_events_for_run(&self, run_id: Uuid, limit: Option<usize>) -> anyhow::Result<Vec<DbEvent>> {
|
||||
let limit_str = limit.map(|l| format!("&limit={}", l)).unwrap_or_default();
|
||||
let resp = self.client
|
||||
.get(format!(
|
||||
"{}/events?run_id=eq.{}&order=seq{}",
|
||||
self.rest_url(), run_id, limit_str
|
||||
))
|
||||
.header("apikey", &self.service_role_key)
|
||||
.header("Authorization", format!("Bearer {}", self.service_role_key))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
// ==================== Chunks ====================
|
||||
|
||||
/// Insert a chunk with embedding.
|
||||
pub async fn insert_chunk(&self, chunk: &DbChunk, embedding: &[f32]) -> anyhow::Result<Uuid> {
|
||||
// Format embedding as Postgres array literal
|
||||
let embedding_str = format!(
|
||||
"[{}]",
|
||||
embedding.iter().map(|f| f.to_string()).collect::<Vec<_>>().join(",")
|
||||
);
|
||||
|
||||
let body = serde_json::json!({
|
||||
"run_id": chunk.run_id,
|
||||
"task_id": chunk.task_id,
|
||||
"source_event_id": chunk.source_event_id,
|
||||
"chunk_text": chunk.chunk_text,
|
||||
"embedding": embedding_str,
|
||||
"meta": chunk.meta
|
||||
});
|
||||
|
||||
let resp = self.client
|
||||
.post(format!("{}/chunks", self.rest_url()))
|
||||
.header("apikey", &self.service_role_key)
|
||||
.header("Authorization", format!("Bearer {}", self.service_role_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Prefer", "return=representation")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = resp.status();
|
||||
let text = resp.text().await?;
|
||||
|
||||
if !status.is_success() {
|
||||
anyhow::bail!("Failed to insert chunk: {} - {}", status, text);
|
||||
}
|
||||
|
||||
let chunks: Vec<DbChunk> = serde_json::from_str(&text)?;
|
||||
chunks.into_iter().next()
|
||||
.and_then(|c| c.id)
|
||||
.ok_or_else(|| anyhow::anyhow!("No chunk ID returned"))
|
||||
}
|
||||
|
||||
/// Search chunks by embedding similarity.
|
||||
pub async fn search_chunks(
|
||||
&self,
|
||||
embedding: &[f32],
|
||||
threshold: f64,
|
||||
limit: usize,
|
||||
filter_run_id: Option<Uuid>,
|
||||
) -> anyhow::Result<Vec<SearchResult>> {
|
||||
let embedding_str = format!(
|
||||
"[{}]",
|
||||
embedding.iter().map(|f| f.to_string()).collect::<Vec<_>>().join(",")
|
||||
);
|
||||
|
||||
let body = serde_json::json!({
|
||||
"query_embedding": embedding_str,
|
||||
"match_threshold": threshold,
|
||||
"match_count": limit,
|
||||
"filter_run_id": filter_run_id
|
||||
});
|
||||
|
||||
let resp = self.client
|
||||
.post(format!("{}/rpc/search_chunks", self.rest_url()))
|
||||
.header("apikey", &self.service_role_key)
|
||||
.header("Authorization", format!("Bearer {}", self.service_role_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = resp.status();
|
||||
let text = resp.text().await?;
|
||||
|
||||
if !status.is_success() {
|
||||
anyhow::bail!("Failed to search chunks: {} - {}", status, text);
|
||||
}
|
||||
|
||||
Ok(serde_json::from_str(&text)?)
|
||||
}
|
||||
|
||||
// ==================== Storage ====================
|
||||
|
||||
/// Upload a file to storage.
|
||||
pub async fn upload_file(
|
||||
&self,
|
||||
bucket: &str,
|
||||
path: &str,
|
||||
content: &[u8],
|
||||
content_type: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
let resp = self.client
|
||||
.post(format!("{}/object/{}/{}", self.storage_url(), bucket, path))
|
||||
.header("apikey", &self.service_role_key)
|
||||
.header("Authorization", format!("Bearer {}", self.service_role_key))
|
||||
.header("Content-Type", content_type)
|
||||
.body(content.to_vec())
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let text = resp.text().await?;
|
||||
anyhow::bail!("Failed to upload file: {} - {}", status, text);
|
||||
}
|
||||
|
||||
Ok(format!("{}/{}", bucket, path))
|
||||
}
|
||||
|
||||
/// Download a file from storage.
|
||||
pub async fn download_file(&self, bucket: &str, path: &str) -> anyhow::Result<Vec<u8>> {
|
||||
let resp = self.client
|
||||
.get(format!("{}/object/{}/{}", self.storage_url(), bucket, path))
|
||||
.header("apikey", &self.service_role_key)
|
||||
.header("Authorization", format!("Bearer {}", self.service_role_key))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let text = resp.text().await?;
|
||||
anyhow::bail!("Failed to download file: {} - {}", status, text);
|
||||
}
|
||||
|
||||
Ok(resp.bytes().await?.to_vec())
|
||||
}
|
||||
|
||||
/// Update run with summary embedding.
|
||||
pub async fn update_run_summary(
|
||||
&self,
|
||||
run_id: Uuid,
|
||||
summary_text: &str,
|
||||
embedding: &[f32],
|
||||
) -> anyhow::Result<()> {
|
||||
let embedding_str = format!(
|
||||
"[{}]",
|
||||
embedding.iter().map(|f| f.to_string()).collect::<Vec<_>>().join(",")
|
||||
);
|
||||
|
||||
let body = serde_json::json!({
|
||||
"summary_text": summary_text,
|
||||
"summary_embedding": embedding_str
|
||||
});
|
||||
|
||||
self.update_run(run_id, body).await
|
||||
}
|
||||
|
||||
// ==================== Task Outcomes (Learning) ====================
|
||||
|
||||
/// Insert a task outcome for learning.
|
||||
pub async fn insert_task_outcome(
|
||||
&self,
|
||||
outcome: &DbTaskOutcome,
|
||||
embedding: Option<&[f32]>,
|
||||
) -> anyhow::Result<Uuid> {
|
||||
let embedding_str = embedding.map(|e| format!(
|
||||
"[{}]",
|
||||
e.iter().map(|f| f.to_string()).collect::<Vec<_>>().join(",")
|
||||
));
|
||||
|
||||
let body = serde_json::json!({
|
||||
"run_id": outcome.run_id,
|
||||
"task_id": outcome.task_id,
|
||||
"predicted_complexity": outcome.predicted_complexity,
|
||||
"predicted_tokens": outcome.predicted_tokens,
|
||||
"predicted_cost_cents": outcome.predicted_cost_cents,
|
||||
"selected_model": outcome.selected_model,
|
||||
"actual_tokens": outcome.actual_tokens,
|
||||
"actual_cost_cents": outcome.actual_cost_cents,
|
||||
"success": outcome.success,
|
||||
"iterations": outcome.iterations,
|
||||
"tool_calls_count": outcome.tool_calls_count,
|
||||
"task_description": outcome.task_description,
|
||||
"task_type": outcome.task_type,
|
||||
"cost_error_ratio": outcome.cost_error_ratio,
|
||||
"token_error_ratio": outcome.token_error_ratio,
|
||||
"task_embedding": embedding_str
|
||||
});
|
||||
|
||||
let resp = self.client
|
||||
.post(format!("{}/task_outcomes", self.rest_url()))
|
||||
.header("apikey", &self.service_role_key)
|
||||
.header("Authorization", format!("Bearer {}", self.service_role_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Prefer", "return=representation")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = resp.status();
|
||||
let text = resp.text().await?;
|
||||
|
||||
if !status.is_success() {
|
||||
anyhow::bail!("Failed to insert task outcome: {} - {}", status, text);
|
||||
}
|
||||
|
||||
let outcomes: Vec<DbTaskOutcome> = serde_json::from_str(&text)?;
|
||||
outcomes.into_iter().next()
|
||||
.and_then(|o| o.id)
|
||||
.ok_or_else(|| anyhow::anyhow!("No outcome ID returned"))
|
||||
}
|
||||
|
||||
/// Get model statistics for a complexity range.
|
||||
///
|
||||
/// Returns aggregated stats for each model that has been used
|
||||
/// for tasks in the given complexity range.
|
||||
pub async fn get_model_stats(
|
||||
&self,
|
||||
complexity_min: f64,
|
||||
complexity_max: f64,
|
||||
) -> anyhow::Result<Vec<ModelStats>> {
|
||||
let body = serde_json::json!({
|
||||
"complexity_min": complexity_min,
|
||||
"complexity_max": complexity_max
|
||||
});
|
||||
|
||||
let resp = self.client
|
||||
.post(format!("{}/rpc/get_model_stats", self.rest_url()))
|
||||
.header("apikey", &self.service_role_key)
|
||||
.header("Authorization", format!("Bearer {}", self.service_role_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = resp.status();
|
||||
let text = resp.text().await?;
|
||||
|
||||
if !status.is_success() {
|
||||
// If RPC doesn't exist yet, return empty
|
||||
tracing::debug!("get_model_stats RPC not available: {}", text);
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
Ok(serde_json::from_str(&text).unwrap_or_default())
|
||||
}
|
||||
|
||||
/// Search for similar task outcomes by embedding similarity.
|
||||
pub async fn search_similar_outcomes(
|
||||
&self,
|
||||
embedding: &[f32],
|
||||
threshold: f64,
|
||||
limit: usize,
|
||||
) -> anyhow::Result<Vec<DbTaskOutcome>> {
|
||||
let embedding_str = format!(
|
||||
"[{}]",
|
||||
embedding.iter().map(|f| f.to_string()).collect::<Vec<_>>().join(",")
|
||||
);
|
||||
|
||||
let body = serde_json::json!({
|
||||
"query_embedding": embedding_str,
|
||||
"match_threshold": threshold,
|
||||
"match_count": limit
|
||||
});
|
||||
|
||||
let resp = self.client
|
||||
.post(format!("{}/rpc/search_similar_outcomes", self.rest_url()))
|
||||
.header("apikey", &self.service_role_key)
|
||||
.header("Authorization", format!("Bearer {}", self.service_role_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = resp.status();
|
||||
let text = resp.text().await?;
|
||||
|
||||
if !status.is_success() {
|
||||
// If RPC doesn't exist yet, return empty
|
||||
tracing::debug!("search_similar_outcomes RPC not available: {}", text);
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
Ok(serde_json::from_str(&text).unwrap_or_default())
|
||||
}
|
||||
|
||||
/// Get global learning statistics (for tuning).
|
||||
pub async fn get_global_stats(&self) -> anyhow::Result<serde_json::Value> {
|
||||
let resp = self.client
|
||||
.post(format!("{}/rpc/get_global_learning_stats", self.rest_url()))
|
||||
.header("apikey", &self.service_role_key)
|
||||
.header("Authorization", format!("Bearer {}", self.service_role_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&serde_json::json!({}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = resp.status();
|
||||
let text = resp.text().await?;
|
||||
|
||||
if !status.is_success() {
|
||||
tracing::debug!("get_global_learning_stats RPC not available: {}", text);
|
||||
return Ok(serde_json::json!({}));
|
||||
}
|
||||
|
||||
Ok(serde_json::from_str(&text).unwrap_or_default())
|
||||
}
|
||||
|
||||
// ==================== Missions ====================
|
||||
|
||||
/// Create a new mission.
|
||||
pub async fn create_mission(&self, title: Option<&str>) -> anyhow::Result<DbMission> {
|
||||
let body = serde_json::json!({
|
||||
"title": title,
|
||||
"status": "active",
|
||||
"history": []
|
||||
});
|
||||
|
||||
let resp = self.client
|
||||
.post(format!("{}/missions", self.rest_url()))
|
||||
.header("apikey", &self.service_role_key)
|
||||
.header("Authorization", format!("Bearer {}", self.service_role_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Prefer", "return=representation")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = resp.status();
|
||||
let text = resp.text().await?;
|
||||
|
||||
if !status.is_success() {
|
||||
anyhow::bail!("Failed to create mission: {} - {}", status, text);
|
||||
}
|
||||
|
||||
let missions: Vec<DbMission> = serde_json::from_str(&text)?;
|
||||
missions.into_iter().next().ok_or_else(|| anyhow::anyhow!("No mission returned"))
|
||||
}
|
||||
|
||||
/// Get a mission by ID.
|
||||
pub async fn get_mission(&self, id: Uuid) -> anyhow::Result<Option<DbMission>> {
|
||||
let resp = self.client
|
||||
.get(format!("{}/missions?id=eq.{}", self.rest_url(), id))
|
||||
.header("apikey", &self.service_role_key)
|
||||
.header("Authorization", format!("Bearer {}", self.service_role_key))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let missions: Vec<DbMission> = resp.json().await?;
|
||||
Ok(missions.into_iter().next())
|
||||
}
|
||||
|
||||
/// List missions with pagination.
|
||||
pub async fn list_missions(&self, limit: usize, offset: usize) -> anyhow::Result<Vec<DbMission>> {
|
||||
let resp = self.client
|
||||
.get(format!(
|
||||
"{}/missions?order=updated_at.desc&limit={}&offset={}",
|
||||
self.rest_url(), limit, offset
|
||||
))
|
||||
.header("apikey", &self.service_role_key)
|
||||
.header("Authorization", format!("Bearer {}", self.service_role_key))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
/// Update mission status.
|
||||
pub async fn update_mission_status(&self, id: Uuid, status: &str) -> anyhow::Result<()> {
|
||||
let body = serde_json::json!({
|
||||
"status": status,
|
||||
"updated_at": chrono::Utc::now().to_rfc3339()
|
||||
});
|
||||
|
||||
let resp = self.client
|
||||
.patch(format!("{}/missions?id=eq.{}", self.rest_url(), id))
|
||||
.header("apikey", &self.service_role_key)
|
||||
.header("Authorization", format!("Bearer {}", self.service_role_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let text = resp.text().await?;
|
||||
anyhow::bail!("Failed to update mission status: {}", text);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update mission history.
|
||||
pub async fn update_mission_history(&self, id: Uuid, history: &[MissionMessage]) -> anyhow::Result<()> {
|
||||
let body = serde_json::json!({
|
||||
"history": history,
|
||||
"updated_at": chrono::Utc::now().to_rfc3339()
|
||||
});
|
||||
|
||||
let resp = self.client
|
||||
.patch(format!("{}/missions?id=eq.{}", self.rest_url(), id))
|
||||
.header("apikey", &self.service_role_key)
|
||||
.header("Authorization", format!("Bearer {}", self.service_role_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let text = resp.text().await?;
|
||||
anyhow::bail!("Failed to update mission history: {}", text);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update mission title.
|
||||
pub async fn update_mission_title(&self, id: Uuid, title: &str) -> anyhow::Result<()> {
|
||||
let body = serde_json::json!({
|
||||
"title": title,
|
||||
"updated_at": chrono::Utc::now().to_rfc3339()
|
||||
});
|
||||
|
||||
let resp = self.client
|
||||
.patch(format!("{}/missions?id=eq.{}", self.rest_url(), id))
|
||||
.header("apikey", &self.service_role_key)
|
||||
.header("Authorization", format!("Bearer {}", self.service_role_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let text = resp.text().await?;
|
||||
anyhow::bail!("Failed to update mission title: {}", text);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1 +1,345 @@
|
||||
//! Types for the memory subsystem.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Status of a run or task.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum MemoryStatus {
|
||||
Pending,
|
||||
Running,
|
||||
Completed,
|
||||
Failed,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for MemoryStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Pending => write!(f, "pending"),
|
||||
Self::Running => write!(f, "running"),
|
||||
Self::Completed => write!(f, "completed"),
|
||||
Self::Failed => write!(f, "failed"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A run stored in the database.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DbRun {
|
||||
pub id: Uuid,
|
||||
pub created_at: String,
|
||||
pub updated_at: String,
|
||||
pub status: String,
|
||||
pub input_text: String,
|
||||
pub final_output: Option<String>,
|
||||
pub total_cost_cents: Option<i32>,
|
||||
pub summary_text: Option<String>,
|
||||
pub archive_path: Option<String>,
|
||||
}
|
||||
|
||||
/// A task stored in the database (hierarchical).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DbTask {
|
||||
pub id: Uuid,
|
||||
pub run_id: Uuid,
|
||||
pub parent_id: Option<Uuid>,
|
||||
pub depth: i32,
|
||||
pub seq: i32,
|
||||
pub description: String,
|
||||
pub status: String,
|
||||
pub complexity_score: Option<f64>,
|
||||
pub model_used: Option<String>,
|
||||
pub budget_cents: Option<i32>,
|
||||
pub spent_cents: Option<i32>,
|
||||
pub output: Option<String>,
|
||||
pub created_at: String,
|
||||
pub completed_at: Option<String>,
|
||||
}
|
||||
|
||||
/// An event stored in the database.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DbEvent {
|
||||
pub id: Option<i64>,
|
||||
pub run_id: Uuid,
|
||||
pub task_id: Option<Uuid>,
|
||||
pub seq: i32,
|
||||
pub ts: Option<String>,
|
||||
pub agent_type: String,
|
||||
pub event_kind: String,
|
||||
pub preview_text: Option<String>,
|
||||
pub meta: Option<serde_json::Value>,
|
||||
pub blob_path: Option<String>,
|
||||
pub prompt_tokens: Option<i32>,
|
||||
pub completion_tokens: Option<i32>,
|
||||
pub cost_cents: Option<i32>,
|
||||
}
|
||||
|
||||
/// A chunk for vector search.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DbChunk {
|
||||
pub id: Option<Uuid>,
|
||||
pub run_id: Uuid,
|
||||
pub task_id: Option<Uuid>,
|
||||
pub source_event_id: Option<i64>,
|
||||
pub chunk_text: String,
|
||||
pub meta: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Event kinds for the event stream.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum EventKind {
|
||||
/// Task started
|
||||
TaskStart,
|
||||
/// Task completed
|
||||
TaskEnd,
|
||||
/// LLM request sent
|
||||
LlmRequest,
|
||||
/// LLM response received
|
||||
LlmResponse,
|
||||
/// Tool invoked
|
||||
ToolCall,
|
||||
/// Tool result received
|
||||
ToolResult,
|
||||
/// Complexity estimation
|
||||
ComplexityEstimate,
|
||||
/// Model selection decision
|
||||
ModelSelect,
|
||||
/// Verification result
|
||||
Verification,
|
||||
/// Task split into subtasks
|
||||
TaskSplit,
|
||||
/// Error occurred
|
||||
Error,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for EventKind {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let s = match self {
|
||||
Self::TaskStart => "task_start",
|
||||
Self::TaskEnd => "task_end",
|
||||
Self::LlmRequest => "llm_request",
|
||||
Self::LlmResponse => "llm_response",
|
||||
Self::ToolCall => "tool_call",
|
||||
Self::ToolResult => "tool_result",
|
||||
Self::ComplexityEstimate => "complexity_estimate",
|
||||
Self::ModelSelect => "model_select",
|
||||
Self::Verification => "verification",
|
||||
Self::TaskSplit => "task_split",
|
||||
Self::Error => "error",
|
||||
};
|
||||
write!(f, "{}", s)
|
||||
}
|
||||
}
|
||||
|
||||
/// Search result from vector similarity search.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SearchResult {
|
||||
pub id: Uuid,
|
||||
pub run_id: Uuid,
|
||||
pub task_id: Option<Uuid>,
|
||||
pub chunk_text: String,
|
||||
pub meta: Option<serde_json::Value>,
|
||||
pub similarity: f64,
|
||||
}
|
||||
|
||||
/// Context pack for injection into prompts.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ContextPack {
|
||||
/// Relevant chunks from memory
|
||||
pub chunks: Vec<SearchResult>,
|
||||
/// Total token estimate for the context
|
||||
pub estimated_tokens: usize,
|
||||
/// Query that was used
|
||||
pub query: String,
|
||||
}
|
||||
|
||||
/// Task outcome record for learning from execution history.
|
||||
///
|
||||
/// Captures predictions vs actuals to enable data-driven optimization
|
||||
/// of complexity estimation, model selection, and budget allocation.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DbTaskOutcome {
|
||||
pub id: Option<Uuid>,
|
||||
pub run_id: Uuid,
|
||||
pub task_id: Uuid,
|
||||
|
||||
// Predictions (what we estimated before execution)
|
||||
pub predicted_complexity: Option<f64>,
|
||||
pub predicted_tokens: Option<i64>,
|
||||
pub predicted_cost_cents: Option<i64>,
|
||||
pub selected_model: Option<String>,
|
||||
|
||||
// Actuals (what happened during execution)
|
||||
pub actual_tokens: Option<i64>,
|
||||
pub actual_cost_cents: Option<i64>,
|
||||
pub success: bool,
|
||||
pub iterations: Option<i32>,
|
||||
pub tool_calls_count: Option<i32>,
|
||||
|
||||
// Metadata for similarity search
|
||||
pub task_description: String,
|
||||
/// Category of task (inferred or explicit)
|
||||
pub task_type: Option<String>,
|
||||
|
||||
// Computed ratios for quick stats
|
||||
pub cost_error_ratio: Option<f64>,
|
||||
pub token_error_ratio: Option<f64>,
|
||||
|
||||
pub created_at: Option<String>,
|
||||
}
|
||||
|
||||
impl DbTaskOutcome {
|
||||
/// Create a new outcome from predictions and actuals.
|
||||
pub fn new(
|
||||
run_id: Uuid,
|
||||
task_id: Uuid,
|
||||
task_description: String,
|
||||
predicted_complexity: Option<f64>,
|
||||
predicted_tokens: Option<i64>,
|
||||
predicted_cost_cents: Option<i64>,
|
||||
selected_model: Option<String>,
|
||||
actual_tokens: Option<i64>,
|
||||
actual_cost_cents: Option<i64>,
|
||||
success: bool,
|
||||
iterations: Option<i32>,
|
||||
tool_calls_count: Option<i32>,
|
||||
) -> Self {
|
||||
// Compute error ratios
|
||||
let cost_error_ratio = match (actual_cost_cents, predicted_cost_cents) {
|
||||
(Some(actual), Some(predicted)) if predicted > 0 => {
|
||||
Some(actual as f64 / predicted as f64)
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
let token_error_ratio = match (actual_tokens, predicted_tokens) {
|
||||
(Some(actual), Some(predicted)) if predicted > 0 => {
|
||||
Some(actual as f64 / predicted as f64)
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
Self {
|
||||
id: None,
|
||||
run_id,
|
||||
task_id,
|
||||
predicted_complexity,
|
||||
predicted_tokens,
|
||||
predicted_cost_cents,
|
||||
selected_model,
|
||||
actual_tokens,
|
||||
actual_cost_cents,
|
||||
success,
|
||||
iterations,
|
||||
tool_calls_count,
|
||||
task_description,
|
||||
task_type: None,
|
||||
cost_error_ratio,
|
||||
token_error_ratio,
|
||||
created_at: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Model performance statistics from historical data.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelStats {
|
||||
pub model_id: String,
|
||||
/// Success rate (0-1)
|
||||
pub success_rate: f64,
|
||||
/// Average cost vs predicted (1.0 = accurate, >1 = underestimated)
|
||||
pub avg_cost_ratio: f64,
|
||||
/// Average tokens vs predicted
|
||||
pub avg_token_ratio: f64,
|
||||
/// Average iterations needed
|
||||
pub avg_iterations: f64,
|
||||
/// Number of samples
|
||||
pub sample_count: i64,
|
||||
}
|
||||
|
||||
/// Historical context for a task (similar past tasks).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HistoricalContext {
|
||||
/// Similar past task outcomes
|
||||
pub similar_outcomes: Vec<DbTaskOutcome>,
|
||||
/// Average cost adjustment multiplier
|
||||
pub avg_cost_multiplier: f64,
|
||||
/// Average token adjustment multiplier
|
||||
pub avg_token_multiplier: f64,
|
||||
/// Success rate for similar tasks
|
||||
pub similar_success_rate: f64,
|
||||
}
|
||||
|
||||
/// Mission status.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum MissionStatus {
|
||||
Active,
|
||||
Completed,
|
||||
Failed,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for MissionStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Active => write!(f, "active"),
|
||||
Self::Completed => write!(f, "completed"),
|
||||
Self::Failed => write!(f, "failed"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for MissionStatus {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"active" => Ok(Self::Active),
|
||||
"completed" => Ok(Self::Completed),
|
||||
"failed" => Ok(Self::Failed),
|
||||
_ => Err(format!("Invalid mission status: {}", s)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A mission stored in the database.
|
||||
/// Represents a persistent goal-oriented agent session.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DbMission {
|
||||
pub id: Uuid,
|
||||
pub status: String,
|
||||
pub title: Option<String>,
|
||||
/// Conversation history as JSON array of {role, content} objects
|
||||
pub history: serde_json::Value,
|
||||
pub created_at: String,
|
||||
pub updated_at: String,
|
||||
}
|
||||
|
||||
/// A message in the mission history.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MissionMessage {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
impl ContextPack {
|
||||
/// Format as a string for prompt injection.
|
||||
pub fn format_for_prompt(&self) -> String {
|
||||
if self.chunks.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
let mut out = String::from("## Relevant Context from Memory\n\n");
|
||||
for (i, chunk) in self.chunks.iter().enumerate() {
|
||||
out.push_str(&format!(
|
||||
"### Context {} (similarity: {:.2})\n{}\n\n",
|
||||
i + 1,
|
||||
chunk.similarity,
|
||||
chunk.chunk_text
|
||||
));
|
||||
}
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user