diff --git a/.cursor/rules/project.md b/.cursor/rules/project.md index f89454e..cfbf6cb 100644 --- a/.cursor/rules/project.md +++ b/.cursor/rules/project.md @@ -104,17 +104,99 @@ src/ - **Long tasks beyond context**: persist step-by-step execution so the agent can retrieve relevant context later - **Fast query + browsing**: structured metadata in Postgres, heavy blobs in Storage - **Embedding + rerank**: Qwen3 Embedding 8B for vectors, Qwen reranker for precision +- **Learning from execution**: store predictions vs actuals to improve estimates over time ### Data Flow 1. Agents emit events via `EventRecorder` 2. `MemoryWriter` persists to Supabase Postgres + Storage 3. Before LLM calls, `MemoryRetriever` fetches relevant context 4. On completion, run is archived with summary embedding +5. **Task outcomes recorded for learning** (complexity, cost, tokens, success) ### Storage Strategy -- **Postgres (pgvector)**: runs, tasks (hierarchical), events (preview), chunks (embeddings) +- **Postgres (pgvector)**: runs, tasks (hierarchical), events (preview), chunks (embeddings), **task_outcomes** - **Supabase Storage**: full event streams (jsonl), large artifacts +## Learning System (v3) + +### Purpose +Enable data-driven optimization of: +- **Complexity estimation**: learn actual token usage vs predicted +- **Model selection**: learn actual success rates per model/complexity +- **Budget allocation**: learn actual costs vs estimated + +### Architecture +``` +┌──────────────────────────────────────────────────────────────────────┐ +│ Memory-Enhanced Agent Flow │ +├──────────────────────────────────────────────────────────────────────┤ +│ │ +│ ┌──────────┐ Query similar ┌─────────────────────────────┐ │ +│ │ New Task │ ───────────────────▶│ MemoryRetriever │ │ +│ └────┬─────┘ past tasks │ - find_similar_tasks() │ │ +│ │ │ - get_historical_context() │ │ +│ ▼ │ - get_model_stats() │ │ +│ ┌────────────────┐ └───────────────┬─────────────┘ │ +│ │ Complexity │◀── historical context ─────────┘ │ +│ │ Estimator │ (avg token ratio, avg cost ratio) │ +│ │ (enhanced) │ │ +│ └────────┬───────┘ │ +│ │ │ +│ ▼ │ +│ ┌────────────────┐ Query: "models at complexity ~0.6" │ +│ │ Model Selector │ Returns: actual success rates, cost ratios │ +│ │ (enhanced) │ │ +│ └────────┬───────┘ │ +│ │ │ +│ ▼ │ +│ ┌────────────────┐ │ +│ │ TaskExecutor │──▶ record_task_outcome() ──▶ task_outcomes │ +│ └────────────────┘ │ +│ │ +└──────────────────────────────────────────────────────────────────────┘ +``` + +### Database Schema: `task_outcomes` +```sql +CREATE TABLE task_outcomes ( + id uuid PRIMARY KEY, + run_id uuid REFERENCES runs(id), + task_id uuid REFERENCES tasks(id), + + -- Predictions + predicted_complexity float, + predicted_tokens bigint, + predicted_cost_cents bigint, + selected_model text, + + -- Actuals + actual_tokens bigint, + actual_cost_cents bigint, + success boolean, + iterations int, + tool_calls_count int, + + -- Computed ratios (actual/predicted) + cost_error_ratio float, + token_error_ratio float, + + -- Similarity search + task_description text, + task_embedding vector(1536) +); +``` + +### RPC Functions +- `get_model_stats(complexity_min, complexity_max)` - Model performance by complexity tier +- `search_similar_outcomes(embedding, threshold, limit)` - Find similar past tasks +- `get_global_learning_stats()` - Overall system metrics + +### Learning Integration Points +1. **ComplexityEstimator**: Query similar tasks → adjust token estimate by `avg_token_ratio` +2. **ModelSelector**: Query model stats → use actual success rates instead of heuristics +3. **TaskExecutor**: After execution → call `record_task_outcome()` with all metrics +4. **Budget**: Use historical cost ratios to add appropriate safety margins + ## Design for Provability ### Conventions for Future Lean Proofs @@ -163,7 +245,7 @@ GET /api/memory/search - Semantic search across memory ``` OPENROUTER_API_KEY - Required. Your OpenRouter API key -DEFAULT_MODEL - Optional. Default: openai/gpt-4.1-mini +DEFAULT_MODEL - Optional. Default: anthropic/claude-sonnet-4.5 WORKSPACE_PATH - Optional. Default: current directory HOST - Optional. Default: 127.0.0.1 PORT - Optional. Default: 3000 @@ -174,6 +256,11 @@ MEMORY_EMBED_MODEL - Optional. Default: qwen/qwen3-embedding-8b MEMORY_RERANK_MODEL - Optional. Default: qwen/qwen3-reranker-8b ``` +### Recommended Models +- **Default (tools)**: `anthropic/claude-sonnet-4.5` - Best coding, 1M context, $3/$15 per 1M tokens +- **Budget fallback**: `anthropic/claude-3.5-haiku` - Fast, cheap, good for simple tasks +- **Complex tasks**: `anthropic/claude-opus-4.5` - Highest capability when needed + ## Security Considerations This agent has **full machine access**. It can: @@ -192,8 +279,12 @@ When deploying: - [ ] Formal verification in Lean (extract pure logic) - [ ] WebSocket for bidirectional streaming +- [ ] Budget overflow strategies (fallback to cheaper model, request extension) +- [ ] Enhanced ComplexityEstimator with historical context injection +- [ ] Enhanced ModelSelector with data-driven success rates - [x] Semantic code search (embeddings-based) - [x] Multi-model support (U-curve optimization) - [x] Cost tracking (Budget system) - [x] Persistent memory (Supabase + pgvector) +- [x] Learning system (task_outcomes table, historical queries) diff --git a/migrations/002_task_outcomes.sql b/migrations/002_task_outcomes.sql new file mode 100644 index 0000000..ea3d175 --- /dev/null +++ b/migrations/002_task_outcomes.sql @@ -0,0 +1,239 @@ +-- Migration: Task Outcomes for Learning +-- This table stores predictions vs actuals for each task execution, +-- enabling the agent to learn optimal model selection and cost estimation. + +-- ============================================================================ +-- Table: task_outcomes +-- ============================================================================ + +CREATE TABLE IF NOT EXISTS task_outcomes ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + run_id uuid REFERENCES runs(id) ON DELETE CASCADE, + task_id uuid REFERENCES tasks(id) ON DELETE CASCADE, + + -- Predictions (what we estimated before execution) + predicted_complexity float, + predicted_tokens bigint, + predicted_cost_cents bigint, + selected_model text, + + -- Actuals (what happened during execution) + actual_tokens bigint, + actual_cost_cents bigint, + success boolean NOT NULL DEFAULT false, + iterations int, + tool_calls_count int, + + -- Metadata for similarity search + task_description text NOT NULL, + task_type text, -- 'file_create', 'refactor', 'debug', etc. + task_embedding vector(1536), -- For finding similar tasks + + -- Computed ratios for quick stats + cost_error_ratio float, -- actual/predicted (1.0 = accurate) + token_error_ratio float, + + created_at timestamptz DEFAULT now() +); + +-- Indexes for efficient queries +CREATE INDEX IF NOT EXISTS idx_outcomes_run_id ON task_outcomes(run_id); +CREATE INDEX IF NOT EXISTS idx_outcomes_task_id ON task_outcomes(task_id); +CREATE INDEX IF NOT EXISTS idx_outcomes_model ON task_outcomes(selected_model, success); +CREATE INDEX IF NOT EXISTS idx_outcomes_complexity ON task_outcomes(predicted_complexity); +CREATE INDEX IF NOT EXISTS idx_outcomes_created ON task_outcomes(created_at DESC); + +-- Vector index for similarity search (HNSW) +CREATE INDEX IF NOT EXISTS idx_outcomes_embedding ON task_outcomes + USING hnsw (task_embedding vector_cosine_ops); + +-- ============================================================================ +-- RPC: get_model_stats +-- Returns aggregated statistics per model for a given complexity range. +-- ============================================================================ + +CREATE OR REPLACE FUNCTION get_model_stats( + complexity_min float DEFAULT 0.0, + complexity_max float DEFAULT 1.0 +) +RETURNS TABLE ( + model_id text, + success_rate float, + avg_cost_ratio float, + avg_token_ratio float, + avg_iterations float, + sample_count bigint +) +LANGUAGE sql STABLE +AS $$ + SELECT + selected_model as model_id, + AVG(CASE WHEN success THEN 1.0 ELSE 0.0 END) as success_rate, + COALESCE(AVG(cost_error_ratio), 1.0) as avg_cost_ratio, + COALESCE(AVG(token_error_ratio), 1.0) as avg_token_ratio, + COALESCE(AVG(iterations), 1.0) as avg_iterations, + COUNT(*) as sample_count + FROM task_outcomes + WHERE + selected_model IS NOT NULL + AND predicted_complexity >= complexity_min + AND predicted_complexity <= complexity_max + GROUP BY selected_model + HAVING COUNT(*) >= 3 -- Minimum samples for reliability + ORDER BY success_rate DESC, avg_cost_ratio ASC; +$$; + +-- ============================================================================ +-- RPC: search_similar_outcomes +-- Find similar past task outcomes by embedding similarity. +-- ============================================================================ + +CREATE OR REPLACE FUNCTION search_similar_outcomes( + query_embedding vector(1536), + match_threshold float DEFAULT 0.7, + match_count int DEFAULT 5 +) +RETURNS TABLE ( + id uuid, + run_id uuid, + task_id uuid, + predicted_complexity float, + predicted_tokens bigint, + predicted_cost_cents bigint, + selected_model text, + actual_tokens bigint, + actual_cost_cents bigint, + success boolean, + iterations int, + tool_calls_count int, + task_description text, + task_type text, + cost_error_ratio float, + token_error_ratio float, + created_at timestamptz, + similarity float +) +LANGUAGE sql STABLE +AS $$ + SELECT + o.id, + o.run_id, + o.task_id, + o.predicted_complexity, + o.predicted_tokens, + o.predicted_cost_cents, + o.selected_model, + o.actual_tokens, + o.actual_cost_cents, + o.success, + o.iterations, + o.tool_calls_count, + o.task_description, + o.task_type, + o.cost_error_ratio, + o.token_error_ratio, + o.created_at, + 1 - (o.task_embedding <=> query_embedding) as similarity + FROM task_outcomes o + WHERE + o.task_embedding IS NOT NULL + AND 1 - (o.task_embedding <=> query_embedding) > match_threshold + ORDER BY o.task_embedding <=> query_embedding + LIMIT match_count; +$$; + +-- ============================================================================ +-- RPC: get_global_learning_stats +-- Returns overall system learning statistics for monitoring/tuning. +-- ============================================================================ + +CREATE OR REPLACE FUNCTION get_global_learning_stats() +RETURNS json +LANGUAGE sql STABLE +AS $$ + SELECT json_build_object( + 'total_outcomes', (SELECT COUNT(*) FROM task_outcomes), + 'success_rate', (SELECT AVG(CASE WHEN success THEN 1.0 ELSE 0.0 END) FROM task_outcomes), + 'avg_cost_error', (SELECT AVG(cost_error_ratio) FROM task_outcomes WHERE cost_error_ratio IS NOT NULL), + 'avg_token_error', (SELECT AVG(token_error_ratio) FROM task_outcomes WHERE token_error_ratio IS NOT NULL), + 'models_used', (SELECT COUNT(DISTINCT selected_model) FROM task_outcomes), + 'top_models', ( + SELECT json_agg(row_to_json(t)) + FROM ( + SELECT + selected_model, + COUNT(*) as uses, + AVG(CASE WHEN success THEN 1.0 ELSE 0.0 END) as success_rate + FROM task_outcomes + WHERE selected_model IS NOT NULL + GROUP BY selected_model + ORDER BY uses DESC + LIMIT 5 + ) t + ), + 'complexity_distribution', ( + SELECT json_agg(row_to_json(t)) + FROM ( + SELECT + CASE + WHEN predicted_complexity < 0.2 THEN 'trivial' + WHEN predicted_complexity < 0.4 THEN 'simple' + WHEN predicted_complexity < 0.6 THEN 'moderate' + WHEN predicted_complexity < 0.8 THEN 'complex' + ELSE 'very_complex' + END as tier, + COUNT(*) as count, + AVG(CASE WHEN success THEN 1.0 ELSE 0.0 END) as success_rate + FROM task_outcomes + WHERE predicted_complexity IS NOT NULL + GROUP BY tier + ) t + ) + ); +$$; + +-- ============================================================================ +-- RPC: get_optimal_model_for_complexity +-- Returns the best model for a given complexity based on historical data. +-- ============================================================================ + +CREATE OR REPLACE FUNCTION get_optimal_model_for_complexity( + target_complexity float, + budget_cents bigint DEFAULT 1000 +) +RETURNS TABLE ( + model_id text, + expected_success_rate float, + expected_cost_cents float, + confidence float +) +LANGUAGE sql STABLE +AS $$ + WITH model_perf AS ( + SELECT + selected_model, + AVG(CASE WHEN success THEN 1.0 ELSE 0.0 END) as success_rate, + AVG(actual_cost_cents) as avg_cost, + COUNT(*) as samples, + -- Confidence based on sample size and recency + LEAST(1.0, COUNT(*) / 10.0) as sample_confidence + FROM task_outcomes + WHERE + selected_model IS NOT NULL + AND ABS(predicted_complexity - target_complexity) < 0.2 + AND created_at > now() - interval '30 days' + GROUP BY selected_model + ) + SELECT + selected_model as model_id, + success_rate as expected_success_rate, + avg_cost as expected_cost_cents, + sample_confidence as confidence + FROM model_perf + WHERE avg_cost <= budget_cents OR samples < 3 -- Allow trying new models + ORDER BY + -- Balance success rate and cost + (success_rate * 0.7 + (1.0 - LEAST(avg_cost / budget_cents, 1.0)) * 0.3) DESC + LIMIT 3; +$$; + diff --git a/src/memory/retriever.rs b/src/memory/retriever.rs index f173841..c1c447d 100644 --- a/src/memory/retriever.rs +++ b/src/memory/retriever.rs @@ -6,7 +6,7 @@ use serde::Deserialize; use super::supabase::SupabaseClient; use super::embed::EmbeddingClient; -use super::types::{SearchResult, ContextPack}; +use super::types::{SearchResult, ContextPack, DbTaskOutcome, ModelStats, HistoricalContext}; /// Default similarity threshold for vector search. const DEFAULT_THRESHOLD: f64 = 0.5; @@ -230,6 +230,82 @@ Only return the JSON array, nothing else."#, pub async fn list_runs(&self, limit: usize, offset: usize) -> anyhow::Result> { self.supabase.list_runs(limit, offset).await } + + // ==================== Learning Methods ==================== + + /// Get model performance statistics for a given complexity range. + /// + /// Returns historical success rates, cost ratios, etc. for each model + /// that has been used at the given complexity level. + pub async fn get_model_stats( + &self, + complexity: f64, + range: f64, + ) -> anyhow::Result> { + let min = (complexity - range).max(0.0); + let max = (complexity + range).min(1.0); + self.supabase.get_model_stats(min, max).await + } + + /// Find similar past tasks and their outcomes. + /// + /// Uses embedding similarity to find tasks that are semantically similar + /// to the given task description, then returns their execution outcomes. + pub async fn find_similar_tasks( + &self, + task_description: &str, + limit: usize, + ) -> anyhow::Result> { + // Generate embedding for the task description + let embedding = self.embedder.embed(task_description).await?; + + // Search for similar outcomes + self.supabase.search_similar_outcomes(&embedding, 0.6, limit).await + } + + /// Get historical context for a task. + /// + /// Returns aggregated learning data from similar past tasks including: + /// - Average cost adjustment multiplier + /// - Average token adjustment multiplier + /// - Success rate for similar tasks + pub async fn get_historical_context( + &self, + task_description: &str, + limit: usize, + ) -> anyhow::Result> { + let similar = self.find_similar_tasks(task_description, limit).await?; + + if similar.is_empty() { + return Ok(None); + } + + // Calculate aggregated stats + let total = similar.len() as f64; + + let avg_cost_multiplier = similar.iter() + .filter_map(|o| o.cost_error_ratio) + .sum::() / similar.iter().filter(|o| o.cost_error_ratio.is_some()).count().max(1) as f64; + + let avg_token_multiplier = similar.iter() + .filter_map(|o| o.token_error_ratio) + .sum::() / similar.iter().filter(|o| o.token_error_ratio.is_some()).count().max(1) as f64; + + let success_count = similar.iter().filter(|o| o.success).count() as f64; + let similar_success_rate = success_count / total; + + Ok(Some(HistoricalContext { + similar_outcomes: similar, + avg_cost_multiplier: if avg_cost_multiplier.is_nan() { 1.0 } else { avg_cost_multiplier }, + avg_token_multiplier: if avg_token_multiplier.is_nan() { 1.0 } else { avg_token_multiplier }, + similar_success_rate, + })) + } + + /// Get global learning statistics. + pub async fn get_learning_stats(&self) -> anyhow::Result { + self.supabase.get_global_stats().await + } } /// Truncate a string to max length. diff --git a/src/memory/supabase.rs b/src/memory/supabase.rs index 421072a..a9dc19b 100644 --- a/src/memory/supabase.rs +++ b/src/memory/supabase.rs @@ -3,7 +3,7 @@ use reqwest::Client; use uuid::Uuid; -use super::types::{DbRun, DbTask, DbEvent, DbChunk, SearchResult}; +use super::types::{DbRun, DbTask, DbEvent, DbChunk, DbTaskOutcome, SearchResult, ModelStats}; /// Supabase client for database and storage operations. pub struct SupabaseClient { @@ -357,5 +357,156 @@ impl SupabaseClient { self.update_run(run_id, body).await } + + // ==================== Task Outcomes (Learning) ==================== + + /// Insert a task outcome for learning. + pub async fn insert_task_outcome( + &self, + outcome: &DbTaskOutcome, + embedding: Option<&[f32]>, + ) -> anyhow::Result { + let embedding_str = embedding.map(|e| format!( + "[{}]", + e.iter().map(|f| f.to_string()).collect::>().join(",") + )); + + let body = serde_json::json!({ + "run_id": outcome.run_id, + "task_id": outcome.task_id, + "predicted_complexity": outcome.predicted_complexity, + "predicted_tokens": outcome.predicted_tokens, + "predicted_cost_cents": outcome.predicted_cost_cents, + "selected_model": outcome.selected_model, + "actual_tokens": outcome.actual_tokens, + "actual_cost_cents": outcome.actual_cost_cents, + "success": outcome.success, + "iterations": outcome.iterations, + "tool_calls_count": outcome.tool_calls_count, + "task_description": outcome.task_description, + "task_type": outcome.task_type, + "cost_error_ratio": outcome.cost_error_ratio, + "token_error_ratio": outcome.token_error_ratio, + "task_embedding": embedding_str + }); + + let resp = self.client + .post(format!("{}/task_outcomes", self.rest_url())) + .header("apikey", &self.service_role_key) + .header("Authorization", format!("Bearer {}", self.service_role_key)) + .header("Content-Type", "application/json") + .header("Prefer", "return=representation") + .json(&body) + .send() + .await?; + + let status = resp.status(); + let text = resp.text().await?; + + if !status.is_success() { + anyhow::bail!("Failed to insert task outcome: {} - {}", status, text); + } + + let outcomes: Vec = serde_json::from_str(&text)?; + outcomes.into_iter().next() + .and_then(|o| o.id) + .ok_or_else(|| anyhow::anyhow!("No outcome ID returned")) + } + + /// Get model statistics for a complexity range. + /// + /// Returns aggregated stats for each model that has been used + /// for tasks in the given complexity range. + pub async fn get_model_stats( + &self, + complexity_min: f64, + complexity_max: f64, + ) -> anyhow::Result> { + let body = serde_json::json!({ + "complexity_min": complexity_min, + "complexity_max": complexity_max + }); + + let resp = self.client + .post(format!("{}/rpc/get_model_stats", self.rest_url())) + .header("apikey", &self.service_role_key) + .header("Authorization", format!("Bearer {}", self.service_role_key)) + .header("Content-Type", "application/json") + .json(&body) + .send() + .await?; + + let status = resp.status(); + let text = resp.text().await?; + + if !status.is_success() { + // If RPC doesn't exist yet, return empty + tracing::debug!("get_model_stats RPC not available: {}", text); + return Ok(vec![]); + } + + Ok(serde_json::from_str(&text).unwrap_or_default()) + } + + /// Search for similar task outcomes by embedding similarity. + pub async fn search_similar_outcomes( + &self, + embedding: &[f32], + threshold: f64, + limit: usize, + ) -> anyhow::Result> { + let embedding_str = format!( + "[{}]", + embedding.iter().map(|f| f.to_string()).collect::>().join(",") + ); + + let body = serde_json::json!({ + "query_embedding": embedding_str, + "match_threshold": threshold, + "match_count": limit + }); + + let resp = self.client + .post(format!("{}/rpc/search_similar_outcomes", self.rest_url())) + .header("apikey", &self.service_role_key) + .header("Authorization", format!("Bearer {}", self.service_role_key)) + .header("Content-Type", "application/json") + .json(&body) + .send() + .await?; + + let status = resp.status(); + let text = resp.text().await?; + + if !status.is_success() { + // If RPC doesn't exist yet, return empty + tracing::debug!("search_similar_outcomes RPC not available: {}", text); + return Ok(vec![]); + } + + Ok(serde_json::from_str(&text).unwrap_or_default()) + } + + /// Get global learning statistics (for tuning). + pub async fn get_global_stats(&self) -> anyhow::Result { + let resp = self.client + .post(format!("{}/rpc/get_global_learning_stats", self.rest_url())) + .header("apikey", &self.service_role_key) + .header("Authorization", format!("Bearer {}", self.service_role_key)) + .header("Content-Type", "application/json") + .json(&serde_json::json!({})) + .send() + .await?; + + let status = resp.status(); + let text = resp.text().await?; + + if !status.is_success() { + tracing::debug!("get_global_learning_stats RPC not available: {}", text); + return Ok(serde_json::json!({})); + } + + Ok(serde_json::from_str(&text).unwrap_or_default()) + } } diff --git a/src/memory/types.rs b/src/memory/types.rs index 586335a..b72934b 100644 --- a/src/memory/types.rs +++ b/src/memory/types.rs @@ -154,6 +154,123 @@ pub struct ContextPack { pub query: String, } +/// Task outcome record for learning from execution history. +/// +/// Captures predictions vs actuals to enable data-driven optimization +/// of complexity estimation, model selection, and budget allocation. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DbTaskOutcome { + pub id: Option, + pub run_id: Uuid, + pub task_id: Uuid, + + // Predictions (what we estimated before execution) + pub predicted_complexity: Option, + pub predicted_tokens: Option, + pub predicted_cost_cents: Option, + pub selected_model: Option, + + // Actuals (what happened during execution) + pub actual_tokens: Option, + pub actual_cost_cents: Option, + pub success: bool, + pub iterations: Option, + pub tool_calls_count: Option, + + // Metadata for similarity search + pub task_description: String, + /// Category of task (inferred or explicit) + pub task_type: Option, + + // Computed ratios for quick stats + pub cost_error_ratio: Option, + pub token_error_ratio: Option, + + pub created_at: Option, +} + +impl DbTaskOutcome { + /// Create a new outcome from predictions and actuals. + pub fn new( + run_id: Uuid, + task_id: Uuid, + task_description: String, + predicted_complexity: Option, + predicted_tokens: Option, + predicted_cost_cents: Option, + selected_model: Option, + actual_tokens: Option, + actual_cost_cents: Option, + success: bool, + iterations: Option, + tool_calls_count: Option, + ) -> Self { + // Compute error ratios + let cost_error_ratio = match (actual_cost_cents, predicted_cost_cents) { + (Some(actual), Some(predicted)) if predicted > 0 => { + Some(actual as f64 / predicted as f64) + } + _ => None, + }; + + let token_error_ratio = match (actual_tokens, predicted_tokens) { + (Some(actual), Some(predicted)) if predicted > 0 => { + Some(actual as f64 / predicted as f64) + } + _ => None, + }; + + Self { + id: None, + run_id, + task_id, + predicted_complexity, + predicted_tokens, + predicted_cost_cents, + selected_model, + actual_tokens, + actual_cost_cents, + success, + iterations, + tool_calls_count, + task_description, + task_type: None, + cost_error_ratio, + token_error_ratio, + created_at: None, + } + } +} + +/// Model performance statistics from historical data. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelStats { + pub model_id: String, + /// Success rate (0-1) + pub success_rate: f64, + /// Average cost vs predicted (1.0 = accurate, >1 = underestimated) + pub avg_cost_ratio: f64, + /// Average tokens vs predicted + pub avg_token_ratio: f64, + /// Average iterations needed + pub avg_iterations: f64, + /// Number of samples + pub sample_count: i64, +} + +/// Historical context for a task (similar past tasks). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HistoricalContext { + /// Similar past task outcomes + pub similar_outcomes: Vec, + /// Average cost adjustment multiplier + pub avg_cost_multiplier: f64, + /// Average token adjustment multiplier + pub avg_token_multiplier: f64, + /// Success rate for similar tasks + pub similar_success_rate: f64, +} + impl ContextPack { /// Format as a string for prompt injection. pub fn format_for_prompt(&self) -> String { diff --git a/src/memory/writer.rs b/src/memory/writer.rs index 38b4a12..4465eef 100644 --- a/src/memory/writer.rs +++ b/src/memory/writer.rs @@ -7,7 +7,7 @@ use serde_json::json; use super::supabase::SupabaseClient; use super::embed::EmbeddingClient; -use super::types::{DbTask, DbEvent, DbChunk, EventKind, MemoryStatus}; +use super::types::{DbTask, DbEvent, DbChunk, DbTaskOutcome, EventKind, MemoryStatus}; /// Maximum chunk size in characters. const MAX_CHUNK_SIZE: usize = 2000; @@ -241,6 +241,47 @@ impl MemoryWriter { self.supabase.update_run_summary(run_id, summary, &embedding).await } + /// Record a task outcome for learning. + /// + /// This captures predictions vs actuals to enable data-driven optimization + /// of complexity estimation, model selection, and budget allocation. + pub async fn record_task_outcome( + &self, + run_id: Uuid, + task_id: Uuid, + task_description: &str, + predicted_complexity: Option, + predicted_tokens: Option, + predicted_cost_cents: Option, + selected_model: Option, + actual_tokens: Option, + actual_cost_cents: Option, + success: bool, + iterations: Option, + tool_calls_count: Option, + ) -> anyhow::Result { + // Create the outcome record + let outcome = DbTaskOutcome::new( + run_id, + task_id, + task_description.to_string(), + predicted_complexity, + predicted_tokens, + predicted_cost_cents, + selected_model, + actual_tokens, + actual_cost_cents, + success, + iterations, + tool_calls_count, + ); + + // Generate embedding for similarity search + let embedding = self.embedder.embed(task_description).await.ok(); + + self.supabase.insert_task_outcome(&outcome, embedding.as_deref()).await + } + /// Split text into chunks. fn chunk_text(&self, text: &str) -> Vec { let mut chunks = Vec::new();