- Update typing in error handling and validation nodes to improve type safety. - Refactor cache decorators for async compatibility and cleanup functionality. - Enhance URL processing and validation logic with improved type checks. - Centralize error handling and recovery mechanisms in nodes. - Simplify and standardize function signatures across multiple modules for consistency. - Resolve linting issues and ensure compliance with type safety standards.
441 lines
15 KiB
Python
441 lines
15 KiB
Python
"""Error analyzer node for classifying errors and determining recovery strategies."""
|
|
|
|
import re
|
|
from collections.abc import Mapping
|
|
from typing import Literal, TypedDict, cast
|
|
|
|
from langchain_core.runnables import RunnableConfig
|
|
|
|
from biz_bud.core import ErrorCategory, ErrorInfo
|
|
from biz_bud.core.langgraph import standard_node
|
|
from biz_bud.core.utils.regex_security import search_safe
|
|
from biz_bud.logging import get_logger
|
|
from biz_bud.services.llm.client import LangchainLLMClient
|
|
from biz_bud.states.error_handling import (
|
|
ErrorAnalysis,
|
|
ErrorContext,
|
|
ErrorHandlingState,
|
|
)
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class ErrorAnalysisUpdate(TypedDict):
|
|
"""Typed update returned by the analyzer node."""
|
|
|
|
error_analysis: ErrorAnalysis
|
|
|
|
|
|
class ErrorAnalysisDelta(TypedDict, total=False):
|
|
"""Optional fields that can be enriched by LLM analysis."""
|
|
|
|
error_type: str
|
|
criticality: Literal["low", "medium", "high", "critical"]
|
|
can_continue: bool
|
|
suggested_actions: list[str]
|
|
root_cause: str | None
|
|
|
|
|
|
@standard_node("error_analyzer_node")
|
|
async def error_analyzer_node(
|
|
state: ErrorHandlingState, config: RunnableConfig | None
|
|
) -> ErrorAnalysisUpdate:
|
|
"""Analyze error criticality and determine recovery strategies.
|
|
|
|
Uses both rule-based logic and LLM analysis to understand the error
|
|
and suggest appropriate recovery actions.
|
|
|
|
Args:
|
|
state: Current error handling state
|
|
config: Configuration dictionary
|
|
|
|
Returns:
|
|
Dictionary with error analysis results
|
|
|
|
"""
|
|
error = state["current_error"]
|
|
context = state["error_context"]
|
|
|
|
node_name = context["node_name"]
|
|
logger.info(f"Analyzing error from node: {node_name}")
|
|
|
|
# First, apply rule-based classification
|
|
initial_analysis = _rule_based_analysis(error, context)
|
|
logger.info(
|
|
f"Rule-based analysis - Type: {initial_analysis['error_type']}, "
|
|
f"Criticality: {initial_analysis['criticality']}"
|
|
)
|
|
|
|
# For complex errors, enhance with LLM analysis if enabled
|
|
config_dict = _get_configurable_section(config)
|
|
error_handling_config = _get_mapping(config_dict.get("error_handling"))
|
|
if (
|
|
initial_analysis["criticality"] in ["high", "critical"]
|
|
or not initial_analysis["suggested_actions"]
|
|
) and bool(error_handling_config.get("enable_llm_analysis", True)):
|
|
try:
|
|
from biz_bud.services.factory import get_global_factory
|
|
|
|
factory = await get_global_factory()
|
|
llm_client = await factory.get_llm_for_node("error_analyzer")
|
|
if isinstance(llm_client, LangchainLLMClient):
|
|
enhanced_analysis = await _llm_error_analysis(
|
|
llm_client, error, context, initial_analysis
|
|
)
|
|
else:
|
|
# Handle wrapped client case
|
|
logger.warning("Unexpected LLM client type, skipping LLM analysis")
|
|
enhanced_analysis = ErrorAnalysisDelta()
|
|
|
|
merged_analysis: ErrorAnalysis = ErrorAnalysis(
|
|
error_type=_get_str(
|
|
enhanced_analysis, "error_type", initial_analysis["error_type"]
|
|
),
|
|
criticality=_get_literal(
|
|
enhanced_analysis,
|
|
"criticality",
|
|
initial_analysis["criticality"],
|
|
),
|
|
can_continue=_get_bool(
|
|
enhanced_analysis, "can_continue", initial_analysis["can_continue"]
|
|
),
|
|
suggested_actions=_get_list(
|
|
enhanced_analysis,
|
|
"suggested_actions",
|
|
initial_analysis["suggested_actions"],
|
|
),
|
|
root_cause=_get_optional_str(
|
|
enhanced_analysis,
|
|
"root_cause",
|
|
initial_analysis.get("root_cause"),
|
|
),
|
|
)
|
|
initial_analysis = merged_analysis
|
|
logger.info("Enhanced analysis with LLM insights")
|
|
except Exception as e:
|
|
logger.warning(f"LLM analysis failed, using rule-based only: {e}")
|
|
|
|
return {"error_analysis": initial_analysis}
|
|
|
|
|
|
def _rule_based_analysis(error: ErrorInfo, context: ErrorContext) -> ErrorAnalysis:
|
|
"""Apply rule-based error classification.
|
|
|
|
Args:
|
|
error: Error information
|
|
context: Error context
|
|
|
|
Returns:
|
|
Error analysis result
|
|
|
|
"""
|
|
error_type = error.get("error_type", "unknown")
|
|
error_message = error.get("message", "")
|
|
error_category = error.get("category", ErrorCategory.UNKNOWN)
|
|
raw_details = error.get("details")
|
|
if isinstance(raw_details, Mapping):
|
|
error_category = raw_details.get("category", error_category)
|
|
|
|
# Analyze based on error category
|
|
if error_category == ErrorCategory.LLM:
|
|
return _analyze_llm_error(error_message)
|
|
elif error_category == ErrorCategory.CONFIGURATION:
|
|
return _analyze_config_error(error_message)
|
|
elif error_category == ErrorCategory.TOOL:
|
|
return _analyze_tool_error(error_message)
|
|
elif error_category == ErrorCategory.NETWORK:
|
|
return _analyze_network_error(error_message)
|
|
elif error_category == ErrorCategory.VALIDATION:
|
|
return _analyze_validation_error(error_message)
|
|
elif error_category == ErrorCategory.RATE_LIMIT:
|
|
return _analyze_rate_limit_error(error_message)
|
|
elif error_category == ErrorCategory.AUTHENTICATION:
|
|
return _analyze_auth_error(error_message)
|
|
else:
|
|
return _analyze_generic_error(error_message, error_type)
|
|
|
|
|
|
def _analyze_llm_error(error_message: str) -> ErrorAnalysis:
|
|
"""Analyze LLM-specific errors."""
|
|
error_lower = error_message.lower()
|
|
|
|
if any(term in error_lower for term in ["rate limit", "quota exceeded", "429"]):
|
|
return ErrorAnalysis(
|
|
error_type="rate_limit",
|
|
criticality="medium",
|
|
can_continue=True,
|
|
suggested_actions=["retry_with_backoff", "switch_llm_provider"],
|
|
root_cause="API rate limit exceeded",
|
|
)
|
|
elif any(term in error_lower for term in ["context length", "too long", "token"]):
|
|
return ErrorAnalysis(
|
|
error_type="context_overflow",
|
|
criticality="high",
|
|
can_continue=True,
|
|
suggested_actions=["trim_context", "chunk_input", "use_larger_model"],
|
|
root_cause="Input exceeds model context window",
|
|
)
|
|
elif any(
|
|
term in error_lower for term in ["invalid api key", "unauthorized", "403"]
|
|
):
|
|
return ErrorAnalysis(
|
|
error_type="authentication",
|
|
criticality="critical",
|
|
can_continue=False,
|
|
suggested_actions=["verify_api_credentials", "rotate_api_key"],
|
|
root_cause="Invalid or expired API credentials",
|
|
)
|
|
else:
|
|
return ErrorAnalysis(
|
|
error_type="llm_general",
|
|
criticality="medium",
|
|
can_continue=True,
|
|
suggested_actions=["retry", "switch_llm_provider", "fallback"],
|
|
root_cause="General LLM service error",
|
|
)
|
|
|
|
|
|
def _analyze_config_error(error_message: str) -> ErrorAnalysis:
|
|
"""Analyze configuration errors."""
|
|
return ErrorAnalysis(
|
|
error_type="configuration",
|
|
criticality="critical",
|
|
can_continue=False,
|
|
suggested_actions=["verify_config", "restore_defaults", "check_env_vars"],
|
|
root_cause="Invalid or missing configuration",
|
|
)
|
|
|
|
|
|
def _analyze_tool_error(error_message: str) -> ErrorAnalysis:
|
|
"""Analyze tool-specific errors."""
|
|
error_lower = error_message.lower()
|
|
|
|
if any(term in error_lower for term in ["not found", "404", "missing"]):
|
|
return ErrorAnalysis(
|
|
error_type="tool_not_found",
|
|
criticality="high",
|
|
can_continue=True,
|
|
suggested_actions=["skip", "use_alternative_tool", "install_tool"],
|
|
root_cause="Required tool or resource not found",
|
|
)
|
|
else:
|
|
return ErrorAnalysis(
|
|
error_type="tool_execution",
|
|
criticality="medium",
|
|
can_continue=True,
|
|
suggested_actions=["retry", "use_fallback_tool", "skip"],
|
|
root_cause="Tool execution failed",
|
|
)
|
|
|
|
|
|
def _analyze_network_error(error_message: str) -> ErrorAnalysis:
|
|
"""Analyze network-related errors."""
|
|
error_lower = error_message.lower()
|
|
|
|
if any(term in error_lower for term in ["timeout", "timed out"]):
|
|
return ErrorAnalysis(
|
|
error_type="timeout",
|
|
criticality="low",
|
|
can_continue=True,
|
|
suggested_actions=["retry_with_increased_timeout", "check_network"],
|
|
root_cause="Network request timed out",
|
|
)
|
|
elif any(term in error_lower for term in ["connection", "refused", "unreachable"]):
|
|
return ErrorAnalysis(
|
|
error_type="connection_error",
|
|
criticality="medium",
|
|
can_continue=True,
|
|
suggested_actions=["retry", "check_connectivity", "use_cache"],
|
|
root_cause="Network connection failed",
|
|
)
|
|
else:
|
|
return ErrorAnalysis(
|
|
error_type="network_general",
|
|
criticality="medium",
|
|
can_continue=True,
|
|
suggested_actions=["retry", "check_network", "fallback"],
|
|
root_cause="General network error",
|
|
)
|
|
|
|
|
|
def _analyze_validation_error(error_message: str) -> ErrorAnalysis:
|
|
"""Analyze validation errors."""
|
|
return ErrorAnalysis(
|
|
error_type="validation",
|
|
criticality="medium",
|
|
can_continue=True,
|
|
suggested_actions=["fix_data_format", "use_defaults", "skip_validation"],
|
|
root_cause="Data validation failed",
|
|
)
|
|
|
|
|
|
def _analyze_rate_limit_error(error_message: str) -> ErrorAnalysis:
|
|
"""Analyze rate limit errors."""
|
|
if wait_match := search_safe(r"(\d+)\s*(second|minute|hour)", error_message.lower()):
|
|
wait_time = f"{wait_match[1]} {wait_match[2]}s"
|
|
else:
|
|
wait_time = None
|
|
return ErrorAnalysis(
|
|
error_type="rate_limit",
|
|
criticality="medium",
|
|
can_continue=True,
|
|
suggested_actions=["retry_with_backoff", "switch_provider", "queue_request"],
|
|
root_cause=f"Rate limit exceeded{f' - wait {wait_time}' if wait_time else ''}",
|
|
)
|
|
|
|
|
|
def _analyze_auth_error(error_message: str) -> ErrorAnalysis:
|
|
"""Analyze authentication errors."""
|
|
return ErrorAnalysis(
|
|
error_type="authentication",
|
|
criticality="critical",
|
|
can_continue=False,
|
|
suggested_actions=["verify_credentials", "rotate_api_key", "check_permissions"],
|
|
root_cause="Authentication failed",
|
|
)
|
|
|
|
|
|
def _analyze_generic_error(error_message: str, error_type: str) -> ErrorAnalysis:
|
|
"""Analyze generic/unknown errors."""
|
|
# Try to determine criticality from error message
|
|
error_lower = error_message.lower()
|
|
criticality = "medium"
|
|
|
|
if any(term in error_lower for term in ["critical", "fatal", "severe"]):
|
|
criticality = "critical"
|
|
elif any(term in error_lower for term in ["warning", "minor"]):
|
|
criticality = "low"
|
|
|
|
return ErrorAnalysis(
|
|
error_type=error_type or "unknown",
|
|
criticality=criticality,
|
|
can_continue=criticality != "critical",
|
|
suggested_actions=["retry", "log_and_continue", "manual_intervention"],
|
|
root_cause="Unclassified error occurred",
|
|
)
|
|
|
|
|
|
async def _llm_error_analysis(
|
|
llm_client: LangchainLLMClient,
|
|
error: ErrorInfo,
|
|
context: ErrorContext,
|
|
initial_analysis: ErrorAnalysis,
|
|
) -> ErrorAnalysisDelta:
|
|
"""Enhance error analysis using LLM.
|
|
|
|
Args:
|
|
llm_client: LLM client for analysis
|
|
error: Error information
|
|
context: Error context
|
|
initial_analysis: Initial rule-based analysis
|
|
|
|
Returns:
|
|
Enhanced analysis fields
|
|
|
|
"""
|
|
from biz_bud.prompts.error_handling import ERROR_ANALYSIS_PROMPT
|
|
|
|
prompt = ERROR_ANALYSIS_PROMPT.format(
|
|
error_type=error.get("error_type", "unknown"),
|
|
error_message=error.get("message", ""),
|
|
context=context,
|
|
attempts=context["execution_count"],
|
|
)
|
|
|
|
try:
|
|
response = await llm_client.llm_chat(prompt)
|
|
|
|
# Parse LLM response and extract enhancements
|
|
# This is a simplified version - in production, use structured output
|
|
enhanced: ErrorAnalysisDelta = {}
|
|
content = response.lower()
|
|
|
|
# Extract root cause if more detailed
|
|
if "root cause:" in content:
|
|
if root_cause_match := search_safe(r"root cause:\s*(.+?)(?:\n|$)", content):
|
|
enhanced["root_cause"] = root_cause_match[1].strip()
|
|
|
|
# Extract additional suggested actions
|
|
if "suggested actions:" in content or "recommendations:" in content:
|
|
if action_section := search_safe(
|
|
r"(?:suggested actions|recommendations):\s*(.+?)(?:\n\n|$)",
|
|
content,
|
|
flags=re.DOTALL,
|
|
):
|
|
action_lines = action_section[1].strip().split("\n")
|
|
actions = []
|
|
for line in action_lines:
|
|
if line := line.strip("- *").strip():
|
|
actions.append(line.lower().replace(" ", "_"))
|
|
if actions:
|
|
enhanced["suggested_actions"] = [
|
|
*initial_analysis["suggested_actions"],
|
|
*actions,
|
|
]
|
|
|
|
return enhanced
|
|
|
|
except Exception as e:
|
|
logger.warning(f"LLM analysis failed: {e}")
|
|
return ErrorAnalysisDelta()
|
|
|
|
|
|
def _get_configurable_section(config: RunnableConfig | None = None) -> Mapping[str, object]:
|
|
if config is None:
|
|
return {}
|
|
raw_configurable = config.get("configurable")
|
|
if isinstance(raw_configurable, Mapping):
|
|
return cast(Mapping[str, object], raw_configurable)
|
|
return {}
|
|
|
|
|
|
def _get_mapping(value: object) -> Mapping[str, object]:
|
|
if isinstance(value, Mapping):
|
|
return cast(Mapping[str, object], value)
|
|
return {}
|
|
|
|
|
|
def _get_str(
|
|
mapping: Mapping[str, object], key: str, default: str
|
|
) -> str:
|
|
value = mapping.get(key, default)
|
|
return value if isinstance(value, str) else default
|
|
|
|
|
|
def _get_bool(
|
|
mapping: Mapping[str, object], key: str, default: bool
|
|
) -> bool:
|
|
value = mapping.get(key, default)
|
|
if isinstance(value, bool):
|
|
return value
|
|
return default
|
|
|
|
|
|
def _get_literal(
|
|
mapping: Mapping[str, object],
|
|
key: str,
|
|
default: Literal["low", "medium", "high", "critical"],
|
|
) -> Literal["low", "medium", "high", "critical"]:
|
|
value = mapping.get(key, default)
|
|
if value in {"low", "medium", "high", "critical"}:
|
|
return cast(Literal["low", "medium", "high", "critical"], value)
|
|
return default
|
|
|
|
|
|
def _get_list(
|
|
mapping: Mapping[str, object], key: str, default: list[str]
|
|
) -> list[str]:
|
|
value = mapping.get(key, default)
|
|
if isinstance(value, list) and all(isinstance(item, str) for item in value):
|
|
return list(value)
|
|
return list(default)
|
|
|
|
|
|
def _get_optional_str(
|
|
mapping: Mapping[str, object], key: str, default: str | None
|
|
) -> str | None:
|
|
value = mapping.get(key)
|
|
if value is None or isinstance(value, str):
|
|
return value if value is not None else default
|
|
return default
|