Implement hierarchical agent tree architecture
Core types with provability design: - Task, Budget, Complexity with documented invariants - VerificationCriteria with programmatic/LLM hybrid support - SubtaskPlan with topological sort for execution order Agent hierarchy: - Agent trait with pre/post-conditions documented - OrchestratorAgent for Root/Node agents - LeafAgent for specialized workers Leaf agents: - ComplexityEstimator: estimates task difficulty (0-1 score) - ModelSelector: U-curve optimization for cost/capability - TaskExecutor: refactored from original agent loop - Verifier: hybrid programmatic + LLM verification Orchestrators: - RootAgent: top-level, estimates complexity, splits tasks - NodeAgent: intermediate, handles delegated subtasks Budget system: - Budget allocation strategies (proportional, equal, priority) - OpenRouter pricing integration for cost estimation API updated to use hierarchical RootAgent
This commit is contained in:
@@ -36,6 +36,7 @@ walkdir = "2"
|
||||
urlencoding = "2"
|
||||
anyhow = "1"
|
||||
async-stream = "0.3"
|
||||
regex = "1"
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-test = "0.4"
|
||||
|
||||
85
src/agents/context.rs
Normal file
85
src/agents/context.rs
Normal file
@@ -0,0 +1,85 @@
|
||||
//! Agent execution context - shared state across the agent tree.
|
||||
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::budget::ModelPricing;
|
||||
use crate::config::Config;
|
||||
use crate::llm::LlmClient;
|
||||
use crate::tools::ToolRegistry;
|
||||
|
||||
/// Shared context passed to all agents during execution.
|
||||
///
|
||||
/// # Thread Safety
|
||||
/// Context is wrapped in Arc for sharing across async tasks.
|
||||
/// Individual components use interior mutability where needed.
|
||||
pub struct AgentContext {
|
||||
/// Application configuration
|
||||
pub config: Config,
|
||||
|
||||
/// LLM client for model calls
|
||||
pub llm: Arc<dyn LlmClient>,
|
||||
|
||||
/// Tool registry for task execution
|
||||
pub tools: ToolRegistry,
|
||||
|
||||
/// Model pricing information
|
||||
pub pricing: Arc<ModelPricing>,
|
||||
|
||||
/// Workspace path for file operations
|
||||
pub workspace: PathBuf,
|
||||
|
||||
/// Maximum depth for recursive task splitting
|
||||
pub max_split_depth: usize,
|
||||
|
||||
/// Maximum iterations per agent
|
||||
pub max_iterations: usize,
|
||||
}
|
||||
|
||||
impl AgentContext {
|
||||
/// Create a new agent context.
|
||||
pub fn new(
|
||||
config: Config,
|
||||
llm: Arc<dyn LlmClient>,
|
||||
tools: ToolRegistry,
|
||||
pricing: Arc<ModelPricing>,
|
||||
workspace: PathBuf,
|
||||
) -> Self {
|
||||
Self {
|
||||
max_iterations: config.max_iterations,
|
||||
config,
|
||||
llm,
|
||||
tools,
|
||||
pricing,
|
||||
workspace,
|
||||
max_split_depth: 3, // Default max recursion for splitting
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a child context with reduced split depth.
|
||||
///
|
||||
/// # Postcondition
|
||||
/// `child.max_split_depth == self.max_split_depth - 1`
|
||||
pub fn child_context(&self) -> Self {
|
||||
Self {
|
||||
config: self.config.clone(),
|
||||
llm: Arc::clone(&self.llm),
|
||||
tools: ToolRegistry::new(), // Fresh tools for isolation
|
||||
pricing: Arc::clone(&self.pricing),
|
||||
workspace: self.workspace.clone(),
|
||||
max_split_depth: self.max_split_depth.saturating_sub(1),
|
||||
max_iterations: self.max_iterations,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if further task splitting is allowed.
|
||||
pub fn can_split(&self) -> bool {
|
||||
self.max_split_depth > 0
|
||||
}
|
||||
|
||||
/// Get the workspace path as a string.
|
||||
pub fn workspace_str(&self) -> String {
|
||||
self.workspace.to_string_lossy().to_string()
|
||||
}
|
||||
}
|
||||
|
||||
242
src/agents/leaf/complexity.rs
Normal file
242
src/agents/leaf/complexity.rs
Normal file
@@ -0,0 +1,242 @@
|
||||
//! Complexity estimation agent.
|
||||
//!
|
||||
//! Analyzes a task description and estimates:
|
||||
//! - Complexity score (0-1)
|
||||
//! - Whether to split into subtasks
|
||||
//! - Estimated token count
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::agents::{
|
||||
Agent, AgentContext, AgentId, AgentResult, AgentType, Complexity, LeafAgent, LeafCapability,
|
||||
};
|
||||
use crate::llm::{ChatMessage, Role};
|
||||
use crate::task::Task;
|
||||
|
||||
/// Agent that estimates task complexity.
|
||||
///
|
||||
/// # Purpose
|
||||
/// Given a task description, estimate how complex it is and whether
|
||||
/// it should be split into subtasks.
|
||||
///
|
||||
/// # Algorithm
|
||||
/// 1. Send task description to LLM with complexity evaluation prompt
|
||||
/// 2. Parse LLM response for complexity score and reasoning
|
||||
/// 3. Return structured Complexity object
|
||||
pub struct ComplexityEstimator {
|
||||
id: AgentId,
|
||||
}
|
||||
|
||||
impl ComplexityEstimator {
|
||||
/// Create a new complexity estimator.
|
||||
pub fn new() -> Self {
|
||||
Self { id: AgentId::new() }
|
||||
}
|
||||
|
||||
/// Prompt template for complexity estimation.
|
||||
///
|
||||
/// # Response Format
|
||||
/// LLM should respond with JSON containing:
|
||||
/// - score: float 0-1
|
||||
/// - reasoning: string explanation
|
||||
/// - estimated_tokens: int
|
||||
/// - subtasks: optional array if should split
|
||||
fn build_prompt(&self, task: &Task) -> String {
|
||||
format!(
|
||||
r#"You are a task complexity analyzer. Analyze the following task and estimate its complexity.
|
||||
|
||||
Task: {}
|
||||
|
||||
Respond with a JSON object containing:
|
||||
- "score": A float from 0.0 to 1.0 where:
|
||||
- 0.0-0.2: Trivial (single command, simple file operation)
|
||||
- 0.2-0.4: Simple (few steps, straightforward implementation)
|
||||
- 0.4-0.6: Moderate (multiple files, some decision making)
|
||||
- 0.6-0.8: Complex (many files, architectural decisions, testing)
|
||||
- 0.8-1.0: Very Complex (large refactoring, many dependencies)
|
||||
|
||||
- "reasoning": Brief explanation of why this complexity level
|
||||
|
||||
- "estimated_tokens": Estimated total tokens needed (input + output) to complete this task
|
||||
|
||||
- "should_split": Boolean, true if task should be broken into subtasks
|
||||
|
||||
- "subtasks": If should_split is true, array of suggested subtask descriptions
|
||||
|
||||
Respond with ONLY the JSON object, no other text."#,
|
||||
task.description()
|
||||
)
|
||||
}
|
||||
|
||||
/// Parse LLM response into Complexity struct.
|
||||
///
|
||||
/// # Postconditions
|
||||
/// - Returns valid Complexity with score in [0, 1]
|
||||
/// - Falls back to moderate complexity on parse error
|
||||
fn parse_response(&self, response: &str) -> Complexity {
|
||||
// Try to parse as JSON
|
||||
if let Ok(json) = serde_json::from_str::<serde_json::Value>(response) {
|
||||
let score = json["score"].as_f64().unwrap_or(0.5);
|
||||
let reasoning = json["reasoning"].as_str().unwrap_or("No reasoning provided");
|
||||
let estimated_tokens = json["estimated_tokens"].as_u64().unwrap_or(2000);
|
||||
|
||||
return Complexity::new(score, reasoning, estimated_tokens);
|
||||
}
|
||||
|
||||
// Try to extract score from text
|
||||
if let Some(score) = self.extract_score_from_text(response) {
|
||||
return Complexity::new(score, response, 2000);
|
||||
}
|
||||
|
||||
// Default to moderate complexity
|
||||
Complexity::moderate("Could not parse complexity response")
|
||||
}
|
||||
|
||||
/// Try to extract a score from free-form text.
|
||||
fn extract_score_from_text(&self, text: &str) -> Option<f64> {
|
||||
// Look for patterns like "0.5" or "score: 0.5" or "50%"
|
||||
let text_lower = text.to_lowercase();
|
||||
|
||||
// Check for keywords
|
||||
if text_lower.contains("trivial") || text_lower.contains("very simple") {
|
||||
return Some(0.1);
|
||||
}
|
||||
if text_lower.contains("very complex") || text_lower.contains("extremely") {
|
||||
return Some(0.9);
|
||||
}
|
||||
if text_lower.contains("complex") {
|
||||
return Some(0.7);
|
||||
}
|
||||
if text_lower.contains("moderate") || text_lower.contains("medium") {
|
||||
return Some(0.5);
|
||||
}
|
||||
if text_lower.contains("simple") || text_lower.contains("easy") {
|
||||
return Some(0.3);
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ComplexityEstimator {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Agent for ComplexityEstimator {
|
||||
fn id(&self) -> &AgentId {
|
||||
&self.id
|
||||
}
|
||||
|
||||
fn agent_type(&self) -> AgentType {
|
||||
AgentType::ComplexityEstimator
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Estimates task complexity and recommends splitting strategy"
|
||||
}
|
||||
|
||||
/// Estimate complexity of a task.
|
||||
///
|
||||
/// # Returns
|
||||
/// AgentResult with Complexity data in the `data` field.
|
||||
async fn execute(&self, task: &mut Task, ctx: &AgentContext) -> AgentResult {
|
||||
let prompt = self.build_prompt(task);
|
||||
|
||||
let messages = vec![
|
||||
ChatMessage {
|
||||
role: Role::System,
|
||||
content: Some("You are a precise task complexity analyzer. Respond only with JSON.".to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
ChatMessage {
|
||||
role: Role::User,
|
||||
content: Some(prompt),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
];
|
||||
|
||||
// Use a fast, cheap model for complexity estimation
|
||||
let model = "openai/gpt-4.1-mini";
|
||||
|
||||
match ctx.llm.chat_completion(model, &messages, None).await {
|
||||
Ok(response) => {
|
||||
let content = response.content.unwrap_or_default();
|
||||
let complexity = self.parse_response(&content);
|
||||
|
||||
// Estimate cost (rough: ~1000 tokens for this request)
|
||||
let cost_cents = 1; // Very cheap operation
|
||||
|
||||
AgentResult::success(
|
||||
format!(
|
||||
"Complexity: {:.2} - {}",
|
||||
complexity.score(),
|
||||
if complexity.should_split() { "Should split" } else { "Execute directly" }
|
||||
),
|
||||
cost_cents,
|
||||
)
|
||||
.with_model(model)
|
||||
.with_data(json!({
|
||||
"score": complexity.score(),
|
||||
"reasoning": complexity.reasoning(),
|
||||
"should_split": complexity.should_split(),
|
||||
"estimated_tokens": complexity.estimated_tokens(),
|
||||
}))
|
||||
}
|
||||
Err(e) => {
|
||||
// On error, return moderate complexity as fallback
|
||||
let fallback = Complexity::moderate(format!("LLM error, using fallback: {}", e));
|
||||
|
||||
AgentResult::success(
|
||||
"Using fallback complexity estimate due to LLM error",
|
||||
0,
|
||||
)
|
||||
.with_data(json!({
|
||||
"score": fallback.score(),
|
||||
"reasoning": fallback.reasoning(),
|
||||
"should_split": fallback.should_split(),
|
||||
"estimated_tokens": fallback.estimated_tokens(),
|
||||
"fallback": true,
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LeafAgent for ComplexityEstimator {
|
||||
fn capability(&self) -> LeafCapability {
|
||||
LeafCapability::ComplexityEstimation
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_json_response() {
|
||||
let estimator = ComplexityEstimator::new();
|
||||
|
||||
let json_response = r#"{"score": 0.7, "reasoning": "Complex task", "estimated_tokens": 3000, "should_split": true}"#;
|
||||
let complexity = estimator.parse_response(json_response);
|
||||
|
||||
assert!((complexity.score() - 0.7).abs() < 0.01);
|
||||
assert!(complexity.should_split());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_text_response() {
|
||||
let estimator = ComplexityEstimator::new();
|
||||
|
||||
let text_response = "This is a very complex task";
|
||||
let complexity = estimator.parse_response(text_response);
|
||||
|
||||
assert!(complexity.score() > 0.6);
|
||||
}
|
||||
}
|
||||
|
||||
244
src/agents/leaf/executor.rs
Normal file
244
src/agents/leaf/executor.rs
Normal file
@@ -0,0 +1,244 @@
|
||||
//! Task executor agent - the main worker that uses tools.
|
||||
//!
|
||||
//! This is a refactored version of the original agent loop,
|
||||
//! now as a leaf agent in the hierarchical tree.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::agents::{
|
||||
Agent, AgentContext, AgentId, AgentResult, AgentType, LeafAgent, LeafCapability,
|
||||
};
|
||||
use crate::llm::{ChatMessage, Role, ToolCall};
|
||||
use crate::task::Task;
|
||||
use crate::tools::ToolRegistry;
|
||||
|
||||
/// Agent that executes tasks using tools.
|
||||
///
|
||||
/// # Algorithm
|
||||
/// 1. Build system prompt with available tools
|
||||
/// 2. Call LLM with task description
|
||||
/// 3. If LLM requests tool call: execute, feed back result
|
||||
/// 4. Repeat until LLM produces final response or max iterations
|
||||
///
|
||||
/// # Budget Management
|
||||
/// - Tracks token usage and costs
|
||||
/// - Stops if budget is exhausted
|
||||
pub struct TaskExecutor {
|
||||
id: AgentId,
|
||||
}
|
||||
|
||||
impl TaskExecutor {
|
||||
/// Create a new task executor.
|
||||
pub fn new() -> Self {
|
||||
Self { id: AgentId::new() }
|
||||
}
|
||||
|
||||
/// Build the system prompt for task execution.
|
||||
fn build_system_prompt(&self, workspace: &str, tools: &ToolRegistry) -> String {
|
||||
let tool_descriptions = tools
|
||||
.list_tools()
|
||||
.iter()
|
||||
.map(|t| format!("- **{}**: {}", t.name, t.description))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
format!(
|
||||
r#"You are an autonomous task executor with access to tools.
|
||||
You operate in the workspace: {workspace}
|
||||
|
||||
## Available Tools
|
||||
{tool_descriptions}
|
||||
|
||||
## Rules
|
||||
1. Use tools to accomplish the task - don't just describe what to do
|
||||
2. Read files before editing them
|
||||
3. Verify your work when possible
|
||||
4. If stuck, explain what's blocking you
|
||||
5. When done, summarize what you accomplished
|
||||
|
||||
## Response
|
||||
When task is complete, provide a clear summary of:
|
||||
- What you did
|
||||
- Files created/modified
|
||||
- How to verify the result"#,
|
||||
workspace = workspace,
|
||||
tool_descriptions = tool_descriptions
|
||||
)
|
||||
}
|
||||
|
||||
/// Execute a single tool call.
|
||||
async fn execute_tool_call(
|
||||
&self,
|
||||
tool_call: &ToolCall,
|
||||
ctx: &AgentContext,
|
||||
) -> anyhow::Result<String> {
|
||||
let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments)
|
||||
.unwrap_or(serde_json::Value::Null);
|
||||
|
||||
ctx.tools
|
||||
.execute(&tool_call.function.name, args, &ctx.workspace)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Run the agent loop for a task.
|
||||
async fn run_loop(
|
||||
&self,
|
||||
task: &Task,
|
||||
model: &str,
|
||||
ctx: &AgentContext,
|
||||
) -> (String, u64, Vec<String>) {
|
||||
let mut total_cost = 0u64;
|
||||
let mut tool_log = Vec::new();
|
||||
|
||||
// Build initial messages
|
||||
let system_prompt = self.build_system_prompt(&ctx.workspace_str(), &ctx.tools);
|
||||
let mut messages = vec![
|
||||
ChatMessage {
|
||||
role: Role::System,
|
||||
content: Some(system_prompt),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
ChatMessage {
|
||||
role: Role::User,
|
||||
content: Some(task.description().to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
];
|
||||
|
||||
// Get tool schemas
|
||||
let tool_schemas = ctx.tools.get_tool_schemas();
|
||||
|
||||
// Agent loop
|
||||
for iteration in 0..ctx.max_iterations {
|
||||
tracing::debug!("TaskExecutor iteration {}", iteration + 1);
|
||||
|
||||
// Check budget
|
||||
let remaining = task.budget().remaining_cents();
|
||||
if remaining == 0 && total_cost > 0 {
|
||||
return (
|
||||
"Budget exhausted before task completion".to_string(),
|
||||
total_cost,
|
||||
tool_log,
|
||||
);
|
||||
}
|
||||
|
||||
// Call LLM
|
||||
let response = match ctx.llm.chat_completion(model, &messages, Some(&tool_schemas)).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
return (
|
||||
format!("LLM error: {}", e),
|
||||
total_cost,
|
||||
tool_log,
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
// Estimate cost (rough: ~1 cent per request for cheap models)
|
||||
total_cost += 2;
|
||||
|
||||
// Check for tool calls
|
||||
if let Some(tool_calls) = &response.tool_calls {
|
||||
if !tool_calls.is_empty() {
|
||||
// Add assistant message with tool calls
|
||||
messages.push(ChatMessage {
|
||||
role: Role::Assistant,
|
||||
content: response.content.clone(),
|
||||
tool_calls: Some(tool_calls.clone()),
|
||||
tool_call_id: None,
|
||||
});
|
||||
|
||||
// Execute each tool call
|
||||
for tool_call in tool_calls {
|
||||
tool_log.push(format!(
|
||||
"Tool: {} Args: {}",
|
||||
tool_call.function.name,
|
||||
tool_call.function.arguments
|
||||
));
|
||||
|
||||
let result = match self.execute_tool_call(tool_call, ctx).await {
|
||||
Ok(output) => output,
|
||||
Err(e) => format!("Error: {}", e),
|
||||
};
|
||||
|
||||
// Add tool result
|
||||
messages.push(ChatMessage {
|
||||
role: Role::Tool,
|
||||
content: Some(result),
|
||||
tool_calls: None,
|
||||
tool_call_id: Some(tool_call.id.clone()),
|
||||
});
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// No tool calls - final response
|
||||
if let Some(content) = response.content {
|
||||
return (content, total_cost, tool_log);
|
||||
}
|
||||
|
||||
// Empty response
|
||||
return (
|
||||
"LLM returned empty response".to_string(),
|
||||
total_cost,
|
||||
tool_log,
|
||||
);
|
||||
}
|
||||
|
||||
(
|
||||
format!("Max iterations ({}) reached", ctx.max_iterations),
|
||||
total_cost,
|
||||
tool_log,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TaskExecutor {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Agent for TaskExecutor {
|
||||
fn id(&self) -> &AgentId {
|
||||
&self.id
|
||||
}
|
||||
|
||||
fn agent_type(&self) -> AgentType {
|
||||
AgentType::TaskExecutor
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Executes tasks using tools (file ops, terminal, search, etc.)"
|
||||
}
|
||||
|
||||
async fn execute(&self, task: &mut Task, ctx: &AgentContext) -> AgentResult {
|
||||
// Use default model or from context
|
||||
let model = &ctx.config.default_model;
|
||||
|
||||
let (output, cost, tool_log) = self.run_loop(task, model, ctx).await;
|
||||
|
||||
// Update task budget
|
||||
let _ = task.budget_mut().try_spend(cost);
|
||||
|
||||
AgentResult::success(&output, cost)
|
||||
.with_model(model)
|
||||
.with_data(json!({
|
||||
"tool_calls": tool_log.len(),
|
||||
"tools_used": tool_log,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl LeafAgent for TaskExecutor {
|
||||
fn capability(&self) -> LeafCapability {
|
||||
LeafCapability::TaskExecution
|
||||
}
|
||||
}
|
||||
|
||||
18
src/agents/leaf/mod.rs
Normal file
18
src/agents/leaf/mod.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
//! Leaf agents - specialized agents that do actual work.
|
||||
//!
|
||||
//! # Leaf Agent Types
|
||||
//! - `ComplexityEstimator`: Estimates task complexity (0-1 score)
|
||||
//! - `ModelSelector`: Selects optimal model for task/budget
|
||||
//! - `TaskExecutor`: Executes tasks using tools (main worker)
|
||||
//! - `Verifier`: Validates task completion
|
||||
|
||||
mod complexity;
|
||||
mod model_select;
|
||||
mod executor;
|
||||
mod verifier;
|
||||
|
||||
pub use complexity::ComplexityEstimator;
|
||||
pub use model_select::ModelSelector;
|
||||
pub use executor::TaskExecutor;
|
||||
pub use verifier::Verifier;
|
||||
|
||||
364
src/agents/leaf/model_select.rs
Normal file
364
src/agents/leaf/model_select.rs
Normal file
@@ -0,0 +1,364 @@
|
||||
//! Model selection agent with U-curve cost optimization.
|
||||
//!
|
||||
//! # U-Curve Optimization
|
||||
//! The total expected cost follows a U-shaped curve:
|
||||
//! - Cheap models: Low per-token cost, but may fail/retry, use more tokens
|
||||
//! - Expensive models: High per-token cost, but succeed more often
|
||||
//! - Optimal: Somewhere in the middle, minimizing total expected cost
|
||||
//!
|
||||
//! # Cost Model
|
||||
//! Expected Cost = base_cost * (1 + failure_rate * retry_multiplier) * token_efficiency
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::agents::{
|
||||
Agent, AgentContext, AgentId, AgentResult, AgentType, LeafAgent, LeafCapability,
|
||||
};
|
||||
use crate::budget::PricingInfo;
|
||||
use crate::task::Task;
|
||||
|
||||
/// Agent that selects the optimal model for a task.
|
||||
///
|
||||
/// # Algorithm
|
||||
/// 1. Get task complexity and budget constraints
|
||||
/// 2. Fetch available models and pricing
|
||||
/// 3. For each model, calculate expected total cost
|
||||
/// 4. Return model with minimum expected cost within budget
|
||||
pub struct ModelSelector {
|
||||
id: AgentId,
|
||||
}
|
||||
|
||||
/// Model recommendation from the selector.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelRecommendation {
|
||||
/// Recommended model ID
|
||||
pub model_id: String,
|
||||
|
||||
/// Expected cost in cents
|
||||
pub expected_cost_cents: u64,
|
||||
|
||||
/// Confidence in this recommendation (0-1)
|
||||
pub confidence: f64,
|
||||
|
||||
/// Reasoning for the selection
|
||||
pub reasoning: String,
|
||||
|
||||
/// Alternative models if primary fails
|
||||
pub fallbacks: Vec<String>,
|
||||
}
|
||||
|
||||
impl ModelSelector {
|
||||
/// Create a new model selector.
|
||||
pub fn new() -> Self {
|
||||
Self { id: AgentId::new() }
|
||||
}
|
||||
|
||||
/// Calculate expected cost for a model given task complexity.
|
||||
///
|
||||
/// # Formula
|
||||
/// ```text
|
||||
/// expected_cost = base_cost * (1 + failure_prob * retry_cost) * inefficiency_factor
|
||||
/// ```
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `pricing`: Model pricing info
|
||||
/// - `complexity`: Task complexity (0-1)
|
||||
/// - `estimated_tokens`: Estimated tokens needed
|
||||
///
|
||||
/// # Returns
|
||||
/// Expected cost in cents
|
||||
///
|
||||
/// # Pure Function
|
||||
/// No side effects, deterministic output.
|
||||
fn calculate_expected_cost(
|
||||
&self,
|
||||
pricing: &PricingInfo,
|
||||
complexity: f64,
|
||||
estimated_tokens: u64,
|
||||
) -> ExpectedCost {
|
||||
// Model capability estimate based on pricing tier
|
||||
// Higher price generally means more capable
|
||||
let avg_cost = pricing.average_cost_per_token();
|
||||
let capability = self.estimate_capability(avg_cost);
|
||||
|
||||
// Failure probability: higher complexity + lower capability = more failures
|
||||
// Formula: P(fail) = complexity * (1 - capability)
|
||||
let failure_prob = (complexity * (1.0 - capability)).clamp(0.0, 0.9);
|
||||
|
||||
// Token inefficiency: weaker models need more tokens
|
||||
// Formula: inefficiency = 1 + (1 - capability) * 0.5
|
||||
let inefficiency = 1.0 + (1.0 - capability) * 0.5;
|
||||
|
||||
// Retry cost: if it fails, we pay again (possibly with a better model)
|
||||
let retry_multiplier = 1.5; // Retries cost 50% more (wasted context)
|
||||
|
||||
// Base cost for estimated tokens
|
||||
let input_tokens = estimated_tokens / 2;
|
||||
let output_tokens = estimated_tokens / 2;
|
||||
let base_cost = pricing.calculate_cost_cents(input_tokens, output_tokens);
|
||||
|
||||
// Adjusted for inefficiency (weak models use more tokens)
|
||||
let adjusted_tokens = ((estimated_tokens as f64) * inefficiency) as u64;
|
||||
let adjusted_cost = pricing.calculate_cost_cents(adjusted_tokens / 2, adjusted_tokens / 2);
|
||||
|
||||
// Expected cost including retry probability
|
||||
let expected_cost = (adjusted_cost as f64) * (1.0 + failure_prob * retry_multiplier);
|
||||
|
||||
ExpectedCost {
|
||||
model_id: pricing.model_id.clone(),
|
||||
base_cost_cents: base_cost,
|
||||
expected_cost_cents: expected_cost.ceil() as u64,
|
||||
failure_probability: failure_prob,
|
||||
capability,
|
||||
inefficiency,
|
||||
}
|
||||
}
|
||||
|
||||
/// Estimate model capability from its cost.
|
||||
///
|
||||
/// # Heuristic
|
||||
/// More expensive models are generally more capable.
|
||||
/// Uses log scale to normalize across price ranges.
|
||||
///
|
||||
/// # Returns
|
||||
/// Capability score 0-1
|
||||
fn estimate_capability(&self, avg_cost_per_token: f64) -> f64 {
|
||||
// Cost tiers (per token):
|
||||
// < 0.0001: weak (capability ~0.3)
|
||||
// 0.0001-0.001: moderate (capability ~0.6)
|
||||
// > 0.001: strong (capability ~0.9)
|
||||
|
||||
if avg_cost_per_token < 0.0000001 {
|
||||
return 0.3; // Free/very cheap
|
||||
}
|
||||
|
||||
// Log scale normalization
|
||||
let log_cost = avg_cost_per_token.log10();
|
||||
// Map from ~-7 (cheap) to ~-3 (expensive) => 0.3 to 0.95
|
||||
let normalized = ((log_cost + 7.0) / 4.0).clamp(0.0, 1.0);
|
||||
|
||||
0.3 + normalized * 0.65
|
||||
}
|
||||
|
||||
/// Select optimal model from available options.
|
||||
///
|
||||
/// # Algorithm
|
||||
/// 1. Calculate expected cost for each model
|
||||
/// 2. Filter models exceeding budget
|
||||
/// 3. Select model with minimum expected cost
|
||||
/// 4. Include fallbacks in case of failure
|
||||
///
|
||||
/// # Preconditions
|
||||
/// - `models` is non-empty
|
||||
/// - `budget_cents > 0`
|
||||
fn select_optimal(
|
||||
&self,
|
||||
models: &[PricingInfo],
|
||||
complexity: f64,
|
||||
estimated_tokens: u64,
|
||||
budget_cents: u64,
|
||||
) -> Option<ModelRecommendation> {
|
||||
if models.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Calculate expected cost for all models
|
||||
let mut costs: Vec<ExpectedCost> = models
|
||||
.iter()
|
||||
.map(|m| self.calculate_expected_cost(m, complexity, estimated_tokens))
|
||||
.collect();
|
||||
|
||||
// Sort by expected cost (ascending)
|
||||
costs.sort_by(|a, b| {
|
||||
a.expected_cost_cents
|
||||
.cmp(&b.expected_cost_cents)
|
||||
});
|
||||
|
||||
// Find cheapest model within budget
|
||||
let within_budget: Vec<_> = costs
|
||||
.iter()
|
||||
.filter(|c| c.expected_cost_cents <= budget_cents)
|
||||
.collect();
|
||||
|
||||
let selected = within_budget.first().copied().or(costs.first())?;
|
||||
|
||||
// Get fallback models (next best options)
|
||||
let fallbacks: Vec<String> = costs
|
||||
.iter()
|
||||
.filter(|c| c.model_id != selected.model_id)
|
||||
.take(2)
|
||||
.map(|c| c.model_id.clone())
|
||||
.collect();
|
||||
|
||||
Some(ModelRecommendation {
|
||||
model_id: selected.model_id.clone(),
|
||||
expected_cost_cents: selected.expected_cost_cents,
|
||||
confidence: 1.0 - selected.failure_probability,
|
||||
reasoning: format!(
|
||||
"Selected {} with expected cost {} cents (capability: {:.2}, failure prob: {:.2})",
|
||||
selected.model_id,
|
||||
selected.expected_cost_cents,
|
||||
selected.capability,
|
||||
selected.failure_probability
|
||||
),
|
||||
fallbacks,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Intermediate calculation result for a model.
|
||||
#[derive(Debug)]
|
||||
struct ExpectedCost {
|
||||
model_id: String,
|
||||
base_cost_cents: u64,
|
||||
expected_cost_cents: u64,
|
||||
failure_probability: f64,
|
||||
capability: f64,
|
||||
inefficiency: f64,
|
||||
}
|
||||
|
||||
impl Default for ModelSelector {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Agent for ModelSelector {
|
||||
fn id(&self) -> &AgentId {
|
||||
&self.id
|
||||
}
|
||||
|
||||
fn agent_type(&self) -> AgentType {
|
||||
AgentType::ModelSelector
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Selects optimal model for task based on complexity and budget (U-curve optimization)"
|
||||
}
|
||||
|
||||
/// Select the optimal model for a task.
|
||||
///
|
||||
/// # Expected Input
|
||||
/// Task should have complexity data in its context (from ComplexityEstimator).
|
||||
///
|
||||
/// # Returns
|
||||
/// AgentResult with ModelRecommendation in the `data` field.
|
||||
async fn execute(&self, task: &mut Task, ctx: &AgentContext) -> AgentResult {
|
||||
// Get complexity (default to moderate if not estimated)
|
||||
// In practice, this would be passed from ComplexityEstimator
|
||||
let complexity = 0.5; // TODO: Get from task metadata
|
||||
let estimated_tokens = 2000_u64;
|
||||
|
||||
// Get available budget
|
||||
let budget_cents = task.budget().remaining_cents();
|
||||
|
||||
// Fetch pricing for all models
|
||||
let models = ctx.pricing.models_by_cost().await;
|
||||
|
||||
if models.is_empty() {
|
||||
// Use hardcoded defaults if no pricing available
|
||||
return AgentResult::success(
|
||||
"Using default model (no pricing data available)",
|
||||
0,
|
||||
)
|
||||
.with_data(json!({
|
||||
"model_id": "openai/gpt-4.1-mini",
|
||||
"expected_cost_cents": 10,
|
||||
"confidence": 0.5,
|
||||
"reasoning": "Fallback to default model",
|
||||
"fallbacks": ["openai/gpt-4o-mini", "anthropic/claude-3-haiku"],
|
||||
}));
|
||||
}
|
||||
|
||||
match self.select_optimal(&models, complexity, estimated_tokens, budget_cents) {
|
||||
Some(rec) => {
|
||||
AgentResult::success(
|
||||
&rec.reasoning,
|
||||
1, // Minimal cost for selection itself
|
||||
)
|
||||
.with_data(json!({
|
||||
"model_id": rec.model_id,
|
||||
"expected_cost_cents": rec.expected_cost_cents,
|
||||
"confidence": rec.confidence,
|
||||
"reasoning": rec.reasoning,
|
||||
"fallbacks": rec.fallbacks,
|
||||
}))
|
||||
}
|
||||
None => {
|
||||
AgentResult::failure(
|
||||
"No suitable model found within budget",
|
||||
0,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LeafAgent for ModelSelector {
|
||||
fn capability(&self) -> LeafCapability {
|
||||
LeafCapability::ModelSelection
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_pricing(id: &str, prompt: f64, completion: f64) -> PricingInfo {
|
||||
PricingInfo {
|
||||
model_id: id.to_string(),
|
||||
prompt_cost_per_million: prompt,
|
||||
completion_cost_per_million: completion,
|
||||
context_length: 100000,
|
||||
max_output_tokens: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expected_cost_u_curve() {
|
||||
let selector = ModelSelector::new();
|
||||
|
||||
let cheap = make_pricing("cheap", 0.1, 0.2);
|
||||
let medium = make_pricing("medium", 1.0, 2.0);
|
||||
let expensive = make_pricing("expensive", 10.0, 20.0);
|
||||
|
||||
let complexity = 0.7;
|
||||
let tokens = 2000;
|
||||
|
||||
let cheap_cost = selector.calculate_expected_cost(&cheap, complexity, tokens);
|
||||
let medium_cost = selector.calculate_expected_cost(&medium, complexity, tokens);
|
||||
let expensive_cost = selector.calculate_expected_cost(&expensive, complexity, tokens);
|
||||
|
||||
// For complex tasks, medium should be optimal (U-curve)
|
||||
// Cheap model has high failure rate
|
||||
// Expensive model has high base cost
|
||||
println!("Cheap: {} (fail: {})", cheap_cost.expected_cost_cents, cheap_cost.failure_probability);
|
||||
println!("Medium: {} (fail: {})", medium_cost.expected_cost_cents, medium_cost.failure_probability);
|
||||
println!("Expensive: {} (fail: {})", expensive_cost.expected_cost_cents, expensive_cost.failure_probability);
|
||||
|
||||
// Basic sanity check: cheap model should have higher failure rate
|
||||
assert!(cheap_cost.failure_probability > medium_cost.failure_probability);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_optimal() {
|
||||
let selector = ModelSelector::new();
|
||||
|
||||
let models = vec![
|
||||
make_pricing("cheap", 0.1, 0.2),
|
||||
make_pricing("medium", 1.0, 2.0),
|
||||
make_pricing("expensive", 10.0, 20.0),
|
||||
];
|
||||
|
||||
// For moderate complexity, should pick cost-effective option
|
||||
let rec = selector.select_optimal(&models, 0.5, 1000, 1000);
|
||||
assert!(rec.is_some());
|
||||
|
||||
// For very low budget, might be forced to pick cheap
|
||||
let rec_low = selector.select_optimal(&models, 0.5, 1000, 1);
|
||||
assert!(rec_low.is_some());
|
||||
}
|
||||
}
|
||||
|
||||
347
src/agents/leaf/verifier.rs
Normal file
347
src/agents/leaf/verifier.rs
Normal file
@@ -0,0 +1,347 @@
|
||||
//! Verification agent - validates task completion.
|
||||
//!
|
||||
//! # Verification Strategy (Hybrid)
|
||||
//! 1. Try programmatic verification first (fast, deterministic)
|
||||
//! 2. Fall back to LLM verification if needed
|
||||
//!
|
||||
//! # Programmatic Checks
|
||||
//! - File exists
|
||||
//! - Command succeeds
|
||||
//! - Output matches pattern
|
||||
|
||||
use async_trait::async_trait;
|
||||
use std::path::Path;
|
||||
use std::process::Stdio;
|
||||
use tokio::process::Command;
|
||||
|
||||
use crate::agents::{
|
||||
Agent, AgentContext, AgentId, AgentResult, AgentType, LeafAgent, LeafCapability,
|
||||
};
|
||||
use crate::llm::{ChatMessage, Role};
|
||||
use crate::task::{ProgrammaticCheck, Task, VerificationCriteria, VerificationMethod, VerificationResult};
|
||||
|
||||
/// Agent that verifies task completion.
|
||||
///
|
||||
/// # Hybrid Verification
|
||||
/// - Programmatic: Fast, deterministic, no cost
|
||||
/// - LLM: Flexible, for subjective criteria
|
||||
pub struct Verifier {
|
||||
id: AgentId,
|
||||
}
|
||||
|
||||
impl Verifier {
|
||||
/// Create a new verifier.
|
||||
pub fn new() -> Self {
|
||||
Self { id: AgentId::new() }
|
||||
}
|
||||
|
||||
/// Execute a programmatic check.
|
||||
///
|
||||
/// # Returns
|
||||
/// `Ok(true)` if check passes, `Ok(false)` if fails, `Err` on error.
|
||||
async fn run_programmatic_check(
|
||||
&self,
|
||||
check: &ProgrammaticCheck,
|
||||
workspace: &Path,
|
||||
) -> Result<bool, String> {
|
||||
match check {
|
||||
ProgrammaticCheck::FileExists { path } => {
|
||||
let full_path = workspace.join(path);
|
||||
Ok(full_path.exists())
|
||||
}
|
||||
|
||||
ProgrammaticCheck::FileContains { path, content } => {
|
||||
let full_path = workspace.join(path);
|
||||
match tokio::fs::read_to_string(&full_path).await {
|
||||
Ok(file_content) => Ok(file_content.contains(content)),
|
||||
Err(_) => Ok(false),
|
||||
}
|
||||
}
|
||||
|
||||
ProgrammaticCheck::CommandSucceeds { command } => {
|
||||
let output = Command::new("sh")
|
||||
.arg("-c")
|
||||
.arg(command)
|
||||
.current_dir(workspace)
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null())
|
||||
.status()
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
Ok(output.success())
|
||||
}
|
||||
|
||||
ProgrammaticCheck::CommandOutputMatches { command, pattern } => {
|
||||
let output = Command::new("sh")
|
||||
.arg("-c")
|
||||
.arg(command)
|
||||
.current_dir(workspace)
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let regex = regex::Regex::new(pattern).map_err(|e| e.to_string())?;
|
||||
Ok(regex.is_match(&stdout))
|
||||
}
|
||||
|
||||
ProgrammaticCheck::DirectoryExists { path } => {
|
||||
let full_path = workspace.join(path);
|
||||
Ok(full_path.is_dir())
|
||||
}
|
||||
|
||||
ProgrammaticCheck::FileMatchesRegex { path, pattern } => {
|
||||
let full_path = workspace.join(path);
|
||||
match tokio::fs::read_to_string(&full_path).await {
|
||||
Ok(content) => {
|
||||
let regex = regex::Regex::new(pattern).map_err(|e| e.to_string())?;
|
||||
Ok(regex.is_match(&content))
|
||||
}
|
||||
Err(_) => Ok(false),
|
||||
}
|
||||
}
|
||||
|
||||
ProgrammaticCheck::All(checks) => {
|
||||
for c in checks {
|
||||
if !Box::pin(self.run_programmatic_check(c, workspace)).await? {
|
||||
return Ok(false);
|
||||
}
|
||||
}
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
ProgrammaticCheck::Any(checks) => {
|
||||
for c in checks {
|
||||
if Box::pin(self.run_programmatic_check(c, workspace)).await? {
|
||||
return Ok(true);
|
||||
}
|
||||
}
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Verify using LLM.
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `task`: The task that was executed
|
||||
/// - `success_criteria`: What success looks like
|
||||
/// - `ctx`: Agent context
|
||||
///
|
||||
/// # Returns
|
||||
/// VerificationResult with LLM's assessment
|
||||
async fn verify_with_llm(
|
||||
&self,
|
||||
task: &Task,
|
||||
success_criteria: &str,
|
||||
ctx: &AgentContext,
|
||||
) -> VerificationResult {
|
||||
let prompt = format!(
|
||||
r#"You are verifying if a task was completed correctly.
|
||||
|
||||
Task: {}
|
||||
|
||||
Success Criteria: {}
|
||||
|
||||
Based on your assessment, respond with a JSON object:
|
||||
{{
|
||||
"passed": true/false,
|
||||
"reasoning": "explanation of why the task passed or failed"
|
||||
}}
|
||||
|
||||
Respond ONLY with the JSON object."#,
|
||||
task.description(),
|
||||
success_criteria
|
||||
);
|
||||
|
||||
let messages = vec![
|
||||
ChatMessage {
|
||||
role: Role::System,
|
||||
content: Some("You are a precise task verifier. Respond only with JSON.".to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
ChatMessage {
|
||||
role: Role::User,
|
||||
content: Some(prompt),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
];
|
||||
|
||||
let model = "openai/gpt-4.1-mini";
|
||||
|
||||
match ctx.llm.chat_completion(model, &messages, None).await {
|
||||
Ok(response) => {
|
||||
let content = response.content.unwrap_or_default();
|
||||
self.parse_llm_verification(&content, model)
|
||||
}
|
||||
Err(e) => {
|
||||
VerificationResult::fail(
|
||||
format!("LLM verification failed: {}", e),
|
||||
VerificationMethod::Llm { model: model.to_string() },
|
||||
0,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse LLM verification response.
|
||||
fn parse_llm_verification(&self, response: &str, model: &str) -> VerificationResult {
|
||||
if let Ok(json) = serde_json::from_str::<serde_json::Value>(response) {
|
||||
let passed = json["passed"].as_bool().unwrap_or(false);
|
||||
let reasoning = json["reasoning"]
|
||||
.as_str()
|
||||
.unwrap_or("No reasoning provided")
|
||||
.to_string();
|
||||
|
||||
if passed {
|
||||
VerificationResult::pass(
|
||||
reasoning,
|
||||
VerificationMethod::Llm { model: model.to_string() },
|
||||
1, // Minimal cost
|
||||
)
|
||||
} else {
|
||||
VerificationResult::fail(
|
||||
reasoning,
|
||||
VerificationMethod::Llm { model: model.to_string() },
|
||||
1,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
// Try to infer from text
|
||||
let passed = response.to_lowercase().contains("pass")
|
||||
|| response.to_lowercase().contains("success")
|
||||
|| response.to_lowercase().contains("completed");
|
||||
|
||||
if passed {
|
||||
VerificationResult::pass(
|
||||
response.to_string(),
|
||||
VerificationMethod::Llm { model: model.to_string() },
|
||||
1,
|
||||
)
|
||||
} else {
|
||||
VerificationResult::fail(
|
||||
response.to_string(),
|
||||
VerificationMethod::Llm { model: model.to_string() },
|
||||
1,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Run verification according to criteria.
|
||||
async fn verify(
|
||||
&self,
|
||||
task: &Task,
|
||||
ctx: &AgentContext,
|
||||
) -> VerificationResult {
|
||||
match task.verification() {
|
||||
VerificationCriteria::None => {
|
||||
VerificationResult::pass(
|
||||
"No verification required",
|
||||
VerificationMethod::None,
|
||||
0,
|
||||
)
|
||||
}
|
||||
|
||||
VerificationCriteria::Programmatic(check) => {
|
||||
match self.run_programmatic_check(check, &ctx.workspace).await {
|
||||
Ok(true) => VerificationResult::pass(
|
||||
"Programmatic check passed",
|
||||
VerificationMethod::Programmatic,
|
||||
0,
|
||||
),
|
||||
Ok(false) => VerificationResult::fail(
|
||||
"Programmatic check failed",
|
||||
VerificationMethod::Programmatic,
|
||||
0,
|
||||
),
|
||||
Err(e) => VerificationResult::fail(
|
||||
format!("Programmatic check error: {}", e),
|
||||
VerificationMethod::Programmatic,
|
||||
0,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
VerificationCriteria::LlmBased { success_criteria } => {
|
||||
self.verify_with_llm(task, success_criteria, ctx).await
|
||||
}
|
||||
|
||||
VerificationCriteria::Hybrid { programmatic, llm_fallback } => {
|
||||
// Try programmatic first
|
||||
match self.run_programmatic_check(programmatic, &ctx.workspace).await {
|
||||
Ok(true) => VerificationResult::pass(
|
||||
"Programmatic check passed",
|
||||
VerificationMethod::Programmatic,
|
||||
0,
|
||||
),
|
||||
Ok(false) => {
|
||||
// Fall back to LLM
|
||||
self.verify_with_llm(task, llm_fallback, ctx).await
|
||||
}
|
||||
Err(_) => {
|
||||
// Error in programmatic, fall back to LLM
|
||||
self.verify_with_llm(task, llm_fallback, ctx).await
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Verifier {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Agent for Verifier {
|
||||
fn id(&self) -> &AgentId {
|
||||
&self.id
|
||||
}
|
||||
|
||||
fn agent_type(&self) -> AgentType {
|
||||
AgentType::Verifier
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Verifies task completion using programmatic checks and LLM fallback"
|
||||
}
|
||||
|
||||
async fn execute(&self, task: &mut Task, ctx: &AgentContext) -> AgentResult {
|
||||
let result = self.verify(task, ctx).await;
|
||||
|
||||
if result.passed() {
|
||||
AgentResult::success(
|
||||
result.reasoning(),
|
||||
result.cost_cents(),
|
||||
)
|
||||
.with_data(serde_json::json!({
|
||||
"passed": true,
|
||||
"method": format!("{:?}", result.method()),
|
||||
"reasoning": result.reasoning(),
|
||||
}))
|
||||
} else {
|
||||
AgentResult::failure(
|
||||
result.reasoning(),
|
||||
result.cost_cents(),
|
||||
)
|
||||
.with_data(serde_json::json!({
|
||||
"passed": false,
|
||||
"method": format!("{:?}", result.method()),
|
||||
"reasoning": result.reasoning(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LeafAgent for Verifier {
|
||||
fn capability(&self) -> LeafCapability {
|
||||
LeafCapability::Verification
|
||||
}
|
||||
}
|
||||
|
||||
139
src/agents/mod.rs
Normal file
139
src/agents/mod.rs
Normal file
@@ -0,0 +1,139 @@
|
||||
//! Agents module - the hierarchical agent tree.
|
||||
//!
|
||||
//! # Agent Types
|
||||
//! - **RootAgent**: Top-level orchestrator, receives tasks from API
|
||||
//! - **NodeAgent**: Intermediate orchestrator, delegates to children
|
||||
//! - **LeafAgent**: Specialized agents that do actual work
|
||||
//!
|
||||
//! # Leaf Agent Specializations
|
||||
//! - `ComplexityEstimator`: Estimates task difficulty
|
||||
//! - `ModelSelector`: Chooses optimal model for cost/capability
|
||||
//! - `TaskExecutor`: Executes tasks using tools
|
||||
//! - `Verifier`: Validates task completion
|
||||
//!
|
||||
//! # Design Principles
|
||||
//! - Agents communicate synchronously (parent calls child, child returns)
|
||||
//! - Designed for future async message passing migration
|
||||
//! - All operations return `Result` with meaningful errors
|
||||
|
||||
mod types;
|
||||
mod context;
|
||||
mod tree;
|
||||
pub mod orchestrator;
|
||||
pub mod leaf;
|
||||
|
||||
pub use types::{AgentId, AgentType, AgentResult, AgentError, Complexity};
|
||||
pub use context::AgentContext;
|
||||
pub use tree::{AgentTree, AgentRef};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use crate::task::Task;
|
||||
|
||||
/// Base trait for all agents.
|
||||
///
|
||||
/// # Invariants
|
||||
/// - `execute()` returns `Ok` only if the task was actually completed or delegated
|
||||
/// - `execute()` never panics; all errors are returned as `Err`
|
||||
///
|
||||
/// # Design for Provability
|
||||
/// - Preconditions and postconditions are documented
|
||||
/// - Pure functions are preferred where possible
|
||||
#[async_trait]
|
||||
pub trait Agent: Send + Sync {
|
||||
/// Get the unique identifier for this agent.
|
||||
fn id(&self) -> &AgentId;
|
||||
|
||||
/// Get the type/role of this agent.
|
||||
fn agent_type(&self) -> AgentType;
|
||||
|
||||
/// Execute a task.
|
||||
///
|
||||
/// # Preconditions
|
||||
/// - `task.budget().remaining_cents() > 0` (has budget)
|
||||
/// - `task.status() == Pending || task.status() == Running`
|
||||
///
|
||||
/// # Postconditions
|
||||
/// - On success: task is completed or delegated appropriately
|
||||
/// - `result.cost_cents <= task.budget().total_cents()`
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns `Err` if:
|
||||
/// - Task cannot be executed (insufficient budget, invalid state)
|
||||
/// - Execution fails (tool error, LLM error, etc.)
|
||||
async fn execute(&self, task: &mut Task, ctx: &AgentContext) -> AgentResult;
|
||||
|
||||
/// Get a human-readable description of this agent.
|
||||
fn description(&self) -> &str {
|
||||
"Generic agent"
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait for orchestrator agents (Root and Node) that can have children.
|
||||
///
|
||||
/// # Child Management
|
||||
/// Orchestrators can delegate work to child agents.
|
||||
#[async_trait]
|
||||
pub trait OrchestratorAgent: Agent {
|
||||
/// Get references to child agents.
|
||||
fn children(&self) -> Vec<AgentRef>;
|
||||
|
||||
/// Find a child agent by capability.
|
||||
fn find_child(&self, agent_type: AgentType) -> Option<AgentRef>;
|
||||
|
||||
/// Delegate a task to a specific child.
|
||||
///
|
||||
/// # Preconditions
|
||||
/// - Child exists and is capable of handling the task
|
||||
/// - Task has sufficient budget
|
||||
///
|
||||
/// # Postconditions
|
||||
/// - Child's execute() is called
|
||||
/// - Results are aggregated and returned
|
||||
async fn delegate(&self, task: &mut Task, child: AgentRef, ctx: &AgentContext) -> AgentResult;
|
||||
|
||||
/// Delegate multiple tasks to appropriate children.
|
||||
///
|
||||
/// # Preconditions
|
||||
/// - Sum of task budgets <= available budget
|
||||
/// - All tasks can be matched to capable children
|
||||
async fn delegate_all(
|
||||
&self,
|
||||
tasks: &mut [Task],
|
||||
ctx: &AgentContext,
|
||||
) -> Vec<AgentResult>;
|
||||
}
|
||||
|
||||
/// Trait for leaf agents with specialized capabilities.
|
||||
pub trait LeafAgent: Agent {
|
||||
/// Get the specific capability of this leaf agent.
|
||||
fn capability(&self) -> LeafCapability;
|
||||
}
|
||||
|
||||
/// Capabilities of leaf agents.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum LeafCapability {
|
||||
/// Can estimate task complexity
|
||||
ComplexityEstimation,
|
||||
|
||||
/// Can select optimal model for a task
|
||||
ModelSelection,
|
||||
|
||||
/// Can execute tasks using tools
|
||||
TaskExecution,
|
||||
|
||||
/// Can verify task completion
|
||||
Verification,
|
||||
}
|
||||
|
||||
impl LeafCapability {
|
||||
/// Get the agent type for this capability.
|
||||
pub fn agent_type(&self) -> AgentType {
|
||||
match self {
|
||||
Self::ComplexityEstimation => AgentType::ComplexityEstimator,
|
||||
Self::ModelSelection => AgentType::ModelSelector,
|
||||
Self::TaskExecution => AgentType::TaskExecutor,
|
||||
Self::Verification => AgentType::Verifier,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
8
src/agents/orchestrator/mod.rs
Normal file
8
src/agents/orchestrator/mod.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
//! Orchestrator agents - Root and Node agents that manage the tree.
|
||||
|
||||
mod root;
|
||||
mod node;
|
||||
|
||||
pub use root::RootAgent;
|
||||
pub use node::NodeAgent;
|
||||
|
||||
163
src/agents/orchestrator/node.rs
Normal file
163
src/agents/orchestrator/node.rs
Normal file
@@ -0,0 +1,163 @@
|
||||
//! Node agent - intermediate orchestrator in the agent tree.
|
||||
//!
|
||||
//! Node agents are like mini-root agents that can:
|
||||
//! - Receive delegated tasks from parent
|
||||
//! - Split complex subtasks further
|
||||
//! - Delegate to their own children
|
||||
//! - Aggregate results for parent
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::agents::{
|
||||
Agent, AgentContext, AgentId, AgentRef, AgentResult, AgentType, OrchestratorAgent,
|
||||
leaf::{TaskExecutor, Verifier},
|
||||
};
|
||||
use crate::task::Task;
|
||||
|
||||
/// Node agent - intermediate orchestrator.
|
||||
///
|
||||
/// # Purpose
|
||||
/// Handles subtasks that may still be complex enough
|
||||
/// to warrant further splitting.
|
||||
///
|
||||
/// # Differences from Root
|
||||
/// - No complexity estimation (parent already decided to split)
|
||||
/// - Simpler child set (just executor and verifier)
|
||||
/// - Limited split depth (prevents infinite recursion)
|
||||
pub struct NodeAgent {
|
||||
id: AgentId,
|
||||
|
||||
/// Name for identification in logs
|
||||
name: String,
|
||||
|
||||
// Child agents
|
||||
task_executor: Arc<TaskExecutor>,
|
||||
verifier: Arc<Verifier>,
|
||||
|
||||
// Child node agents (for further splitting)
|
||||
child_nodes: Vec<Arc<NodeAgent>>,
|
||||
}
|
||||
|
||||
impl NodeAgent {
|
||||
/// Create a new node agent.
|
||||
pub fn new(name: impl Into<String>) -> Self {
|
||||
Self {
|
||||
id: AgentId::new(),
|
||||
name: name.into(),
|
||||
task_executor: Arc::new(TaskExecutor::new()),
|
||||
verifier: Arc::new(Verifier::new()),
|
||||
child_nodes: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a node with custom executor.
|
||||
pub fn with_executor(mut self, executor: Arc<TaskExecutor>) -> Self {
|
||||
self.task_executor = executor;
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a child node for hierarchical delegation.
|
||||
pub fn add_child_node(&mut self, child: Arc<NodeAgent>) {
|
||||
self.child_nodes.push(child);
|
||||
}
|
||||
|
||||
/// Get the node's name.
|
||||
pub fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for NodeAgent {
|
||||
fn default() -> Self {
|
||||
Self::new("node")
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Agent for NodeAgent {
|
||||
fn id(&self) -> &AgentId {
|
||||
&self.id
|
||||
}
|
||||
|
||||
fn agent_type(&self) -> AgentType {
|
||||
AgentType::Node
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Intermediate orchestrator for subtask delegation"
|
||||
}
|
||||
|
||||
async fn execute(&self, task: &mut Task, ctx: &AgentContext) -> AgentResult {
|
||||
tracing::debug!("NodeAgent '{}' executing task", self.name);
|
||||
|
||||
// Execute the task
|
||||
let result = self.task_executor.execute(task, ctx).await;
|
||||
|
||||
if !result.success {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Verify if criteria specified
|
||||
let verification = self.verifier.execute(task, ctx).await;
|
||||
|
||||
if verification.success {
|
||||
result
|
||||
} else {
|
||||
AgentResult::failure(
|
||||
format!(
|
||||
"Task completed but verification failed: {}",
|
||||
verification.output
|
||||
),
|
||||
result.cost_cents + verification.cost_cents,
|
||||
)
|
||||
.with_data(json!({
|
||||
"execution": result.data,
|
||||
"verification": verification.data,
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl OrchestratorAgent for NodeAgent {
|
||||
fn children(&self) -> Vec<AgentRef> {
|
||||
let mut children: Vec<AgentRef> = vec![
|
||||
Arc::clone(&self.task_executor) as AgentRef,
|
||||
Arc::clone(&self.verifier) as AgentRef,
|
||||
];
|
||||
|
||||
for node in &self.child_nodes {
|
||||
children.push(Arc::clone(node) as AgentRef);
|
||||
}
|
||||
|
||||
children
|
||||
}
|
||||
|
||||
fn find_child(&self, agent_type: AgentType) -> Option<AgentRef> {
|
||||
match agent_type {
|
||||
AgentType::TaskExecutor => Some(Arc::clone(&self.task_executor) as AgentRef),
|
||||
AgentType::Verifier => Some(Arc::clone(&self.verifier) as AgentRef),
|
||||
AgentType::Node => self.child_nodes.first().map(|n| Arc::clone(n) as AgentRef),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn delegate(&self, task: &mut Task, child: AgentRef, ctx: &AgentContext) -> AgentResult {
|
||||
child.execute(task, ctx).await
|
||||
}
|
||||
|
||||
async fn delegate_all(&self, tasks: &mut [Task], ctx: &AgentContext) -> Vec<AgentResult> {
|
||||
let mut results = Vec::with_capacity(tasks.len());
|
||||
|
||||
for task in tasks {
|
||||
let result = self.task_executor.execute(task, ctx).await;
|
||||
results.push(result);
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
}
|
||||
|
||||
346
src/agents/orchestrator/root.rs
Normal file
346
src/agents/orchestrator/root.rs
Normal file
@@ -0,0 +1,346 @@
|
||||
//! Root agent - top-level orchestrator of the agent tree.
|
||||
//!
|
||||
//! # Responsibilities
|
||||
//! 1. Receive tasks from the API
|
||||
//! 2. Estimate complexity
|
||||
//! 3. Decide: execute directly or split into subtasks
|
||||
//! 4. Delegate to appropriate children
|
||||
//! 5. Aggregate results
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::agents::{
|
||||
Agent, AgentContext, AgentId, AgentRef, AgentResult, AgentType, Complexity,
|
||||
OrchestratorAgent,
|
||||
leaf::{ComplexityEstimator, ModelSelector, TaskExecutor, Verifier},
|
||||
};
|
||||
use crate::budget::Budget;
|
||||
use crate::task::{Task, Subtask, SubtaskPlan, VerificationCriteria};
|
||||
|
||||
/// Root agent - the top of the agent tree.
|
||||
///
|
||||
/// # Task Processing Flow
|
||||
/// ```text
|
||||
/// 1. Estimate complexity (ComplexityEstimator)
|
||||
/// 2. If simple: execute directly (TaskExecutor)
|
||||
/// 3. If complex:
|
||||
/// a. Split into subtasks (LLM-based)
|
||||
/// b. Select model for each subtask (ModelSelector)
|
||||
/// c. Execute subtasks (TaskExecutor)
|
||||
/// d. Verify results (Verifier)
|
||||
/// 4. Return aggregated result
|
||||
/// ```
|
||||
pub struct RootAgent {
|
||||
id: AgentId,
|
||||
|
||||
// Child agents
|
||||
complexity_estimator: Arc<ComplexityEstimator>,
|
||||
model_selector: Arc<ModelSelector>,
|
||||
task_executor: Arc<TaskExecutor>,
|
||||
verifier: Arc<Verifier>,
|
||||
}
|
||||
|
||||
impl RootAgent {
|
||||
/// Create a new root agent with default children.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
id: AgentId::new(),
|
||||
complexity_estimator: Arc::new(ComplexityEstimator::new()),
|
||||
model_selector: Arc::new(ModelSelector::new()),
|
||||
task_executor: Arc::new(TaskExecutor::new()),
|
||||
verifier: Arc::new(Verifier::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Estimate complexity of a task.
|
||||
async fn estimate_complexity(&self, task: &mut Task, ctx: &AgentContext) -> Complexity {
|
||||
let result = self.complexity_estimator.execute(task, ctx).await;
|
||||
|
||||
if let Some(data) = result.data {
|
||||
let score = data["score"].as_f64().unwrap_or(0.5);
|
||||
let reasoning = data["reasoning"].as_str().unwrap_or("").to_string();
|
||||
let estimated_tokens = data["estimated_tokens"].as_u64().unwrap_or(2000);
|
||||
let should_split = data["should_split"].as_bool().unwrap_or(false);
|
||||
|
||||
Complexity::new(score, reasoning, estimated_tokens)
|
||||
.with_split(should_split)
|
||||
} else {
|
||||
Complexity::moderate("Could not estimate complexity")
|
||||
}
|
||||
}
|
||||
|
||||
/// Split a complex task into subtasks.
|
||||
///
|
||||
/// Uses LLM to analyze the task and propose subtasks.
|
||||
async fn split_task(&self, task: &Task, ctx: &AgentContext) -> Result<SubtaskPlan, AgentResult> {
|
||||
let prompt = format!(
|
||||
r#"You are a task planner. Break down this task into smaller, manageable subtasks.
|
||||
|
||||
Task: {}
|
||||
|
||||
Respond with a JSON object:
|
||||
{{
|
||||
"subtasks": [
|
||||
{{
|
||||
"description": "What to do",
|
||||
"verification": "How to verify it's done",
|
||||
"weight": 1.0
|
||||
}}
|
||||
],
|
||||
"reasoning": "Why this breakdown makes sense"
|
||||
}}
|
||||
|
||||
Guidelines:
|
||||
- Each subtask should be independently executable
|
||||
- Include verification for each subtask
|
||||
- Weight indicates relative effort (higher = more work)
|
||||
- Keep subtasks focused and specific
|
||||
|
||||
Respond ONLY with the JSON object."#,
|
||||
task.description()
|
||||
);
|
||||
|
||||
let messages = vec![
|
||||
crate::llm::ChatMessage {
|
||||
role: crate::llm::Role::System,
|
||||
content: Some("You are a precise task planner. Respond only with JSON.".to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
crate::llm::ChatMessage {
|
||||
role: crate::llm::Role::User,
|
||||
content: Some(prompt),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
];
|
||||
|
||||
let response = ctx.llm
|
||||
.chat_completion("openai/gpt-4.1-mini", &messages, None)
|
||||
.await
|
||||
.map_err(|e| AgentResult::failure(format!("LLM error: {}", e), 1))?;
|
||||
|
||||
let content = response.content.unwrap_or_default();
|
||||
self.parse_subtask_plan(&content, task.id())
|
||||
}
|
||||
|
||||
/// Parse LLM response into SubtaskPlan.
|
||||
fn parse_subtask_plan(
|
||||
&self,
|
||||
response: &str,
|
||||
parent_id: crate::task::TaskId,
|
||||
) -> Result<SubtaskPlan, AgentResult> {
|
||||
let json: serde_json::Value = serde_json::from_str(response)
|
||||
.map_err(|e| AgentResult::failure(format!("Failed to parse subtasks: {}", e), 0))?;
|
||||
|
||||
let reasoning = json["reasoning"]
|
||||
.as_str()
|
||||
.unwrap_or("No reasoning provided")
|
||||
.to_string();
|
||||
|
||||
let subtasks: Vec<Subtask> = json["subtasks"]
|
||||
.as_array()
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.map(|s| {
|
||||
let desc = s["description"].as_str().unwrap_or("").to_string();
|
||||
let verification = s["verification"].as_str().unwrap_or("");
|
||||
let weight = s["weight"].as_f64().unwrap_or(1.0);
|
||||
|
||||
Subtask::new(
|
||||
desc,
|
||||
VerificationCriteria::llm_based(verification),
|
||||
weight,
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
if subtasks.is_empty() {
|
||||
return Err(AgentResult::failure("No subtasks generated", 1));
|
||||
}
|
||||
|
||||
SubtaskPlan::new(parent_id, subtasks, reasoning)
|
||||
.map_err(|e| AgentResult::failure(format!("Invalid subtask plan: {}", e), 0))
|
||||
}
|
||||
|
||||
/// Execute subtasks and aggregate results.
|
||||
async fn execute_subtasks(
|
||||
&self,
|
||||
subtask_plan: SubtaskPlan,
|
||||
parent_budget: &Budget,
|
||||
ctx: &AgentContext,
|
||||
) -> AgentResult {
|
||||
// Convert plan to tasks
|
||||
let mut tasks = match subtask_plan.into_tasks(parent_budget) {
|
||||
Ok(t) => t,
|
||||
Err(e) => return AgentResult::failure(format!("Failed to create subtasks: {}", e), 0),
|
||||
};
|
||||
|
||||
let mut results = Vec::new();
|
||||
let mut total_cost = 0u64;
|
||||
|
||||
// Execute each subtask
|
||||
for task in &mut tasks {
|
||||
let result = self.task_executor.execute(task, ctx).await;
|
||||
total_cost += result.cost_cents;
|
||||
results.push(result);
|
||||
}
|
||||
|
||||
// Aggregate results
|
||||
let successes = results.iter().filter(|r| r.success).count();
|
||||
let total = results.len();
|
||||
|
||||
if successes == total {
|
||||
AgentResult::success(
|
||||
format!("All {} subtasks completed successfully", total),
|
||||
total_cost,
|
||||
)
|
||||
.with_data(json!({
|
||||
"subtasks_total": total,
|
||||
"subtasks_succeeded": successes,
|
||||
"results": results.iter().map(|r| &r.output).collect::<Vec<_>>(),
|
||||
}))
|
||||
} else {
|
||||
AgentResult::failure(
|
||||
format!("{}/{} subtasks succeeded", successes, total),
|
||||
total_cost,
|
||||
)
|
||||
.with_data(json!({
|
||||
"subtasks_total": total,
|
||||
"subtasks_succeeded": successes,
|
||||
"results": results.iter().map(|r| json!({
|
||||
"success": r.success,
|
||||
"output": &r.output,
|
||||
})).collect::<Vec<_>>(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RootAgent {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Agent for RootAgent {
|
||||
fn id(&self) -> &AgentId {
|
||||
&self.id
|
||||
}
|
||||
|
||||
fn agent_type(&self) -> AgentType {
|
||||
AgentType::Root
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Root orchestrator: estimates complexity, splits tasks, delegates execution"
|
||||
}
|
||||
|
||||
async fn execute(&self, task: &mut Task, ctx: &AgentContext) -> AgentResult {
|
||||
let mut total_cost = 0u64;
|
||||
|
||||
// Step 1: Estimate complexity
|
||||
let complexity = self.estimate_complexity(task, ctx).await;
|
||||
total_cost += 1; // Complexity estimation cost
|
||||
|
||||
tracing::info!(
|
||||
"Task complexity: {:.2} (should_split: {})",
|
||||
complexity.score(),
|
||||
complexity.should_split()
|
||||
);
|
||||
|
||||
// Step 2: Decide execution strategy
|
||||
if complexity.should_split() && ctx.can_split() {
|
||||
// Complex task: split and delegate
|
||||
match self.split_task(task, ctx).await {
|
||||
Ok(plan) => {
|
||||
total_cost += 2; // Splitting cost
|
||||
|
||||
// Execute subtasks
|
||||
let child_ctx = ctx.child_context();
|
||||
let result = self.execute_subtasks(plan, task.budget(), &child_ctx).await;
|
||||
|
||||
return AgentResult {
|
||||
success: result.success,
|
||||
output: result.output,
|
||||
cost_cents: total_cost + result.cost_cents,
|
||||
model_used: result.model_used,
|
||||
data: result.data,
|
||||
};
|
||||
}
|
||||
Err(e) => {
|
||||
// Couldn't split, fall back to direct execution
|
||||
tracing::warn!("Couldn't split task, executing directly: {}", e.output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Simple task or failed to split: execute directly
|
||||
let result = self.task_executor.execute(task, ctx).await;
|
||||
|
||||
// Step 3: Verify (if verification criteria specified)
|
||||
let verification = self.verifier.execute(task, ctx).await;
|
||||
total_cost += verification.cost_cents;
|
||||
|
||||
AgentResult {
|
||||
success: result.success && verification.success,
|
||||
output: if verification.success {
|
||||
result.output
|
||||
} else {
|
||||
format!("{}\n\nVerification failed: {}", result.output, verification.output)
|
||||
},
|
||||
cost_cents: total_cost + result.cost_cents,
|
||||
model_used: result.model_used,
|
||||
data: json!({
|
||||
"complexity": complexity.score(),
|
||||
"was_split": false,
|
||||
"verification": verification.data,
|
||||
"execution": result.data,
|
||||
}).into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl OrchestratorAgent for RootAgent {
|
||||
fn children(&self) -> Vec<AgentRef> {
|
||||
vec![
|
||||
Arc::clone(&self.complexity_estimator) as AgentRef,
|
||||
Arc::clone(&self.model_selector) as AgentRef,
|
||||
Arc::clone(&self.task_executor) as AgentRef,
|
||||
Arc::clone(&self.verifier) as AgentRef,
|
||||
]
|
||||
}
|
||||
|
||||
fn find_child(&self, agent_type: AgentType) -> Option<AgentRef> {
|
||||
match agent_type {
|
||||
AgentType::ComplexityEstimator => Some(Arc::clone(&self.complexity_estimator) as AgentRef),
|
||||
AgentType::ModelSelector => Some(Arc::clone(&self.model_selector) as AgentRef),
|
||||
AgentType::TaskExecutor => Some(Arc::clone(&self.task_executor) as AgentRef),
|
||||
AgentType::Verifier => Some(Arc::clone(&self.verifier) as AgentRef),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn delegate(&self, task: &mut Task, child: AgentRef, ctx: &AgentContext) -> AgentResult {
|
||||
child.execute(task, ctx).await
|
||||
}
|
||||
|
||||
async fn delegate_all(&self, tasks: &mut [Task], ctx: &AgentContext) -> Vec<AgentResult> {
|
||||
let mut results = Vec::with_capacity(tasks.len());
|
||||
|
||||
for task in tasks {
|
||||
let result = self.task_executor.execute(task, ctx).await;
|
||||
results.push(result);
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
}
|
||||
|
||||
150
src/agents/tree.rs
Normal file
150
src/agents/tree.rs
Normal file
@@ -0,0 +1,150 @@
|
||||
//! Agent tree structure and management.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::{Agent, AgentId, AgentType};
|
||||
|
||||
/// Reference to an agent in the tree.
|
||||
pub type AgentRef = Arc<dyn Agent>;
|
||||
|
||||
/// The agent tree structure.
|
||||
///
|
||||
/// # Structure
|
||||
/// - Root agent at the top
|
||||
/// - Node agents as intermediate orchestrators
|
||||
/// - Leaf agents doing specialized work
|
||||
///
|
||||
/// # Invariants
|
||||
/// - Exactly one root agent
|
||||
/// - All non-root agents have a parent
|
||||
/// - No cycles in parent-child relationships
|
||||
pub struct AgentTree {
|
||||
/// All agents indexed by ID
|
||||
agents: HashMap<AgentId, AgentRef>,
|
||||
|
||||
/// Parent-child relationships
|
||||
children: HashMap<AgentId, Vec<AgentId>>,
|
||||
|
||||
/// Root agent ID
|
||||
root_id: Option<AgentId>,
|
||||
}
|
||||
|
||||
impl AgentTree {
|
||||
/// Create a new empty agent tree.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
agents: HashMap::new(),
|
||||
children: HashMap::new(),
|
||||
root_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the root agent.
|
||||
///
|
||||
/// # Preconditions
|
||||
/// - No root agent currently set
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns error if root already exists.
|
||||
pub fn set_root(&mut self, agent: AgentRef) -> Result<(), TreeError> {
|
||||
if self.root_id.is_some() {
|
||||
return Err(TreeError::RootAlreadyExists);
|
||||
}
|
||||
|
||||
let id = *agent.id();
|
||||
self.agents.insert(id, agent);
|
||||
self.children.insert(id, Vec::new());
|
||||
self.root_id = Some(id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Add a child agent to a parent.
|
||||
///
|
||||
/// # Preconditions
|
||||
/// - Parent exists in the tree
|
||||
/// - Child is not already in the tree
|
||||
pub fn add_child(&mut self, parent_id: AgentId, child: AgentRef) -> Result<(), TreeError> {
|
||||
if !self.agents.contains_key(&parent_id) {
|
||||
return Err(TreeError::ParentNotFound(parent_id));
|
||||
}
|
||||
|
||||
let child_id = *child.id();
|
||||
|
||||
if self.agents.contains_key(&child_id) {
|
||||
return Err(TreeError::AgentAlreadyExists(child_id));
|
||||
}
|
||||
|
||||
self.agents.insert(child_id, child);
|
||||
self.children.insert(child_id, Vec::new());
|
||||
self.children.get_mut(&parent_id).unwrap().push(child_id);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the root agent.
|
||||
pub fn root(&self) -> Option<AgentRef> {
|
||||
self.root_id.and_then(|id| self.agents.get(&id).cloned())
|
||||
}
|
||||
|
||||
/// Get an agent by ID.
|
||||
pub fn get(&self, id: &AgentId) -> Option<AgentRef> {
|
||||
self.agents.get(id).cloned()
|
||||
}
|
||||
|
||||
/// Get children of an agent.
|
||||
pub fn get_children(&self, id: &AgentId) -> Vec<AgentRef> {
|
||||
self.children
|
||||
.get(id)
|
||||
.map(|ids| {
|
||||
ids.iter()
|
||||
.filter_map(|id| self.agents.get(id).cloned())
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Find agents by type.
|
||||
pub fn find_by_type(&self, agent_type: AgentType) -> Vec<AgentRef> {
|
||||
self.agents
|
||||
.values()
|
||||
.filter(|a| a.agent_type() == agent_type)
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get all agents in the tree.
|
||||
pub fn all_agents(&self) -> Vec<AgentRef> {
|
||||
self.agents.values().cloned().collect()
|
||||
}
|
||||
|
||||
/// Get the number of agents in the tree.
|
||||
pub fn len(&self) -> usize {
|
||||
self.agents.len()
|
||||
}
|
||||
|
||||
/// Check if the tree is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.agents.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AgentTree {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Errors in tree operations.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum TreeError {
|
||||
#[error("Root agent already exists")]
|
||||
RootAlreadyExists,
|
||||
|
||||
#[error("Parent agent not found: {0}")]
|
||||
ParentNotFound(AgentId),
|
||||
|
||||
#[error("Agent already exists in tree: {0}")]
|
||||
AgentAlreadyExists(AgentId),
|
||||
}
|
||||
|
||||
248
src/agents/types.rs
Normal file
248
src/agents/types.rs
Normal file
@@ -0,0 +1,248 @@
|
||||
//! Core types for the agent system.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Unique identifier for an agent.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct AgentId(Uuid);
|
||||
|
||||
impl AgentId {
|
||||
/// Create a new unique agent ID.
|
||||
pub fn new() -> Self {
|
||||
Self(Uuid::new_v4())
|
||||
}
|
||||
|
||||
/// Create an agent ID from a string (for testing).
|
||||
pub fn from_str(s: &str) -> Self {
|
||||
Self(Uuid::parse_str(s).unwrap_or_else(|_| Uuid::new_v4()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AgentId {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for AgentId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Type of agent in the hierarchy.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub enum AgentType {
|
||||
/// Root orchestrator (top of tree)
|
||||
Root,
|
||||
/// Intermediate orchestrator (can have children)
|
||||
Node,
|
||||
/// Estimates task complexity
|
||||
ComplexityEstimator,
|
||||
/// Selects optimal model
|
||||
ModelSelector,
|
||||
/// Executes tasks using tools
|
||||
TaskExecutor,
|
||||
/// Verifies task completion
|
||||
Verifier,
|
||||
}
|
||||
|
||||
impl AgentType {
|
||||
/// Check if this is an orchestrator type (can have children).
|
||||
pub fn is_orchestrator(&self) -> bool {
|
||||
matches!(self, Self::Root | Self::Node)
|
||||
}
|
||||
|
||||
/// Check if this is a leaf type.
|
||||
pub fn is_leaf(&self) -> bool {
|
||||
!self.is_orchestrator()
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of an agent executing a task.
|
||||
///
|
||||
/// # Invariants
|
||||
/// - If `success == true`, the task was completed
|
||||
/// - `cost_cents` reflects actual cost incurred
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AgentResult {
|
||||
/// Whether the task was successful
|
||||
pub success: bool,
|
||||
|
||||
/// Output or response from the agent
|
||||
pub output: String,
|
||||
|
||||
/// Cost incurred in cents
|
||||
pub cost_cents: u64,
|
||||
|
||||
/// Model used (if any)
|
||||
pub model_used: Option<String>,
|
||||
|
||||
/// Detailed result data (type-specific)
|
||||
pub data: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl AgentResult {
|
||||
/// Create a successful result.
|
||||
pub fn success(output: impl Into<String>, cost_cents: u64) -> Self {
|
||||
Self {
|
||||
success: true,
|
||||
output: output.into(),
|
||||
cost_cents,
|
||||
model_used: None,
|
||||
data: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a failure result.
|
||||
pub fn failure(error: impl Into<String>, cost_cents: u64) -> Self {
|
||||
Self {
|
||||
success: false,
|
||||
output: error.into(),
|
||||
cost_cents,
|
||||
model_used: None,
|
||||
data: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add model information to the result.
|
||||
pub fn with_model(mut self, model: impl Into<String>) -> Self {
|
||||
self.model_used = Some(model.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Add additional data to the result.
|
||||
pub fn with_data(mut self, data: serde_json::Value) -> Self {
|
||||
self.data = Some(data);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Complexity estimation for a task.
|
||||
///
|
||||
/// # Invariants
|
||||
/// - `score` is in range [0.0, 1.0]
|
||||
/// - `should_split` is derived from score threshold
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Complexity {
|
||||
/// Complexity score: 0.0 = trivial, 1.0 = extremely complex
|
||||
score: f64,
|
||||
|
||||
/// Human-readable explanation
|
||||
reasoning: String,
|
||||
|
||||
/// Whether the task should be split into subtasks
|
||||
should_split: bool,
|
||||
|
||||
/// Estimated token count for this task
|
||||
estimated_tokens: u64,
|
||||
}
|
||||
|
||||
impl Complexity {
|
||||
/// Create a new complexity estimate.
|
||||
///
|
||||
/// # Preconditions
|
||||
/// - `score` is in [0.0, 1.0] (will be clamped if not)
|
||||
///
|
||||
/// # Postconditions
|
||||
/// - `self.score` is in [0.0, 1.0]
|
||||
/// - `self.should_split` is true if score > threshold (0.6)
|
||||
pub fn new(score: f64, reasoning: impl Into<String>, estimated_tokens: u64) -> Self {
|
||||
let clamped_score = score.clamp(0.0, 1.0);
|
||||
Self {
|
||||
score: clamped_score,
|
||||
reasoning: reasoning.into(),
|
||||
should_split: clamped_score > Self::SPLIT_THRESHOLD,
|
||||
estimated_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
/// Threshold above which tasks should be split.
|
||||
pub const SPLIT_THRESHOLD: f64 = 0.6;
|
||||
|
||||
/// Get the complexity score.
|
||||
pub fn score(&self) -> f64 {
|
||||
self.score
|
||||
}
|
||||
|
||||
/// Get the reasoning explanation.
|
||||
pub fn reasoning(&self) -> &str {
|
||||
&self.reasoning
|
||||
}
|
||||
|
||||
/// Check if the task should be split.
|
||||
pub fn should_split(&self) -> bool {
|
||||
self.should_split
|
||||
}
|
||||
|
||||
/// Get estimated token count.
|
||||
pub fn estimated_tokens(&self) -> u64 {
|
||||
self.estimated_tokens
|
||||
}
|
||||
|
||||
/// Create a simple (low complexity) estimate.
|
||||
pub fn simple(reasoning: impl Into<String>) -> Self {
|
||||
Self::new(0.2, reasoning, 500)
|
||||
}
|
||||
|
||||
/// Create a moderate complexity estimate.
|
||||
pub fn moderate(reasoning: impl Into<String>) -> Self {
|
||||
Self::new(0.5, reasoning, 2000)
|
||||
}
|
||||
|
||||
/// Create a complex estimate that should be split.
|
||||
pub fn complex(reasoning: impl Into<String>) -> Self {
|
||||
Self::new(0.8, reasoning, 5000)
|
||||
}
|
||||
|
||||
/// Override the should_split decision.
|
||||
pub fn with_split(mut self, should_split: bool) -> Self {
|
||||
self.should_split = should_split;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Errors that can occur in agent operations.
|
||||
#[derive(Debug, Clone, thiserror::Error)]
|
||||
pub enum AgentError {
|
||||
#[error("Task error: {0}")]
|
||||
TaskError(String),
|
||||
|
||||
#[error("Budget exhausted: needed {needed} cents, had {available} cents")]
|
||||
BudgetExhausted { needed: u64, available: u64 },
|
||||
|
||||
#[error("No capable agent found for task")]
|
||||
NoCapableAgent,
|
||||
|
||||
#[error("LLM error: {0}")]
|
||||
LlmError(String),
|
||||
|
||||
#[error("Tool error: {0}")]
|
||||
ToolError(String),
|
||||
|
||||
#[error("Verification failed: {0}")]
|
||||
VerificationFailed(String),
|
||||
|
||||
#[error("Max iterations reached: {0}")]
|
||||
MaxIterations(usize),
|
||||
|
||||
#[error("Internal error: {0}")]
|
||||
Internal(String),
|
||||
}
|
||||
|
||||
impl From<crate::task::TaskError> for AgentError {
|
||||
fn from(e: crate::task::TaskError) -> Self {
|
||||
Self::TaskError(e.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crate::budget::BudgetError> for AgentError {
|
||||
fn from(_e: crate::budget::BudgetError) -> Self {
|
||||
Self::BudgetExhausted {
|
||||
needed: 0,
|
||||
available: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,8 +19,12 @@ use tower_http::cors::CorsLayer;
|
||||
use tower_http::trace::TraceLayer;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::agent::Agent;
|
||||
use crate::agents::{Agent, AgentContext, AgentRef};
|
||||
use crate::agents::orchestrator::RootAgent;
|
||||
use crate::budget::ModelPricing;
|
||||
use crate::config::Config;
|
||||
use crate::llm::OpenRouterClient;
|
||||
use crate::tools::ToolRegistry;
|
||||
|
||||
use super::types::*;
|
||||
|
||||
@@ -28,17 +32,36 @@ use super::types::*;
|
||||
pub struct AppState {
|
||||
pub config: Config,
|
||||
pub tasks: RwLock<HashMap<Uuid, TaskState>>,
|
||||
pub agent: Agent,
|
||||
/// The hierarchical root agent
|
||||
pub root_agent: AgentRef,
|
||||
/// Shared context for agent execution
|
||||
pub agent_context: AgentContext,
|
||||
}
|
||||
|
||||
/// Start the HTTP server.
|
||||
pub async fn serve(config: Config) -> anyhow::Result<()> {
|
||||
let agent = Agent::new(config.clone());
|
||||
// Create the root agent (hierarchical)
|
||||
let root_agent: AgentRef = Arc::new(RootAgent::new());
|
||||
|
||||
// Create shared agent context
|
||||
let llm = Arc::new(OpenRouterClient::new(config.api_key.clone()));
|
||||
let tools = ToolRegistry::new();
|
||||
let pricing = Arc::new(ModelPricing::new());
|
||||
let workspace = config.workspace_path.clone();
|
||||
|
||||
let agent_context = AgentContext::new(
|
||||
config.clone(),
|
||||
llm,
|
||||
tools,
|
||||
pricing,
|
||||
workspace,
|
||||
);
|
||||
|
||||
let state = Arc::new(AppState {
|
||||
config: config.clone(),
|
||||
tasks: RwLock::new(HashMap::new()),
|
||||
agent,
|
||||
root_agent,
|
||||
agent_context,
|
||||
});
|
||||
|
||||
let app = Router::new()
|
||||
@@ -112,8 +135,8 @@ async fn create_task(
|
||||
async fn run_agent_task(
|
||||
state: Arc<AppState>,
|
||||
task_id: Uuid,
|
||||
task: String,
|
||||
model: String,
|
||||
task_description: String,
|
||||
_model: String,
|
||||
workspace_path: std::path::PathBuf,
|
||||
) {
|
||||
// Update status to running
|
||||
@@ -124,24 +147,77 @@ async fn run_agent_task(
|
||||
}
|
||||
}
|
||||
|
||||
// Run the agent
|
||||
let result = state.agent.run_task(&task, &model, &workspace_path).await;
|
||||
// Create a Task object for the hierarchical agent
|
||||
let budget = crate::budget::Budget::new(1000); // $10 default budget
|
||||
let verification = crate::task::VerificationCriteria::None;
|
||||
|
||||
let task_result = crate::task::Task::new(
|
||||
task_description.clone(),
|
||||
verification,
|
||||
budget,
|
||||
);
|
||||
|
||||
let mut task = match task_result {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
let mut tasks = state.tasks.write().await;
|
||||
if let Some(task_state) = tasks.get_mut(&task_id) {
|
||||
task_state.status = TaskStatus::Failed;
|
||||
task_state.result = Some(format!("Failed to create task: {}", e));
|
||||
}
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Create context with the specified workspace
|
||||
let llm = Arc::new(OpenRouterClient::new(state.config.api_key.clone()));
|
||||
let tools = ToolRegistry::new();
|
||||
let pricing = Arc::new(ModelPricing::new());
|
||||
|
||||
let ctx = AgentContext::new(
|
||||
state.config.clone(),
|
||||
llm,
|
||||
tools,
|
||||
pricing,
|
||||
workspace_path,
|
||||
);
|
||||
|
||||
// Run the hierarchical agent
|
||||
let result = state.root_agent.execute(&mut task, &ctx).await;
|
||||
|
||||
// Update task with result
|
||||
{
|
||||
let mut tasks = state.tasks.write().await;
|
||||
if let Some(task_state) = tasks.get_mut(&task_id) {
|
||||
match result {
|
||||
Ok((response, log)) => {
|
||||
task_state.status = TaskStatus::Completed;
|
||||
task_state.result = Some(response);
|
||||
task_state.log = log;
|
||||
}
|
||||
Err(e) => {
|
||||
task_state.status = TaskStatus::Failed;
|
||||
task_state.result = Some(format!("Error: {}", e));
|
||||
// Add log entries from result data
|
||||
if let Some(data) = &result.data {
|
||||
if let Some(tools_used) = data.get("tools_used") {
|
||||
if let Some(arr) = tools_used.as_array() {
|
||||
for tool in arr {
|
||||
task_state.log.push(TaskLogEntry {
|
||||
timestamp: "0".to_string(),
|
||||
entry_type: LogEntryType::ToolCall,
|
||||
content: tool.as_str().unwrap_or("").to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add final response log
|
||||
task_state.log.push(TaskLogEntry {
|
||||
timestamp: "0".to_string(),
|
||||
entry_type: LogEntryType::Response,
|
||||
content: result.output.clone(),
|
||||
});
|
||||
|
||||
if result.success {
|
||||
task_state.status = TaskStatus::Completed;
|
||||
task_state.result = Some(result.output);
|
||||
} else {
|
||||
task_state.status = TaskStatus::Failed;
|
||||
task_state.result = Some(format!("Error: {}", result.output));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -217,4 +293,3 @@ async fn stream_task(
|
||||
|
||||
Ok(Sse::new(stream))
|
||||
}
|
||||
|
||||
|
||||
205
src/budget/allocation.rs
Normal file
205
src/budget/allocation.rs
Normal file
@@ -0,0 +1,205 @@
|
||||
//! Budget allocation strategies for subtasks.
|
||||
//!
|
||||
//! # Strategies
|
||||
//! - Proportional: allocate based on estimated complexity
|
||||
//! - Equal: split evenly among subtasks
|
||||
//! - Priority: allocate more to critical subtasks
|
||||
|
||||
use crate::task::Subtask;
|
||||
|
||||
/// Strategy for allocating budget across subtasks.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum AllocationStrategy {
|
||||
/// Allocate proportionally based on subtask weights
|
||||
Proportional,
|
||||
|
||||
/// Allocate equally among all subtasks
|
||||
Equal,
|
||||
|
||||
/// Allocate based on priority (first subtasks get more)
|
||||
PriorityFirst,
|
||||
}
|
||||
|
||||
/// Allocate budget across subtasks.
|
||||
///
|
||||
/// # Preconditions
|
||||
/// - `subtasks` is non-empty
|
||||
/// - `total_budget > 0`
|
||||
///
|
||||
/// # Postconditions
|
||||
/// - `result.len() == subtasks.len()`
|
||||
/// - `result.iter().sum() <= total_budget`
|
||||
///
|
||||
/// # Pure Function
|
||||
/// This is a pure function with no side effects.
|
||||
pub fn allocate_budget(
|
||||
subtasks: &[Subtask],
|
||||
total_budget: u64,
|
||||
strategy: AllocationStrategy,
|
||||
) -> Vec<u64> {
|
||||
if subtasks.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
match strategy {
|
||||
AllocationStrategy::Proportional => allocate_proportional(subtasks, total_budget),
|
||||
AllocationStrategy::Equal => allocate_equal(subtasks, total_budget),
|
||||
AllocationStrategy::PriorityFirst => allocate_priority(subtasks, total_budget),
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocate proportionally based on weights.
|
||||
///
|
||||
/// # Invariant
|
||||
/// Sum of allocations == total_budget (minus rounding)
|
||||
fn allocate_proportional(subtasks: &[Subtask], total_budget: u64) -> Vec<u64> {
|
||||
let total_weight: f64 = subtasks.iter().map(|s| s.weight).sum();
|
||||
|
||||
if total_weight <= 0.0 {
|
||||
return allocate_equal(subtasks, total_budget);
|
||||
}
|
||||
|
||||
let mut allocations: Vec<u64> = subtasks
|
||||
.iter()
|
||||
.map(|s| {
|
||||
let proportion = s.weight / total_weight;
|
||||
((total_budget as f64) * proportion).floor() as u64
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Distribute remainder to maintain total
|
||||
let allocated: u64 = allocations.iter().sum();
|
||||
let remainder = total_budget.saturating_sub(allocated);
|
||||
|
||||
// Give remainder to the first subtask (or distribute evenly)
|
||||
if remainder > 0 && !allocations.is_empty() {
|
||||
allocations[0] += remainder;
|
||||
}
|
||||
|
||||
allocations
|
||||
}
|
||||
|
||||
/// Allocate equally among all subtasks.
|
||||
fn allocate_equal(subtasks: &[Subtask], total_budget: u64) -> Vec<u64> {
|
||||
let n = subtasks.len() as u64;
|
||||
if n == 0 {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let base = total_budget / n;
|
||||
let remainder = total_budget % n;
|
||||
|
||||
let mut allocations = vec![base; subtasks.len()];
|
||||
|
||||
// Distribute remainder
|
||||
for i in 0..(remainder as usize) {
|
||||
allocations[i] += 1;
|
||||
}
|
||||
|
||||
allocations
|
||||
}
|
||||
|
||||
/// Allocate with priority to earlier subtasks.
|
||||
///
|
||||
/// Earlier subtasks (lower index) get proportionally more.
|
||||
/// Uses exponential decay: weight[i] = 2^(n-i)
|
||||
fn allocate_priority(subtasks: &[Subtask], total_budget: u64) -> Vec<u64> {
|
||||
let n = subtasks.len();
|
||||
if n == 0 {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
// Compute exponential weights
|
||||
let weights: Vec<f64> = (0..n)
|
||||
.map(|i| 2.0_f64.powi((n - i - 1) as i32))
|
||||
.collect();
|
||||
|
||||
let total_weight: f64 = weights.iter().sum();
|
||||
|
||||
let mut allocations: Vec<u64> = weights
|
||||
.iter()
|
||||
.map(|w| ((total_budget as f64) * (w / total_weight)).floor() as u64)
|
||||
.collect();
|
||||
|
||||
// Distribute remainder
|
||||
let allocated: u64 = allocations.iter().sum();
|
||||
let remainder = total_budget.saturating_sub(allocated);
|
||||
|
||||
if remainder > 0 && !allocations.is_empty() {
|
||||
allocations[0] += remainder;
|
||||
}
|
||||
|
||||
allocations
|
||||
}
|
||||
|
||||
/// Estimate reasonable budget for a task based on complexity.
|
||||
///
|
||||
/// # Formula
|
||||
/// Uses a heuristic based on complexity score:
|
||||
/// - 0.0-0.2: ~10 cents (simple task)
|
||||
/// - 0.2-0.5: ~50 cents (moderate task)
|
||||
/// - 0.5-0.8: ~200 cents (complex task)
|
||||
/// - 0.8-1.0: ~500 cents (very complex task)
|
||||
///
|
||||
/// # Pure Function
|
||||
pub fn estimate_budget_for_complexity(complexity_score: f64) -> u64 {
|
||||
let clamped = complexity_score.clamp(0.0, 1.0);
|
||||
|
||||
// Exponential scaling
|
||||
let base: f64 = 10.0; // Minimum 10 cents
|
||||
let max: f64 = 500.0; // Maximum 500 cents ($5)
|
||||
|
||||
let budget = base * (max / base).powf(clamped);
|
||||
budget.ceil() as u64
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::task::VerificationCriteria;
|
||||
|
||||
fn make_subtasks(weights: &[f64]) -> Vec<Subtask> {
|
||||
weights
|
||||
.iter()
|
||||
.map(|&w| Subtask::new("test", VerificationCriteria::None, w))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_proportional_allocation() {
|
||||
let subtasks = make_subtasks(&[1.0, 2.0, 1.0]);
|
||||
let allocs = allocate_budget(&subtasks, 100, AllocationStrategy::Proportional);
|
||||
|
||||
assert_eq!(allocs.len(), 3);
|
||||
// Should be roughly 25, 50, 25
|
||||
assert!(allocs[1] > allocs[0]);
|
||||
assert_eq!(allocs.iter().sum::<u64>(), 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_equal_allocation() {
|
||||
let subtasks = make_subtasks(&[1.0, 1.0, 1.0]);
|
||||
let allocs = allocate_budget(&subtasks, 99, AllocationStrategy::Equal);
|
||||
|
||||
assert_eq!(allocs.len(), 3);
|
||||
assert_eq!(allocs.iter().sum::<u64>(), 99);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_priority_allocation() {
|
||||
let subtasks = make_subtasks(&[1.0, 1.0, 1.0]);
|
||||
let allocs = allocate_budget(&subtasks, 100, AllocationStrategy::PriorityFirst);
|
||||
|
||||
assert_eq!(allocs.len(), 3);
|
||||
assert!(allocs[0] > allocs[2]); // First gets more
|
||||
assert_eq!(allocs.iter().sum::<u64>(), 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_complexity_budget_estimation() {
|
||||
assert!(estimate_budget_for_complexity(0.0) <= 15);
|
||||
assert!(estimate_budget_for_complexity(0.5) > 50);
|
||||
assert!(estimate_budget_for_complexity(1.0) >= 450);
|
||||
}
|
||||
}
|
||||
|
||||
249
src/budget/budget.rs
Normal file
249
src/budget/budget.rs
Normal file
@@ -0,0 +1,249 @@
|
||||
//! Budget tracking for tasks.
|
||||
//!
|
||||
//! # Invariants
|
||||
//! - `allocated_cents <= total_cents` (enforced at all times)
|
||||
//! - `spent_cents <= allocated_cents` (enforced at all times)
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Budget for a task, tracking total, allocated, and spent amounts.
|
||||
///
|
||||
/// # Invariants
|
||||
/// - `allocated_cents <= total_cents`
|
||||
/// - `spent_cents <= allocated_cents`
|
||||
///
|
||||
/// # Design for Provability
|
||||
/// All mutations go through methods that enforce invariants.
|
||||
/// Direct field access is prevented (fields are private).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Budget {
|
||||
/// Total budget available (in cents)
|
||||
total_cents: u64,
|
||||
|
||||
/// Amount allocated to subtasks (in cents)
|
||||
allocated_cents: u64,
|
||||
|
||||
/// Amount actually spent (in cents)
|
||||
spent_cents: u64,
|
||||
}
|
||||
|
||||
impl Budget {
|
||||
/// Create a new budget with the given total.
|
||||
///
|
||||
/// # Postconditions
|
||||
/// - `budget.total_cents == total_cents`
|
||||
/// - `budget.allocated_cents == 0`
|
||||
/// - `budget.spent_cents == 0`
|
||||
pub fn new(total_cents: u64) -> Self {
|
||||
Self {
|
||||
total_cents,
|
||||
allocated_cents: 0,
|
||||
spent_cents: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a budget with unlimited funds (for testing).
|
||||
///
|
||||
/// # Warning
|
||||
/// This should only be used for testing, not production.
|
||||
pub fn unlimited() -> Self {
|
||||
Self {
|
||||
total_cents: u64::MAX,
|
||||
allocated_cents: 0,
|
||||
spent_cents: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// Getters
|
||||
|
||||
/// Get the total budget in cents.
|
||||
pub fn total_cents(&self) -> u64 {
|
||||
self.total_cents
|
||||
}
|
||||
|
||||
/// Get the allocated amount in cents.
|
||||
pub fn allocated_cents(&self) -> u64 {
|
||||
self.allocated_cents
|
||||
}
|
||||
|
||||
/// Get the spent amount in cents.
|
||||
pub fn spent_cents(&self) -> u64 {
|
||||
self.spent_cents
|
||||
}
|
||||
|
||||
/// Get the remaining unallocated budget in cents.
|
||||
///
|
||||
/// # Property
|
||||
/// `remaining_cents() == total_cents - allocated_cents`
|
||||
pub fn remaining_cents(&self) -> u64 {
|
||||
self.total_cents.saturating_sub(self.allocated_cents)
|
||||
}
|
||||
|
||||
/// Get the unspent allocated budget in cents.
|
||||
///
|
||||
/// # Property
|
||||
/// `unspent_cents() == allocated_cents - spent_cents`
|
||||
pub fn unspent_cents(&self) -> u64 {
|
||||
self.allocated_cents.saturating_sub(self.spent_cents)
|
||||
}
|
||||
|
||||
/// Check if there's any remaining budget to allocate.
|
||||
pub fn has_remaining(&self) -> bool {
|
||||
self.remaining_cents() > 0
|
||||
}
|
||||
|
||||
/// Check if the budget is exhausted (all spent).
|
||||
pub fn is_exhausted(&self) -> bool {
|
||||
self.spent_cents >= self.allocated_cents
|
||||
}
|
||||
|
||||
// Mutations with invariant enforcement
|
||||
|
||||
/// Allocate some budget for a subtask.
|
||||
///
|
||||
/// # Precondition
|
||||
/// `amount <= self.remaining_cents()`
|
||||
///
|
||||
/// # Postcondition
|
||||
/// `self.allocated_cents` increases by `amount`
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns `Err` if allocation would exceed total.
|
||||
pub fn allocate(&mut self, amount: u64) -> Result<(), BudgetError> {
|
||||
let new_allocated = self.allocated_cents.saturating_add(amount);
|
||||
|
||||
if new_allocated > self.total_cents {
|
||||
return Err(BudgetError::AllocationExceedsTotal {
|
||||
requested: amount,
|
||||
remaining: self.remaining_cents(),
|
||||
});
|
||||
}
|
||||
|
||||
self.allocated_cents = new_allocated;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Record spending against the allocated budget.
|
||||
///
|
||||
/// # Precondition
|
||||
/// `amount <= self.unspent_cents()`
|
||||
///
|
||||
/// # Postcondition
|
||||
/// `self.spent_cents` increases by `amount`
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns `Err` if spending would exceed allocated.
|
||||
pub fn spend(&mut self, amount: u64) -> Result<(), BudgetError> {
|
||||
let new_spent = self.spent_cents.saturating_add(amount);
|
||||
|
||||
if new_spent > self.allocated_cents {
|
||||
return Err(BudgetError::SpendingExceedsAllocated {
|
||||
requested: amount,
|
||||
available: self.unspent_cents(),
|
||||
});
|
||||
}
|
||||
|
||||
self.spent_cents = new_spent;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Try to spend, returning how much was actually spent.
|
||||
///
|
||||
/// This is a "best effort" version that won't fail,
|
||||
/// but may spend less than requested.
|
||||
///
|
||||
/// # Postcondition
|
||||
/// `result <= amount`
|
||||
/// `result <= self.unspent_cents()` (before call)
|
||||
pub fn try_spend(&mut self, amount: u64) -> u64 {
|
||||
let available = self.unspent_cents();
|
||||
let actual = amount.min(available);
|
||||
self.spent_cents += actual;
|
||||
actual
|
||||
}
|
||||
|
||||
/// Check if we can afford a given cost.
|
||||
///
|
||||
/// # Returns
|
||||
/// `true` if `cost <= self.unspent_cents()`
|
||||
pub fn can_afford(&self, cost: u64) -> bool {
|
||||
cost <= self.unspent_cents()
|
||||
}
|
||||
|
||||
/// Create a sub-budget from this budget.
|
||||
///
|
||||
/// # Precondition
|
||||
/// `amount <= self.remaining_cents()`
|
||||
///
|
||||
/// # Side Effects
|
||||
/// Allocates `amount` from this budget.
|
||||
///
|
||||
/// # Returns
|
||||
/// A new budget with `total_cents == amount`.
|
||||
pub fn create_sub_budget(&mut self, amount: u64) -> Result<Budget, BudgetError> {
|
||||
self.allocate(amount)?;
|
||||
Ok(Budget::new(amount))
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Budget {
|
||||
/// Default budget is $1.00 (100 cents).
|
||||
fn default() -> Self {
|
||||
Self::new(100)
|
||||
}
|
||||
}
|
||||
|
||||
/// Errors related to budget operations.
|
||||
#[derive(Debug, Clone, thiserror::Error)]
|
||||
pub enum BudgetError {
|
||||
#[error("Allocation of {requested} cents exceeds remaining budget of {remaining} cents")]
|
||||
AllocationExceedsTotal { requested: u64, remaining: u64 },
|
||||
|
||||
#[error("Spending of {requested} cents exceeds available budget of {available} cents")]
|
||||
SpendingExceedsAllocated { requested: u64, available: u64 },
|
||||
|
||||
#[error("Insufficient budget: need {needed} cents, have {available} cents")]
|
||||
InsufficientBudget { needed: u64, available: u64 },
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_budget_invariants() {
|
||||
let mut budget = Budget::new(100);
|
||||
|
||||
// Initially, nothing is allocated or spent
|
||||
assert_eq!(budget.remaining_cents(), 100);
|
||||
assert_eq!(budget.unspent_cents(), 0);
|
||||
|
||||
// Allocate some
|
||||
budget.allocate(50).unwrap();
|
||||
assert_eq!(budget.remaining_cents(), 50);
|
||||
assert_eq!(budget.unspent_cents(), 50);
|
||||
|
||||
// Spend some
|
||||
budget.spend(30).unwrap();
|
||||
assert_eq!(budget.unspent_cents(), 20);
|
||||
assert_eq!(budget.spent_cents(), 30);
|
||||
|
||||
// Can't over-allocate
|
||||
assert!(budget.allocate(60).is_err());
|
||||
|
||||
// Can't over-spend
|
||||
assert!(budget.spend(30).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sub_budget() {
|
||||
let mut parent = Budget::new(100);
|
||||
|
||||
let child = parent.create_sub_budget(40).unwrap();
|
||||
|
||||
assert_eq!(parent.remaining_cents(), 60);
|
||||
assert_eq!(child.total_cents(), 40);
|
||||
assert_eq!(child.remaining_cents(), 40);
|
||||
}
|
||||
}
|
||||
|
||||
15
src/budget/mod.rs
Normal file
15
src/budget/mod.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
//! Budget module - cost tracking and model pricing.
|
||||
//!
|
||||
//! # Key Concepts
|
||||
//! - Budget: tracks total and allocated costs for a task
|
||||
//! - Pricing: fetches and caches OpenRouter model pricing
|
||||
//! - Allocation: algorithms for distributing budget across subtasks
|
||||
|
||||
mod budget;
|
||||
mod pricing;
|
||||
mod allocation;
|
||||
|
||||
pub use budget::{Budget, BudgetError};
|
||||
pub use pricing::{ModelPricing, PricingInfo};
|
||||
pub use allocation::{AllocationStrategy, allocate_budget};
|
||||
|
||||
239
src/budget/pricing.rs
Normal file
239
src/budget/pricing.rs
Normal file
@@ -0,0 +1,239 @@
|
||||
//! OpenRouter pricing information and caching.
|
||||
//!
|
||||
//! Fetches real-time pricing from OpenRouter API to enable
|
||||
//! cost-aware model selection.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Pricing information for a single model.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PricingInfo {
|
||||
/// Model identifier (e.g., "openai/gpt-4.1-mini")
|
||||
pub model_id: String,
|
||||
|
||||
/// Cost per 1M input tokens in dollars
|
||||
pub prompt_cost_per_million: f64,
|
||||
|
||||
/// Cost per 1M output tokens in dollars
|
||||
pub completion_cost_per_million: f64,
|
||||
|
||||
/// Context window size in tokens
|
||||
pub context_length: u64,
|
||||
|
||||
/// Maximum output tokens
|
||||
pub max_output_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
impl PricingInfo {
|
||||
/// Calculate cost in cents for given token counts.
|
||||
///
|
||||
/// # Formula
|
||||
/// `cost = (input_tokens * prompt_rate + output_tokens * completion_rate) / 1_000_000 * 100`
|
||||
///
|
||||
/// # Postcondition
|
||||
/// `result >= 0`
|
||||
pub fn calculate_cost_cents(&self, input_tokens: u64, output_tokens: u64) -> u64 {
|
||||
let input_cost = (input_tokens as f64) * self.prompt_cost_per_million / 1_000_000.0;
|
||||
let output_cost = (output_tokens as f64) * self.completion_cost_per_million / 1_000_000.0;
|
||||
let total_dollars = input_cost + output_cost;
|
||||
(total_dollars * 100.0).ceil() as u64
|
||||
}
|
||||
|
||||
/// Estimate cost for a task given estimated token counts.
|
||||
///
|
||||
/// Adds a safety margin of 20% for estimation errors.
|
||||
pub fn estimate_cost_cents(&self, estimated_input: u64, estimated_output: u64) -> u64 {
|
||||
let base_cost = self.calculate_cost_cents(estimated_input, estimated_output);
|
||||
// Add 20% safety margin
|
||||
(base_cost as f64 * 1.2).ceil() as u64
|
||||
}
|
||||
|
||||
/// Get the average cost per token (for rough comparisons).
|
||||
pub fn average_cost_per_token(&self) -> f64 {
|
||||
(self.prompt_cost_per_million + self.completion_cost_per_million) / 2.0 / 1_000_000.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Model pricing cache and fetcher.
|
||||
pub struct ModelPricing {
|
||||
/// Cached pricing data
|
||||
cache: Arc<RwLock<HashMap<String, PricingInfo>>>,
|
||||
|
||||
/// HTTP client for fetching pricing
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl ModelPricing {
|
||||
/// Create a new pricing manager.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
cache: Arc::new(RwLock::new(HashMap::new())),
|
||||
client: reqwest::Client::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with pre-populated pricing data (for testing or offline use).
|
||||
pub fn with_pricing(pricing: HashMap<String, PricingInfo>) -> Self {
|
||||
Self {
|
||||
cache: Arc::new(RwLock::new(pricing)),
|
||||
client: reqwest::Client::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get pricing for a specific model.
|
||||
///
|
||||
/// Returns cached data if available, otherwise fetches from API.
|
||||
pub async fn get_pricing(&self, model_id: &str) -> Option<PricingInfo> {
|
||||
// Check cache first
|
||||
{
|
||||
let cache = self.cache.read().await;
|
||||
if let Some(info) = cache.get(model_id) {
|
||||
return Some(info.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// If not in cache, try to fetch all models
|
||||
if let Ok(()) = self.refresh_pricing().await {
|
||||
let cache = self.cache.read().await;
|
||||
return cache.get(model_id).cloned();
|
||||
}
|
||||
|
||||
// Fall back to hardcoded defaults for common models
|
||||
self.default_pricing(model_id)
|
||||
}
|
||||
|
||||
/// Get all cached pricing info.
|
||||
pub async fn all_pricing(&self) -> HashMap<String, PricingInfo> {
|
||||
self.cache.read().await.clone()
|
||||
}
|
||||
|
||||
/// Refresh pricing from OpenRouter API.
|
||||
pub async fn refresh_pricing(&self) -> Result<(), PricingError> {
|
||||
let response = self
|
||||
.client
|
||||
.get("https://openrouter.ai/api/v1/models")
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| PricingError::NetworkError(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(PricingError::ApiError(format!(
|
||||
"Status: {}",
|
||||
response.status()
|
||||
)));
|
||||
}
|
||||
|
||||
let data: OpenRouterModelsResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| PricingError::ParseError(e.to_string()))?;
|
||||
|
||||
let mut cache = self.cache.write().await;
|
||||
|
||||
for model in data.data {
|
||||
let info = PricingInfo {
|
||||
model_id: model.id.clone(),
|
||||
prompt_cost_per_million: parse_price(&model.pricing.prompt),
|
||||
completion_cost_per_million: parse_price(&model.pricing.completion),
|
||||
context_length: model.context_length.unwrap_or(4096),
|
||||
max_output_tokens: model.top_provider.as_ref()
|
||||
.and_then(|p| p.max_completion_tokens),
|
||||
};
|
||||
cache.insert(model.id, info);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get default pricing for common models (fallback).
|
||||
fn default_pricing(&self, model_id: &str) -> Option<PricingInfo> {
|
||||
// Hardcoded defaults for when API is unavailable
|
||||
let defaults = [
|
||||
("openai/gpt-4.1-mini", 0.40, 1.60, 1_000_000),
|
||||
("openai/gpt-4.1", 2.50, 10.00, 1_000_000),
|
||||
("openai/gpt-4o", 2.50, 10.00, 128_000),
|
||||
("openai/gpt-4o-mini", 0.15, 0.60, 128_000),
|
||||
("anthropic/claude-3.5-sonnet", 3.00, 15.00, 200_000),
|
||||
("anthropic/claude-3-haiku", 0.25, 1.25, 200_000),
|
||||
];
|
||||
|
||||
for (id, prompt, completion, context) in defaults {
|
||||
if model_id == id || model_id.contains(id.split('/').last().unwrap_or("")) {
|
||||
return Some(PricingInfo {
|
||||
model_id: model_id.to_string(),
|
||||
prompt_cost_per_million: prompt,
|
||||
completion_cost_per_million: completion,
|
||||
context_length: context,
|
||||
max_output_tokens: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Get models sorted by cost (cheapest first).
|
||||
pub async fn models_by_cost(&self) -> Vec<PricingInfo> {
|
||||
let cache = self.cache.read().await;
|
||||
let mut models: Vec<_> = cache.values().cloned().collect();
|
||||
models.sort_by(|a, b| {
|
||||
a.average_cost_per_token()
|
||||
.partial_cmp(&b.average_cost_per_token())
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
models
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ModelPricing {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse price string from OpenRouter API.
|
||||
fn parse_price(price: &str) -> f64 {
|
||||
price.parse().unwrap_or(0.0)
|
||||
}
|
||||
|
||||
/// OpenRouter API response structures.
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenRouterModelsResponse {
|
||||
data: Vec<OpenRouterModel>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenRouterModel {
|
||||
id: String,
|
||||
pricing: OpenRouterPricing,
|
||||
context_length: Option<u64>,
|
||||
top_provider: Option<OpenRouterProvider>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenRouterPricing {
|
||||
prompt: String,
|
||||
completion: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenRouterProvider {
|
||||
max_completion_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
/// Pricing-related errors.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum PricingError {
|
||||
#[error("Network error: {0}")]
|
||||
NetworkError(String),
|
||||
|
||||
#[error("API error: {0}")]
|
||||
ApiError(String),
|
||||
|
||||
#[error("Parse error: {0}")]
|
||||
ParseError(String),
|
||||
}
|
||||
|
||||
45
src/lib.rs
45
src/lib.rs
@@ -4,31 +4,44 @@
|
||||
//!
|
||||
//! This library provides:
|
||||
//! - An HTTP API for task submission and monitoring
|
||||
//! - A tool-based agent loop for autonomous code editing
|
||||
//! - A hierarchical agent tree for complex task handling
|
||||
//! - Tool-based execution for autonomous code editing
|
||||
//! - Integration with OpenRouter for LLM access
|
||||
//!
|
||||
//! ## Architecture
|
||||
//! ## Architecture (v2: Hierarchical Agent Tree)
|
||||
//!
|
||||
//! The agent follows the "tools in a loop" pattern:
|
||||
//! 1. Receive a task via the API
|
||||
//! 2. Build context with system prompt and available tools
|
||||
//! 3. Call LLM, parse response, execute any tool calls
|
||||
//! 4. Feed results back to LLM, repeat until task complete
|
||||
//!
|
||||
//! ## Example
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use open_agent::{config::Config, agent::Agent};
|
||||
//!
|
||||
//! let config = Config::from_env()?;
|
||||
//! let agent = Agent::new(config);
|
||||
//! let result = agent.run_task("Create a hello world script").await?;
|
||||
//! ```text
|
||||
//! ┌─────────────┐
|
||||
//! │ RootAgent │
|
||||
//! └──────┬──────┘
|
||||
//! ┌─────────────────┼─────────────────┐
|
||||
//! ▼ ▼ ▼
|
||||
//! ┌───────────────┐ ┌─────────────┐ ┌─────────────┐
|
||||
//! │ Complexity │ │ Model │ │ Task │
|
||||
//! │ Estimator │ │ Selector │ │ Executor │
|
||||
//! └───────────────┘ └─────────────┘ └─────────────┘
|
||||
//! ```
|
||||
//!
|
||||
//! ## Task Flow
|
||||
//! 1. Receive task via API
|
||||
//! 2. Estimate complexity (should we split?)
|
||||
//! 3. Select optimal model (U-curve cost optimization)
|
||||
//! 4. Execute (directly or via subtasks)
|
||||
//! 5. Verify completion (programmatic + LLM hybrid)
|
||||
//!
|
||||
//! ## Modules
|
||||
//! - `agents`: Hierarchical agent tree (Root, Node, Leaf agents)
|
||||
//! - `task`: Task, subtask, and verification types
|
||||
//! - `budget`: Cost tracking and model pricing
|
||||
//! - `agent`: Original simple agent (kept for compatibility)
|
||||
|
||||
pub mod api;
|
||||
pub mod agent;
|
||||
pub mod agents;
|
||||
pub mod budget;
|
||||
pub mod config;
|
||||
pub mod llm;
|
||||
pub mod task;
|
||||
pub mod tools;
|
||||
|
||||
pub use config::Config;
|
||||
|
||||
15
src/task/mod.rs
Normal file
15
src/task/mod.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
//! Task module - defines tasks, subtasks, and verification criteria.
|
||||
//!
|
||||
//! This module is designed with formal verification in mind:
|
||||
//! - All types use algebraic data types with exhaustive matching
|
||||
//! - Invariants are documented and enforced in constructors
|
||||
//! - Pure functions are separated from IO operations
|
||||
|
||||
pub mod task;
|
||||
mod subtask;
|
||||
mod verification;
|
||||
|
||||
pub use task::{Task, TaskId, TaskStatus, TaskError};
|
||||
pub use subtask::{Subtask, SubtaskPlan, SubtaskPlanError};
|
||||
pub use verification::{VerificationCriteria, VerificationResult, VerificationMethod, ProgrammaticCheck};
|
||||
|
||||
249
src/task/subtask.rs
Normal file
249
src/task/subtask.rs
Normal file
@@ -0,0 +1,249 @@
|
||||
//! Subtask definitions and splitting logic.
|
||||
//!
|
||||
//! When a task is too complex, it can be split into subtasks.
|
||||
//! Each subtask has:
|
||||
//! - A description of what to do
|
||||
//! - Verification criteria (how to check it's done)
|
||||
//! - A budget allocation
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::{Task, TaskId, VerificationCriteria};
|
||||
use crate::budget::Budget;
|
||||
|
||||
/// A planned subtask before it becomes a full Task.
|
||||
///
|
||||
/// # Purpose
|
||||
/// Represents the output of task splitting before budget allocation.
|
||||
/// Once budgets are assigned, these become full `Task` objects.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Subtask {
|
||||
/// Description of what this subtask should accomplish
|
||||
pub description: String,
|
||||
|
||||
/// How to verify this subtask is complete
|
||||
pub verification: VerificationCriteria,
|
||||
|
||||
/// Relative weight for budget allocation (higher = more budget)
|
||||
pub weight: f64,
|
||||
|
||||
/// Dependencies: IDs of subtasks that must complete first
|
||||
pub dependencies: Vec<usize>,
|
||||
}
|
||||
|
||||
impl Subtask {
|
||||
/// Create a new subtask with no dependencies.
|
||||
///
|
||||
/// # Preconditions
|
||||
/// - `description` is non-empty
|
||||
/// - `weight > 0.0`
|
||||
pub fn new(
|
||||
description: impl Into<String>,
|
||||
verification: VerificationCriteria,
|
||||
weight: f64,
|
||||
) -> Self {
|
||||
Self {
|
||||
description: description.into(),
|
||||
verification,
|
||||
weight: weight.max(0.01), // Ensure positive weight
|
||||
dependencies: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a dependency on another subtask (by index).
|
||||
pub fn with_dependency(mut self, index: usize) -> Self {
|
||||
self.dependencies.push(index);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add multiple dependencies.
|
||||
pub fn with_dependencies(mut self, indices: Vec<usize>) -> Self {
|
||||
self.dependencies.extend(indices);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// A plan for splitting a task into subtasks.
|
||||
///
|
||||
/// # Invariants
|
||||
/// - `subtasks` is non-empty
|
||||
/// - All dependency indices are valid (< subtasks.len())
|
||||
/// - No circular dependencies
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SubtaskPlan {
|
||||
/// The parent task ID
|
||||
parent_id: TaskId,
|
||||
|
||||
/// The subtasks to create
|
||||
subtasks: Vec<Subtask>,
|
||||
|
||||
/// Reasoning for the split
|
||||
reasoning: String,
|
||||
}
|
||||
|
||||
impl SubtaskPlan {
|
||||
/// Create a new subtask plan.
|
||||
///
|
||||
/// # Preconditions
|
||||
/// - `subtasks` is non-empty
|
||||
/// - All dependency indices are valid
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns `Err` if preconditions are violated.
|
||||
pub fn new(
|
||||
parent_id: TaskId,
|
||||
subtasks: Vec<Subtask>,
|
||||
reasoning: impl Into<String>,
|
||||
) -> Result<Self, SubtaskPlanError> {
|
||||
if subtasks.is_empty() {
|
||||
return Err(SubtaskPlanError::EmptySubtasks);
|
||||
}
|
||||
|
||||
// Validate dependency indices
|
||||
for (i, subtask) in subtasks.iter().enumerate() {
|
||||
for &dep in &subtask.dependencies {
|
||||
if dep >= subtasks.len() {
|
||||
return Err(SubtaskPlanError::InvalidDependency {
|
||||
subtask_index: i,
|
||||
dependency_index: dep
|
||||
});
|
||||
}
|
||||
if dep == i {
|
||||
return Err(SubtaskPlanError::SelfDependency { subtask_index: i });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Check for circular dependencies (would need topological sort)
|
||||
|
||||
Ok(Self {
|
||||
parent_id,
|
||||
subtasks,
|
||||
reasoning: reasoning.into(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the parent task ID.
|
||||
pub fn parent_id(&self) -> TaskId {
|
||||
self.parent_id
|
||||
}
|
||||
|
||||
/// Get the subtasks.
|
||||
pub fn subtasks(&self) -> &[Subtask] {
|
||||
&self.subtasks
|
||||
}
|
||||
|
||||
/// Get the reasoning for the split.
|
||||
pub fn reasoning(&self) -> &str {
|
||||
&self.reasoning
|
||||
}
|
||||
|
||||
/// Convert this plan into actual Task objects with allocated budgets.
|
||||
///
|
||||
/// # Preconditions
|
||||
/// - `total_budget.remaining() > 0`
|
||||
///
|
||||
/// # Postconditions
|
||||
/// - Sum of subtask budgets <= total_budget
|
||||
/// - Each subtask has parent_id set to self.parent_id
|
||||
///
|
||||
/// # Budget Allocation
|
||||
/// Budget is allocated proportionally based on subtask weights.
|
||||
pub fn into_tasks(self, total_budget: &Budget) -> Result<Vec<Task>, SubtaskPlanError> {
|
||||
let total_weight: f64 = self.subtasks.iter().map(|s| s.weight).sum();
|
||||
|
||||
if total_weight <= 0.0 {
|
||||
return Err(SubtaskPlanError::ZeroTotalWeight);
|
||||
}
|
||||
|
||||
let available = total_budget.remaining_cents();
|
||||
|
||||
self.subtasks
|
||||
.into_iter()
|
||||
.map(|subtask| {
|
||||
// Allocate budget proportionally
|
||||
let proportion = subtask.weight / total_weight;
|
||||
let allocated = ((available as f64) * proportion) as u64;
|
||||
|
||||
let budget = Budget::new(allocated);
|
||||
|
||||
Task::new_subtask(
|
||||
subtask.description,
|
||||
subtask.verification,
|
||||
budget,
|
||||
self.parent_id,
|
||||
).map_err(|e| SubtaskPlanError::TaskCreation(e.to_string()))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get execution order respecting dependencies (topological sort).
|
||||
///
|
||||
/// # Returns
|
||||
/// Vector of subtask indices in valid execution order.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns `Err` if there are circular dependencies.
|
||||
pub fn execution_order(&self) -> Result<Vec<usize>, SubtaskPlanError> {
|
||||
let n = self.subtasks.len();
|
||||
let mut in_degree = vec![0usize; n];
|
||||
let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
|
||||
|
||||
// Build adjacency list and compute in-degrees
|
||||
for (i, subtask) in self.subtasks.iter().enumerate() {
|
||||
for &dep in &subtask.dependencies {
|
||||
adj[dep].push(i);
|
||||
in_degree[i] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Kahn's algorithm for topological sort
|
||||
let mut queue: Vec<usize> = in_degree
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, &d)| d == 0)
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
|
||||
let mut order = Vec::with_capacity(n);
|
||||
|
||||
while let Some(node) = queue.pop() {
|
||||
order.push(node);
|
||||
for &next in &adj[node] {
|
||||
in_degree[next] -= 1;
|
||||
if in_degree[next] == 0 {
|
||||
queue.push(next);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if order.len() != n {
|
||||
Err(SubtaskPlanError::CircularDependency)
|
||||
} else {
|
||||
Ok(order)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Errors in subtask plan creation or execution.
|
||||
#[derive(Debug, Clone, thiserror::Error)]
|
||||
pub enum SubtaskPlanError {
|
||||
#[error("Subtask list cannot be empty")]
|
||||
EmptySubtasks,
|
||||
|
||||
#[error("Subtask {subtask_index} has invalid dependency index {dependency_index}")]
|
||||
InvalidDependency { subtask_index: usize, dependency_index: usize },
|
||||
|
||||
#[error("Subtask {subtask_index} depends on itself")]
|
||||
SelfDependency { subtask_index: usize },
|
||||
|
||||
#[error("Circular dependency detected in subtask plan")]
|
||||
CircularDependency,
|
||||
|
||||
#[error("Total weight of subtasks is zero")]
|
||||
ZeroTotalWeight,
|
||||
|
||||
#[error("Failed to create task: {0}")]
|
||||
TaskCreation(String),
|
||||
}
|
||||
|
||||
291
src/task/task.rs
Normal file
291
src/task/task.rs
Normal file
@@ -0,0 +1,291 @@
|
||||
//! Core Task type with budget and verification criteria.
|
||||
//!
|
||||
//! # Invariants
|
||||
//! - `budget.allocated_cents <= budget.total_cents`
|
||||
//! - `id` is unique within an agent tree execution
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::verification::VerificationCriteria;
|
||||
use crate::budget::Budget;
|
||||
|
||||
/// Unique identifier for a task.
|
||||
///
|
||||
/// # Properties
|
||||
/// - Globally unique within an execution context
|
||||
/// - Immutable once created
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct TaskId(Uuid);
|
||||
|
||||
impl TaskId {
|
||||
/// Create a new unique task ID.
|
||||
///
|
||||
/// # Postcondition
|
||||
/// Returns a fresh ID that has never been used before in this process.
|
||||
pub fn new() -> Self {
|
||||
Self(Uuid::new_v4())
|
||||
}
|
||||
|
||||
/// Get the inner UUID.
|
||||
pub fn as_uuid(&self) -> Uuid {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TaskId {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TaskId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Status of a task in its lifecycle.
|
||||
///
|
||||
/// # State Machine
|
||||
/// ```text
|
||||
/// Pending -> Running -> Completed
|
||||
/// \-> Failed
|
||||
/// \-> Cancelled
|
||||
/// ```
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum TaskStatus {
|
||||
/// Task is waiting to be executed
|
||||
Pending,
|
||||
/// Task is currently being executed
|
||||
Running,
|
||||
/// Task completed successfully
|
||||
Completed,
|
||||
/// Task failed with an error
|
||||
Failed { reason: String },
|
||||
/// Task was cancelled before completion
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
impl TaskStatus {
|
||||
/// Check if the task is in a terminal state.
|
||||
///
|
||||
/// # Returns
|
||||
/// `true` if the task is Completed, Failed, or Cancelled.
|
||||
///
|
||||
/// # Property
|
||||
/// `is_terminal() => !can_transition()`
|
||||
pub fn is_terminal(&self) -> bool {
|
||||
matches!(self, TaskStatus::Completed | TaskStatus::Failed { .. } | TaskStatus::Cancelled)
|
||||
}
|
||||
|
||||
/// Check if the task is still active (can make progress).
|
||||
///
|
||||
/// # Returns
|
||||
/// `true` if the task is Pending or Running.
|
||||
pub fn is_active(&self) -> bool {
|
||||
matches!(self, TaskStatus::Pending | TaskStatus::Running)
|
||||
}
|
||||
}
|
||||
|
||||
/// A task to be executed by an agent.
|
||||
///
|
||||
/// # Invariants
|
||||
/// - `budget.allocated_cents <= budget.total_cents`
|
||||
/// - If `parent_id.is_some()`, this is a subtask
|
||||
///
|
||||
/// # Design for Provability
|
||||
/// - All fields are immutable after construction (except status via explicit transitions)
|
||||
/// - Budget constraints are checked at construction time
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Task {
|
||||
/// Unique identifier for this task
|
||||
id: TaskId,
|
||||
|
||||
/// Human-readable description of what to accomplish
|
||||
description: String,
|
||||
|
||||
/// How to verify the task was completed correctly
|
||||
verification: VerificationCriteria,
|
||||
|
||||
/// Budget constraints for this task
|
||||
budget: Budget,
|
||||
|
||||
/// Parent task ID if this is a subtask
|
||||
parent_id: Option<TaskId>,
|
||||
|
||||
/// Current status
|
||||
status: TaskStatus,
|
||||
}
|
||||
|
||||
impl Task {
|
||||
/// Create a new task with the given parameters.
|
||||
///
|
||||
/// # Preconditions
|
||||
/// - `budget.allocated_cents <= budget.total_cents`
|
||||
/// - `description` is non-empty
|
||||
///
|
||||
/// # Postconditions
|
||||
/// - Returns a task with `status == Pending`
|
||||
/// - `task.id` is a fresh unique identifier
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns `Err` if preconditions are violated.
|
||||
pub fn new(
|
||||
description: String,
|
||||
verification: VerificationCriteria,
|
||||
budget: Budget,
|
||||
) -> Result<Self, TaskError> {
|
||||
if description.is_empty() {
|
||||
return Err(TaskError::EmptyDescription);
|
||||
}
|
||||
|
||||
// Budget invariant is enforced by Budget::new()
|
||||
|
||||
Ok(Self {
|
||||
id: TaskId::new(),
|
||||
description,
|
||||
verification,
|
||||
budget,
|
||||
parent_id: None,
|
||||
status: TaskStatus::Pending,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a subtask with a parent reference.
|
||||
///
|
||||
/// # Preconditions
|
||||
/// - Same as `new()`
|
||||
/// - `parent_id` refers to an existing task
|
||||
pub fn new_subtask(
|
||||
description: String,
|
||||
verification: VerificationCriteria,
|
||||
budget: Budget,
|
||||
parent_id: TaskId,
|
||||
) -> Result<Self, TaskError> {
|
||||
let mut task = Self::new(description, verification, budget)?;
|
||||
task.parent_id = Some(parent_id);
|
||||
Ok(task)
|
||||
}
|
||||
|
||||
// Getters - all return references to preserve immutability semantics
|
||||
|
||||
pub fn id(&self) -> TaskId {
|
||||
self.id
|
||||
}
|
||||
|
||||
pub fn description(&self) -> &str {
|
||||
&self.description
|
||||
}
|
||||
|
||||
pub fn verification(&self) -> &VerificationCriteria {
|
||||
&self.verification
|
||||
}
|
||||
|
||||
pub fn budget(&self) -> &Budget {
|
||||
&self.budget
|
||||
}
|
||||
|
||||
pub fn budget_mut(&mut self) -> &mut Budget {
|
||||
&mut self.budget
|
||||
}
|
||||
|
||||
pub fn parent_id(&self) -> Option<TaskId> {
|
||||
self.parent_id
|
||||
}
|
||||
|
||||
pub fn status(&self) -> &TaskStatus {
|
||||
&self.status
|
||||
}
|
||||
|
||||
/// Check if this task is a subtask (has a parent).
|
||||
pub fn is_subtask(&self) -> bool {
|
||||
self.parent_id.is_some()
|
||||
}
|
||||
|
||||
// State transitions - explicit and validated
|
||||
|
||||
/// Transition the task to Running state.
|
||||
///
|
||||
/// # Precondition
|
||||
/// `self.status == Pending`
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns `Err` if the task is not in Pending state.
|
||||
pub fn start(&mut self) -> Result<(), TaskError> {
|
||||
match &self.status {
|
||||
TaskStatus::Pending => {
|
||||
self.status = TaskStatus::Running;
|
||||
Ok(())
|
||||
}
|
||||
other => Err(TaskError::InvalidTransition {
|
||||
from: format!("{:?}", other),
|
||||
to: "Running".to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Transition the task to Completed state.
|
||||
///
|
||||
/// # Precondition
|
||||
/// `self.status == Running`
|
||||
pub fn complete(&mut self) -> Result<(), TaskError> {
|
||||
match &self.status {
|
||||
TaskStatus::Running => {
|
||||
self.status = TaskStatus::Completed;
|
||||
Ok(())
|
||||
}
|
||||
other => Err(TaskError::InvalidTransition {
|
||||
from: format!("{:?}", other),
|
||||
to: "Completed".to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Transition the task to Failed state.
|
||||
///
|
||||
/// # Precondition
|
||||
/// `self.status == Running`
|
||||
pub fn fail(&mut self, reason: String) -> Result<(), TaskError> {
|
||||
match &self.status {
|
||||
TaskStatus::Running => {
|
||||
self.status = TaskStatus::Failed { reason };
|
||||
Ok(())
|
||||
}
|
||||
other => Err(TaskError::InvalidTransition {
|
||||
from: format!("{:?}", other),
|
||||
to: "Failed".to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Transition the task to Cancelled state.
|
||||
///
|
||||
/// # Precondition
|
||||
/// `self.status.is_active()`
|
||||
pub fn cancel(&mut self) -> Result<(), TaskError> {
|
||||
if self.status.is_active() {
|
||||
self.status = TaskStatus::Cancelled;
|
||||
Ok(())
|
||||
} else {
|
||||
Err(TaskError::InvalidTransition {
|
||||
from: format!("{:?}", self.status),
|
||||
to: "Cancelled".to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Errors that can occur during task operations.
|
||||
#[derive(Debug, Clone, thiserror::Error)]
|
||||
pub enum TaskError {
|
||||
#[error("Task description cannot be empty")]
|
||||
EmptyDescription,
|
||||
|
||||
#[error("Invalid state transition from {from} to {to}")]
|
||||
InvalidTransition { from: String, to: String },
|
||||
|
||||
#[error("Budget error: {0}")]
|
||||
BudgetError(String),
|
||||
}
|
||||
|
||||
215
src/task/verification.rs
Normal file
215
src/task/verification.rs
Normal file
@@ -0,0 +1,215 @@
|
||||
//! Verification criteria and results for tasks.
|
||||
//!
|
||||
//! Supports hybrid verification: programmatic checks with LLM fallback.
|
||||
//!
|
||||
//! # Design Principles
|
||||
//! - Prefer programmatic verification when possible (deterministic, fast)
|
||||
//! - Use LLM verification for subjective or complex assessments
|
||||
//! - Hybrid mode tries programmatic first, falls back to LLM
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Programmatic checks that can be performed without LLM.
|
||||
///
|
||||
/// # Exhaustive Matching
|
||||
/// All variants must be handled explicitly - no catch-all allowed.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum ProgrammaticCheck {
|
||||
/// Check if a file exists at the given path
|
||||
FileExists { path: String },
|
||||
|
||||
/// Check if a file contains specific content
|
||||
FileContains { path: String, content: String },
|
||||
|
||||
/// Check if a command exits with code 0
|
||||
CommandSucceeds { command: String },
|
||||
|
||||
/// Check if a command output matches expected pattern
|
||||
CommandOutputMatches { command: String, pattern: String },
|
||||
|
||||
/// Check if a directory exists
|
||||
DirectoryExists { path: String },
|
||||
|
||||
/// Check if a file matches a regex pattern
|
||||
FileMatchesRegex { path: String, pattern: String },
|
||||
|
||||
/// Multiple checks that must all pass
|
||||
All(Vec<ProgrammaticCheck>),
|
||||
|
||||
/// At least one check must pass
|
||||
Any(Vec<ProgrammaticCheck>),
|
||||
}
|
||||
|
||||
impl ProgrammaticCheck {
|
||||
/// Create a file exists check.
|
||||
pub fn file_exists(path: impl Into<String>) -> Self {
|
||||
Self::FileExists { path: path.into() }
|
||||
}
|
||||
|
||||
/// Create a command succeeds check.
|
||||
pub fn command_succeeds(command: impl Into<String>) -> Self {
|
||||
Self::CommandSucceeds { command: command.into() }
|
||||
}
|
||||
|
||||
/// Create an "all must pass" composite check.
|
||||
pub fn all(checks: Vec<ProgrammaticCheck>) -> Self {
|
||||
Self::All(checks)
|
||||
}
|
||||
|
||||
/// Create an "any must pass" composite check.
|
||||
pub fn any(checks: Vec<ProgrammaticCheck>) -> Self {
|
||||
Self::Any(checks)
|
||||
}
|
||||
}
|
||||
|
||||
/// How to verify a task was completed correctly.
|
||||
///
|
||||
/// # Variants
|
||||
/// - `Programmatic`: Use only programmatic checks (fast, deterministic)
|
||||
/// - `LlmBased`: Use LLM to verify (flexible, slower)
|
||||
/// - `Hybrid`: Try programmatic first, fall back to LLM if inconclusive
|
||||
/// - `None`: No verification required (use with caution)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum VerificationCriteria {
|
||||
/// Programmatic verification only
|
||||
Programmatic(ProgrammaticCheck),
|
||||
|
||||
/// LLM-based verification with a prompt describing success criteria
|
||||
LlmBased {
|
||||
/// Prompt describing what "success" looks like
|
||||
success_criteria: String,
|
||||
},
|
||||
|
||||
/// Try programmatic first, fall back to LLM
|
||||
Hybrid {
|
||||
/// Programmatic check to try first
|
||||
programmatic: ProgrammaticCheck,
|
||||
/// LLM prompt to use if programmatic is inconclusive
|
||||
llm_fallback: String,
|
||||
},
|
||||
|
||||
/// No verification (task is considered complete when agent says so)
|
||||
None,
|
||||
}
|
||||
|
||||
impl VerificationCriteria {
|
||||
/// Create a programmatic-only verification.
|
||||
pub fn programmatic(check: ProgrammaticCheck) -> Self {
|
||||
Self::Programmatic(check)
|
||||
}
|
||||
|
||||
/// Create an LLM-based verification.
|
||||
pub fn llm_based(success_criteria: impl Into<String>) -> Self {
|
||||
Self::LlmBased {
|
||||
success_criteria: success_criteria.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a hybrid verification.
|
||||
pub fn hybrid(programmatic: ProgrammaticCheck, llm_fallback: impl Into<String>) -> Self {
|
||||
Self::Hybrid {
|
||||
programmatic,
|
||||
llm_fallback: llm_fallback.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a no-verification criteria.
|
||||
pub fn none() -> Self {
|
||||
Self::None
|
||||
}
|
||||
|
||||
/// Check if this verification requires LLM access.
|
||||
///
|
||||
/// # Returns
|
||||
/// `true` if LLM may be needed (LlmBased or Hybrid)
|
||||
pub fn may_require_llm(&self) -> bool {
|
||||
matches!(self, Self::LlmBased { .. } | Self::Hybrid { .. })
|
||||
}
|
||||
|
||||
/// Check if this verification is purely programmatic.
|
||||
pub fn is_programmatic_only(&self) -> bool {
|
||||
matches!(self, Self::Programmatic(_))
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for VerificationCriteria {
|
||||
fn default() -> Self {
|
||||
Self::None
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of a verification attempt.
|
||||
///
|
||||
/// # Invariants
|
||||
/// - If `passed == true`, the task is considered successfully completed
|
||||
/// - `reasoning` should always explain why the verification passed or failed
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VerificationResult {
|
||||
/// Whether the verification passed
|
||||
passed: bool,
|
||||
|
||||
/// Explanation of why the verification passed or failed
|
||||
reasoning: String,
|
||||
|
||||
/// Which method was used for verification
|
||||
method: VerificationMethod,
|
||||
|
||||
/// Cost in cents if LLM was used
|
||||
cost_cents: u64,
|
||||
}
|
||||
|
||||
impl VerificationResult {
|
||||
/// Create a passing result.
|
||||
///
|
||||
/// # Postcondition
|
||||
/// `result.passed == true`
|
||||
pub fn pass(reasoning: impl Into<String>, method: VerificationMethod, cost_cents: u64) -> Self {
|
||||
Self {
|
||||
passed: true,
|
||||
reasoning: reasoning.into(),
|
||||
method,
|
||||
cost_cents,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a failing result.
|
||||
///
|
||||
/// # Postcondition
|
||||
/// `result.passed == false`
|
||||
pub fn fail(reasoning: impl Into<String>, method: VerificationMethod, cost_cents: u64) -> Self {
|
||||
Self {
|
||||
passed: false,
|
||||
reasoning: reasoning.into(),
|
||||
method,
|
||||
cost_cents,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn passed(&self) -> bool {
|
||||
self.passed
|
||||
}
|
||||
|
||||
pub fn reasoning(&self) -> &str {
|
||||
&self.reasoning
|
||||
}
|
||||
|
||||
pub fn method(&self) -> &VerificationMethod {
|
||||
&self.method
|
||||
}
|
||||
|
||||
pub fn cost_cents(&self) -> u64 {
|
||||
self.cost_cents
|
||||
}
|
||||
}
|
||||
|
||||
/// Method used for verification.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum VerificationMethod {
|
||||
/// Programmatic check was used
|
||||
Programmatic,
|
||||
/// LLM was used for verification
|
||||
Llm { model: String },
|
||||
/// No verification was performed
|
||||
None,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user