Add learning system for data-driven optimization

## New Features

### task_outcomes table (Supabase)
- Stores predictions vs actuals for every task execution
- Fields: predicted_complexity, predicted_tokens, predicted_cost_cents
- Fields: actual_tokens, actual_cost_cents, success, iterations, tool_calls
- Computed: cost_error_ratio, token_error_ratio
- Vector embedding for similarity search

### RPC Functions
- get_model_stats(complexity_min, complexity_max) - Model performance by tier
- search_similar_outcomes(embedding, threshold, limit) - Find similar tasks
- get_global_learning_stats() - Overall system metrics

### Rust Integration
- MemoryWriter.record_task_outcome() - Record after execution
- MemoryRetriever.get_model_stats() - Query model performance
- MemoryRetriever.find_similar_tasks() - Semantic similarity search
- MemoryRetriever.get_historical_context() - Aggregated learning data
- New types: DbTaskOutcome, ModelStats, HistoricalContext

### Documentation
- Updated .cursor/rules/project.md with learning architecture diagram
- Documented data flow and integration points
- Added recommended models section
This commit is contained in:
Thomas Marchand
2025-12-15 10:52:49 +00:00
parent 06f234d62f
commit ccfbcd9b37
6 changed files with 720 additions and 5 deletions

View File

@@ -104,17 +104,99 @@ src/
- **Long tasks beyond context**: persist step-by-step execution so the agent can retrieve relevant context later
- **Fast query + browsing**: structured metadata in Postgres, heavy blobs in Storage
- **Embedding + rerank**: Qwen3 Embedding 8B for vectors, Qwen reranker for precision
- **Learning from execution**: store predictions vs actuals to improve estimates over time
### Data Flow
1. Agents emit events via `EventRecorder`
2. `MemoryWriter` persists to Supabase Postgres + Storage
3. Before LLM calls, `MemoryRetriever` fetches relevant context
4. On completion, run is archived with summary embedding
5. **Task outcomes recorded for learning** (complexity, cost, tokens, success)
### Storage Strategy
- **Postgres (pgvector)**: runs, tasks (hierarchical), events (preview), chunks (embeddings)
- **Postgres (pgvector)**: runs, tasks (hierarchical), events (preview), chunks (embeddings), **task_outcomes**
- **Supabase Storage**: full event streams (jsonl), large artifacts
## Learning System (v3)
### Purpose
Enable data-driven optimization of:
- **Complexity estimation**: learn actual token usage vs predicted
- **Model selection**: learn actual success rates per model/complexity
- **Budget allocation**: learn actual costs vs estimated
### Architecture
```
┌──────────────────────────────────────────────────────────────────────┐
│ Memory-Enhanced Agent Flow │
├──────────────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────┐ Query similar ┌─────────────────────────────┐ │
│ │ New Task │ ───────────────────▶│ MemoryRetriever │ │
│ └────┬─────┘ past tasks │ - find_similar_tasks() │ │
│ │ │ - get_historical_context() │ │
│ ▼ │ - get_model_stats() │ │
│ ┌────────────────┐ └───────────────┬─────────────┘ │
│ │ Complexity │◀── historical context ─────────┘ │
│ │ Estimator │ (avg token ratio, avg cost ratio) │
│ │ (enhanced) │ │
│ └────────┬───────┘ │
│ │ │
│ ▼ │
│ ┌────────────────┐ Query: "models at complexity ~0.6" │
│ │ Model Selector │ Returns: actual success rates, cost ratios │
│ │ (enhanced) │ │
│ └────────┬───────┘ │
│ │ │
│ ▼ │
│ ┌────────────────┐ │
│ │ TaskExecutor │──▶ record_task_outcome() ──▶ task_outcomes │
│ └────────────────┘ │
│ │
└──────────────────────────────────────────────────────────────────────┘
```
### Database Schema: `task_outcomes`
```sql
CREATE TABLE task_outcomes (
id uuid PRIMARY KEY,
run_id uuid REFERENCES runs(id),
task_id uuid REFERENCES tasks(id),
-- Predictions
predicted_complexity float,
predicted_tokens bigint,
predicted_cost_cents bigint,
selected_model text,
-- Actuals
actual_tokens bigint,
actual_cost_cents bigint,
success boolean,
iterations int,
tool_calls_count int,
-- Computed ratios (actual/predicted)
cost_error_ratio float,
token_error_ratio float,
-- Similarity search
task_description text,
task_embedding vector(1536)
);
```
### RPC Functions
- `get_model_stats(complexity_min, complexity_max)` - Model performance by complexity tier
- `search_similar_outcomes(embedding, threshold, limit)` - Find similar past tasks
- `get_global_learning_stats()` - Overall system metrics
### Learning Integration Points
1. **ComplexityEstimator**: Query similar tasks → adjust token estimate by `avg_token_ratio`
2. **ModelSelector**: Query model stats → use actual success rates instead of heuristics
3. **TaskExecutor**: After execution → call `record_task_outcome()` with all metrics
4. **Budget**: Use historical cost ratios to add appropriate safety margins
## Design for Provability
### Conventions for Future Lean Proofs
@@ -163,7 +245,7 @@ GET /api/memory/search - Semantic search across memory
```
OPENROUTER_API_KEY - Required. Your OpenRouter API key
DEFAULT_MODEL - Optional. Default: openai/gpt-4.1-mini
DEFAULT_MODEL - Optional. Default: anthropic/claude-sonnet-4.5
WORKSPACE_PATH - Optional. Default: current directory
HOST - Optional. Default: 127.0.0.1
PORT - Optional. Default: 3000
@@ -174,6 +256,11 @@ MEMORY_EMBED_MODEL - Optional. Default: qwen/qwen3-embedding-8b
MEMORY_RERANK_MODEL - Optional. Default: qwen/qwen3-reranker-8b
```
### Recommended Models
- **Default (tools)**: `anthropic/claude-sonnet-4.5` - Best coding, 1M context, $3/$15 per 1M tokens
- **Budget fallback**: `anthropic/claude-3.5-haiku` - Fast, cheap, good for simple tasks
- **Complex tasks**: `anthropic/claude-opus-4.5` - Highest capability when needed
## Security Considerations
This agent has **full machine access**. It can:
@@ -192,8 +279,12 @@ When deploying:
- [ ] Formal verification in Lean (extract pure logic)
- [ ] WebSocket for bidirectional streaming
- [ ] Budget overflow strategies (fallback to cheaper model, request extension)
- [ ] Enhanced ComplexityEstimator with historical context injection
- [ ] Enhanced ModelSelector with data-driven success rates
- [x] Semantic code search (embeddings-based)
- [x] Multi-model support (U-curve optimization)
- [x] Cost tracking (Budget system)
- [x] Persistent memory (Supabase + pgvector)
- [x] Learning system (task_outcomes table, historical queries)

View File

@@ -0,0 +1,239 @@
-- Migration: Task Outcomes for Learning
-- This table stores predictions vs actuals for each task execution,
-- enabling the agent to learn optimal model selection and cost estimation.
-- ============================================================================
-- Table: task_outcomes
-- ============================================================================
CREATE TABLE IF NOT EXISTS task_outcomes (
id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
run_id uuid REFERENCES runs(id) ON DELETE CASCADE,
task_id uuid REFERENCES tasks(id) ON DELETE CASCADE,
-- Predictions (what we estimated before execution)
predicted_complexity float,
predicted_tokens bigint,
predicted_cost_cents bigint,
selected_model text,
-- Actuals (what happened during execution)
actual_tokens bigint,
actual_cost_cents bigint,
success boolean NOT NULL DEFAULT false,
iterations int,
tool_calls_count int,
-- Metadata for similarity search
task_description text NOT NULL,
task_type text, -- 'file_create', 'refactor', 'debug', etc.
task_embedding vector(1536), -- For finding similar tasks
-- Computed ratios for quick stats
cost_error_ratio float, -- actual/predicted (1.0 = accurate)
token_error_ratio float,
created_at timestamptz DEFAULT now()
);
-- Indexes for efficient queries
CREATE INDEX IF NOT EXISTS idx_outcomes_run_id ON task_outcomes(run_id);
CREATE INDEX IF NOT EXISTS idx_outcomes_task_id ON task_outcomes(task_id);
CREATE INDEX IF NOT EXISTS idx_outcomes_model ON task_outcomes(selected_model, success);
CREATE INDEX IF NOT EXISTS idx_outcomes_complexity ON task_outcomes(predicted_complexity);
CREATE INDEX IF NOT EXISTS idx_outcomes_created ON task_outcomes(created_at DESC);
-- Vector index for similarity search (HNSW)
CREATE INDEX IF NOT EXISTS idx_outcomes_embedding ON task_outcomes
USING hnsw (task_embedding vector_cosine_ops);
-- ============================================================================
-- RPC: get_model_stats
-- Returns aggregated statistics per model for a given complexity range.
-- ============================================================================
CREATE OR REPLACE FUNCTION get_model_stats(
complexity_min float DEFAULT 0.0,
complexity_max float DEFAULT 1.0
)
RETURNS TABLE (
model_id text,
success_rate float,
avg_cost_ratio float,
avg_token_ratio float,
avg_iterations float,
sample_count bigint
)
LANGUAGE sql STABLE
AS $$
SELECT
selected_model as model_id,
AVG(CASE WHEN success THEN 1.0 ELSE 0.0 END) as success_rate,
COALESCE(AVG(cost_error_ratio), 1.0) as avg_cost_ratio,
COALESCE(AVG(token_error_ratio), 1.0) as avg_token_ratio,
COALESCE(AVG(iterations), 1.0) as avg_iterations,
COUNT(*) as sample_count
FROM task_outcomes
WHERE
selected_model IS NOT NULL
AND predicted_complexity >= complexity_min
AND predicted_complexity <= complexity_max
GROUP BY selected_model
HAVING COUNT(*) >= 3 -- Minimum samples for reliability
ORDER BY success_rate DESC, avg_cost_ratio ASC;
$$;
-- ============================================================================
-- RPC: search_similar_outcomes
-- Find similar past task outcomes by embedding similarity.
-- ============================================================================
CREATE OR REPLACE FUNCTION search_similar_outcomes(
query_embedding vector(1536),
match_threshold float DEFAULT 0.7,
match_count int DEFAULT 5
)
RETURNS TABLE (
id uuid,
run_id uuid,
task_id uuid,
predicted_complexity float,
predicted_tokens bigint,
predicted_cost_cents bigint,
selected_model text,
actual_tokens bigint,
actual_cost_cents bigint,
success boolean,
iterations int,
tool_calls_count int,
task_description text,
task_type text,
cost_error_ratio float,
token_error_ratio float,
created_at timestamptz,
similarity float
)
LANGUAGE sql STABLE
AS $$
SELECT
o.id,
o.run_id,
o.task_id,
o.predicted_complexity,
o.predicted_tokens,
o.predicted_cost_cents,
o.selected_model,
o.actual_tokens,
o.actual_cost_cents,
o.success,
o.iterations,
o.tool_calls_count,
o.task_description,
o.task_type,
o.cost_error_ratio,
o.token_error_ratio,
o.created_at,
1 - (o.task_embedding <=> query_embedding) as similarity
FROM task_outcomes o
WHERE
o.task_embedding IS NOT NULL
AND 1 - (o.task_embedding <=> query_embedding) > match_threshold
ORDER BY o.task_embedding <=> query_embedding
LIMIT match_count;
$$;
-- ============================================================================
-- RPC: get_global_learning_stats
-- Returns overall system learning statistics for monitoring/tuning.
-- ============================================================================
CREATE OR REPLACE FUNCTION get_global_learning_stats()
RETURNS json
LANGUAGE sql STABLE
AS $$
SELECT json_build_object(
'total_outcomes', (SELECT COUNT(*) FROM task_outcomes),
'success_rate', (SELECT AVG(CASE WHEN success THEN 1.0 ELSE 0.0 END) FROM task_outcomes),
'avg_cost_error', (SELECT AVG(cost_error_ratio) FROM task_outcomes WHERE cost_error_ratio IS NOT NULL),
'avg_token_error', (SELECT AVG(token_error_ratio) FROM task_outcomes WHERE token_error_ratio IS NOT NULL),
'models_used', (SELECT COUNT(DISTINCT selected_model) FROM task_outcomes),
'top_models', (
SELECT json_agg(row_to_json(t))
FROM (
SELECT
selected_model,
COUNT(*) as uses,
AVG(CASE WHEN success THEN 1.0 ELSE 0.0 END) as success_rate
FROM task_outcomes
WHERE selected_model IS NOT NULL
GROUP BY selected_model
ORDER BY uses DESC
LIMIT 5
) t
),
'complexity_distribution', (
SELECT json_agg(row_to_json(t))
FROM (
SELECT
CASE
WHEN predicted_complexity < 0.2 THEN 'trivial'
WHEN predicted_complexity < 0.4 THEN 'simple'
WHEN predicted_complexity < 0.6 THEN 'moderate'
WHEN predicted_complexity < 0.8 THEN 'complex'
ELSE 'very_complex'
END as tier,
COUNT(*) as count,
AVG(CASE WHEN success THEN 1.0 ELSE 0.0 END) as success_rate
FROM task_outcomes
WHERE predicted_complexity IS NOT NULL
GROUP BY tier
) t
)
);
$$;
-- ============================================================================
-- RPC: get_optimal_model_for_complexity
-- Returns the best model for a given complexity based on historical data.
-- ============================================================================
CREATE OR REPLACE FUNCTION get_optimal_model_for_complexity(
target_complexity float,
budget_cents bigint DEFAULT 1000
)
RETURNS TABLE (
model_id text,
expected_success_rate float,
expected_cost_cents float,
confidence float
)
LANGUAGE sql STABLE
AS $$
WITH model_perf AS (
SELECT
selected_model,
AVG(CASE WHEN success THEN 1.0 ELSE 0.0 END) as success_rate,
AVG(actual_cost_cents) as avg_cost,
COUNT(*) as samples,
-- Confidence based on sample size and recency
LEAST(1.0, COUNT(*) / 10.0) as sample_confidence
FROM task_outcomes
WHERE
selected_model IS NOT NULL
AND ABS(predicted_complexity - target_complexity) < 0.2
AND created_at > now() - interval '30 days'
GROUP BY selected_model
)
SELECT
selected_model as model_id,
success_rate as expected_success_rate,
avg_cost as expected_cost_cents,
sample_confidence as confidence
FROM model_perf
WHERE avg_cost <= budget_cents OR samples < 3 -- Allow trying new models
ORDER BY
-- Balance success rate and cost
(success_rate * 0.7 + (1.0 - LEAST(avg_cost / budget_cents, 1.0)) * 0.3) DESC
LIMIT 3;
$$;

View File

@@ -6,7 +6,7 @@ use serde::Deserialize;
use super::supabase::SupabaseClient;
use super::embed::EmbeddingClient;
use super::types::{SearchResult, ContextPack};
use super::types::{SearchResult, ContextPack, DbTaskOutcome, ModelStats, HistoricalContext};
/// Default similarity threshold for vector search.
const DEFAULT_THRESHOLD: f64 = 0.5;
@@ -230,6 +230,82 @@ Only return the JSON array, nothing else."#,
pub async fn list_runs(&self, limit: usize, offset: usize) -> anyhow::Result<Vec<super::types::DbRun>> {
self.supabase.list_runs(limit, offset).await
}
// ==================== Learning Methods ====================
/// Get model performance statistics for a given complexity range.
///
/// Returns historical success rates, cost ratios, etc. for each model
/// that has been used at the given complexity level.
pub async fn get_model_stats(
&self,
complexity: f64,
range: f64,
) -> anyhow::Result<Vec<ModelStats>> {
let min = (complexity - range).max(0.0);
let max = (complexity + range).min(1.0);
self.supabase.get_model_stats(min, max).await
}
/// Find similar past tasks and their outcomes.
///
/// Uses embedding similarity to find tasks that are semantically similar
/// to the given task description, then returns their execution outcomes.
pub async fn find_similar_tasks(
&self,
task_description: &str,
limit: usize,
) -> anyhow::Result<Vec<DbTaskOutcome>> {
// Generate embedding for the task description
let embedding = self.embedder.embed(task_description).await?;
// Search for similar outcomes
self.supabase.search_similar_outcomes(&embedding, 0.6, limit).await
}
/// Get historical context for a task.
///
/// Returns aggregated learning data from similar past tasks including:
/// - Average cost adjustment multiplier
/// - Average token adjustment multiplier
/// - Success rate for similar tasks
pub async fn get_historical_context(
&self,
task_description: &str,
limit: usize,
) -> anyhow::Result<Option<HistoricalContext>> {
let similar = self.find_similar_tasks(task_description, limit).await?;
if similar.is_empty() {
return Ok(None);
}
// Calculate aggregated stats
let total = similar.len() as f64;
let avg_cost_multiplier = similar.iter()
.filter_map(|o| o.cost_error_ratio)
.sum::<f64>() / similar.iter().filter(|o| o.cost_error_ratio.is_some()).count().max(1) as f64;
let avg_token_multiplier = similar.iter()
.filter_map(|o| o.token_error_ratio)
.sum::<f64>() / similar.iter().filter(|o| o.token_error_ratio.is_some()).count().max(1) as f64;
let success_count = similar.iter().filter(|o| o.success).count() as f64;
let similar_success_rate = success_count / total;
Ok(Some(HistoricalContext {
similar_outcomes: similar,
avg_cost_multiplier: if avg_cost_multiplier.is_nan() { 1.0 } else { avg_cost_multiplier },
avg_token_multiplier: if avg_token_multiplier.is_nan() { 1.0 } else { avg_token_multiplier },
similar_success_rate,
}))
}
/// Get global learning statistics.
pub async fn get_learning_stats(&self) -> anyhow::Result<serde_json::Value> {
self.supabase.get_global_stats().await
}
}
/// Truncate a string to max length.

View File

@@ -3,7 +3,7 @@
use reqwest::Client;
use uuid::Uuid;
use super::types::{DbRun, DbTask, DbEvent, DbChunk, SearchResult};
use super::types::{DbRun, DbTask, DbEvent, DbChunk, DbTaskOutcome, SearchResult, ModelStats};
/// Supabase client for database and storage operations.
pub struct SupabaseClient {
@@ -357,5 +357,156 @@ impl SupabaseClient {
self.update_run(run_id, body).await
}
// ==================== Task Outcomes (Learning) ====================
/// Insert a task outcome for learning.
pub async fn insert_task_outcome(
&self,
outcome: &DbTaskOutcome,
embedding: Option<&[f32]>,
) -> anyhow::Result<Uuid> {
let embedding_str = embedding.map(|e| format!(
"[{}]",
e.iter().map(|f| f.to_string()).collect::<Vec<_>>().join(",")
));
let body = serde_json::json!({
"run_id": outcome.run_id,
"task_id": outcome.task_id,
"predicted_complexity": outcome.predicted_complexity,
"predicted_tokens": outcome.predicted_tokens,
"predicted_cost_cents": outcome.predicted_cost_cents,
"selected_model": outcome.selected_model,
"actual_tokens": outcome.actual_tokens,
"actual_cost_cents": outcome.actual_cost_cents,
"success": outcome.success,
"iterations": outcome.iterations,
"tool_calls_count": outcome.tool_calls_count,
"task_description": outcome.task_description,
"task_type": outcome.task_type,
"cost_error_ratio": outcome.cost_error_ratio,
"token_error_ratio": outcome.token_error_ratio,
"task_embedding": embedding_str
});
let resp = self.client
.post(format!("{}/task_outcomes", self.rest_url()))
.header("apikey", &self.service_role_key)
.header("Authorization", format!("Bearer {}", self.service_role_key))
.header("Content-Type", "application/json")
.header("Prefer", "return=representation")
.json(&body)
.send()
.await?;
let status = resp.status();
let text = resp.text().await?;
if !status.is_success() {
anyhow::bail!("Failed to insert task outcome: {} - {}", status, text);
}
let outcomes: Vec<DbTaskOutcome> = serde_json::from_str(&text)?;
outcomes.into_iter().next()
.and_then(|o| o.id)
.ok_or_else(|| anyhow::anyhow!("No outcome ID returned"))
}
/// Get model statistics for a complexity range.
///
/// Returns aggregated stats for each model that has been used
/// for tasks in the given complexity range.
pub async fn get_model_stats(
&self,
complexity_min: f64,
complexity_max: f64,
) -> anyhow::Result<Vec<ModelStats>> {
let body = serde_json::json!({
"complexity_min": complexity_min,
"complexity_max": complexity_max
});
let resp = self.client
.post(format!("{}/rpc/get_model_stats", self.rest_url()))
.header("apikey", &self.service_role_key)
.header("Authorization", format!("Bearer {}", self.service_role_key))
.header("Content-Type", "application/json")
.json(&body)
.send()
.await?;
let status = resp.status();
let text = resp.text().await?;
if !status.is_success() {
// If RPC doesn't exist yet, return empty
tracing::debug!("get_model_stats RPC not available: {}", text);
return Ok(vec![]);
}
Ok(serde_json::from_str(&text).unwrap_or_default())
}
/// Search for similar task outcomes by embedding similarity.
pub async fn search_similar_outcomes(
&self,
embedding: &[f32],
threshold: f64,
limit: usize,
) -> anyhow::Result<Vec<DbTaskOutcome>> {
let embedding_str = format!(
"[{}]",
embedding.iter().map(|f| f.to_string()).collect::<Vec<_>>().join(",")
);
let body = serde_json::json!({
"query_embedding": embedding_str,
"match_threshold": threshold,
"match_count": limit
});
let resp = self.client
.post(format!("{}/rpc/search_similar_outcomes", self.rest_url()))
.header("apikey", &self.service_role_key)
.header("Authorization", format!("Bearer {}", self.service_role_key))
.header("Content-Type", "application/json")
.json(&body)
.send()
.await?;
let status = resp.status();
let text = resp.text().await?;
if !status.is_success() {
// If RPC doesn't exist yet, return empty
tracing::debug!("search_similar_outcomes RPC not available: {}", text);
return Ok(vec![]);
}
Ok(serde_json::from_str(&text).unwrap_or_default())
}
/// Get global learning statistics (for tuning).
pub async fn get_global_stats(&self) -> anyhow::Result<serde_json::Value> {
let resp = self.client
.post(format!("{}/rpc/get_global_learning_stats", self.rest_url()))
.header("apikey", &self.service_role_key)
.header("Authorization", format!("Bearer {}", self.service_role_key))
.header("Content-Type", "application/json")
.json(&serde_json::json!({}))
.send()
.await?;
let status = resp.status();
let text = resp.text().await?;
if !status.is_success() {
tracing::debug!("get_global_learning_stats RPC not available: {}", text);
return Ok(serde_json::json!({}));
}
Ok(serde_json::from_str(&text).unwrap_or_default())
}
}

View File

@@ -154,6 +154,123 @@ pub struct ContextPack {
pub query: String,
}
/// Task outcome record for learning from execution history.
///
/// Captures predictions vs actuals to enable data-driven optimization
/// of complexity estimation, model selection, and budget allocation.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DbTaskOutcome {
pub id: Option<Uuid>,
pub run_id: Uuid,
pub task_id: Uuid,
// Predictions (what we estimated before execution)
pub predicted_complexity: Option<f64>,
pub predicted_tokens: Option<i64>,
pub predicted_cost_cents: Option<i64>,
pub selected_model: Option<String>,
// Actuals (what happened during execution)
pub actual_tokens: Option<i64>,
pub actual_cost_cents: Option<i64>,
pub success: bool,
pub iterations: Option<i32>,
pub tool_calls_count: Option<i32>,
// Metadata for similarity search
pub task_description: String,
/// Category of task (inferred or explicit)
pub task_type: Option<String>,
// Computed ratios for quick stats
pub cost_error_ratio: Option<f64>,
pub token_error_ratio: Option<f64>,
pub created_at: Option<String>,
}
impl DbTaskOutcome {
/// Create a new outcome from predictions and actuals.
pub fn new(
run_id: Uuid,
task_id: Uuid,
task_description: String,
predicted_complexity: Option<f64>,
predicted_tokens: Option<i64>,
predicted_cost_cents: Option<i64>,
selected_model: Option<String>,
actual_tokens: Option<i64>,
actual_cost_cents: Option<i64>,
success: bool,
iterations: Option<i32>,
tool_calls_count: Option<i32>,
) -> Self {
// Compute error ratios
let cost_error_ratio = match (actual_cost_cents, predicted_cost_cents) {
(Some(actual), Some(predicted)) if predicted > 0 => {
Some(actual as f64 / predicted as f64)
}
_ => None,
};
let token_error_ratio = match (actual_tokens, predicted_tokens) {
(Some(actual), Some(predicted)) if predicted > 0 => {
Some(actual as f64 / predicted as f64)
}
_ => None,
};
Self {
id: None,
run_id,
task_id,
predicted_complexity,
predicted_tokens,
predicted_cost_cents,
selected_model,
actual_tokens,
actual_cost_cents,
success,
iterations,
tool_calls_count,
task_description,
task_type: None,
cost_error_ratio,
token_error_ratio,
created_at: None,
}
}
}
/// Model performance statistics from historical data.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelStats {
pub model_id: String,
/// Success rate (0-1)
pub success_rate: f64,
/// Average cost vs predicted (1.0 = accurate, >1 = underestimated)
pub avg_cost_ratio: f64,
/// Average tokens vs predicted
pub avg_token_ratio: f64,
/// Average iterations needed
pub avg_iterations: f64,
/// Number of samples
pub sample_count: i64,
}
/// Historical context for a task (similar past tasks).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HistoricalContext {
/// Similar past task outcomes
pub similar_outcomes: Vec<DbTaskOutcome>,
/// Average cost adjustment multiplier
pub avg_cost_multiplier: f64,
/// Average token adjustment multiplier
pub avg_token_multiplier: f64,
/// Success rate for similar tasks
pub similar_success_rate: f64,
}
impl ContextPack {
/// Format as a string for prompt injection.
pub fn format_for_prompt(&self) -> String {

View File

@@ -7,7 +7,7 @@ use serde_json::json;
use super::supabase::SupabaseClient;
use super::embed::EmbeddingClient;
use super::types::{DbTask, DbEvent, DbChunk, EventKind, MemoryStatus};
use super::types::{DbTask, DbEvent, DbChunk, DbTaskOutcome, EventKind, MemoryStatus};
/// Maximum chunk size in characters.
const MAX_CHUNK_SIZE: usize = 2000;
@@ -241,6 +241,47 @@ impl MemoryWriter {
self.supabase.update_run_summary(run_id, summary, &embedding).await
}
/// Record a task outcome for learning.
///
/// This captures predictions vs actuals to enable data-driven optimization
/// of complexity estimation, model selection, and budget allocation.
pub async fn record_task_outcome(
&self,
run_id: Uuid,
task_id: Uuid,
task_description: &str,
predicted_complexity: Option<f64>,
predicted_tokens: Option<i64>,
predicted_cost_cents: Option<i64>,
selected_model: Option<String>,
actual_tokens: Option<i64>,
actual_cost_cents: Option<i64>,
success: bool,
iterations: Option<i32>,
tool_calls_count: Option<i32>,
) -> anyhow::Result<Uuid> {
// Create the outcome record
let outcome = DbTaskOutcome::new(
run_id,
task_id,
task_description.to_string(),
predicted_complexity,
predicted_tokens,
predicted_cost_cents,
selected_model,
actual_tokens,
actual_cost_cents,
success,
iterations,
tool_calls_count,
);
// Generate embedding for similarity search
let embedding = self.embedder.embed(task_description).await.ok();
self.supabase.insert_task_outcome(&outcome, embedding.as_deref()).await
}
/// Split text into chunks.
fn chunk_text(&self, text: &str) -> Vec<String> {
let mut chunks = Vec::new();