Files
biz-bud/src/biz_bud/nodes/error_handling/analyzer.py
Travis Vasceannie 7a84d75d8e Refactor type safety checks and enhance error handling across various modules
- Update typing in error handling and validation nodes to improve type safety.
- Refactor cache decorators for async compatibility and cleanup functionality.
- Enhance URL processing and validation logic with improved type checks.
- Centralize error handling and recovery mechanisms in nodes.
- Simplify and standardize function signatures across multiple modules for consistency.
- Resolve linting issues and ensure compliance with type safety standards.
2025-09-28 13:45:52 -04:00

441 lines
15 KiB
Python

"""Error analyzer node for classifying errors and determining recovery strategies."""
import re
from collections.abc import Mapping
from typing import Literal, TypedDict, cast
from langchain_core.runnables import RunnableConfig
from biz_bud.core import ErrorCategory, ErrorInfo
from biz_bud.core.langgraph import standard_node
from biz_bud.core.utils.regex_security import search_safe
from biz_bud.logging import get_logger
from biz_bud.services.llm.client import LangchainLLMClient
from biz_bud.states.error_handling import (
ErrorAnalysis,
ErrorContext,
ErrorHandlingState,
)
logger = get_logger(__name__)
class ErrorAnalysisUpdate(TypedDict):
"""Typed update returned by the analyzer node."""
error_analysis: ErrorAnalysis
class ErrorAnalysisDelta(TypedDict, total=False):
"""Optional fields that can be enriched by LLM analysis."""
error_type: str
criticality: Literal["low", "medium", "high", "critical"]
can_continue: bool
suggested_actions: list[str]
root_cause: str | None
@standard_node("error_analyzer_node")
async def error_analyzer_node(
state: ErrorHandlingState, config: RunnableConfig | None
) -> ErrorAnalysisUpdate:
"""Analyze error criticality and determine recovery strategies.
Uses both rule-based logic and LLM analysis to understand the error
and suggest appropriate recovery actions.
Args:
state: Current error handling state
config: Configuration dictionary
Returns:
Dictionary with error analysis results
"""
error = state["current_error"]
context = state["error_context"]
node_name = context["node_name"]
logger.info(f"Analyzing error from node: {node_name}")
# First, apply rule-based classification
initial_analysis = _rule_based_analysis(error, context)
logger.info(
f"Rule-based analysis - Type: {initial_analysis['error_type']}, "
f"Criticality: {initial_analysis['criticality']}"
)
# For complex errors, enhance with LLM analysis if enabled
config_dict = _get_configurable_section(config)
error_handling_config = _get_mapping(config_dict.get("error_handling"))
if (
initial_analysis["criticality"] in ["high", "critical"]
or not initial_analysis["suggested_actions"]
) and bool(error_handling_config.get("enable_llm_analysis", True)):
try:
from biz_bud.services.factory import get_global_factory
factory = await get_global_factory()
llm_client = await factory.get_llm_for_node("error_analyzer")
if isinstance(llm_client, LangchainLLMClient):
enhanced_analysis = await _llm_error_analysis(
llm_client, error, context, initial_analysis
)
else:
# Handle wrapped client case
logger.warning("Unexpected LLM client type, skipping LLM analysis")
enhanced_analysis = ErrorAnalysisDelta()
merged_analysis: ErrorAnalysis = ErrorAnalysis(
error_type=_get_str(
enhanced_analysis, "error_type", initial_analysis["error_type"]
),
criticality=_get_literal(
enhanced_analysis,
"criticality",
initial_analysis["criticality"],
),
can_continue=_get_bool(
enhanced_analysis, "can_continue", initial_analysis["can_continue"]
),
suggested_actions=_get_list(
enhanced_analysis,
"suggested_actions",
initial_analysis["suggested_actions"],
),
root_cause=_get_optional_str(
enhanced_analysis,
"root_cause",
initial_analysis.get("root_cause"),
),
)
initial_analysis = merged_analysis
logger.info("Enhanced analysis with LLM insights")
except Exception as e:
logger.warning(f"LLM analysis failed, using rule-based only: {e}")
return {"error_analysis": initial_analysis}
def _rule_based_analysis(error: ErrorInfo, context: ErrorContext) -> ErrorAnalysis:
"""Apply rule-based error classification.
Args:
error: Error information
context: Error context
Returns:
Error analysis result
"""
error_type = error.get("error_type", "unknown")
error_message = error.get("message", "")
error_category = error.get("category", ErrorCategory.UNKNOWN)
raw_details = error.get("details")
if isinstance(raw_details, Mapping):
error_category = raw_details.get("category", error_category)
# Analyze based on error category
if error_category == ErrorCategory.LLM:
return _analyze_llm_error(error_message)
elif error_category == ErrorCategory.CONFIGURATION:
return _analyze_config_error(error_message)
elif error_category == ErrorCategory.TOOL:
return _analyze_tool_error(error_message)
elif error_category == ErrorCategory.NETWORK:
return _analyze_network_error(error_message)
elif error_category == ErrorCategory.VALIDATION:
return _analyze_validation_error(error_message)
elif error_category == ErrorCategory.RATE_LIMIT:
return _analyze_rate_limit_error(error_message)
elif error_category == ErrorCategory.AUTHENTICATION:
return _analyze_auth_error(error_message)
else:
return _analyze_generic_error(error_message, error_type)
def _analyze_llm_error(error_message: str) -> ErrorAnalysis:
"""Analyze LLM-specific errors."""
error_lower = error_message.lower()
if any(term in error_lower for term in ["rate limit", "quota exceeded", "429"]):
return ErrorAnalysis(
error_type="rate_limit",
criticality="medium",
can_continue=True,
suggested_actions=["retry_with_backoff", "switch_llm_provider"],
root_cause="API rate limit exceeded",
)
elif any(term in error_lower for term in ["context length", "too long", "token"]):
return ErrorAnalysis(
error_type="context_overflow",
criticality="high",
can_continue=True,
suggested_actions=["trim_context", "chunk_input", "use_larger_model"],
root_cause="Input exceeds model context window",
)
elif any(
term in error_lower for term in ["invalid api key", "unauthorized", "403"]
):
return ErrorAnalysis(
error_type="authentication",
criticality="critical",
can_continue=False,
suggested_actions=["verify_api_credentials", "rotate_api_key"],
root_cause="Invalid or expired API credentials",
)
else:
return ErrorAnalysis(
error_type="llm_general",
criticality="medium",
can_continue=True,
suggested_actions=["retry", "switch_llm_provider", "fallback"],
root_cause="General LLM service error",
)
def _analyze_config_error(error_message: str) -> ErrorAnalysis:
"""Analyze configuration errors."""
return ErrorAnalysis(
error_type="configuration",
criticality="critical",
can_continue=False,
suggested_actions=["verify_config", "restore_defaults", "check_env_vars"],
root_cause="Invalid or missing configuration",
)
def _analyze_tool_error(error_message: str) -> ErrorAnalysis:
"""Analyze tool-specific errors."""
error_lower = error_message.lower()
if any(term in error_lower for term in ["not found", "404", "missing"]):
return ErrorAnalysis(
error_type="tool_not_found",
criticality="high",
can_continue=True,
suggested_actions=["skip", "use_alternative_tool", "install_tool"],
root_cause="Required tool or resource not found",
)
else:
return ErrorAnalysis(
error_type="tool_execution",
criticality="medium",
can_continue=True,
suggested_actions=["retry", "use_fallback_tool", "skip"],
root_cause="Tool execution failed",
)
def _analyze_network_error(error_message: str) -> ErrorAnalysis:
"""Analyze network-related errors."""
error_lower = error_message.lower()
if any(term in error_lower for term in ["timeout", "timed out"]):
return ErrorAnalysis(
error_type="timeout",
criticality="low",
can_continue=True,
suggested_actions=["retry_with_increased_timeout", "check_network"],
root_cause="Network request timed out",
)
elif any(term in error_lower for term in ["connection", "refused", "unreachable"]):
return ErrorAnalysis(
error_type="connection_error",
criticality="medium",
can_continue=True,
suggested_actions=["retry", "check_connectivity", "use_cache"],
root_cause="Network connection failed",
)
else:
return ErrorAnalysis(
error_type="network_general",
criticality="medium",
can_continue=True,
suggested_actions=["retry", "check_network", "fallback"],
root_cause="General network error",
)
def _analyze_validation_error(error_message: str) -> ErrorAnalysis:
"""Analyze validation errors."""
return ErrorAnalysis(
error_type="validation",
criticality="medium",
can_continue=True,
suggested_actions=["fix_data_format", "use_defaults", "skip_validation"],
root_cause="Data validation failed",
)
def _analyze_rate_limit_error(error_message: str) -> ErrorAnalysis:
"""Analyze rate limit errors."""
if wait_match := search_safe(r"(\d+)\s*(second|minute|hour)", error_message.lower()):
wait_time = f"{wait_match[1]} {wait_match[2]}s"
else:
wait_time = None
return ErrorAnalysis(
error_type="rate_limit",
criticality="medium",
can_continue=True,
suggested_actions=["retry_with_backoff", "switch_provider", "queue_request"],
root_cause=f"Rate limit exceeded{f' - wait {wait_time}' if wait_time else ''}",
)
def _analyze_auth_error(error_message: str) -> ErrorAnalysis:
"""Analyze authentication errors."""
return ErrorAnalysis(
error_type="authentication",
criticality="critical",
can_continue=False,
suggested_actions=["verify_credentials", "rotate_api_key", "check_permissions"],
root_cause="Authentication failed",
)
def _analyze_generic_error(error_message: str, error_type: str) -> ErrorAnalysis:
"""Analyze generic/unknown errors."""
# Try to determine criticality from error message
error_lower = error_message.lower()
criticality = "medium"
if any(term in error_lower for term in ["critical", "fatal", "severe"]):
criticality = "critical"
elif any(term in error_lower for term in ["warning", "minor"]):
criticality = "low"
return ErrorAnalysis(
error_type=error_type or "unknown",
criticality=criticality,
can_continue=criticality != "critical",
suggested_actions=["retry", "log_and_continue", "manual_intervention"],
root_cause="Unclassified error occurred",
)
async def _llm_error_analysis(
llm_client: LangchainLLMClient,
error: ErrorInfo,
context: ErrorContext,
initial_analysis: ErrorAnalysis,
) -> ErrorAnalysisDelta:
"""Enhance error analysis using LLM.
Args:
llm_client: LLM client for analysis
error: Error information
context: Error context
initial_analysis: Initial rule-based analysis
Returns:
Enhanced analysis fields
"""
from biz_bud.prompts.error_handling import ERROR_ANALYSIS_PROMPT
prompt = ERROR_ANALYSIS_PROMPT.format(
error_type=error.get("error_type", "unknown"),
error_message=error.get("message", ""),
context=context,
attempts=context["execution_count"],
)
try:
response = await llm_client.llm_chat(prompt)
# Parse LLM response and extract enhancements
# This is a simplified version - in production, use structured output
enhanced: ErrorAnalysisDelta = {}
content = response.lower()
# Extract root cause if more detailed
if "root cause:" in content:
if root_cause_match := search_safe(r"root cause:\s*(.+?)(?:\n|$)", content):
enhanced["root_cause"] = root_cause_match[1].strip()
# Extract additional suggested actions
if "suggested actions:" in content or "recommendations:" in content:
if action_section := search_safe(
r"(?:suggested actions|recommendations):\s*(.+?)(?:\n\n|$)",
content,
flags=re.DOTALL,
):
action_lines = action_section[1].strip().split("\n")
actions = []
for line in action_lines:
if line := line.strip("- *").strip():
actions.append(line.lower().replace(" ", "_"))
if actions:
enhanced["suggested_actions"] = [
*initial_analysis["suggested_actions"],
*actions,
]
return enhanced
except Exception as e:
logger.warning(f"LLM analysis failed: {e}")
return ErrorAnalysisDelta()
def _get_configurable_section(config: RunnableConfig | None = None) -> Mapping[str, object]:
if config is None:
return {}
raw_configurable = config.get("configurable")
if isinstance(raw_configurable, Mapping):
return cast(Mapping[str, object], raw_configurable)
return {}
def _get_mapping(value: object) -> Mapping[str, object]:
if isinstance(value, Mapping):
return cast(Mapping[str, object], value)
return {}
def _get_str(
mapping: Mapping[str, object], key: str, default: str
) -> str:
value = mapping.get(key, default)
return value if isinstance(value, str) else default
def _get_bool(
mapping: Mapping[str, object], key: str, default: bool
) -> bool:
value = mapping.get(key, default)
if isinstance(value, bool):
return value
return default
def _get_literal(
mapping: Mapping[str, object],
key: str,
default: Literal["low", "medium", "high", "critical"],
) -> Literal["low", "medium", "high", "critical"]:
value = mapping.get(key, default)
if value in {"low", "medium", "high", "critical"}:
return cast(Literal["low", "medium", "high", "critical"], value)
return default
def _get_list(
mapping: Mapping[str, object], key: str, default: list[str]
) -> list[str]:
value = mapping.get(key, default)
if isinstance(value, list) and all(isinstance(item, str) for item in value):
return list(value)
return list(default)
def _get_optional_str(
mapping: Mapping[str, object], key: str, default: str | None
) -> str | None:
value = mapping.get(key)
if value is None or isinstance(value, str):
return value if value is not None else default
return default