Add learning system for data-driven optimization
## New Features ### task_outcomes table (Supabase) - Stores predictions vs actuals for every task execution - Fields: predicted_complexity, predicted_tokens, predicted_cost_cents - Fields: actual_tokens, actual_cost_cents, success, iterations, tool_calls - Computed: cost_error_ratio, token_error_ratio - Vector embedding for similarity search ### RPC Functions - get_model_stats(complexity_min, complexity_max) - Model performance by tier - search_similar_outcomes(embedding, threshold, limit) - Find similar tasks - get_global_learning_stats() - Overall system metrics ### Rust Integration - MemoryWriter.record_task_outcome() - Record after execution - MemoryRetriever.get_model_stats() - Query model performance - MemoryRetriever.find_similar_tasks() - Semantic similarity search - MemoryRetriever.get_historical_context() - Aggregated learning data - New types: DbTaskOutcome, ModelStats, HistoricalContext ### Documentation - Updated .cursor/rules/project.md with learning architecture diagram - Documented data flow and integration points - Added recommended models section
This commit is contained in:
@@ -104,17 +104,99 @@ src/
|
||||
- **Long tasks beyond context**: persist step-by-step execution so the agent can retrieve relevant context later
|
||||
- **Fast query + browsing**: structured metadata in Postgres, heavy blobs in Storage
|
||||
- **Embedding + rerank**: Qwen3 Embedding 8B for vectors, Qwen reranker for precision
|
||||
- **Learning from execution**: store predictions vs actuals to improve estimates over time
|
||||
|
||||
### Data Flow
|
||||
1. Agents emit events via `EventRecorder`
|
||||
2. `MemoryWriter` persists to Supabase Postgres + Storage
|
||||
3. Before LLM calls, `MemoryRetriever` fetches relevant context
|
||||
4. On completion, run is archived with summary embedding
|
||||
5. **Task outcomes recorded for learning** (complexity, cost, tokens, success)
|
||||
|
||||
### Storage Strategy
|
||||
- **Postgres (pgvector)**: runs, tasks (hierarchical), events (preview), chunks (embeddings)
|
||||
- **Postgres (pgvector)**: runs, tasks (hierarchical), events (preview), chunks (embeddings), **task_outcomes**
|
||||
- **Supabase Storage**: full event streams (jsonl), large artifacts
|
||||
|
||||
## Learning System (v3)
|
||||
|
||||
### Purpose
|
||||
Enable data-driven optimization of:
|
||||
- **Complexity estimation**: learn actual token usage vs predicted
|
||||
- **Model selection**: learn actual success rates per model/complexity
|
||||
- **Budget allocation**: learn actual costs vs estimated
|
||||
|
||||
### Architecture
|
||||
```
|
||||
┌──────────────────────────────────────────────────────────────────────┐
|
||||
│ Memory-Enhanced Agent Flow │
|
||||
├──────────────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ ┌──────────┐ Query similar ┌─────────────────────────────┐ │
|
||||
│ │ New Task │ ───────────────────▶│ MemoryRetriever │ │
|
||||
│ └────┬─────┘ past tasks │ - find_similar_tasks() │ │
|
||||
│ │ │ - get_historical_context() │ │
|
||||
│ ▼ │ - get_model_stats() │ │
|
||||
│ ┌────────────────┐ └───────────────┬─────────────┘ │
|
||||
│ │ Complexity │◀── historical context ─────────┘ │
|
||||
│ │ Estimator │ (avg token ratio, avg cost ratio) │
|
||||
│ │ (enhanced) │ │
|
||||
│ └────────┬───────┘ │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ ┌────────────────┐ Query: "models at complexity ~0.6" │
|
||||
│ │ Model Selector │ Returns: actual success rates, cost ratios │
|
||||
│ │ (enhanced) │ │
|
||||
│ └────────┬───────┘ │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ ┌────────────────┐ │
|
||||
│ │ TaskExecutor │──▶ record_task_outcome() ──▶ task_outcomes │
|
||||
│ └────────────────┘ │
|
||||
│ │
|
||||
└──────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Database Schema: `task_outcomes`
|
||||
```sql
|
||||
CREATE TABLE task_outcomes (
|
||||
id uuid PRIMARY KEY,
|
||||
run_id uuid REFERENCES runs(id),
|
||||
task_id uuid REFERENCES tasks(id),
|
||||
|
||||
-- Predictions
|
||||
predicted_complexity float,
|
||||
predicted_tokens bigint,
|
||||
predicted_cost_cents bigint,
|
||||
selected_model text,
|
||||
|
||||
-- Actuals
|
||||
actual_tokens bigint,
|
||||
actual_cost_cents bigint,
|
||||
success boolean,
|
||||
iterations int,
|
||||
tool_calls_count int,
|
||||
|
||||
-- Computed ratios (actual/predicted)
|
||||
cost_error_ratio float,
|
||||
token_error_ratio float,
|
||||
|
||||
-- Similarity search
|
||||
task_description text,
|
||||
task_embedding vector(1536)
|
||||
);
|
||||
```
|
||||
|
||||
### RPC Functions
|
||||
- `get_model_stats(complexity_min, complexity_max)` - Model performance by complexity tier
|
||||
- `search_similar_outcomes(embedding, threshold, limit)` - Find similar past tasks
|
||||
- `get_global_learning_stats()` - Overall system metrics
|
||||
|
||||
### Learning Integration Points
|
||||
1. **ComplexityEstimator**: Query similar tasks → adjust token estimate by `avg_token_ratio`
|
||||
2. **ModelSelector**: Query model stats → use actual success rates instead of heuristics
|
||||
3. **TaskExecutor**: After execution → call `record_task_outcome()` with all metrics
|
||||
4. **Budget**: Use historical cost ratios to add appropriate safety margins
|
||||
|
||||
## Design for Provability
|
||||
|
||||
### Conventions for Future Lean Proofs
|
||||
@@ -163,7 +245,7 @@ GET /api/memory/search - Semantic search across memory
|
||||
|
||||
```
|
||||
OPENROUTER_API_KEY - Required. Your OpenRouter API key
|
||||
DEFAULT_MODEL - Optional. Default: openai/gpt-4.1-mini
|
||||
DEFAULT_MODEL - Optional. Default: anthropic/claude-sonnet-4.5
|
||||
WORKSPACE_PATH - Optional. Default: current directory
|
||||
HOST - Optional. Default: 127.0.0.1
|
||||
PORT - Optional. Default: 3000
|
||||
@@ -174,6 +256,11 @@ MEMORY_EMBED_MODEL - Optional. Default: qwen/qwen3-embedding-8b
|
||||
MEMORY_RERANK_MODEL - Optional. Default: qwen/qwen3-reranker-8b
|
||||
```
|
||||
|
||||
### Recommended Models
|
||||
- **Default (tools)**: `anthropic/claude-sonnet-4.5` - Best coding, 1M context, $3/$15 per 1M tokens
|
||||
- **Budget fallback**: `anthropic/claude-3.5-haiku` - Fast, cheap, good for simple tasks
|
||||
- **Complex tasks**: `anthropic/claude-opus-4.5` - Highest capability when needed
|
||||
|
||||
## Security Considerations
|
||||
|
||||
This agent has **full machine access**. It can:
|
||||
@@ -192,8 +279,12 @@ When deploying:
|
||||
|
||||
- [ ] Formal verification in Lean (extract pure logic)
|
||||
- [ ] WebSocket for bidirectional streaming
|
||||
- [ ] Budget overflow strategies (fallback to cheaper model, request extension)
|
||||
- [ ] Enhanced ComplexityEstimator with historical context injection
|
||||
- [ ] Enhanced ModelSelector with data-driven success rates
|
||||
- [x] Semantic code search (embeddings-based)
|
||||
- [x] Multi-model support (U-curve optimization)
|
||||
- [x] Cost tracking (Budget system)
|
||||
- [x] Persistent memory (Supabase + pgvector)
|
||||
- [x] Learning system (task_outcomes table, historical queries)
|
||||
|
||||
|
||||
239
migrations/002_task_outcomes.sql
Normal file
239
migrations/002_task_outcomes.sql
Normal file
@@ -0,0 +1,239 @@
|
||||
-- Migration: Task Outcomes for Learning
|
||||
-- This table stores predictions vs actuals for each task execution,
|
||||
-- enabling the agent to learn optimal model selection and cost estimation.
|
||||
|
||||
-- ============================================================================
|
||||
-- Table: task_outcomes
|
||||
-- ============================================================================
|
||||
|
||||
CREATE TABLE IF NOT EXISTS task_outcomes (
|
||||
id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
run_id uuid REFERENCES runs(id) ON DELETE CASCADE,
|
||||
task_id uuid REFERENCES tasks(id) ON DELETE CASCADE,
|
||||
|
||||
-- Predictions (what we estimated before execution)
|
||||
predicted_complexity float,
|
||||
predicted_tokens bigint,
|
||||
predicted_cost_cents bigint,
|
||||
selected_model text,
|
||||
|
||||
-- Actuals (what happened during execution)
|
||||
actual_tokens bigint,
|
||||
actual_cost_cents bigint,
|
||||
success boolean NOT NULL DEFAULT false,
|
||||
iterations int,
|
||||
tool_calls_count int,
|
||||
|
||||
-- Metadata for similarity search
|
||||
task_description text NOT NULL,
|
||||
task_type text, -- 'file_create', 'refactor', 'debug', etc.
|
||||
task_embedding vector(1536), -- For finding similar tasks
|
||||
|
||||
-- Computed ratios for quick stats
|
||||
cost_error_ratio float, -- actual/predicted (1.0 = accurate)
|
||||
token_error_ratio float,
|
||||
|
||||
created_at timestamptz DEFAULT now()
|
||||
);
|
||||
|
||||
-- Indexes for efficient queries
|
||||
CREATE INDEX IF NOT EXISTS idx_outcomes_run_id ON task_outcomes(run_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_outcomes_task_id ON task_outcomes(task_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_outcomes_model ON task_outcomes(selected_model, success);
|
||||
CREATE INDEX IF NOT EXISTS idx_outcomes_complexity ON task_outcomes(predicted_complexity);
|
||||
CREATE INDEX IF NOT EXISTS idx_outcomes_created ON task_outcomes(created_at DESC);
|
||||
|
||||
-- Vector index for similarity search (HNSW)
|
||||
CREATE INDEX IF NOT EXISTS idx_outcomes_embedding ON task_outcomes
|
||||
USING hnsw (task_embedding vector_cosine_ops);
|
||||
|
||||
-- ============================================================================
|
||||
-- RPC: get_model_stats
|
||||
-- Returns aggregated statistics per model for a given complexity range.
|
||||
-- ============================================================================
|
||||
|
||||
CREATE OR REPLACE FUNCTION get_model_stats(
|
||||
complexity_min float DEFAULT 0.0,
|
||||
complexity_max float DEFAULT 1.0
|
||||
)
|
||||
RETURNS TABLE (
|
||||
model_id text,
|
||||
success_rate float,
|
||||
avg_cost_ratio float,
|
||||
avg_token_ratio float,
|
||||
avg_iterations float,
|
||||
sample_count bigint
|
||||
)
|
||||
LANGUAGE sql STABLE
|
||||
AS $$
|
||||
SELECT
|
||||
selected_model as model_id,
|
||||
AVG(CASE WHEN success THEN 1.0 ELSE 0.0 END) as success_rate,
|
||||
COALESCE(AVG(cost_error_ratio), 1.0) as avg_cost_ratio,
|
||||
COALESCE(AVG(token_error_ratio), 1.0) as avg_token_ratio,
|
||||
COALESCE(AVG(iterations), 1.0) as avg_iterations,
|
||||
COUNT(*) as sample_count
|
||||
FROM task_outcomes
|
||||
WHERE
|
||||
selected_model IS NOT NULL
|
||||
AND predicted_complexity >= complexity_min
|
||||
AND predicted_complexity <= complexity_max
|
||||
GROUP BY selected_model
|
||||
HAVING COUNT(*) >= 3 -- Minimum samples for reliability
|
||||
ORDER BY success_rate DESC, avg_cost_ratio ASC;
|
||||
$$;
|
||||
|
||||
-- ============================================================================
|
||||
-- RPC: search_similar_outcomes
|
||||
-- Find similar past task outcomes by embedding similarity.
|
||||
-- ============================================================================
|
||||
|
||||
CREATE OR REPLACE FUNCTION search_similar_outcomes(
|
||||
query_embedding vector(1536),
|
||||
match_threshold float DEFAULT 0.7,
|
||||
match_count int DEFAULT 5
|
||||
)
|
||||
RETURNS TABLE (
|
||||
id uuid,
|
||||
run_id uuid,
|
||||
task_id uuid,
|
||||
predicted_complexity float,
|
||||
predicted_tokens bigint,
|
||||
predicted_cost_cents bigint,
|
||||
selected_model text,
|
||||
actual_tokens bigint,
|
||||
actual_cost_cents bigint,
|
||||
success boolean,
|
||||
iterations int,
|
||||
tool_calls_count int,
|
||||
task_description text,
|
||||
task_type text,
|
||||
cost_error_ratio float,
|
||||
token_error_ratio float,
|
||||
created_at timestamptz,
|
||||
similarity float
|
||||
)
|
||||
LANGUAGE sql STABLE
|
||||
AS $$
|
||||
SELECT
|
||||
o.id,
|
||||
o.run_id,
|
||||
o.task_id,
|
||||
o.predicted_complexity,
|
||||
o.predicted_tokens,
|
||||
o.predicted_cost_cents,
|
||||
o.selected_model,
|
||||
o.actual_tokens,
|
||||
o.actual_cost_cents,
|
||||
o.success,
|
||||
o.iterations,
|
||||
o.tool_calls_count,
|
||||
o.task_description,
|
||||
o.task_type,
|
||||
o.cost_error_ratio,
|
||||
o.token_error_ratio,
|
||||
o.created_at,
|
||||
1 - (o.task_embedding <=> query_embedding) as similarity
|
||||
FROM task_outcomes o
|
||||
WHERE
|
||||
o.task_embedding IS NOT NULL
|
||||
AND 1 - (o.task_embedding <=> query_embedding) > match_threshold
|
||||
ORDER BY o.task_embedding <=> query_embedding
|
||||
LIMIT match_count;
|
||||
$$;
|
||||
|
||||
-- ============================================================================
|
||||
-- RPC: get_global_learning_stats
|
||||
-- Returns overall system learning statistics for monitoring/tuning.
|
||||
-- ============================================================================
|
||||
|
||||
CREATE OR REPLACE FUNCTION get_global_learning_stats()
|
||||
RETURNS json
|
||||
LANGUAGE sql STABLE
|
||||
AS $$
|
||||
SELECT json_build_object(
|
||||
'total_outcomes', (SELECT COUNT(*) FROM task_outcomes),
|
||||
'success_rate', (SELECT AVG(CASE WHEN success THEN 1.0 ELSE 0.0 END) FROM task_outcomes),
|
||||
'avg_cost_error', (SELECT AVG(cost_error_ratio) FROM task_outcomes WHERE cost_error_ratio IS NOT NULL),
|
||||
'avg_token_error', (SELECT AVG(token_error_ratio) FROM task_outcomes WHERE token_error_ratio IS NOT NULL),
|
||||
'models_used', (SELECT COUNT(DISTINCT selected_model) FROM task_outcomes),
|
||||
'top_models', (
|
||||
SELECT json_agg(row_to_json(t))
|
||||
FROM (
|
||||
SELECT
|
||||
selected_model,
|
||||
COUNT(*) as uses,
|
||||
AVG(CASE WHEN success THEN 1.0 ELSE 0.0 END) as success_rate
|
||||
FROM task_outcomes
|
||||
WHERE selected_model IS NOT NULL
|
||||
GROUP BY selected_model
|
||||
ORDER BY uses DESC
|
||||
LIMIT 5
|
||||
) t
|
||||
),
|
||||
'complexity_distribution', (
|
||||
SELECT json_agg(row_to_json(t))
|
||||
FROM (
|
||||
SELECT
|
||||
CASE
|
||||
WHEN predicted_complexity < 0.2 THEN 'trivial'
|
||||
WHEN predicted_complexity < 0.4 THEN 'simple'
|
||||
WHEN predicted_complexity < 0.6 THEN 'moderate'
|
||||
WHEN predicted_complexity < 0.8 THEN 'complex'
|
||||
ELSE 'very_complex'
|
||||
END as tier,
|
||||
COUNT(*) as count,
|
||||
AVG(CASE WHEN success THEN 1.0 ELSE 0.0 END) as success_rate
|
||||
FROM task_outcomes
|
||||
WHERE predicted_complexity IS NOT NULL
|
||||
GROUP BY tier
|
||||
) t
|
||||
)
|
||||
);
|
||||
$$;
|
||||
|
||||
-- ============================================================================
|
||||
-- RPC: get_optimal_model_for_complexity
|
||||
-- Returns the best model for a given complexity based on historical data.
|
||||
-- ============================================================================
|
||||
|
||||
CREATE OR REPLACE FUNCTION get_optimal_model_for_complexity(
|
||||
target_complexity float,
|
||||
budget_cents bigint DEFAULT 1000
|
||||
)
|
||||
RETURNS TABLE (
|
||||
model_id text,
|
||||
expected_success_rate float,
|
||||
expected_cost_cents float,
|
||||
confidence float
|
||||
)
|
||||
LANGUAGE sql STABLE
|
||||
AS $$
|
||||
WITH model_perf AS (
|
||||
SELECT
|
||||
selected_model,
|
||||
AVG(CASE WHEN success THEN 1.0 ELSE 0.0 END) as success_rate,
|
||||
AVG(actual_cost_cents) as avg_cost,
|
||||
COUNT(*) as samples,
|
||||
-- Confidence based on sample size and recency
|
||||
LEAST(1.0, COUNT(*) / 10.0) as sample_confidence
|
||||
FROM task_outcomes
|
||||
WHERE
|
||||
selected_model IS NOT NULL
|
||||
AND ABS(predicted_complexity - target_complexity) < 0.2
|
||||
AND created_at > now() - interval '30 days'
|
||||
GROUP BY selected_model
|
||||
)
|
||||
SELECT
|
||||
selected_model as model_id,
|
||||
success_rate as expected_success_rate,
|
||||
avg_cost as expected_cost_cents,
|
||||
sample_confidence as confidence
|
||||
FROM model_perf
|
||||
WHERE avg_cost <= budget_cents OR samples < 3 -- Allow trying new models
|
||||
ORDER BY
|
||||
-- Balance success rate and cost
|
||||
(success_rate * 0.7 + (1.0 - LEAST(avg_cost / budget_cents, 1.0)) * 0.3) DESC
|
||||
LIMIT 3;
|
||||
$$;
|
||||
|
||||
@@ -6,7 +6,7 @@ use serde::Deserialize;
|
||||
|
||||
use super::supabase::SupabaseClient;
|
||||
use super::embed::EmbeddingClient;
|
||||
use super::types::{SearchResult, ContextPack};
|
||||
use super::types::{SearchResult, ContextPack, DbTaskOutcome, ModelStats, HistoricalContext};
|
||||
|
||||
/// Default similarity threshold for vector search.
|
||||
const DEFAULT_THRESHOLD: f64 = 0.5;
|
||||
@@ -230,6 +230,82 @@ Only return the JSON array, nothing else."#,
|
||||
pub async fn list_runs(&self, limit: usize, offset: usize) -> anyhow::Result<Vec<super::types::DbRun>> {
|
||||
self.supabase.list_runs(limit, offset).await
|
||||
}
|
||||
|
||||
// ==================== Learning Methods ====================
|
||||
|
||||
/// Get model performance statistics for a given complexity range.
|
||||
///
|
||||
/// Returns historical success rates, cost ratios, etc. for each model
|
||||
/// that has been used at the given complexity level.
|
||||
pub async fn get_model_stats(
|
||||
&self,
|
||||
complexity: f64,
|
||||
range: f64,
|
||||
) -> anyhow::Result<Vec<ModelStats>> {
|
||||
let min = (complexity - range).max(0.0);
|
||||
let max = (complexity + range).min(1.0);
|
||||
self.supabase.get_model_stats(min, max).await
|
||||
}
|
||||
|
||||
/// Find similar past tasks and their outcomes.
|
||||
///
|
||||
/// Uses embedding similarity to find tasks that are semantically similar
|
||||
/// to the given task description, then returns their execution outcomes.
|
||||
pub async fn find_similar_tasks(
|
||||
&self,
|
||||
task_description: &str,
|
||||
limit: usize,
|
||||
) -> anyhow::Result<Vec<DbTaskOutcome>> {
|
||||
// Generate embedding for the task description
|
||||
let embedding = self.embedder.embed(task_description).await?;
|
||||
|
||||
// Search for similar outcomes
|
||||
self.supabase.search_similar_outcomes(&embedding, 0.6, limit).await
|
||||
}
|
||||
|
||||
/// Get historical context for a task.
|
||||
///
|
||||
/// Returns aggregated learning data from similar past tasks including:
|
||||
/// - Average cost adjustment multiplier
|
||||
/// - Average token adjustment multiplier
|
||||
/// - Success rate for similar tasks
|
||||
pub async fn get_historical_context(
|
||||
&self,
|
||||
task_description: &str,
|
||||
limit: usize,
|
||||
) -> anyhow::Result<Option<HistoricalContext>> {
|
||||
let similar = self.find_similar_tasks(task_description, limit).await?;
|
||||
|
||||
if similar.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Calculate aggregated stats
|
||||
let total = similar.len() as f64;
|
||||
|
||||
let avg_cost_multiplier = similar.iter()
|
||||
.filter_map(|o| o.cost_error_ratio)
|
||||
.sum::<f64>() / similar.iter().filter(|o| o.cost_error_ratio.is_some()).count().max(1) as f64;
|
||||
|
||||
let avg_token_multiplier = similar.iter()
|
||||
.filter_map(|o| o.token_error_ratio)
|
||||
.sum::<f64>() / similar.iter().filter(|o| o.token_error_ratio.is_some()).count().max(1) as f64;
|
||||
|
||||
let success_count = similar.iter().filter(|o| o.success).count() as f64;
|
||||
let similar_success_rate = success_count / total;
|
||||
|
||||
Ok(Some(HistoricalContext {
|
||||
similar_outcomes: similar,
|
||||
avg_cost_multiplier: if avg_cost_multiplier.is_nan() { 1.0 } else { avg_cost_multiplier },
|
||||
avg_token_multiplier: if avg_token_multiplier.is_nan() { 1.0 } else { avg_token_multiplier },
|
||||
similar_success_rate,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Get global learning statistics.
|
||||
pub async fn get_learning_stats(&self) -> anyhow::Result<serde_json::Value> {
|
||||
self.supabase.get_global_stats().await
|
||||
}
|
||||
}
|
||||
|
||||
/// Truncate a string to max length.
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
use reqwest::Client;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::types::{DbRun, DbTask, DbEvent, DbChunk, SearchResult};
|
||||
use super::types::{DbRun, DbTask, DbEvent, DbChunk, DbTaskOutcome, SearchResult, ModelStats};
|
||||
|
||||
/// Supabase client for database and storage operations.
|
||||
pub struct SupabaseClient {
|
||||
@@ -357,5 +357,156 @@ impl SupabaseClient {
|
||||
|
||||
self.update_run(run_id, body).await
|
||||
}
|
||||
|
||||
// ==================== Task Outcomes (Learning) ====================
|
||||
|
||||
/// Insert a task outcome for learning.
|
||||
pub async fn insert_task_outcome(
|
||||
&self,
|
||||
outcome: &DbTaskOutcome,
|
||||
embedding: Option<&[f32]>,
|
||||
) -> anyhow::Result<Uuid> {
|
||||
let embedding_str = embedding.map(|e| format!(
|
||||
"[{}]",
|
||||
e.iter().map(|f| f.to_string()).collect::<Vec<_>>().join(",")
|
||||
));
|
||||
|
||||
let body = serde_json::json!({
|
||||
"run_id": outcome.run_id,
|
||||
"task_id": outcome.task_id,
|
||||
"predicted_complexity": outcome.predicted_complexity,
|
||||
"predicted_tokens": outcome.predicted_tokens,
|
||||
"predicted_cost_cents": outcome.predicted_cost_cents,
|
||||
"selected_model": outcome.selected_model,
|
||||
"actual_tokens": outcome.actual_tokens,
|
||||
"actual_cost_cents": outcome.actual_cost_cents,
|
||||
"success": outcome.success,
|
||||
"iterations": outcome.iterations,
|
||||
"tool_calls_count": outcome.tool_calls_count,
|
||||
"task_description": outcome.task_description,
|
||||
"task_type": outcome.task_type,
|
||||
"cost_error_ratio": outcome.cost_error_ratio,
|
||||
"token_error_ratio": outcome.token_error_ratio,
|
||||
"task_embedding": embedding_str
|
||||
});
|
||||
|
||||
let resp = self.client
|
||||
.post(format!("{}/task_outcomes", self.rest_url()))
|
||||
.header("apikey", &self.service_role_key)
|
||||
.header("Authorization", format!("Bearer {}", self.service_role_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Prefer", "return=representation")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = resp.status();
|
||||
let text = resp.text().await?;
|
||||
|
||||
if !status.is_success() {
|
||||
anyhow::bail!("Failed to insert task outcome: {} - {}", status, text);
|
||||
}
|
||||
|
||||
let outcomes: Vec<DbTaskOutcome> = serde_json::from_str(&text)?;
|
||||
outcomes.into_iter().next()
|
||||
.and_then(|o| o.id)
|
||||
.ok_or_else(|| anyhow::anyhow!("No outcome ID returned"))
|
||||
}
|
||||
|
||||
/// Get model statistics for a complexity range.
|
||||
///
|
||||
/// Returns aggregated stats for each model that has been used
|
||||
/// for tasks in the given complexity range.
|
||||
pub async fn get_model_stats(
|
||||
&self,
|
||||
complexity_min: f64,
|
||||
complexity_max: f64,
|
||||
) -> anyhow::Result<Vec<ModelStats>> {
|
||||
let body = serde_json::json!({
|
||||
"complexity_min": complexity_min,
|
||||
"complexity_max": complexity_max
|
||||
});
|
||||
|
||||
let resp = self.client
|
||||
.post(format!("{}/rpc/get_model_stats", self.rest_url()))
|
||||
.header("apikey", &self.service_role_key)
|
||||
.header("Authorization", format!("Bearer {}", self.service_role_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = resp.status();
|
||||
let text = resp.text().await?;
|
||||
|
||||
if !status.is_success() {
|
||||
// If RPC doesn't exist yet, return empty
|
||||
tracing::debug!("get_model_stats RPC not available: {}", text);
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
Ok(serde_json::from_str(&text).unwrap_or_default())
|
||||
}
|
||||
|
||||
/// Search for similar task outcomes by embedding similarity.
|
||||
pub async fn search_similar_outcomes(
|
||||
&self,
|
||||
embedding: &[f32],
|
||||
threshold: f64,
|
||||
limit: usize,
|
||||
) -> anyhow::Result<Vec<DbTaskOutcome>> {
|
||||
let embedding_str = format!(
|
||||
"[{}]",
|
||||
embedding.iter().map(|f| f.to_string()).collect::<Vec<_>>().join(",")
|
||||
);
|
||||
|
||||
let body = serde_json::json!({
|
||||
"query_embedding": embedding_str,
|
||||
"match_threshold": threshold,
|
||||
"match_count": limit
|
||||
});
|
||||
|
||||
let resp = self.client
|
||||
.post(format!("{}/rpc/search_similar_outcomes", self.rest_url()))
|
||||
.header("apikey", &self.service_role_key)
|
||||
.header("Authorization", format!("Bearer {}", self.service_role_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = resp.status();
|
||||
let text = resp.text().await?;
|
||||
|
||||
if !status.is_success() {
|
||||
// If RPC doesn't exist yet, return empty
|
||||
tracing::debug!("search_similar_outcomes RPC not available: {}", text);
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
Ok(serde_json::from_str(&text).unwrap_or_default())
|
||||
}
|
||||
|
||||
/// Get global learning statistics (for tuning).
|
||||
pub async fn get_global_stats(&self) -> anyhow::Result<serde_json::Value> {
|
||||
let resp = self.client
|
||||
.post(format!("{}/rpc/get_global_learning_stats", self.rest_url()))
|
||||
.header("apikey", &self.service_role_key)
|
||||
.header("Authorization", format!("Bearer {}", self.service_role_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&serde_json::json!({}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = resp.status();
|
||||
let text = resp.text().await?;
|
||||
|
||||
if !status.is_success() {
|
||||
tracing::debug!("get_global_learning_stats RPC not available: {}", text);
|
||||
return Ok(serde_json::json!({}));
|
||||
}
|
||||
|
||||
Ok(serde_json::from_str(&text).unwrap_or_default())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -154,6 +154,123 @@ pub struct ContextPack {
|
||||
pub query: String,
|
||||
}
|
||||
|
||||
/// Task outcome record for learning from execution history.
|
||||
///
|
||||
/// Captures predictions vs actuals to enable data-driven optimization
|
||||
/// of complexity estimation, model selection, and budget allocation.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DbTaskOutcome {
|
||||
pub id: Option<Uuid>,
|
||||
pub run_id: Uuid,
|
||||
pub task_id: Uuid,
|
||||
|
||||
// Predictions (what we estimated before execution)
|
||||
pub predicted_complexity: Option<f64>,
|
||||
pub predicted_tokens: Option<i64>,
|
||||
pub predicted_cost_cents: Option<i64>,
|
||||
pub selected_model: Option<String>,
|
||||
|
||||
// Actuals (what happened during execution)
|
||||
pub actual_tokens: Option<i64>,
|
||||
pub actual_cost_cents: Option<i64>,
|
||||
pub success: bool,
|
||||
pub iterations: Option<i32>,
|
||||
pub tool_calls_count: Option<i32>,
|
||||
|
||||
// Metadata for similarity search
|
||||
pub task_description: String,
|
||||
/// Category of task (inferred or explicit)
|
||||
pub task_type: Option<String>,
|
||||
|
||||
// Computed ratios for quick stats
|
||||
pub cost_error_ratio: Option<f64>,
|
||||
pub token_error_ratio: Option<f64>,
|
||||
|
||||
pub created_at: Option<String>,
|
||||
}
|
||||
|
||||
impl DbTaskOutcome {
|
||||
/// Create a new outcome from predictions and actuals.
|
||||
pub fn new(
|
||||
run_id: Uuid,
|
||||
task_id: Uuid,
|
||||
task_description: String,
|
||||
predicted_complexity: Option<f64>,
|
||||
predicted_tokens: Option<i64>,
|
||||
predicted_cost_cents: Option<i64>,
|
||||
selected_model: Option<String>,
|
||||
actual_tokens: Option<i64>,
|
||||
actual_cost_cents: Option<i64>,
|
||||
success: bool,
|
||||
iterations: Option<i32>,
|
||||
tool_calls_count: Option<i32>,
|
||||
) -> Self {
|
||||
// Compute error ratios
|
||||
let cost_error_ratio = match (actual_cost_cents, predicted_cost_cents) {
|
||||
(Some(actual), Some(predicted)) if predicted > 0 => {
|
||||
Some(actual as f64 / predicted as f64)
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
let token_error_ratio = match (actual_tokens, predicted_tokens) {
|
||||
(Some(actual), Some(predicted)) if predicted > 0 => {
|
||||
Some(actual as f64 / predicted as f64)
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
Self {
|
||||
id: None,
|
||||
run_id,
|
||||
task_id,
|
||||
predicted_complexity,
|
||||
predicted_tokens,
|
||||
predicted_cost_cents,
|
||||
selected_model,
|
||||
actual_tokens,
|
||||
actual_cost_cents,
|
||||
success,
|
||||
iterations,
|
||||
tool_calls_count,
|
||||
task_description,
|
||||
task_type: None,
|
||||
cost_error_ratio,
|
||||
token_error_ratio,
|
||||
created_at: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Model performance statistics from historical data.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelStats {
|
||||
pub model_id: String,
|
||||
/// Success rate (0-1)
|
||||
pub success_rate: f64,
|
||||
/// Average cost vs predicted (1.0 = accurate, >1 = underestimated)
|
||||
pub avg_cost_ratio: f64,
|
||||
/// Average tokens vs predicted
|
||||
pub avg_token_ratio: f64,
|
||||
/// Average iterations needed
|
||||
pub avg_iterations: f64,
|
||||
/// Number of samples
|
||||
pub sample_count: i64,
|
||||
}
|
||||
|
||||
/// Historical context for a task (similar past tasks).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HistoricalContext {
|
||||
/// Similar past task outcomes
|
||||
pub similar_outcomes: Vec<DbTaskOutcome>,
|
||||
/// Average cost adjustment multiplier
|
||||
pub avg_cost_multiplier: f64,
|
||||
/// Average token adjustment multiplier
|
||||
pub avg_token_multiplier: f64,
|
||||
/// Success rate for similar tasks
|
||||
pub similar_success_rate: f64,
|
||||
}
|
||||
|
||||
impl ContextPack {
|
||||
/// Format as a string for prompt injection.
|
||||
pub fn format_for_prompt(&self) -> String {
|
||||
|
||||
@@ -7,7 +7,7 @@ use serde_json::json;
|
||||
|
||||
use super::supabase::SupabaseClient;
|
||||
use super::embed::EmbeddingClient;
|
||||
use super::types::{DbTask, DbEvent, DbChunk, EventKind, MemoryStatus};
|
||||
use super::types::{DbTask, DbEvent, DbChunk, DbTaskOutcome, EventKind, MemoryStatus};
|
||||
|
||||
/// Maximum chunk size in characters.
|
||||
const MAX_CHUNK_SIZE: usize = 2000;
|
||||
@@ -241,6 +241,47 @@ impl MemoryWriter {
|
||||
self.supabase.update_run_summary(run_id, summary, &embedding).await
|
||||
}
|
||||
|
||||
/// Record a task outcome for learning.
|
||||
///
|
||||
/// This captures predictions vs actuals to enable data-driven optimization
|
||||
/// of complexity estimation, model selection, and budget allocation.
|
||||
pub async fn record_task_outcome(
|
||||
&self,
|
||||
run_id: Uuid,
|
||||
task_id: Uuid,
|
||||
task_description: &str,
|
||||
predicted_complexity: Option<f64>,
|
||||
predicted_tokens: Option<i64>,
|
||||
predicted_cost_cents: Option<i64>,
|
||||
selected_model: Option<String>,
|
||||
actual_tokens: Option<i64>,
|
||||
actual_cost_cents: Option<i64>,
|
||||
success: bool,
|
||||
iterations: Option<i32>,
|
||||
tool_calls_count: Option<i32>,
|
||||
) -> anyhow::Result<Uuid> {
|
||||
// Create the outcome record
|
||||
let outcome = DbTaskOutcome::new(
|
||||
run_id,
|
||||
task_id,
|
||||
task_description.to_string(),
|
||||
predicted_complexity,
|
||||
predicted_tokens,
|
||||
predicted_cost_cents,
|
||||
selected_model,
|
||||
actual_tokens,
|
||||
actual_cost_cents,
|
||||
success,
|
||||
iterations,
|
||||
tool_calls_count,
|
||||
);
|
||||
|
||||
// Generate embedding for similarity search
|
||||
let embedding = self.embedder.embed(task_description).await.ok();
|
||||
|
||||
self.supabase.insert_task_outcome(&outcome, embedding.as_deref()).await
|
||||
}
|
||||
|
||||
/// Split text into chunks.
|
||||
fn chunk_text(&self, text: &str) -> Vec<String> {
|
||||
let mut chunks = Vec::new();
|
||||
|
||||
Reference in New Issue
Block a user