Implement hierarchical agent tree architecture

Core types with provability design:
- Task, Budget, Complexity with documented invariants
- VerificationCriteria with programmatic/LLM hybrid support
- SubtaskPlan with topological sort for execution order

Agent hierarchy:
- Agent trait with pre/post-conditions documented
- OrchestratorAgent for Root/Node agents
- LeafAgent for specialized workers

Leaf agents:
- ComplexityEstimator: estimates task difficulty (0-1 score)
- ModelSelector: U-curve optimization for cost/capability
- TaskExecutor: refactored from original agent loop
- Verifier: hybrid programmatic + LLM verification

Orchestrators:
- RootAgent: top-level, estimates complexity, splits tasks
- NodeAgent: intermediate, handles delegated subtasks

Budget system:
- Budget allocation strategies (proportional, equal, priority)
- OpenRouter pricing integration for cost estimation

API updated to use hierarchical RootAgent
This commit is contained in:
Thomas Marchand
2025-12-14 21:49:03 +00:00
parent 4ffa605c0c
commit 773991ffba
23 changed files with 3955 additions and 34 deletions

View File

@@ -36,6 +36,7 @@ walkdir = "2"
urlencoding = "2"
anyhow = "1"
async-stream = "0.3"
regex = "1"
[dev-dependencies]
tokio-test = "0.4"

85
src/agents/context.rs Normal file
View File

@@ -0,0 +1,85 @@
//! Agent execution context - shared state across the agent tree.
use std::path::PathBuf;
use std::sync::Arc;
use crate::budget::ModelPricing;
use crate::config::Config;
use crate::llm::LlmClient;
use crate::tools::ToolRegistry;
/// Shared context passed to all agents during execution.
///
/// # Thread Safety
/// Context is wrapped in Arc for sharing across async tasks.
/// Individual components use interior mutability where needed.
pub struct AgentContext {
/// Application configuration
pub config: Config,
/// LLM client for model calls
pub llm: Arc<dyn LlmClient>,
/// Tool registry for task execution
pub tools: ToolRegistry,
/// Model pricing information
pub pricing: Arc<ModelPricing>,
/// Workspace path for file operations
pub workspace: PathBuf,
/// Maximum depth for recursive task splitting
pub max_split_depth: usize,
/// Maximum iterations per agent
pub max_iterations: usize,
}
impl AgentContext {
/// Create a new agent context.
pub fn new(
config: Config,
llm: Arc<dyn LlmClient>,
tools: ToolRegistry,
pricing: Arc<ModelPricing>,
workspace: PathBuf,
) -> Self {
Self {
max_iterations: config.max_iterations,
config,
llm,
tools,
pricing,
workspace,
max_split_depth: 3, // Default max recursion for splitting
}
}
/// Create a child context with reduced split depth.
///
/// # Postcondition
/// `child.max_split_depth == self.max_split_depth - 1`
pub fn child_context(&self) -> Self {
Self {
config: self.config.clone(),
llm: Arc::clone(&self.llm),
tools: ToolRegistry::new(), // Fresh tools for isolation
pricing: Arc::clone(&self.pricing),
workspace: self.workspace.clone(),
max_split_depth: self.max_split_depth.saturating_sub(1),
max_iterations: self.max_iterations,
}
}
/// Check if further task splitting is allowed.
pub fn can_split(&self) -> bool {
self.max_split_depth > 0
}
/// Get the workspace path as a string.
pub fn workspace_str(&self) -> String {
self.workspace.to_string_lossy().to_string()
}
}

View File

@@ -0,0 +1,242 @@
//! Complexity estimation agent.
//!
//! Analyzes a task description and estimates:
//! - Complexity score (0-1)
//! - Whether to split into subtasks
//! - Estimated token count
use async_trait::async_trait;
use serde_json::json;
use crate::agents::{
Agent, AgentContext, AgentId, AgentResult, AgentType, Complexity, LeafAgent, LeafCapability,
};
use crate::llm::{ChatMessage, Role};
use crate::task::Task;
/// Agent that estimates task complexity.
///
/// # Purpose
/// Given a task description, estimate how complex it is and whether
/// it should be split into subtasks.
///
/// # Algorithm
/// 1. Send task description to LLM with complexity evaluation prompt
/// 2. Parse LLM response for complexity score and reasoning
/// 3. Return structured Complexity object
pub struct ComplexityEstimator {
id: AgentId,
}
impl ComplexityEstimator {
/// Create a new complexity estimator.
pub fn new() -> Self {
Self { id: AgentId::new() }
}
/// Prompt template for complexity estimation.
///
/// # Response Format
/// LLM should respond with JSON containing:
/// - score: float 0-1
/// - reasoning: string explanation
/// - estimated_tokens: int
/// - subtasks: optional array if should split
fn build_prompt(&self, task: &Task) -> String {
format!(
r#"You are a task complexity analyzer. Analyze the following task and estimate its complexity.
Task: {}
Respond with a JSON object containing:
- "score": A float from 0.0 to 1.0 where:
- 0.0-0.2: Trivial (single command, simple file operation)
- 0.2-0.4: Simple (few steps, straightforward implementation)
- 0.4-0.6: Moderate (multiple files, some decision making)
- 0.6-0.8: Complex (many files, architectural decisions, testing)
- 0.8-1.0: Very Complex (large refactoring, many dependencies)
- "reasoning": Brief explanation of why this complexity level
- "estimated_tokens": Estimated total tokens needed (input + output) to complete this task
- "should_split": Boolean, true if task should be broken into subtasks
- "subtasks": If should_split is true, array of suggested subtask descriptions
Respond with ONLY the JSON object, no other text."#,
task.description()
)
}
/// Parse LLM response into Complexity struct.
///
/// # Postconditions
/// - Returns valid Complexity with score in [0, 1]
/// - Falls back to moderate complexity on parse error
fn parse_response(&self, response: &str) -> Complexity {
// Try to parse as JSON
if let Ok(json) = serde_json::from_str::<serde_json::Value>(response) {
let score = json["score"].as_f64().unwrap_or(0.5);
let reasoning = json["reasoning"].as_str().unwrap_or("No reasoning provided");
let estimated_tokens = json["estimated_tokens"].as_u64().unwrap_or(2000);
return Complexity::new(score, reasoning, estimated_tokens);
}
// Try to extract score from text
if let Some(score) = self.extract_score_from_text(response) {
return Complexity::new(score, response, 2000);
}
// Default to moderate complexity
Complexity::moderate("Could not parse complexity response")
}
/// Try to extract a score from free-form text.
fn extract_score_from_text(&self, text: &str) -> Option<f64> {
// Look for patterns like "0.5" or "score: 0.5" or "50%"
let text_lower = text.to_lowercase();
// Check for keywords
if text_lower.contains("trivial") || text_lower.contains("very simple") {
return Some(0.1);
}
if text_lower.contains("very complex") || text_lower.contains("extremely") {
return Some(0.9);
}
if text_lower.contains("complex") {
return Some(0.7);
}
if text_lower.contains("moderate") || text_lower.contains("medium") {
return Some(0.5);
}
if text_lower.contains("simple") || text_lower.contains("easy") {
return Some(0.3);
}
None
}
}
impl Default for ComplexityEstimator {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Agent for ComplexityEstimator {
fn id(&self) -> &AgentId {
&self.id
}
fn agent_type(&self) -> AgentType {
AgentType::ComplexityEstimator
}
fn description(&self) -> &str {
"Estimates task complexity and recommends splitting strategy"
}
/// Estimate complexity of a task.
///
/// # Returns
/// AgentResult with Complexity data in the `data` field.
async fn execute(&self, task: &mut Task, ctx: &AgentContext) -> AgentResult {
let prompt = self.build_prompt(task);
let messages = vec![
ChatMessage {
role: Role::System,
content: Some("You are a precise task complexity analyzer. Respond only with JSON.".to_string()),
tool_calls: None,
tool_call_id: None,
},
ChatMessage {
role: Role::User,
content: Some(prompt),
tool_calls: None,
tool_call_id: None,
},
];
// Use a fast, cheap model for complexity estimation
let model = "openai/gpt-4.1-mini";
match ctx.llm.chat_completion(model, &messages, None).await {
Ok(response) => {
let content = response.content.unwrap_or_default();
let complexity = self.parse_response(&content);
// Estimate cost (rough: ~1000 tokens for this request)
let cost_cents = 1; // Very cheap operation
AgentResult::success(
format!(
"Complexity: {:.2} - {}",
complexity.score(),
if complexity.should_split() { "Should split" } else { "Execute directly" }
),
cost_cents,
)
.with_model(model)
.with_data(json!({
"score": complexity.score(),
"reasoning": complexity.reasoning(),
"should_split": complexity.should_split(),
"estimated_tokens": complexity.estimated_tokens(),
}))
}
Err(e) => {
// On error, return moderate complexity as fallback
let fallback = Complexity::moderate(format!("LLM error, using fallback: {}", e));
AgentResult::success(
"Using fallback complexity estimate due to LLM error",
0,
)
.with_data(json!({
"score": fallback.score(),
"reasoning": fallback.reasoning(),
"should_split": fallback.should_split(),
"estimated_tokens": fallback.estimated_tokens(),
"fallback": true,
}))
}
}
}
}
impl LeafAgent for ComplexityEstimator {
fn capability(&self) -> LeafCapability {
LeafCapability::ComplexityEstimation
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_json_response() {
let estimator = ComplexityEstimator::new();
let json_response = r#"{"score": 0.7, "reasoning": "Complex task", "estimated_tokens": 3000, "should_split": true}"#;
let complexity = estimator.parse_response(json_response);
assert!((complexity.score() - 0.7).abs() < 0.01);
assert!(complexity.should_split());
}
#[test]
fn test_parse_text_response() {
let estimator = ComplexityEstimator::new();
let text_response = "This is a very complex task";
let complexity = estimator.parse_response(text_response);
assert!(complexity.score() > 0.6);
}
}

244
src/agents/leaf/executor.rs Normal file
View File

@@ -0,0 +1,244 @@
//! Task executor agent - the main worker that uses tools.
//!
//! This is a refactored version of the original agent loop,
//! now as a leaf agent in the hierarchical tree.
use async_trait::async_trait;
use serde_json::json;
use crate::agents::{
Agent, AgentContext, AgentId, AgentResult, AgentType, LeafAgent, LeafCapability,
};
use crate::llm::{ChatMessage, Role, ToolCall};
use crate::task::Task;
use crate::tools::ToolRegistry;
/// Agent that executes tasks using tools.
///
/// # Algorithm
/// 1. Build system prompt with available tools
/// 2. Call LLM with task description
/// 3. If LLM requests tool call: execute, feed back result
/// 4. Repeat until LLM produces final response or max iterations
///
/// # Budget Management
/// - Tracks token usage and costs
/// - Stops if budget is exhausted
pub struct TaskExecutor {
id: AgentId,
}
impl TaskExecutor {
/// Create a new task executor.
pub fn new() -> Self {
Self { id: AgentId::new() }
}
/// Build the system prompt for task execution.
fn build_system_prompt(&self, workspace: &str, tools: &ToolRegistry) -> String {
let tool_descriptions = tools
.list_tools()
.iter()
.map(|t| format!("- **{}**: {}", t.name, t.description))
.collect::<Vec<_>>()
.join("\n");
format!(
r#"You are an autonomous task executor with access to tools.
You operate in the workspace: {workspace}
## Available Tools
{tool_descriptions}
## Rules
1. Use tools to accomplish the task - don't just describe what to do
2. Read files before editing them
3. Verify your work when possible
4. If stuck, explain what's blocking you
5. When done, summarize what you accomplished
## Response
When task is complete, provide a clear summary of:
- What you did
- Files created/modified
- How to verify the result"#,
workspace = workspace,
tool_descriptions = tool_descriptions
)
}
/// Execute a single tool call.
async fn execute_tool_call(
&self,
tool_call: &ToolCall,
ctx: &AgentContext,
) -> anyhow::Result<String> {
let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments)
.unwrap_or(serde_json::Value::Null);
ctx.tools
.execute(&tool_call.function.name, args, &ctx.workspace)
.await
}
/// Run the agent loop for a task.
async fn run_loop(
&self,
task: &Task,
model: &str,
ctx: &AgentContext,
) -> (String, u64, Vec<String>) {
let mut total_cost = 0u64;
let mut tool_log = Vec::new();
// Build initial messages
let system_prompt = self.build_system_prompt(&ctx.workspace_str(), &ctx.tools);
let mut messages = vec![
ChatMessage {
role: Role::System,
content: Some(system_prompt),
tool_calls: None,
tool_call_id: None,
},
ChatMessage {
role: Role::User,
content: Some(task.description().to_string()),
tool_calls: None,
tool_call_id: None,
},
];
// Get tool schemas
let tool_schemas = ctx.tools.get_tool_schemas();
// Agent loop
for iteration in 0..ctx.max_iterations {
tracing::debug!("TaskExecutor iteration {}", iteration + 1);
// Check budget
let remaining = task.budget().remaining_cents();
if remaining == 0 && total_cost > 0 {
return (
"Budget exhausted before task completion".to_string(),
total_cost,
tool_log,
);
}
// Call LLM
let response = match ctx.llm.chat_completion(model, &messages, Some(&tool_schemas)).await {
Ok(r) => r,
Err(e) => {
return (
format!("LLM error: {}", e),
total_cost,
tool_log,
);
}
};
// Estimate cost (rough: ~1 cent per request for cheap models)
total_cost += 2;
// Check for tool calls
if let Some(tool_calls) = &response.tool_calls {
if !tool_calls.is_empty() {
// Add assistant message with tool calls
messages.push(ChatMessage {
role: Role::Assistant,
content: response.content.clone(),
tool_calls: Some(tool_calls.clone()),
tool_call_id: None,
});
// Execute each tool call
for tool_call in tool_calls {
tool_log.push(format!(
"Tool: {} Args: {}",
tool_call.function.name,
tool_call.function.arguments
));
let result = match self.execute_tool_call(tool_call, ctx).await {
Ok(output) => output,
Err(e) => format!("Error: {}", e),
};
// Add tool result
messages.push(ChatMessage {
role: Role::Tool,
content: Some(result),
tool_calls: None,
tool_call_id: Some(tool_call.id.clone()),
});
}
continue;
}
}
// No tool calls - final response
if let Some(content) = response.content {
return (content, total_cost, tool_log);
}
// Empty response
return (
"LLM returned empty response".to_string(),
total_cost,
tool_log,
);
}
(
format!("Max iterations ({}) reached", ctx.max_iterations),
total_cost,
tool_log,
)
}
}
impl Default for TaskExecutor {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Agent for TaskExecutor {
fn id(&self) -> &AgentId {
&self.id
}
fn agent_type(&self) -> AgentType {
AgentType::TaskExecutor
}
fn description(&self) -> &str {
"Executes tasks using tools (file ops, terminal, search, etc.)"
}
async fn execute(&self, task: &mut Task, ctx: &AgentContext) -> AgentResult {
// Use default model or from context
let model = &ctx.config.default_model;
let (output, cost, tool_log) = self.run_loop(task, model, ctx).await;
// Update task budget
let _ = task.budget_mut().try_spend(cost);
AgentResult::success(&output, cost)
.with_model(model)
.with_data(json!({
"tool_calls": tool_log.len(),
"tools_used": tool_log,
}))
}
}
impl LeafAgent for TaskExecutor {
fn capability(&self) -> LeafCapability {
LeafCapability::TaskExecution
}
}

18
src/agents/leaf/mod.rs Normal file
View File

@@ -0,0 +1,18 @@
//! Leaf agents - specialized agents that do actual work.
//!
//! # Leaf Agent Types
//! - `ComplexityEstimator`: Estimates task complexity (0-1 score)
//! - `ModelSelector`: Selects optimal model for task/budget
//! - `TaskExecutor`: Executes tasks using tools (main worker)
//! - `Verifier`: Validates task completion
mod complexity;
mod model_select;
mod executor;
mod verifier;
pub use complexity::ComplexityEstimator;
pub use model_select::ModelSelector;
pub use executor::TaskExecutor;
pub use verifier::Verifier;

View File

@@ -0,0 +1,364 @@
//! Model selection agent with U-curve cost optimization.
//!
//! # U-Curve Optimization
//! The total expected cost follows a U-shaped curve:
//! - Cheap models: Low per-token cost, but may fail/retry, use more tokens
//! - Expensive models: High per-token cost, but succeed more often
//! - Optimal: Somewhere in the middle, minimizing total expected cost
//!
//! # Cost Model
//! Expected Cost = base_cost * (1 + failure_rate * retry_multiplier) * token_efficiency
use async_trait::async_trait;
use serde_json::json;
use crate::agents::{
Agent, AgentContext, AgentId, AgentResult, AgentType, LeafAgent, LeafCapability,
};
use crate::budget::PricingInfo;
use crate::task::Task;
/// Agent that selects the optimal model for a task.
///
/// # Algorithm
/// 1. Get task complexity and budget constraints
/// 2. Fetch available models and pricing
/// 3. For each model, calculate expected total cost
/// 4. Return model with minimum expected cost within budget
pub struct ModelSelector {
id: AgentId,
}
/// Model recommendation from the selector.
#[derive(Debug, Clone)]
pub struct ModelRecommendation {
/// Recommended model ID
pub model_id: String,
/// Expected cost in cents
pub expected_cost_cents: u64,
/// Confidence in this recommendation (0-1)
pub confidence: f64,
/// Reasoning for the selection
pub reasoning: String,
/// Alternative models if primary fails
pub fallbacks: Vec<String>,
}
impl ModelSelector {
/// Create a new model selector.
pub fn new() -> Self {
Self { id: AgentId::new() }
}
/// Calculate expected cost for a model given task complexity.
///
/// # Formula
/// ```text
/// expected_cost = base_cost * (1 + failure_prob * retry_cost) * inefficiency_factor
/// ```
///
/// # Parameters
/// - `pricing`: Model pricing info
/// - `complexity`: Task complexity (0-1)
/// - `estimated_tokens`: Estimated tokens needed
///
/// # Returns
/// Expected cost in cents
///
/// # Pure Function
/// No side effects, deterministic output.
fn calculate_expected_cost(
&self,
pricing: &PricingInfo,
complexity: f64,
estimated_tokens: u64,
) -> ExpectedCost {
// Model capability estimate based on pricing tier
// Higher price generally means more capable
let avg_cost = pricing.average_cost_per_token();
let capability = self.estimate_capability(avg_cost);
// Failure probability: higher complexity + lower capability = more failures
// Formula: P(fail) = complexity * (1 - capability)
let failure_prob = (complexity * (1.0 - capability)).clamp(0.0, 0.9);
// Token inefficiency: weaker models need more tokens
// Formula: inefficiency = 1 + (1 - capability) * 0.5
let inefficiency = 1.0 + (1.0 - capability) * 0.5;
// Retry cost: if it fails, we pay again (possibly with a better model)
let retry_multiplier = 1.5; // Retries cost 50% more (wasted context)
// Base cost for estimated tokens
let input_tokens = estimated_tokens / 2;
let output_tokens = estimated_tokens / 2;
let base_cost = pricing.calculate_cost_cents(input_tokens, output_tokens);
// Adjusted for inefficiency (weak models use more tokens)
let adjusted_tokens = ((estimated_tokens as f64) * inefficiency) as u64;
let adjusted_cost = pricing.calculate_cost_cents(adjusted_tokens / 2, adjusted_tokens / 2);
// Expected cost including retry probability
let expected_cost = (adjusted_cost as f64) * (1.0 + failure_prob * retry_multiplier);
ExpectedCost {
model_id: pricing.model_id.clone(),
base_cost_cents: base_cost,
expected_cost_cents: expected_cost.ceil() as u64,
failure_probability: failure_prob,
capability,
inefficiency,
}
}
/// Estimate model capability from its cost.
///
/// # Heuristic
/// More expensive models are generally more capable.
/// Uses log scale to normalize across price ranges.
///
/// # Returns
/// Capability score 0-1
fn estimate_capability(&self, avg_cost_per_token: f64) -> f64 {
// Cost tiers (per token):
// < 0.0001: weak (capability ~0.3)
// 0.0001-0.001: moderate (capability ~0.6)
// > 0.001: strong (capability ~0.9)
if avg_cost_per_token < 0.0000001 {
return 0.3; // Free/very cheap
}
// Log scale normalization
let log_cost = avg_cost_per_token.log10();
// Map from ~-7 (cheap) to ~-3 (expensive) => 0.3 to 0.95
let normalized = ((log_cost + 7.0) / 4.0).clamp(0.0, 1.0);
0.3 + normalized * 0.65
}
/// Select optimal model from available options.
///
/// # Algorithm
/// 1. Calculate expected cost for each model
/// 2. Filter models exceeding budget
/// 3. Select model with minimum expected cost
/// 4. Include fallbacks in case of failure
///
/// # Preconditions
/// - `models` is non-empty
/// - `budget_cents > 0`
fn select_optimal(
&self,
models: &[PricingInfo],
complexity: f64,
estimated_tokens: u64,
budget_cents: u64,
) -> Option<ModelRecommendation> {
if models.is_empty() {
return None;
}
// Calculate expected cost for all models
let mut costs: Vec<ExpectedCost> = models
.iter()
.map(|m| self.calculate_expected_cost(m, complexity, estimated_tokens))
.collect();
// Sort by expected cost (ascending)
costs.sort_by(|a, b| {
a.expected_cost_cents
.cmp(&b.expected_cost_cents)
});
// Find cheapest model within budget
let within_budget: Vec<_> = costs
.iter()
.filter(|c| c.expected_cost_cents <= budget_cents)
.collect();
let selected = within_budget.first().copied().or(costs.first())?;
// Get fallback models (next best options)
let fallbacks: Vec<String> = costs
.iter()
.filter(|c| c.model_id != selected.model_id)
.take(2)
.map(|c| c.model_id.clone())
.collect();
Some(ModelRecommendation {
model_id: selected.model_id.clone(),
expected_cost_cents: selected.expected_cost_cents,
confidence: 1.0 - selected.failure_probability,
reasoning: format!(
"Selected {} with expected cost {} cents (capability: {:.2}, failure prob: {:.2})",
selected.model_id,
selected.expected_cost_cents,
selected.capability,
selected.failure_probability
),
fallbacks,
})
}
}
/// Intermediate calculation result for a model.
#[derive(Debug)]
struct ExpectedCost {
model_id: String,
base_cost_cents: u64,
expected_cost_cents: u64,
failure_probability: f64,
capability: f64,
inefficiency: f64,
}
impl Default for ModelSelector {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Agent for ModelSelector {
fn id(&self) -> &AgentId {
&self.id
}
fn agent_type(&self) -> AgentType {
AgentType::ModelSelector
}
fn description(&self) -> &str {
"Selects optimal model for task based on complexity and budget (U-curve optimization)"
}
/// Select the optimal model for a task.
///
/// # Expected Input
/// Task should have complexity data in its context (from ComplexityEstimator).
///
/// # Returns
/// AgentResult with ModelRecommendation in the `data` field.
async fn execute(&self, task: &mut Task, ctx: &AgentContext) -> AgentResult {
// Get complexity (default to moderate if not estimated)
// In practice, this would be passed from ComplexityEstimator
let complexity = 0.5; // TODO: Get from task metadata
let estimated_tokens = 2000_u64;
// Get available budget
let budget_cents = task.budget().remaining_cents();
// Fetch pricing for all models
let models = ctx.pricing.models_by_cost().await;
if models.is_empty() {
// Use hardcoded defaults if no pricing available
return AgentResult::success(
"Using default model (no pricing data available)",
0,
)
.with_data(json!({
"model_id": "openai/gpt-4.1-mini",
"expected_cost_cents": 10,
"confidence": 0.5,
"reasoning": "Fallback to default model",
"fallbacks": ["openai/gpt-4o-mini", "anthropic/claude-3-haiku"],
}));
}
match self.select_optimal(&models, complexity, estimated_tokens, budget_cents) {
Some(rec) => {
AgentResult::success(
&rec.reasoning,
1, // Minimal cost for selection itself
)
.with_data(json!({
"model_id": rec.model_id,
"expected_cost_cents": rec.expected_cost_cents,
"confidence": rec.confidence,
"reasoning": rec.reasoning,
"fallbacks": rec.fallbacks,
}))
}
None => {
AgentResult::failure(
"No suitable model found within budget",
0,
)
}
}
}
}
impl LeafAgent for ModelSelector {
fn capability(&self) -> LeafCapability {
LeafCapability::ModelSelection
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_pricing(id: &str, prompt: f64, completion: f64) -> PricingInfo {
PricingInfo {
model_id: id.to_string(),
prompt_cost_per_million: prompt,
completion_cost_per_million: completion,
context_length: 100000,
max_output_tokens: None,
}
}
#[test]
fn test_expected_cost_u_curve() {
let selector = ModelSelector::new();
let cheap = make_pricing("cheap", 0.1, 0.2);
let medium = make_pricing("medium", 1.0, 2.0);
let expensive = make_pricing("expensive", 10.0, 20.0);
let complexity = 0.7;
let tokens = 2000;
let cheap_cost = selector.calculate_expected_cost(&cheap, complexity, tokens);
let medium_cost = selector.calculate_expected_cost(&medium, complexity, tokens);
let expensive_cost = selector.calculate_expected_cost(&expensive, complexity, tokens);
// For complex tasks, medium should be optimal (U-curve)
// Cheap model has high failure rate
// Expensive model has high base cost
println!("Cheap: {} (fail: {})", cheap_cost.expected_cost_cents, cheap_cost.failure_probability);
println!("Medium: {} (fail: {})", medium_cost.expected_cost_cents, medium_cost.failure_probability);
println!("Expensive: {} (fail: {})", expensive_cost.expected_cost_cents, expensive_cost.failure_probability);
// Basic sanity check: cheap model should have higher failure rate
assert!(cheap_cost.failure_probability > medium_cost.failure_probability);
}
#[test]
fn test_select_optimal() {
let selector = ModelSelector::new();
let models = vec![
make_pricing("cheap", 0.1, 0.2),
make_pricing("medium", 1.0, 2.0),
make_pricing("expensive", 10.0, 20.0),
];
// For moderate complexity, should pick cost-effective option
let rec = selector.select_optimal(&models, 0.5, 1000, 1000);
assert!(rec.is_some());
// For very low budget, might be forced to pick cheap
let rec_low = selector.select_optimal(&models, 0.5, 1000, 1);
assert!(rec_low.is_some());
}
}

347
src/agents/leaf/verifier.rs Normal file
View File

@@ -0,0 +1,347 @@
//! Verification agent - validates task completion.
//!
//! # Verification Strategy (Hybrid)
//! 1. Try programmatic verification first (fast, deterministic)
//! 2. Fall back to LLM verification if needed
//!
//! # Programmatic Checks
//! - File exists
//! - Command succeeds
//! - Output matches pattern
use async_trait::async_trait;
use std::path::Path;
use std::process::Stdio;
use tokio::process::Command;
use crate::agents::{
Agent, AgentContext, AgentId, AgentResult, AgentType, LeafAgent, LeafCapability,
};
use crate::llm::{ChatMessage, Role};
use crate::task::{ProgrammaticCheck, Task, VerificationCriteria, VerificationMethod, VerificationResult};
/// Agent that verifies task completion.
///
/// # Hybrid Verification
/// - Programmatic: Fast, deterministic, no cost
/// - LLM: Flexible, for subjective criteria
pub struct Verifier {
id: AgentId,
}
impl Verifier {
/// Create a new verifier.
pub fn new() -> Self {
Self { id: AgentId::new() }
}
/// Execute a programmatic check.
///
/// # Returns
/// `Ok(true)` if check passes, `Ok(false)` if fails, `Err` on error.
async fn run_programmatic_check(
&self,
check: &ProgrammaticCheck,
workspace: &Path,
) -> Result<bool, String> {
match check {
ProgrammaticCheck::FileExists { path } => {
let full_path = workspace.join(path);
Ok(full_path.exists())
}
ProgrammaticCheck::FileContains { path, content } => {
let full_path = workspace.join(path);
match tokio::fs::read_to_string(&full_path).await {
Ok(file_content) => Ok(file_content.contains(content)),
Err(_) => Ok(false),
}
}
ProgrammaticCheck::CommandSucceeds { command } => {
let output = Command::new("sh")
.arg("-c")
.arg(command)
.current_dir(workspace)
.stdout(Stdio::null())
.stderr(Stdio::null())
.status()
.await
.map_err(|e| e.to_string())?;
Ok(output.success())
}
ProgrammaticCheck::CommandOutputMatches { command, pattern } => {
let output = Command::new("sh")
.arg("-c")
.arg(command)
.current_dir(workspace)
.output()
.await
.map_err(|e| e.to_string())?;
let stdout = String::from_utf8_lossy(&output.stdout);
let regex = regex::Regex::new(pattern).map_err(|e| e.to_string())?;
Ok(regex.is_match(&stdout))
}
ProgrammaticCheck::DirectoryExists { path } => {
let full_path = workspace.join(path);
Ok(full_path.is_dir())
}
ProgrammaticCheck::FileMatchesRegex { path, pattern } => {
let full_path = workspace.join(path);
match tokio::fs::read_to_string(&full_path).await {
Ok(content) => {
let regex = regex::Regex::new(pattern).map_err(|e| e.to_string())?;
Ok(regex.is_match(&content))
}
Err(_) => Ok(false),
}
}
ProgrammaticCheck::All(checks) => {
for c in checks {
if !Box::pin(self.run_programmatic_check(c, workspace)).await? {
return Ok(false);
}
}
Ok(true)
}
ProgrammaticCheck::Any(checks) => {
for c in checks {
if Box::pin(self.run_programmatic_check(c, workspace)).await? {
return Ok(true);
}
}
Ok(false)
}
}
}
/// Verify using LLM.
///
/// # Parameters
/// - `task`: The task that was executed
/// - `success_criteria`: What success looks like
/// - `ctx`: Agent context
///
/// # Returns
/// VerificationResult with LLM's assessment
async fn verify_with_llm(
&self,
task: &Task,
success_criteria: &str,
ctx: &AgentContext,
) -> VerificationResult {
let prompt = format!(
r#"You are verifying if a task was completed correctly.
Task: {}
Success Criteria: {}
Based on your assessment, respond with a JSON object:
{{
"passed": true/false,
"reasoning": "explanation of why the task passed or failed"
}}
Respond ONLY with the JSON object."#,
task.description(),
success_criteria
);
let messages = vec![
ChatMessage {
role: Role::System,
content: Some("You are a precise task verifier. Respond only with JSON.".to_string()),
tool_calls: None,
tool_call_id: None,
},
ChatMessage {
role: Role::User,
content: Some(prompt),
tool_calls: None,
tool_call_id: None,
},
];
let model = "openai/gpt-4.1-mini";
match ctx.llm.chat_completion(model, &messages, None).await {
Ok(response) => {
let content = response.content.unwrap_or_default();
self.parse_llm_verification(&content, model)
}
Err(e) => {
VerificationResult::fail(
format!("LLM verification failed: {}", e),
VerificationMethod::Llm { model: model.to_string() },
0,
)
}
}
}
/// Parse LLM verification response.
fn parse_llm_verification(&self, response: &str, model: &str) -> VerificationResult {
if let Ok(json) = serde_json::from_str::<serde_json::Value>(response) {
let passed = json["passed"].as_bool().unwrap_or(false);
let reasoning = json["reasoning"]
.as_str()
.unwrap_or("No reasoning provided")
.to_string();
if passed {
VerificationResult::pass(
reasoning,
VerificationMethod::Llm { model: model.to_string() },
1, // Minimal cost
)
} else {
VerificationResult::fail(
reasoning,
VerificationMethod::Llm { model: model.to_string() },
1,
)
}
} else {
// Try to infer from text
let passed = response.to_lowercase().contains("pass")
|| response.to_lowercase().contains("success")
|| response.to_lowercase().contains("completed");
if passed {
VerificationResult::pass(
response.to_string(),
VerificationMethod::Llm { model: model.to_string() },
1,
)
} else {
VerificationResult::fail(
response.to_string(),
VerificationMethod::Llm { model: model.to_string() },
1,
)
}
}
}
/// Run verification according to criteria.
async fn verify(
&self,
task: &Task,
ctx: &AgentContext,
) -> VerificationResult {
match task.verification() {
VerificationCriteria::None => {
VerificationResult::pass(
"No verification required",
VerificationMethod::None,
0,
)
}
VerificationCriteria::Programmatic(check) => {
match self.run_programmatic_check(check, &ctx.workspace).await {
Ok(true) => VerificationResult::pass(
"Programmatic check passed",
VerificationMethod::Programmatic,
0,
),
Ok(false) => VerificationResult::fail(
"Programmatic check failed",
VerificationMethod::Programmatic,
0,
),
Err(e) => VerificationResult::fail(
format!("Programmatic check error: {}", e),
VerificationMethod::Programmatic,
0,
),
}
}
VerificationCriteria::LlmBased { success_criteria } => {
self.verify_with_llm(task, success_criteria, ctx).await
}
VerificationCriteria::Hybrid { programmatic, llm_fallback } => {
// Try programmatic first
match self.run_programmatic_check(programmatic, &ctx.workspace).await {
Ok(true) => VerificationResult::pass(
"Programmatic check passed",
VerificationMethod::Programmatic,
0,
),
Ok(false) => {
// Fall back to LLM
self.verify_with_llm(task, llm_fallback, ctx).await
}
Err(_) => {
// Error in programmatic, fall back to LLM
self.verify_with_llm(task, llm_fallback, ctx).await
}
}
}
}
}
}
impl Default for Verifier {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Agent for Verifier {
fn id(&self) -> &AgentId {
&self.id
}
fn agent_type(&self) -> AgentType {
AgentType::Verifier
}
fn description(&self) -> &str {
"Verifies task completion using programmatic checks and LLM fallback"
}
async fn execute(&self, task: &mut Task, ctx: &AgentContext) -> AgentResult {
let result = self.verify(task, ctx).await;
if result.passed() {
AgentResult::success(
result.reasoning(),
result.cost_cents(),
)
.with_data(serde_json::json!({
"passed": true,
"method": format!("{:?}", result.method()),
"reasoning": result.reasoning(),
}))
} else {
AgentResult::failure(
result.reasoning(),
result.cost_cents(),
)
.with_data(serde_json::json!({
"passed": false,
"method": format!("{:?}", result.method()),
"reasoning": result.reasoning(),
}))
}
}
}
impl LeafAgent for Verifier {
fn capability(&self) -> LeafCapability {
LeafCapability::Verification
}
}

139
src/agents/mod.rs Normal file
View File

@@ -0,0 +1,139 @@
//! Agents module - the hierarchical agent tree.
//!
//! # Agent Types
//! - **RootAgent**: Top-level orchestrator, receives tasks from API
//! - **NodeAgent**: Intermediate orchestrator, delegates to children
//! - **LeafAgent**: Specialized agents that do actual work
//!
//! # Leaf Agent Specializations
//! - `ComplexityEstimator`: Estimates task difficulty
//! - `ModelSelector`: Chooses optimal model for cost/capability
//! - `TaskExecutor`: Executes tasks using tools
//! - `Verifier`: Validates task completion
//!
//! # Design Principles
//! - Agents communicate synchronously (parent calls child, child returns)
//! - Designed for future async message passing migration
//! - All operations return `Result` with meaningful errors
mod types;
mod context;
mod tree;
pub mod orchestrator;
pub mod leaf;
pub use types::{AgentId, AgentType, AgentResult, AgentError, Complexity};
pub use context::AgentContext;
pub use tree::{AgentTree, AgentRef};
use async_trait::async_trait;
use crate::task::Task;
/// Base trait for all agents.
///
/// # Invariants
/// - `execute()` returns `Ok` only if the task was actually completed or delegated
/// - `execute()` never panics; all errors are returned as `Err`
///
/// # Design for Provability
/// - Preconditions and postconditions are documented
/// - Pure functions are preferred where possible
#[async_trait]
pub trait Agent: Send + Sync {
/// Get the unique identifier for this agent.
fn id(&self) -> &AgentId;
/// Get the type/role of this agent.
fn agent_type(&self) -> AgentType;
/// Execute a task.
///
/// # Preconditions
/// - `task.budget().remaining_cents() > 0` (has budget)
/// - `task.status() == Pending || task.status() == Running`
///
/// # Postconditions
/// - On success: task is completed or delegated appropriately
/// - `result.cost_cents <= task.budget().total_cents()`
///
/// # Errors
/// Returns `Err` if:
/// - Task cannot be executed (insufficient budget, invalid state)
/// - Execution fails (tool error, LLM error, etc.)
async fn execute(&self, task: &mut Task, ctx: &AgentContext) -> AgentResult;
/// Get a human-readable description of this agent.
fn description(&self) -> &str {
"Generic agent"
}
}
/// Trait for orchestrator agents (Root and Node) that can have children.
///
/// # Child Management
/// Orchestrators can delegate work to child agents.
#[async_trait]
pub trait OrchestratorAgent: Agent {
/// Get references to child agents.
fn children(&self) -> Vec<AgentRef>;
/// Find a child agent by capability.
fn find_child(&self, agent_type: AgentType) -> Option<AgentRef>;
/// Delegate a task to a specific child.
///
/// # Preconditions
/// - Child exists and is capable of handling the task
/// - Task has sufficient budget
///
/// # Postconditions
/// - Child's execute() is called
/// - Results are aggregated and returned
async fn delegate(&self, task: &mut Task, child: AgentRef, ctx: &AgentContext) -> AgentResult;
/// Delegate multiple tasks to appropriate children.
///
/// # Preconditions
/// - Sum of task budgets <= available budget
/// - All tasks can be matched to capable children
async fn delegate_all(
&self,
tasks: &mut [Task],
ctx: &AgentContext,
) -> Vec<AgentResult>;
}
/// Trait for leaf agents with specialized capabilities.
pub trait LeafAgent: Agent {
/// Get the specific capability of this leaf agent.
fn capability(&self) -> LeafCapability;
}
/// Capabilities of leaf agents.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum LeafCapability {
/// Can estimate task complexity
ComplexityEstimation,
/// Can select optimal model for a task
ModelSelection,
/// Can execute tasks using tools
TaskExecution,
/// Can verify task completion
Verification,
}
impl LeafCapability {
/// Get the agent type for this capability.
pub fn agent_type(&self) -> AgentType {
match self {
Self::ComplexityEstimation => AgentType::ComplexityEstimator,
Self::ModelSelection => AgentType::ModelSelector,
Self::TaskExecution => AgentType::TaskExecutor,
Self::Verification => AgentType::Verifier,
}
}
}

View File

@@ -0,0 +1,8 @@
//! Orchestrator agents - Root and Node agents that manage the tree.
mod root;
mod node;
pub use root::RootAgent;
pub use node::NodeAgent;

View File

@@ -0,0 +1,163 @@
//! Node agent - intermediate orchestrator in the agent tree.
//!
//! Node agents are like mini-root agents that can:
//! - Receive delegated tasks from parent
//! - Split complex subtasks further
//! - Delegate to their own children
//! - Aggregate results for parent
use std::sync::Arc;
use async_trait::async_trait;
use serde_json::json;
use crate::agents::{
Agent, AgentContext, AgentId, AgentRef, AgentResult, AgentType, OrchestratorAgent,
leaf::{TaskExecutor, Verifier},
};
use crate::task::Task;
/// Node agent - intermediate orchestrator.
///
/// # Purpose
/// Handles subtasks that may still be complex enough
/// to warrant further splitting.
///
/// # Differences from Root
/// - No complexity estimation (parent already decided to split)
/// - Simpler child set (just executor and verifier)
/// - Limited split depth (prevents infinite recursion)
pub struct NodeAgent {
id: AgentId,
/// Name for identification in logs
name: String,
// Child agents
task_executor: Arc<TaskExecutor>,
verifier: Arc<Verifier>,
// Child node agents (for further splitting)
child_nodes: Vec<Arc<NodeAgent>>,
}
impl NodeAgent {
/// Create a new node agent.
pub fn new(name: impl Into<String>) -> Self {
Self {
id: AgentId::new(),
name: name.into(),
task_executor: Arc::new(TaskExecutor::new()),
verifier: Arc::new(Verifier::new()),
child_nodes: Vec::new(),
}
}
/// Create a node with custom executor.
pub fn with_executor(mut self, executor: Arc<TaskExecutor>) -> Self {
self.task_executor = executor;
self
}
/// Add a child node for hierarchical delegation.
pub fn add_child_node(&mut self, child: Arc<NodeAgent>) {
self.child_nodes.push(child);
}
/// Get the node's name.
pub fn name(&self) -> &str {
&self.name
}
}
impl Default for NodeAgent {
fn default() -> Self {
Self::new("node")
}
}
#[async_trait]
impl Agent for NodeAgent {
fn id(&self) -> &AgentId {
&self.id
}
fn agent_type(&self) -> AgentType {
AgentType::Node
}
fn description(&self) -> &str {
"Intermediate orchestrator for subtask delegation"
}
async fn execute(&self, task: &mut Task, ctx: &AgentContext) -> AgentResult {
tracing::debug!("NodeAgent '{}' executing task", self.name);
// Execute the task
let result = self.task_executor.execute(task, ctx).await;
if !result.success {
return result;
}
// Verify if criteria specified
let verification = self.verifier.execute(task, ctx).await;
if verification.success {
result
} else {
AgentResult::failure(
format!(
"Task completed but verification failed: {}",
verification.output
),
result.cost_cents + verification.cost_cents,
)
.with_data(json!({
"execution": result.data,
"verification": verification.data,
}))
}
}
}
#[async_trait]
impl OrchestratorAgent for NodeAgent {
fn children(&self) -> Vec<AgentRef> {
let mut children: Vec<AgentRef> = vec![
Arc::clone(&self.task_executor) as AgentRef,
Arc::clone(&self.verifier) as AgentRef,
];
for node in &self.child_nodes {
children.push(Arc::clone(node) as AgentRef);
}
children
}
fn find_child(&self, agent_type: AgentType) -> Option<AgentRef> {
match agent_type {
AgentType::TaskExecutor => Some(Arc::clone(&self.task_executor) as AgentRef),
AgentType::Verifier => Some(Arc::clone(&self.verifier) as AgentRef),
AgentType::Node => self.child_nodes.first().map(|n| Arc::clone(n) as AgentRef),
_ => None,
}
}
async fn delegate(&self, task: &mut Task, child: AgentRef, ctx: &AgentContext) -> AgentResult {
child.execute(task, ctx).await
}
async fn delegate_all(&self, tasks: &mut [Task], ctx: &AgentContext) -> Vec<AgentResult> {
let mut results = Vec::with_capacity(tasks.len());
for task in tasks {
let result = self.task_executor.execute(task, ctx).await;
results.push(result);
}
results
}
}

View File

@@ -0,0 +1,346 @@
//! Root agent - top-level orchestrator of the agent tree.
//!
//! # Responsibilities
//! 1. Receive tasks from the API
//! 2. Estimate complexity
//! 3. Decide: execute directly or split into subtasks
//! 4. Delegate to appropriate children
//! 5. Aggregate results
use std::sync::Arc;
use async_trait::async_trait;
use serde_json::json;
use crate::agents::{
Agent, AgentContext, AgentId, AgentRef, AgentResult, AgentType, Complexity,
OrchestratorAgent,
leaf::{ComplexityEstimator, ModelSelector, TaskExecutor, Verifier},
};
use crate::budget::Budget;
use crate::task::{Task, Subtask, SubtaskPlan, VerificationCriteria};
/// Root agent - the top of the agent tree.
///
/// # Task Processing Flow
/// ```text
/// 1. Estimate complexity (ComplexityEstimator)
/// 2. If simple: execute directly (TaskExecutor)
/// 3. If complex:
/// a. Split into subtasks (LLM-based)
/// b. Select model for each subtask (ModelSelector)
/// c. Execute subtasks (TaskExecutor)
/// d. Verify results (Verifier)
/// 4. Return aggregated result
/// ```
pub struct RootAgent {
id: AgentId,
// Child agents
complexity_estimator: Arc<ComplexityEstimator>,
model_selector: Arc<ModelSelector>,
task_executor: Arc<TaskExecutor>,
verifier: Arc<Verifier>,
}
impl RootAgent {
/// Create a new root agent with default children.
pub fn new() -> Self {
Self {
id: AgentId::new(),
complexity_estimator: Arc::new(ComplexityEstimator::new()),
model_selector: Arc::new(ModelSelector::new()),
task_executor: Arc::new(TaskExecutor::new()),
verifier: Arc::new(Verifier::new()),
}
}
/// Estimate complexity of a task.
async fn estimate_complexity(&self, task: &mut Task, ctx: &AgentContext) -> Complexity {
let result = self.complexity_estimator.execute(task, ctx).await;
if let Some(data) = result.data {
let score = data["score"].as_f64().unwrap_or(0.5);
let reasoning = data["reasoning"].as_str().unwrap_or("").to_string();
let estimated_tokens = data["estimated_tokens"].as_u64().unwrap_or(2000);
let should_split = data["should_split"].as_bool().unwrap_or(false);
Complexity::new(score, reasoning, estimated_tokens)
.with_split(should_split)
} else {
Complexity::moderate("Could not estimate complexity")
}
}
/// Split a complex task into subtasks.
///
/// Uses LLM to analyze the task and propose subtasks.
async fn split_task(&self, task: &Task, ctx: &AgentContext) -> Result<SubtaskPlan, AgentResult> {
let prompt = format!(
r#"You are a task planner. Break down this task into smaller, manageable subtasks.
Task: {}
Respond with a JSON object:
{{
"subtasks": [
{{
"description": "What to do",
"verification": "How to verify it's done",
"weight": 1.0
}}
],
"reasoning": "Why this breakdown makes sense"
}}
Guidelines:
- Each subtask should be independently executable
- Include verification for each subtask
- Weight indicates relative effort (higher = more work)
- Keep subtasks focused and specific
Respond ONLY with the JSON object."#,
task.description()
);
let messages = vec![
crate::llm::ChatMessage {
role: crate::llm::Role::System,
content: Some("You are a precise task planner. Respond only with JSON.".to_string()),
tool_calls: None,
tool_call_id: None,
},
crate::llm::ChatMessage {
role: crate::llm::Role::User,
content: Some(prompt),
tool_calls: None,
tool_call_id: None,
},
];
let response = ctx.llm
.chat_completion("openai/gpt-4.1-mini", &messages, None)
.await
.map_err(|e| AgentResult::failure(format!("LLM error: {}", e), 1))?;
let content = response.content.unwrap_or_default();
self.parse_subtask_plan(&content, task.id())
}
/// Parse LLM response into SubtaskPlan.
fn parse_subtask_plan(
&self,
response: &str,
parent_id: crate::task::TaskId,
) -> Result<SubtaskPlan, AgentResult> {
let json: serde_json::Value = serde_json::from_str(response)
.map_err(|e| AgentResult::failure(format!("Failed to parse subtasks: {}", e), 0))?;
let reasoning = json["reasoning"]
.as_str()
.unwrap_or("No reasoning provided")
.to_string();
let subtasks: Vec<Subtask> = json["subtasks"]
.as_array()
.map(|arr| {
arr.iter()
.map(|s| {
let desc = s["description"].as_str().unwrap_or("").to_string();
let verification = s["verification"].as_str().unwrap_or("");
let weight = s["weight"].as_f64().unwrap_or(1.0);
Subtask::new(
desc,
VerificationCriteria::llm_based(verification),
weight,
)
})
.collect()
})
.unwrap_or_default();
if subtasks.is_empty() {
return Err(AgentResult::failure("No subtasks generated", 1));
}
SubtaskPlan::new(parent_id, subtasks, reasoning)
.map_err(|e| AgentResult::failure(format!("Invalid subtask plan: {}", e), 0))
}
/// Execute subtasks and aggregate results.
async fn execute_subtasks(
&self,
subtask_plan: SubtaskPlan,
parent_budget: &Budget,
ctx: &AgentContext,
) -> AgentResult {
// Convert plan to tasks
let mut tasks = match subtask_plan.into_tasks(parent_budget) {
Ok(t) => t,
Err(e) => return AgentResult::failure(format!("Failed to create subtasks: {}", e), 0),
};
let mut results = Vec::new();
let mut total_cost = 0u64;
// Execute each subtask
for task in &mut tasks {
let result = self.task_executor.execute(task, ctx).await;
total_cost += result.cost_cents;
results.push(result);
}
// Aggregate results
let successes = results.iter().filter(|r| r.success).count();
let total = results.len();
if successes == total {
AgentResult::success(
format!("All {} subtasks completed successfully", total),
total_cost,
)
.with_data(json!({
"subtasks_total": total,
"subtasks_succeeded": successes,
"results": results.iter().map(|r| &r.output).collect::<Vec<_>>(),
}))
} else {
AgentResult::failure(
format!("{}/{} subtasks succeeded", successes, total),
total_cost,
)
.with_data(json!({
"subtasks_total": total,
"subtasks_succeeded": successes,
"results": results.iter().map(|r| json!({
"success": r.success,
"output": &r.output,
})).collect::<Vec<_>>(),
}))
}
}
}
impl Default for RootAgent {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Agent for RootAgent {
fn id(&self) -> &AgentId {
&self.id
}
fn agent_type(&self) -> AgentType {
AgentType::Root
}
fn description(&self) -> &str {
"Root orchestrator: estimates complexity, splits tasks, delegates execution"
}
async fn execute(&self, task: &mut Task, ctx: &AgentContext) -> AgentResult {
let mut total_cost = 0u64;
// Step 1: Estimate complexity
let complexity = self.estimate_complexity(task, ctx).await;
total_cost += 1; // Complexity estimation cost
tracing::info!(
"Task complexity: {:.2} (should_split: {})",
complexity.score(),
complexity.should_split()
);
// Step 2: Decide execution strategy
if complexity.should_split() && ctx.can_split() {
// Complex task: split and delegate
match self.split_task(task, ctx).await {
Ok(plan) => {
total_cost += 2; // Splitting cost
// Execute subtasks
let child_ctx = ctx.child_context();
let result = self.execute_subtasks(plan, task.budget(), &child_ctx).await;
return AgentResult {
success: result.success,
output: result.output,
cost_cents: total_cost + result.cost_cents,
model_used: result.model_used,
data: result.data,
};
}
Err(e) => {
// Couldn't split, fall back to direct execution
tracing::warn!("Couldn't split task, executing directly: {}", e.output);
}
}
}
// Simple task or failed to split: execute directly
let result = self.task_executor.execute(task, ctx).await;
// Step 3: Verify (if verification criteria specified)
let verification = self.verifier.execute(task, ctx).await;
total_cost += verification.cost_cents;
AgentResult {
success: result.success && verification.success,
output: if verification.success {
result.output
} else {
format!("{}\n\nVerification failed: {}", result.output, verification.output)
},
cost_cents: total_cost + result.cost_cents,
model_used: result.model_used,
data: json!({
"complexity": complexity.score(),
"was_split": false,
"verification": verification.data,
"execution": result.data,
}).into(),
}
}
}
#[async_trait]
impl OrchestratorAgent for RootAgent {
fn children(&self) -> Vec<AgentRef> {
vec![
Arc::clone(&self.complexity_estimator) as AgentRef,
Arc::clone(&self.model_selector) as AgentRef,
Arc::clone(&self.task_executor) as AgentRef,
Arc::clone(&self.verifier) as AgentRef,
]
}
fn find_child(&self, agent_type: AgentType) -> Option<AgentRef> {
match agent_type {
AgentType::ComplexityEstimator => Some(Arc::clone(&self.complexity_estimator) as AgentRef),
AgentType::ModelSelector => Some(Arc::clone(&self.model_selector) as AgentRef),
AgentType::TaskExecutor => Some(Arc::clone(&self.task_executor) as AgentRef),
AgentType::Verifier => Some(Arc::clone(&self.verifier) as AgentRef),
_ => None,
}
}
async fn delegate(&self, task: &mut Task, child: AgentRef, ctx: &AgentContext) -> AgentResult {
child.execute(task, ctx).await
}
async fn delegate_all(&self, tasks: &mut [Task], ctx: &AgentContext) -> Vec<AgentResult> {
let mut results = Vec::with_capacity(tasks.len());
for task in tasks {
let result = self.task_executor.execute(task, ctx).await;
results.push(result);
}
results
}
}

150
src/agents/tree.rs Normal file
View File

@@ -0,0 +1,150 @@
//! Agent tree structure and management.
use std::collections::HashMap;
use std::sync::Arc;
use super::{Agent, AgentId, AgentType};
/// Reference to an agent in the tree.
pub type AgentRef = Arc<dyn Agent>;
/// The agent tree structure.
///
/// # Structure
/// - Root agent at the top
/// - Node agents as intermediate orchestrators
/// - Leaf agents doing specialized work
///
/// # Invariants
/// - Exactly one root agent
/// - All non-root agents have a parent
/// - No cycles in parent-child relationships
pub struct AgentTree {
/// All agents indexed by ID
agents: HashMap<AgentId, AgentRef>,
/// Parent-child relationships
children: HashMap<AgentId, Vec<AgentId>>,
/// Root agent ID
root_id: Option<AgentId>,
}
impl AgentTree {
/// Create a new empty agent tree.
pub fn new() -> Self {
Self {
agents: HashMap::new(),
children: HashMap::new(),
root_id: None,
}
}
/// Set the root agent.
///
/// # Preconditions
/// - No root agent currently set
///
/// # Errors
/// Returns error if root already exists.
pub fn set_root(&mut self, agent: AgentRef) -> Result<(), TreeError> {
if self.root_id.is_some() {
return Err(TreeError::RootAlreadyExists);
}
let id = *agent.id();
self.agents.insert(id, agent);
self.children.insert(id, Vec::new());
self.root_id = Some(id);
Ok(())
}
/// Add a child agent to a parent.
///
/// # Preconditions
/// - Parent exists in the tree
/// - Child is not already in the tree
pub fn add_child(&mut self, parent_id: AgentId, child: AgentRef) -> Result<(), TreeError> {
if !self.agents.contains_key(&parent_id) {
return Err(TreeError::ParentNotFound(parent_id));
}
let child_id = *child.id();
if self.agents.contains_key(&child_id) {
return Err(TreeError::AgentAlreadyExists(child_id));
}
self.agents.insert(child_id, child);
self.children.insert(child_id, Vec::new());
self.children.get_mut(&parent_id).unwrap().push(child_id);
Ok(())
}
/// Get the root agent.
pub fn root(&self) -> Option<AgentRef> {
self.root_id.and_then(|id| self.agents.get(&id).cloned())
}
/// Get an agent by ID.
pub fn get(&self, id: &AgentId) -> Option<AgentRef> {
self.agents.get(id).cloned()
}
/// Get children of an agent.
pub fn get_children(&self, id: &AgentId) -> Vec<AgentRef> {
self.children
.get(id)
.map(|ids| {
ids.iter()
.filter_map(|id| self.agents.get(id).cloned())
.collect()
})
.unwrap_or_default()
}
/// Find agents by type.
pub fn find_by_type(&self, agent_type: AgentType) -> Vec<AgentRef> {
self.agents
.values()
.filter(|a| a.agent_type() == agent_type)
.cloned()
.collect()
}
/// Get all agents in the tree.
pub fn all_agents(&self) -> Vec<AgentRef> {
self.agents.values().cloned().collect()
}
/// Get the number of agents in the tree.
pub fn len(&self) -> usize {
self.agents.len()
}
/// Check if the tree is empty.
pub fn is_empty(&self) -> bool {
self.agents.is_empty()
}
}
impl Default for AgentTree {
fn default() -> Self {
Self::new()
}
}
/// Errors in tree operations.
#[derive(Debug, thiserror::Error)]
pub enum TreeError {
#[error("Root agent already exists")]
RootAlreadyExists,
#[error("Parent agent not found: {0}")]
ParentNotFound(AgentId),
#[error("Agent already exists in tree: {0}")]
AgentAlreadyExists(AgentId),
}

248
src/agents/types.rs Normal file
View File

@@ -0,0 +1,248 @@
//! Core types for the agent system.
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// Unique identifier for an agent.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct AgentId(Uuid);
impl AgentId {
/// Create a new unique agent ID.
pub fn new() -> Self {
Self(Uuid::new_v4())
}
/// Create an agent ID from a string (for testing).
pub fn from_str(s: &str) -> Self {
Self(Uuid::parse_str(s).unwrap_or_else(|_| Uuid::new_v4()))
}
}
impl Default for AgentId {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for AgentId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
/// Type of agent in the hierarchy.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum AgentType {
/// Root orchestrator (top of tree)
Root,
/// Intermediate orchestrator (can have children)
Node,
/// Estimates task complexity
ComplexityEstimator,
/// Selects optimal model
ModelSelector,
/// Executes tasks using tools
TaskExecutor,
/// Verifies task completion
Verifier,
}
impl AgentType {
/// Check if this is an orchestrator type (can have children).
pub fn is_orchestrator(&self) -> bool {
matches!(self, Self::Root | Self::Node)
}
/// Check if this is a leaf type.
pub fn is_leaf(&self) -> bool {
!self.is_orchestrator()
}
}
/// Result of an agent executing a task.
///
/// # Invariants
/// - If `success == true`, the task was completed
/// - `cost_cents` reflects actual cost incurred
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentResult {
/// Whether the task was successful
pub success: bool,
/// Output or response from the agent
pub output: String,
/// Cost incurred in cents
pub cost_cents: u64,
/// Model used (if any)
pub model_used: Option<String>,
/// Detailed result data (type-specific)
pub data: Option<serde_json::Value>,
}
impl AgentResult {
/// Create a successful result.
pub fn success(output: impl Into<String>, cost_cents: u64) -> Self {
Self {
success: true,
output: output.into(),
cost_cents,
model_used: None,
data: None,
}
}
/// Create a failure result.
pub fn failure(error: impl Into<String>, cost_cents: u64) -> Self {
Self {
success: false,
output: error.into(),
cost_cents,
model_used: None,
data: None,
}
}
/// Add model information to the result.
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model_used = Some(model.into());
self
}
/// Add additional data to the result.
pub fn with_data(mut self, data: serde_json::Value) -> Self {
self.data = Some(data);
self
}
}
/// Complexity estimation for a task.
///
/// # Invariants
/// - `score` is in range [0.0, 1.0]
/// - `should_split` is derived from score threshold
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Complexity {
/// Complexity score: 0.0 = trivial, 1.0 = extremely complex
score: f64,
/// Human-readable explanation
reasoning: String,
/// Whether the task should be split into subtasks
should_split: bool,
/// Estimated token count for this task
estimated_tokens: u64,
}
impl Complexity {
/// Create a new complexity estimate.
///
/// # Preconditions
/// - `score` is in [0.0, 1.0] (will be clamped if not)
///
/// # Postconditions
/// - `self.score` is in [0.0, 1.0]
/// - `self.should_split` is true if score > threshold (0.6)
pub fn new(score: f64, reasoning: impl Into<String>, estimated_tokens: u64) -> Self {
let clamped_score = score.clamp(0.0, 1.0);
Self {
score: clamped_score,
reasoning: reasoning.into(),
should_split: clamped_score > Self::SPLIT_THRESHOLD,
estimated_tokens,
}
}
/// Threshold above which tasks should be split.
pub const SPLIT_THRESHOLD: f64 = 0.6;
/// Get the complexity score.
pub fn score(&self) -> f64 {
self.score
}
/// Get the reasoning explanation.
pub fn reasoning(&self) -> &str {
&self.reasoning
}
/// Check if the task should be split.
pub fn should_split(&self) -> bool {
self.should_split
}
/// Get estimated token count.
pub fn estimated_tokens(&self) -> u64 {
self.estimated_tokens
}
/// Create a simple (low complexity) estimate.
pub fn simple(reasoning: impl Into<String>) -> Self {
Self::new(0.2, reasoning, 500)
}
/// Create a moderate complexity estimate.
pub fn moderate(reasoning: impl Into<String>) -> Self {
Self::new(0.5, reasoning, 2000)
}
/// Create a complex estimate that should be split.
pub fn complex(reasoning: impl Into<String>) -> Self {
Self::new(0.8, reasoning, 5000)
}
/// Override the should_split decision.
pub fn with_split(mut self, should_split: bool) -> Self {
self.should_split = should_split;
self
}
}
/// Errors that can occur in agent operations.
#[derive(Debug, Clone, thiserror::Error)]
pub enum AgentError {
#[error("Task error: {0}")]
TaskError(String),
#[error("Budget exhausted: needed {needed} cents, had {available} cents")]
BudgetExhausted { needed: u64, available: u64 },
#[error("No capable agent found for task")]
NoCapableAgent,
#[error("LLM error: {0}")]
LlmError(String),
#[error("Tool error: {0}")]
ToolError(String),
#[error("Verification failed: {0}")]
VerificationFailed(String),
#[error("Max iterations reached: {0}")]
MaxIterations(usize),
#[error("Internal error: {0}")]
Internal(String),
}
impl From<crate::task::TaskError> for AgentError {
fn from(e: crate::task::TaskError) -> Self {
Self::TaskError(e.to_string())
}
}
impl From<crate::budget::BudgetError> for AgentError {
fn from(_e: crate::budget::BudgetError) -> Self {
Self::BudgetExhausted {
needed: 0,
available: 0,
}
}
}

View File

@@ -19,8 +19,12 @@ use tower_http::cors::CorsLayer;
use tower_http::trace::TraceLayer;
use uuid::Uuid;
use crate::agent::Agent;
use crate::agents::{Agent, AgentContext, AgentRef};
use crate::agents::orchestrator::RootAgent;
use crate::budget::ModelPricing;
use crate::config::Config;
use crate::llm::OpenRouterClient;
use crate::tools::ToolRegistry;
use super::types::*;
@@ -28,17 +32,36 @@ use super::types::*;
pub struct AppState {
pub config: Config,
pub tasks: RwLock<HashMap<Uuid, TaskState>>,
pub agent: Agent,
/// The hierarchical root agent
pub root_agent: AgentRef,
/// Shared context for agent execution
pub agent_context: AgentContext,
}
/// Start the HTTP server.
pub async fn serve(config: Config) -> anyhow::Result<()> {
let agent = Agent::new(config.clone());
// Create the root agent (hierarchical)
let root_agent: AgentRef = Arc::new(RootAgent::new());
// Create shared agent context
let llm = Arc::new(OpenRouterClient::new(config.api_key.clone()));
let tools = ToolRegistry::new();
let pricing = Arc::new(ModelPricing::new());
let workspace = config.workspace_path.clone();
let agent_context = AgentContext::new(
config.clone(),
llm,
tools,
pricing,
workspace,
);
let state = Arc::new(AppState {
config: config.clone(),
tasks: RwLock::new(HashMap::new()),
agent,
root_agent,
agent_context,
});
let app = Router::new()
@@ -112,8 +135,8 @@ async fn create_task(
async fn run_agent_task(
state: Arc<AppState>,
task_id: Uuid,
task: String,
model: String,
task_description: String,
_model: String,
workspace_path: std::path::PathBuf,
) {
// Update status to running
@@ -124,24 +147,77 @@ async fn run_agent_task(
}
}
// Run the agent
let result = state.agent.run_task(&task, &model, &workspace_path).await;
// Create a Task object for the hierarchical agent
let budget = crate::budget::Budget::new(1000); // $10 default budget
let verification = crate::task::VerificationCriteria::None;
let task_result = crate::task::Task::new(
task_description.clone(),
verification,
budget,
);
let mut task = match task_result {
Ok(t) => t,
Err(e) => {
let mut tasks = state.tasks.write().await;
if let Some(task_state) = tasks.get_mut(&task_id) {
task_state.status = TaskStatus::Failed;
task_state.result = Some(format!("Failed to create task: {}", e));
}
return;
}
};
// Create context with the specified workspace
let llm = Arc::new(OpenRouterClient::new(state.config.api_key.clone()));
let tools = ToolRegistry::new();
let pricing = Arc::new(ModelPricing::new());
let ctx = AgentContext::new(
state.config.clone(),
llm,
tools,
pricing,
workspace_path,
);
// Run the hierarchical agent
let result = state.root_agent.execute(&mut task, &ctx).await;
// Update task with result
{
let mut tasks = state.tasks.write().await;
if let Some(task_state) = tasks.get_mut(&task_id) {
match result {
Ok((response, log)) => {
task_state.status = TaskStatus::Completed;
task_state.result = Some(response);
task_state.log = log;
}
Err(e) => {
task_state.status = TaskStatus::Failed;
task_state.result = Some(format!("Error: {}", e));
// Add log entries from result data
if let Some(data) = &result.data {
if let Some(tools_used) = data.get("tools_used") {
if let Some(arr) = tools_used.as_array() {
for tool in arr {
task_state.log.push(TaskLogEntry {
timestamp: "0".to_string(),
entry_type: LogEntryType::ToolCall,
content: tool.as_str().unwrap_or("").to_string(),
});
}
}
}
}
// Add final response log
task_state.log.push(TaskLogEntry {
timestamp: "0".to_string(),
entry_type: LogEntryType::Response,
content: result.output.clone(),
});
if result.success {
task_state.status = TaskStatus::Completed;
task_state.result = Some(result.output);
} else {
task_state.status = TaskStatus::Failed;
task_state.result = Some(format!("Error: {}", result.output));
}
}
}
}
@@ -217,4 +293,3 @@ async fn stream_task(
Ok(Sse::new(stream))
}

205
src/budget/allocation.rs Normal file
View File

@@ -0,0 +1,205 @@
//! Budget allocation strategies for subtasks.
//!
//! # Strategies
//! - Proportional: allocate based on estimated complexity
//! - Equal: split evenly among subtasks
//! - Priority: allocate more to critical subtasks
use crate::task::Subtask;
/// Strategy for allocating budget across subtasks.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AllocationStrategy {
/// Allocate proportionally based on subtask weights
Proportional,
/// Allocate equally among all subtasks
Equal,
/// Allocate based on priority (first subtasks get more)
PriorityFirst,
}
/// Allocate budget across subtasks.
///
/// # Preconditions
/// - `subtasks` is non-empty
/// - `total_budget > 0`
///
/// # Postconditions
/// - `result.len() == subtasks.len()`
/// - `result.iter().sum() <= total_budget`
///
/// # Pure Function
/// This is a pure function with no side effects.
pub fn allocate_budget(
subtasks: &[Subtask],
total_budget: u64,
strategy: AllocationStrategy,
) -> Vec<u64> {
if subtasks.is_empty() {
return Vec::new();
}
match strategy {
AllocationStrategy::Proportional => allocate_proportional(subtasks, total_budget),
AllocationStrategy::Equal => allocate_equal(subtasks, total_budget),
AllocationStrategy::PriorityFirst => allocate_priority(subtasks, total_budget),
}
}
/// Allocate proportionally based on weights.
///
/// # Invariant
/// Sum of allocations == total_budget (minus rounding)
fn allocate_proportional(subtasks: &[Subtask], total_budget: u64) -> Vec<u64> {
let total_weight: f64 = subtasks.iter().map(|s| s.weight).sum();
if total_weight <= 0.0 {
return allocate_equal(subtasks, total_budget);
}
let mut allocations: Vec<u64> = subtasks
.iter()
.map(|s| {
let proportion = s.weight / total_weight;
((total_budget as f64) * proportion).floor() as u64
})
.collect();
// Distribute remainder to maintain total
let allocated: u64 = allocations.iter().sum();
let remainder = total_budget.saturating_sub(allocated);
// Give remainder to the first subtask (or distribute evenly)
if remainder > 0 && !allocations.is_empty() {
allocations[0] += remainder;
}
allocations
}
/// Allocate equally among all subtasks.
fn allocate_equal(subtasks: &[Subtask], total_budget: u64) -> Vec<u64> {
let n = subtasks.len() as u64;
if n == 0 {
return Vec::new();
}
let base = total_budget / n;
let remainder = total_budget % n;
let mut allocations = vec![base; subtasks.len()];
// Distribute remainder
for i in 0..(remainder as usize) {
allocations[i] += 1;
}
allocations
}
/// Allocate with priority to earlier subtasks.
///
/// Earlier subtasks (lower index) get proportionally more.
/// Uses exponential decay: weight[i] = 2^(n-i)
fn allocate_priority(subtasks: &[Subtask], total_budget: u64) -> Vec<u64> {
let n = subtasks.len();
if n == 0 {
return Vec::new();
}
// Compute exponential weights
let weights: Vec<f64> = (0..n)
.map(|i| 2.0_f64.powi((n - i - 1) as i32))
.collect();
let total_weight: f64 = weights.iter().sum();
let mut allocations: Vec<u64> = weights
.iter()
.map(|w| ((total_budget as f64) * (w / total_weight)).floor() as u64)
.collect();
// Distribute remainder
let allocated: u64 = allocations.iter().sum();
let remainder = total_budget.saturating_sub(allocated);
if remainder > 0 && !allocations.is_empty() {
allocations[0] += remainder;
}
allocations
}
/// Estimate reasonable budget for a task based on complexity.
///
/// # Formula
/// Uses a heuristic based on complexity score:
/// - 0.0-0.2: ~10 cents (simple task)
/// - 0.2-0.5: ~50 cents (moderate task)
/// - 0.5-0.8: ~200 cents (complex task)
/// - 0.8-1.0: ~500 cents (very complex task)
///
/// # Pure Function
pub fn estimate_budget_for_complexity(complexity_score: f64) -> u64 {
let clamped = complexity_score.clamp(0.0, 1.0);
// Exponential scaling
let base: f64 = 10.0; // Minimum 10 cents
let max: f64 = 500.0; // Maximum 500 cents ($5)
let budget = base * (max / base).powf(clamped);
budget.ceil() as u64
}
#[cfg(test)]
mod tests {
use super::*;
use crate::task::VerificationCriteria;
fn make_subtasks(weights: &[f64]) -> Vec<Subtask> {
weights
.iter()
.map(|&w| Subtask::new("test", VerificationCriteria::None, w))
.collect()
}
#[test]
fn test_proportional_allocation() {
let subtasks = make_subtasks(&[1.0, 2.0, 1.0]);
let allocs = allocate_budget(&subtasks, 100, AllocationStrategy::Proportional);
assert_eq!(allocs.len(), 3);
// Should be roughly 25, 50, 25
assert!(allocs[1] > allocs[0]);
assert_eq!(allocs.iter().sum::<u64>(), 100);
}
#[test]
fn test_equal_allocation() {
let subtasks = make_subtasks(&[1.0, 1.0, 1.0]);
let allocs = allocate_budget(&subtasks, 99, AllocationStrategy::Equal);
assert_eq!(allocs.len(), 3);
assert_eq!(allocs.iter().sum::<u64>(), 99);
}
#[test]
fn test_priority_allocation() {
let subtasks = make_subtasks(&[1.0, 1.0, 1.0]);
let allocs = allocate_budget(&subtasks, 100, AllocationStrategy::PriorityFirst);
assert_eq!(allocs.len(), 3);
assert!(allocs[0] > allocs[2]); // First gets more
assert_eq!(allocs.iter().sum::<u64>(), 100);
}
#[test]
fn test_complexity_budget_estimation() {
assert!(estimate_budget_for_complexity(0.0) <= 15);
assert!(estimate_budget_for_complexity(0.5) > 50);
assert!(estimate_budget_for_complexity(1.0) >= 450);
}
}

249
src/budget/budget.rs Normal file
View File

@@ -0,0 +1,249 @@
//! Budget tracking for tasks.
//!
//! # Invariants
//! - `allocated_cents <= total_cents` (enforced at all times)
//! - `spent_cents <= allocated_cents` (enforced at all times)
use serde::{Deserialize, Serialize};
/// Budget for a task, tracking total, allocated, and spent amounts.
///
/// # Invariants
/// - `allocated_cents <= total_cents`
/// - `spent_cents <= allocated_cents`
///
/// # Design for Provability
/// All mutations go through methods that enforce invariants.
/// Direct field access is prevented (fields are private).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Budget {
/// Total budget available (in cents)
total_cents: u64,
/// Amount allocated to subtasks (in cents)
allocated_cents: u64,
/// Amount actually spent (in cents)
spent_cents: u64,
}
impl Budget {
/// Create a new budget with the given total.
///
/// # Postconditions
/// - `budget.total_cents == total_cents`
/// - `budget.allocated_cents == 0`
/// - `budget.spent_cents == 0`
pub fn new(total_cents: u64) -> Self {
Self {
total_cents,
allocated_cents: 0,
spent_cents: 0,
}
}
/// Create a budget with unlimited funds (for testing).
///
/// # Warning
/// This should only be used for testing, not production.
pub fn unlimited() -> Self {
Self {
total_cents: u64::MAX,
allocated_cents: 0,
spent_cents: 0,
}
}
// Getters
/// Get the total budget in cents.
pub fn total_cents(&self) -> u64 {
self.total_cents
}
/// Get the allocated amount in cents.
pub fn allocated_cents(&self) -> u64 {
self.allocated_cents
}
/// Get the spent amount in cents.
pub fn spent_cents(&self) -> u64 {
self.spent_cents
}
/// Get the remaining unallocated budget in cents.
///
/// # Property
/// `remaining_cents() == total_cents - allocated_cents`
pub fn remaining_cents(&self) -> u64 {
self.total_cents.saturating_sub(self.allocated_cents)
}
/// Get the unspent allocated budget in cents.
///
/// # Property
/// `unspent_cents() == allocated_cents - spent_cents`
pub fn unspent_cents(&self) -> u64 {
self.allocated_cents.saturating_sub(self.spent_cents)
}
/// Check if there's any remaining budget to allocate.
pub fn has_remaining(&self) -> bool {
self.remaining_cents() > 0
}
/// Check if the budget is exhausted (all spent).
pub fn is_exhausted(&self) -> bool {
self.spent_cents >= self.allocated_cents
}
// Mutations with invariant enforcement
/// Allocate some budget for a subtask.
///
/// # Precondition
/// `amount <= self.remaining_cents()`
///
/// # Postcondition
/// `self.allocated_cents` increases by `amount`
///
/// # Errors
/// Returns `Err` if allocation would exceed total.
pub fn allocate(&mut self, amount: u64) -> Result<(), BudgetError> {
let new_allocated = self.allocated_cents.saturating_add(amount);
if new_allocated > self.total_cents {
return Err(BudgetError::AllocationExceedsTotal {
requested: amount,
remaining: self.remaining_cents(),
});
}
self.allocated_cents = new_allocated;
Ok(())
}
/// Record spending against the allocated budget.
///
/// # Precondition
/// `amount <= self.unspent_cents()`
///
/// # Postcondition
/// `self.spent_cents` increases by `amount`
///
/// # Errors
/// Returns `Err` if spending would exceed allocated.
pub fn spend(&mut self, amount: u64) -> Result<(), BudgetError> {
let new_spent = self.spent_cents.saturating_add(amount);
if new_spent > self.allocated_cents {
return Err(BudgetError::SpendingExceedsAllocated {
requested: amount,
available: self.unspent_cents(),
});
}
self.spent_cents = new_spent;
Ok(())
}
/// Try to spend, returning how much was actually spent.
///
/// This is a "best effort" version that won't fail,
/// but may spend less than requested.
///
/// # Postcondition
/// `result <= amount`
/// `result <= self.unspent_cents()` (before call)
pub fn try_spend(&mut self, amount: u64) -> u64 {
let available = self.unspent_cents();
let actual = amount.min(available);
self.spent_cents += actual;
actual
}
/// Check if we can afford a given cost.
///
/// # Returns
/// `true` if `cost <= self.unspent_cents()`
pub fn can_afford(&self, cost: u64) -> bool {
cost <= self.unspent_cents()
}
/// Create a sub-budget from this budget.
///
/// # Precondition
/// `amount <= self.remaining_cents()`
///
/// # Side Effects
/// Allocates `amount` from this budget.
///
/// # Returns
/// A new budget with `total_cents == amount`.
pub fn create_sub_budget(&mut self, amount: u64) -> Result<Budget, BudgetError> {
self.allocate(amount)?;
Ok(Budget::new(amount))
}
}
impl Default for Budget {
/// Default budget is $1.00 (100 cents).
fn default() -> Self {
Self::new(100)
}
}
/// Errors related to budget operations.
#[derive(Debug, Clone, thiserror::Error)]
pub enum BudgetError {
#[error("Allocation of {requested} cents exceeds remaining budget of {remaining} cents")]
AllocationExceedsTotal { requested: u64, remaining: u64 },
#[error("Spending of {requested} cents exceeds available budget of {available} cents")]
SpendingExceedsAllocated { requested: u64, available: u64 },
#[error("Insufficient budget: need {needed} cents, have {available} cents")]
InsufficientBudget { needed: u64, available: u64 },
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_budget_invariants() {
let mut budget = Budget::new(100);
// Initially, nothing is allocated or spent
assert_eq!(budget.remaining_cents(), 100);
assert_eq!(budget.unspent_cents(), 0);
// Allocate some
budget.allocate(50).unwrap();
assert_eq!(budget.remaining_cents(), 50);
assert_eq!(budget.unspent_cents(), 50);
// Spend some
budget.spend(30).unwrap();
assert_eq!(budget.unspent_cents(), 20);
assert_eq!(budget.spent_cents(), 30);
// Can't over-allocate
assert!(budget.allocate(60).is_err());
// Can't over-spend
assert!(budget.spend(30).is_err());
}
#[test]
fn test_sub_budget() {
let mut parent = Budget::new(100);
let child = parent.create_sub_budget(40).unwrap();
assert_eq!(parent.remaining_cents(), 60);
assert_eq!(child.total_cents(), 40);
assert_eq!(child.remaining_cents(), 40);
}
}

15
src/budget/mod.rs Normal file
View File

@@ -0,0 +1,15 @@
//! Budget module - cost tracking and model pricing.
//!
//! # Key Concepts
//! - Budget: tracks total and allocated costs for a task
//! - Pricing: fetches and caches OpenRouter model pricing
//! - Allocation: algorithms for distributing budget across subtasks
mod budget;
mod pricing;
mod allocation;
pub use budget::{Budget, BudgetError};
pub use pricing::{ModelPricing, PricingInfo};
pub use allocation::{AllocationStrategy, allocate_budget};

239
src/budget/pricing.rs Normal file
View File

@@ -0,0 +1,239 @@
//! OpenRouter pricing information and caching.
//!
//! Fetches real-time pricing from OpenRouter API to enable
//! cost-aware model selection.
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use serde::{Deserialize, Serialize};
/// Pricing information for a single model.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PricingInfo {
/// Model identifier (e.g., "openai/gpt-4.1-mini")
pub model_id: String,
/// Cost per 1M input tokens in dollars
pub prompt_cost_per_million: f64,
/// Cost per 1M output tokens in dollars
pub completion_cost_per_million: f64,
/// Context window size in tokens
pub context_length: u64,
/// Maximum output tokens
pub max_output_tokens: Option<u64>,
}
impl PricingInfo {
/// Calculate cost in cents for given token counts.
///
/// # Formula
/// `cost = (input_tokens * prompt_rate + output_tokens * completion_rate) / 1_000_000 * 100`
///
/// # Postcondition
/// `result >= 0`
pub fn calculate_cost_cents(&self, input_tokens: u64, output_tokens: u64) -> u64 {
let input_cost = (input_tokens as f64) * self.prompt_cost_per_million / 1_000_000.0;
let output_cost = (output_tokens as f64) * self.completion_cost_per_million / 1_000_000.0;
let total_dollars = input_cost + output_cost;
(total_dollars * 100.0).ceil() as u64
}
/// Estimate cost for a task given estimated token counts.
///
/// Adds a safety margin of 20% for estimation errors.
pub fn estimate_cost_cents(&self, estimated_input: u64, estimated_output: u64) -> u64 {
let base_cost = self.calculate_cost_cents(estimated_input, estimated_output);
// Add 20% safety margin
(base_cost as f64 * 1.2).ceil() as u64
}
/// Get the average cost per token (for rough comparisons).
pub fn average_cost_per_token(&self) -> f64 {
(self.prompt_cost_per_million + self.completion_cost_per_million) / 2.0 / 1_000_000.0
}
}
/// Model pricing cache and fetcher.
pub struct ModelPricing {
/// Cached pricing data
cache: Arc<RwLock<HashMap<String, PricingInfo>>>,
/// HTTP client for fetching pricing
client: reqwest::Client,
}
impl ModelPricing {
/// Create a new pricing manager.
pub fn new() -> Self {
Self {
cache: Arc::new(RwLock::new(HashMap::new())),
client: reqwest::Client::new(),
}
}
/// Create with pre-populated pricing data (for testing or offline use).
pub fn with_pricing(pricing: HashMap<String, PricingInfo>) -> Self {
Self {
cache: Arc::new(RwLock::new(pricing)),
client: reqwest::Client::new(),
}
}
/// Get pricing for a specific model.
///
/// Returns cached data if available, otherwise fetches from API.
pub async fn get_pricing(&self, model_id: &str) -> Option<PricingInfo> {
// Check cache first
{
let cache = self.cache.read().await;
if let Some(info) = cache.get(model_id) {
return Some(info.clone());
}
}
// If not in cache, try to fetch all models
if let Ok(()) = self.refresh_pricing().await {
let cache = self.cache.read().await;
return cache.get(model_id).cloned();
}
// Fall back to hardcoded defaults for common models
self.default_pricing(model_id)
}
/// Get all cached pricing info.
pub async fn all_pricing(&self) -> HashMap<String, PricingInfo> {
self.cache.read().await.clone()
}
/// Refresh pricing from OpenRouter API.
pub async fn refresh_pricing(&self) -> Result<(), PricingError> {
let response = self
.client
.get("https://openrouter.ai/api/v1/models")
.send()
.await
.map_err(|e| PricingError::NetworkError(e.to_string()))?;
if !response.status().is_success() {
return Err(PricingError::ApiError(format!(
"Status: {}",
response.status()
)));
}
let data: OpenRouterModelsResponse = response
.json()
.await
.map_err(|e| PricingError::ParseError(e.to_string()))?;
let mut cache = self.cache.write().await;
for model in data.data {
let info = PricingInfo {
model_id: model.id.clone(),
prompt_cost_per_million: parse_price(&model.pricing.prompt),
completion_cost_per_million: parse_price(&model.pricing.completion),
context_length: model.context_length.unwrap_or(4096),
max_output_tokens: model.top_provider.as_ref()
.and_then(|p| p.max_completion_tokens),
};
cache.insert(model.id, info);
}
Ok(())
}
/// Get default pricing for common models (fallback).
fn default_pricing(&self, model_id: &str) -> Option<PricingInfo> {
// Hardcoded defaults for when API is unavailable
let defaults = [
("openai/gpt-4.1-mini", 0.40, 1.60, 1_000_000),
("openai/gpt-4.1", 2.50, 10.00, 1_000_000),
("openai/gpt-4o", 2.50, 10.00, 128_000),
("openai/gpt-4o-mini", 0.15, 0.60, 128_000),
("anthropic/claude-3.5-sonnet", 3.00, 15.00, 200_000),
("anthropic/claude-3-haiku", 0.25, 1.25, 200_000),
];
for (id, prompt, completion, context) in defaults {
if model_id == id || model_id.contains(id.split('/').last().unwrap_or("")) {
return Some(PricingInfo {
model_id: model_id.to_string(),
prompt_cost_per_million: prompt,
completion_cost_per_million: completion,
context_length: context,
max_output_tokens: None,
});
}
}
None
}
/// Get models sorted by cost (cheapest first).
pub async fn models_by_cost(&self) -> Vec<PricingInfo> {
let cache = self.cache.read().await;
let mut models: Vec<_> = cache.values().cloned().collect();
models.sort_by(|a, b| {
a.average_cost_per_token()
.partial_cmp(&b.average_cost_per_token())
.unwrap_or(std::cmp::Ordering::Equal)
});
models
}
}
impl Default for ModelPricing {
fn default() -> Self {
Self::new()
}
}
/// Parse price string from OpenRouter API.
fn parse_price(price: &str) -> f64 {
price.parse().unwrap_or(0.0)
}
/// OpenRouter API response structures.
#[derive(Debug, Deserialize)]
struct OpenRouterModelsResponse {
data: Vec<OpenRouterModel>,
}
#[derive(Debug, Deserialize)]
struct OpenRouterModel {
id: String,
pricing: OpenRouterPricing,
context_length: Option<u64>,
top_provider: Option<OpenRouterProvider>,
}
#[derive(Debug, Deserialize)]
struct OpenRouterPricing {
prompt: String,
completion: String,
}
#[derive(Debug, Deserialize)]
struct OpenRouterProvider {
max_completion_tokens: Option<u64>,
}
/// Pricing-related errors.
#[derive(Debug, thiserror::Error)]
pub enum PricingError {
#[error("Network error: {0}")]
NetworkError(String),
#[error("API error: {0}")]
ApiError(String),
#[error("Parse error: {0}")]
ParseError(String),
}

View File

@@ -4,31 +4,44 @@
//!
//! This library provides:
//! - An HTTP API for task submission and monitoring
//! - A tool-based agent loop for autonomous code editing
//! - A hierarchical agent tree for complex task handling
//! - Tool-based execution for autonomous code editing
//! - Integration with OpenRouter for LLM access
//!
//! ## Architecture
//! ## Architecture (v2: Hierarchical Agent Tree)
//!
//! The agent follows the "tools in a loop" pattern:
//! 1. Receive a task via the API
//! 2. Build context with system prompt and available tools
//! 3. Call LLM, parse response, execute any tool calls
//! 4. Feed results back to LLM, repeat until task complete
//!
//! ## Example
//!
//! ```rust,ignore
//! use open_agent::{config::Config, agent::Agent};
//!
//! let config = Config::from_env()?;
//! let agent = Agent::new(config);
//! let result = agent.run_task("Create a hello world script").await?;
//! ```text
//! ┌─────────────┐
//! │ RootAgent │
//! └──────┬──────┘
//! ┌─────────────────┼─────────────────┐
//! ▼ ▼ ▼
//! ┌───────────────┐ ┌─────────────┐ ┌─────────────┐
//! │ Complexity │ │ Model │ │ Task │
//! │ Estimator │ │ Selector │ │ Executor │
//! └───────────────┘ └─────────────┘ └─────────────┘
//! ```
//!
//! ## Task Flow
//! 1. Receive task via API
//! 2. Estimate complexity (should we split?)
//! 3. Select optimal model (U-curve cost optimization)
//! 4. Execute (directly or via subtasks)
//! 5. Verify completion (programmatic + LLM hybrid)
//!
//! ## Modules
//! - `agents`: Hierarchical agent tree (Root, Node, Leaf agents)
//! - `task`: Task, subtask, and verification types
//! - `budget`: Cost tracking and model pricing
//! - `agent`: Original simple agent (kept for compatibility)
pub mod api;
pub mod agent;
pub mod agents;
pub mod budget;
pub mod config;
pub mod llm;
pub mod task;
pub mod tools;
pub use config::Config;

15
src/task/mod.rs Normal file
View File

@@ -0,0 +1,15 @@
//! Task module - defines tasks, subtasks, and verification criteria.
//!
//! This module is designed with formal verification in mind:
//! - All types use algebraic data types with exhaustive matching
//! - Invariants are documented and enforced in constructors
//! - Pure functions are separated from IO operations
pub mod task;
mod subtask;
mod verification;
pub use task::{Task, TaskId, TaskStatus, TaskError};
pub use subtask::{Subtask, SubtaskPlan, SubtaskPlanError};
pub use verification::{VerificationCriteria, VerificationResult, VerificationMethod, ProgrammaticCheck};

249
src/task/subtask.rs Normal file
View File

@@ -0,0 +1,249 @@
//! Subtask definitions and splitting logic.
//!
//! When a task is too complex, it can be split into subtasks.
//! Each subtask has:
//! - A description of what to do
//! - Verification criteria (how to check it's done)
//! - A budget allocation
use serde::{Deserialize, Serialize};
use super::{Task, TaskId, VerificationCriteria};
use crate::budget::Budget;
/// A planned subtask before it becomes a full Task.
///
/// # Purpose
/// Represents the output of task splitting before budget allocation.
/// Once budgets are assigned, these become full `Task` objects.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Subtask {
/// Description of what this subtask should accomplish
pub description: String,
/// How to verify this subtask is complete
pub verification: VerificationCriteria,
/// Relative weight for budget allocation (higher = more budget)
pub weight: f64,
/// Dependencies: IDs of subtasks that must complete first
pub dependencies: Vec<usize>,
}
impl Subtask {
/// Create a new subtask with no dependencies.
///
/// # Preconditions
/// - `description` is non-empty
/// - `weight > 0.0`
pub fn new(
description: impl Into<String>,
verification: VerificationCriteria,
weight: f64,
) -> Self {
Self {
description: description.into(),
verification,
weight: weight.max(0.01), // Ensure positive weight
dependencies: Vec::new(),
}
}
/// Add a dependency on another subtask (by index).
pub fn with_dependency(mut self, index: usize) -> Self {
self.dependencies.push(index);
self
}
/// Add multiple dependencies.
pub fn with_dependencies(mut self, indices: Vec<usize>) -> Self {
self.dependencies.extend(indices);
self
}
}
/// A plan for splitting a task into subtasks.
///
/// # Invariants
/// - `subtasks` is non-empty
/// - All dependency indices are valid (< subtasks.len())
/// - No circular dependencies
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SubtaskPlan {
/// The parent task ID
parent_id: TaskId,
/// The subtasks to create
subtasks: Vec<Subtask>,
/// Reasoning for the split
reasoning: String,
}
impl SubtaskPlan {
/// Create a new subtask plan.
///
/// # Preconditions
/// - `subtasks` is non-empty
/// - All dependency indices are valid
///
/// # Errors
/// Returns `Err` if preconditions are violated.
pub fn new(
parent_id: TaskId,
subtasks: Vec<Subtask>,
reasoning: impl Into<String>,
) -> Result<Self, SubtaskPlanError> {
if subtasks.is_empty() {
return Err(SubtaskPlanError::EmptySubtasks);
}
// Validate dependency indices
for (i, subtask) in subtasks.iter().enumerate() {
for &dep in &subtask.dependencies {
if dep >= subtasks.len() {
return Err(SubtaskPlanError::InvalidDependency {
subtask_index: i,
dependency_index: dep
});
}
if dep == i {
return Err(SubtaskPlanError::SelfDependency { subtask_index: i });
}
}
}
// TODO: Check for circular dependencies (would need topological sort)
Ok(Self {
parent_id,
subtasks,
reasoning: reasoning.into(),
})
}
/// Get the parent task ID.
pub fn parent_id(&self) -> TaskId {
self.parent_id
}
/// Get the subtasks.
pub fn subtasks(&self) -> &[Subtask] {
&self.subtasks
}
/// Get the reasoning for the split.
pub fn reasoning(&self) -> &str {
&self.reasoning
}
/// Convert this plan into actual Task objects with allocated budgets.
///
/// # Preconditions
/// - `total_budget.remaining() > 0`
///
/// # Postconditions
/// - Sum of subtask budgets <= total_budget
/// - Each subtask has parent_id set to self.parent_id
///
/// # Budget Allocation
/// Budget is allocated proportionally based on subtask weights.
pub fn into_tasks(self, total_budget: &Budget) -> Result<Vec<Task>, SubtaskPlanError> {
let total_weight: f64 = self.subtasks.iter().map(|s| s.weight).sum();
if total_weight <= 0.0 {
return Err(SubtaskPlanError::ZeroTotalWeight);
}
let available = total_budget.remaining_cents();
self.subtasks
.into_iter()
.map(|subtask| {
// Allocate budget proportionally
let proportion = subtask.weight / total_weight;
let allocated = ((available as f64) * proportion) as u64;
let budget = Budget::new(allocated);
Task::new_subtask(
subtask.description,
subtask.verification,
budget,
self.parent_id,
).map_err(|e| SubtaskPlanError::TaskCreation(e.to_string()))
})
.collect()
}
/// Get execution order respecting dependencies (topological sort).
///
/// # Returns
/// Vector of subtask indices in valid execution order.
///
/// # Errors
/// Returns `Err` if there are circular dependencies.
pub fn execution_order(&self) -> Result<Vec<usize>, SubtaskPlanError> {
let n = self.subtasks.len();
let mut in_degree = vec![0usize; n];
let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
// Build adjacency list and compute in-degrees
for (i, subtask) in self.subtasks.iter().enumerate() {
for &dep in &subtask.dependencies {
adj[dep].push(i);
in_degree[i] += 1;
}
}
// Kahn's algorithm for topological sort
let mut queue: Vec<usize> = in_degree
.iter()
.enumerate()
.filter(|(_, &d)| d == 0)
.map(|(i, _)| i)
.collect();
let mut order = Vec::with_capacity(n);
while let Some(node) = queue.pop() {
order.push(node);
for &next in &adj[node] {
in_degree[next] -= 1;
if in_degree[next] == 0 {
queue.push(next);
}
}
}
if order.len() != n {
Err(SubtaskPlanError::CircularDependency)
} else {
Ok(order)
}
}
}
/// Errors in subtask plan creation or execution.
#[derive(Debug, Clone, thiserror::Error)]
pub enum SubtaskPlanError {
#[error("Subtask list cannot be empty")]
EmptySubtasks,
#[error("Subtask {subtask_index} has invalid dependency index {dependency_index}")]
InvalidDependency { subtask_index: usize, dependency_index: usize },
#[error("Subtask {subtask_index} depends on itself")]
SelfDependency { subtask_index: usize },
#[error("Circular dependency detected in subtask plan")]
CircularDependency,
#[error("Total weight of subtasks is zero")]
ZeroTotalWeight,
#[error("Failed to create task: {0}")]
TaskCreation(String),
}

291
src/task/task.rs Normal file
View File

@@ -0,0 +1,291 @@
//! Core Task type with budget and verification criteria.
//!
//! # Invariants
//! - `budget.allocated_cents <= budget.total_cents`
//! - `id` is unique within an agent tree execution
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use super::verification::VerificationCriteria;
use crate::budget::Budget;
/// Unique identifier for a task.
///
/// # Properties
/// - Globally unique within an execution context
/// - Immutable once created
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TaskId(Uuid);
impl TaskId {
/// Create a new unique task ID.
///
/// # Postcondition
/// Returns a fresh ID that has never been used before in this process.
pub fn new() -> Self {
Self(Uuid::new_v4())
}
/// Get the inner UUID.
pub fn as_uuid(&self) -> Uuid {
self.0
}
}
impl Default for TaskId {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for TaskId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
/// Status of a task in its lifecycle.
///
/// # State Machine
/// ```text
/// Pending -> Running -> Completed
/// \-> Failed
/// \-> Cancelled
/// ```
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum TaskStatus {
/// Task is waiting to be executed
Pending,
/// Task is currently being executed
Running,
/// Task completed successfully
Completed,
/// Task failed with an error
Failed { reason: String },
/// Task was cancelled before completion
Cancelled,
}
impl TaskStatus {
/// Check if the task is in a terminal state.
///
/// # Returns
/// `true` if the task is Completed, Failed, or Cancelled.
///
/// # Property
/// `is_terminal() => !can_transition()`
pub fn is_terminal(&self) -> bool {
matches!(self, TaskStatus::Completed | TaskStatus::Failed { .. } | TaskStatus::Cancelled)
}
/// Check if the task is still active (can make progress).
///
/// # Returns
/// `true` if the task is Pending or Running.
pub fn is_active(&self) -> bool {
matches!(self, TaskStatus::Pending | TaskStatus::Running)
}
}
/// A task to be executed by an agent.
///
/// # Invariants
/// - `budget.allocated_cents <= budget.total_cents`
/// - If `parent_id.is_some()`, this is a subtask
///
/// # Design for Provability
/// - All fields are immutable after construction (except status via explicit transitions)
/// - Budget constraints are checked at construction time
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Task {
/// Unique identifier for this task
id: TaskId,
/// Human-readable description of what to accomplish
description: String,
/// How to verify the task was completed correctly
verification: VerificationCriteria,
/// Budget constraints for this task
budget: Budget,
/// Parent task ID if this is a subtask
parent_id: Option<TaskId>,
/// Current status
status: TaskStatus,
}
impl Task {
/// Create a new task with the given parameters.
///
/// # Preconditions
/// - `budget.allocated_cents <= budget.total_cents`
/// - `description` is non-empty
///
/// # Postconditions
/// - Returns a task with `status == Pending`
/// - `task.id` is a fresh unique identifier
///
/// # Errors
/// Returns `Err` if preconditions are violated.
pub fn new(
description: String,
verification: VerificationCriteria,
budget: Budget,
) -> Result<Self, TaskError> {
if description.is_empty() {
return Err(TaskError::EmptyDescription);
}
// Budget invariant is enforced by Budget::new()
Ok(Self {
id: TaskId::new(),
description,
verification,
budget,
parent_id: None,
status: TaskStatus::Pending,
})
}
/// Create a subtask with a parent reference.
///
/// # Preconditions
/// - Same as `new()`
/// - `parent_id` refers to an existing task
pub fn new_subtask(
description: String,
verification: VerificationCriteria,
budget: Budget,
parent_id: TaskId,
) -> Result<Self, TaskError> {
let mut task = Self::new(description, verification, budget)?;
task.parent_id = Some(parent_id);
Ok(task)
}
// Getters - all return references to preserve immutability semantics
pub fn id(&self) -> TaskId {
self.id
}
pub fn description(&self) -> &str {
&self.description
}
pub fn verification(&self) -> &VerificationCriteria {
&self.verification
}
pub fn budget(&self) -> &Budget {
&self.budget
}
pub fn budget_mut(&mut self) -> &mut Budget {
&mut self.budget
}
pub fn parent_id(&self) -> Option<TaskId> {
self.parent_id
}
pub fn status(&self) -> &TaskStatus {
&self.status
}
/// Check if this task is a subtask (has a parent).
pub fn is_subtask(&self) -> bool {
self.parent_id.is_some()
}
// State transitions - explicit and validated
/// Transition the task to Running state.
///
/// # Precondition
/// `self.status == Pending`
///
/// # Errors
/// Returns `Err` if the task is not in Pending state.
pub fn start(&mut self) -> Result<(), TaskError> {
match &self.status {
TaskStatus::Pending => {
self.status = TaskStatus::Running;
Ok(())
}
other => Err(TaskError::InvalidTransition {
from: format!("{:?}", other),
to: "Running".to_string(),
}),
}
}
/// Transition the task to Completed state.
///
/// # Precondition
/// `self.status == Running`
pub fn complete(&mut self) -> Result<(), TaskError> {
match &self.status {
TaskStatus::Running => {
self.status = TaskStatus::Completed;
Ok(())
}
other => Err(TaskError::InvalidTransition {
from: format!("{:?}", other),
to: "Completed".to_string(),
}),
}
}
/// Transition the task to Failed state.
///
/// # Precondition
/// `self.status == Running`
pub fn fail(&mut self, reason: String) -> Result<(), TaskError> {
match &self.status {
TaskStatus::Running => {
self.status = TaskStatus::Failed { reason };
Ok(())
}
other => Err(TaskError::InvalidTransition {
from: format!("{:?}", other),
to: "Failed".to_string(),
}),
}
}
/// Transition the task to Cancelled state.
///
/// # Precondition
/// `self.status.is_active()`
pub fn cancel(&mut self) -> Result<(), TaskError> {
if self.status.is_active() {
self.status = TaskStatus::Cancelled;
Ok(())
} else {
Err(TaskError::InvalidTransition {
from: format!("{:?}", self.status),
to: "Cancelled".to_string(),
})
}
}
}
/// Errors that can occur during task operations.
#[derive(Debug, Clone, thiserror::Error)]
pub enum TaskError {
#[error("Task description cannot be empty")]
EmptyDescription,
#[error("Invalid state transition from {from} to {to}")]
InvalidTransition { from: String, to: String },
#[error("Budget error: {0}")]
BudgetError(String),
}

215
src/task/verification.rs Normal file
View File

@@ -0,0 +1,215 @@
//! Verification criteria and results for tasks.
//!
//! Supports hybrid verification: programmatic checks with LLM fallback.
//!
//! # Design Principles
//! - Prefer programmatic verification when possible (deterministic, fast)
//! - Use LLM verification for subjective or complex assessments
//! - Hybrid mode tries programmatic first, falls back to LLM
use serde::{Deserialize, Serialize};
/// Programmatic checks that can be performed without LLM.
///
/// # Exhaustive Matching
/// All variants must be handled explicitly - no catch-all allowed.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ProgrammaticCheck {
/// Check if a file exists at the given path
FileExists { path: String },
/// Check if a file contains specific content
FileContains { path: String, content: String },
/// Check if a command exits with code 0
CommandSucceeds { command: String },
/// Check if a command output matches expected pattern
CommandOutputMatches { command: String, pattern: String },
/// Check if a directory exists
DirectoryExists { path: String },
/// Check if a file matches a regex pattern
FileMatchesRegex { path: String, pattern: String },
/// Multiple checks that must all pass
All(Vec<ProgrammaticCheck>),
/// At least one check must pass
Any(Vec<ProgrammaticCheck>),
}
impl ProgrammaticCheck {
/// Create a file exists check.
pub fn file_exists(path: impl Into<String>) -> Self {
Self::FileExists { path: path.into() }
}
/// Create a command succeeds check.
pub fn command_succeeds(command: impl Into<String>) -> Self {
Self::CommandSucceeds { command: command.into() }
}
/// Create an "all must pass" composite check.
pub fn all(checks: Vec<ProgrammaticCheck>) -> Self {
Self::All(checks)
}
/// Create an "any must pass" composite check.
pub fn any(checks: Vec<ProgrammaticCheck>) -> Self {
Self::Any(checks)
}
}
/// How to verify a task was completed correctly.
///
/// # Variants
/// - `Programmatic`: Use only programmatic checks (fast, deterministic)
/// - `LlmBased`: Use LLM to verify (flexible, slower)
/// - `Hybrid`: Try programmatic first, fall back to LLM if inconclusive
/// - `None`: No verification required (use with caution)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum VerificationCriteria {
/// Programmatic verification only
Programmatic(ProgrammaticCheck),
/// LLM-based verification with a prompt describing success criteria
LlmBased {
/// Prompt describing what "success" looks like
success_criteria: String,
},
/// Try programmatic first, fall back to LLM
Hybrid {
/// Programmatic check to try first
programmatic: ProgrammaticCheck,
/// LLM prompt to use if programmatic is inconclusive
llm_fallback: String,
},
/// No verification (task is considered complete when agent says so)
None,
}
impl VerificationCriteria {
/// Create a programmatic-only verification.
pub fn programmatic(check: ProgrammaticCheck) -> Self {
Self::Programmatic(check)
}
/// Create an LLM-based verification.
pub fn llm_based(success_criteria: impl Into<String>) -> Self {
Self::LlmBased {
success_criteria: success_criteria.into(),
}
}
/// Create a hybrid verification.
pub fn hybrid(programmatic: ProgrammaticCheck, llm_fallback: impl Into<String>) -> Self {
Self::Hybrid {
programmatic,
llm_fallback: llm_fallback.into(),
}
}
/// Create a no-verification criteria.
pub fn none() -> Self {
Self::None
}
/// Check if this verification requires LLM access.
///
/// # Returns
/// `true` if LLM may be needed (LlmBased or Hybrid)
pub fn may_require_llm(&self) -> bool {
matches!(self, Self::LlmBased { .. } | Self::Hybrid { .. })
}
/// Check if this verification is purely programmatic.
pub fn is_programmatic_only(&self) -> bool {
matches!(self, Self::Programmatic(_))
}
}
impl Default for VerificationCriteria {
fn default() -> Self {
Self::None
}
}
/// Result of a verification attempt.
///
/// # Invariants
/// - If `passed == true`, the task is considered successfully completed
/// - `reasoning` should always explain why the verification passed or failed
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VerificationResult {
/// Whether the verification passed
passed: bool,
/// Explanation of why the verification passed or failed
reasoning: String,
/// Which method was used for verification
method: VerificationMethod,
/// Cost in cents if LLM was used
cost_cents: u64,
}
impl VerificationResult {
/// Create a passing result.
///
/// # Postcondition
/// `result.passed == true`
pub fn pass(reasoning: impl Into<String>, method: VerificationMethod, cost_cents: u64) -> Self {
Self {
passed: true,
reasoning: reasoning.into(),
method,
cost_cents,
}
}
/// Create a failing result.
///
/// # Postcondition
/// `result.passed == false`
pub fn fail(reasoning: impl Into<String>, method: VerificationMethod, cost_cents: u64) -> Self {
Self {
passed: false,
reasoning: reasoning.into(),
method,
cost_cents,
}
}
pub fn passed(&self) -> bool {
self.passed
}
pub fn reasoning(&self) -> &str {
&self.reasoning
}
pub fn method(&self) -> &VerificationMethod {
&self.method
}
pub fn cost_cents(&self) -> u64 {
self.cost_cents
}
}
/// Method used for verification.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum VerificationMethod {
/// Programmatic check was used
Programmatic,
/// LLM was used for verification
Llm { model: String },
/// No verification was performed
None,
}