feat: Add model resolver and fix remaining build issues
- Add resolver.rs for model name resolution - Update budget/mod.rs exports - Fix remaining compilation errors
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -6,7 +6,8 @@ This script:
|
||||
1. Fetches all models from OpenRouter API
|
||||
2. Fetches benchmark metadata from ZeroEval API
|
||||
3. For key benchmarks in each category, fetches model scores
|
||||
4. Creates a merged JSON with benchmark scores per category
|
||||
4. Auto-detects model families and tracks latest versions
|
||||
5. Creates a merged JSON with benchmark scores per category
|
||||
|
||||
Categories tracked:
|
||||
- code: Coding benchmarks (SWE-bench, HumanEval, etc.)
|
||||
@@ -14,13 +15,20 @@ Categories tracked:
|
||||
- reasoning: Reasoning benchmarks (GPQA, MMLU, etc.)
|
||||
- tool_calling: Tool/function calling benchmarks
|
||||
- long_context: Long context benchmarks
|
||||
|
||||
Model families tracked:
|
||||
- claude-sonnet, claude-haiku, claude-opus (Anthropic)
|
||||
- gpt-4, gpt-4-mini (OpenAI)
|
||||
- gemini-pro, gemini-flash (Google)
|
||||
- And more...
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from urllib.request import Request, urlopen
|
||||
from urllib.error import URLError, HTTPError
|
||||
from collections import defaultdict
|
||||
@@ -55,6 +63,55 @@ KEY_BENCHMARKS = {
|
||||
]
|
||||
}
|
||||
|
||||
# Model family patterns with tier classification
|
||||
# Format: (regex_pattern, family_name, tier)
|
||||
# Tier: "flagship" (best), "mid" (balanced), "fast" (cheap/fast)
|
||||
MODEL_FAMILY_PATTERNS = [
|
||||
# Anthropic Claude
|
||||
(r"^anthropic/claude-opus-(\d+\.?\d*)$", "claude-opus", "flagship"),
|
||||
(r"^anthropic/claude-(\d+\.?\d*)-opus$", "claude-opus", "flagship"),
|
||||
(r"^anthropic/claude-sonnet-(\d+\.?\d*)$", "claude-sonnet", "mid"),
|
||||
(r"^anthropic/claude-(\d+\.?\d*)-sonnet$", "claude-sonnet", "mid"),
|
||||
(r"^anthropic/claude-haiku-(\d+\.?\d*)$", "claude-haiku", "fast"),
|
||||
(r"^anthropic/claude-(\d+\.?\d*)-haiku$", "claude-haiku", "fast"),
|
||||
|
||||
# OpenAI GPT
|
||||
(r"^openai/gpt-4\.1$", "gpt-4", "mid"),
|
||||
(r"^openai/gpt-4o$", "gpt-4", "mid"),
|
||||
(r"^openai/gpt-4-turbo", "gpt-4", "mid"),
|
||||
(r"^openai/gpt-4\.1-mini$", "gpt-4-mini", "fast"),
|
||||
(r"^openai/gpt-4o-mini$", "gpt-4-mini", "fast"),
|
||||
(r"^openai/o1$", "o1", "flagship"),
|
||||
(r"^openai/o1-preview", "o1", "flagship"),
|
||||
(r"^openai/o1-mini", "o1-mini", "mid"),
|
||||
(r"^openai/o3-mini", "o3-mini", "mid"),
|
||||
|
||||
# Google Gemini
|
||||
(r"^google/gemini-(\d+\.?\d*)-pro", "gemini-pro", "mid"),
|
||||
(r"^google/gemini-pro", "gemini-pro", "mid"),
|
||||
(r"^google/gemini-(\d+\.?\d*)-flash(?!-lite)", "gemini-flash", "fast"),
|
||||
(r"^google/gemini-flash", "gemini-flash", "fast"),
|
||||
|
||||
# DeepSeek
|
||||
(r"^deepseek/deepseek-chat", "deepseek-chat", "mid"),
|
||||
(r"^deepseek/deepseek-coder", "deepseek-coder", "mid"),
|
||||
(r"^deepseek/deepseek-r1$", "deepseek-r1", "flagship"),
|
||||
|
||||
# Mistral
|
||||
(r"^mistralai/mistral-large", "mistral-large", "mid"),
|
||||
(r"^mistralai/mistral-medium", "mistral-medium", "mid"),
|
||||
(r"^mistralai/mistral-small", "mistral-small", "fast"),
|
||||
|
||||
# Meta Llama
|
||||
(r"^meta-llama/llama-3\.3-70b", "llama-3-70b", "mid"),
|
||||
(r"^meta-llama/llama-3\.2-90b", "llama-3-90b", "mid"),
|
||||
(r"^meta-llama/llama-3\.1-405b", "llama-3-405b", "flagship"),
|
||||
|
||||
# Qwen
|
||||
(r"^qwen/qwen-2\.5-72b", "qwen-72b", "mid"),
|
||||
(r"^qwen/qwq-32b", "qwq", "mid"),
|
||||
]
|
||||
|
||||
HEADERS = {
|
||||
"Accept": "application/json",
|
||||
"Origin": "https://llm-stats.com",
|
||||
@@ -121,6 +178,75 @@ def normalize_model_id(model_id: str) -> str:
|
||||
return "-".join(filtered)
|
||||
|
||||
|
||||
def extract_version(model_id: str) -> Tuple[float, str]:
|
||||
"""
|
||||
Extract version number from model ID for sorting.
|
||||
Returns (version_float, original_id) for sorting.
|
||||
Higher version = newer model.
|
||||
"""
|
||||
# Try to find version patterns like 4.5, 3.7, 2.5, etc.
|
||||
patterns = [
|
||||
r"-(\d+\.?\d*)-", # e.g., claude-3.5-sonnet
|
||||
r"-(\d+\.?\d*)$", # e.g., gemini-2.5-pro
|
||||
r"(\d+\.?\d*)$", # e.g., claude-sonnet-4.5
|
||||
r"/[a-z]+-(\d+\.?\d*)", # e.g., gpt-4.1
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, model_id)
|
||||
if match:
|
||||
try:
|
||||
return (float(match.group(1)), model_id)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Fallback: use model name length as proxy (longer names often newer)
|
||||
return (0.0, model_id)
|
||||
|
||||
|
||||
def infer_model_families(models: List[dict]) -> Dict[str, dict]:
|
||||
"""
|
||||
Infer model families from OpenRouter model list.
|
||||
|
||||
Returns a dict like:
|
||||
{
|
||||
"claude-sonnet": {
|
||||
"latest": "anthropic/claude-sonnet-4.5",
|
||||
"members": ["anthropic/claude-sonnet-4.5", ...],
|
||||
"tier": "mid"
|
||||
}
|
||||
}
|
||||
"""
|
||||
families: Dict[str, List[Tuple[str, float]]] = defaultdict(list)
|
||||
family_tiers: Dict[str, str] = {}
|
||||
|
||||
for model in models:
|
||||
model_id = model.get("id", "")
|
||||
|
||||
for pattern, family_name, tier in MODEL_FAMILY_PATTERNS:
|
||||
if re.match(pattern, model_id):
|
||||
version, _ = extract_version(model_id)
|
||||
families[family_name].append((model_id, version))
|
||||
family_tiers[family_name] = tier
|
||||
break
|
||||
|
||||
# Sort each family by version (descending) and build result
|
||||
result = {}
|
||||
for family_name, members in families.items():
|
||||
# Sort by version descending (highest first = latest)
|
||||
sorted_members = sorted(members, key=lambda x: x[1], reverse=True)
|
||||
member_ids = [m[0] for m in sorted_members]
|
||||
|
||||
if member_ids:
|
||||
result[family_name] = {
|
||||
"latest": member_ids[0],
|
||||
"members": member_ids,
|
||||
"tier": family_tiers.get(family_name, "mid")
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def build_model_score_map(benchmarks_data: Dict[str, dict]) -> Dict[str, dict]:
|
||||
"""
|
||||
Build a map from normalized model names to their benchmark scores.
|
||||
@@ -182,6 +308,52 @@ def calculate_category_averages(scores: dict) -> dict:
|
||||
return averages
|
||||
|
||||
|
||||
def generate_aliases(families: Dict[str, dict]) -> Dict[str, str]:
|
||||
"""
|
||||
Generate common aliases that map to the latest model in a family.
|
||||
|
||||
This helps resolve outdated model names like "claude-3.5-sonnet"
|
||||
to the latest "anthropic/claude-sonnet-4.5".
|
||||
"""
|
||||
aliases = {}
|
||||
|
||||
for family_name, family_info in families.items():
|
||||
latest = family_info["latest"]
|
||||
members = family_info["members"]
|
||||
|
||||
# Add all members as aliases to latest
|
||||
for member in members:
|
||||
if member != latest:
|
||||
aliases[member] = latest
|
||||
|
||||
# Also add short forms
|
||||
if "/" in member:
|
||||
short = member.split("/")[-1]
|
||||
aliases[short] = latest
|
||||
|
||||
# Add family name as alias
|
||||
aliases[family_name] = latest
|
||||
|
||||
# Add common variations
|
||||
if family_name == "claude-sonnet":
|
||||
aliases["sonnet"] = latest
|
||||
aliases["claude sonnet"] = latest
|
||||
elif family_name == "claude-haiku":
|
||||
aliases["haiku"] = latest
|
||||
aliases["claude haiku"] = latest
|
||||
elif family_name == "claude-opus":
|
||||
aliases["opus"] = latest
|
||||
aliases["claude opus"] = latest
|
||||
elif family_name == "gpt-4":
|
||||
aliases["gpt4"] = latest
|
||||
aliases["gpt-4o"] = latest
|
||||
elif family_name == "gpt-4-mini":
|
||||
aliases["gpt4-mini"] = latest
|
||||
aliases["gpt-4o-mini"] = latest
|
||||
|
||||
return aliases
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("OpenRouter + ZeroEval Benchmark Merger")
|
||||
@@ -199,7 +371,18 @@ def main():
|
||||
json.dump({"data": openrouter_models}, f)
|
||||
print(f"Saved raw OpenRouter models to {or_path}")
|
||||
|
||||
# Step 2: Fetch all benchmark metadata
|
||||
# Step 2: Infer model families
|
||||
print("\nInferring model families...")
|
||||
families = infer_model_families(openrouter_models)
|
||||
print(f" Found {len(families)} model families:")
|
||||
for name, info in sorted(families.items()):
|
||||
print(f" - {name}: {info['latest']} ({len(info['members'])} members, tier={info['tier']})")
|
||||
|
||||
# Generate aliases
|
||||
aliases = generate_aliases(families)
|
||||
print(f" Generated {len(aliases)} aliases for auto-upgrade")
|
||||
|
||||
# Step 3: Fetch all benchmark metadata
|
||||
all_benchmarks = fetch_all_benchmarks()
|
||||
if not all_benchmarks:
|
||||
print("Failed to fetch benchmarks, exiting.")
|
||||
@@ -214,7 +397,7 @@ def main():
|
||||
# Build benchmark ID lookup
|
||||
benchmark_lookup = {b["benchmark_id"]: b for b in all_benchmarks}
|
||||
|
||||
# Step 3: Fetch scores for key benchmarks in each category
|
||||
# Step 4: Fetch scores for key benchmarks in each category
|
||||
print("\nFetching benchmark scores by category...")
|
||||
benchmarks_data = {}
|
||||
|
||||
@@ -245,12 +428,12 @@ def main():
|
||||
|
||||
time.sleep(0.2) # Rate limiting
|
||||
|
||||
# Step 4: Build model score map
|
||||
# Step 5: Build model score map
|
||||
print("\nBuilding model score map...")
|
||||
model_scores = build_model_score_map(benchmarks_data)
|
||||
print(f" Found scores for {len(model_scores)} unique model IDs")
|
||||
|
||||
# Step 5: Merge with OpenRouter models
|
||||
# Step 6: Merge with OpenRouter models
|
||||
print("\nMerging with OpenRouter models...")
|
||||
merged_models = []
|
||||
matched_count = 0
|
||||
@@ -281,12 +464,14 @@ def main():
|
||||
|
||||
print(f" Matched {matched_count}/{len(openrouter_models)} models with benchmarks")
|
||||
|
||||
# Step 6: Save merged data
|
||||
# Step 7: Save merged data with families
|
||||
output = {
|
||||
"generated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||
"total_models": len(merged_models),
|
||||
"models_with_benchmarks": matched_count,
|
||||
"categories": list(KEY_BENCHMARKS.keys()),
|
||||
"families": families,
|
||||
"aliases": aliases,
|
||||
"models": merged_models
|
||||
}
|
||||
|
||||
@@ -295,14 +480,21 @@ def main():
|
||||
json.dump(output, f, indent=2)
|
||||
print(f"\n✓ Saved merged data to {output_path}")
|
||||
|
||||
# Step 7: Create summary
|
||||
# Step 8: Create summary
|
||||
print("\n" + "=" * 60)
|
||||
print("Summary")
|
||||
print("=" * 60)
|
||||
print(f"Total OpenRouter models: {len(openrouter_models)}")
|
||||
print(f"Models with benchmark data: {matched_count}")
|
||||
print(f"Model families detected: {len(families)}")
|
||||
print(f"Aliases generated: {len(aliases)}")
|
||||
print(f"Categories tracked: {', '.join(KEY_BENCHMARKS.keys())}")
|
||||
|
||||
# Show family info
|
||||
print("\nModel families (latest versions):")
|
||||
for name, info in sorted(families.items()):
|
||||
print(f" - {name}: {info['latest']}")
|
||||
|
||||
# Show some example matches
|
||||
print("\nExample matched models:")
|
||||
for m in merged_models[:10]:
|
||||
|
||||
@@ -523,13 +523,11 @@ Use `search_memory` when you encounter a problem you might have solved before or
|
||||
if let Some(tool_calls) = &response.tool_calls {
|
||||
if !tool_calls.is_empty() {
|
||||
// Add assistant message with tool calls
|
||||
// Preserve reasoning_details for models that require it (Gemini 3, Claude 3.7+)
|
||||
messages.push(ChatMessage {
|
||||
role: Role::Assistant,
|
||||
content: response.content.clone().map(MessageContent::text),
|
||||
tool_calls: Some(tool_calls.clone()),
|
||||
tool_call_id: None,
|
||||
reasoning_details: response.reasoning_details.clone(),
|
||||
});
|
||||
|
||||
// Check for repetitive actions
|
||||
@@ -717,7 +715,6 @@ Use `search_memory` when you encounter a problem you might have solved before or
|
||||
content: Some(message_content),
|
||||
tool_calls: None,
|
||||
tool_call_id: Some(tool_call.id.clone()),
|
||||
reasoning_details: None,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -859,11 +856,25 @@ impl Agent for TaskExecutor {
|
||||
|
||||
async fn execute(&self, task: &mut Task, ctx: &AgentContext) -> AgentResult {
|
||||
// Use model selected during planning, otherwise fall back to default.
|
||||
let selected = task
|
||||
.analysis()
|
||||
.selected_model
|
||||
.clone()
|
||||
.unwrap_or_else(|| ctx.config.default_model.clone());
|
||||
// If falling back to default, resolve it to latest version first.
|
||||
let selected = if let Some(model) = task.analysis().selected_model.clone() {
|
||||
model
|
||||
} else {
|
||||
// Resolve default model to latest version
|
||||
if let Some(resolver) = &ctx.resolver {
|
||||
let resolver = resolver.read().await;
|
||||
let resolved = resolver.resolve(&ctx.config.default_model);
|
||||
if resolved.upgraded {
|
||||
tracing::info!(
|
||||
"Executor: default model auto-upgraded: {} → {}",
|
||||
resolved.original, resolved.resolved
|
||||
);
|
||||
}
|
||||
resolved.resolved
|
||||
} else {
|
||||
ctx.config.default_model.clone()
|
||||
}
|
||||
};
|
||||
let model = selected.as_str();
|
||||
|
||||
let result = self.run_loop(task, model, ctx).await;
|
||||
@@ -909,11 +920,26 @@ impl Agent for TaskExecutor {
|
||||
impl TaskExecutor {
|
||||
/// Execute a task and return detailed execution result for retry analysis.
|
||||
pub async fn execute_with_signals(&self, task: &mut Task, ctx: &AgentContext) -> (AgentResult, ExecutionSignals) {
|
||||
let selected = task
|
||||
.analysis()
|
||||
.selected_model
|
||||
.clone()
|
||||
.unwrap_or_else(|| ctx.config.default_model.clone());
|
||||
// Use model selected during planning, otherwise fall back to default.
|
||||
// If falling back to default, resolve it to latest version first.
|
||||
let selected = if let Some(model) = task.analysis().selected_model.clone() {
|
||||
model
|
||||
} else {
|
||||
// Resolve default model to latest version
|
||||
if let Some(resolver) = &ctx.resolver {
|
||||
let resolver = resolver.read().await;
|
||||
let resolved = resolver.resolve(&ctx.config.default_model);
|
||||
if resolved.upgraded {
|
||||
tracing::info!(
|
||||
"Executor: default model auto-upgraded: {} → {}",
|
||||
resolved.original, resolved.resolved
|
||||
);
|
||||
}
|
||||
resolved.resolved
|
||||
} else {
|
||||
ctx.config.default_model.clone()
|
||||
}
|
||||
};
|
||||
let model = selected.as_str();
|
||||
|
||||
let result = self.run_loop(task, model, ctx).await;
|
||||
|
||||
@@ -540,8 +540,20 @@ impl Agent for ModelSelector {
|
||||
let models = ctx.pricing.models_by_cost_filtered(true).await;
|
||||
|
||||
if models.is_empty() {
|
||||
// Fall back to configured default model
|
||||
let default_model = ctx.config.default_model.clone();
|
||||
// Fall back to configured default model (after resolving to latest)
|
||||
let default_model = if let Some(resolver) = &ctx.resolver {
|
||||
let resolver = resolver.read().await;
|
||||
let resolved = resolver.resolve(&ctx.config.default_model);
|
||||
if resolved.upgraded {
|
||||
tracing::info!(
|
||||
"Default model auto-upgraded: {} → {}",
|
||||
resolved.original, resolved.resolved
|
||||
);
|
||||
}
|
||||
resolved.resolved
|
||||
} else {
|
||||
ctx.config.default_model.clone()
|
||||
};
|
||||
|
||||
// Record on task analysis
|
||||
{
|
||||
@@ -565,15 +577,43 @@ impl Agent for ModelSelector {
|
||||
}));
|
||||
}
|
||||
|
||||
// Get user-requested model - if specified, use it directly if available
|
||||
// Get user-requested model - if specified, resolve to latest version and use it
|
||||
let requested_model = task.analysis().requested_model.clone();
|
||||
|
||||
// If user explicitly requested a model and it's available, use it directly
|
||||
if let Some(ref req_model) = requested_model {
|
||||
// Auto-upgrade outdated model names using the resolver
|
||||
let (resolved_model, was_upgraded) = if let Some(ref req_model) = requested_model {
|
||||
if let Some(resolver) = &ctx.resolver {
|
||||
let resolver = resolver.read().await;
|
||||
let resolved = resolver.resolve(req_model);
|
||||
if resolved.upgraded {
|
||||
tracing::info!(
|
||||
"Model auto-upgraded: {} → {} ({})",
|
||||
resolved.original,
|
||||
resolved.resolved,
|
||||
resolved.reason.as_deref().unwrap_or("family upgrade")
|
||||
);
|
||||
}
|
||||
(Some(resolved.resolved), resolved.upgraded)
|
||||
} else {
|
||||
(Some(req_model.clone()), false)
|
||||
}
|
||||
} else {
|
||||
(None, false)
|
||||
};
|
||||
|
||||
// If user explicitly requested a model (possibly upgraded) and it's available, use it directly
|
||||
if let Some(ref req_model) = resolved_model {
|
||||
if models.iter().any(|m| &m.model_id == req_model) {
|
||||
let upgrade_note = if was_upgraded {
|
||||
format!(" (auto-upgraded from {})", requested_model.as_deref().unwrap_or("unknown"))
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
tracing::info!(
|
||||
"Using user-requested model directly: {} (not optimizing)",
|
||||
req_model
|
||||
"Using requested model directly: {}{}",
|
||||
req_model,
|
||||
upgrade_note
|
||||
);
|
||||
|
||||
// Record selection in analysis
|
||||
@@ -584,17 +624,19 @@ impl Agent for ModelSelector {
|
||||
}
|
||||
|
||||
return AgentResult::success(
|
||||
&format!("Using user-requested model: {}", req_model),
|
||||
&format!("Using requested model: {}{}", req_model, upgrade_note),
|
||||
1,
|
||||
)
|
||||
.with_data(json!({
|
||||
"model_id": req_model,
|
||||
"expected_cost_cents": 50,
|
||||
"confidence": 1.0,
|
||||
"reasoning": format!("User explicitly requested model: {}", req_model),
|
||||
"reasoning": format!("User requested model: {}{}", req_model, upgrade_note),
|
||||
"fallbacks": [],
|
||||
"used_historical_data": false,
|
||||
"used_benchmark_data": false,
|
||||
"was_upgraded": was_upgraded,
|
||||
"original_model": requested_model,
|
||||
"task_type": format!("{:?}", task_type),
|
||||
}));
|
||||
}
|
||||
@@ -607,7 +649,7 @@ impl Agent for ModelSelector {
|
||||
budget_cents,
|
||||
task_type,
|
||||
historical_stats.as_ref(),
|
||||
requested_model.as_deref(),
|
||||
resolved_model.as_deref(),
|
||||
ctx,
|
||||
).await {
|
||||
Some(rec) => {
|
||||
|
||||
@@ -12,36 +12,36 @@ use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::agents::{
|
||||
Agent, AgentContext, AgentId, AgentRef, AgentResult, AgentType, Complexity, OrchestratorAgent,
|
||||
leaf::{ComplexityEstimator, ModelSelector, TaskExecutor, Verifier},
|
||||
Agent, AgentContext, AgentId, AgentRef, AgentResult, AgentType, Complexity, OrchestratorAgent,
|
||||
};
|
||||
use crate::budget::Budget;
|
||||
use crate::llm::{ChatMessage, Role};
|
||||
use crate::task::{Task, Subtask, SubtaskPlan, VerificationCriteria};
|
||||
use crate::task::{Subtask, SubtaskPlan, Task, VerificationCriteria};
|
||||
|
||||
/// Node agent - intermediate orchestrator.
|
||||
///
|
||||
///
|
||||
/// # Purpose
|
||||
/// Handles subtasks that may still be complex enough
|
||||
/// to warrant further splitting. Now with full recursive
|
||||
/// splitting capabilities like RootAgent.
|
||||
///
|
||||
///
|
||||
/// # Recursive Splitting
|
||||
/// NodeAgent can estimate complexity of its subtasks and
|
||||
/// recursively split them if they're still too complex,
|
||||
/// respecting the `max_split_depth` in context.
|
||||
pub struct NodeAgent {
|
||||
id: AgentId,
|
||||
|
||||
|
||||
/// Name for identification in logs
|
||||
name: String,
|
||||
|
||||
|
||||
// Child agents - full pipeline for recursive splitting
|
||||
complexity_estimator: Arc<ComplexityEstimator>,
|
||||
model_selector: Arc<ModelSelector>,
|
||||
task_executor: Arc<TaskExecutor>,
|
||||
verifier: Arc<Verifier>,
|
||||
|
||||
|
||||
// Child node agents (for further splitting)
|
||||
child_nodes: Vec<Arc<NodeAgent>>,
|
||||
}
|
||||
@@ -79,22 +79,25 @@ impl NodeAgent {
|
||||
/// Estimate complexity of a task.
|
||||
async fn estimate_complexity(&self, task: &mut Task, ctx: &AgentContext) -> Complexity {
|
||||
let result = self.complexity_estimator.execute(task, ctx).await;
|
||||
|
||||
|
||||
if let Some(data) = result.data {
|
||||
let score = data["score"].as_f64().unwrap_or(0.5);
|
||||
let reasoning = data["reasoning"].as_str().unwrap_or("").to_string();
|
||||
let estimated_tokens = data["estimated_tokens"].as_u64().unwrap_or(2000);
|
||||
let should_split = data["should_split"].as_bool().unwrap_or(false);
|
||||
|
||||
Complexity::new(score, reasoning, estimated_tokens)
|
||||
.with_split(should_split)
|
||||
|
||||
Complexity::new(score, reasoning, estimated_tokens).with_split(should_split)
|
||||
} else {
|
||||
Complexity::moderate("Could not estimate complexity")
|
||||
}
|
||||
}
|
||||
|
||||
/// Split a complex task into subtasks.
|
||||
async fn split_task(&self, task: &Task, ctx: &AgentContext) -> Result<SubtaskPlan, AgentResult> {
|
||||
async fn split_task(
|
||||
&self,
|
||||
task: &Task,
|
||||
ctx: &AgentContext,
|
||||
) -> Result<SubtaskPlan, AgentResult> {
|
||||
let prompt = format!(
|
||||
r#"You are a task planner. Break down this task into smaller, manageable subtasks.
|
||||
|
||||
@@ -135,11 +138,15 @@ Respond ONLY with the JSON object."#,
|
||||
);
|
||||
|
||||
let messages = vec![
|
||||
ChatMessage::new(Role::System, "You are a precise task planner. Respond only with JSON."),
|
||||
ChatMessage::new(
|
||||
Role::System,
|
||||
"You are a precise task planner. Respond only with JSON.",
|
||||
),
|
||||
ChatMessage::new(Role::User, prompt),
|
||||
];
|
||||
|
||||
let response = ctx.llm
|
||||
let response = ctx
|
||||
.llm
|
||||
.chat_completion("openai/gpt-4.1-mini", &messages, None)
|
||||
.await
|
||||
.map_err(|e| AgentResult::failure(format!("LLM error: {}", e), 1))?;
|
||||
@@ -151,7 +158,7 @@ Respond ONLY with the JSON object."#,
|
||||
/// Extract JSON from LLM response (handles markdown code blocks).
|
||||
fn extract_json(response: &str) -> String {
|
||||
let trimmed = response.trim();
|
||||
|
||||
|
||||
// Check for markdown code block
|
||||
if trimmed.starts_with("```") {
|
||||
// Find the end of the opening fence
|
||||
@@ -163,7 +170,7 @@ Respond ONLY with the JSON object."#,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Try to find JSON object in the response
|
||||
if let Some(start) = trimmed.find('{') {
|
||||
if let Some(end) = trimmed.rfind('}') {
|
||||
@@ -172,7 +179,7 @@ Respond ONLY with the JSON object."#,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Return as-is if no extraction needed
|
||||
trimmed.to_string()
|
||||
}
|
||||
@@ -184,8 +191,16 @@ Respond ONLY with the JSON object."#,
|
||||
parent_id: crate::task::TaskId,
|
||||
) -> Result<SubtaskPlan, AgentResult> {
|
||||
let extracted = Self::extract_json(response);
|
||||
let json: serde_json::Value = serde_json::from_str(&extracted)
|
||||
.map_err(|e| AgentResult::failure(format!("Failed to parse subtasks: {} (raw: {}...)", e, response.chars().take(100).collect::<String>()), 0))?;
|
||||
let json: serde_json::Value = serde_json::from_str(&extracted).map_err(|e| {
|
||||
AgentResult::failure(
|
||||
format!(
|
||||
"Failed to parse subtasks: {} (raw: {}...)",
|
||||
e,
|
||||
response.chars().take(100).collect::<String>()
|
||||
),
|
||||
0,
|
||||
)
|
||||
})?;
|
||||
|
||||
let reasoning = json["reasoning"]
|
||||
.as_str()
|
||||
@@ -200,7 +215,7 @@ Respond ONLY with the JSON object."#,
|
||||
let desc = s["description"].as_str().unwrap_or("").to_string();
|
||||
let verification = s["verification"].as_str().unwrap_or("");
|
||||
let weight = s["weight"].as_f64().unwrap_or(1.0);
|
||||
|
||||
|
||||
// Parse dependencies array
|
||||
let dependencies: Vec<usize> = s["dependencies"]
|
||||
.as_array()
|
||||
@@ -210,12 +225,9 @@ Respond ONLY with the JSON object."#,
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
Subtask::new(
|
||||
desc,
|
||||
VerificationCriteria::llm_based(verification),
|
||||
weight,
|
||||
).with_dependencies(dependencies)
|
||||
|
||||
Subtask::new(desc, VerificationCriteria::llm_based(verification), weight)
|
||||
.with_dependencies(dependencies)
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
@@ -266,7 +278,7 @@ Respond ONLY with the JSON object."#,
|
||||
|
||||
// Create a child NodeAgent for this subtask (recursive)
|
||||
let child_node = NodeAgent::new(format!("{}-sub", self.name));
|
||||
|
||||
|
||||
// Execute through the child node (which may split further)
|
||||
let result = child_node.execute(task, &child_ctx).await;
|
||||
total_cost += result.cost_cents;
|
||||
@@ -278,19 +290,21 @@ Respond ONLY with the JSON object."#,
|
||||
let successes = results.iter().filter(|r| r.success).count();
|
||||
let total = results.len();
|
||||
|
||||
// Concatenate successful outputs for meaningful aggregation
|
||||
let combined_output = Self::concatenate_outputs(&results);
|
||||
|
||||
if successes == total {
|
||||
AgentResult::success(
|
||||
format!("All {} subtasks completed successfully", total),
|
||||
total_cost,
|
||||
)
|
||||
.with_data(json!({
|
||||
AgentResult::success(combined_output, total_cost).with_data(json!({
|
||||
"subtasks_total": total,
|
||||
"subtasks_succeeded": successes,
|
||||
"results": results.iter().map(|r| &r.output).collect::<Vec<_>>(),
|
||||
}))
|
||||
} else {
|
||||
AgentResult::failure(
|
||||
format!("{}/{} subtasks succeeded", successes, total),
|
||||
format!(
|
||||
"{}/{} subtasks succeeded\n\n{}",
|
||||
successes, total, combined_output
|
||||
),
|
||||
total_cost,
|
||||
)
|
||||
.with_data(json!({
|
||||
@@ -304,6 +318,31 @@ Respond ONLY with the JSON object."#,
|
||||
}
|
||||
}
|
||||
|
||||
/// Concatenate subtask outputs into a single string.
|
||||
/// Used for intermediate aggregation (RootAgent handles final synthesis).
|
||||
fn concatenate_outputs(results: &[AgentResult]) -> String {
|
||||
let outputs: Vec<String> = results
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, r)| r.success && !r.output.is_empty())
|
||||
.map(|(i, r)| {
|
||||
if results.len() == 1 {
|
||||
r.output.clone()
|
||||
} else {
|
||||
format!("### Part {}\n{}", i + 1, r.output)
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if outputs.is_empty() {
|
||||
"No output generated.".to_string()
|
||||
} else if outputs.len() == 1 {
|
||||
outputs.into_iter().next().unwrap()
|
||||
} else {
|
||||
outputs.join("\n\n")
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute with tree updates for visualization.
|
||||
/// This method updates the parent's tree structure as this node executes.
|
||||
pub async fn execute_with_tree(
|
||||
@@ -315,7 +354,7 @@ Respond ONLY with the JSON object."#,
|
||||
emit_ctx: &AgentContext,
|
||||
) -> AgentResult {
|
||||
use crate::api::control::AgentTreeNode;
|
||||
|
||||
|
||||
let mut total_cost = 0u64;
|
||||
|
||||
tracing::info!(
|
||||
@@ -326,7 +365,11 @@ Respond ONLY with the JSON object."#,
|
||||
);
|
||||
|
||||
// Step 1: Estimate complexity
|
||||
ctx.emit_phase("estimating_complexity", Some("Analyzing subtask..."), Some(&self.name));
|
||||
ctx.emit_phase(
|
||||
"estimating_complexity",
|
||||
Some("Analyzing subtask..."),
|
||||
Some(&self.name),
|
||||
);
|
||||
let complexity = self.estimate_complexity(task, ctx).await;
|
||||
total_cost += 1;
|
||||
|
||||
@@ -346,15 +389,21 @@ Respond ONLY with the JSON object."#,
|
||||
|
||||
// Step 2: Decide execution strategy
|
||||
if complexity.should_split() && ctx.can_split() {
|
||||
ctx.emit_phase("splitting_task", Some("Decomposing subtask..."), Some(&self.name));
|
||||
ctx.emit_phase(
|
||||
"splitting_task",
|
||||
Some("Decomposing subtask..."),
|
||||
Some(&self.name),
|
||||
);
|
||||
tracing::info!("NodeAgent '{}' splitting task into sub-subtasks", self.name);
|
||||
|
||||
match self.split_task(task, ctx).await {
|
||||
Ok(plan) => {
|
||||
total_cost += 2;
|
||||
|
||||
|
||||
// Add child nodes to this node in the tree
|
||||
if let Some(parent_node) = root_tree.children.iter_mut().find(|n| n.id == node_id) {
|
||||
if let Some(parent_node) =
|
||||
root_tree.children.iter_mut().find(|n| n.id == node_id)
|
||||
{
|
||||
for (i, subtask) in plan.subtasks().iter().enumerate() {
|
||||
let child_node = AgentTreeNode::new(
|
||||
&format!("{}-sub-{}", node_id, i + 1),
|
||||
@@ -367,7 +416,7 @@ Respond ONLY with the JSON object."#,
|
||||
}
|
||||
}
|
||||
emit_ctx.emit_tree(root_tree.clone());
|
||||
|
||||
|
||||
let subtask_count = plan.subtasks().len();
|
||||
tracing::info!(
|
||||
"NodeAgent '{}' created {} sub-subtasks",
|
||||
@@ -378,8 +427,18 @@ Respond ONLY with the JSON object."#,
|
||||
// Execute subtasks recursively with tree updates
|
||||
let child_ctx = ctx.child_context();
|
||||
let requested_model = task.analysis().requested_model.as_deref();
|
||||
let result = self.execute_subtasks_with_tree(plan, task.budget(), &child_ctx, node_id, root_tree, emit_ctx, requested_model).await;
|
||||
|
||||
let result = self
|
||||
.execute_subtasks_with_tree(
|
||||
plan,
|
||||
task.budget(),
|
||||
&child_ctx,
|
||||
node_id,
|
||||
root_tree,
|
||||
emit_ctx,
|
||||
requested_model,
|
||||
)
|
||||
.await;
|
||||
|
||||
return AgentResult {
|
||||
success: result.success,
|
||||
output: result.output,
|
||||
@@ -407,7 +466,7 @@ Respond ONLY with the JSON object."#,
|
||||
"Task Executor",
|
||||
"Execute subtask",
|
||||
)
|
||||
.with_status("running")
|
||||
.with_status("running"),
|
||||
);
|
||||
parent_node.children.push(
|
||||
AgentTreeNode::new(
|
||||
@@ -416,13 +475,17 @@ Respond ONLY with the JSON object."#,
|
||||
"Verifier",
|
||||
"Verify result",
|
||||
)
|
||||
.with_status("pending")
|
||||
.with_status("pending"),
|
||||
);
|
||||
}
|
||||
emit_ctx.emit_tree(root_tree.clone());
|
||||
|
||||
// Select model
|
||||
ctx.emit_phase("selecting_model", Some("Choosing model..."), Some(&self.name));
|
||||
ctx.emit_phase(
|
||||
"selecting_model",
|
||||
Some("Choosing model..."),
|
||||
Some(&self.name),
|
||||
);
|
||||
let sel_result = self.model_selector.execute(task, ctx).await;
|
||||
total_cost += sel_result.cost_cents;
|
||||
|
||||
@@ -433,8 +496,16 @@ Respond ONLY with the JSON object."#,
|
||||
|
||||
// Update executor status
|
||||
if let Some(parent_node) = root_tree.children.iter_mut().find(|n| n.id == node_id) {
|
||||
if let Some(exec_node) = parent_node.children.iter_mut().find(|n| n.id == format!("{}-executor", node_id)) {
|
||||
exec_node.status = if result.success { "completed".to_string() } else { "failed".to_string() };
|
||||
if let Some(exec_node) = parent_node
|
||||
.children
|
||||
.iter_mut()
|
||||
.find(|n| n.id == format!("{}-executor", node_id))
|
||||
{
|
||||
exec_node.status = if result.success {
|
||||
"completed".to_string()
|
||||
} else {
|
||||
"failed".to_string()
|
||||
};
|
||||
exec_node.budget_spent = result.cost_cents;
|
||||
}
|
||||
}
|
||||
@@ -444,18 +515,21 @@ Respond ONLY with the JSON object."#,
|
||||
task.set_last_output(result.output.clone());
|
||||
|
||||
if !result.success {
|
||||
return AgentResult::failure(result.output, total_cost)
|
||||
.with_data(json!({
|
||||
"node_name": self.name,
|
||||
"complexity": complexity.score(),
|
||||
"was_split": false,
|
||||
"execution": result.data,
|
||||
}));
|
||||
return AgentResult::failure(result.output, total_cost).with_data(json!({
|
||||
"node_name": self.name,
|
||||
"complexity": complexity.score(),
|
||||
"was_split": false,
|
||||
"execution": result.data,
|
||||
}));
|
||||
}
|
||||
|
||||
// Verify
|
||||
if let Some(parent_node) = root_tree.children.iter_mut().find(|n| n.id == node_id) {
|
||||
if let Some(ver_node) = parent_node.children.iter_mut().find(|n| n.id == format!("{}-verifier", node_id)) {
|
||||
if let Some(ver_node) = parent_node
|
||||
.children
|
||||
.iter_mut()
|
||||
.find(|n| n.id == format!("{}-verifier", node_id))
|
||||
{
|
||||
ver_node.status = "running".to_string();
|
||||
}
|
||||
}
|
||||
@@ -467,8 +541,16 @@ Respond ONLY with the JSON object."#,
|
||||
|
||||
// Update verifier status
|
||||
if let Some(parent_node) = root_tree.children.iter_mut().find(|n| n.id == node_id) {
|
||||
if let Some(ver_node) = parent_node.children.iter_mut().find(|n| n.id == format!("{}-verifier", node_id)) {
|
||||
ver_node.status = if verification.success { "completed".to_string() } else { "failed".to_string() };
|
||||
if let Some(ver_node) = parent_node
|
||||
.children
|
||||
.iter_mut()
|
||||
.find(|n| n.id == format!("{}-verifier", node_id))
|
||||
{
|
||||
ver_node.status = if verification.success {
|
||||
"completed".to_string()
|
||||
} else {
|
||||
"failed".to_string()
|
||||
};
|
||||
ver_node.budget_spent = verification.cost_cents;
|
||||
}
|
||||
}
|
||||
@@ -531,10 +613,16 @@ Respond ONLY with the JSON object."#,
|
||||
|
||||
for (i, task) in tasks.iter_mut().enumerate() {
|
||||
let subtask_id = format!("{}-sub-{}", parent_node_id, i + 1);
|
||||
|
||||
|
||||
// Update subtask status to running
|
||||
if let Some(parent_node) = root_tree.children.iter_mut().find(|n| n.id == parent_node_id) {
|
||||
if let Some(child_node) = parent_node.children.iter_mut().find(|n| n.id == subtask_id) {
|
||||
if let Some(parent_node) = root_tree
|
||||
.children
|
||||
.iter_mut()
|
||||
.find(|n| n.id == parent_node_id)
|
||||
{
|
||||
if let Some(child_node) =
|
||||
parent_node.children.iter_mut().find(|n| n.id == subtask_id)
|
||||
{
|
||||
child_node.status = "running".to_string();
|
||||
}
|
||||
}
|
||||
@@ -552,9 +640,19 @@ Respond ONLY with the JSON object."#,
|
||||
total_cost += result.cost_cents;
|
||||
|
||||
// Update subtask status
|
||||
if let Some(parent_node) = root_tree.children.iter_mut().find(|n| n.id == parent_node_id) {
|
||||
if let Some(child_node) = parent_node.children.iter_mut().find(|n| n.id == subtask_id) {
|
||||
child_node.status = if result.success { "completed".to_string() } else { "failed".to_string() };
|
||||
if let Some(parent_node) = root_tree
|
||||
.children
|
||||
.iter_mut()
|
||||
.find(|n| n.id == parent_node_id)
|
||||
{
|
||||
if let Some(child_node) =
|
||||
parent_node.children.iter_mut().find(|n| n.id == subtask_id)
|
||||
{
|
||||
child_node.status = if result.success {
|
||||
"completed".to_string()
|
||||
} else {
|
||||
"failed".to_string()
|
||||
};
|
||||
child_node.budget_spent = result.cost_cents;
|
||||
}
|
||||
}
|
||||
@@ -566,18 +664,20 @@ Respond ONLY with the JSON object."#,
|
||||
let successes = results.iter().filter(|r| r.success).count();
|
||||
let total = results.len();
|
||||
|
||||
// Concatenate successful outputs for meaningful aggregation
|
||||
let combined_output = Self::concatenate_outputs(&results);
|
||||
|
||||
if successes == total {
|
||||
AgentResult::success(
|
||||
format!("All {} sub-subtasks completed successfully", total),
|
||||
total_cost,
|
||||
)
|
||||
.with_data(json!({
|
||||
AgentResult::success(combined_output, total_cost).with_data(json!({
|
||||
"subtasks_total": total,
|
||||
"subtasks_succeeded": successes,
|
||||
}))
|
||||
} else {
|
||||
AgentResult::failure(
|
||||
format!("{}/{} sub-subtasks succeeded", successes, total),
|
||||
format!(
|
||||
"{}/{} sub-subtasks succeeded\n\n{}",
|
||||
successes, total, combined_output
|
||||
),
|
||||
total_cost,
|
||||
)
|
||||
.with_data(json!({
|
||||
@@ -619,7 +719,11 @@ impl Agent for NodeAgent {
|
||||
);
|
||||
|
||||
// Step 1: Estimate complexity
|
||||
ctx.emit_phase("estimating_complexity", Some("Analyzing subtask..."), Some(&self.name));
|
||||
ctx.emit_phase(
|
||||
"estimating_complexity",
|
||||
Some("Analyzing subtask..."),
|
||||
Some(&self.name),
|
||||
);
|
||||
let complexity = self.estimate_complexity(task, ctx).await;
|
||||
total_cost += 1;
|
||||
|
||||
@@ -634,13 +738,17 @@ impl Agent for NodeAgent {
|
||||
// Step 2: Decide execution strategy
|
||||
if complexity.should_split() && ctx.can_split() {
|
||||
// Complex subtask: split further recursively
|
||||
ctx.emit_phase("splitting_task", Some("Decomposing subtask..."), Some(&self.name));
|
||||
ctx.emit_phase(
|
||||
"splitting_task",
|
||||
Some("Decomposing subtask..."),
|
||||
Some(&self.name),
|
||||
);
|
||||
tracing::info!("NodeAgent '{}' splitting task into sub-subtasks", self.name);
|
||||
|
||||
match self.split_task(task, ctx).await {
|
||||
Ok(plan) => {
|
||||
total_cost += 2; // Splitting cost
|
||||
|
||||
|
||||
let subtask_count = plan.subtasks().len();
|
||||
tracing::info!(
|
||||
"NodeAgent '{}' created {} sub-subtasks",
|
||||
@@ -650,8 +758,10 @@ impl Agent for NodeAgent {
|
||||
|
||||
// Execute subtasks recursively
|
||||
let requested_model = task.analysis().requested_model.as_deref();
|
||||
let result = self.execute_subtasks(plan, task.budget(), ctx, requested_model).await;
|
||||
|
||||
let result = self
|
||||
.execute_subtasks(plan, task.budget(), ctx, requested_model)
|
||||
.await;
|
||||
|
||||
return AgentResult {
|
||||
success: result.success,
|
||||
output: result.output,
|
||||
@@ -672,7 +782,11 @@ impl Agent for NodeAgent {
|
||||
|
||||
// Simple task or failed to split: execute directly
|
||||
// Select model
|
||||
ctx.emit_phase("selecting_model", Some("Choosing model..."), Some(&self.name));
|
||||
ctx.emit_phase(
|
||||
"selecting_model",
|
||||
Some("Choosing model..."),
|
||||
Some(&self.name),
|
||||
);
|
||||
let sel_result = self.model_selector.execute(task, ctx).await;
|
||||
total_cost += sel_result.cost_cents;
|
||||
|
||||
@@ -685,13 +799,12 @@ impl Agent for NodeAgent {
|
||||
task.set_last_output(result.output.clone());
|
||||
|
||||
if !result.success {
|
||||
return AgentResult::failure(result.output, total_cost)
|
||||
.with_data(json!({
|
||||
"node_name": self.name,
|
||||
"complexity": complexity.score(),
|
||||
"was_split": false,
|
||||
"execution": result.data,
|
||||
}));
|
||||
return AgentResult::failure(result.output, total_cost).with_data(json!({
|
||||
"node_name": self.name,
|
||||
"complexity": complexity.score(),
|
||||
"was_split": false,
|
||||
"execution": result.data,
|
||||
}));
|
||||
}
|
||||
|
||||
// Verify
|
||||
@@ -747,7 +860,9 @@ impl OrchestratorAgent for NodeAgent {
|
||||
|
||||
fn find_child(&self, agent_type: AgentType) -> Option<AgentRef> {
|
||||
match agent_type {
|
||||
AgentType::ComplexityEstimator => Some(Arc::clone(&self.complexity_estimator) as AgentRef),
|
||||
AgentType::ComplexityEstimator => {
|
||||
Some(Arc::clone(&self.complexity_estimator) as AgentRef)
|
||||
}
|
||||
AgentType::ModelSelector => Some(Arc::clone(&self.model_selector) as AgentRef),
|
||||
AgentType::TaskExecutor => Some(Arc::clone(&self.task_executor) as AgentRef),
|
||||
AgentType::Verifier => Some(Arc::clone(&self.verifier) as AgentRef),
|
||||
@@ -772,4 +887,3 @@ impl OrchestratorAgent for NodeAgent {
|
||||
results
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,16 +6,19 @@
|
||||
//! - Allocation: algorithms for distributing budget across subtasks
|
||||
//! - Retry: smart retry strategies for budget overflow
|
||||
//! - Benchmarks: model capability scores for task-aware selection
|
||||
//! - Resolver: auto-upgrade outdated model names to latest equivalents
|
||||
|
||||
mod budget;
|
||||
mod pricing;
|
||||
mod allocation;
|
||||
mod retry;
|
||||
pub mod benchmarks;
|
||||
pub mod resolver;
|
||||
|
||||
pub use budget::{Budget, BudgetError};
|
||||
pub use pricing::{ModelPricing, PricingInfo};
|
||||
pub use allocation::{AllocationStrategy, allocate_budget};
|
||||
pub use retry::{ExecutionSignals, FailureAnalysis, FailureMode, RetryRecommendation, RetryConfig};
|
||||
pub use benchmarks::{TaskType, BenchmarkRegistry, SharedBenchmarkRegistry, load_benchmarks};
|
||||
pub use resolver::{ModelResolver, ModelFamily, ResolvedModel, SharedModelResolver, load_resolver};
|
||||
|
||||
|
||||
@@ -208,34 +208,73 @@ impl ModelPricing {
|
||||
/// - Models with $0 pricing
|
||||
/// - "Lite" or small model variants
|
||||
/// - Models not in the explicit allowlist
|
||||
///
|
||||
/// # Model Allowlist Maintenance
|
||||
/// This list should be kept in sync with the model families defined in
|
||||
/// `models_with_benchmarks.json` (generated by `scripts/merge_benchmarks.py`).
|
||||
/// The ModelResolver auto-upgrades outdated model names to latest versions.
|
||||
pub async fn models_by_cost_filtered(&self, require_tools: bool) -> Vec<PricingInfo> {
|
||||
// Explicitly allowed model patterns (exact match or prefix with version suffix like -001)
|
||||
// These are the ONLY models that will be considered for task execution
|
||||
// These are the ONLY models that will be considered for task execution.
|
||||
//
|
||||
// IMPORTANT: Keep in sync with MODEL_FAMILY_PATTERNS in scripts/merge_benchmarks.py
|
||||
// When new model versions are released, add them here and run the merge script.
|
||||
const CAPABLE_MODEL_BASES: &[&str] = &[
|
||||
// Claude family (all sizes work great)
|
||||
// === Anthropic Claude ===
|
||||
// Flagship tier
|
||||
"anthropic/claude-opus-4.5",
|
||||
"anthropic/claude-opus-4",
|
||||
// Mid tier (balanced cost/performance)
|
||||
"anthropic/claude-sonnet-4.5",
|
||||
"anthropic/claude-sonnet-4",
|
||||
"anthropic/claude-3.7-sonnet",
|
||||
"anthropic/claude-3.5-sonnet",
|
||||
// Fast tier (cheap/fast)
|
||||
"anthropic/claude-haiku-4.5",
|
||||
"anthropic/claude-3.5-haiku",
|
||||
"anthropic/claude-3-haiku",
|
||||
// OpenAI GPT-4 family
|
||||
"openai/gpt-4o",
|
||||
"openai/gpt-4o-mini",
|
||||
"openai/gpt-4-turbo",
|
||||
|
||||
// === OpenAI ===
|
||||
// Flagship tier
|
||||
"openai/o1",
|
||||
"openai/o1-preview",
|
||||
// Mid tier
|
||||
"openai/gpt-4.1",
|
||||
"openai/gpt-4o",
|
||||
"openai/gpt-4-turbo",
|
||||
"openai/o1-mini",
|
||||
"openai/o3-mini",
|
||||
// Fast tier
|
||||
"openai/gpt-4.1-mini",
|
||||
// Google Gemini (large models ONLY - no lite/flash-lite)
|
||||
"google/gemini-pro",
|
||||
"google/gemini-1.5-pro",
|
||||
"openai/gpt-4o-mini",
|
||||
|
||||
// === Google Gemini ===
|
||||
// Mid tier (large models ONLY - no lite/flash-lite)
|
||||
"google/gemini-2.5-pro",
|
||||
// Mistral large models
|
||||
"google/gemini-1.5-pro",
|
||||
"google/gemini-pro",
|
||||
// Fast tier
|
||||
"google/gemini-2.0-flash",
|
||||
"google/gemini-1.5-flash",
|
||||
|
||||
// === Mistral ===
|
||||
"mistralai/mistral-large",
|
||||
"mistralai/mistral-medium",
|
||||
// DeepSeek large
|
||||
"mistralai/mistral-small",
|
||||
|
||||
// === DeepSeek ===
|
||||
"deepseek/deepseek-r1",
|
||||
"deepseek/deepseek-chat",
|
||||
"deepseek/deepseek-coder",
|
||||
|
||||
// === Meta Llama ===
|
||||
"meta-llama/llama-3.3-70b",
|
||||
"meta-llama/llama-3.2-90b",
|
||||
"meta-llama/llama-3.1-405b",
|
||||
|
||||
// === Qwen ===
|
||||
"qwen/qwen-2.5-72b",
|
||||
"qwen/qwq-32b",
|
||||
];
|
||||
|
||||
// Patterns to exclude even if they match an allowed base
|
||||
|
||||
345
src/budget/resolver.rs
Normal file
345
src/budget/resolver.rs
Normal file
@@ -0,0 +1,345 @@
|
||||
//! Model resolver for auto-upgrading outdated model names.
|
||||
//!
|
||||
//! # Problem
|
||||
//! AI models often suggest outdated model versions (e.g., "claude-3.5-sonnet")
|
||||
//! because their training data is stale. Newer models are typically cheaper and
|
||||
//! smarter, so we want to automatically upgrade to the latest equivalent.
|
||||
//!
|
||||
//! # Solution
|
||||
//! The `ModelResolver` maintains a mapping of:
|
||||
//! - Model families (claude-sonnet, gpt-4, etc.) with their latest versions
|
||||
//! - Aliases from old model IDs to new ones
|
||||
//!
|
||||
//! When a model is requested, the resolver:
|
||||
//! 1. Checks if it's an outdated family member
|
||||
//! 2. Returns the latest equivalent with upgrade info
|
||||
//!
|
||||
//! # Data Source
|
||||
//! Families and aliases are loaded from `models_with_benchmarks.json`,
|
||||
//! which is auto-generated by `scripts/merge_benchmarks.py`.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Information about a model family.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelFamily {
|
||||
/// The latest (recommended) model in this family
|
||||
pub latest: String,
|
||||
/// All members of this family (sorted by version, latest first)
|
||||
pub members: Vec<String>,
|
||||
/// Performance tier: "flagship", "mid", or "fast"
|
||||
pub tier: String,
|
||||
}
|
||||
|
||||
/// Result of resolving a model ID.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ResolvedModel {
|
||||
/// Original model ID that was requested
|
||||
pub original: String,
|
||||
/// Resolved model ID (may be same or upgraded)
|
||||
pub resolved: String,
|
||||
/// Whether the model was upgraded
|
||||
pub upgraded: bool,
|
||||
/// Reason for upgrade (if any)
|
||||
pub reason: Option<String>,
|
||||
/// The family this model belongs to (if known)
|
||||
pub family: Option<String>,
|
||||
}
|
||||
|
||||
impl ResolvedModel {
|
||||
/// Create a result for an unchanged model.
|
||||
pub fn unchanged(model_id: &str) -> Self {
|
||||
Self {
|
||||
original: model_id.to_string(),
|
||||
resolved: model_id.to_string(),
|
||||
upgraded: false,
|
||||
reason: None,
|
||||
family: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a result for an upgraded model.
|
||||
pub fn upgraded(original: &str, resolved: &str, reason: &str, family: Option<&str>) -> Self {
|
||||
Self {
|
||||
original: original.to_string(),
|
||||
resolved: resolved.to_string(),
|
||||
upgraded: true,
|
||||
reason: Some(reason.to_string()),
|
||||
family: family.map(|s| s.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Model resolver with family-based auto-upgrade.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ModelResolver {
|
||||
/// Model families: family_name -> ModelFamily
|
||||
families: HashMap<String, ModelFamily>,
|
||||
/// Direct aliases: old_model_id -> new_model_id
|
||||
aliases: HashMap<String, String>,
|
||||
/// Reverse lookup: model_id -> family_name
|
||||
model_to_family: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl ModelResolver {
|
||||
/// Create an empty resolver.
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Load resolver data from the benchmark JSON file.
|
||||
pub fn load_from_file(path: impl AsRef<Path>) -> Result<Self, String> {
|
||||
let content = std::fs::read_to_string(path.as_ref())
|
||||
.map_err(|e| format!("Failed to read resolver data: {}", e))?;
|
||||
|
||||
Self::load_from_json(&content)
|
||||
}
|
||||
|
||||
/// Load resolver data from JSON string.
|
||||
pub fn load_from_json(json: &str) -> Result<Self, String> {
|
||||
#[derive(Deserialize)]
|
||||
struct BenchmarkFile {
|
||||
#[serde(default)]
|
||||
families: HashMap<String, ModelFamily>,
|
||||
#[serde(default)]
|
||||
aliases: HashMap<String, String>,
|
||||
}
|
||||
|
||||
let data: BenchmarkFile = serde_json::from_str(json)
|
||||
.map_err(|e| format!("Failed to parse resolver data: {}", e))?;
|
||||
|
||||
let mut resolver = Self {
|
||||
families: data.families.clone(),
|
||||
aliases: data.aliases,
|
||||
model_to_family: HashMap::new(),
|
||||
};
|
||||
|
||||
// Build reverse lookup
|
||||
for (family_name, family) in &data.families {
|
||||
for member in &family.members {
|
||||
resolver
|
||||
.model_to_family
|
||||
.insert(member.clone(), family_name.clone());
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
"Loaded model resolver: {} families, {} aliases",
|
||||
resolver.families.len(),
|
||||
resolver.aliases.len()
|
||||
);
|
||||
|
||||
Ok(resolver)
|
||||
}
|
||||
|
||||
/// Resolve a potentially outdated model ID to the latest equivalent.
|
||||
///
|
||||
/// # Examples
|
||||
/// - "claude-3.5-sonnet" → "anthropic/claude-sonnet-4.5" (upgraded)
|
||||
/// - "anthropic/claude-sonnet-4.5" → "anthropic/claude-sonnet-4.5" (unchanged)
|
||||
/// - "gpt-4o" → "openai/gpt-4.1" (upgraded)
|
||||
/// - "unknown-model" → "unknown-model" (unchanged, not in families)
|
||||
pub fn resolve(&self, model_id: &str) -> ResolvedModel {
|
||||
// 1. Check direct alias first (covers short names and old versions)
|
||||
if let Some(target) = self.aliases.get(model_id) {
|
||||
let family = self.model_to_family.get(target).map(|s| s.as_str());
|
||||
return ResolvedModel::upgraded(
|
||||
model_id,
|
||||
target,
|
||||
&format!("Alias resolved to latest"),
|
||||
family,
|
||||
);
|
||||
}
|
||||
|
||||
// 2. Check if model is in a family but not the latest
|
||||
if let Some(family_name) = self.model_to_family.get(model_id) {
|
||||
if let Some(family) = self.families.get(family_name) {
|
||||
if model_id != family.latest {
|
||||
return ResolvedModel::upgraded(
|
||||
model_id,
|
||||
&family.latest,
|
||||
&format!("Upgraded to latest {} model", family_name),
|
||||
Some(family_name),
|
||||
);
|
||||
} else {
|
||||
// Already the latest
|
||||
return ResolvedModel {
|
||||
original: model_id.to_string(),
|
||||
resolved: model_id.to_string(),
|
||||
upgraded: false,
|
||||
reason: None,
|
||||
family: Some(family_name.clone()),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Try fuzzy matching by normalizing the model name
|
||||
let normalized = Self::normalize(model_id);
|
||||
if let Some(target) = self.aliases.get(&normalized) {
|
||||
let family = self.model_to_family.get(target).map(|s| s.as_str());
|
||||
return ResolvedModel::upgraded(
|
||||
model_id,
|
||||
target,
|
||||
"Fuzzy match to latest",
|
||||
family,
|
||||
);
|
||||
}
|
||||
|
||||
// 4. Try to match family name directly
|
||||
for (family_name, family) in &self.families {
|
||||
if normalized.contains(family_name) || family_name.contains(&normalized) {
|
||||
return ResolvedModel::upgraded(
|
||||
model_id,
|
||||
&family.latest,
|
||||
&format!("Matched to {} family", family_name),
|
||||
Some(family_name),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// 5. No match - return as-is
|
||||
ResolvedModel::unchanged(model_id)
|
||||
}
|
||||
|
||||
/// Check if a model ID exists and is the latest in its family.
|
||||
pub fn is_latest(&self, model_id: &str) -> bool {
|
||||
if let Some(family_name) = self.model_to_family.get(model_id) {
|
||||
if let Some(family) = self.families.get(family_name) {
|
||||
return model_id == family.latest;
|
||||
}
|
||||
}
|
||||
// Unknown models are considered "latest" (no upgrade available)
|
||||
true
|
||||
}
|
||||
|
||||
/// Get the family a model belongs to.
|
||||
pub fn get_family(&self, model_id: &str) -> Option<&ModelFamily> {
|
||||
self.model_to_family
|
||||
.get(model_id)
|
||||
.and_then(|name| self.families.get(name))
|
||||
}
|
||||
|
||||
/// Get all model families.
|
||||
pub fn families(&self) -> &HashMap<String, ModelFamily> {
|
||||
&self.families
|
||||
}
|
||||
|
||||
/// Get all known latest model IDs (one per family).
|
||||
pub fn latest_models(&self) -> Vec<&str> {
|
||||
self.families.values().map(|f| f.latest.as_str()).collect()
|
||||
}
|
||||
|
||||
/// Get all model IDs in a tier ("flagship", "mid", "fast").
|
||||
pub fn models_by_tier(&self, tier: &str) -> Vec<&str> {
|
||||
self.families
|
||||
.values()
|
||||
.filter(|f| f.tier == tier)
|
||||
.map(|f| f.latest.as_str())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Normalize a model ID for fuzzy matching.
|
||||
fn normalize(model_id: &str) -> String {
|
||||
model_id
|
||||
.to_lowercase()
|
||||
.replace([':', '-', '_', '.', '/'], "")
|
||||
}
|
||||
}
|
||||
|
||||
/// Thread-safe model resolver wrapper.
|
||||
pub type SharedModelResolver = Arc<RwLock<ModelResolver>>;
|
||||
|
||||
/// Create a shared model resolver, loading from default path.
|
||||
pub fn load_resolver(workspace_dir: &str) -> SharedModelResolver {
|
||||
let path = format!("{}/models_with_benchmarks.json", workspace_dir);
|
||||
|
||||
match ModelResolver::load_from_file(&path) {
|
||||
Ok(resolver) => {
|
||||
tracing::info!("Loaded model resolver from {}", path);
|
||||
Arc::new(RwLock::new(resolver))
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to load resolver: {}. Using empty resolver.", e);
|
||||
Arc::new(RwLock::new(ModelResolver::new()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_resolver() -> ModelResolver {
|
||||
let json = r#"{
|
||||
"families": {
|
||||
"claude-sonnet": {
|
||||
"latest": "anthropic/claude-sonnet-4.5",
|
||||
"members": ["anthropic/claude-sonnet-4.5", "anthropic/claude-3.7-sonnet", "anthropic/claude-3.5-sonnet"],
|
||||
"tier": "mid"
|
||||
},
|
||||
"gpt-4": {
|
||||
"latest": "openai/gpt-4.1",
|
||||
"members": ["openai/gpt-4.1", "openai/gpt-4o"],
|
||||
"tier": "mid"
|
||||
}
|
||||
},
|
||||
"aliases": {
|
||||
"claude-3.5-sonnet": "anthropic/claude-sonnet-4.5",
|
||||
"sonnet": "anthropic/claude-sonnet-4.5",
|
||||
"gpt-4o": "openai/gpt-4.1",
|
||||
"gpt4": "openai/gpt-4.1"
|
||||
}
|
||||
}"#;
|
||||
ModelResolver::load_from_json(json).unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_alias() {
|
||||
let resolver = test_resolver();
|
||||
|
||||
let result = resolver.resolve("claude-3.5-sonnet");
|
||||
assert!(result.upgraded);
|
||||
assert_eq!(result.resolved, "anthropic/claude-sonnet-4.5");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_family_member() {
|
||||
let resolver = test_resolver();
|
||||
|
||||
let result = resolver.resolve("anthropic/claude-3.7-sonnet");
|
||||
assert!(result.upgraded);
|
||||
assert_eq!(result.resolved, "anthropic/claude-sonnet-4.5");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_latest_unchanged() {
|
||||
let resolver = test_resolver();
|
||||
|
||||
let result = resolver.resolve("anthropic/claude-sonnet-4.5");
|
||||
assert!(!result.upgraded);
|
||||
assert_eq!(result.resolved, "anthropic/claude-sonnet-4.5");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_unknown_unchanged() {
|
||||
let resolver = test_resolver();
|
||||
|
||||
let result = resolver.resolve("some-unknown-model");
|
||||
assert!(!result.upgraded);
|
||||
assert_eq!(result.resolved, "some-unknown-model");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_latest() {
|
||||
let resolver = test_resolver();
|
||||
|
||||
assert!(resolver.is_latest("anthropic/claude-sonnet-4.5"));
|
||||
assert!(!resolver.is_latest("anthropic/claude-3.5-sonnet"));
|
||||
assert!(resolver.is_latest("unknown-model")); // Unknown = no upgrade
|
||||
}
|
||||
}
|
||||
@@ -8,7 +8,7 @@
|
||||
mod error;
|
||||
mod openrouter;
|
||||
|
||||
pub use error::{LlmError, LlmErrorKind, RetryConfig, classify_http_status};
|
||||
pub use error::{classify_http_status, LlmError, LlmErrorKind, RetryConfig};
|
||||
pub use openrouter::OpenRouterClient;
|
||||
|
||||
use async_trait::async_trait;
|
||||
@@ -122,10 +122,6 @@ pub struct ChatMessage {
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
/// Reasoning details for models with extended thinking (Gemini 3, Claude 3.7+).
|
||||
/// Must be preserved from responses and passed back in subsequent requests.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning_details: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl ChatMessage {
|
||||
@@ -136,7 +132,6 @@ impl ChatMessage {
|
||||
content: Some(MessageContent::text(content)),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
reasoning_details: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -147,7 +142,6 @@ impl ChatMessage {
|
||||
content: Some(MessageContent::text_and_image(text, image_url)),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
reasoning_details: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -199,9 +193,6 @@ pub struct ChatResponse {
|
||||
pub finish_reason: Option<String>,
|
||||
pub usage: Option<TokenUsage>,
|
||||
pub model: Option<String>,
|
||||
/// Reasoning details for models with extended thinking (Gemini 3, Claude 3.7+).
|
||||
/// Must be preserved and passed back in subsequent requests for tool calling.
|
||||
pub reasoning_details: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Token usage information (if provided by the upstream provider).
|
||||
@@ -260,4 +251,3 @@ pub trait LlmClient: Send + Sync {
|
||||
self.chat_completion(model, messages, tools).await
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,9 @@ use serde::{Deserialize, Serialize};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use super::error::{classify_http_status, LlmError, LlmErrorKind, RetryConfig};
|
||||
use super::{ChatMessage, ChatOptions, ChatResponse, LlmClient, TokenUsage, ToolCall, ToolDefinition};
|
||||
use super::{
|
||||
ChatMessage, ChatOptions, ChatResponse, LlmClient, TokenUsage, ToolCall, ToolDefinition,
|
||||
};
|
||||
|
||||
const OPENROUTER_API_URL: &str = "https://openrouter.ai/api/v1/chat/completions";
|
||||
|
||||
@@ -65,10 +67,7 @@ impl OpenRouterClient {
|
||||
}
|
||||
|
||||
/// Execute a single request without retry.
|
||||
async fn execute_request(
|
||||
&self,
|
||||
request: &OpenRouterRequest,
|
||||
) -> Result<ChatResponse, LlmError> {
|
||||
async fn execute_request(&self, request: &OpenRouterRequest) -> Result<ChatResponse, LlmError> {
|
||||
let response = match self
|
||||
.client
|
||||
.post(OPENROUTER_API_URL)
|
||||
@@ -119,7 +118,6 @@ impl OpenRouterClient {
|
||||
.usage
|
||||
.map(|u| TokenUsage::new(u.prompt_tokens, u.completion_tokens)),
|
||||
model: parsed.model.or_else(|| Some(request.model.clone())),
|
||||
reasoning_details: choice.message.reasoning_details,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -153,8 +151,8 @@ impl OpenRouterClient {
|
||||
return Ok(response);
|
||||
}
|
||||
Err(error) => {
|
||||
let should_retry =
|
||||
self.retry_config.should_retry(&error) && attempt < self.retry_config.max_retries;
|
||||
let should_retry = self.retry_config.should_retry(&error)
|
||||
&& attempt < self.retry_config.max_retries;
|
||||
|
||||
if should_retry {
|
||||
let delay = error.suggested_delay(attempt);
|
||||
@@ -280,10 +278,6 @@ struct OpenRouterChoice {
|
||||
struct OpenRouterMessage {
|
||||
content: Option<String>,
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
/// Reasoning details for models that support extended thinking (Gemini 3, Claude 3.7+, etc.)
|
||||
/// Must be preserved and passed back in subsequent requests for tool calling to work.
|
||||
#[serde(default)]
|
||||
reasoning_details: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Usage data (OpenAI-compatible).
|
||||
@@ -294,4 +288,3 @@ struct OpenRouterUsage {
|
||||
#[serde(rename = "total_tokens")]
|
||||
_total_tokens: u64,
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user