Refine (#58)
* feat: implement async factory functions and optimize blocking I/O operations * refactor: improve thread safety and error handling in graph factory caching system * Update src/biz_bud/agents/buddy_agent.py Co-authored-by: qodo-merge-pro[bot] <151058649+qodo-merge-pro[bot]@users.noreply.github.com> * Update src/biz_bud/services/factory/service_factory.py Co-authored-by: qodo-merge-pro[bot] <151058649+qodo-merge-pro[bot]@users.noreply.github.com> * Update src/biz_bud/core/edge_helpers/error_handling.py Co-authored-by: qodo-merge-pro[bot] <151058649+qodo-merge-pro[bot]@users.noreply.github.com> * refactor: extract validation helpers and improve code organization in workflow modules * refactor: extract shared code into helper methods and reuse across modules * fix: prevent task cancellation leaks and improve async handling across core services * refactor: replace async methods with sync in URL processing and error handling * refactor: replace emoji and unsafe regex with plain text and secure regex operations * refactor: consolidate timeout handling and regex safety checks with cross-platform support * refactor: consolidate async detection and error normalization utilities across core modules * Update src/biz_bud/core/utils/url_analyzer.py Co-authored-by: qodo-merge-pro[bot] <151058649+qodo-merge-pro[bot]@users.noreply.github.com> * fix: prevent config override in LLM kwargs to maintain custom parameters * refactor: centralize regex security and async context handling with improved error handling * Update scripts/checks/audit_core_dependencies.py Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * Update scripts/checks/audit_core_dependencies.py Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * Update scripts/checks/audit_core_dependencies.py Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * Update scripts/checks/audit_core_dependencies.py Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * fix: correct regex pattern for alternation in repeated groups by removing escaped bracket --------- Co-authored-by: qodo-merge-pro[bot] <151058649+qodo-merge-pro[bot]@users.noreply.github.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
This commit is contained in:
@@ -2,5 +2,5 @@ projectKey=vasceannie_biz-bud_6c113581-e663-4a15-8a76-1ce5dab23a5f
|
||||
serverUrl=http://sonar.lab
|
||||
serverVersion=25.7.0.110598
|
||||
dashboardUrl=http://sonar.lab/dashboard?id=vasceannie_biz-bud_6c113581-e663-4a15-8a76-1ce5dab23a5f
|
||||
ceTaskId=342a04f5-5ffe-426a-afc4-c68a4e92f0cb
|
||||
ceTaskUrl=http://sonar.lab/api/ce/task?id=342a04f5-5ffe-426a-afc4-c68a4e92f0cb
|
||||
ceTaskId=cfa37333-abfb-4a0a-b950-db0c0f741f5f
|
||||
ceTaskUrl=http://sonar.lab/api/ce/task?id=cfa37333-abfb-4a0a-b950-db0c0f741f5f
|
||||
|
||||
@@ -20,8 +20,11 @@ DISALLOWED_IMPORTS: Dict[str, str] = {
|
||||
"requests": "biz_bud.core.networking.http_client.HTTPClient",
|
||||
"httpx": "biz_bud.core.networking.http_client.HTTPClient",
|
||||
"aiohttp": "biz_bud.core.networking.http_client.HTTPClient",
|
||||
"asyncio.gather": "biz_bud.core.utils.async_utils.gather_with_concurrency",
|
||||
"asyncio.gather": "biz_bud.core.utils.gather_with_concurrency",
|
||||
"threading.Lock": "asyncio.Lock (use pure async patterns)",
|
||||
# Removed hashlib and json.dumps - too many false positives
|
||||
"asyncio.iscoroutinefunction": "biz_bud.core.networking.async_utils.detect_async_context",
|
||||
"inspect.iscoroutinefunction": "biz_bud.core.networking.async_utils.detect_async_context",
|
||||
}
|
||||
|
||||
# Direct instantiation of service clients or tools that should come from the factory.
|
||||
@@ -48,6 +51,26 @@ DISALLOWED_EXCEPTIONS: Set[str] = {
|
||||
"NotImplementedError",
|
||||
}
|
||||
|
||||
# Function patterns that should use core utilities
|
||||
DISALLOWED_PATTERNS: Dict[str, str] = {
|
||||
"_generate_config_hash": "biz_bud.core.config.loader.generate_config_hash",
|
||||
"_normalize_errors": "biz_bud.core.utils.normalize_errors_to_list",
|
||||
# Removed _create_initial_state - domain-specific state creators are legitimate
|
||||
"_extract_state_data": "biz_bud.core.utils.graph_helpers.extract_state_update_data",
|
||||
"_format_input": "biz_bud.core.utils.graph_helpers.format_raw_input",
|
||||
"_process_query": "biz_bud.core.utils.graph_helpers.process_state_query",
|
||||
}
|
||||
|
||||
# Variable names that suggest manual caching implementations
|
||||
CACHE_VARIABLE_PATTERNS: List[str] = [
|
||||
"_graph_cache",
|
||||
"_compiled_graphs",
|
||||
"_graph_instances",
|
||||
"_cached_graphs",
|
||||
"graph_cache",
|
||||
"compiled_cache",
|
||||
]
|
||||
|
||||
# --- File Path Patterns for Exemptions ---
|
||||
|
||||
# Core infrastructure files that can use networking libraries directly
|
||||
@@ -171,6 +194,9 @@ class InfrastructureVisitor(ast.NodeVisitor):
|
||||
node,
|
||||
f"Disallowed import '{alias.name}'. Please use '{suggestion}'."
|
||||
)
|
||||
# Special handling for hashlib - store that it was imported
|
||||
elif alias.name == "hashlib":
|
||||
self.imported_names[alias.asname or "hashlib"] = "hashlib"
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_If(self, node: ast.If) -> None:
|
||||
@@ -181,6 +207,22 @@ class InfrastructureVisitor(ast.NodeVisitor):
|
||||
if isinstance(node.test, ast.Name) and node.test.id == "TYPE_CHECKING":
|
||||
self.in_type_checking = True
|
||||
|
||||
# Check for manual error normalization patterns
|
||||
if isinstance(node.test, ast.Call) and isinstance(node.test.func, ast.Name) and node.test.func.id == 'isinstance':
|
||||
# Check if testing for list type on something with 'error' in variable name
|
||||
if len(node.test.args) >= 2:
|
||||
var_arg = node.test.args[0]
|
||||
type_arg = node.test.args[1]
|
||||
if isinstance(var_arg, ast.Name) and 'error' in var_arg.id.lower():
|
||||
if isinstance(type_arg, ast.Name) and type_arg.id == 'list':
|
||||
# Skip if we're in the normalize_errors_to_list function itself
|
||||
if '/core/utils/__init__.py' not in self.filepath:
|
||||
# This is likely error normalization code
|
||||
self._add_violation(
|
||||
node,
|
||||
"Manual error normalization pattern detected. Use 'biz_bud.core.utils.normalize_errors_to_list' instead."
|
||||
)
|
||||
|
||||
self.generic_visit(node)
|
||||
self.in_type_checking = was_in_type_checking
|
||||
|
||||
@@ -237,10 +279,12 @@ class InfrastructureVisitor(ast.NodeVisitor):
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
|
||||
"""Track if we're in a Pydantic field validator or a _get_client method."""
|
||||
self._check_function_pattern(node)
|
||||
self._visit_function_def(node)
|
||||
|
||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
|
||||
"""Track if we're in a Pydantic field validator or a _get_client method (async version)."""
|
||||
self._check_function_pattern(node)
|
||||
self._visit_function_def(node)
|
||||
|
||||
def _visit_function_def(self, node) -> None:
|
||||
@@ -316,6 +360,41 @@ class InfrastructureVisitor(ast.NodeVisitor):
|
||||
if isinstance(target, ast.Subscript) and isinstance(target.value, ast.Name):
|
||||
if target.value.id == 'state':
|
||||
self._add_violation(node, STATE_MUTATION_MESSAGE)
|
||||
|
||||
# Check for manual state dict creation patterns
|
||||
if isinstance(node.value, ast.Dict):
|
||||
# Check if assigning to a variable that looks like initial state creation
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name) and ('initial' in target.id.lower() and 'state' in target.id.lower()):
|
||||
# Check if it's creating a full initial state structure
|
||||
if self._is_initial_state_dict(node.value):
|
||||
self._add_violation(
|
||||
node,
|
||||
"Direct state dict creation detected. Use 'biz_bud.core.utils.graph_helpers.create_initial_state_dict' instead."
|
||||
)
|
||||
|
||||
# Check for manual cache implementations
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
for pattern in CACHE_VARIABLE_PATTERNS:
|
||||
if pattern in target.id and not self._has_filepath_pattern(["/core/caching/", "cache_manager.py", "/graphs/graph.py"]):
|
||||
self._add_violation(
|
||||
node,
|
||||
f"Manual cache implementation '{target.id}' detected. Use 'biz_bud.core.caching.cache_manager.GraphCache' instead."
|
||||
)
|
||||
break
|
||||
|
||||
# Also check for self.attribute assignments
|
||||
elif isinstance(target, ast.Attribute) and isinstance(target.value, ast.Name) and target.value.id == 'self':
|
||||
for pattern in CACHE_VARIABLE_PATTERNS:
|
||||
if pattern in target.attr and not self._has_filepath_pattern(["/core/caching/", "cache_manager.py", "/graphs/graph.py"]):
|
||||
self._add_violation(
|
||||
node,
|
||||
f"Manual cache implementation 'self.{target.attr}' detected. Use 'biz_bud.core.caching.cache_manager.GraphCache' instead."
|
||||
)
|
||||
break
|
||||
|
||||
|
||||
self.generic_visit(node)
|
||||
|
||||
def _should_skip_call_validation(self) -> bool:
|
||||
@@ -366,6 +445,19 @@ class InfrastructureVisitor(ast.NodeVisitor):
|
||||
if parent_name == 'state' and attr_name == 'update':
|
||||
self._add_violation(node, STATE_UPDATE_MESSAGE)
|
||||
|
||||
# Note: We don't check for iscoroutinefunction usage because:
|
||||
# - detect_async_context() checks if we're IN an async context
|
||||
# - iscoroutinefunction() checks if a function IS async
|
||||
# These serve different purposes and both are legitimate
|
||||
|
||||
# Check for hashlib usage with config-related patterns
|
||||
if parent_name in self.imported_names and self.imported_names[parent_name] == "hashlib" and self._is_likely_config_hashing(node):
|
||||
self._add_violation(
|
||||
node,
|
||||
"Using hashlib for config hashing. Use 'biz_bud.core.config.loader.generate_config_hash' instead."
|
||||
)
|
||||
|
||||
|
||||
def visit_Call(self, node: ast.Call) -> None:
|
||||
"""
|
||||
Checks for:
|
||||
@@ -381,10 +473,81 @@ class InfrastructureVisitor(ast.NodeVisitor):
|
||||
self._check_attribute_calls(node)
|
||||
self.generic_visit(node)
|
||||
|
||||
def _check_function_pattern(self, node) -> None:
|
||||
"""Check if function name matches a disallowed pattern."""
|
||||
for pattern, suggestion in DISALLOWED_PATTERNS.items():
|
||||
if node.name == pattern or node.name.endswith(pattern):
|
||||
# Skip if it's in a core infrastructure file
|
||||
if self._has_filepath_pattern(FACTORY_SERVICE_PATTERNS):
|
||||
continue
|
||||
|
||||
self._add_violation(
|
||||
node,
|
||||
f"Function '{node.name}' appears to duplicate core functionality. Use '{suggestion}' instead."
|
||||
)
|
||||
|
||||
def _is_likely_config_hashing(self, node: ast.Call) -> bool:
|
||||
"""Check if hashlib usage is likely for config hashing."""
|
||||
# Skip if in infrastructure files where hashlib is allowed
|
||||
if self._has_filepath_pattern(["/core/config/", "/core/caching/", "/core/errors/", "/core/networking/"]):
|
||||
return False
|
||||
|
||||
# Check if the function name or context suggests config hashing
|
||||
# This is a simple heuristic - could be improved
|
||||
if hasattr(node, 'args') and node.args:
|
||||
for arg in node.args:
|
||||
# Check if any argument contains 'config' in variable name
|
||||
if isinstance(arg, ast.Name) and 'config' in arg.id.lower():
|
||||
return True
|
||||
# Check if json.dumps is used on the argument (common pattern)
|
||||
if isinstance(arg, ast.Call) and isinstance(arg.func, ast.Attribute) and (arg.func.attr == 'dumps' and isinstance(arg.func.value, ast.Name) and arg.func.value.id == 'json'):
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
def _is_initial_state_dict(self, dict_node: ast.Dict) -> bool:
|
||||
"""Check if a dict literal looks like an initial state creation."""
|
||||
if not hasattr(dict_node, 'keys'):
|
||||
return False
|
||||
|
||||
# Count how many initial state keys are present
|
||||
initial_state_keys = {
|
||||
'raw_input', 'parsed_input', 'messages', 'initial_input',
|
||||
'thread_id', 'config', 'input_metadata', 'context',
|
||||
'status', 'errors', 'run_metadata', 'is_last_step', 'final_result'
|
||||
}
|
||||
|
||||
found_keys = set()
|
||||
for key in dict_node.keys:
|
||||
if isinstance(key, ast.Constant) and key.value in initial_state_keys:
|
||||
found_keys.add(key.value)
|
||||
|
||||
# If it has at least 5 of the initial state keys, it's likely an initial state
|
||||
return len(found_keys) >= 5
|
||||
|
||||
|
||||
def audit_directory(directory: str) -> Dict[str, List[Tuple[int, str]]]:
|
||||
"""Scans a directory for Python files and audits them."""
|
||||
all_violations: Dict[str, List[Tuple[int, str]]] = {}
|
||||
|
||||
# Handle single file
|
||||
if os.path.isfile(directory) and directory.endswith(".py"):
|
||||
try:
|
||||
with open(directory, "r", encoding="utf-8") as f:
|
||||
source_code = f.read()
|
||||
tree = ast.parse(source_code, filename=directory)
|
||||
|
||||
visitor = InfrastructureVisitor(directory)
|
||||
visitor.visit(tree)
|
||||
|
||||
if visitor.violations:
|
||||
all_violations[directory] = visitor.violations
|
||||
except (SyntaxError, ValueError) as e:
|
||||
all_violations[directory] = [(0, f"ERROR: Could not parse file: {e}")]
|
||||
return all_violations
|
||||
|
||||
# Handle directory
|
||||
for root, _, files in os.walk(directory):
|
||||
for file in files:
|
||||
if file.endswith(".py"):
|
||||
|
||||
@@ -2,785 +2,23 @@
|
||||
|
||||
This module provides factories and parsers for managing execution records,
|
||||
parsing plans, and formatting responses in the Buddy agent.
|
||||
|
||||
NOTE: This module is now a compatibility layer that delegates to the refactored
|
||||
workflow utilities to avoid code duplication.
|
||||
"""
|
||||
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Literal
|
||||
|
||||
from biz_bud.core.errors import ValidationError
|
||||
from biz_bud.logging import get_logger
|
||||
|
||||
# Removed broken core import
|
||||
from biz_bud.states.buddy import ExecutionRecord
|
||||
from biz_bud.states.planner import ExecutionPlan, QueryStep
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ExecutionRecordFactory:
|
||||
"""Factory for creating standardized execution records."""
|
||||
|
||||
@staticmethod
|
||||
def create_success_record(
|
||||
step_id: str,
|
||||
graph_name: str,
|
||||
start_time: float,
|
||||
result: Any,
|
||||
) -> ExecutionRecord:
|
||||
"""Create an execution record for a successful execution.
|
||||
|
||||
Args:
|
||||
step_id: The ID of the executed step
|
||||
graph_name: Name of the graph that was executed
|
||||
start_time: Timestamp when execution started
|
||||
result: The result of the execution
|
||||
|
||||
Returns:
|
||||
ExecutionRecord for a successful execution
|
||||
"""
|
||||
return ExecutionRecord(
|
||||
step_id=step_id,
|
||||
graph_name=graph_name,
|
||||
start_time=start_time,
|
||||
end_time=time.time(),
|
||||
status="completed",
|
||||
result=result,
|
||||
error=None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_failure_record(
|
||||
step_id: str,
|
||||
graph_name: str,
|
||||
start_time: float,
|
||||
error: str | Exception,
|
||||
) -> ExecutionRecord:
|
||||
"""Create an execution record for a failed execution.
|
||||
|
||||
Args:
|
||||
step_id: The ID of the executed step
|
||||
graph_name: Name of the graph that was executed
|
||||
start_time: Timestamp when execution started
|
||||
error: The error that occurred
|
||||
|
||||
Returns:
|
||||
ExecutionRecord for a failed execution
|
||||
"""
|
||||
return ExecutionRecord(
|
||||
step_id=step_id,
|
||||
graph_name=graph_name,
|
||||
start_time=start_time,
|
||||
end_time=time.time(),
|
||||
status="failed",
|
||||
result=None,
|
||||
error=str(error),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_skipped_record(
|
||||
step_id: str,
|
||||
graph_name: str,
|
||||
reason: str = "Dependencies not met",
|
||||
) -> ExecutionRecord:
|
||||
"""Create an execution record for a skipped step.
|
||||
|
||||
Args:
|
||||
step_id: The ID of the skipped step
|
||||
graph_name: Name of the graph that would have been executed
|
||||
reason: Reason for skipping
|
||||
|
||||
Returns:
|
||||
ExecutionRecord for a skipped execution
|
||||
"""
|
||||
current_time = time.time()
|
||||
return ExecutionRecord(
|
||||
step_id=step_id,
|
||||
graph_name=graph_name,
|
||||
start_time=current_time,
|
||||
end_time=current_time,
|
||||
status="skipped",
|
||||
result=None,
|
||||
error=reason,
|
||||
)
|
||||
|
||||
|
||||
class PlanParser:
|
||||
"""Parser for converting planner output into structured execution plans."""
|
||||
|
||||
# Regex pattern for parsing plan steps
|
||||
STEP_PATTERN = re.compile(r"Step (\w+): ([^\n]+)\n\s*- Graph: (\w+)")
|
||||
|
||||
@staticmethod
|
||||
def parse_planner_result(result: str | dict[str, Any]) -> ExecutionPlan | None:
|
||||
"""Parse a planner result into an ExecutionPlan.
|
||||
|
||||
Expected format:
|
||||
Step 1: Description here
|
||||
- Graph: graph_name
|
||||
|
||||
Args:
|
||||
result: The planner output string
|
||||
|
||||
Returns:
|
||||
ExecutionPlan if parsing successful, None otherwise
|
||||
"""
|
||||
if not result:
|
||||
logger.warning("Empty planner result provided")
|
||||
return None
|
||||
|
||||
# Handle dict results from planner tools
|
||||
if isinstance(result, dict):
|
||||
# If result contains an 'execution_plan' field, convert it to ExecutionPlan
|
||||
if "execution_plan" in result and isinstance(
|
||||
result["execution_plan"], dict
|
||||
):
|
||||
execution_plan_data: dict[str, Any] = result["execution_plan"]
|
||||
|
||||
# Validate required fields in execution_plan
|
||||
if "steps" not in execution_plan_data:
|
||||
logger.warning(
|
||||
"execution_plan missing required 'steps' field, creating fallback plan"
|
||||
)
|
||||
execution_plan_data["steps"] = []
|
||||
|
||||
if not isinstance(execution_plan_data["steps"], list):
|
||||
logger.warning(
|
||||
"execution_plan 'steps' field is not a list, creating fallback plan"
|
||||
)
|
||||
execution_plan_data["steps"] = []
|
||||
|
||||
logger.info(
|
||||
f"Found structured execution_plan with {len(execution_plan_data['steps'])} steps"
|
||||
)
|
||||
|
||||
# Convert to proper ExecutionPlan object
|
||||
steps = []
|
||||
for step_data in execution_plan_data.get("steps", []):
|
||||
if isinstance(step_data, dict):
|
||||
# Ensure dependencies is a list of strings
|
||||
dependencies_raw = step_data.get("dependencies", [])
|
||||
if isinstance(dependencies_raw, str):
|
||||
dependencies = [dependencies_raw]
|
||||
elif isinstance(dependencies_raw, list):
|
||||
dependencies = [str(dep) for dep in dependencies_raw]
|
||||
else:
|
||||
dependencies = []
|
||||
|
||||
# Validate priority with proper type checking
|
||||
priority_raw = step_data.get("priority", "medium")
|
||||
if not isinstance(priority_raw, str):
|
||||
logger.warning(
|
||||
f"Invalid priority type: {type(priority_raw)}, using default 'medium'"
|
||||
)
|
||||
priority: Literal["high", "medium", "low"] = "medium"
|
||||
elif priority_raw not in ["high", "medium", "low"]:
|
||||
logger.warning(
|
||||
f"Invalid priority value: {priority_raw}, using default 'medium'"
|
||||
)
|
||||
priority = "medium"
|
||||
else:
|
||||
priority = priority_raw # type: ignore[assignment] # Validated above
|
||||
|
||||
# Validate status with proper type checking
|
||||
status_raw = step_data.get("status", "pending")
|
||||
valid_statuses: list[str] = [
|
||||
"pending",
|
||||
"in_progress",
|
||||
"completed",
|
||||
"failed",
|
||||
"blocked",
|
||||
]
|
||||
if not isinstance(status_raw, str):
|
||||
logger.warning(
|
||||
f"Invalid status type: {type(status_raw)}, using default 'pending'"
|
||||
)
|
||||
status: Literal[
|
||||
"pending",
|
||||
"in_progress",
|
||||
"completed",
|
||||
"failed",
|
||||
"blocked",
|
||||
] = "pending"
|
||||
elif status_raw not in valid_statuses:
|
||||
logger.warning(
|
||||
f"Invalid status value: {status_raw}, using default 'pending'"
|
||||
)
|
||||
status = "pending"
|
||||
else:
|
||||
status = status_raw # type: ignore[assignment] # Validated above
|
||||
|
||||
# Validate results field
|
||||
results_raw = step_data.get("results")
|
||||
if results_raw is not None and not isinstance(
|
||||
results_raw, dict
|
||||
):
|
||||
logger.warning(
|
||||
f"Invalid results type: {type(results_raw)}, setting to None"
|
||||
)
|
||||
results = None
|
||||
else:
|
||||
results = results_raw
|
||||
|
||||
# Ensure all string fields are properly converted with validation
|
||||
step_id_raw = step_data.get("id", f"step_{len(steps) + 1}")
|
||||
if not isinstance(step_id_raw, (str, int)):
|
||||
logger.warning(
|
||||
f"Invalid step_id type: {type(step_id_raw)}, using default"
|
||||
)
|
||||
step_id = f"step_{len(steps) + 1}"
|
||||
else:
|
||||
step_id = str(step_id_raw)
|
||||
|
||||
description_raw = step_data.get("description", "")
|
||||
if not isinstance(description_raw, str):
|
||||
logger.warning(
|
||||
f"Invalid description type: {type(description_raw)}, converting to string"
|
||||
)
|
||||
description = str(description_raw)
|
||||
|
||||
agent_name_raw = step_data.get("agent_name", "main")
|
||||
if agent_name_raw is not None and not isinstance(
|
||||
agent_name_raw, str
|
||||
):
|
||||
logger.warning(
|
||||
f"Invalid agent_name type: {type(agent_name_raw)}, converting to string"
|
||||
)
|
||||
agent_name = (
|
||||
str(agent_name_raw) if agent_name_raw is not None else None
|
||||
)
|
||||
|
||||
agent_role_prompt_raw = step_data.get("agent_role_prompt")
|
||||
if agent_role_prompt_raw is not None and not isinstance(
|
||||
agent_role_prompt_raw, str
|
||||
):
|
||||
logger.warning(
|
||||
f"Invalid agent_role_prompt type: {type(agent_role_prompt_raw)}, converting to string"
|
||||
)
|
||||
agent_role_prompt = (
|
||||
str(agent_role_prompt_raw)
|
||||
if agent_role_prompt_raw is not None
|
||||
else None
|
||||
)
|
||||
|
||||
error_message_raw = step_data.get("error_message")
|
||||
if error_message_raw is not None and not isinstance(
|
||||
error_message_raw, str
|
||||
):
|
||||
logger.warning(
|
||||
f"Invalid error_message type: {type(error_message_raw)}, converting to string"
|
||||
)
|
||||
error_message = (
|
||||
str(error_message_raw)
|
||||
if error_message_raw is not None
|
||||
else None
|
||||
)
|
||||
|
||||
step = QueryStep(
|
||||
id=step_id,
|
||||
description=description,
|
||||
agent_name=agent_name,
|
||||
dependencies=dependencies,
|
||||
priority=priority,
|
||||
query=str(
|
||||
step_data.get("query", step_data.get("description", ""))
|
||||
),
|
||||
status=status,
|
||||
agent_role_prompt=agent_role_prompt,
|
||||
results=results,
|
||||
error_message=error_message,
|
||||
)
|
||||
steps.append(step)
|
||||
|
||||
if steps:
|
||||
# Validate current_step_id with proper type checking
|
||||
current_step_id_raw = execution_plan_data.get("current_step_id")
|
||||
if current_step_id_raw is not None:
|
||||
if not isinstance(current_step_id_raw, (str, int)):
|
||||
logger.warning(
|
||||
f"Invalid current_step_id type: {type(current_step_id_raw)}, setting to None"
|
||||
)
|
||||
current_step_id = None
|
||||
else:
|
||||
current_step_id = str(current_step_id_raw)
|
||||
else:
|
||||
current_step_id = None
|
||||
|
||||
# Validate completed_steps with proper error handling
|
||||
completed_steps_raw = execution_plan_data.get("completed_steps", [])
|
||||
completed_steps = []
|
||||
if not isinstance(completed_steps_raw, list):
|
||||
logger.warning(
|
||||
f"Invalid completed_steps type: {type(completed_steps_raw)}, using empty list"
|
||||
)
|
||||
else:
|
||||
for step in completed_steps_raw:
|
||||
if isinstance(step, (str, int)):
|
||||
completed_steps.append(str(step))
|
||||
else:
|
||||
logger.warning(
|
||||
f"Invalid completed step type: {type(step)}, skipping"
|
||||
)
|
||||
|
||||
# Validate failed_steps with proper error handling
|
||||
failed_steps_raw = execution_plan_data.get("failed_steps", [])
|
||||
failed_steps = []
|
||||
if not isinstance(failed_steps_raw, list):
|
||||
logger.warning(
|
||||
f"Invalid failed_steps type: {type(failed_steps_raw)}, using empty list"
|
||||
)
|
||||
else:
|
||||
for step in failed_steps_raw:
|
||||
if isinstance(step, (str, int)):
|
||||
failed_steps.append(str(step))
|
||||
else:
|
||||
logger.warning(
|
||||
f"Invalid failed step type: {type(step)}, skipping"
|
||||
)
|
||||
|
||||
# Validate can_execute_parallel with proper type checking
|
||||
can_execute_parallel_raw = execution_plan_data.get(
|
||||
"can_execute_parallel", False
|
||||
)
|
||||
if not isinstance(can_execute_parallel_raw, (bool, int, str)):
|
||||
logger.warning(
|
||||
f"Invalid can_execute_parallel type: {type(can_execute_parallel_raw)}, using False"
|
||||
)
|
||||
can_execute_parallel = False
|
||||
else:
|
||||
try:
|
||||
can_execute_parallel = bool(can_execute_parallel_raw)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Cannot convert can_execute_parallel to bool: {can_execute_parallel_raw}, using False"
|
||||
)
|
||||
can_execute_parallel = False
|
||||
|
||||
# Validate execution_mode with proper type checking
|
||||
execution_mode_raw = execution_plan_data.get(
|
||||
"execution_mode", "sequential"
|
||||
)
|
||||
valid_modes = ["sequential", "parallel", "hybrid"]
|
||||
if not isinstance(execution_mode_raw, str):
|
||||
logger.warning(
|
||||
f"Invalid execution_mode type: {type(execution_mode_raw)}, using 'sequential'"
|
||||
)
|
||||
execution_mode: Literal["sequential", "parallel", "hybrid"] = (
|
||||
"sequential"
|
||||
)
|
||||
elif execution_mode_raw not in valid_modes:
|
||||
logger.warning(
|
||||
f"Invalid execution_mode value: {execution_mode_raw}, using 'sequential'"
|
||||
)
|
||||
execution_mode = "sequential"
|
||||
else:
|
||||
execution_mode = execution_mode_raw # type: ignore[assignment] # Validated above
|
||||
|
||||
# Add fallback step if no valid steps found in structured plan
|
||||
if not steps:
|
||||
logger.warning(
|
||||
"No valid steps found in structured execution plan. Creating fallback step."
|
||||
)
|
||||
fallback_step = QueryStep(
|
||||
id="fallback_1",
|
||||
description="Process user query with available tools",
|
||||
agent_name="research",
|
||||
dependencies=[],
|
||||
priority="high",
|
||||
query="Process user request",
|
||||
status="pending",
|
||||
agent_role_prompt=None,
|
||||
results=None,
|
||||
error_message=None,
|
||||
)
|
||||
steps = [fallback_step]
|
||||
|
||||
return ExecutionPlan(
|
||||
steps=steps,
|
||||
current_step_id=current_step_id,
|
||||
completed_steps=completed_steps,
|
||||
failed_steps=failed_steps,
|
||||
can_execute_parallel=can_execute_parallel,
|
||||
execution_mode=execution_mode,
|
||||
)
|
||||
|
||||
# Try common keys that might contain the plan text
|
||||
# First try standard text keys
|
||||
plan_text = next(
|
||||
(
|
||||
result[key]
|
||||
for key in ["content", "response", "plan", "output", "text"]
|
||||
if key in result and isinstance(result[key], str)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
# If no direct text found, try structured keys
|
||||
if (
|
||||
not plan_text
|
||||
and "step_results" in result
|
||||
and isinstance(result["step_results"], list)
|
||||
):
|
||||
# Try to reconstruct plan from step list
|
||||
plan_parts: list[str] = []
|
||||
for i, step in enumerate(result["step_results"], 1):
|
||||
if isinstance(step, dict):
|
||||
desc = step.get("description", step.get("query", f"Step {i}"))
|
||||
graph = step.get("agent_name", step.get("graph", "main"))
|
||||
plan_parts.append(f"Step {i}: {desc}\n- Graph: {graph}")
|
||||
elif isinstance(step, str):
|
||||
plan_parts.append(f"Step {i}: {step}\n- Graph: main")
|
||||
if plan_parts:
|
||||
plan_text = "\n".join(plan_parts)
|
||||
|
||||
# Handle 'summary' - could contain plan summary we can use
|
||||
if (
|
||||
not plan_text
|
||||
and "summary" in result
|
||||
and isinstance(result["summary"], str)
|
||||
and "step" in result["summary"].lower()
|
||||
and "graph" in result["summary"].lower()
|
||||
):
|
||||
plan_text = result["summary"]
|
||||
|
||||
if not plan_text:
|
||||
logger.warning(
|
||||
f"Could not extract plan text from dict result. Keys: {list(result.keys())}"
|
||||
)
|
||||
# Log the structure for debugging
|
||||
logger.debug(f"Result structure: {result}")
|
||||
return None
|
||||
|
||||
result = plan_text
|
||||
|
||||
# Ensure result is a string at this point
|
||||
if not isinstance(result, str):
|
||||
logger.warning(
|
||||
f"Result is not a string after processing. Type: {type(result)}"
|
||||
)
|
||||
return None
|
||||
|
||||
steps: list[QueryStep] = []
|
||||
|
||||
for match in PlanParser.STEP_PATTERN.finditer(result):
|
||||
step_id = match.group(1)
|
||||
description = match.group(2).strip()
|
||||
graph_name = match.group(3)
|
||||
|
||||
step = QueryStep(
|
||||
id=step_id,
|
||||
description=description,
|
||||
agent_name=graph_name,
|
||||
dependencies=[], # Could be enhanced to parse dependencies
|
||||
priority="medium", # Default priority
|
||||
query=description, # Use description as query by default
|
||||
status="pending", # Required field
|
||||
agent_role_prompt=None, # Required field
|
||||
results=None, # Required field
|
||||
error_message=None, # Required field
|
||||
)
|
||||
steps.append(step)
|
||||
|
||||
if not steps:
|
||||
logger.warning(
|
||||
"No valid steps found in planner result. Creating fallback plan."
|
||||
)
|
||||
# Create a fallback plan with a basic step to prevent complete failure
|
||||
fallback_step = QueryStep(
|
||||
id="fallback_1",
|
||||
description="Process user query with available tools",
|
||||
agent_name="research", # Default to research graph
|
||||
dependencies=[],
|
||||
priority="high",
|
||||
query=str(result)[:500]
|
||||
if result
|
||||
else "Process user request", # Truncate if too long
|
||||
status="pending",
|
||||
agent_role_prompt=None,
|
||||
results=None,
|
||||
error_message=None,
|
||||
)
|
||||
steps = [fallback_step]
|
||||
logger.info("Created fallback execution plan with basic research step")
|
||||
|
||||
return ExecutionPlan(
|
||||
steps=steps,
|
||||
current_step_id=None,
|
||||
completed_steps=[],
|
||||
failed_steps=[],
|
||||
can_execute_parallel=False,
|
||||
execution_mode="sequential",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_dependencies(result: str) -> dict[str, list[str]]:
|
||||
"""Parse dependencies from planner result.
|
||||
|
||||
This is a placeholder for more sophisticated dependency parsing.
|
||||
|
||||
Args:
|
||||
result: The planner output string
|
||||
|
||||
Returns:
|
||||
Dictionary mapping step IDs to their dependencies
|
||||
"""
|
||||
# For now, return empty dependencies
|
||||
# Could be enhanced to parse "depends on Step X" patterns
|
||||
return {}
|
||||
|
||||
|
||||
class ResponseFormatter:
|
||||
"""Formatter for creating final responses from execution results."""
|
||||
|
||||
@staticmethod
|
||||
def format_final_response(
|
||||
query: str,
|
||||
synthesis: str,
|
||||
execution_history: list[ExecutionRecord],
|
||||
completed_steps: list[str],
|
||||
adaptation_count: int = 0,
|
||||
) -> str:
|
||||
"""Format the final response for the user.
|
||||
|
||||
Args:
|
||||
query: Original user query
|
||||
synthesis: Synthesized results
|
||||
execution_history: List of execution records
|
||||
completed_steps: List of completed step IDs
|
||||
adaptation_count: Number of adaptations made
|
||||
|
||||
Returns:
|
||||
Formatted response string
|
||||
"""
|
||||
# Calculate execution statistics
|
||||
total_executions = len(execution_history)
|
||||
status_counts = {"completed": 0, "failed": 0}
|
||||
for record in execution_history:
|
||||
status = record.get("status")
|
||||
if status in status_counts:
|
||||
status_counts[status] += 1
|
||||
|
||||
successful_executions = status_counts["completed"]
|
||||
failed_executions = status_counts["failed"]
|
||||
|
||||
# Check for records missing status and handle them explicitly
|
||||
if records_without_status := [
|
||||
record for record in execution_history if "status" not in record
|
||||
]:
|
||||
raise ValidationError(
|
||||
f"Found {len(records_without_status)} execution records missing 'status' key. "
|
||||
f"All execution records must include a 'status' key. Offending records: {records_without_status}"
|
||||
)
|
||||
|
||||
# Build the response
|
||||
response_parts = [
|
||||
"# Buddy Orchestration Complete",
|
||||
"",
|
||||
f"**Query**: {query}",
|
||||
"",
|
||||
"**Execution Summary**:",
|
||||
f"- Total steps executed: {total_executions}",
|
||||
f"- Successfully completed: {successful_executions}",
|
||||
]
|
||||
|
||||
if failed_executions > 0:
|
||||
response_parts.append(f"- Failed executions: {failed_executions}")
|
||||
|
||||
if adaptation_count > 0:
|
||||
response_parts.append(f"- Adaptations made: {adaptation_count}")
|
||||
|
||||
response_parts.extend(
|
||||
[
|
||||
"",
|
||||
"**Results**:",
|
||||
synthesis,
|
||||
]
|
||||
)
|
||||
|
||||
return "\n".join(response_parts)
|
||||
|
||||
@staticmethod
|
||||
def format_error_response(
|
||||
query: str,
|
||||
error: str,
|
||||
partial_results: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""Format an error response for the user.
|
||||
|
||||
Args:
|
||||
query: Original user query
|
||||
error: Error message
|
||||
partial_results: Any partial results obtained
|
||||
|
||||
Returns:
|
||||
Formatted error response string
|
||||
"""
|
||||
response_parts = [
|
||||
"# Buddy Orchestration Error",
|
||||
"",
|
||||
f"**Query**: {query}",
|
||||
"",
|
||||
f"**Error**: {error}",
|
||||
]
|
||||
|
||||
if partial_results:
|
||||
response_parts.extend(
|
||||
[
|
||||
"",
|
||||
"**Partial Results**:",
|
||||
"Some information was gathered before the error occurred:",
|
||||
str(partial_results),
|
||||
]
|
||||
)
|
||||
|
||||
return "\n".join(response_parts)
|
||||
|
||||
@staticmethod
|
||||
def format_streaming_update(
|
||||
phase: str,
|
||||
step: QueryStep | None = None,
|
||||
message: str | None = None,
|
||||
) -> str:
|
||||
"""Format a streaming update message.
|
||||
|
||||
Args:
|
||||
phase: Current orchestration phase
|
||||
step: Current step being executed (if any)
|
||||
message: Optional additional message
|
||||
|
||||
Returns:
|
||||
Formatted streaming update
|
||||
"""
|
||||
if step:
|
||||
return f"[{phase}] Executing step {step.get('id', 'unknown')}: {step.get('description', 'Unknown step')}\n"
|
||||
elif message:
|
||||
return f"[{phase}] {message}\n"
|
||||
else:
|
||||
return f"[{phase}] "
|
||||
|
||||
|
||||
class IntermediateResultsConverter:
|
||||
"""Converter for transforming intermediate results into various formats."""
|
||||
|
||||
@staticmethod
|
||||
def to_extracted_info(
|
||||
intermediate_results: dict[str, Any],
|
||||
) -> tuple[dict[str, Any], list[dict[str, str]]]:
|
||||
"""Convert intermediate results to extracted_info format for synthesis.
|
||||
|
||||
Args:
|
||||
intermediate_results: Dictionary of step_id -> result mappings
|
||||
|
||||
Returns:
|
||||
Tuple of (extracted_info dict, sources list)
|
||||
"""
|
||||
logger.info(
|
||||
f"Converting {len(intermediate_results)} intermediate results to extracted_info format"
|
||||
)
|
||||
logger.debug(f"Intermediate results keys: {list(intermediate_results.keys())}")
|
||||
|
||||
extracted_info: dict[str, dict[str, Any]] = {}
|
||||
sources: list[dict[str, str]] = []
|
||||
|
||||
for step_id, result in intermediate_results.items():
|
||||
logger.debug(f"Processing step {step_id}: {type(result).__name__}")
|
||||
|
||||
if isinstance(result, str):
|
||||
logger.debug(f"String result for step {step_id}, length: {len(result)}")
|
||||
# Extract key information from result string
|
||||
extracted_info[step_id] = {
|
||||
"content": result,
|
||||
"summary": f"{result[:300]}..." if len(result) > 300 else result,
|
||||
"key_points": [f"{result[:200]}..."]
|
||||
if len(result) > 200
|
||||
else [result],
|
||||
"facts": [],
|
||||
}
|
||||
sources.append(
|
||||
{
|
||||
"key": step_id,
|
||||
"url": f"step_{step_id}",
|
||||
"title": f"Step {step_id} Results",
|
||||
}
|
||||
)
|
||||
elif isinstance(result, dict):
|
||||
logger.debug(
|
||||
f"Dict result for step {step_id}, keys: {list(result.keys())}"
|
||||
)
|
||||
# Handle dictionary results - extract actual content
|
||||
content = None
|
||||
|
||||
# Try to extract meaningful content from various possible keys
|
||||
for content_key in [
|
||||
"synthesis",
|
||||
"final_response",
|
||||
"content",
|
||||
"response",
|
||||
"result",
|
||||
"output",
|
||||
]:
|
||||
if content_key in result and result[content_key]:
|
||||
content = str(result[content_key])
|
||||
logger.debug(
|
||||
f"Found content in key '{content_key}' for step {step_id}"
|
||||
)
|
||||
break
|
||||
|
||||
# If no content found, stringify the whole result
|
||||
if not content:
|
||||
content = str(result)
|
||||
logger.debug(
|
||||
f"No specific content key found, using stringified result for step {step_id}"
|
||||
)
|
||||
|
||||
# Extract key points if available
|
||||
key_points = result.get("key_points", [])
|
||||
if not key_points and content:
|
||||
# Create key points from content
|
||||
key_points = (
|
||||
[f"{content[:200]}..."] if len(content) > 200 else [content]
|
||||
)
|
||||
|
||||
extracted_info[step_id] = {
|
||||
"content": content,
|
||||
"summary": result.get(
|
||||
"summary",
|
||||
f"{content[:300]}..." if len(content) > 300 else content,
|
||||
),
|
||||
"key_points": key_points,
|
||||
"facts": result.get("facts", []),
|
||||
}
|
||||
sources.append(
|
||||
{
|
||||
"key": str(step_id),
|
||||
"url": str(result.get("url", f"step_{step_id}")),
|
||||
"title": str(result.get("title", f"Step {step_id} Results")),
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unexpected result type for step {step_id}: {type(result).__name__}"
|
||||
)
|
||||
# Handle other types by converting to string
|
||||
content_str: str = str(result)
|
||||
summary = (
|
||||
f"{content_str[:300]}..." if len(content_str) > 300 else content_str
|
||||
)
|
||||
extracted_info[step_id] = {
|
||||
"content": content_str,
|
||||
"summary": summary,
|
||||
"key_points": [content_str],
|
||||
"facts": [],
|
||||
}
|
||||
sources.append(
|
||||
{
|
||||
"key": step_id,
|
||||
"url": f"step_{step_id}",
|
||||
"title": f"Step {step_id} Results",
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Conversion complete: {len(extracted_info)} extracted_info entries, {len(sources)} sources"
|
||||
)
|
||||
return extracted_info, sources
|
||||
# Import refactored utilities instead of duplicating code
|
||||
from biz_bud.tools.capabilities.workflow.execution import (
|
||||
ExecutionRecordFactory,
|
||||
IntermediateResultsConverter,
|
||||
ResponseFormatter,
|
||||
)
|
||||
from biz_bud.tools.capabilities.workflow.planning import PlanParser
|
||||
|
||||
# Re-export for backward compatibility
|
||||
__all__ = [
|
||||
"ExecutionRecordFactory",
|
||||
"PlanParser",
|
||||
"ResponseFormatter",
|
||||
"IntermediateResultsConverter",
|
||||
]
|
||||
|
||||
@@ -565,7 +565,7 @@ async def buddy_orchestrator_node(
|
||||
"Invalid query for tool execution: query is empty"
|
||||
)
|
||||
|
||||
search_result = await search_tool._arun(query=user_query)
|
||||
search_result = await search_tool.ainvoke({"query": user_query})
|
||||
|
||||
# Validate search result
|
||||
if search_result is None:
|
||||
@@ -682,7 +682,10 @@ async def buddy_orchestrator_node(
|
||||
if "capability_summary" in state:
|
||||
planner_context["capability_summary"] = state["capability_summary"] # type: ignore[index]
|
||||
|
||||
plan_result = await planner._arun(query=user_query, context=planner_context)
|
||||
plan_result = await planner.ainvoke({
|
||||
"query": user_query,
|
||||
"context": planner_context
|
||||
})
|
||||
|
||||
# Parse plan using PlanParser
|
||||
if execution_plan := PlanParser.parse_planner_result(plan_result):
|
||||
@@ -816,7 +819,10 @@ async def buddy_executor_node(
|
||||
|
||||
factory = await get_global_factory()
|
||||
executor = await factory.create_node_tool(graph_name)
|
||||
result = await executor._arun(query=step_query, context=context)
|
||||
result = await executor.ainvoke({
|
||||
"query": step_query,
|
||||
"context": context
|
||||
})
|
||||
|
||||
# Process enhanced tool result if it includes messages
|
||||
tool_messages = []
|
||||
@@ -1010,11 +1016,11 @@ async def buddy_synthesizer_node(
|
||||
factory = await get_global_factory()
|
||||
synthesizer = await factory.create_node_tool("synthesize_search_results")
|
||||
|
||||
synthesis_result = await synthesizer._arun(
|
||||
query=user_query,
|
||||
extracted_info=extracted_info,
|
||||
sources=sources,
|
||||
)
|
||||
synthesis_result = await synthesizer.ainvoke({
|
||||
"query": user_query,
|
||||
"extracted_info": extracted_info,
|
||||
"sources": sources
|
||||
})
|
||||
|
||||
# Process enhanced tool result if it includes messages
|
||||
synthesis_messages = []
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Cache manager for LLM operations."""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import pickle
|
||||
@@ -144,7 +145,7 @@ class LLMCache[T]:
|
||||
if hasattr(backend_class, '__orig_bases__'):
|
||||
orig_bases = getattr(backend_class, '__orig_bases__', ())
|
||||
for base in orig_bases:
|
||||
if hasattr(base, '__args__') and base.__args__ and base.__args__[0] == bytes:
|
||||
if hasattr(base, '__args__') and base.__args__ and base.__args__[0] == bytes:
|
||||
return True
|
||||
|
||||
# Check for bytes-only signature by attempting to inspect the set method
|
||||
@@ -215,3 +216,102 @@ class LLMCache[T]:
|
||||
await self._backend.clear()
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache clear failed: {e}")
|
||||
|
||||
|
||||
class GraphCache[T]:
|
||||
"""Cache manager specifically designed for LangGraph graph instances.
|
||||
|
||||
This manager handles caching of compiled graph instances with configuration-based
|
||||
keys and provides thread-safe access patterns for multi-tenant scenarios.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backend: CacheBackend[T] | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
ttl: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the graph cache manager.
|
||||
|
||||
Args:
|
||||
backend: Cache backend to use (defaults to in-memory)
|
||||
cache_dir: Directory for file-based cache (if backend not provided)
|
||||
ttl: Time-to-live in seconds
|
||||
"""
|
||||
if backend is None:
|
||||
# Use in-memory backend for graphs by default
|
||||
from .memory import InMemoryCache
|
||||
self._backend: CacheBackend[T] = InMemoryCache[T](ttl=ttl)
|
||||
else:
|
||||
self._backend = backend
|
||||
|
||||
# Thread-safe access
|
||||
import asyncio
|
||||
self._lock = asyncio.Lock()
|
||||
self._cache: dict[str, T] = {}
|
||||
|
||||
async def get_or_create(
|
||||
self,
|
||||
key: str,
|
||||
factory_func: Any,
|
||||
*args: Any,
|
||||
**kwargs: Any
|
||||
) -> T:
|
||||
"""Get cached graph or create new one using factory function.
|
||||
|
||||
This method provides thread-safe access to cached graphs with
|
||||
automatic creation if not present.
|
||||
|
||||
Args:
|
||||
key: Cache key (typically config hash)
|
||||
factory_func: Function to create graph if not cached
|
||||
*args: Arguments for factory function
|
||||
**kwargs: Keyword arguments for factory function
|
||||
|
||||
Returns:
|
||||
Cached or newly created graph instance
|
||||
"""
|
||||
async with self._lock:
|
||||
# Check in-memory cache first
|
||||
if key in self._cache:
|
||||
return self._cache[key]
|
||||
|
||||
# Check backend cache
|
||||
cached = await self._backend.get(key)
|
||||
if cached is not None:
|
||||
self._cache[key] = cached
|
||||
return cached
|
||||
|
||||
# Create new instance
|
||||
logger.info(f"Creating new graph instance for key: {key}")
|
||||
|
||||
# Handle async factory functions
|
||||
if asyncio.iscoroutinefunction(factory_func):
|
||||
instance = await factory_func(*args, **kwargs)
|
||||
else:
|
||||
instance = factory_func(*args, **kwargs)
|
||||
|
||||
# Cache the instance
|
||||
self._cache[key] = instance
|
||||
await self._backend.set(key, instance, ttl=None)
|
||||
|
||||
logger.info(f"Successfully cached graph for key: {key}")
|
||||
return instance
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all cached graphs."""
|
||||
async with self._lock:
|
||||
self._cache.clear()
|
||||
await self._backend.clear()
|
||||
logger.info("Cleared graph cache")
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""Get cache statistics for monitoring.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache statistics
|
||||
"""
|
||||
return {
|
||||
"memory_cache_size": len(self._cache),
|
||||
"cache_keys": list(self._cache.keys()),
|
||||
}
|
||||
|
||||
@@ -30,11 +30,11 @@ class _DefaultCacheManager:
|
||||
"""Thread-safe manager for the default cache instance using task-based pattern."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._cache_instance: InMemoryCache | None = None
|
||||
self._cache_instance: InMemoryCache[Any] | None = None
|
||||
self._creation_lock = asyncio.Lock()
|
||||
self._initializing_task: asyncio.Task[InMemoryCache] | None = None
|
||||
self._initializing_task: asyncio.Task[InMemoryCache[Any]] | None = None
|
||||
|
||||
async def get_cache(self) -> InMemoryCache:
|
||||
async def get_cache(self) -> InMemoryCache[Any]:
|
||||
"""Get or create the default cache instance with race-condition-free init."""
|
||||
# Fast path - cache already exists
|
||||
if self._cache_instance is not None:
|
||||
@@ -52,10 +52,11 @@ class _DefaultCacheManager:
|
||||
task = self._initializing_task
|
||||
else:
|
||||
# Create new initialization task
|
||||
async def create_cache() -> InMemoryCache:
|
||||
return InMemoryCache()
|
||||
def create_cache() -> InMemoryCache[Any]:
|
||||
return InMemoryCache[Any]()
|
||||
|
||||
task = asyncio.create_task(create_cache())
|
||||
# Use asyncio.to_thread for sync function
|
||||
task = asyncio.create_task(asyncio.to_thread(create_cache))
|
||||
self._initializing_task = task
|
||||
|
||||
# Wait for initialization to complete (outside the lock)
|
||||
@@ -189,6 +190,9 @@ async def _get_cached_value(
|
||||
cached_bytes = await backend.get(cache_key)
|
||||
if cached_bytes is not None:
|
||||
return pickle.loads(cached_bytes)
|
||||
except asyncio.CancelledError:
|
||||
# Always re-raise CancelledError to allow proper cancellation
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
@@ -205,6 +209,9 @@ async def _store_cache_value(
|
||||
if hasattr(backend, 'set'):
|
||||
serialized = pickle.dumps(result)
|
||||
await backend.set(cache_key, serialized, ttl=ttl)
|
||||
except asyncio.CancelledError:
|
||||
# Always re-raise CancelledError to allow proper cancellation
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@@ -4,33 +4,37 @@ import asyncio
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import Final
|
||||
from typing import Final, Generic, TypeVar
|
||||
|
||||
from .base import GenericCacheBackend as CacheBackend
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
class CacheEntry(Generic[T]):
|
||||
"""A single cache entry with expiration."""
|
||||
|
||||
value: bytes
|
||||
value: T
|
||||
expires_at: float | None
|
||||
|
||||
|
||||
class InMemoryCache(CacheBackend[bytes]):
|
||||
class InMemoryCache(CacheBackend[T], Generic[T]):
|
||||
"""In-memory cache backend using a dictionary."""
|
||||
|
||||
def __init__(self, max_size: int | None = None) -> None:
|
||||
def __init__(self, max_size: int | None = None, ttl: int | None = None) -> None:
|
||||
"""Initialize in-memory cache.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of entries (None for unlimited)
|
||||
ttl: Default time-to-live in seconds for entries (None for no expiry)
|
||||
"""
|
||||
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
|
||||
self._cache: OrderedDict[str, CacheEntry[T]] = OrderedDict()
|
||||
self._lock = asyncio.Lock()
|
||||
self.max_size: Final = max_size
|
||||
self.default_ttl: Final = ttl
|
||||
|
||||
async def get(self, key: str) -> bytes | None:
|
||||
async def get(self, key: str) -> T | None:
|
||||
"""Retrieve value from cache."""
|
||||
|
||||
async with self._lock:
|
||||
@@ -51,12 +55,15 @@ class InMemoryCache(CacheBackend[bytes]):
|
||||
async def set(
|
||||
self,
|
||||
key: str,
|
||||
value: bytes,
|
||||
value: T,
|
||||
ttl: int | None = None,
|
||||
) -> None:
|
||||
"""Store value in cache."""
|
||||
|
||||
expires_at = time.time() + ttl if ttl is not None else None
|
||||
# Use provided ttl or default ttl
|
||||
effective_ttl = ttl if ttl is not None else self.default_ttl
|
||||
expires_at = time.time() + effective_ttl if effective_ttl is not None else None
|
||||
|
||||
async with self._lock:
|
||||
# If max_size is 0, don't store anything
|
||||
if self.max_size == 0:
|
||||
|
||||
@@ -495,6 +495,116 @@ class CleanupRegistry:
|
||||
logger.error(f"Failed to create {service_class.__name__} with dependencies: {e}")
|
||||
raise
|
||||
|
||||
async def cleanup_caches(self, cache_names: list[str] | None = None) -> None:
|
||||
"""Cleanup specific or all registered caches.
|
||||
|
||||
This method provides a centralized way to cleanup various caches
|
||||
used throughout the application, including graph caches, service
|
||||
factory caches, and state template caches.
|
||||
|
||||
Args:
|
||||
cache_names: List of specific cache names to clean. If None,
|
||||
cleans all registered caches.
|
||||
|
||||
Example:
|
||||
# Clean specific caches
|
||||
await registry.cleanup_caches(["graph_cache", "service_factory_cache"])
|
||||
|
||||
# Clean all caches
|
||||
await registry.cleanup_caches()
|
||||
"""
|
||||
logger.info(f"Starting cache cleanup: {cache_names or 'all caches'}")
|
||||
|
||||
# Define standard cache cleanup functions
|
||||
cache_cleanup_funcs: dict[str, CleanupFunction] = {
|
||||
"graph_cache": self._cleanup_graph_cache,
|
||||
"service_factory_cache": self._cleanup_service_factory_cache,
|
||||
"state_template_cache": self._cleanup_state_template_cache,
|
||||
"llm_cache": self._cleanup_llm_cache,
|
||||
}
|
||||
|
||||
# Add any registered cache cleanup functions
|
||||
for name, func in self._cleanup_functions.items():
|
||||
if name.endswith("_cache") and name not in cache_cleanup_funcs:
|
||||
cache_cleanup_funcs[name] = func
|
||||
|
||||
# Determine which caches to clean
|
||||
if cache_names:
|
||||
caches_to_clean = {
|
||||
name: func for name, func in cache_cleanup_funcs.items()
|
||||
if name in cache_names
|
||||
}
|
||||
else:
|
||||
caches_to_clean = cache_cleanup_funcs
|
||||
|
||||
# Clean caches concurrently
|
||||
cleanup_tasks = []
|
||||
for cache_name, cleanup_func in caches_to_clean.items():
|
||||
logger.debug(f"Cleaning cache: {cache_name}")
|
||||
cleanup_tasks.append(cleanup_func())
|
||||
|
||||
if cleanup_tasks:
|
||||
results = await gather_with_concurrency(
|
||||
5, *cleanup_tasks, return_exceptions=True
|
||||
)
|
||||
|
||||
# Log any failures
|
||||
for cache_name, result in zip(caches_to_clean.keys(), results):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Failed to clean {cache_name}: {result}")
|
||||
else:
|
||||
logger.debug(f"Successfully cleaned {cache_name}")
|
||||
|
||||
logger.info("Cache cleanup completed")
|
||||
|
||||
async def _cleanup_graph_cache(self) -> None:
|
||||
"""Clean up graph cache instances."""
|
||||
try:
|
||||
# Delegate to the graph module's cleanup function
|
||||
if "cleanup_graph_cache" in self._cleanup_functions:
|
||||
await self._cleanup_functions["cleanup_graph_cache"]()
|
||||
else:
|
||||
logger.debug("No graph cache cleanup function registered")
|
||||
except Exception as e:
|
||||
logger.error(f"Graph cache cleanup failed: {e}")
|
||||
raise
|
||||
|
||||
async def _cleanup_service_factory_cache(self) -> None:
|
||||
"""Clean up service factory cache."""
|
||||
try:
|
||||
# Delegate to factory cleanup
|
||||
if "cleanup_service_factory" in self._cleanup_functions:
|
||||
await self._cleanup_functions["cleanup_service_factory"]()
|
||||
else:
|
||||
logger.debug("No service factory cleanup function registered")
|
||||
except Exception as e:
|
||||
logger.error(f"Service factory cache cleanup failed: {e}")
|
||||
raise
|
||||
|
||||
async def _cleanup_state_template_cache(self) -> None:
|
||||
"""Clean up state template cache."""
|
||||
try:
|
||||
# This would be registered by graph module
|
||||
if "cleanup_state_templates" in self._cleanup_functions:
|
||||
await self._cleanup_functions["cleanup_state_templates"]()
|
||||
else:
|
||||
logger.debug("No state template cleanup function registered")
|
||||
except Exception as e:
|
||||
logger.error(f"State template cache cleanup failed: {e}")
|
||||
raise
|
||||
|
||||
async def _cleanup_llm_cache(self) -> None:
|
||||
"""Clean up LLM response cache."""
|
||||
try:
|
||||
# This would be registered by LLM services
|
||||
if "cleanup_llm_cache" in self._cleanup_functions:
|
||||
await self._cleanup_functions["cleanup_llm_cache"]()
|
||||
else:
|
||||
logger.debug("No LLM cache cleanup function registered")
|
||||
except Exception as e:
|
||||
logger.error(f"LLM cache cleanup failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# Global cleanup registry instance
|
||||
_cleanup_registry: CleanupRegistry | None = None
|
||||
|
||||
@@ -17,9 +17,11 @@ Precedence hierarchy (highest to lowest):
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
import yaml
|
||||
from dotenv import dotenv_values
|
||||
@@ -652,3 +654,79 @@ async def load_config_async(
|
||||
yaml_path=yaml_path,
|
||||
overrides=overrides
|
||||
)
|
||||
|
||||
|
||||
def generate_config_hash(config: AppConfig | dict[str, Any]) -> str:
|
||||
"""Generate a deterministic hash from application configuration.
|
||||
|
||||
This function creates a SHA-256 hash for configuration-based caching,
|
||||
including all relevant fields that can affect graph creation or service
|
||||
initialization to avoid cache misses or collisions.
|
||||
|
||||
Args:
|
||||
config: Application configuration (AppConfig instance or dict)
|
||||
|
||||
Returns:
|
||||
Configuration hash string (first 16 chars of SHA-256 hash)
|
||||
|
||||
Example:
|
||||
```python
|
||||
config = load_config()
|
||||
cache_key = generate_config_hash(config)
|
||||
# Use cache_key for caching graphs, services, etc.
|
||||
```
|
||||
"""
|
||||
# Use Pydantic's model_dump with include set for relevant fields
|
||||
if hasattr(config, 'model_dump') and callable(getattr(config, 'model_dump')):
|
||||
config_dict = config.model_dump(include={ # type: ignore[union-attr]
|
||||
"llm_config": {"profile", "model", "max_retries", "temperature", "timeout"},
|
||||
"tools_enabled": True,
|
||||
"cache_enabled": True,
|
||||
"logging_level": True,
|
||||
"debug_mode": True,
|
||||
"redis_config": True,
|
||||
"postgres_config": True,
|
||||
"qdrant_config": True,
|
||||
})
|
||||
else:
|
||||
config_dict = cast(dict[str, Any], config)
|
||||
|
||||
# Flatten out each *_config into a boolean *_enabled field
|
||||
config_data = {}
|
||||
|
||||
# Handle LLM config fields
|
||||
if "llm_config" in config_dict and isinstance(config_dict["llm_config"], dict):
|
||||
llm_config = config_dict["llm_config"]
|
||||
config_data.update({
|
||||
"llm_profile": llm_config.get("profile", "default"),
|
||||
"model_name": llm_config.get("model", "default"),
|
||||
"max_retries": llm_config.get("max_retries", 3),
|
||||
"temperature": llm_config.get("temperature", 0.7),
|
||||
"timeout": llm_config.get("timeout", 30),
|
||||
})
|
||||
else:
|
||||
config_data.update({
|
||||
"llm_profile": "default",
|
||||
"model_name": "default",
|
||||
"max_retries": 3,
|
||||
"temperature": 0.7,
|
||||
"timeout": 30,
|
||||
})
|
||||
|
||||
# Handle other config fields
|
||||
config_data.update({
|
||||
"tools_enabled": config_dict.get("tools_enabled", True),
|
||||
"cache_enabled": config_dict.get("cache_enabled", True),
|
||||
"logging_level": config_dict.get("logging_level", "INFO"),
|
||||
"debug_mode": config_dict.get("debug_mode", False),
|
||||
})
|
||||
|
||||
# Handle service flags
|
||||
for service in ("redis", "postgres", "qdrant"):
|
||||
config_data[f"{service}_enabled"] = bool(config_dict.get(f"{service}_config"))
|
||||
|
||||
# Use canonical JSON serialization with sorted keys for deterministic hashing
|
||||
config_str = json.dumps(config_data, sort_keys=True, separators=(',', ':'))
|
||||
|
||||
# Use SHA-256 for collision resistance
|
||||
return hashlib.sha256(config_str.encode('utf-8')).hexdigest()[:16]
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Application-level configuration models."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
@@ -203,8 +204,7 @@ class AppConfig(BaseModel):
|
||||
|
||||
def __await__(self) -> Generator[Any, None, "AppConfig"]:
|
||||
"""Make AppConfig awaitable (no-op, returns self)."""
|
||||
|
||||
async def _noop() -> "AppConfig":
|
||||
return self
|
||||
|
||||
return _noop().__await__()
|
||||
# Create a completed future with self as the result
|
||||
future: asyncio.Future["AppConfig"] = asyncio.Future()
|
||||
future.set_result(self)
|
||||
return future.__await__()
|
||||
|
||||
@@ -4,6 +4,8 @@ from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from biz_bud.core.errors import ConfigurationError
|
||||
|
||||
|
||||
class APIConfigModel(BaseModel):
|
||||
"""Pydantic model for API configuration parameters.
|
||||
@@ -124,7 +126,7 @@ class DatabaseConfigModel(BaseModel):
|
||||
def connection_string(self) -> str:
|
||||
"""Generate PostgreSQL connection string from configuration."""
|
||||
if not all([self.postgres_user, self.postgres_password, self.postgres_host, self.postgres_port, self.postgres_db]):
|
||||
raise ValueError("All PostgreSQL connection parameters must be set")
|
||||
raise ConfigurationError("All PostgreSQL connection parameters must be set")
|
||||
return f"postgresql://{self.postgres_user}:{self.postgres_password}@{self.postgres_host}:{self.postgres_port}/{self.postgres_db}"
|
||||
|
||||
|
||||
|
||||
@@ -52,10 +52,12 @@ class EdgeHelpers(BasicRouters, CommandRouters, WorkflowRouters):
|
||||
def router(state: dict[str, Any]) -> str:
|
||||
errors = state.get(error_key, [])
|
||||
|
||||
if isinstance(errors, list):
|
||||
error_count = len(errors)
|
||||
else:
|
||||
error_count = 1 if errors else 0
|
||||
# Normalize errors inline to avoid circular import
|
||||
if not errors:
|
||||
errors = []
|
||||
elif not isinstance(errors, list):
|
||||
errors = [errors]
|
||||
error_count = len(errors)
|
||||
|
||||
return error_target if error_count >= threshold else success_target
|
||||
|
||||
|
||||
@@ -44,10 +44,7 @@ def detect_errors_list(
|
||||
errors = getattr(state, errors_key, [])
|
||||
|
||||
# Check if we have any errors
|
||||
if len(errors) > 0:
|
||||
return error_target
|
||||
|
||||
return success_target
|
||||
return error_target if len(errors) > 0 else success_target
|
||||
|
||||
return router
|
||||
|
||||
|
||||
@@ -378,9 +378,9 @@ def health_check(
|
||||
degraded_count += 1
|
||||
else:
|
||||
# Lower is better (default for most metrics including unknown ones)
|
||||
if threshold_val == 0.0:
|
||||
if abs(threshold_val) < 1e-9: # Use tolerance for float comparison
|
||||
# Special case for zero threshold - any positive value is degraded
|
||||
if value > 0.0:
|
||||
if value > 1e-9:
|
||||
degraded_count += 1
|
||||
elif value > threshold_val * 1.5: # 150% of threshold
|
||||
unhealthy_count += 1
|
||||
|
||||
@@ -95,7 +95,7 @@ class SecureGraphRouter:
|
||||
)
|
||||
|
||||
# SECURITY: Validate factory function
|
||||
await self.execution_manager.validate_factory_function(
|
||||
self.execution_manager.validate_factory_function(
|
||||
factory_function, validated_graph_name
|
||||
)
|
||||
|
||||
|
||||
@@ -15,6 +15,13 @@ from typing import Any
|
||||
from .base import ErrorCategory, ErrorInfo, ErrorNamespace, create_error_info
|
||||
|
||||
|
||||
# Lazy imports to avoid circular dependency
|
||||
def _get_regex_utils():
|
||||
"""Lazy import for regex security utilities to avoid circular dependency."""
|
||||
from biz_bud.core.utils.regex_security import search_safe, sub_safe
|
||||
return search_safe, sub_safe
|
||||
|
||||
|
||||
class ErrorMessageFormatter:
|
||||
"""Formatter for standardized error messages with security sanitization."""
|
||||
|
||||
@@ -198,21 +205,23 @@ class ErrorMessageFormatter:
|
||||
# Network errors
|
||||
if "connection" in message or "connect" in message:
|
||||
details["template_type"] = "connection_failed"
|
||||
# Try to find any host:port pattern first
|
||||
if host_match := re.search(r"([a-zA-Z0-9.-]+:\d+)", str(exception)):
|
||||
details["resource"] = host_match[1]
|
||||
elif host_match := re.search(
|
||||
# Try to find any host:port pattern first using safe regex
|
||||
search_safe, _ = _get_regex_utils()
|
||||
if host_match := search_safe(r"([a-zA-Z0-9.-]+:\d+)", str(exception)):
|
||||
details["resource"] = host_match.group(1)
|
||||
elif host_match := search_safe(
|
||||
r"(?:to|from|at)\s+([a-zA-Z0-9.-]+)", str(exception)
|
||||
):
|
||||
details["resource"] = host_match[1]
|
||||
details["resource"] = host_match.group(1)
|
||||
|
||||
elif "timeout" in message or "timeout" in details["exception_type"].lower():
|
||||
details["template_type"] = "timeout"
|
||||
# Try to extract timeout duration
|
||||
if timeout_match := re.search(
|
||||
# Try to extract timeout duration using safe regex
|
||||
search_safe, _ = _get_regex_utils()
|
||||
if timeout_match := search_safe(
|
||||
r"(\d+(?:\.\d+)?)\s*(?:s|sec|seconds?)", str(exception)
|
||||
):
|
||||
details["timeout"] = timeout_match[1]
|
||||
details["timeout"] = timeout_match.group(1)
|
||||
|
||||
elif "auth" in message or "credential" in message or "unauthorized" in message:
|
||||
details["template_type"] = "invalid_credentials"
|
||||
@@ -222,27 +231,30 @@ class ErrorMessageFormatter:
|
||||
|
||||
elif "rate limit" in message or "quota" in message or "too many" in message:
|
||||
details["template_type"] = "quota_exceeded"
|
||||
if retry_match := re.search(
|
||||
search_safe, _ = _get_regex_utils()
|
||||
if retry_match := search_safe(
|
||||
r"(?:retry|wait)\s+(?:after\s+)?(\d+)", message
|
||||
):
|
||||
details["retry_after"] = retry_match[1]
|
||||
details["retry_after"] = retry_match.group(1)
|
||||
|
||||
elif "missing" in message and ("field" in message or "required" in message):
|
||||
details["template_type"] = "missing_field"
|
||||
if field_match := re.search(
|
||||
search_safe, _ = _get_regex_utils()
|
||||
if field_match := search_safe(
|
||||
r"(?:field|parameter|argument)\s+['\"`]?([a-zA-Z_]\w*)['\"`]?",
|
||||
str(exception),
|
||||
):
|
||||
details["field"] = field_match[1]
|
||||
details["field"] = field_match.group(1)
|
||||
|
||||
elif "invalid" in message and "format" in message:
|
||||
details["template_type"] = "invalid_format"
|
||||
|
||||
elif "token" in message and ("limit" in message or "exceed" in message):
|
||||
details["template_type"] = "context_overflow"
|
||||
if token_match := re.search(r"(\d+)\s*tokens?.*?(\d+)", str(exception)):
|
||||
details["tokens"] = token_match[1]
|
||||
details["limit"] = token_match[2]
|
||||
search_safe, _ = _get_regex_utils()
|
||||
if token_match := search_safe(r"(\d+)\s*tokens?.*?(\d+)", str(exception)):
|
||||
details["tokens"] = token_match.group(1)
|
||||
details["limit"] = token_match.group(2)
|
||||
|
||||
return details
|
||||
|
||||
@@ -263,15 +275,16 @@ class ErrorMessageFormatter:
|
||||
"""
|
||||
message = str(error)
|
||||
|
||||
# Apply sanitization rules
|
||||
# Apply sanitization rules using safe regex operations
|
||||
_, sub_safe = _get_regex_utils()
|
||||
for pattern, replacement, flags, user_only in cls.SANITIZE_RULES:
|
||||
if not user_only or for_user:
|
||||
message = re.sub(pattern, replacement, message, flags)
|
||||
message = sub_safe(pattern, replacement, message, flags=flags)
|
||||
|
||||
# Apply logging-only rules when not for user
|
||||
# Apply logging-only rules when not for user using safe regex operations
|
||||
if not for_user:
|
||||
for pattern, replacement, flags in cls.LOGGING_ONLY_RULES:
|
||||
message = re.sub(pattern, replacement, message, flags)
|
||||
message = sub_safe(pattern, replacement, message, flags=flags)
|
||||
|
||||
if for_user:
|
||||
# Stack traces (keep only the error message part)
|
||||
@@ -404,15 +417,15 @@ def format_error_for_user(error: ErrorInfo) -> str:
|
||||
if category and category != "unknown":
|
||||
# Capitalize category for display
|
||||
category_display = category.capitalize()
|
||||
user_message += f"\n🏷️ Category: {category_display}"
|
||||
user_message += f"\n[Category: {category_display}]"
|
||||
|
||||
if suggestion := error["context"].get("suggestion"):
|
||||
user_message += f"\n💡 {suggestion}"
|
||||
user_message += f"\nSuggestion: {suggestion}"
|
||||
|
||||
# Add node information if relevant
|
||||
node = error["node"]
|
||||
if node and node not in sanitized_message:
|
||||
user_message += f"\n📍 Location: {node}"
|
||||
user_message += f"\nLocation: {node}"
|
||||
|
||||
return user_message
|
||||
|
||||
|
||||
@@ -242,10 +242,12 @@ def get_error_summary(state: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
# Add state-specific information
|
||||
state_errors = state.get("errors", [])
|
||||
if isinstance(state_errors, list):
|
||||
aggregator_summary["state_error_count"] = len(state_errors)
|
||||
else:
|
||||
aggregator_summary["state_error_count"] = 0
|
||||
# Normalize errors locally to avoid circular import
|
||||
if not state_errors:
|
||||
state_errors = []
|
||||
elif not isinstance(state_errors, list):
|
||||
state_errors = [state_errors]
|
||||
aggregator_summary["state_error_count"] = len(state_errors)
|
||||
|
||||
return aggregator_summary
|
||||
|
||||
|
||||
@@ -192,7 +192,7 @@ class LLMExceptionHandler:
|
||||
else:
|
||||
raise exception_instance
|
||||
|
||||
async def handle_llm_exception(
|
||||
def handle_llm_exception(
|
||||
self,
|
||||
exc: Exception,
|
||||
attempt: int,
|
||||
|
||||
@@ -235,9 +235,9 @@ class ErrorRouter:
|
||||
break
|
||||
|
||||
# Apply final action
|
||||
return await self._apply_action(final_action, final_error, context)
|
||||
return self._apply_action(final_action, final_error, context)
|
||||
|
||||
async def _apply_action(
|
||||
def _apply_action(
|
||||
self,
|
||||
action: RouteAction,
|
||||
error: ErrorInfo | None,
|
||||
|
||||
@@ -302,7 +302,7 @@ class RouterConfig:
|
||||
|
||||
|
||||
# Example custom handlers
|
||||
async def retry_handler(error: ErrorInfo, context: dict[str, Any]) -> ErrorInfo | None:
|
||||
def retry_handler(error: ErrorInfo, context: dict[str, Any]) -> ErrorInfo | None:
|
||||
"""Handle transient errors with retry logic."""
|
||||
retry_count = context.get("retry_count", 0)
|
||||
max_retries = context.get("max_retries", 3)
|
||||
@@ -333,7 +333,7 @@ async def retry_handler(error: ErrorInfo, context: dict[str, Any]) -> ErrorInfo
|
||||
return None
|
||||
|
||||
|
||||
async def notification_handler(
|
||||
def notification_handler(
|
||||
error: ErrorInfo, context: dict[str, Any]
|
||||
) -> ErrorInfo | None:
|
||||
"""Send notifications for critical errors."""
|
||||
|
||||
@@ -444,7 +444,7 @@ class ConsoleMetricsClient:
|
||||
tags: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
"""Print increment metric."""
|
||||
print(f"📊 METRIC: {metric} +{value} tags={tags or {}}")
|
||||
print(f"METRIC: {metric} +{value} tags={tags or {}}")
|
||||
|
||||
def gauge(
|
||||
self,
|
||||
@@ -453,7 +453,7 @@ class ConsoleMetricsClient:
|
||||
tags: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
"""Print gauge metric."""
|
||||
print(f"📊 METRIC: {metric} ={value} tags={tags or {}}")
|
||||
print(f"METRIC: {metric} ={value} tags={tags or {}}")
|
||||
|
||||
def histogram(
|
||||
self,
|
||||
|
||||
@@ -5,6 +5,8 @@ best practices including state immutability, cross-cutting concerns,
|
||||
configuration management, and graph orchestration.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .cross_cutting import (
|
||||
handle_errors,
|
||||
log_node_execution,
|
||||
@@ -56,4 +58,72 @@ __all__ = [
|
||||
"create_config_injected_node",
|
||||
"extract_config_from_state",
|
||||
"update_node_to_use_config",
|
||||
# Type compatibility utilities
|
||||
"create_type_safe_wrapper",
|
||||
"wrap_for_langgraph",
|
||||
]
|
||||
|
||||
|
||||
def create_type_safe_wrapper(func: Any) -> Any:
|
||||
"""Create a type-safe wrapper for functions to avoid LangGraph typing issues.
|
||||
|
||||
This utility helps wrap functions that need to cast their state parameter
|
||||
to avoid typing conflicts in LangGraph's strict type system.
|
||||
|
||||
Args:
|
||||
func: Function to wrap
|
||||
|
||||
Returns:
|
||||
Wrapped function with proper type casting
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Original function with specific state type
|
||||
def my_router(state: InputState) -> str:
|
||||
return route_error_severity(state) # Type error!
|
||||
|
||||
# Create type-safe wrapper
|
||||
safe_router = create_type_safe_wrapper(route_error_severity)
|
||||
|
||||
# Use in graph
|
||||
builder.add_conditional_edges(
|
||||
"node",
|
||||
safe_router,
|
||||
{...}
|
||||
)
|
||||
```
|
||||
"""
|
||||
def wrapper(state: Any) -> Any:
|
||||
"""Type-safe wrapper that casts state to target type."""
|
||||
return func(state)
|
||||
|
||||
# Preserve function metadata
|
||||
wrapper.__name__ = f"{func.__name__}_wrapped"
|
||||
wrapper.__doc__ = func.__doc__
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def wrap_for_langgraph() -> Any:
|
||||
"""Decorator to create type-safe wrappers for LangGraph conditional edges.
|
||||
|
||||
This decorator helps avoid typing issues when using functions as
|
||||
conditional edges in LangGraph by properly casting the state parameter.
|
||||
|
||||
Returns:
|
||||
Decorator function
|
||||
|
||||
Example:
|
||||
```python
|
||||
@wrap_for_langgraph()
|
||||
def route_by_error(state: InputState) -> str:
|
||||
return route_error_severity(state)
|
||||
|
||||
# Now safe to use in graph
|
||||
builder.add_conditional_edges("node", route_by_error, {...})
|
||||
```
|
||||
"""
|
||||
def decorator(func: Any) -> Any:
|
||||
return create_type_safe_wrapper(func)
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -7,6 +7,7 @@ nodes and tools in the Business Buddy framework.
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import UTC, datetime
|
||||
@@ -29,6 +30,70 @@ class NodeMetric(TypedDict):
|
||||
last_error: str | None
|
||||
|
||||
|
||||
def _log_execution_start(node_name: str, context: dict[str, Any]) -> float:
|
||||
"""Log the start of node execution and return start time.
|
||||
|
||||
Args:
|
||||
node_name: Name of the node being executed
|
||||
context: Execution context
|
||||
|
||||
Returns:
|
||||
Start time for duration calculation
|
||||
"""
|
||||
start_time = time.time()
|
||||
logger.info(
|
||||
f"Node '{node_name}' started",
|
||||
extra={
|
||||
"node_name": node_name,
|
||||
"run_id": context.get("run_id"),
|
||||
"user_id": context.get("user_id"),
|
||||
},
|
||||
)
|
||||
return start_time
|
||||
|
||||
|
||||
def _log_execution_success(node_name: str, start_time: float, context: dict[str, Any]) -> None:
|
||||
"""Log successful node execution.
|
||||
|
||||
Args:
|
||||
node_name: Name of the node
|
||||
start_time: Execution start time
|
||||
context: Execution context
|
||||
"""
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
logger.info(
|
||||
f"Node '{node_name}' completed successfully",
|
||||
extra={
|
||||
"node_name": node_name,
|
||||
"duration_ms": elapsed_ms,
|
||||
"run_id": context.get("run_id"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _log_execution_error(
|
||||
node_name: str, start_time: float, error: Exception, context: dict[str, Any]
|
||||
) -> None:
|
||||
"""Log node execution error.
|
||||
|
||||
Args:
|
||||
node_name: Name of the node
|
||||
start_time: Execution start time
|
||||
error: The exception that occurred
|
||||
context: Execution context
|
||||
"""
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
logger.error(
|
||||
f"Node '{node_name}' failed: {str(error)}",
|
||||
extra={
|
||||
"node_name": node_name,
|
||||
"duration_ms": elapsed_ms,
|
||||
"error_type": type(error).__name__,
|
||||
"run_id": context.get("run_id"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def log_node_execution(
|
||||
node_name: str | None = None,
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
@@ -49,90 +114,28 @@ def log_node_execution(
|
||||
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
start_time = time.time()
|
||||
|
||||
# Extract context from RunnableConfig if available
|
||||
context = _extract_context_from_args(args, kwargs)
|
||||
|
||||
logger.info(
|
||||
f"Node '{actual_node_name}' started",
|
||||
extra={
|
||||
"node_name": actual_node_name,
|
||||
"run_id": context.get("run_id"),
|
||||
"user_id": context.get("user_id"),
|
||||
},
|
||||
)
|
||||
start_time = _log_execution_start(actual_node_name, context)
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
logger.info(
|
||||
f"Node '{actual_node_name}' completed successfully",
|
||||
extra={
|
||||
"node_name": actual_node_name,
|
||||
"duration_ms": elapsed_ms,
|
||||
"run_id": context.get("run_id"),
|
||||
},
|
||||
)
|
||||
|
||||
_log_execution_success(actual_node_name, start_time, context)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
logger.error(
|
||||
f"Node '{actual_node_name}' failed: {str(e)}",
|
||||
extra={
|
||||
"node_name": actual_node_name,
|
||||
"duration_ms": elapsed_ms,
|
||||
"error_type": type(e).__name__,
|
||||
"run_id": context.get("run_id"),
|
||||
},
|
||||
)
|
||||
_log_execution_error(actual_node_name, start_time, e, context)
|
||||
raise
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
start_time = time.time()
|
||||
context = _extract_context_from_args(args, kwargs)
|
||||
|
||||
logger.info(
|
||||
f"Node '{actual_node_name}' started",
|
||||
extra={
|
||||
"node_name": actual_node_name,
|
||||
"run_id": context.get("run_id"),
|
||||
"user_id": context.get("user_id"),
|
||||
},
|
||||
)
|
||||
start_time = _log_execution_start(actual_node_name, context)
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
logger.info(
|
||||
f"Node '{actual_node_name}' completed successfully",
|
||||
extra={
|
||||
"node_name": actual_node_name,
|
||||
"duration_ms": elapsed_ms,
|
||||
"run_id": context.get("run_id"),
|
||||
},
|
||||
)
|
||||
|
||||
_log_execution_success(actual_node_name, start_time, context)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
logger.error(
|
||||
f"Node '{actual_node_name}' failed: {str(e)}",
|
||||
extra={
|
||||
"node_name": actual_node_name,
|
||||
"duration_ms": elapsed_ms,
|
||||
"error_type": type(e).__name__,
|
||||
"run_id": context.get("run_id"),
|
||||
},
|
||||
)
|
||||
_log_execution_error(actual_node_name, start_time, e, context)
|
||||
raise
|
||||
|
||||
# Return appropriate wrapper based on function type
|
||||
@@ -287,6 +290,58 @@ def track_metrics(
|
||||
return decorator
|
||||
|
||||
|
||||
def _handle_error(
|
||||
error: Exception,
|
||||
func_name: str,
|
||||
args: tuple[Any, ...],
|
||||
error_handler: Callable[[Exception], Any] | None,
|
||||
fallback_value: Any
|
||||
) -> Any:
|
||||
"""Common error handling logic for both sync and async functions.
|
||||
|
||||
Args:
|
||||
error: The exception that was raised
|
||||
func_name: Name of the function that failed
|
||||
args: Function arguments (to extract state)
|
||||
error_handler: Optional custom error handler
|
||||
fallback_value: Value to return on error (if not re-raising)
|
||||
|
||||
Returns:
|
||||
Fallback value or re-raises the exception
|
||||
"""
|
||||
# Log the error
|
||||
logger.error(
|
||||
f"Error in {func_name}: {str(error)}",
|
||||
exc_info=True,
|
||||
extra={"function": func_name, "error_type": type(error).__name__},
|
||||
)
|
||||
|
||||
# Call custom error handler if provided
|
||||
if error_handler:
|
||||
error_handler(error)
|
||||
|
||||
# Update state with error if available
|
||||
state = args[0] if args and isinstance(args[0], dict) else None
|
||||
if state and "errors" in state:
|
||||
if not isinstance(state["errors"], list):
|
||||
state["errors"] = []
|
||||
|
||||
state["errors"].append(
|
||||
{
|
||||
"node": func_name,
|
||||
"error": str(error),
|
||||
"type": type(error).__name__,
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
# Return fallback value or re-raise
|
||||
if fallback_value is not None:
|
||||
return fallback_value
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def handle_errors(
|
||||
error_handler: Callable[[Exception], Any] | None = None, fallback_value: Any = None
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
@@ -309,70 +364,14 @@ def handle_errors(
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
# Log the error
|
||||
logger.error(
|
||||
f"Error in {func.__name__}: {str(e)}",
|
||||
exc_info=True,
|
||||
extra={"function": func.__name__, "error_type": type(e).__name__},
|
||||
)
|
||||
|
||||
# Call custom error handler if provided
|
||||
if error_handler:
|
||||
error_handler(e)
|
||||
|
||||
# Update state with error if available
|
||||
state = args[0] if args and isinstance(args[0], dict) else None
|
||||
if state and "errors" in state:
|
||||
if not isinstance(state["errors"], list):
|
||||
state["errors"] = []
|
||||
|
||||
state["errors"].append(
|
||||
{
|
||||
"node": func.__name__,
|
||||
"error": str(e),
|
||||
"type": type(e).__name__,
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
# Return fallback value or re-raise
|
||||
if fallback_value is not None:
|
||||
return fallback_value
|
||||
else:
|
||||
raise
|
||||
return _handle_error(e, func.__name__, args, error_handler, fallback_value)
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in {func.__name__}: {str(e)}",
|
||||
exc_info=True,
|
||||
extra={"function": func.__name__, "error_type": type(e).__name__},
|
||||
)
|
||||
|
||||
if error_handler:
|
||||
error_handler(e)
|
||||
|
||||
state = args[0] if args and isinstance(args[0], dict) else None
|
||||
if state and "errors" in state:
|
||||
if not isinstance(state["errors"], list):
|
||||
state["errors"] = []
|
||||
|
||||
state["errors"].append(
|
||||
{
|
||||
"node": func.__name__,
|
||||
"error": str(e),
|
||||
"type": type(e).__name__,
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
if fallback_value is not None:
|
||||
return fallback_value
|
||||
else:
|
||||
raise
|
||||
return _handle_error(e, func.__name__, args, error_handler, fallback_value)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
|
||||
|
||||
@@ -6,7 +6,8 @@ enabling consistent configuration injection across all nodes and tools.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
@@ -51,7 +52,7 @@ def configure_graph_with_injection(
|
||||
|
||||
# Create wrapper that injects config
|
||||
if callable(node_func):
|
||||
callable_node = cast(Callable[..., object], node_func)
|
||||
callable_node = cast(Callable[..., Any] | Callable[..., Awaitable[Any]], node_func)
|
||||
wrapped_node = create_config_injected_node(callable_node, base_config)
|
||||
# Replace the node
|
||||
graph_builder.nodes[node_name] = wrapped_node
|
||||
@@ -60,7 +61,7 @@ def configure_graph_with_injection(
|
||||
|
||||
|
||||
def create_config_injected_node(
|
||||
node_func: Callable[..., object], base_config: RunnableConfig
|
||||
node_func: Callable[..., Any] | Callable[..., Awaitable[Any]], base_config: RunnableConfig
|
||||
) -> Any:
|
||||
"""Create a node wrapper that injects RunnableConfig.
|
||||
|
||||
@@ -96,8 +97,9 @@ def create_config_injected_node(
|
||||
merged_config = base_config
|
||||
|
||||
# Call original node with merged config
|
||||
if inspect.iscoroutinefunction(node_func):
|
||||
return await node_func(state, config=merged_config)
|
||||
if asyncio.iscoroutinefunction(node_func):
|
||||
coro_result = node_func(state, config=merged_config)
|
||||
return await cast(Awaitable[Any], coro_result)
|
||||
else:
|
||||
return node_func(state, config=merged_config)
|
||||
|
||||
@@ -145,8 +147,9 @@ def update_node_to_use_config(
|
||||
# noqa: ARG001
|
||||
) -> object:
|
||||
# Call original without config (for backward compatibility)
|
||||
if inspect.iscoroutinefunction(node_func):
|
||||
return await node_func(state)
|
||||
if asyncio.iscoroutinefunction(node_func):
|
||||
coro_result = node_func(state)
|
||||
return await cast(Awaitable[Any], coro_result)
|
||||
else:
|
||||
return node_func(state)
|
||||
|
||||
|
||||
@@ -605,6 +605,7 @@ def validate_state_schema(state: dict[str, Any], schema: type) -> None:
|
||||
)
|
||||
|
||||
if key in state:
|
||||
state[key]
|
||||
# Access the value for potential type checking
|
||||
_value = state[key]
|
||||
# Basic type checking (would be more sophisticated in practice)
|
||||
# Skip validation for now to avoid complex type checking
|
||||
|
||||
@@ -13,12 +13,15 @@ from .api_client import (
|
||||
proxied_rate_limited_request,
|
||||
)
|
||||
from .async_utils import (
|
||||
AsyncContextInfo,
|
||||
ChainLink,
|
||||
RateLimiter,
|
||||
detect_async_context,
|
||||
gather_with_concurrency,
|
||||
process_items_in_parallel,
|
||||
retry_async,
|
||||
run_async_chain,
|
||||
run_in_appropriate_context,
|
||||
to_async,
|
||||
with_timeout,
|
||||
)
|
||||
@@ -50,12 +53,15 @@ __all__ = [
|
||||
"create_api_client",
|
||||
"proxied_rate_limited_request",
|
||||
# Async utilities
|
||||
"AsyncContextInfo",
|
||||
"ChainLink",
|
||||
"RateLimiter",
|
||||
"detect_async_context",
|
||||
"gather_with_concurrency",
|
||||
"process_items_in_parallel",
|
||||
"retry_async",
|
||||
"run_async_chain",
|
||||
"run_in_appropriate_context",
|
||||
"to_async",
|
||||
"with_timeout",
|
||||
# HTTP Client
|
||||
|
||||
@@ -2,8 +2,12 @@
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable, Coroutine
|
||||
from contextvars import copy_context
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, ParamSpec, TypeVar, cast
|
||||
|
||||
from biz_bud.core.errors import BusinessBuddyError, ValidationError
|
||||
@@ -278,3 +282,252 @@ async def run_async_chain[T]( # noqa: D103
|
||||
if asyncio.iscoroutine(result):
|
||||
result = await result
|
||||
return cast(T, result)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AsyncContextInfo:
|
||||
"""Information about the current execution context."""
|
||||
|
||||
is_async: bool
|
||||
detection_method: str
|
||||
has_running_loop: bool
|
||||
loop_info: str | None = None
|
||||
|
||||
|
||||
def detect_async_context() -> AsyncContextInfo:
|
||||
"""Detect if we're currently in an async context.
|
||||
|
||||
This function uses multiple methods to reliably detect whether
|
||||
code is running in an async context, which helps avoid common
|
||||
issues with mixing sync and async code.
|
||||
|
||||
Returns:
|
||||
AsyncContextInfo with details about the current context
|
||||
"""
|
||||
# Method 1: Check if we're in a coroutine using inspect
|
||||
current_frame = inspect.currentframe()
|
||||
try:
|
||||
frame = current_frame
|
||||
while frame is not None:
|
||||
# Check if frame is from a coroutine
|
||||
if frame.f_code.co_flags & inspect.CO_COROUTINE:
|
||||
return AsyncContextInfo(
|
||||
is_async=True,
|
||||
detection_method="coroutine_frame",
|
||||
has_running_loop=_check_running_loop(),
|
||||
)
|
||||
# Also check for async generator
|
||||
if frame.f_code.co_flags & inspect.CO_ASYNC_GENERATOR:
|
||||
return AsyncContextInfo(
|
||||
is_async=True,
|
||||
detection_method="async_generator_frame",
|
||||
has_running_loop=_check_running_loop(),
|
||||
)
|
||||
frame = frame.f_back
|
||||
finally:
|
||||
# Clean up frame references to avoid reference cycles
|
||||
del current_frame
|
||||
frame = None # Initialize before potential use
|
||||
if 'frame' in locals():
|
||||
del frame
|
||||
|
||||
# Method 2: Check for running event loop
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
if loop.is_running():
|
||||
return AsyncContextInfo(
|
||||
is_async=True,
|
||||
detection_method="running_loop",
|
||||
has_running_loop=True,
|
||||
loop_info=f"Loop type: {type(loop).__name__}",
|
||||
)
|
||||
except RuntimeError:
|
||||
# No running loop
|
||||
pass
|
||||
|
||||
# Method 3: Check context vars for async markers
|
||||
ctx = copy_context()
|
||||
for var in ctx:
|
||||
if hasattr(var, 'name') and 'async' in str(var.name).lower():
|
||||
return AsyncContextInfo(
|
||||
is_async=True,
|
||||
detection_method="context_vars",
|
||||
has_running_loop=_check_running_loop(),
|
||||
)
|
||||
|
||||
# Not in async context
|
||||
return AsyncContextInfo(
|
||||
is_async=False,
|
||||
detection_method="none",
|
||||
has_running_loop=False,
|
||||
)
|
||||
|
||||
|
||||
def _check_running_loop() -> bool:
|
||||
"""Check if there's a running event loop."""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
return loop.is_running()
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
|
||||
async def run_in_appropriate_context(
|
||||
async_func: Callable[..., Awaitable[T]],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> T:
|
||||
"""Run an async function in the appropriate context.
|
||||
|
||||
This function handles the common pattern of needing to run async
|
||||
code from potentially sync contexts, automatically detecting the
|
||||
current context and using the appropriate execution method.
|
||||
|
||||
Args:
|
||||
async_func: The async function to execute
|
||||
*args: Positional arguments for the function
|
||||
**kwargs: Keyword arguments for the function
|
||||
|
||||
Returns:
|
||||
The result of the async function
|
||||
|
||||
Raises:
|
||||
RuntimeError: If unable to execute the async function
|
||||
"""
|
||||
context_info = detect_async_context()
|
||||
|
||||
if context_info.is_async and context_info.has_running_loop:
|
||||
# Already in async context with running loop - just await
|
||||
return await async_func(*args, **kwargs)
|
||||
|
||||
# Not in async context or no running loop - need to run in new loop
|
||||
try:
|
||||
# Try to get existing event loop
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
# Create new loop if closed
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
# No event loop, create one
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Run the async function
|
||||
return loop.run_until_complete(async_func(*args, **kwargs))
|
||||
except RuntimeError as e:
|
||||
# If asyncio.run is available and works better, use it
|
||||
if sys.version_info >= (3, 7):
|
||||
try:
|
||||
return asyncio.run(cast(Coroutine[Any, Any, T], async_func(*args, **kwargs)))
|
||||
except RuntimeError:
|
||||
pass
|
||||
raise RuntimeError(f"Failed to execute async function: {e}") from e
|
||||
|
||||
|
||||
def create_async_sync_wrapper(
|
||||
sync_resolver_func: Callable[..., T],
|
||||
async_resolver_func: Callable[..., Awaitable[T]],
|
||||
) -> tuple[Callable[..., T], Callable[..., Awaitable[T]]]:
|
||||
"""Create async and sync wrapper functions for factory pattern.
|
||||
|
||||
This function consolidates the duplicated async wrapper logic found in
|
||||
factory functions, following DRY principles and ensuring consistency.
|
||||
|
||||
Args:
|
||||
sync_resolver_func: Function to resolve configuration synchronously
|
||||
async_resolver_func: Function to resolve configuration asynchronously
|
||||
|
||||
Returns:
|
||||
Tuple of (sync_factory, async_factory) functions that handle
|
||||
RunnableConfig processing and proper async/sync execution
|
||||
|
||||
Example:
|
||||
```python
|
||||
def resolve_config_sync(runnable_config):
|
||||
return load_config()
|
||||
|
||||
async def resolve_config_async(runnable_config):
|
||||
return await load_config_async()
|
||||
|
||||
sync_factory, async_factory = create_async_sync_wrapper(
|
||||
resolve_config_sync,
|
||||
resolve_config_async
|
||||
)
|
||||
```
|
||||
"""
|
||||
def create_sync_factory():
|
||||
def sync_factory(config: dict[str, Any]) -> Any:
|
||||
"""Create factory with RunnableConfig (optimized)."""
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
runnable_config = RunnableConfig(configurable=config.get("configurable", {}))
|
||||
resolved_config = sync_resolver_func(runnable_config=runnable_config)
|
||||
|
||||
# Return the resolved configuration
|
||||
# Actual factory creation logic should be handled by the caller
|
||||
return resolved_config
|
||||
|
||||
return sync_factory
|
||||
|
||||
def create_async_factory():
|
||||
async def async_factory(config: dict[str, Any]) -> Any:
|
||||
"""Create factory with RunnableConfig (async optimized)."""
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
runnable_config = RunnableConfig(configurable=config.get("configurable", {}))
|
||||
resolved_config = await async_resolver_func(runnable_config=runnable_config)
|
||||
|
||||
# Return the resolved configuration
|
||||
# Actual factory creation logic should be handled by the caller
|
||||
return resolved_config
|
||||
|
||||
return async_factory
|
||||
|
||||
return create_sync_factory(), create_async_factory()
|
||||
|
||||
|
||||
def handle_sync_async_context(
|
||||
app_config: Any,
|
||||
service_factory: Any,
|
||||
async_func: Callable[..., Awaitable[T]],
|
||||
sync_fallback: Callable[[], T],
|
||||
) -> T:
|
||||
"""Handle sync/async context detection for graph creation.
|
||||
|
||||
This function uses centralized context detection to reliably determine
|
||||
execution context and choose the appropriate graph creation method.
|
||||
|
||||
Args:
|
||||
app_config: Application configuration
|
||||
service_factory: Service factory instance
|
||||
async_func: Async function to execute (e.g., create_graph_with_services)
|
||||
sync_fallback: Synchronous fallback function if async fails
|
||||
|
||||
Returns:
|
||||
Result from either async or sync execution
|
||||
|
||||
Example:
|
||||
```python
|
||||
graph = handle_sync_async_context(
|
||||
app_config,
|
||||
service_factory,
|
||||
lambda: create_graph_with_services(app_config, service_factory),
|
||||
lambda: get_graph() # sync fallback
|
||||
)
|
||||
```
|
||||
"""
|
||||
context_info = detect_async_context()
|
||||
|
||||
if context_info.is_async:
|
||||
# We're in async context but called synchronously
|
||||
# Create non-optimized sync version to avoid event loop conflicts
|
||||
return sync_fallback()
|
||||
# No async context detected - safe to run async code
|
||||
try:
|
||||
# Use centralized context handler
|
||||
return asyncio.run(cast(Coroutine[Any, Any, T], async_func(app_config, service_factory)))
|
||||
except RuntimeError:
|
||||
# Fallback to sync creation if async fails
|
||||
return sync_fallback()
|
||||
|
||||
@@ -188,7 +188,7 @@ class CircuitBreaker:
|
||||
self.stats = CircuitBreakerStats()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def _should_attempt_reset(self) -> bool:
|
||||
def _should_attempt_reset(self) -> bool:
|
||||
"""Check if we should attempt to reset the circuit breaker."""
|
||||
if self.stats.state != CircuitBreakerState.OPEN:
|
||||
return False
|
||||
@@ -243,7 +243,7 @@ class CircuitBreaker:
|
||||
"""
|
||||
# Check if we should attempt reset
|
||||
if self.stats.state == CircuitBreakerState.OPEN:
|
||||
if await self._should_attempt_reset():
|
||||
if self._should_attempt_reset():
|
||||
async with self._lock:
|
||||
logger.info("CircuitBreaker state transition: OPEN -> HALF_OPEN")
|
||||
self.stats.state = CircuitBreakerState.HALF_OPEN
|
||||
@@ -360,10 +360,8 @@ async def retry_with_backoff(
|
||||
config.jitter,
|
||||
)
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
time.sleep(delay)
|
||||
# Always use async sleep since we're in an async context
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
if last_exception:
|
||||
raise last_exception
|
||||
|
||||
@@ -267,11 +267,13 @@ class ServiceLifecycleManager:
|
||||
|
||||
try:
|
||||
# Stop health monitoring first
|
||||
cancelled_error = None
|
||||
try:
|
||||
await self._stop_health_monitoring()
|
||||
except asyncio.CancelledError:
|
||||
except asyncio.CancelledError as e:
|
||||
# Health monitoring cancellation is expected during shutdown
|
||||
logger.debug("Health monitoring task cancelled during shutdown")
|
||||
cancelled_error = e # Store for re-raising after cleanup
|
||||
|
||||
# Shutdown services in reverse dependency order
|
||||
await asyncio.wait_for(
|
||||
@@ -284,6 +286,10 @@ class ServiceLifecycleManager:
|
||||
|
||||
logger.info(f"Service lifecycle shutdown completed in {self._shutdown_time:.2f}s")
|
||||
|
||||
# Re-raise CancelledError if it occurred during health monitoring
|
||||
if cancelled_error is not None:
|
||||
raise cancelled_error
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Service shutdown timed out after {timeout}s, forcing shutdown")
|
||||
await self._force_shutdown_all()
|
||||
|
||||
@@ -590,7 +590,7 @@ def log_alert_handler(message: str) -> None:
|
||||
|
||||
def console_alert_handler(message: str) -> None:
|
||||
"""Alert handler that prints to console."""
|
||||
print(f"🚨 SERVICE ALERT: {message}")
|
||||
print(f"SERVICE ALERT: {message}")
|
||||
|
||||
|
||||
# Health check utilities
|
||||
|
||||
@@ -34,7 +34,7 @@ class URLDiscoverer:
|
||||
)
|
||||
self.config = config or {}
|
||||
|
||||
async def discover_urls(self, base_url: str) -> list[str]:
|
||||
def discover_urls(self, base_url: str) -> list[str]:
|
||||
"""Discover URLs from a base URL.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -37,7 +37,7 @@ class LegacyURLValidator:
|
||||
# Accept additional kwargs for backward compatibility
|
||||
self.config.update(kwargs)
|
||||
|
||||
async def validate_url(self, url: str, level: str = "standard") -> dict[str, Any]:
|
||||
def validate_url(self, url: str, level: str = "standard") -> dict[str, Any]:
|
||||
"""Validate a URL.
|
||||
|
||||
Args:
|
||||
@@ -68,7 +68,7 @@ class LegacyURLValidator:
|
||||
"""
|
||||
return self.config.get("level", "standard")
|
||||
|
||||
async def batch_validate_urls(self, urls: list[str]) -> list[dict[str, Any]]:
|
||||
def batch_validate_urls(self, urls: list[str]) -> list[dict[str, Any]]:
|
||||
"""Batch validate multiple URLs (for test compatibility).
|
||||
|
||||
Args:
|
||||
@@ -79,7 +79,7 @@ class LegacyURLValidator:
|
||||
"""
|
||||
results = []
|
||||
for url in urls:
|
||||
result = await self.validate_url(url)
|
||||
result = self.validate_url(url)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
"""Core utilities for Business Buddy framework."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
# Context detection utilities - re-export from networking for convenience
|
||||
from ..networking import AsyncContextInfo, detect_async_context, run_in_appropriate_context
|
||||
from ..networking.async_utils import gather_with_concurrency
|
||||
from .json_extractor import (
|
||||
JsonExtractionError,
|
||||
@@ -45,6 +49,10 @@ __all__ = [
|
||||
"create_lazy_loader",
|
||||
# Async utilities
|
||||
"gather_with_concurrency",
|
||||
# Context detection utilities
|
||||
"AsyncContextInfo",
|
||||
"detect_async_context",
|
||||
"run_in_appropriate_context",
|
||||
# JSON extraction utilities
|
||||
"JsonExtractorCore",
|
||||
"JsonRecoveryStrategies",
|
||||
@@ -79,4 +87,105 @@ __all__ = [
|
||||
"is_git_repo_url",
|
||||
"is_pdf_url",
|
||||
"get_url_type",
|
||||
# Error utilities
|
||||
"normalize_errors_to_list",
|
||||
# Regex security utilities
|
||||
"compile_safe_regex",
|
||||
"get_safe_regex_compiler",
|
||||
]
|
||||
|
||||
|
||||
def normalize_errors_to_list(errors: Any) -> list[Any]:
|
||||
"""Normalize errors to a list format for consistent processing.
|
||||
|
||||
This function handles various error formats and ensures the result
|
||||
is always a list, making error handling more predictable across
|
||||
the codebase.
|
||||
|
||||
Args:
|
||||
errors: Error value that could be a list, single value, or None
|
||||
|
||||
Returns:
|
||||
List of errors (empty list if no errors)
|
||||
|
||||
Examples:
|
||||
>>> normalize_errors_to_list(None)
|
||||
[]
|
||||
>>> normalize_errors_to_list("error")
|
||||
["error"]
|
||||
>>> normalize_errors_to_list(["error1", "error2"])
|
||||
["error1", "error2"]
|
||||
>>> normalize_errors_to_list({"message": "error"})
|
||||
[{"message": "error"}]
|
||||
"""
|
||||
# Handle None or falsy values
|
||||
if not errors:
|
||||
return []
|
||||
|
||||
# Already a list - return as-is
|
||||
if isinstance(errors, list):
|
||||
return errors
|
||||
|
||||
# String or bytes - special handling
|
||||
if isinstance(errors, (str, bytes)):
|
||||
return [errors]
|
||||
|
||||
# Check for Mapping types (e.g., dict) and wrap them in a list
|
||||
from collections.abc import Mapping
|
||||
if isinstance(errors, Mapping):
|
||||
return [errors]
|
||||
|
||||
# Check if it's iterable but not string/bytes/mapping
|
||||
try:
|
||||
# Try to iterate
|
||||
iter(errors)
|
||||
# It's iterable - convert to list
|
||||
return list(errors)
|
||||
except TypeError:
|
||||
# Not iterable - wrap in list
|
||||
return [errors]
|
||||
|
||||
|
||||
# Centralized regex security utilities
|
||||
_safe_regex_compiler = None
|
||||
|
||||
|
||||
def get_safe_regex_compiler():
|
||||
"""Get the global safe regex compiler instance.
|
||||
|
||||
Returns:
|
||||
SafeRegexCompiler: Global compiler instance with consistent settings
|
||||
"""
|
||||
global _safe_regex_compiler
|
||||
if _safe_regex_compiler is None:
|
||||
from biz_bud.core.utils.regex_security import SafeRegexCompiler
|
||||
_safe_regex_compiler = SafeRegexCompiler(
|
||||
max_pattern_length=1000, # Reasonable limit for most patterns
|
||||
default_timeout=2.0 # Conservative timeout
|
||||
)
|
||||
return _safe_regex_compiler
|
||||
|
||||
|
||||
def compile_safe_regex(pattern: str, flags: int = 0, timeout: float | None = None):
|
||||
"""Compile a regex pattern using centralized security validation.
|
||||
|
||||
This function should be used instead of re.compile() throughout the codebase
|
||||
to ensure consistent ReDoS protection and security validation.
|
||||
|
||||
Args:
|
||||
pattern: Regex pattern to compile
|
||||
flags: Regex compilation flags (same as re.compile)
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
Pattern: Compiled regex pattern
|
||||
|
||||
Raises:
|
||||
RegexSecurityError: If pattern is potentially dangerous
|
||||
|
||||
Example:
|
||||
>>> pattern = compile_safe_regex(r'\\d{3}-\\d{3}-\\d{4}')
|
||||
>>> match = pattern.search('123-456-7890')
|
||||
"""
|
||||
compiler = get_safe_regex_compiler()
|
||||
return compiler.compile_safe(pattern, flags, timeout)
|
||||
|
||||
163
src/biz_bud/core/utils/graph_helpers.py
Normal file
163
src/biz_bud/core/utils/graph_helpers.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""Helper functions for graph creation and state initialization.
|
||||
|
||||
This module consolidates common patterns used in graph creation to reduce
|
||||
code duplication and improve maintainability.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from biz_bud.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def process_state_query(
|
||||
query: str | None,
|
||||
messages: list[dict[str, Any]] | None,
|
||||
state_update: dict[str, Any] | None,
|
||||
default_query: str
|
||||
) -> str:
|
||||
"""Process and extract query from various sources.
|
||||
|
||||
Args:
|
||||
query: Direct query string
|
||||
messages: Message history to extract query from
|
||||
state_update: State update that may contain query
|
||||
default_query: Default query to use if none found
|
||||
|
||||
Returns:
|
||||
Processed query string
|
||||
"""
|
||||
# Extract query from messages if not provided
|
||||
if query is None and messages:
|
||||
# Look for the last human/user message
|
||||
for msg in reversed(messages):
|
||||
# Handle different message formats
|
||||
role = msg.get("role")
|
||||
msg_type = msg.get("type")
|
||||
if role in ("user", "human") or msg_type == "human":
|
||||
query = msg.get("content", "")
|
||||
break
|
||||
|
||||
return query or default_query
|
||||
|
||||
|
||||
def format_raw_input(
|
||||
raw_input: str | dict[str, Any] | None,
|
||||
user_query: str
|
||||
) -> tuple[str, str]:
|
||||
"""Format raw input into a consistent string format.
|
||||
|
||||
Args:
|
||||
raw_input: Raw input data (string, dict, or None).
|
||||
If a dict, must be JSON serializable.
|
||||
user_query: User query to use as default
|
||||
|
||||
Returns:
|
||||
Tuple of (raw_input_str, extracted_query)
|
||||
|
||||
Raises:
|
||||
TypeError: If raw_input is a dict but not JSON serializable.
|
||||
"""
|
||||
if raw_input is None:
|
||||
return f'{{"query": "{user_query}"}}', user_query
|
||||
|
||||
if isinstance(raw_input, dict):
|
||||
# If raw_input has a query field, use it
|
||||
extracted_query = raw_input.get("query", user_query)
|
||||
# Avoid json.dumps for simple cases
|
||||
if len(raw_input) == 1 and "query" in raw_input:
|
||||
return f'{{"query": "{raw_input["query"]}"}}', extracted_query
|
||||
else:
|
||||
try:
|
||||
return json.dumps(raw_input), extracted_query
|
||||
except (TypeError, ValueError) as e:
|
||||
# Fallback: show error message in string
|
||||
raw_input_str = f"<non-serializable dict: {e}>"
|
||||
return raw_input_str, extracted_query
|
||||
|
||||
# For unsupported types
|
||||
if not isinstance(raw_input, str):
|
||||
raw_input_str = f"<unsupported type: {type(raw_input).__name__}>"
|
||||
return raw_input_str, user_query
|
||||
|
||||
return str(raw_input), user_query
|
||||
|
||||
|
||||
def extract_state_update_data(
|
||||
state_update: dict[str, Any] | None,
|
||||
messages: list[dict[str, Any]] | None,
|
||||
raw_input: str | dict[str, Any] | None,
|
||||
thread_id: str | None
|
||||
) -> tuple[list[dict[str, Any]] | None, str | dict[str, Any] | None, str | None]:
|
||||
"""Extract data from state update if provided.
|
||||
|
||||
Args:
|
||||
state_update: State update dictionary from LangGraph API
|
||||
messages: Existing messages
|
||||
raw_input: Existing raw input
|
||||
thread_id: Existing thread ID
|
||||
|
||||
Returns:
|
||||
Tuple of (messages, raw_input, thread_id)
|
||||
"""
|
||||
if not state_update:
|
||||
return messages, raw_input, thread_id
|
||||
|
||||
# Extract messages if present
|
||||
if "messages" in state_update and not messages:
|
||||
messages = state_update["messages"]
|
||||
|
||||
# Extract other fields if present
|
||||
if "raw_input" in state_update and not raw_input:
|
||||
raw_input = state_update["raw_input"]
|
||||
if "thread_id" in state_update and not thread_id:
|
||||
thread_id = state_update["thread_id"]
|
||||
|
||||
return messages, raw_input, thread_id
|
||||
|
||||
|
||||
def create_initial_state_dict(
|
||||
raw_input_str: str,
|
||||
user_query: str,
|
||||
messages: list[dict[str, Any]],
|
||||
thread_id: str,
|
||||
config_dict: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Create the initial state dictionary.
|
||||
|
||||
Args:
|
||||
raw_input_str: Formatted raw input string
|
||||
user_query: User query
|
||||
messages: Message history
|
||||
thread_id: Thread identifier
|
||||
config_dict: Configuration dictionary
|
||||
|
||||
Returns:
|
||||
Initial state dictionary
|
||||
"""
|
||||
return {
|
||||
"raw_input": raw_input_str,
|
||||
"parsed_input": {
|
||||
"raw_payload": {
|
||||
"query": user_query,
|
||||
},
|
||||
"user_query": user_query,
|
||||
},
|
||||
"messages": messages,
|
||||
"initial_input": {
|
||||
"query": user_query,
|
||||
},
|
||||
"thread_id": thread_id,
|
||||
"config": config_dict,
|
||||
"input_metadata": {},
|
||||
"context": {},
|
||||
"status": "pending",
|
||||
"errors": [],
|
||||
"run_metadata": {},
|
||||
"is_last_step": False,
|
||||
"final_result": None,
|
||||
}
|
||||
@@ -11,11 +11,17 @@ from __future__ import annotations
|
||||
import contextlib
|
||||
import json
|
||||
import re
|
||||
import signal
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, TypedDict, cast
|
||||
|
||||
# Lazy imports to avoid circular dependency
|
||||
from biz_bud.core.errors import ValidationError
|
||||
from biz_bud.logging import get_logger
|
||||
|
||||
try:
|
||||
import jsonschema
|
||||
from jsonschema import ValidationError as JsonSchemaValidationError
|
||||
@@ -28,9 +34,6 @@ except ImportError:
|
||||
# Make it available as a constant
|
||||
JSONSCHEMA_AVAILABLE = _jsonschema_available
|
||||
|
||||
from biz_bud.core.errors import ErrorContext, JsonParsingError
|
||||
from biz_bud.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
@@ -47,6 +50,14 @@ JsonDict = dict[str, JsonValue]
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _get_error_context():
|
||||
"""Lazy import for ErrorContext to avoid circular dependency."""
|
||||
from biz_bud.core.errors import ErrorContext
|
||||
return ErrorContext
|
||||
|
||||
|
||||
|
||||
|
||||
class JsonExtractionErrorType(Enum):
|
||||
"""Types of JSON extraction errors with structured categorization."""
|
||||
|
||||
@@ -60,7 +71,7 @@ class JsonExtractionErrorType(Enum):
|
||||
UNKNOWN_ERROR = "unknown_error"
|
||||
|
||||
|
||||
class JsonExtractionError(JsonParsingError):
|
||||
class JsonExtractionError(Exception):
|
||||
"""Specialized exception for JSON extraction failures with detailed context.
|
||||
|
||||
This extends the base JsonParsingError with specific functionality
|
||||
@@ -74,29 +85,47 @@ class JsonExtractionError(JsonParsingError):
|
||||
pattern_attempted: str | None = None,
|
||||
recovery_strategies_tried: list[str] | None = None,
|
||||
extraction_time_ms: float | None = None,
|
||||
context: ErrorContext | None = None,
|
||||
context: object | None = None, # ErrorContext | None but avoiding import
|
||||
cause: Exception | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize JSON extraction error with detailed extraction context."""
|
||||
super().__init__(
|
||||
message,
|
||||
error_type=error_type,
|
||||
context=context,
|
||||
cause=cause,
|
||||
**kwargs,
|
||||
)
|
||||
# Create ErrorContext if not provided
|
||||
if context is None:
|
||||
ErrorContext = _get_error_context()
|
||||
context = ErrorContext()
|
||||
|
||||
super().__init__(message)
|
||||
|
||||
# Store attributes for compatibility with JsonParsingError
|
||||
self.error_type = error_type
|
||||
self.context = context
|
||||
self.cause = cause
|
||||
self.pattern_attempted = pattern_attempted
|
||||
self.recovery_strategies_tried = recovery_strategies_tried or []
|
||||
self.extraction_time_ms = extraction_time_ms
|
||||
|
||||
# Store additional kwargs for compatibility
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
# Add extraction-specific metadata
|
||||
if pattern_attempted:
|
||||
self.context.metadata["pattern_attempted"] = pattern_attempted
|
||||
if recovery_strategies_tried:
|
||||
self.context.metadata["recovery_strategies_tried"] = recovery_strategies_tried
|
||||
if extraction_time_ms is not None:
|
||||
self.context.metadata["extraction_time_ms"] = extraction_time_ms
|
||||
if pattern_attempted and hasattr(self.context, 'metadata'):
|
||||
getattr(self.context, 'metadata')["pattern_attempted"] = pattern_attempted
|
||||
if recovery_strategies_tried and hasattr(self.context, 'metadata'):
|
||||
getattr(self.context, 'metadata')["recovery_strategies_tried"] = recovery_strategies_tried
|
||||
if extraction_time_ms is not None and hasattr(self.context, 'metadata'):
|
||||
getattr(self.context, 'metadata')["extraction_time_ms"] = extraction_time_ms
|
||||
|
||||
def to_log_context(self) -> dict[str, object]:
|
||||
"""Generate context for logging (compatibility with BusinessBuddyError)."""
|
||||
return {
|
||||
"error_type": str(self.error_type) if hasattr(self, 'error_type') else None,
|
||||
"message": str(self),
|
||||
"pattern_attempted": self.pattern_attempted,
|
||||
"recovery_strategies_tried": self.recovery_strategies_tried,
|
||||
"extraction_time_ms": self.extraction_time_ms,
|
||||
}
|
||||
|
||||
def to_extraction_context(self) -> dict[str, object]:
|
||||
"""Generate extraction-specific context for logging and debugging."""
|
||||
@@ -119,7 +148,7 @@ class JsonValidationError(JsonExtractionError):
|
||||
validation_rule: str | None = None,
|
||||
expected_type: str | None = None,
|
||||
actual_value: object = None,
|
||||
context: ErrorContext | None = None,
|
||||
context: object | None = None, # ErrorContext | None but avoiding import
|
||||
cause: Exception | None = None,
|
||||
):
|
||||
"""Initialize JSON validation error with validation context."""
|
||||
@@ -135,14 +164,14 @@ class JsonValidationError(JsonExtractionError):
|
||||
self.actual_value = actual_value
|
||||
|
||||
# Add validation-specific metadata
|
||||
if schema_path:
|
||||
self.context.metadata["schema_path"] = schema_path
|
||||
if validation_rule:
|
||||
self.context.metadata["validation_rule"] = validation_rule
|
||||
if expected_type:
|
||||
self.context.metadata["expected_type"] = expected_type
|
||||
if actual_value is not None:
|
||||
self.context.metadata["actual_value_type"] = type(actual_value).__name__
|
||||
if schema_path and hasattr(self, 'context') and hasattr(self.context, 'metadata'):
|
||||
getattr(self.context, 'metadata')["schema_path"] = schema_path
|
||||
if validation_rule and hasattr(self, 'context') and hasattr(self.context, 'metadata'):
|
||||
getattr(self.context, 'metadata')["validation_rule"] = validation_rule
|
||||
if expected_type and hasattr(self, 'context') and hasattr(self.context, 'metadata'):
|
||||
getattr(self.context, 'metadata')["expected_type"] = expected_type
|
||||
if actual_value is not None and hasattr(self, 'context') and hasattr(self.context, 'metadata'):
|
||||
getattr(self.context, 'metadata')["actual_value_type"] = type(actual_value).__name__
|
||||
|
||||
|
||||
class JsonRecoveryError(JsonExtractionError):
|
||||
@@ -154,7 +183,7 @@ class JsonRecoveryError(JsonExtractionError):
|
||||
failed_strategy: str | None = None,
|
||||
strategy_error: str | None = None,
|
||||
original_text_length: int | None = None,
|
||||
context: ErrorContext | None = None,
|
||||
context: object | None = None, # ErrorContext | None but avoiding import
|
||||
cause: Exception | None = None,
|
||||
):
|
||||
"""Initialize JSON recovery error with strategy context."""
|
||||
@@ -169,12 +198,12 @@ class JsonRecoveryError(JsonExtractionError):
|
||||
self.original_text_length = original_text_length
|
||||
|
||||
# Add recovery-specific metadata
|
||||
if failed_strategy:
|
||||
self.context.metadata["failed_strategy"] = failed_strategy
|
||||
if strategy_error:
|
||||
self.context.metadata["strategy_error"] = strategy_error
|
||||
if original_text_length is not None:
|
||||
self.context.metadata["original_text_length"] = original_text_length
|
||||
if failed_strategy and hasattr(self, 'context') and hasattr(self.context, 'metadata'):
|
||||
getattr(self.context, 'metadata')["failed_strategy"] = failed_strategy
|
||||
if strategy_error and hasattr(self, 'context') and hasattr(self.context, 'metadata'):
|
||||
getattr(self.context, 'metadata')["strategy_error"] = strategy_error
|
||||
if original_text_length is not None and hasattr(self, 'context') and hasattr(self.context, 'metadata'):
|
||||
getattr(self.context, 'metadata')["original_text_length"] = original_text_length
|
||||
|
||||
|
||||
class JsonExtractionConfig(TypedDict, total=False):
|
||||
@@ -225,12 +254,20 @@ class JsonExtractionResult(TypedDict):
|
||||
validation_errors: list[str] | None
|
||||
|
||||
|
||||
# Extraction patterns - compiled for performance
|
||||
JSON_CODE_BLOCK_PATTERN = re.compile(
|
||||
# Extraction patterns - compiled for performance with centralized ReDoS protection
|
||||
from biz_bud.core.utils.regex_security import SafeRegexCompiler
|
||||
|
||||
# Create a local compiler instance to avoid circular imports
|
||||
_compiler = SafeRegexCompiler(max_pattern_length=1000, default_timeout=2.0)
|
||||
|
||||
JSON_CODE_BLOCK_PATTERN = _compiler.compile_safe(
|
||||
r"```(?:json|javascript)?\s*\n?({.*?})\s*\n?```", re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
JSON_BRACES_PATTERN = re.compile(r"({[^{}]*(?:{[^{}]*}[^{}]*)*})", re.DOTALL)
|
||||
JSON_OBJECT_PATTERN = re.compile(r"(\{(?:[^{}]|{[^{}]*})*\})", re.DOTALL)
|
||||
# SECURITY FIX: Replace ReDoS-vulnerable patterns with atomic groups and possessive quantifiers
|
||||
# Old pattern: r"({[^{}]*(?:{[^{}]*}[^{}]*)*})" - vulnerable to polynomial backtracking
|
||||
JSON_BRACES_PATTERN = _compiler.compile_safe(r"(\{[^{}]{0,1000}(?:\{[^{}]{0,1000}\}[^{}]{0,1000}){0,50}\})", re.DOTALL)
|
||||
# Old pattern: r"(\{(?:[^{}]|{[^{}]*})*\})" - vulnerable to exponential backtracking
|
||||
JSON_OBJECT_PATTERN = _compiler.compile_safe(r"(\{(?:[^{}]{0,1000}|\{[^{}]{0,1000}\}){0,100}\})", re.DOTALL)
|
||||
|
||||
|
||||
class JsonRecoveryStrategies:
|
||||
@@ -317,7 +354,7 @@ class JsonRecoveryStrategies:
|
||||
@lru_cache(maxsize=32)
|
||||
def _get_log_pattern() -> re.Pattern[str]:
|
||||
"""Get compiled regex pattern for log suffixes with caching."""
|
||||
return re.compile(r"\s+\S+\.py:\d+\s*$")
|
||||
return _compiler.compile_safe(r"\s+\S+\.py:\d+\s*$")
|
||||
|
||||
@classmethod
|
||||
def fix_truncation(cls, text: str) -> str:
|
||||
@@ -350,7 +387,10 @@ class JsonRecoveryStrategies:
|
||||
@lru_cache(maxsize=16)
|
||||
def _get_escape_patterns() -> tuple[re.Pattern[str], ...]:
|
||||
"""Get compiled regex patterns for escape fixing with caching."""
|
||||
unescaped_quote_pattern = re.compile(r'(?<!\\)"(?=[^"]*"[^"]*$)')
|
||||
# SECURITY FIX: Replace ReDoS-vulnerable lookahead pattern with safer approach
|
||||
# Old pattern: r'(?<!\\)"(?=[^"]*"[^"]*$)' - vulnerable to quadratic backtracking
|
||||
# Use simpler approach that doesn't rely on complex lookaheads
|
||||
unescaped_quote_pattern = _compiler.compile_safe(r'(?<!\\)"(?!")')
|
||||
return (unescaped_quote_pattern,)
|
||||
|
||||
@classmethod
|
||||
@@ -498,6 +538,66 @@ class JsonExtractorCore:
|
||||
if config:
|
||||
self.config.update(config)
|
||||
|
||||
def _validate_input_security(self, text: str, max_length: int = 50000) -> None:
|
||||
"""Validate input for security issues to prevent ReDoS attacks.
|
||||
|
||||
Args:
|
||||
text: Input text to validate
|
||||
max_length: Maximum allowed input length in characters (default: 50000, ~50KB).
|
||||
This limit is set to prevent potential ReDoS attacks and excessive memory usage.
|
||||
Override if your use case requires larger input, but ensure security implications are considered.
|
||||
|
||||
Raises:
|
||||
ValueError: If input poses security risk
|
||||
"""
|
||||
if len(text) > max_length:
|
||||
raise ValidationError(f"Input text too long ({len(text)} chars), max allowed: {max_length}")
|
||||
|
||||
# Check for patterns that could cause ReDoS
|
||||
dangerous_patterns = [
|
||||
r'(\{){10,}', # Many opening braces
|
||||
r'(\[){10,}', # Many opening brackets
|
||||
r'("){20,}', # Many quotes
|
||||
r'(<){10,}', # Many angle brackets
|
||||
]
|
||||
|
||||
for pattern in dangerous_patterns:
|
||||
if re.search(pattern, text):
|
||||
logger.warning(f"Input contains potentially dangerous pattern: {pattern}")
|
||||
# Don't reject, but log for monitoring
|
||||
break
|
||||
|
||||
@contextmanager
|
||||
def _regex_timeout_context(self, timeout_seconds: float = 1.0):
|
||||
"""Context manager to timeout regex operations to prevent ReDoS.
|
||||
|
||||
Args:
|
||||
timeout_seconds: Maximum execution time allowed
|
||||
|
||||
Yields:
|
||||
None
|
||||
|
||||
Raises:
|
||||
TimeoutError: If regex execution exceeds timeout
|
||||
"""
|
||||
def timeout_handler(signum: int, frame) -> None:
|
||||
raise TimeoutError(f"Regex execution exceeded {timeout_seconds}s timeout")
|
||||
|
||||
# Set up timeout alarm (Unix/Linux only)
|
||||
old_handler = None
|
||||
try:
|
||||
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
|
||||
signal.alarm(int(timeout_seconds))
|
||||
yield
|
||||
except AttributeError:
|
||||
# Windows doesn't support SIGALRM, skip timeout protection
|
||||
yield
|
||||
finally:
|
||||
with contextlib.suppress(AttributeError, NameError):
|
||||
signal.alarm(0)
|
||||
if 'old_handler' in locals() and old_handler is not None:
|
||||
signal.signal(signal.SIGALRM, old_handler)
|
||||
|
||||
def _validate_against_schema(self, data: JsonDict) -> tuple[bool, list[str]]:
|
||||
"""Validate extracted JSON against provided schema.
|
||||
|
||||
@@ -628,6 +728,20 @@ class JsonExtractorCore:
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# SECURITY: Validate input to prevent ReDoS attacks
|
||||
try:
|
||||
self._validate_input_security(text)
|
||||
except ValueError as e:
|
||||
logger.error(f"Security validation failed: {e}")
|
||||
return self._create_result(
|
||||
data=None,
|
||||
success=False,
|
||||
error_type=JsonExtractionErrorType.EXTRACTION_FAILED,
|
||||
recovery_used=False,
|
||||
extraction_time_ms=0.0,
|
||||
pattern_used=None,
|
||||
)
|
||||
|
||||
if self.config.get("log_performance", True):
|
||||
logger.debug(
|
||||
"Starting JSON extraction",
|
||||
@@ -742,31 +856,43 @@ class JsonExtractorCore:
|
||||
)
|
||||
|
||||
elif pattern_name == "code_blocks":
|
||||
# Try to find JSON in code blocks
|
||||
matches = JSON_CODE_BLOCK_PATTERN.findall(text)
|
||||
for match in matches:
|
||||
result = self._try_parse_with_recovery(match)
|
||||
if result["data"] is not None:
|
||||
result["recovery_used"] = recovery_used
|
||||
return result
|
||||
# Try to find JSON in code blocks with timeout protection
|
||||
try:
|
||||
with self._regex_timeout_context(1.0):
|
||||
matches = JSON_CODE_BLOCK_PATTERN.findall(text)
|
||||
for match in matches:
|
||||
result = self._try_parse_with_recovery(match)
|
||||
if result["data"] is not None:
|
||||
result["recovery_used"] = recovery_used
|
||||
return result
|
||||
except TimeoutError:
|
||||
logger.warning(f"Regex timeout in {pattern_name} pattern")
|
||||
|
||||
elif pattern_name == "object_pattern":
|
||||
# Try the improved JSON object pattern
|
||||
matches = JSON_OBJECT_PATTERN.findall(text)
|
||||
for match in matches:
|
||||
result = self._try_parse_with_recovery(match)
|
||||
if result["data"] is not None:
|
||||
result["recovery_used"] = recovery_used
|
||||
return result
|
||||
# Try the improved JSON object pattern with timeout protection
|
||||
try:
|
||||
with self._regex_timeout_context(1.0):
|
||||
matches = JSON_OBJECT_PATTERN.findall(text)
|
||||
for match in matches:
|
||||
result = self._try_parse_with_recovery(match)
|
||||
if result["data"] is not None:
|
||||
result["recovery_used"] = recovery_used
|
||||
return result
|
||||
except TimeoutError:
|
||||
logger.warning(f"Regex timeout in {pattern_name} pattern")
|
||||
|
||||
elif pattern_name == "braces_pattern":
|
||||
# Try to find JSON objects directly in the text with basic pattern
|
||||
matches = JSON_BRACES_PATTERN.findall(text)
|
||||
for match in matches:
|
||||
result = self._try_parse_with_recovery(match)
|
||||
if result["data"] is not None:
|
||||
result["recovery_used"] = recovery_used
|
||||
return result
|
||||
# Try to find JSON objects directly in the text with timeout protection
|
||||
try:
|
||||
with self._regex_timeout_context(1.0):
|
||||
matches = JSON_BRACES_PATTERN.findall(text)
|
||||
for match in matches:
|
||||
result = self._try_parse_with_recovery(match)
|
||||
if result["data"] is not None:
|
||||
result["recovery_used"] = recovery_used
|
||||
return result
|
||||
except TimeoutError:
|
||||
logger.warning(f"Regex timeout in {pattern_name} pattern")
|
||||
|
||||
elif pattern_name == "balanced_extraction":
|
||||
# Try to extract by finding balanced braces
|
||||
|
||||
@@ -213,7 +213,7 @@ class AsyncFactoryManager[T]:
|
||||
# Always clear the weak reference
|
||||
self._factory_ref = None
|
||||
|
||||
async def check_factory_health(self) -> bool:
|
||||
def check_factory_health(self) -> bool:
|
||||
"""Check if the current factory is healthy and functional.
|
||||
|
||||
Returns:
|
||||
@@ -255,7 +255,7 @@ class AsyncFactoryManager[T]:
|
||||
A healthy factory instance.
|
||||
"""
|
||||
# Check if current factory is healthy
|
||||
if await self.check_factory_health():
|
||||
if self.check_factory_health():
|
||||
if self._factory_ref is None:
|
||||
raise StateError("Factory reference is None after health check passed", error_code=ErrorNamespace.STATE_CONSISTENCY_ERROR)
|
||||
factory = self._factory_ref()
|
||||
|
||||
586
src/biz_bud/core/utils/regex_security.py
Normal file
586
src/biz_bud/core/utils/regex_security.py
Normal file
@@ -0,0 +1,586 @@
|
||||
"""Regex security utilities for preventing ReDoS attacks.
|
||||
|
||||
This module provides comprehensive regex security validation and safe execution
|
||||
utilities to prevent Regular Expression Denial of Service (ReDoS) attacks
|
||||
throughout the codebase.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import signal
|
||||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures import TimeoutError as FutureTimeoutError
|
||||
from contextlib import contextmanager
|
||||
from functools import lru_cache
|
||||
from typing import Any, Callable, Pattern
|
||||
|
||||
from biz_bud.core.errors import ValidationError
|
||||
from biz_bud.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# Thread pool for timeout execution (reused for efficiency)
|
||||
_timeout_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="regex_timeout")
|
||||
|
||||
|
||||
def _execute_with_timeout(func: Callable[[], Any], timeout_seconds: float, operation_type: str) -> Any:
|
||||
"""Execute a function with timeout protection (cross-platform).
|
||||
|
||||
Args:
|
||||
func: Function to execute
|
||||
timeout_seconds: Maximum execution time
|
||||
operation_type: Type of operation (for error messages)
|
||||
|
||||
Returns:
|
||||
Result of the function
|
||||
|
||||
Raises:
|
||||
TimeoutError: If execution exceeds timeout
|
||||
"""
|
||||
future = _timeout_executor.submit(func)
|
||||
try:
|
||||
return future.result(timeout=timeout_seconds)
|
||||
except FutureTimeoutError as e:
|
||||
# Attempt to cancel the future (may not always work for CPU-bound tasks)
|
||||
future.cancel()
|
||||
raise TimeoutError(
|
||||
f"Regex {operation_type} exceeded {timeout_seconds}s timeout"
|
||||
) from e
|
||||
except Exception:
|
||||
# Re-raise any exception from the function
|
||||
raise
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _cross_platform_timeout(timeout_seconds: float, operation_type: str):
|
||||
"""Cross-platform timeout mechanism.
|
||||
|
||||
For Unix/Linux, uses efficient signal-based timeout.
|
||||
For Windows, operations must be wrapped with _execute_with_timeout.
|
||||
|
||||
Args:
|
||||
timeout_seconds: Maximum execution time in seconds
|
||||
operation_type: Type of operation (for error messages)
|
||||
|
||||
Yields:
|
||||
None - operations should be performed within the context
|
||||
|
||||
Raises:
|
||||
TimeoutError: If operation exceeds timeout
|
||||
"""
|
||||
# Check if we can use signal-based timeout (Unix/Linux)
|
||||
if hasattr(signal, 'SIGALRM') and sys.platform != 'win32':
|
||||
# Use efficient signal-based timeout on Unix/Linux
|
||||
def timeout_handler(signum: int, frame: Any) -> None:
|
||||
raise TimeoutError(f"Regex {operation_type} exceeded {timeout_seconds}s timeout")
|
||||
|
||||
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
|
||||
signal.alarm(int(timeout_seconds))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
signal.alarm(0)
|
||||
signal.signal(signal.SIGALRM, old_handler)
|
||||
else:
|
||||
# On Windows, we can't use signals, so operations need to be
|
||||
# executed through _execute_with_timeout
|
||||
logger.debug(
|
||||
"Windows platform detected - timeout protection requires thread execution"
|
||||
)
|
||||
yield
|
||||
|
||||
|
||||
class RegexSecurityError(Exception):
|
||||
"""Exception raised when regex pattern poses security risk."""
|
||||
|
||||
def __init__(self, message: str, pattern: str, reason: str):
|
||||
"""Initialize regex security error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
pattern: The problematic regex pattern
|
||||
reason: Specific reason for rejection
|
||||
"""
|
||||
super().__init__(message)
|
||||
self.pattern = pattern
|
||||
self.reason = reason
|
||||
|
||||
|
||||
class SafeRegexCompiler:
|
||||
"""Safe regex compiler with ReDoS protection."""
|
||||
|
||||
def __init__(self, max_pattern_length: int = 500, default_timeout: float = 1.0):
|
||||
"""Initialize safe regex compiler.
|
||||
|
||||
Args:
|
||||
max_pattern_length: Maximum allowed pattern length
|
||||
default_timeout: Default timeout for regex execution
|
||||
"""
|
||||
self.max_pattern_length = max_pattern_length
|
||||
self.default_timeout = default_timeout
|
||||
|
||||
def _run_with_timeout(self, func: Callable[[], Any], timeout: float, operation_type: str) -> Any:
|
||||
"""Run a function with timeout protection using unified timeout mechanism.
|
||||
|
||||
Args:
|
||||
func: Function to execute
|
||||
timeout: Timeout in seconds
|
||||
operation_type: Type of operation (for error messages)
|
||||
|
||||
Returns:
|
||||
Result of the function
|
||||
|
||||
Raises:
|
||||
TimeoutError: If execution exceeds timeout
|
||||
"""
|
||||
if sys.platform == 'win32':
|
||||
return _execute_with_timeout(func, timeout, operation_type)
|
||||
# Unix/Linux: signal-based
|
||||
with _cross_platform_timeout(timeout, operation_type):
|
||||
return func()
|
||||
|
||||
def compile_safe(self, pattern: str, flags: int = 0, timeout: float | None = None) -> Pattern[str]:
|
||||
"""Safely compile regex pattern with validation.
|
||||
|
||||
Args:
|
||||
pattern: Regex pattern to compile
|
||||
flags: Regex compilation flags
|
||||
timeout: Execution timeout (uses default if None)
|
||||
|
||||
Returns:
|
||||
Compiled regex pattern
|
||||
|
||||
Raises:
|
||||
RegexSecurityError: If pattern is potentially dangerous
|
||||
"""
|
||||
# Validate pattern length
|
||||
if len(pattern) > self.max_pattern_length:
|
||||
raise RegexSecurityError(
|
||||
f"Regex pattern too long: {len(pattern)} chars (max {self.max_pattern_length})",
|
||||
pattern,
|
||||
"pattern_length"
|
||||
)
|
||||
|
||||
# Check for dangerous constructs
|
||||
self._validate_pattern_safety(pattern)
|
||||
|
||||
# Compile with timeout protection
|
||||
try:
|
||||
compiled_pattern = self._run_with_timeout(
|
||||
lambda: re.compile(pattern, flags),
|
||||
timeout or self.default_timeout,
|
||||
"compilation"
|
||||
)
|
||||
logger.debug(f"Successfully compiled safe regex: {pattern}")
|
||||
return compiled_pattern
|
||||
except TimeoutError as e:
|
||||
raise RegexSecurityError(
|
||||
f"Regex compilation timed out: {pattern}",
|
||||
pattern,
|
||||
"compilation_timeout",
|
||||
) from e
|
||||
except re.error as e:
|
||||
raise RegexSecurityError(
|
||||
f"Invalid regex pattern: {e}", pattern, "invalid_syntax"
|
||||
) from e
|
||||
|
||||
def _validate_pattern_safety(self, pattern: str) -> None:
|
||||
"""Validate pattern for ReDoS vulnerabilities.
|
||||
|
||||
Args:
|
||||
pattern: Pattern to validate
|
||||
|
||||
Raises:
|
||||
RegexSecurityError: If dangerous constructs found
|
||||
"""
|
||||
dangerous_constructs = [
|
||||
# Nested quantifiers (exponential complexity)
|
||||
(r'\*\*', "nested_quantifier_**"),
|
||||
(r'\+\+', "nested_quantifier_++"),
|
||||
(r'\?\?', "nested_quantifier_??"),
|
||||
(r'\{\d*,?\d*\}\*', "quantifier_after_quantifier"),
|
||||
(r'\{\d*,?\d*\}\+', "quantifier_after_quantifier"),
|
||||
|
||||
# Dangerous group patterns
|
||||
(r'\([^)]*\)\*\*', "group_with_nested_quantifiers"),
|
||||
(r'\([^)]*\)\+\+', "group_with_nested_quantifiers"),
|
||||
(r'\([^)]*\|[^)]*\)\*', "alternation_in_repeated_group"),
|
||||
(r'\([^)]*\|[^)]*\)\+', "alternation_in_repeated_group"),
|
||||
|
||||
# Multiple .* patterns (polynomial complexity)
|
||||
(r'(\.\*){2,}', "multiple_wildcard_patterns"),
|
||||
(r'\.\*.*\.\*', "multiple_wildcard_patterns"),
|
||||
|
||||
# Complex lookarounds that can cause backtracking
|
||||
(r'\(\?\=.*\.\*.*\)', "complex_positive_lookahead"),
|
||||
(r'\(\?\!.*\.\*.*\)', "complex_negative_lookahead"),
|
||||
(r'\(\?\<\=.*\.\*.*\)', "complex_positive_lookbehind"),
|
||||
(r'\(\?\<\!.*\.\*.*\)', "complex_negative_lookbehind"),
|
||||
|
||||
# Potentially explosive character classes
|
||||
(r'\[[^\]]{50,}\]', "oversized_character_class"),
|
||||
]
|
||||
|
||||
for construct_pattern, reason in dangerous_constructs:
|
||||
try:
|
||||
if re.search(construct_pattern, pattern):
|
||||
raise RegexSecurityError(
|
||||
f"Potentially dangerous regex construct detected: {reason}",
|
||||
pattern,
|
||||
reason
|
||||
)
|
||||
except re.error:
|
||||
# If we can't validate the construct pattern, skip it
|
||||
logger.warning(f"Failed to validate construct pattern: {construct_pattern}")
|
||||
continue
|
||||
|
||||
@contextmanager
|
||||
def _compilation_timeout(self, timeout_seconds: float):
|
||||
"""Context manager for regex compilation timeout (cross-platform).
|
||||
|
||||
This implementation uses threading for cross-platform compatibility,
|
||||
providing timeout protection on both Unix/Linux and Windows systems.
|
||||
|
||||
Args:
|
||||
timeout_seconds: Maximum compilation time
|
||||
|
||||
Raises:
|
||||
TimeoutError: If compilation exceeds timeout
|
||||
"""
|
||||
# Use cross-platform timeout mechanism
|
||||
with _cross_platform_timeout(timeout_seconds, "compilation"):
|
||||
yield
|
||||
|
||||
|
||||
class SafeRegexExecutor:
|
||||
"""Safe regex executor with timeout and input validation."""
|
||||
|
||||
def __init__(self, default_timeout: float = 1.0, max_input_length: int = 100000):
|
||||
"""Initialize safe regex executor.
|
||||
|
||||
Args:
|
||||
default_timeout: Default timeout for regex operations
|
||||
max_input_length: Maximum input text length
|
||||
"""
|
||||
self.default_timeout = default_timeout
|
||||
self.max_input_length = max_input_length
|
||||
|
||||
def _run_with_timeout(self, func: Callable[[], Any], timeout: float, operation_type: str) -> Any:
|
||||
"""Run a function with timeout protection using unified timeout mechanism.
|
||||
|
||||
Args:
|
||||
func: Function to execute
|
||||
timeout: Timeout in seconds
|
||||
operation_type: Type of operation (for error messages)
|
||||
|
||||
Returns:
|
||||
Result of the function
|
||||
|
||||
Raises:
|
||||
TimeoutError: If execution exceeds timeout
|
||||
"""
|
||||
if sys.platform == 'win32':
|
||||
return _execute_with_timeout(func, timeout, operation_type)
|
||||
# Unix/Linux: signal-based
|
||||
with _cross_platform_timeout(timeout, operation_type):
|
||||
return func()
|
||||
|
||||
def search_safe(self, pattern: Pattern[str], text: str, timeout: float | None = None) -> re.Match[str] | None:
|
||||
"""Safely execute regex search with timeout protection.
|
||||
|
||||
Args:
|
||||
pattern: Compiled regex pattern
|
||||
text: Text to search
|
||||
timeout: Execution timeout (uses default if None)
|
||||
|
||||
Returns:
|
||||
Match object or None if no match found
|
||||
|
||||
Raises:
|
||||
ValueError: If input is too long
|
||||
TimeoutError: If execution exceeds timeout
|
||||
"""
|
||||
self._validate_input(text)
|
||||
|
||||
try:
|
||||
if sys.platform == 'win32':
|
||||
# On Windows, use thread-based timeout
|
||||
def search_func():
|
||||
return pattern.search(text)
|
||||
|
||||
return _execute_with_timeout(
|
||||
search_func,
|
||||
timeout or self.default_timeout,
|
||||
"search"
|
||||
)
|
||||
else:
|
||||
# On Unix/Linux, use signal-based timeout
|
||||
with self._execution_timeout(timeout or self.default_timeout):
|
||||
return pattern.search(text)
|
||||
except TimeoutError:
|
||||
logger.warning(f"Regex search timeout with pattern: {pattern.pattern}")
|
||||
raise
|
||||
|
||||
def findall_safe(self, pattern: Pattern[str], text: str, timeout: float | None = None) -> list[str]:
|
||||
"""Safely execute regex findall with timeout protection.
|
||||
|
||||
Args:
|
||||
pattern: Compiled regex pattern
|
||||
text: Text to search
|
||||
timeout: Execution timeout (uses default if None)
|
||||
|
||||
Returns:
|
||||
List of all matches
|
||||
|
||||
Raises:
|
||||
ValueError: If input is too long
|
||||
TimeoutError: If execution exceeds timeout
|
||||
"""
|
||||
self._validate_input(text)
|
||||
|
||||
try:
|
||||
if sys.platform == 'win32':
|
||||
# On Windows, use thread-based timeout
|
||||
def findall_func():
|
||||
return pattern.findall(text)
|
||||
|
||||
return _execute_with_timeout(
|
||||
findall_func,
|
||||
timeout or self.default_timeout,
|
||||
"findall"
|
||||
)
|
||||
else:
|
||||
# On Unix/Linux, use signal-based timeout
|
||||
with self._execution_timeout(timeout or self.default_timeout):
|
||||
return pattern.findall(text)
|
||||
except TimeoutError:
|
||||
logger.warning(f"Regex findall timeout with pattern: {pattern.pattern}")
|
||||
raise
|
||||
|
||||
def sub_safe(self, pattern: Pattern[str], repl: str, text: str,
|
||||
count: int = 0, timeout: float | None = None) -> str:
|
||||
"""Safely execute regex substitution with timeout protection.
|
||||
|
||||
Args:
|
||||
pattern: Compiled regex pattern
|
||||
repl: Replacement string
|
||||
text: Text to process
|
||||
count: Maximum number of substitutions (0 = all)
|
||||
timeout: Execution timeout (uses default if None)
|
||||
|
||||
Returns:
|
||||
Text with substitutions made
|
||||
|
||||
Raises:
|
||||
ValueError: If input is too long
|
||||
TimeoutError: If execution exceeds timeout
|
||||
"""
|
||||
self._validate_input(text)
|
||||
|
||||
try:
|
||||
if sys.platform == 'win32':
|
||||
# On Windows, use thread-based timeout
|
||||
def sub_func():
|
||||
return pattern.sub(repl, text, count)
|
||||
|
||||
return _execute_with_timeout(
|
||||
sub_func,
|
||||
timeout or self.default_timeout,
|
||||
"substitution"
|
||||
)
|
||||
else:
|
||||
# On Unix/Linux, use signal-based timeout
|
||||
with self._execution_timeout(timeout or self.default_timeout):
|
||||
return pattern.sub(repl, text, count)
|
||||
except TimeoutError:
|
||||
logger.warning(f"Regex substitution timeout with pattern: {pattern.pattern}")
|
||||
raise
|
||||
|
||||
def _validate_input(self, text: str) -> None:
|
||||
"""Validate input text for safety.
|
||||
|
||||
Args:
|
||||
text: Input text to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If input is too long or contains dangerous patterns
|
||||
"""
|
||||
if len(text) > self.max_input_length:
|
||||
raise ValidationError(f"Input text too long: {len(text)} chars (max {self.max_input_length})")
|
||||
|
||||
# Check for patterns that could amplify ReDoS attacks
|
||||
dangerous_input_patterns = [
|
||||
(lambda t: t.count('{') > 1000, "too_many_braces"),
|
||||
(lambda t: t.count('[') > 1000, "too_many_brackets"),
|
||||
(lambda t: t.count('(') > 1000, "too_many_parentheses"),
|
||||
(lambda t: t.count('<') > 1000, "too_many_angle_brackets"),
|
||||
(lambda t: t.count('"') > 2000, "too_many_quotes"),
|
||||
]
|
||||
|
||||
for check_func, reason in dangerous_input_patterns:
|
||||
if check_func(text):
|
||||
logger.warning(f"Input contains potentially dangerous pattern: {reason}")
|
||||
# Don't raise exception, just log warning for monitoring
|
||||
break
|
||||
|
||||
@contextmanager
|
||||
def _execution_timeout(self, timeout_seconds: float):
|
||||
"""Context manager for regex execution timeout (cross-platform).
|
||||
|
||||
This implementation uses threading for cross-platform compatibility,
|
||||
providing timeout protection on both Unix/Linux and Windows systems.
|
||||
|
||||
Args:
|
||||
timeout_seconds: Maximum execution time
|
||||
|
||||
Raises:
|
||||
TimeoutError: If execution exceeds timeout
|
||||
"""
|
||||
# Use cross-platform timeout mechanism
|
||||
with _cross_platform_timeout(timeout_seconds, "execution"):
|
||||
yield
|
||||
|
||||
|
||||
# Global instances for convenience
|
||||
_global_compiler: SafeRegexCompiler | None = None
|
||||
_global_executor: SafeRegexExecutor | None = None
|
||||
|
||||
|
||||
def get_safe_compiler() -> SafeRegexCompiler:
|
||||
"""Get global safe regex compiler instance."""
|
||||
global _global_compiler
|
||||
if _global_compiler is None:
|
||||
_global_compiler = SafeRegexCompiler()
|
||||
return _global_compiler
|
||||
|
||||
|
||||
def get_safe_executor() -> SafeRegexExecutor:
|
||||
"""Get global safe regex executor instance."""
|
||||
global _global_executor
|
||||
if _global_executor is None:
|
||||
_global_executor = SafeRegexExecutor()
|
||||
return _global_executor
|
||||
|
||||
|
||||
# Convenience functions
|
||||
def compile_safe(pattern: str, flags: int = 0, timeout: float | None = None) -> Pattern[str]:
|
||||
"""Safely compile regex pattern with ReDoS protection.
|
||||
|
||||
Args:
|
||||
pattern: Regex pattern to compile
|
||||
flags: Regex compilation flags
|
||||
timeout: Compilation timeout
|
||||
|
||||
Returns:
|
||||
Compiled regex pattern
|
||||
|
||||
Raises:
|
||||
RegexSecurityError: If pattern is potentially dangerous
|
||||
"""
|
||||
return get_safe_compiler().compile_safe(pattern, flags, timeout)
|
||||
|
||||
|
||||
def search_safe(pattern: str | Pattern[str], text: str,
|
||||
flags: int = 0, timeout: float | None = None) -> re.Match[str] | None:
|
||||
"""Safely search text with regex pattern.
|
||||
|
||||
Args:
|
||||
pattern: Regex pattern (string or compiled)
|
||||
text: Text to search
|
||||
flags: Regex flags (ignored if pattern is already compiled)
|
||||
timeout: Execution timeout
|
||||
|
||||
Returns:
|
||||
Match object or None
|
||||
"""
|
||||
if isinstance(pattern, str):
|
||||
pattern = compile_safe(pattern, flags, timeout)
|
||||
|
||||
return get_safe_executor().search_safe(pattern, text, timeout)
|
||||
|
||||
|
||||
def findall_safe(pattern: str | Pattern[str], text: str,
|
||||
flags: int = 0, timeout: float | None = None) -> list[str]:
|
||||
"""Safely find all matches with regex pattern.
|
||||
|
||||
Args:
|
||||
pattern: Regex pattern (string or compiled)
|
||||
text: Text to search
|
||||
flags: Regex flags (ignored if pattern is already compiled)
|
||||
timeout: Execution timeout
|
||||
|
||||
Returns:
|
||||
List of all matches
|
||||
"""
|
||||
if isinstance(pattern, str):
|
||||
pattern = compile_safe(pattern, flags, timeout)
|
||||
|
||||
return get_safe_executor().findall_safe(pattern, text, timeout)
|
||||
|
||||
|
||||
def sub_safe(pattern: str | Pattern[str], repl: str, text: str,
|
||||
count: int = 0, flags: int = 0, timeout: float | None = None) -> str:
|
||||
"""Safely substitute matches with regex pattern.
|
||||
|
||||
Args:
|
||||
pattern: Regex pattern (string or compiled)
|
||||
repl: Replacement string
|
||||
text: Text to process
|
||||
count: Maximum substitutions (0 = all)
|
||||
flags: Regex flags (ignored if pattern is already compiled)
|
||||
timeout: Execution timeout
|
||||
|
||||
Returns:
|
||||
Text with substitutions made
|
||||
"""
|
||||
if isinstance(pattern, str):
|
||||
pattern = compile_safe(pattern, flags, timeout)
|
||||
|
||||
return get_safe_executor().sub_safe(pattern, repl, text, count, timeout)
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def is_pattern_safe(pattern: str) -> bool:
|
||||
"""Check if regex pattern is safe (cached for performance).
|
||||
|
||||
Args:
|
||||
pattern: Regex pattern to check
|
||||
|
||||
Returns:
|
||||
True if pattern is safe, False otherwise
|
||||
"""
|
||||
try:
|
||||
compile_safe(pattern)
|
||||
return True
|
||||
except RegexSecurityError:
|
||||
return False
|
||||
|
||||
|
||||
def cleanup_regex_executor() -> None:
|
||||
"""Clean up the thread pool executor.
|
||||
|
||||
This should be called during application shutdown to ensure
|
||||
all threads are properly terminated.
|
||||
"""
|
||||
global _timeout_executor
|
||||
if _timeout_executor:
|
||||
_timeout_executor.shutdown(wait=True)
|
||||
logger.info("Regex timeout executor shutdown complete")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RegexSecurityError",
|
||||
"SafeRegexCompiler",
|
||||
"SafeRegexExecutor",
|
||||
"get_safe_compiler",
|
||||
"get_safe_executor",
|
||||
"compile_safe",
|
||||
"search_safe",
|
||||
"findall_safe",
|
||||
"sub_safe",
|
||||
"is_pattern_safe",
|
||||
"cleanup_regex_executor",
|
||||
]
|
||||
@@ -14,7 +14,7 @@ from __future__ import annotations
|
||||
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Literal, TypedDict, cast
|
||||
from typing import Any, Literal, Pattern, TypedDict, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
# Cache decorator removed - requires explicit backend now
|
||||
@@ -42,6 +42,56 @@ logger = get_logger(__name__)
|
||||
# URL type literals for type safety
|
||||
URLType = Literal["webpage", "pdf", "image", "video", "git_repo", "sitemap", "unknown"]
|
||||
|
||||
|
||||
def _compile_safe_regex(pattern: str) -> Pattern[str] | None:
|
||||
"""Safely compile regex pattern with validation to prevent ReDoS attacks.
|
||||
|
||||
Args:
|
||||
pattern: Regex pattern to compile
|
||||
|
||||
Returns:
|
||||
Compiled regex pattern or None if validation fails
|
||||
|
||||
Raises:
|
||||
ValueError: If pattern is potentially dangerous
|
||||
"""
|
||||
# Validate pattern length
|
||||
if len(pattern) > 500:
|
||||
logger.warning(f"Regex pattern too long ({len(pattern)} chars), max 500")
|
||||
return None
|
||||
|
||||
# Check for dangerous regex constructs that could cause ReDoS
|
||||
dangerous_constructs = [
|
||||
r'\*\*', # Nested quantifiers like **
|
||||
r'\+\+', # Nested quantifiers like ++
|
||||
r'\?\?', # Nested quantifiers like ??
|
||||
r'\([^)]*\)\*\*', # Group with nested quantifiers
|
||||
r'\([^)]*\|[^)]*\)\*', # Alternation in repeated group
|
||||
r'\([^)]*\+\)[^)]*\+', # Multiple nested quantifiers
|
||||
r'(\.\*){2,}', # Multiple .* patterns
|
||||
]
|
||||
|
||||
for construct in dangerous_constructs:
|
||||
try:
|
||||
if re.search(construct, pattern):
|
||||
logger.warning(f"Potentially dangerous regex construct detected: {construct}")
|
||||
return None
|
||||
except re.error:
|
||||
# If we can't even validate the pattern, it's likely problematic
|
||||
logger.warning(f"Invalid regex construct in validation pattern: {construct}")
|
||||
continue
|
||||
|
||||
# Try to compile the pattern using centralized security validation
|
||||
try:
|
||||
from biz_bud.core.utils.regex_security import SafeRegexCompiler
|
||||
compiler = SafeRegexCompiler(max_pattern_length=1000, default_timeout=2.0)
|
||||
compiled_pattern = compiler.compile_safe(pattern, re.IGNORECASE)
|
||||
logger.debug(f"Successfully compiled safe regex pattern: {pattern}")
|
||||
return compiled_pattern
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to compile regex pattern '{pattern}': {e}")
|
||||
return None
|
||||
|
||||
# Git hosting providers
|
||||
GIT_HOSTS = ["github.com", "gitlab.com", "bitbucket.org", "codeberg.org", "sr.ht"]
|
||||
|
||||
@@ -233,33 +283,43 @@ def _apply_custom_rules(
|
||||
URLAnalysisResult if rule matches, None otherwise
|
||||
"""
|
||||
for rule in rules:
|
||||
if re.search(rule["pattern"], url, re.IGNORECASE):
|
||||
extension = _get_file_extension(parsed_url.path)
|
||||
# SECURITY FIX: Use safe regex compilation to prevent ReDoS attacks
|
||||
safe_pattern = _compile_safe_regex(rule["pattern"])
|
||||
if safe_pattern is None:
|
||||
logger.warning(f"Skipping rule '{rule['name']}' due to unsafe regex pattern")
|
||||
continue
|
||||
|
||||
# Generate additional fields for backward compatibility
|
||||
normalized_url = normalize_url(url)
|
||||
is_url_valid = is_valid_url(url)
|
||||
try:
|
||||
if safe_pattern.search(url):
|
||||
extension = _get_file_extension(parsed_url.path)
|
||||
|
||||
result: URLAnalysisResult = {
|
||||
"url": url,
|
||||
"type": rule["url_type"],
|
||||
"url_type": rule["url_type"], # Backward compatibility alias
|
||||
"domain": parsed_url.netloc.lower(),
|
||||
"path": parsed_url.path.lower(),
|
||||
"is_git_repo": rule["url_type"] == "git_repo",
|
||||
"is_pdf": rule["url_type"] == "pdf",
|
||||
"is_image": rule["url_type"] == "image",
|
||||
"is_video": rule["url_type"] == "video",
|
||||
"is_sitemap": rule["url_type"] == "sitemap",
|
||||
"requires_javascript": False, # Custom rules override heuristic
|
||||
"javascript_framework": False, # Default for custom rules
|
||||
"file_extension": extension,
|
||||
"normalized_url": normalized_url,
|
||||
"is_valid": is_url_valid,
|
||||
"is_processable": is_url_valid,
|
||||
"metadata": {"classification_rule": rule["name"], **rule["metadata"]},
|
||||
}
|
||||
return result
|
||||
# Generate additional fields for backward compatibility
|
||||
normalized_url = normalize_url(url)
|
||||
is_url_valid = is_valid_url(url)
|
||||
|
||||
result: URLAnalysisResult = {
|
||||
"url": url,
|
||||
"type": rule["url_type"],
|
||||
"url_type": rule["url_type"], # Backward compatibility alias
|
||||
"domain": parsed_url.netloc.lower(),
|
||||
"path": parsed_url.path.lower(),
|
||||
"is_git_repo": rule["url_type"] == "git_repo",
|
||||
"is_pdf": rule["url_type"] == "pdf",
|
||||
"is_image": rule["url_type"] == "image",
|
||||
"is_video": rule["url_type"] == "video",
|
||||
"is_sitemap": rule["url_type"] == "sitemap",
|
||||
"requires_javascript": False, # Custom rules override heuristic
|
||||
"javascript_framework": False, # Default for custom rules
|
||||
"file_extension": extension,
|
||||
"normalized_url": normalized_url,
|
||||
"is_valid": is_url_valid,
|
||||
"is_processable": is_url_valid,
|
||||
"metadata": {"classification_rule": rule["name"], **rule["metadata"]},
|
||||
}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"Error applying custom rule '{rule['name']}': {e}")
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
@@ -624,8 +684,6 @@ class URLAnalyzer:
|
||||
try:
|
||||
# Validate legacy config if provided - URLAnalyzerConfig is always a dict
|
||||
if config is not None:
|
||||
pass # config is guaranteed to be URLAnalyzerConfig (TypedDict) which is a dict
|
||||
|
||||
# Validate specific config fields
|
||||
if "cache_size" in config:
|
||||
cache_size = config["cache_size"]
|
||||
@@ -655,33 +713,28 @@ class URLAnalyzer:
|
||||
)
|
||||
|
||||
# Validate processor config if provided
|
||||
if processor_config is not None:
|
||||
if isinstance(processor_config, dict):
|
||||
# Validate dict-based processor config
|
||||
validation_level = processor_config.get("validation_level")
|
||||
if validation_level and validation_level not in ["basic", "standard", "strict"]:
|
||||
from biz_bud.core.errors.specialized_exceptions import URLConfigurationError
|
||||
raise URLConfigurationError(
|
||||
"validation_level must be 'basic', 'standard', or 'strict'",
|
||||
config_field="validation_level",
|
||||
config_value=validation_level,
|
||||
requirement="basic|standard|strict",
|
||||
)
|
||||
# For other types, assume they're valid processor config objects
|
||||
# (validation will happen during component initialization)
|
||||
if processor_config is not None and isinstance(processor_config, dict):
|
||||
validation_level = processor_config.get("validation_level")
|
||||
if validation_level and validation_level not in ["basic", "standard", "strict"]:
|
||||
from biz_bud.core.errors.specialized_exceptions import URLConfigurationError
|
||||
raise URLConfigurationError(
|
||||
"validation_level must be 'basic', 'standard', or 'strict'",
|
||||
config_field="validation_level",
|
||||
config_value=validation_level,
|
||||
requirement="basic|standard|strict",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if "URLConfigurationError" in str(type(e)):
|
||||
raise # Re-raise configuration errors
|
||||
else:
|
||||
# Wrap other exceptions as configuration errors
|
||||
from biz_bud.core.errors.specialized_exceptions import URLConfigurationError
|
||||
raise URLConfigurationError(
|
||||
f"Configuration validation failed: {e}",
|
||||
config_field="unknown",
|
||||
config_value=str(e),
|
||||
requirement="valid configuration",
|
||||
) from e
|
||||
# Wrap other exceptions as configuration errors
|
||||
from biz_bud.core.errors.specialized_exceptions import URLConfigurationError
|
||||
raise URLConfigurationError(
|
||||
f"Configuration validation failed: {e}",
|
||||
config_field="unknown",
|
||||
config_value=str(e),
|
||||
requirement="valid configuration",
|
||||
) from e
|
||||
|
||||
def _ensure_components_initialized(self) -> None:
|
||||
"""Initialize URL processing components if not already done."""
|
||||
@@ -781,7 +834,7 @@ class URLAnalyzer:
|
||||
logger.warning(
|
||||
"Enhanced processing not available, falling back to legacy mode"
|
||||
)
|
||||
return await self._legacy_process_urls(urls, **options)
|
||||
return self._legacy_process_urls(urls, **options)
|
||||
|
||||
try:
|
||||
# Use the URLProcessor for enhanced processing
|
||||
@@ -1002,7 +1055,7 @@ class URLAnalyzer:
|
||||
)
|
||||
return url
|
||||
|
||||
async def deduplicate_urls(self, urls: list[str], **options: Any) -> list[str]:
|
||||
def deduplicate_urls(self, urls: list[str], **options: Any) -> list[str]:
|
||||
"""Remove duplicate URLs using intelligent matching with asyncio support.
|
||||
|
||||
Args:
|
||||
@@ -1060,7 +1113,7 @@ class URLAnalyzer:
|
||||
"""Synchronous version of URL deduplication for backward compatibility."""
|
||||
return self.deduplicate_urls_legacy(urls)
|
||||
|
||||
async def _legacy_process_urls(
|
||||
def _legacy_process_urls(
|
||||
self, urls: list[str], **options: Any
|
||||
) -> dict[str, Any]:
|
||||
"""Legacy processing implementation for fallback."""
|
||||
|
||||
@@ -122,7 +122,9 @@ def should_skip_content(url: str) -> bool:
|
||||
|
||||
# PDF filetype skip logic: match .pdf at end of path, before
|
||||
# query/fragment, or encoded .pdf (case-insensitive, all encodings)
|
||||
pdf_url_pattern = re.compile(
|
||||
from biz_bud.core.utils.regex_security import SafeRegexCompiler
|
||||
compiler = SafeRegexCompiler(max_pattern_length=1000, default_timeout=2.0)
|
||||
pdf_url_pattern = compiler.compile_safe(
|
||||
r"(\.pdf($|[\?#]))|(/pdf($|[\?#]))|(\?pdf($|[\?#]))", re.IGNORECASE
|
||||
)
|
||||
parsed = urlparse(decoded_url)
|
||||
@@ -212,7 +214,7 @@ def preprocess_content(arg1: str | None, arg2: str) -> str:
|
||||
return ""
|
||||
content = arg2
|
||||
cleaned = re.sub(r"(?is)<.*?>", "", content)
|
||||
# Remove control characters (U+0000–U+001F and DEL)
|
||||
# Remove control characters (U+0000-U+001F and DEL)
|
||||
cleaned = re.sub(r"[\x00-\x1F\x7F]+", " ", cleaned)
|
||||
# Normalize whitespace: collapse all whitespace to a single space, then trim
|
||||
cleaned = re.sub(r"\s+", " ", cleaned)
|
||||
|
||||
@@ -23,10 +23,10 @@ from pathlib import Path
|
||||
from typing import cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from docling.document_converter import DocumentConverter
|
||||
|
||||
from biz_bud.core.errors import ValidationError
|
||||
from biz_bud.core.errors import NetworkError, ValidationError
|
||||
from biz_bud.core.networking.http_client import HTTPClient, HTTPClientConfig
|
||||
from biz_bud.logging import get_logger
|
||||
|
||||
# --- Logger setup ---
|
||||
@@ -110,18 +110,27 @@ async def download_document(url: str, timeout: int = 30) -> bytes | None:
|
||||
"""Download document content from the provided URL.
|
||||
|
||||
Returns document content as bytes if successful, None otherwise.
|
||||
Raises requests.RequestException if the download fails.
|
||||
Raises NetworkError if the download fails.
|
||||
"""
|
||||
logger.info(f"Downloading document from {url}")
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
lambda: requests.get(url, stream=True, timeout=timeout)
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Error downloading document: {e}")
|
||||
# Get HTTP client with custom timeout
|
||||
config = HTTPClientConfig(timeout=timeout)
|
||||
client = await HTTPClient.get_or_create_client(config)
|
||||
|
||||
# Download the document
|
||||
response = await client.get(url)
|
||||
|
||||
# Check if response was successful
|
||||
if response["status_code"] >= 400:
|
||||
raise NetworkError(f"HTTP {response['status_code']}: Failed to download document from {url}")
|
||||
|
||||
return response["content"]
|
||||
except NetworkError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading document: {e}")
|
||||
raise NetworkError(f"Failed to download document from {url}: {e}") from e
|
||||
|
||||
|
||||
def create_temp_document_file(content: bytes, file_extension: str) -> tuple[str, str]:
|
||||
|
||||
@@ -139,7 +139,7 @@ def validate_node_input(input_model: type[BaseModel]) -> Callable[[F], F]:
|
||||
if is_validated(func):
|
||||
return func
|
||||
|
||||
is_async: bool = inspect.iscoroutinefunction(func)
|
||||
is_async: bool = asyncio.iscoroutinefunction(func)
|
||||
|
||||
if is_async:
|
||||
|
||||
@@ -198,7 +198,7 @@ def validate_node_output(output_model: type[BaseModel]) -> Callable[[F], F]:
|
||||
def decorator(func: F) -> F:
|
||||
if is_validated(func):
|
||||
return func
|
||||
is_async: bool = inspect.iscoroutinefunction(func)
|
||||
is_async: bool = asyncio.iscoroutinefunction(func)
|
||||
if is_async:
|
||||
|
||||
@functools.wraps(func)
|
||||
@@ -292,7 +292,7 @@ def validated_node[F: Callable[..., object]](
|
||||
return decorator(_func)
|
||||
|
||||
|
||||
async def validate_graph(graph: object, graph_id: str = "unknown") -> bool:
|
||||
def validate_graph(graph: object, graph_id: str = "unknown") -> bool:
|
||||
"""Validate that a graph has all required methods for compatibility.
|
||||
|
||||
Args:
|
||||
@@ -321,7 +321,7 @@ async def ensure_graph_compatibility(
|
||||
Adds default implementations for any missing required methods.
|
||||
"""
|
||||
with contextlib.suppress(GraphValidationError):
|
||||
await validate_graph(graph, graph_id)
|
||||
validate_graph(graph, graph_id)
|
||||
return graph # Already compatible
|
||||
|
||||
def make_not_implemented(
|
||||
@@ -339,7 +339,7 @@ async def ensure_graph_compatibility(
|
||||
setattr(graph, method_name, make_not_implemented(method_name, graph_id))
|
||||
|
||||
try:
|
||||
await validate_graph(graph, graph_id)
|
||||
validate_graph(graph, graph_id)
|
||||
return graph
|
||||
except GraphValidationError:
|
||||
return graph
|
||||
@@ -369,7 +369,7 @@ async def validate_all_graphs(
|
||||
# Skip functions that require arguments for now
|
||||
all_valid = False
|
||||
continue
|
||||
await validate_graph(graph, name)
|
||||
validate_graph(graph, name)
|
||||
except Exception:
|
||||
all_valid = False
|
||||
return all_valid
|
||||
|
||||
@@ -627,7 +627,7 @@ class SecureExecutionManager:
|
||||
f"Completed execution of {operation_name} in {execution_time:.2f}s"
|
||||
)
|
||||
|
||||
async def validate_factory_function(
|
||||
def validate_factory_function(
|
||||
self, factory_function: Callable[[], Any], graph_name: str
|
||||
) -> None:
|
||||
"""Validate that a factory function is safe to execute.
|
||||
|
||||
@@ -2,6 +2,49 @@
|
||||
|
||||
This guide provides best practices for constructing LangGraph workflows in the Business Buddy framework, based on project patterns and modern LangGraph standards. Focus on anti-technical debt patterns, service integration, and maintainable graph architectures.
|
||||
|
||||
## Core Package Support
|
||||
|
||||
The Business Buddy framework provides extensive utilities through the core package (`@src/biz_bud/core/`) to simplify graph construction and reduce code duplication:
|
||||
|
||||
### Configuration Management
|
||||
- **`generate_config_hash`**: Create deterministic cache keys from configuration
|
||||
- **`load_config` / `load_config_async`**: Type-safe configuration loading with precedence
|
||||
- **`AppConfig`**: Strongly-typed configuration schema
|
||||
|
||||
### Async/Sync Utilities
|
||||
- **`create_async_sync_wrapper`**: Create sync/async factory function pairs
|
||||
- **`handle_sync_async_context`**: Smart context detection and execution
|
||||
- **`detect_async_context`**: Reliable async context detection
|
||||
- **`run_in_appropriate_context`**: Execute async code from any context
|
||||
|
||||
### State Management
|
||||
- **`process_state_query`**: Extract queries from various sources
|
||||
- **`format_raw_input`**: Normalize raw input to consistent format
|
||||
- **`extract_state_update_data`**: Extract data from LangGraph state updates
|
||||
- **`create_initial_state_dict`**: Create properly formatted initial states
|
||||
- **`normalize_errors_to_list`**: Ensure errors are always list format
|
||||
|
||||
### Type Compatibility
|
||||
- **`create_type_safe_wrapper`**: Wrap functions for LangGraph type safety
|
||||
- **`wrap_for_langgraph`**: Decorator for type-safe conditional edges
|
||||
|
||||
### Caching Infrastructure
|
||||
- **`LLMCache`**: Specialized cache for LLM responses
|
||||
- **`GraphCache`**: Thread-safe cache for compiled graph instances
|
||||
- **Cache backends**: Memory, Redis, and file-based caching
|
||||
|
||||
### Error Handling & Cleanup
|
||||
- **`CleanupRegistry`**: Centralized resource cleanup management
|
||||
- **`normalize_errors_to_list`**: Consistent error list handling
|
||||
- **Custom exceptions**: Project-specific error types
|
||||
|
||||
### LangGraph Utilities
|
||||
- **`standard_node`**: Decorator for consistent node patterns
|
||||
- **`handle_errors`**: Automatic error handling decorator
|
||||
- **`log_node_execution`**: Node execution logging
|
||||
- **`route_error_severity`**: Error-based routing logic
|
||||
- **`route_llm_output`**: LLM output-based routing
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```
|
||||
@@ -18,6 +61,156 @@ src/biz_bud/graphs/
|
||||
└── error_handling.py # Graph-level error handling and recovery
|
||||
```
|
||||
|
||||
## Using Core Package Utilities
|
||||
|
||||
### Configuration and Caching
|
||||
|
||||
```python
|
||||
from biz_bud.core.config.loader import load_config_async, generate_config_hash
|
||||
from biz_bud.core.caching.cache_manager import GraphCache
|
||||
|
||||
async def create_cached_graph():
|
||||
"""Create graph with configuration-based caching."""
|
||||
# Load configuration with type safety
|
||||
config = await load_config_async(overrides={"temperature": 0.8})
|
||||
|
||||
# Generate cache key from config
|
||||
cache_key = generate_config_hash(config)
|
||||
|
||||
# Use GraphCache for thread-safe caching
|
||||
graph_cache = GraphCache()
|
||||
graph = await graph_cache.get_or_create(
|
||||
cache_key,
|
||||
create_graph_with_config,
|
||||
config
|
||||
)
|
||||
|
||||
return graph
|
||||
```
|
||||
|
||||
### State Creation Utilities
|
||||
|
||||
```python
|
||||
from biz_bud.core.utils.graph_helpers import (
|
||||
process_state_query,
|
||||
format_raw_input,
|
||||
create_initial_state_dict
|
||||
)
|
||||
from biz_bud.core.utils import normalize_errors_to_list
|
||||
|
||||
def create_state_from_input(user_input: str | dict) -> dict:
|
||||
"""Create properly formatted state using core utilities."""
|
||||
# Process query from various sources
|
||||
query = process_state_query(
|
||||
query=None,
|
||||
messages=[{"role": "user", "content": user_input}],
|
||||
state_update=None,
|
||||
default_query="default query"
|
||||
)
|
||||
|
||||
# Format raw input consistently
|
||||
raw_input_str, extracted_query = format_raw_input(user_input, query)
|
||||
|
||||
# Create initial state with all required fields
|
||||
state = create_initial_state_dict(
|
||||
raw_input_str=raw_input_str,
|
||||
user_query=extracted_query,
|
||||
messages=[{"type": "human", "content": extracted_query}],
|
||||
thread_id="thread_123",
|
||||
config_dict={"llm": {"temperature": 0.7}}
|
||||
)
|
||||
|
||||
# Ensure errors are always lists
|
||||
state["errors"] = normalize_errors_to_list(state.get("errors"))
|
||||
|
||||
return state
|
||||
```
|
||||
|
||||
### Type-Safe Wrappers for LangGraph
|
||||
|
||||
```python
|
||||
from biz_bud.core.langgraph import (
|
||||
create_type_safe_wrapper,
|
||||
wrap_for_langgraph,
|
||||
route_error_severity,
|
||||
route_llm_output
|
||||
)
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
# Create type-safe wrappers for routing functions
|
||||
safe_error_router = create_type_safe_wrapper(route_error_severity)
|
||||
safe_llm_router = create_type_safe_wrapper(route_llm_output)
|
||||
|
||||
# Or use decorator pattern
|
||||
@wrap_for_langgraph(dict)
|
||||
def custom_router(state: MyTypedState) -> str:
|
||||
"""Custom routing logic with automatic type casting."""
|
||||
return route_error_severity(state)
|
||||
|
||||
# Use in graph construction
|
||||
builder = StateGraph(MyTypedState)
|
||||
builder.add_conditional_edges(
|
||||
"process_node",
|
||||
safe_error_router, # No type errors!
|
||||
{
|
||||
"retry": "process_node",
|
||||
"error": "error_handler",
|
||||
"continue": "next_node"
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### Async/Sync Factory Pattern
|
||||
|
||||
```python
|
||||
from biz_bud.core.networking.async_utils import (
|
||||
create_async_sync_wrapper,
|
||||
handle_sync_async_context
|
||||
)
|
||||
|
||||
# Create dual sync/async factories
|
||||
def resolve_config_sync(runnable_config):
|
||||
return load_config()
|
||||
|
||||
async def resolve_config_async(runnable_config):
|
||||
return await load_config_async()
|
||||
|
||||
# Generate both versions automatically
|
||||
sync_factory, async_factory = create_async_sync_wrapper(
|
||||
resolve_config_sync,
|
||||
resolve_config_async
|
||||
)
|
||||
|
||||
# Smart context handling
|
||||
graph = handle_sync_async_context(
|
||||
app_config,
|
||||
service_factory,
|
||||
lambda: create_graph_with_services(app_config, service_factory),
|
||||
lambda: get_graph() # sync fallback
|
||||
)
|
||||
```
|
||||
|
||||
### Resource Cleanup Integration
|
||||
|
||||
```python
|
||||
from biz_bud.core.cleanup_registry import get_cleanup_registry
|
||||
|
||||
async def setup_graph_with_cleanup():
|
||||
"""Set up graph with proper cleanup registration."""
|
||||
registry = get_cleanup_registry()
|
||||
|
||||
# Register graph cleanup
|
||||
registry.register_cleanup("cleanup_graph_cache", cleanup_graph_cache)
|
||||
|
||||
# Create graph
|
||||
graph = await create_production_graph()
|
||||
|
||||
# Cleanup will be called automatically on shutdown
|
||||
# or manually: await registry.cleanup_caches(["graph_cache"])
|
||||
|
||||
return graph
|
||||
```
|
||||
|
||||
## Graph Construction Best Practices
|
||||
|
||||
### 1. Use Factory Functions for Graph Creation
|
||||
@@ -43,7 +236,16 @@ graph = await create_analysis_workflow()
|
||||
|
||||
### 2. Leverage Service Factory Integration
|
||||
|
||||
**✅ Inject Services at Graph Level:**
|
||||
The service factory provides centralized service management with automatic lifecycle handling:
|
||||
|
||||
**Available Services:**
|
||||
- **LLM Clients**: OpenAI, Anthropic, Google, etc.
|
||||
- **Vector Stores**: Qdrant, Redis, PostgreSQL+pgvector
|
||||
- **Cache Backends**: Redis, Memory, File-based
|
||||
- **HTTP Clients**: Configured with retry and timeout
|
||||
- **Specialized Services**: Extraction, validation, search
|
||||
|
||||
**✅ Service Access Patterns:**
|
||||
```python
|
||||
from biz_bud.core.langgraph import standard_node
|
||||
from biz_bud.services.factory import get_global_factory
|
||||
@@ -52,10 +254,43 @@ from biz_bud.services.factory import get_global_factory
|
||||
async def analysis_node(state: dict) -> dict:
|
||||
"""Node with automatic service access."""
|
||||
factory = await get_global_factory()
|
||||
|
||||
# Get typed services with automatic initialization
|
||||
llm_client = await factory.get_llm_client()
|
||||
vector_store = await factory.get_vector_store()
|
||||
cache_backend = await factory.get_cache_backend()
|
||||
|
||||
# Services are automatically cleaned up on exit
|
||||
return await process_with_services(state, llm_client, vector_store)
|
||||
|
||||
# Alternative: Direct service access
|
||||
@standard_node
|
||||
async def search_node(state: dict) -> dict:
|
||||
"""Use specific service methods."""
|
||||
factory = await get_global_factory()
|
||||
|
||||
# Get search service
|
||||
search_service = await factory.get_service("SearchService")
|
||||
results = await search_service.search(state["query"])
|
||||
|
||||
return {"search_results": results}
|
||||
```
|
||||
|
||||
**✅ Service Configuration:**
|
||||
```python
|
||||
from biz_bud.core.config import AppConfig
|
||||
from biz_bud.services.factory import ServiceFactory
|
||||
|
||||
# Create factory with custom config
|
||||
config = AppConfig(
|
||||
llm_config={"model": "gpt-4", "temperature": 0.7},
|
||||
redis_config={"host": "localhost", "port": 6379}
|
||||
)
|
||||
|
||||
async with ServiceFactory(config) as factory:
|
||||
# Services use config automatically
|
||||
llm = await factory.get_llm_client() # Uses gpt-4
|
||||
cache = await factory.get_redis_backend() # Connects to localhost:6379
|
||||
```
|
||||
|
||||
### 3. Use Typed State Objects
|
||||
@@ -336,6 +571,131 @@ builder.add_conditional_edges(
|
||||
|
||||
## Complete Graph Construction Template
|
||||
|
||||
### Using All Core Utilities Together
|
||||
|
||||
```python
|
||||
from typing import Annotated
|
||||
import operator
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.types import Command
|
||||
|
||||
# Core utilities
|
||||
from biz_bud.core.config.loader import load_config_async, generate_config_hash
|
||||
from biz_bud.core.caching.cache_manager import GraphCache
|
||||
from biz_bud.core.utils.graph_helpers import (
|
||||
process_state_query,
|
||||
format_raw_input,
|
||||
create_initial_state_dict
|
||||
)
|
||||
from biz_bud.core.utils import normalize_errors_to_list
|
||||
from biz_bud.core.langgraph import (
|
||||
standard_node,
|
||||
handle_errors,
|
||||
log_node_execution,
|
||||
create_type_safe_wrapper,
|
||||
route_error_severity,
|
||||
route_llm_output
|
||||
)
|
||||
from biz_bud.core.networking.async_utils import handle_sync_async_context
|
||||
from biz_bud.core.cleanup_registry import get_cleanup_registry
|
||||
|
||||
# States and services
|
||||
from biz_bud.states import InputState
|
||||
from biz_bud.services.factory import get_global_factory
|
||||
|
||||
class ProductionState(InputState):
|
||||
"""Production workflow state with all fields."""
|
||||
processing_status: str = "pending"
|
||||
results: Annotated[list[dict], operator.add] = []
|
||||
confidence_score: float = 0.0
|
||||
|
||||
# Create type-safe routers
|
||||
safe_error_router = create_type_safe_wrapper(route_error_severity)
|
||||
safe_llm_router = create_type_safe_wrapper(route_llm_output)
|
||||
|
||||
@standard_node
|
||||
@handle_errors
|
||||
@log_node_execution
|
||||
async def intelligent_processing_node(state: ProductionState) -> dict:
|
||||
"""Process with full core utility integration."""
|
||||
factory = await get_global_factory()
|
||||
|
||||
# Normalize any errors
|
||||
state["errors"] = normalize_errors_to_list(state.get("errors"))
|
||||
|
||||
# Get services
|
||||
llm_client = await factory.get_llm_client()
|
||||
cache = await factory.get_cache_backend()
|
||||
|
||||
# Process with caching
|
||||
cache_key = f"process:{state.get('thread_id')}:{state.get('raw_input')}"
|
||||
cached_result = await cache.get(cache_key)
|
||||
|
||||
if cached_result:
|
||||
return {"results": [cached_result], "processing_status": "cached"}
|
||||
|
||||
# Process new
|
||||
result = await llm_client.llm_chat(
|
||||
prompt=state["raw_input"],
|
||||
system_prompt="Process this intelligently."
|
||||
)
|
||||
|
||||
# Cache result
|
||||
await cache.set(cache_key, result, ttl=3600)
|
||||
|
||||
return {
|
||||
"results": [{"output": result, "cached": False}],
|
||||
"processing_status": "completed",
|
||||
"confidence_score": 0.95
|
||||
}
|
||||
|
||||
async def create_production_graph_with_caching():
|
||||
"""Create graph with all core utilities."""
|
||||
# Load config
|
||||
config = await load_config_async()
|
||||
|
||||
# Generate cache key
|
||||
cache_key = generate_config_hash(config)
|
||||
|
||||
# Use graph cache
|
||||
graph_cache = GraphCache()
|
||||
|
||||
async def build_graph():
|
||||
builder = StateGraph(ProductionState)
|
||||
|
||||
# Add nodes
|
||||
builder.add_node("process", intelligent_processing_node)
|
||||
builder.add_node("error_handler", handle_graph_error)
|
||||
|
||||
# Use type-safe routing
|
||||
builder.add_conditional_edges(
|
||||
"process",
|
||||
safe_error_router,
|
||||
{
|
||||
"retry": "process",
|
||||
"error": "error_handler",
|
||||
"continue": END
|
||||
}
|
||||
)
|
||||
|
||||
builder.add_edge(START, "process")
|
||||
|
||||
# Compile with services
|
||||
factory = await get_global_factory()
|
||||
return builder.compile().with_config({
|
||||
"configurable": {"service_factory": factory}
|
||||
})
|
||||
|
||||
# Get or create cached graph
|
||||
graph = await graph_cache.get_or_create(cache_key, build_graph)
|
||||
|
||||
# Register cleanup
|
||||
registry = get_cleanup_registry()
|
||||
registry.register_cleanup("production_graph_cache", graph_cache.clear)
|
||||
|
||||
return graph
|
||||
```
|
||||
|
||||
### Production-Ready Graph Template
|
||||
|
||||
```python
|
||||
@@ -538,6 +898,12 @@ async def process_large_dataset(items: list[dict]) -> list[dict]:
|
||||
5. **Implement Proper Resource Management**: Use context managers and cleanup registries
|
||||
6. **Centralize Configuration**: Single source of truth through `AppConfig` and `RunnableConfig`
|
||||
7. **Reuse Edge Helpers**: Use project's routing utilities to prevent duplication
|
||||
8. **Leverage Core Utilities**:
|
||||
- `generate_config_hash` for cache keys
|
||||
- `create_type_safe_wrapper` for LangGraph compatibility
|
||||
- `normalize_errors_to_list` for consistent error handling
|
||||
- `GraphCache` for thread-safe graph caching
|
||||
- `handle_sync_async_context` for context-aware execution
|
||||
|
||||
### ❌ Avoid This
|
||||
|
||||
@@ -548,6 +914,53 @@ async def process_large_dataset(items: list[dict]) -> list[dict]:
|
||||
5. **Duplicate Routing Logic**: Don't recreate common routing patterns
|
||||
6. **Missing Error Handling**: Don't skip error handling decorators
|
||||
7. **Service Leaks**: Don't forget proper cleanup and context management
|
||||
8. **Reimplementing Core Utilities**: Don't duplicate functionality that exists in core
|
||||
|
||||
### Core Package Quick Reference
|
||||
|
||||
```python
|
||||
# Configuration
|
||||
from biz_bud.core.config.loader import load_config_async, generate_config_hash
|
||||
from biz_bud.core.config.schemas import AppConfig
|
||||
|
||||
# Async/Sync utilities
|
||||
from biz_bud.core.networking.async_utils import (
|
||||
create_async_sync_wrapper,
|
||||
handle_sync_async_context,
|
||||
detect_async_context,
|
||||
run_in_appropriate_context
|
||||
)
|
||||
|
||||
# State management
|
||||
from biz_bud.core.utils.graph_helpers import (
|
||||
process_state_query,
|
||||
format_raw_input,
|
||||
extract_state_update_data,
|
||||
create_initial_state_dict
|
||||
)
|
||||
from biz_bud.core.utils import normalize_errors_to_list
|
||||
|
||||
# Type compatibility
|
||||
from biz_bud.core.langgraph import (
|
||||
create_type_safe_wrapper,
|
||||
wrap_for_langgraph,
|
||||
standard_node,
|
||||
handle_errors,
|
||||
log_node_execution,
|
||||
route_error_severity,
|
||||
route_llm_output
|
||||
)
|
||||
|
||||
# Caching
|
||||
from biz_bud.core.caching.cache_manager import LLMCache, GraphCache
|
||||
from biz_bud.core.caching.decorators import redis_cache
|
||||
|
||||
# Cleanup
|
||||
from biz_bud.core.cleanup_registry import get_cleanup_registry
|
||||
|
||||
# Services
|
||||
from biz_bud.services.factory import get_global_factory, ServiceFactory
|
||||
```
|
||||
|
||||
### Migration Checklist
|
||||
|
||||
|
||||
@@ -13,10 +13,11 @@ from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from biz_bud.core.errors import create_error_info
|
||||
from biz_bud.core.langgraph import standard_node
|
||||
from biz_bud.logging import debug_highlight, error_highlight, info_highlight, warning_highlight
|
||||
from biz_bud.logging import debug_highlight, error_highlight, info_highlight
|
||||
|
||||
# Import from local nodes directory
|
||||
try:
|
||||
from .nodes.analysis import catalog_impact_analysis_node
|
||||
from .nodes.c_intel import (
|
||||
batch_analyze_components_node,
|
||||
find_affected_catalog_items_node,
|
||||
@@ -44,130 +45,12 @@ except ImportError:
|
||||
extract_components_from_sources_node = None
|
||||
research_catalog_item_components_node = None
|
||||
load_catalog_data_node = None
|
||||
|
||||
|
||||
@standard_node(node_name="catalog_impact_analysis", metric_name="impact_analysis")
|
||||
async def catalog_impact_analysis_node(
|
||||
state: dict[str, Any], config: RunnableConfig | None = None
|
||||
|
||||
) -> dict[str, Any]:
|
||||
"""Analyze the impact of changes on catalog items.
|
||||
|
||||
This node performs comprehensive impact analysis for catalog changes,
|
||||
including price changes, component substitutions, and availability updates.
|
||||
|
||||
Args:
|
||||
state: Current workflow state
|
||||
config: Optional runtime configuration
|
||||
|
||||
Returns:
|
||||
Updated state with impact analysis results
|
||||
"""
|
||||
info_highlight("Performing catalog impact analysis...", category="CatalogImpact")
|
||||
|
||||
# Get analysis parameters
|
||||
component_focus = state.get("component_focus", {})
|
||||
catalog_data = state.get("catalog_data", {})
|
||||
|
||||
if not component_focus or not catalog_data:
|
||||
warning_highlight("Missing component focus or catalog data", category="CatalogImpact")
|
||||
return {
|
||||
"impact_analysis": {},
|
||||
"analysis_metadata": {"message": "Insufficient data for impact analysis"},
|
||||
}
|
||||
|
||||
try:
|
||||
# Analyze impact across different dimensions
|
||||
impact_results: dict[str, Any] = {
|
||||
"affected_items": [],
|
||||
"cost_impact": {},
|
||||
"availability_impact": {},
|
||||
"quality_impact": {},
|
||||
"recommendations": [],
|
||||
}
|
||||
|
||||
# Extract component details
|
||||
component_name = component_focus.get("name", "")
|
||||
component_type = component_focus.get("type", "ingredient")
|
||||
|
||||
# Analyze affected catalog items
|
||||
affected_count = 0
|
||||
for category, items in catalog_data.items():
|
||||
if isinstance(items, list):
|
||||
for item in items:
|
||||
if isinstance(item, dict):
|
||||
# Check if item uses the component
|
||||
components = item.get("components", [])
|
||||
ingredients = item.get("ingredients", [])
|
||||
|
||||
uses_component = False
|
||||
if component_type == "ingredient" and ingredients:
|
||||
uses_component = any(
|
||||
component_name.lower() in str(ing).lower() for ing in ingredients
|
||||
)
|
||||
elif components:
|
||||
uses_component = any(
|
||||
component_name.lower() in str(comp).lower() for comp in components
|
||||
)
|
||||
|
||||
if uses_component:
|
||||
affected_count += 1
|
||||
impact_results["affected_items"].append(
|
||||
{
|
||||
"name": item.get("name", "Unknown"),
|
||||
"category": category,
|
||||
"price": item.get("price", 0),
|
||||
"dependency_level": "high", # Simplified
|
||||
}
|
||||
)
|
||||
|
||||
# Calculate aggregate impacts
|
||||
if affected_count > 0:
|
||||
impact_results["cost_impact"] = {
|
||||
"items_affected": affected_count,
|
||||
"potential_cost_increase": f"{affected_count * 5}%", # Simplified calculation
|
||||
"risk_level": "medium" if affected_count < 10 else "high",
|
||||
}
|
||||
|
||||
impact_results["recommendations"] = [
|
||||
f"Monitor {component_name} availability closely",
|
||||
f"Consider alternative components for {affected_count} affected items",
|
||||
"Update pricing strategy if costs increase significantly",
|
||||
]
|
||||
|
||||
info_highlight(
|
||||
f"Impact analysis completed: {affected_count} items affected", category="CatalogImpact"
|
||||
)
|
||||
|
||||
return {
|
||||
"impact_analysis": impact_results,
|
||||
"analysis_metadata": {
|
||||
"component_analyzed": component_name,
|
||||
"items_affected": affected_count,
|
||||
"analysis_depth": "comprehensive",
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Catalog impact analysis failed: {str(e)}"
|
||||
error_highlight(error_msg, category="CatalogImpact")
|
||||
return {
|
||||
"impact_analysis": {},
|
||||
"errors": [
|
||||
create_error_info(
|
||||
message=error_msg,
|
||||
node="catalog_impact_analysis",
|
||||
severity="error",
|
||||
category="analysis_error",
|
||||
)
|
||||
],
|
||||
}
|
||||
catalog_impact_analysis_node = None
|
||||
|
||||
|
||||
@standard_node(node_name="catalog_optimization", metric_name="catalog_optimization")
|
||||
async def catalog_optimization_node(
|
||||
state: dict[str, Any], config: RunnableConfig | None = None
|
||||
|
||||
) -> dict[str, Any]:
|
||||
"""Generate optimization recommendations for the catalog.
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import Any
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from biz_bud.core.types import create_error_info
|
||||
from biz_bud.core.utils.regex_security import findall_safe, search_safe
|
||||
from biz_bud.logging import error_highlight, get_logger, info_highlight
|
||||
from biz_bud.services.factory import ServiceFactory
|
||||
from biz_bud.states.catalog import CatalogIntelState
|
||||
@@ -39,15 +40,15 @@ def _is_component_match(component: str, item_component: str) -> bool:
|
||||
return True
|
||||
|
||||
# Check if component appears as a whole word in item_component
|
||||
# Use word boundaries to prevent substring matches
|
||||
# Use word boundaries to prevent substring matches with safe regex
|
||||
pattern = r"\b" + re.escape(component_lower) + r"\b"
|
||||
if re.search(pattern, item_component_lower):
|
||||
if search_safe(pattern, item_component_lower):
|
||||
return True
|
||||
|
||||
# Check reverse - if item_component appears as whole word in component
|
||||
# This handles cases like "goat" matching "goat meat"
|
||||
pattern = r"\b" + re.escape(item_component_lower) + r"\b"
|
||||
return bool(re.search(pattern, component_lower))
|
||||
return bool(search_safe(pattern, component_lower))
|
||||
|
||||
|
||||
async def identify_component_focus_node(
|
||||
@@ -143,17 +144,17 @@ async def identify_component_focus_node(
|
||||
found_components: list[str] = []
|
||||
content_lower = content.lower()
|
||||
for component in components:
|
||||
# Use word boundary matching to avoid false positives
|
||||
# Use word boundary matching to avoid false positives with safe regex
|
||||
pattern = r"\b" + re.escape(component.lower()) + r"\b"
|
||||
if re.search(pattern, content_lower):
|
||||
if search_safe(pattern, content_lower):
|
||||
found_components.append(component)
|
||||
_logger.info(f"Found component: {component}")
|
||||
|
||||
# Also look for context clues like "goat meat shortage" -> "goat"
|
||||
|
||||
# First check for specific meat shortages to prioritize them
|
||||
# First check for specific meat shortages to prioritize them using safe regex
|
||||
meat_shortage_pattern = r"(\w+\s+meat)\s+shortage"
|
||||
meat_shortage_matches = re.findall(meat_shortage_pattern, content, re.IGNORECASE)
|
||||
meat_shortage_matches = findall_safe(meat_shortage_pattern, content, flags=re.IGNORECASE)
|
||||
if meat_shortage_matches:
|
||||
# If we find a specific meat shortage, focus on that
|
||||
_logger.info(f"Found specific meat shortage: {meat_shortage_matches[0]}")
|
||||
@@ -175,7 +176,7 @@ async def identify_component_focus_node(
|
||||
]
|
||||
|
||||
for pattern in component_context_patterns:
|
||||
matches = re.findall(pattern, content, re.IGNORECASE)
|
||||
matches = findall_safe(pattern, content, flags=re.IGNORECASE)
|
||||
for match in matches:
|
||||
# Extract the base component (e.g., "goat" from "goat meat")
|
||||
match_clean = match.strip().lower()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -13,7 +13,7 @@ from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from biz_bud.core.errors import create_error_info
|
||||
from biz_bud.core.langgraph import standard_node
|
||||
from biz_bud.logging import debug_highlight, error_highlight, info_highlight, warning_highlight
|
||||
from biz_bud.logging import debug_highlight, error_highlight, info_highlight
|
||||
|
||||
# Import from local nodes directory
|
||||
try:
|
||||
@@ -22,6 +22,7 @@ try:
|
||||
paperless_orchestrator_node,
|
||||
paperless_search_node,
|
||||
)
|
||||
from .nodes.processing import process_document_node
|
||||
|
||||
# Legacy imports for compatibility
|
||||
paperless_upload_node = paperless_orchestrator_node # Alias
|
||||
@@ -35,94 +36,7 @@ except ImportError:
|
||||
paperless_orchestrator_node = None
|
||||
paperless_document_retrieval_node = None
|
||||
paperless_metadata_management_node = None
|
||||
|
||||
|
||||
@standard_node(node_name="paperless_document_processor", metric_name="document_processing")
|
||||
async def process_document_node(
|
||||
state: dict[str, Any], config: RunnableConfig | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Process documents for Paperless-NGX upload.
|
||||
|
||||
This node prepares documents by extracting metadata, applying tags,
|
||||
and formatting them for the Paperless-NGX API.
|
||||
|
||||
Args:
|
||||
state: Current workflow state
|
||||
config: Optional runtime configuration
|
||||
|
||||
Returns:
|
||||
Updated state with processed documents
|
||||
"""
|
||||
info_highlight("Processing documents for Paperless-NGX...", category="PaperlessProcessor")
|
||||
|
||||
documents = state.get("documents", [])
|
||||
if not documents:
|
||||
warning_highlight("No documents to process", category="PaperlessProcessor")
|
||||
return {"processed_documents": []}
|
||||
|
||||
try:
|
||||
processed_docs = []
|
||||
|
||||
for doc in documents:
|
||||
# Extract metadata
|
||||
processed = {
|
||||
"content": doc.get("content", ""),
|
||||
"title": doc.get("title", "Untitled"),
|
||||
"correspondent": doc.get("correspondent"),
|
||||
"document_type": doc.get("document_type", "general"),
|
||||
"tags": doc.get("tags", []),
|
||||
"created_date": doc.get("created_date"),
|
||||
"metadata": {
|
||||
"source": doc.get("source", "upload"),
|
||||
"original_filename": doc.get("filename"),
|
||||
"processing_timestamp": None, # Would add timestamp
|
||||
},
|
||||
}
|
||||
|
||||
# Auto-tag based on content (simplified)
|
||||
content = processed["content"]
|
||||
tags = processed["tags"]
|
||||
if isinstance(content, str) and isinstance(tags, list):
|
||||
content_lower = content.lower()
|
||||
if "invoice" in content_lower:
|
||||
tags.append("invoice")
|
||||
processed["document_type"] = "invoice"
|
||||
elif "receipt" in content_lower:
|
||||
tags.append("receipt")
|
||||
processed["document_type"] = "receipt"
|
||||
elif "contract" in content_lower:
|
||||
tags.append("contract")
|
||||
processed["document_type"] = "contract"
|
||||
|
||||
processed_docs.append(processed)
|
||||
|
||||
info_highlight(
|
||||
f"Processed {len(processed_docs)} documents for Paperless-NGX",
|
||||
category="PaperlessProcessor",
|
||||
)
|
||||
|
||||
return {
|
||||
"processed_documents": processed_docs,
|
||||
"processing_metadata": {
|
||||
"total_processed": len(processed_docs),
|
||||
"auto_tagged": sum(bool(d.get("tags")) for d in processed_docs),
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Document processing failed: {str(e)}"
|
||||
error_highlight(error_msg, category="PaperlessProcessor")
|
||||
return {
|
||||
"processed_documents": [],
|
||||
"errors": [
|
||||
create_error_info(
|
||||
message=error_msg,
|
||||
node="process_document",
|
||||
severity="error",
|
||||
category="processing_error",
|
||||
)
|
||||
],
|
||||
}
|
||||
process_document_node = None
|
||||
|
||||
|
||||
@standard_node(node_name="paperless_query_builder", metric_name="query_building")
|
||||
|
||||
@@ -538,15 +538,15 @@ async def check_r2r_duplicate_node(
|
||||
url_variations = _get_url_variations(url)
|
||||
if found_parent_url in url_variations:
|
||||
logger.info(
|
||||
" → Found as parent URL (site already scraped)"
|
||||
" -> Found as parent URL (site already scraped)"
|
||||
)
|
||||
elif found_source_url in url_variations:
|
||||
logger.info(" → Found as exact source URL match")
|
||||
logger.info(" -> Found as exact source URL match")
|
||||
elif found_sourceURL in url_variations:
|
||||
logger.info(" → Found as sourceURL match")
|
||||
logger.info(" -> Found as sourceURL match")
|
||||
else:
|
||||
logger.warning(
|
||||
f"⚠️ URL mismatch! Searched for '{url}' (variations: {url_variations}) but got source='{found_source_url}', parent='{found_parent_url}', sourceURL='{found_sourceURL}'"
|
||||
f"WARNING: URL mismatch! Searched for '{url}' (variations: {url_variations}) but got source='{found_source_url}', parent='{found_parent_url}', sourceURL='{found_sourceURL}'"
|
||||
)
|
||||
|
||||
# Cache positive result
|
||||
|
||||
@@ -1,21 +1,19 @@
|
||||
"""Firecrawl configuration loading utilities."""
|
||||
"""Firecrawl configuration loading utilities for RAG graph.
|
||||
|
||||
import os
|
||||
from typing import Any, NamedTuple
|
||||
This module re-exports the shared Firecrawl configuration utilities
|
||||
with specific defaults for the RAG graph use case.
|
||||
"""
|
||||
|
||||
from biz_bud.core.config.loader import load_config_async
|
||||
from biz_bud.core.errors import ConfigurationError
|
||||
from typing import Any
|
||||
|
||||
|
||||
class FirecrawlSettings(NamedTuple):
|
||||
"""Firecrawl API configuration settings."""
|
||||
|
||||
api_key: str | None
|
||||
base_url: str | None
|
||||
from biz_bud.nodes.integrations.firecrawl.config import FirecrawlSettings
|
||||
from biz_bud.nodes.integrations.firecrawl.config import (
|
||||
load_firecrawl_settings as _load_firecrawl_settings,
|
||||
)
|
||||
|
||||
|
||||
async def load_firecrawl_settings(state: dict[str, Any]) -> FirecrawlSettings:
|
||||
"""Load Firecrawl API settings from configuration and environment.
|
||||
"""Load Firecrawl API settings with RAG-specific defaults.
|
||||
|
||||
Args:
|
||||
state: The current workflow state containing configuration.
|
||||
@@ -24,38 +22,11 @@ async def load_firecrawl_settings(state: dict[str, Any]) -> FirecrawlSettings:
|
||||
FirecrawlSettings with api_key and base_url.
|
||||
|
||||
Raises:
|
||||
ValueError: If no API key is found in environment or configuration.
|
||||
ConfigurationError: If no API key is found (RAG requires API key).
|
||||
"""
|
||||
# First try to get from environment
|
||||
api_key = os.getenv("FIRECRAWL_API_KEY")
|
||||
base_url = os.getenv("FIRECRAWL_BASE_URL")
|
||||
# RAG graph always requires API key
|
||||
return await _load_firecrawl_settings(state, require_api_key=True)
|
||||
|
||||
# If not in environment, try to get from state config
|
||||
if not api_key or not base_url:
|
||||
config_dict = state.get("config", {})
|
||||
|
||||
# Try to get from state's api_config first
|
||||
api_config = config_dict.get("api_config", {})
|
||||
api_key = api_key or api_config.get("firecrawl_api_key")
|
||||
base_url = base_url or api_config.get("firecrawl_base_url")
|
||||
|
||||
# If still not found, load from app config
|
||||
if not api_key or not base_url:
|
||||
app_config = await load_config_async()
|
||||
if hasattr(app_config, "api_config"):
|
||||
if not api_key:
|
||||
api_key = getattr(app_config.api_config, "firecrawl_api_key", None)
|
||||
if not base_url:
|
||||
base_url = getattr(app_config.api_config, "firecrawl_base_url", None)
|
||||
|
||||
# Set default base URL if not specified
|
||||
if not base_url:
|
||||
base_url = "https://api.firecrawl.dev"
|
||||
|
||||
# Raise error if api_key is still missing
|
||||
if not api_key:
|
||||
raise ConfigurationError(
|
||||
"Firecrawl API key is required but was not found in environment or config."
|
||||
)
|
||||
|
||||
return FirecrawlSettings(api_key=api_key, base_url=base_url)
|
||||
# Re-export for backward compatibility
|
||||
__all__ = ["FirecrawlSettings", "load_firecrawl_settings"]
|
||||
|
||||
@@ -104,12 +104,15 @@ async def repomix_process_node(
|
||||
output_path = None
|
||||
tempfile_path = None
|
||||
try:
|
||||
# Create a temporary file for the output
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w+", suffix=".md", delete=False
|
||||
) as tmp_file:
|
||||
output_path = tmp_file.name
|
||||
tempfile_path = tmp_file.name
|
||||
# Create a temporary file path for the output
|
||||
# We use mktemp to get a unique filename, then create it with aiofiles
|
||||
temp_fd, tempfile_path = tempfile.mkstemp(suffix=".md")
|
||||
os.close(temp_fd) # Close the file descriptor immediately
|
||||
output_path = tempfile_path
|
||||
|
||||
# Create an empty file asynchronously
|
||||
async with aiofiles.open(tempfile_path, mode="w") as tmp_file:
|
||||
await tmp_file.write("") # Create empty file
|
||||
|
||||
# Build the repomix command
|
||||
cmd = [
|
||||
|
||||
@@ -7,6 +7,7 @@ from langchain_core.runnables import RunnableConfig
|
||||
|
||||
# Removed broken core import
|
||||
from biz_bud.core.langgraph import standard_node
|
||||
from biz_bud.core.utils.regex_security import search_safe
|
||||
from biz_bud.core.utils.url_analyzer import analyze_url_type
|
||||
from biz_bud.logging import get_logger
|
||||
from biz_bud.nodes import NodeLLMConfigOverride, call_model_node
|
||||
@@ -258,7 +259,6 @@ async def analyze_url_for_params_node(
|
||||
logger.error(f"Error analyzing URL: {e}")
|
||||
|
||||
# Try to extract explicit values from user input as fallback
|
||||
import re
|
||||
|
||||
# Set defaults
|
||||
max_depth = 2
|
||||
@@ -269,7 +269,7 @@ async def analyze_url_for_params_node(
|
||||
user_input_lower = user_input.lower()
|
||||
|
||||
# Look for explicit depth instructions
|
||||
depth_match = re.search(r"(?:max\s+)?depth\s+(?:of\s+)?(\d+)", user_input_lower)
|
||||
depth_match = search_safe(r"(?:max\s+)?depth\s+(?:of\s+)?(\d+)", user_input_lower)
|
||||
has_explicit_depth = False
|
||||
if depth_match:
|
||||
max_depth = min(5, int(depth_match.group(1)))
|
||||
@@ -277,7 +277,7 @@ async def analyze_url_for_params_node(
|
||||
logger.info(f"Extracted explicit max_depth={max_depth} from user input")
|
||||
|
||||
# Look for explicit page count
|
||||
pages_match = re.search(r"(\d+)\s+pages?", user_input_lower)
|
||||
pages_match = search_safe(r"(\d+)\s+pages?", user_input_lower)
|
||||
has_explicit_pages = False
|
||||
if pages_match:
|
||||
max_pages = min(1000, int(pages_match.group(1)))
|
||||
|
||||
@@ -13,6 +13,7 @@ from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from biz_bud.core import preserve_url_fields
|
||||
from biz_bud.core.errors import ValidationError
|
||||
from biz_bud.core.utils.regex_security import search_safe, sub_safe
|
||||
from biz_bud.logging import get_logger
|
||||
from biz_bud.tools.clients.r2r_utils import (
|
||||
authenticate_r2r_client,
|
||||
@@ -48,8 +49,6 @@ def _extract_meaningful_name_from_url(url: str) -> str:
|
||||
A clean, meaningful name extracted from the URL
|
||||
|
||||
"""
|
||||
import re
|
||||
|
||||
parsed = urlparse(url)
|
||||
|
||||
# Check if it's a git repository (GitHub, GitLab, Bitbucket)
|
||||
@@ -60,7 +59,7 @@ def _extract_meaningful_name_from_url(url: str) -> str:
|
||||
]
|
||||
|
||||
for pattern in git_patterns:
|
||||
if match := re.search(pattern, url):
|
||||
if match := search_safe(pattern, url):
|
||||
repo_name = match.group(1)
|
||||
# Remove .git extension if present
|
||||
if repo_name.endswith(".git"):
|
||||
@@ -91,7 +90,7 @@ def _extract_meaningful_name_from_url(url: str) -> str:
|
||||
return parsed.netloc
|
||||
|
||||
# Remove common prefixes
|
||||
domain = re.sub(r"^(www\.|docs\.|api\.|blog\.)", "", domain)
|
||||
domain = sub_safe(r"^(www\.|docs\.|api\.|blog\.)", "", domain)
|
||||
|
||||
# Extract the main part of the domain
|
||||
parts = domain.split(".")
|
||||
@@ -110,7 +109,7 @@ def _extract_meaningful_name_from_url(url: str) -> str:
|
||||
|
||||
# Clean up the name
|
||||
name = name.replace(".", "_").lower()
|
||||
name = re.sub(r"[^a-z0-9\-_]", "", name) or "website"
|
||||
name = sub_safe(r"[^a-z0-9\-_]", "", name) or "website"
|
||||
|
||||
return name
|
||||
|
||||
@@ -152,7 +151,7 @@ async def upload_to_r2r_node(
|
||||
"type": "status",
|
||||
"node": "r2r_upload",
|
||||
"stage": "initialization",
|
||||
"message": "🚀 Starting R2R upload process",
|
||||
"message": "Starting R2R upload process",
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
)
|
||||
@@ -237,7 +236,7 @@ async def upload_to_r2r_node(
|
||||
"type": "status",
|
||||
"node": "r2r_upload",
|
||||
"stage": "connection",
|
||||
"message": f"📡 Connecting to R2R at {r2r_config['base_url']}",
|
||||
"message": f"Connecting to R2R at {r2r_config['base_url']}",
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
)
|
||||
@@ -459,10 +458,8 @@ async def upload_to_r2r_node(
|
||||
]:
|
||||
title_slug = title_slug.replace(suffix, "")
|
||||
# Convert to URL-friendly slug
|
||||
import re
|
||||
|
||||
title_slug = re.sub(r"[^a-z0-9\s-]", "", title_slug)
|
||||
title_slug = re.sub(r"\s+", "-", title_slug.strip())
|
||||
title_slug = sub_safe(r"[^a-z0-9\s-]", "", title_slug)
|
||||
title_slug = sub_safe(r"\s+", "-", title_slug.strip())
|
||||
title_slug = title_slug[:40] # Limit length
|
||||
|
||||
# Create unique page URL with page number for guaranteed uniqueness
|
||||
@@ -679,13 +676,11 @@ async def upload_to_r2r_node(
|
||||
break
|
||||
|
||||
# Also check if it's the same content with just a page number difference
|
||||
import re
|
||||
|
||||
# Remove page numbers from both titles for comparison
|
||||
clean_title_no_page = re.sub(
|
||||
clean_title_no_page = sub_safe(
|
||||
r"\s*-?\s*page\s*\d+\s*", "", clean_title
|
||||
)
|
||||
existing_title_no_page = re.sub(
|
||||
existing_title_no_page = sub_safe(
|
||||
r"\s*-?\s*page\s*\d+\s*", "", existing_title
|
||||
)
|
||||
|
||||
|
||||
@@ -1,307 +0,0 @@
|
||||
"""Research-specific nodes for the research workflow graph.
|
||||
|
||||
This module contains nodes that are specific to the research workflow and
|
||||
wouldn't be reused in other graphs. These nodes handle query derivation,
|
||||
synthesis validation, and other research-specific operations.
|
||||
"""
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from biz_bud.core.errors import create_error_info
|
||||
from biz_bud.core.langgraph import standard_node
|
||||
from biz_bud.logging import debug_highlight, error_highlight, info_highlight, warning_highlight
|
||||
|
||||
# Import from local nodes directory
|
||||
try:
|
||||
_legacy_imports_available = True
|
||||
except ImportError:
|
||||
_legacy_imports_available = False
|
||||
legacy_derive_query_node = None
|
||||
legacy_validate_node = None
|
||||
|
||||
|
||||
@standard_node(node_name="research_query_derivation", metric_name="query_derivation")
|
||||
async def derive_research_query_node(
|
||||
state: dict[str, Any], config: RunnableConfig | None = None
|
||||
|
||||
) -> dict[str, Any]:
|
||||
"""Derive focused research queries from user input.
|
||||
|
||||
This node transforms general user queries into specific, actionable
|
||||
research queries optimized for web search and information gathering.
|
||||
|
||||
Args:
|
||||
state: Current workflow state containing user query
|
||||
config: Optional runtime configuration
|
||||
|
||||
Returns:
|
||||
Updated state with derived research queries
|
||||
"""
|
||||
info_highlight("Deriving research queries from user input...", category="QueryDerivation")
|
||||
|
||||
# Get user query
|
||||
user_query = state.get("query", "")
|
||||
if not user_query:
|
||||
warning_highlight("No user query found for derivation", category="QueryDerivation")
|
||||
return {"search_queries": []}
|
||||
|
||||
try:
|
||||
# Get LLM service for query expansion
|
||||
from biz_bud.services.factory import ServiceFactory
|
||||
|
||||
service_factory = ServiceFactory(config=state.get("config", {}))
|
||||
llm_service = await service_factory.get_llm_service()
|
||||
|
||||
# Create prompt for query derivation
|
||||
derivation_prompt = f"""
|
||||
Given the user's research request, derive 2-3 specific search queries that would help gather comprehensive information.
|
||||
|
||||
User request: {user_query}
|
||||
|
||||
Consider:
|
||||
1. Breaking down complex topics into searchable components
|
||||
2. Including relevant keywords and synonyms
|
||||
3. Focusing on authoritative sources
|
||||
4. Including time-relevant terms if needed
|
||||
|
||||
Provide 2-3 focused search queries, one per line.
|
||||
"""
|
||||
|
||||
# Get derived queries
|
||||
from langchain_core.messages import HumanMessage
|
||||
response = await llm_service.call_model_lc([HumanMessage(content=derivation_prompt)])
|
||||
|
||||
# Parse response
|
||||
derived_queries = []
|
||||
if hasattr(response, "content"):
|
||||
content = response.content
|
||||
# Split by newlines and clean (handle string content)
|
||||
content_str = content if isinstance(content, str) else str(content)
|
||||
for line in content_str.split("\n"):
|
||||
cleaned = line.strip()
|
||||
if cleaned and not cleaned.startswith("#") and len(cleaned) > 10:
|
||||
# Remove numbering if present
|
||||
if cleaned[0].isdigit() and cleaned[1] in ".):":
|
||||
cleaned = cleaned[2:].strip()
|
||||
derived_queries.append(cleaned)
|
||||
|
||||
# Always include original query
|
||||
if user_query not in derived_queries:
|
||||
derived_queries.insert(0, user_query)
|
||||
|
||||
# Limit to 3 queries
|
||||
derived_queries = derived_queries[:3]
|
||||
|
||||
info_highlight(
|
||||
f"Derived {len(derived_queries)} research queries from user input",
|
||||
category="QueryDerivation",
|
||||
)
|
||||
|
||||
return {
|
||||
"search_queries": derived_queries,
|
||||
"query_derivation_metadata": {
|
||||
"original_query": user_query,
|
||||
"derived_count": len(derived_queries),
|
||||
"method": "llm_expansion",
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Query derivation failed: {str(e)}"
|
||||
warning_highlight(error_msg, category="QueryDerivation")
|
||||
# Fall back to using original query
|
||||
return {
|
||||
"search_queries": [user_query],
|
||||
"query_derivation_metadata": {
|
||||
"original_query": user_query,
|
||||
"derived_count": 1,
|
||||
"method": "fallback",
|
||||
"error": str(e),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@standard_node(node_name="synthesize_research_results", metric_name="research_synthesis")
|
||||
async def synthesize_research_results_node(
|
||||
state: dict[str, Any], config: RunnableConfig | None = None
|
||||
|
||||
) -> dict[str, Any]:
|
||||
"""Synthesize research findings into a coherent response.
|
||||
|
||||
This node takes search results, extracted information, and other
|
||||
research data to create a comprehensive synthesis.
|
||||
|
||||
Args:
|
||||
state: Current workflow state with research data
|
||||
config: Optional runtime configuration
|
||||
|
||||
Returns:
|
||||
Updated state with research synthesis
|
||||
"""
|
||||
info_highlight("Synthesizing research results...", category="ResearchSynthesis")
|
||||
|
||||
# Gather all research data
|
||||
search_results = state.get("search_results", [])
|
||||
extracted_info = state.get("extracted_info", {})
|
||||
query = state.get("query", "")
|
||||
|
||||
if not search_results and not extracted_info:
|
||||
warning_highlight("No research data to synthesize", category="ResearchSynthesis")
|
||||
return {
|
||||
"synthesis": "No research data was found to synthesize.",
|
||||
"synthesis_metadata": {"sources_used": 0},
|
||||
}
|
||||
|
||||
try:
|
||||
# Build context for synthesis
|
||||
context_parts = []
|
||||
|
||||
# Add search results
|
||||
for idx, result in enumerate(search_results[:10]): # Limit to top 10
|
||||
title = result.get("title", "Untitled")
|
||||
content = result.get("content", "")
|
||||
url = result.get("url", "")
|
||||
context_parts.append(f"Source {idx + 1}: {title}\nURL: {url}\n{content[:500]}...\n")
|
||||
|
||||
# Add extracted information
|
||||
if extracted_info:
|
||||
context_parts.append("\nExtracted Key Information:")
|
||||
context_parts.extend(
|
||||
f"- {info['title']}: {info.get('chunks', 0)} chunks extracted"
|
||||
for info in extracted_info.values()
|
||||
if isinstance(info, dict) and info.get("title")
|
||||
)
|
||||
# Get LLM service
|
||||
from biz_bud.services.factory import ServiceFactory
|
||||
|
||||
service_factory = ServiceFactory(config=state.get("config", {}))
|
||||
llm_service = await service_factory.get_llm_service()
|
||||
|
||||
# Create synthesis prompt
|
||||
synthesis_prompt = f"""
|
||||
Based on the following research findings, provide a comprehensive synthesis that answers the user's query.
|
||||
|
||||
User Query: {query}
|
||||
|
||||
Research Findings:
|
||||
{chr(10).join(context_parts)}
|
||||
|
||||
Please provide a well-structured response that:
|
||||
1. Directly addresses the user's query
|
||||
2. Synthesizes information from multiple sources
|
||||
3. Highlights key findings and insights
|
||||
4. Notes any contradictions or gaps in the information
|
||||
5. Provides a clear, actionable summary
|
||||
|
||||
Keep the response focused and informative.
|
||||
"""
|
||||
|
||||
# Generate synthesis
|
||||
from langchain_core.messages import HumanMessage
|
||||
response = await llm_service.call_model_lc([HumanMessage(content=synthesis_prompt)])
|
||||
|
||||
synthesis = response.content if hasattr(response, "content") else ""
|
||||
info_highlight(
|
||||
f"Research synthesis completed using {len(search_results)} sources",
|
||||
category="ResearchSynthesis",
|
||||
)
|
||||
|
||||
return {
|
||||
"synthesis": synthesis,
|
||||
"synthesis_metadata": {
|
||||
"sources_used": len(search_results),
|
||||
"extraction_used": bool(extracted_info),
|
||||
"synthesis_length": len(synthesis),
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Research synthesis failed: {str(e)}"
|
||||
error_highlight(error_msg, category="ResearchSynthesis")
|
||||
return {
|
||||
"synthesis": "Failed to synthesize research results due to an error.",
|
||||
"errors": [
|
||||
create_error_info(
|
||||
message=error_msg,
|
||||
node="synthesize_research_results",
|
||||
severity="error",
|
||||
category="synthesis_error",
|
||||
)
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@standard_node(node_name="validate_research_synthesis", metric_name="synthesis_validation")
|
||||
async def validate_research_synthesis_node(
|
||||
state: dict[str, Any], config: RunnableConfig | None = None
|
||||
|
||||
) -> dict[str, Any]:
|
||||
"""Validate the quality and accuracy of research synthesis.
|
||||
|
||||
This node checks if the synthesis adequately addresses the user's query,
|
||||
is well-supported by sources, and meets quality standards.
|
||||
|
||||
Args:
|
||||
state: Current workflow state with synthesis
|
||||
config: Optional runtime configuration
|
||||
|
||||
Returns:
|
||||
Updated state with validation results
|
||||
"""
|
||||
debug_highlight("Validating research synthesis...", category="SynthesisValidation")
|
||||
|
||||
synthesis = state.get("synthesis", "")
|
||||
query = state.get("query", "")
|
||||
sources_count = state.get("synthesis_metadata", {}).get("sources_used", 0)
|
||||
|
||||
# Basic validation checks
|
||||
validation_results: dict[str, Any] = {"is_valid": True, "issues": [], "score": 100}
|
||||
|
||||
# Check synthesis length
|
||||
if len(synthesis) < 100:
|
||||
validation_results["issues"].append("Synthesis is too short")
|
||||
validation_results["score"] -= 20
|
||||
|
||||
# Check if synthesis addresses the query
|
||||
query_terms = query.lower().split()
|
||||
synthesis_lower = synthesis.lower()
|
||||
addressed_terms = sum(term in synthesis_lower for term in query_terms)
|
||||
if addressed_terms < len(query_terms) / 2:
|
||||
validation_results["issues"].append("Synthesis may not fully address the query")
|
||||
validation_results["score"] -= 15
|
||||
|
||||
# Check source usage
|
||||
if sources_count < 2:
|
||||
validation_results["issues"].append("Synthesis based on limited sources")
|
||||
validation_results["score"] -= 10
|
||||
|
||||
# Determine if synthesis is valid
|
||||
validation_results["is_valid"] = validation_results["score"] >= 70
|
||||
|
||||
info_highlight(
|
||||
f"Synthesis validation: score={validation_results['score']}, "
|
||||
f"valid={validation_results['is_valid']}",
|
||||
category="SynthesisValidation",
|
||||
)
|
||||
|
||||
return {
|
||||
"is_valid": validation_results["is_valid"],
|
||||
"validation_results": validation_results,
|
||||
"requires_human_feedback": validation_results["score"] < 80,
|
||||
}
|
||||
|
||||
|
||||
# Export all research-specific nodes
|
||||
__all__ = [
|
||||
"derive_research_query_node",
|
||||
"synthesize_research_results_node",
|
||||
"validate_research_synthesis_node",
|
||||
]
|
||||
|
||||
# Legacy imports are available but not re-exported to avoid F401 errors
|
||||
# They can be imported directly from their respective modules if needed
|
||||
@@ -215,7 +215,7 @@ def info_success(message: str, exc_info: bool | BaseException | None = None) ->
|
||||
message: The message to log.
|
||||
exc_info: Optional exception information to include in the log.
|
||||
"""
|
||||
_root_logger.info(f"✓ {message}", exc_info=exc_info)
|
||||
_root_logger.info(f"[OK] {message}", exc_info=exc_info)
|
||||
|
||||
|
||||
def info_highlight(
|
||||
@@ -236,7 +236,7 @@ def info_highlight(
|
||||
message = f"[{progress}] {message}"
|
||||
if category:
|
||||
message = f"[{category}] {message}"
|
||||
_root_logger.info(f"ℹ {message}", exc_info=exc_info)
|
||||
_root_logger.info(f"INFO: {message}", exc_info=exc_info)
|
||||
|
||||
|
||||
def warning_highlight(
|
||||
|
||||
@@ -43,6 +43,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from biz_bud.core.errors import ValidationError as CoreValidationError
|
||||
from biz_bud.core.errors import create_error_info, handle_exception_group
|
||||
from biz_bud.core.utils import normalize_errors_to_list
|
||||
from biz_bud.logging import debug_highlight, info_highlight, warning_highlight
|
||||
|
||||
# --- Pydantic models for runtime validation only ---
|
||||
@@ -268,9 +269,9 @@ async def parse_and_validate_initial_payload(
|
||||
if "organization" in raw_payload:
|
||||
org_val = raw_payload["organization"]
|
||||
if isinstance(org_val, list):
|
||||
# Filter to only include dict items
|
||||
filtered_org = [item for item in org_val if isinstance(item, dict)]
|
||||
if filtered_org:
|
||||
if filtered_org := [
|
||||
item for item in org_val if isinstance(item, dict)
|
||||
]:
|
||||
filtered_payload["organization"] = filtered_org
|
||||
|
||||
# Handle metadata field
|
||||
@@ -407,11 +408,8 @@ async def parse_and_validate_initial_payload(
|
||||
category="validation",
|
||||
)
|
||||
existing_errors = state.get("errors", [])
|
||||
if isinstance(existing_errors, list):
|
||||
existing_errors = existing_errors.copy()
|
||||
existing_errors.append(error_info)
|
||||
else:
|
||||
existing_errors = [error_info]
|
||||
existing_errors = normalize_errors_to_list(existing_errors).copy()
|
||||
existing_errors.append(error_info)
|
||||
updater = updater.set("errors", existing_errors)
|
||||
|
||||
# Update input_metadata
|
||||
|
||||
@@ -7,6 +7,7 @@ from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from biz_bud.core import ErrorCategory, ErrorInfo
|
||||
from biz_bud.core.langgraph import standard_node
|
||||
from biz_bud.core.utils.regex_security import search_safe
|
||||
from biz_bud.logging import get_logger
|
||||
from biz_bud.services.llm.client import LangchainLLMClient
|
||||
from biz_bud.states.error_handling import ErrorAnalysis, ErrorContext, ErrorHandlingState
|
||||
@@ -243,7 +244,7 @@ def _analyze_validation_error(error_message: str) -> ErrorAnalysis:
|
||||
|
||||
def _analyze_rate_limit_error(error_message: str) -> ErrorAnalysis:
|
||||
"""Analyze rate limit errors."""
|
||||
if wait_match := re.search(r"(\d+)\s*(second|minute|hour)", error_message.lower()):
|
||||
if wait_match := search_safe(r"(\d+)\s*(second|minute|hour)", error_message.lower()):
|
||||
wait_time = f"{wait_match[1]} {wait_match[2]}s"
|
||||
else:
|
||||
wait_time = None
|
||||
@@ -324,20 +325,20 @@ async def _llm_error_analysis(
|
||||
|
||||
# Extract root cause if more detailed
|
||||
if "root cause:" in content:
|
||||
if root_cause_match := re.search(r"root cause:\s*(.+?)(?:\n|$)", content):
|
||||
if root_cause_match := search_safe(r"root cause:\s*(.+?)(?:\n|$)", content):
|
||||
enhanced["root_cause"] = root_cause_match[1].strip()
|
||||
|
||||
# Extract additional suggested actions
|
||||
if "suggested actions:" in content or "recommendations:" in content:
|
||||
if action_section := re.search(
|
||||
if action_section := search_safe(
|
||||
r"(?:suggested actions|recommendations):\s*(.+?)(?:\n\n|$)",
|
||||
content,
|
||||
re.DOTALL,
|
||||
flags=re.DOTALL,
|
||||
):
|
||||
action_lines = action_section[1].strip().split("\n")
|
||||
actions = []
|
||||
for line in action_lines:
|
||||
if line := line.strip("- •*").strip():
|
||||
if line := line.strip("- *").strip():
|
||||
actions.append(line.lower().replace(" ", "_"))
|
||||
if actions:
|
||||
enhanced["suggested_actions"] = (
|
||||
|
||||
@@ -4,6 +4,7 @@ import os
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from biz_bud.core.config.loader import load_config_async
|
||||
from biz_bud.core.errors import ConfigurationError
|
||||
|
||||
|
||||
class FirecrawlSettings(NamedTuple):
|
||||
@@ -13,17 +14,21 @@ class FirecrawlSettings(NamedTuple):
|
||||
base_url: str | None
|
||||
|
||||
|
||||
async def load_firecrawl_settings(state: dict[str, Any]) -> FirecrawlSettings:
|
||||
async def load_firecrawl_settings(
|
||||
state: dict[str, Any],
|
||||
require_api_key: bool = False
|
||||
) -> FirecrawlSettings:
|
||||
"""Load Firecrawl API settings from configuration and environment.
|
||||
|
||||
Args:
|
||||
state: The current workflow state containing configuration.
|
||||
require_api_key: If True, raise ConfigurationError when API key is not found.
|
||||
|
||||
Returns:
|
||||
FirecrawlSettings with api_key and base_url.
|
||||
|
||||
Raises:
|
||||
ValueError: If no API key is found in environment or configuration.
|
||||
ConfigurationError: If require_api_key is True and no API key is found.
|
||||
"""
|
||||
# First try to get from environment
|
||||
api_key = os.getenv("FIRECRAWL_API_KEY")
|
||||
@@ -51,4 +56,10 @@ async def load_firecrawl_settings(state: dict[str, Any]) -> FirecrawlSettings:
|
||||
if not base_url:
|
||||
base_url = "https://api.firecrawl.dev"
|
||||
|
||||
# Raise error if api_key is still missing and required
|
||||
if not api_key and require_api_key:
|
||||
raise ConfigurationError(
|
||||
"Firecrawl API key is required but was not found in environment or config."
|
||||
)
|
||||
|
||||
return FirecrawlSettings(api_key=api_key, base_url=base_url)
|
||||
|
||||
@@ -133,6 +133,41 @@ class ContentNormalizer:
|
||||
error_code=ErrorNamespace.CFG_DEPENDENCY_MISSING,
|
||||
) from e
|
||||
|
||||
def _process_tokens(self, doc: Doc) -> list[str]:
|
||||
"""Process tokens from a spaCy document.
|
||||
|
||||
Args:
|
||||
doc: Processed spaCy document
|
||||
|
||||
Returns:
|
||||
List of processed tokens
|
||||
"""
|
||||
tokens = []
|
||||
for token in doc:
|
||||
# Skip punctuation and whitespace if configured
|
||||
if self.config.remove_punctuation and token.is_punct:
|
||||
continue
|
||||
if token.is_space:
|
||||
continue
|
||||
|
||||
# Get token text
|
||||
text = token.lemma_ if self.config.lemmatize and token.lemma_ else token.text
|
||||
# Normalize case
|
||||
if self.config.normalize_case:
|
||||
text = text.lower()
|
||||
|
||||
# Skip short tokens
|
||||
if len(text) < self.config.min_token_length:
|
||||
continue
|
||||
|
||||
# Skip stopwords if configured
|
||||
if self.config.remove_stopwords and token.is_stop:
|
||||
continue
|
||||
|
||||
tokens.append(text)
|
||||
|
||||
return tokens
|
||||
|
||||
def normalize_content(self, content: str) -> tuple[str, list[str]]:
|
||||
"""Normalize content for consistent fingerprinting.
|
||||
|
||||
@@ -158,30 +193,8 @@ class ContentNormalizer:
|
||||
# Process with spaCy
|
||||
doc: Doc = _nlp_model(content) # type: ignore[misc]
|
||||
|
||||
# Extract and process tokens
|
||||
tokens = []
|
||||
for token in doc:
|
||||
# Skip punctuation and whitespace if configured
|
||||
if self.config.remove_punctuation and token.is_punct:
|
||||
continue
|
||||
if token.is_space:
|
||||
continue
|
||||
|
||||
# Get token text
|
||||
text = token.lemma_ if self.config.lemmatize and token.lemma_ else token.text
|
||||
# Normalize case
|
||||
if self.config.normalize_case:
|
||||
text = text.lower()
|
||||
|
||||
# Skip short tokens
|
||||
if len(text) < self.config.min_token_length:
|
||||
continue
|
||||
|
||||
# Skip stopwords if configured
|
||||
if self.config.remove_stopwords and token.is_stop:
|
||||
continue
|
||||
|
||||
tokens.append(text)
|
||||
# Extract and process tokens using shared method
|
||||
tokens = self._process_tokens(doc)
|
||||
|
||||
# Create normalized content
|
||||
normalized_content = " ".join(tokens)
|
||||
@@ -228,26 +241,8 @@ class ContentNormalizer:
|
||||
doc = docs[valid_idx]
|
||||
valid_idx += 1
|
||||
|
||||
# Extract and process tokens (same logic as normalize_content)
|
||||
tokens = []
|
||||
for token in doc:
|
||||
if self.config.remove_punctuation and token.is_punct:
|
||||
continue
|
||||
if token.is_space:
|
||||
continue
|
||||
|
||||
text = token.lemma_ if self.config.lemmatize and token.lemma_ else token.text
|
||||
if self.config.normalize_case:
|
||||
text = text.lower()
|
||||
|
||||
if len(text) < self.config.min_token_length:
|
||||
continue
|
||||
|
||||
if self.config.remove_stopwords and token.is_stop:
|
||||
continue
|
||||
|
||||
tokens.append(text)
|
||||
|
||||
# Extract and process tokens using shared method
|
||||
tokens = self._process_tokens(doc)
|
||||
normalized_content = " ".join(tokens)
|
||||
results.append((normalized_content, tokens))
|
||||
|
||||
@@ -422,6 +417,27 @@ class LSHIndex:
|
||||
self.fingerprints[item_id] = fingerprint
|
||||
self.item_count += 1
|
||||
|
||||
def _calculate_bucket_hash(self, fingerprint: int, bucket_idx: int, bits_per_bucket: int) -> int:
|
||||
"""Calculate hash for a specific bucket.
|
||||
|
||||
Args:
|
||||
fingerprint: SimHash fingerprint as integer
|
||||
bucket_idx: Index of the bucket
|
||||
bits_per_bucket: Number of bits per bucket
|
||||
|
||||
Returns:
|
||||
Bucket hash value
|
||||
"""
|
||||
start_bit = bucket_idx * bits_per_bucket
|
||||
end_bit = start_bit + bits_per_bucket
|
||||
|
||||
bucket_hash = 0
|
||||
for bit_pos in range(start_bit, min(end_bit, self.config.simhash_bits)):
|
||||
if fingerprint & (1 << bit_pos):
|
||||
bucket_hash |= (1 << (bit_pos - start_bit))
|
||||
|
||||
return bucket_hash
|
||||
|
||||
def _add_simhash(self, item_id: str, fingerprint: int) -> None:
|
||||
"""Add SimHash fingerprint using bit sampling LSH.
|
||||
|
||||
@@ -429,19 +445,11 @@ class LSHIndex:
|
||||
item_id: Unique identifier for the item
|
||||
fingerprint: SimHash fingerprint as integer
|
||||
"""
|
||||
# Create multiple hash buckets using bit sampling
|
||||
num_buckets = 4 # Number of hash buckets for better recall
|
||||
bits_per_bucket = self.config.simhash_bits // num_buckets
|
||||
# Create multiple hash buckets using common parameters
|
||||
num_buckets, bits_per_bucket = self._get_bucket_params()
|
||||
|
||||
for bucket_idx in range(num_buckets):
|
||||
# Extract bits for this bucket
|
||||
start_bit = bucket_idx * bits_per_bucket
|
||||
end_bit = start_bit + bits_per_bucket
|
||||
|
||||
bucket_hash = 0
|
||||
for bit_pos in range(start_bit, min(end_bit, self.config.simhash_bits)):
|
||||
if fingerprint & (1 << bit_pos):
|
||||
bucket_hash |= (1 << (bit_pos - start_bit))
|
||||
bucket_hash = self._calculate_bucket_hash(fingerprint, bucket_idx, bits_per_bucket)
|
||||
|
||||
# Add to bucket
|
||||
bucket_key = (bucket_idx, bucket_hash)
|
||||
@@ -465,6 +473,16 @@ class LSHIndex:
|
||||
results = self.minhash_lsh.query(fingerprint)
|
||||
return [str(result) for result in results][:max_results]
|
||||
|
||||
def _get_bucket_params(self) -> tuple[int, int]:
|
||||
"""Get common bucket parameters.
|
||||
|
||||
Returns:
|
||||
Tuple of (num_buckets, bits_per_bucket)
|
||||
"""
|
||||
num_buckets = 4 # Number of hash buckets for better recall
|
||||
bits_per_bucket = self.config.simhash_bits // num_buckets
|
||||
return num_buckets, bits_per_bucket
|
||||
|
||||
def _query_simhash(self, fingerprint: int, max_results: int) -> list[str]:
|
||||
"""Query SimHash LSH index.
|
||||
|
||||
@@ -477,18 +495,11 @@ class LSHIndex:
|
||||
"""
|
||||
candidates: set[str] = set()
|
||||
|
||||
# Query all buckets
|
||||
num_buckets = 4
|
||||
bits_per_bucket = self.config.simhash_bits // num_buckets
|
||||
# Query all buckets using common parameters
|
||||
num_buckets, bits_per_bucket = self._get_bucket_params()
|
||||
|
||||
for bucket_idx in range(num_buckets):
|
||||
start_bit = bucket_idx * bits_per_bucket
|
||||
end_bit = start_bit + bits_per_bucket
|
||||
|
||||
bucket_hash = 0
|
||||
for bit_pos in range(start_bit, min(end_bit, self.config.simhash_bits)):
|
||||
if fingerprint & (1 << bit_pos):
|
||||
bucket_hash |= (1 << (bit_pos - start_bit))
|
||||
bucket_hash = self._calculate_bucket_hash(fingerprint, bucket_idx, bits_per_bucket)
|
||||
|
||||
bucket_key = (bucket_idx, bucket_hash)
|
||||
if bucket_key in self.simhash_buckets:
|
||||
|
||||
@@ -123,7 +123,7 @@ async def optimized_search_node(
|
||||
try:
|
||||
# Step 1: Optimize queries
|
||||
logger.info(f"Optimizing {len(queries)} search queries")
|
||||
optimized_queries = await query_optimizer.optimize_batch(
|
||||
optimized_queries = query_optimizer.optimize_batch(
|
||||
queries=queries, context=context
|
||||
)
|
||||
|
||||
|
||||
@@ -52,9 +52,9 @@ class QueryOptimizer:
|
||||
self, raw_queries: list[str], context: str = ""
|
||||
) -> list[OptimizedQuery]:
|
||||
"""Optimize a list of queries for better search results."""
|
||||
return await self.optimize_batch(queries=raw_queries, context=context)
|
||||
return self.optimize_batch(queries=raw_queries, context=context)
|
||||
|
||||
async def optimize_batch(
|
||||
def optimize_batch(
|
||||
self, queries: list[str], context: str = ""
|
||||
) -> list[OptimizedQuery]:
|
||||
"""Convert raw queries into optimized search queries.
|
||||
|
||||
@@ -145,6 +145,9 @@ class ConcurrentSearchOrchestrator:
|
||||
self.provider_failures: dict[str, list[ProviderFailure]] = defaultdict(list)
|
||||
self.provider_circuit_open: dict[str, bool] = defaultdict(bool)
|
||||
|
||||
# Track background tasks to prevent garbage collection
|
||||
self._background_tasks: set[asyncio.Task[None]] = set()
|
||||
|
||||
async def execute_search_batch(
|
||||
self, batch: SearchBatch, use_cache: bool = True, min_results_per_query: int = 3
|
||||
) -> dict[str, dict[str, list[SearchResult]] | dict[str, dict[str, int | float]]]:
|
||||
@@ -403,8 +406,11 @@ class ConcurrentSearchOrchestrator:
|
||||
# Open circuit if too many recent failures
|
||||
if len(failures) >= 3:
|
||||
self.provider_circuit_open[provider] = True
|
||||
# Schedule circuit reset
|
||||
asyncio.create_task(self._reset_circuit(provider))
|
||||
# Schedule circuit reset and save task to prevent GC
|
||||
task = asyncio.create_task(self._reset_circuit(provider))
|
||||
self._background_tasks.add(task)
|
||||
# Remove task from set when complete
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
async def _reset_circuit(self, provider: str, delay: int = 60) -> None:
|
||||
"""Reset circuit breaker after delay."""
|
||||
|
||||
@@ -778,13 +778,13 @@ examples:
|
||||
task: "should I invest in apple stocks?"
|
||||
response:
|
||||
{
|
||||
"server": "💰 Finance Agent",
|
||||
"server": "Finance Agent",
|
||||
"agent_role_prompt: "You are a seasoned finance analyst AI assistant. Your primary goal is to compose comprehensive, astute, impartial, and methodically arranged financial reports based on provided data and trends."
|
||||
}
|
||||
task: "could reselling sneakers become profitable?"
|
||||
response:
|
||||
{
|
||||
"server": "📈 Business Analyst Agent",
|
||||
"server": "Business Analyst Agent",
|
||||
"agent_role_prompt": "You are an experienced AI business analyst assistant. Your main objective is to produce comprehensive, insightful, impartial, and systematically structured business reports based on provided business data, market trends, and strategic analysis."
|
||||
}
|
||||
"""
|
||||
|
||||
@@ -380,13 +380,8 @@ class ServiceContainer:
|
||||
f"Resolved dependencies for {service_type.__name__}: {list(resolved_deps.keys())}"
|
||||
)
|
||||
|
||||
# Create service instance
|
||||
if inspect.isclass(factory):
|
||||
# Constructor injection
|
||||
service = factory(self.config, **resolved_deps)
|
||||
else:
|
||||
# Factory function
|
||||
service = factory(self.config, **resolved_deps)
|
||||
# Create service instance (works for both classes and factory functions)
|
||||
service = factory(self.config, **resolved_deps)
|
||||
|
||||
# Initialize if it's a BaseService
|
||||
if hasattr(service, "initialize"):
|
||||
|
||||
@@ -21,6 +21,7 @@ from .service_factory import (
|
||||
cleanup_global_factory,
|
||||
ensure_healthy_global_factory,
|
||||
force_cleanup_global_factory,
|
||||
get_cached_factory_for_config,
|
||||
get_global_factory,
|
||||
get_global_factory_manager,
|
||||
is_global_factory_initialized,
|
||||
@@ -35,6 +36,7 @@ __all__ = [
|
||||
"LLMClientWrapper",
|
||||
# Global factory functions
|
||||
"get_global_factory",
|
||||
"get_cached_factory_for_config",
|
||||
"set_global_factory",
|
||||
"cleanup_global_factory",
|
||||
"is_global_factory_initialized",
|
||||
|
||||
@@ -292,6 +292,65 @@ class ServiceFactory:
|
||||
|
||||
return await self.get_service(TavilyClient)
|
||||
|
||||
async def _get_service_with_dependencies(
|
||||
self,
|
||||
service_class: type[T],
|
||||
dependencies: dict[str, BaseService[Any]]
|
||||
) -> T: # pyrefly: ignore
|
||||
"""Get service with dependency injection (shared helper method).
|
||||
|
||||
Args:
|
||||
service_class: The service class to instantiate
|
||||
dependencies: Dictionary of dependencies to inject
|
||||
|
||||
Returns:
|
||||
An initialized service instance
|
||||
"""
|
||||
# Check if already exists
|
||||
if service_class in self._services:
|
||||
return cast("T", self._services[service_class])
|
||||
|
||||
task_to_await: asyncio.Task[BaseService[Any]] | None = None
|
||||
|
||||
async with self._creation_lock:
|
||||
# Double-check after lock
|
||||
if service_class in self._services:
|
||||
return cast("T", self._services[service_class])
|
||||
|
||||
# Check if initialization is already in progress
|
||||
if service_class in self._initializing:
|
||||
task_to_await = self._initializing[service_class]
|
||||
else:
|
||||
# Create initialization task using cleanup registry
|
||||
task_to_await = asyncio.create_task(
|
||||
self._cleanup_registry.create_service_with_dependencies(
|
||||
service_class,
|
||||
dependencies
|
||||
)
|
||||
)
|
||||
self._initializing[service_class] = task_to_await
|
||||
|
||||
# task_to_await is guaranteed to be set by the logic above
|
||||
assert task_to_await is not None
|
||||
|
||||
try:
|
||||
service = await task_to_await
|
||||
self._services[service_class] = service
|
||||
return cast("T", service)
|
||||
except asyncio.CancelledError:
|
||||
logger.warning(f"{service_class.__name__} initialization cancelled")
|
||||
raise
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"{service_class.__name__} initialization timed out")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize {service_class.__name__}: {e}")
|
||||
logger.exception(f"{service_class.__name__} initialization exception details")
|
||||
raise
|
||||
finally:
|
||||
# Cleanup is handled by cleanup registry
|
||||
self._initializing.pop(service_class, None)
|
||||
|
||||
async def get_semantic_extraction(self) -> "SemanticExtractionService": # pyrefly: ignore
|
||||
"""Get the semantic extraction service with dependency injection.
|
||||
|
||||
@@ -301,62 +360,18 @@ class ServiceFactory:
|
||||
"""
|
||||
from biz_bud.services.semantic_extraction import SemanticExtractionService
|
||||
|
||||
# Check if already exists
|
||||
if SemanticExtractionService in self._services:
|
||||
return cast(
|
||||
"SemanticExtractionService", self._services[SemanticExtractionService]
|
||||
)
|
||||
|
||||
# Get dependencies first
|
||||
llm_client = await self.get_llm_client()
|
||||
vector_store = await self.get_vector_store()
|
||||
|
||||
task_to_await: asyncio.Task[BaseService[Any]] | None = None
|
||||
|
||||
async with self._creation_lock:
|
||||
# Double-check after lock
|
||||
if SemanticExtractionService in self._services:
|
||||
return cast(
|
||||
"SemanticExtractionService",
|
||||
self._services[SemanticExtractionService],
|
||||
)
|
||||
|
||||
# Check if initialization is already in progress
|
||||
if SemanticExtractionService in self._initializing:
|
||||
task_to_await = self._initializing[SemanticExtractionService]
|
||||
else:
|
||||
# Create initialization task using cleanup registry
|
||||
task_to_await = asyncio.create_task(
|
||||
self._cleanup_registry.create_service_with_dependencies(
|
||||
SemanticExtractionService,
|
||||
{
|
||||
"llm_client": llm_client,
|
||||
"vector_store": vector_store,
|
||||
}
|
||||
)
|
||||
)
|
||||
self._initializing[SemanticExtractionService] = task_to_await
|
||||
|
||||
# task_to_await is guaranteed to be set by the logic above
|
||||
assert task_to_await is not None
|
||||
|
||||
try:
|
||||
service = await task_to_await
|
||||
self._services[SemanticExtractionService] = service
|
||||
return cast("SemanticExtractionService", service)
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("SemanticExtractionService initialization cancelled")
|
||||
raise
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("SemanticExtractionService initialization timed out")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize SemanticExtractionService: {e}")
|
||||
logger.exception("SemanticExtractionService initialization exception details")
|
||||
raise
|
||||
finally:
|
||||
# Cleanup is handled by cleanup registry
|
||||
self._initializing.pop(SemanticExtractionService, None)
|
||||
# Use shared helper method
|
||||
return await self._get_service_with_dependencies(
|
||||
SemanticExtractionService,
|
||||
{
|
||||
"llm_client": llm_client,
|
||||
"vector_store": vector_store,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def get_llm_for_node(
|
||||
@@ -404,7 +419,6 @@ class ServiceFactory:
|
||||
resolved_profile = config.get("model", llm_profile_override or "large")
|
||||
# Remove model from kwargs to avoid duplication
|
||||
llm_kwargs = {k: v for k, v in config.items() if k != "model"}
|
||||
llm_kwargs = config
|
||||
|
||||
# Add any remaining kwargs that weren't passed to build_llm_config
|
||||
for key, value in kwargs.items():
|
||||
@@ -602,6 +616,9 @@ class _GlobalFactoryManager(AsyncFactoryManager["ServiceFactory"]):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Cache for configuration-specific factories
|
||||
self._config_cache: dict[str, ServiceFactory] = {}
|
||||
self._cache_lock = asyncio.Lock()
|
||||
|
||||
@staticmethod
|
||||
def _create_service_factory(config: AppConfig) -> ServiceFactory:
|
||||
@@ -616,6 +633,46 @@ class _GlobalFactoryManager(AsyncFactoryManager["ServiceFactory"]):
|
||||
"""Ensure we have a healthy factory, recreating if necessary."""
|
||||
return await super().ensure_healthy_factory(self._create_service_factory, config)
|
||||
|
||||
async def get_factory_for_config(self, config_hash: str, config: AppConfig) -> ServiceFactory:
|
||||
"""Get or create a factory for a specific configuration hash.
|
||||
|
||||
This method provides thread-safe caching of factories by configuration,
|
||||
allowing multiple configurations to coexist without conflicts.
|
||||
|
||||
Args:
|
||||
config_hash: Hash of the configuration for caching
|
||||
config: Application configuration
|
||||
|
||||
Returns:
|
||||
ServiceFactory instance for the given configuration
|
||||
"""
|
||||
# Fast path: check cache without lock
|
||||
if config_hash in self._config_cache:
|
||||
return self._config_cache[config_hash]
|
||||
|
||||
# Slow path: acquire lock and create if needed
|
||||
async with self._cache_lock:
|
||||
# Double-check pattern
|
||||
if config_hash in self._config_cache:
|
||||
return self._config_cache[config_hash]
|
||||
|
||||
# Create new factory
|
||||
logger.info(f"Creating new service factory for config hash: {config_hash}")
|
||||
factory = self._create_service_factory(config)
|
||||
self._config_cache[config_hash] = factory
|
||||
return factory
|
||||
|
||||
async def cleanup_config_cache(self) -> None:
|
||||
"""Clean up all cached configuration-specific factories."""
|
||||
async with self._cache_lock:
|
||||
for config_hash, factory in self._config_cache.items():
|
||||
try:
|
||||
await factory.cleanup()
|
||||
logger.info(f"Cleaned up factory for config hash: {config_hash}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up factory {config_hash}: {e}")
|
||||
self._config_cache.clear()
|
||||
|
||||
|
||||
# Global factory manager instance
|
||||
_global_factory_manager = _GlobalFactoryManager()
|
||||
@@ -632,6 +689,22 @@ async def get_global_factory(config: AppConfig | None = None) -> ServiceFactory:
|
||||
return await _global_factory_manager.get_factory(config)
|
||||
|
||||
|
||||
async def get_cached_factory_for_config(config_hash: str, config: AppConfig) -> ServiceFactory:
|
||||
"""Get or create a cached factory for a specific configuration.
|
||||
|
||||
This function provides thread-safe caching of factories by configuration hash,
|
||||
avoiding the deadlock issues that can occur with nested lock acquisition.
|
||||
|
||||
Args:
|
||||
config_hash: Hash of the configuration for caching
|
||||
config: Application configuration
|
||||
|
||||
Returns:
|
||||
ServiceFactory instance for the given configuration
|
||||
"""
|
||||
return await _global_factory_manager.get_factory_for_config(config_hash, config)
|
||||
|
||||
|
||||
def set_global_factory(factory: ServiceFactory) -> None:
|
||||
"""Set the global factory instance."""
|
||||
_global_factory_manager.set_factory(factory)
|
||||
@@ -670,7 +743,7 @@ def reset_global_factory_state() -> None:
|
||||
|
||||
async def check_global_factory_health() -> bool:
|
||||
"""Check if the global factory is healthy and functional."""
|
||||
return await _global_factory_manager.check_factory_health()
|
||||
return _global_factory_manager.check_factory_health()
|
||||
|
||||
|
||||
async def ensure_healthy_global_factory(config: AppConfig | None = None) -> ServiceFactory:
|
||||
|
||||
@@ -882,7 +882,7 @@ class LangchainLLMClient(BaseService[LLMServiceConfig]):
|
||||
(
|
||||
should_retry,
|
||||
delay,
|
||||
) = await self._exception_handler.handle_llm_exception(
|
||||
) = self._exception_handler.handle_llm_exception(
|
||||
last_exception, attempt, max_retries, base_delay, max_delay
|
||||
)
|
||||
if should_retry and delay is not None:
|
||||
|
||||
@@ -829,7 +829,7 @@ class VectorStore(BaseService[VectorStoreConfig]):
|
||||
formatted_results = []
|
||||
for result in results:
|
||||
score = result.get("score", 0.0)
|
||||
if score == 0.0 and "score" not in result:
|
||||
if abs(score) < 1e-9 and "score" not in result:
|
||||
logger.warning(
|
||||
"Vector search result missing score - defaulting to 0.0: %s",
|
||||
result.get("id", "unknown"),
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Browser automation tool for scraping web pages using Selenium."""
|
||||
|
||||
import asyncio
|
||||
import types
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol, Self
|
||||
@@ -70,16 +71,13 @@ class Browser:
|
||||
async def open(self, url: str, wait_time: float = 0) -> None:
|
||||
"""Open a URL.
|
||||
|
||||
Note: This method contains blocking operations (driver.get and time.sleep).
|
||||
Consider using asyncio.to_thread() if running in an async context to avoid
|
||||
blocking the event loop.
|
||||
Note: driver.get is a blocking operation. Consider using asyncio.to_thread()
|
||||
for better async performance with many concurrent requests.
|
||||
"""
|
||||
# TODO: Consider wrapping blocking operations with asyncio.to_thread()
|
||||
# TODO: Consider wrapping driver.get with asyncio.to_thread() for better concurrency
|
||||
self.driver.get(url)
|
||||
if wait_time > 0:
|
||||
import time
|
||||
|
||||
time.sleep(wait_time)
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
def get_page_content(self) -> str:
|
||||
"""Get page content."""
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""Document processing tools for markdown, text, and various file formats."""
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from biz_bud.core.utils.regex_security import findall_safe, search_safe, sub_safe
|
||||
from biz_bud.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -63,52 +64,69 @@ def extract_markdown_metadata(content: str) -> dict[str, Any]:
|
||||
Dictionary with extracted metadata and document structure
|
||||
"""
|
||||
try:
|
||||
# Extract headers
|
||||
# Extract headers using safe regex
|
||||
headers = []
|
||||
header_pattern = r"^(#{1,6})\s+(.*?)$"
|
||||
for match in re.finditer(header_pattern, content, re.MULTILINE):
|
||||
level = len(match.group(1))
|
||||
text = match.group(2).strip()
|
||||
headers.append(
|
||||
{
|
||||
"level": level,
|
||||
"text": text,
|
||||
"id": _slugify(text),
|
||||
}
|
||||
)
|
||||
matches = cast(list[tuple[str, str]], findall_safe(header_pattern, content, flags=re.MULTILINE))
|
||||
for match in matches:
|
||||
if isinstance(match, tuple) and len(match) == 2:
|
||||
# Type narrowing for pyrefly - we've verified it's a 2-tuple
|
||||
level_markers_raw, text_raw = cast(tuple[str, str], match)
|
||||
level_markers = str(level_markers_raw)
|
||||
text = str(text_raw).strip()
|
||||
level = len(level_markers)
|
||||
headers.append(
|
||||
{
|
||||
"level": level,
|
||||
"text": text,
|
||||
"id": _slugify(text),
|
||||
}
|
||||
)
|
||||
|
||||
# Extract links (exclude images which start with !)
|
||||
# Extract links (exclude images which start with !) using safe regex
|
||||
links = []
|
||||
link_pattern = r"(?<!\!)\[([^\]]+)\]\(([^)]+)\)"
|
||||
for match in re.finditer(link_pattern, content):
|
||||
links.append(
|
||||
{
|
||||
"text": match.group(1),
|
||||
"url": match.group(2),
|
||||
"is_internal": not match.group(2).startswith(
|
||||
("http://", "https://")
|
||||
),
|
||||
}
|
||||
)
|
||||
link_matches = cast(list[tuple[str, str]], findall_safe(link_pattern, content))
|
||||
for match in link_matches:
|
||||
if isinstance(match, tuple) and len(match) == 2:
|
||||
# Type narrowing for pyrefly - we've verified it's a 2-tuple
|
||||
text_raw, url_raw = cast(tuple[str, str], match)
|
||||
text_str = str(text_raw)
|
||||
url_str = str(url_raw)
|
||||
links.append(
|
||||
{
|
||||
"text": text_str,
|
||||
"url": url_str,
|
||||
"is_internal": not url_str.startswith(
|
||||
("http://", "https://")
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# Extract images
|
||||
# Extract images using safe regex
|
||||
images = []
|
||||
image_pattern = r"!\[([^\]]*)\]\(([^)]+)\)"
|
||||
for match in re.finditer(image_pattern, content):
|
||||
images.append(
|
||||
{
|
||||
"alt_text": match.group(1),
|
||||
"url": match.group(2),
|
||||
"is_internal": not match.group(2).startswith(
|
||||
("http://", "https://")
|
||||
),
|
||||
}
|
||||
)
|
||||
image_matches = cast(list[tuple[str, str]], findall_safe(image_pattern, content))
|
||||
for match in image_matches:
|
||||
if isinstance(match, tuple) and len(match) == 2:
|
||||
# Type narrowing for pyrefly - we've verified it's a 2-tuple
|
||||
alt_text_raw, url_raw = cast(tuple[str, str], match)
|
||||
alt_text_str = str(alt_text_raw)
|
||||
url_str = str(url_raw)
|
||||
images.append(
|
||||
{
|
||||
"alt_text": alt_text_str,
|
||||
"url": url_str,
|
||||
"is_internal": not url_str.startswith(
|
||||
("http://", "https://")
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# Extract frontmatter if present
|
||||
frontmatter = {}
|
||||
if content.startswith("---"):
|
||||
if end_match := re.search(r"\n---\n", content):
|
||||
if end_match := search_safe(r"\n---\n", content):
|
||||
frontmatter_text = content[3 : end_match.start()]
|
||||
# Simple YAML-like parsing for basic frontmatter
|
||||
for line in frontmatter_text.split("\n"):
|
||||
@@ -217,40 +235,41 @@ def extract_code_blocks_from_markdown(
|
||||
try:
|
||||
code_blocks = []
|
||||
|
||||
# Pattern for fenced code blocks
|
||||
# Pattern for fenced code blocks using safe regex
|
||||
if language:
|
||||
pattern = rf"```{re.escape(language)}\n(.*?)```"
|
||||
else:
|
||||
pattern = r"```(\w*)\n(.*?)```"
|
||||
|
||||
for match in re.finditer(pattern, content, re.DOTALL):
|
||||
if language:
|
||||
matches = findall_safe(pattern, content, flags=re.DOTALL)
|
||||
for code_content in matches:
|
||||
code_blocks.append(
|
||||
{
|
||||
"language": language,
|
||||
"code": match.group(1).strip(),
|
||||
"line_count": len(match.group(1).strip().split("\n")),
|
||||
}
|
||||
)
|
||||
else:
|
||||
lang = match.group(1) or "text"
|
||||
code_blocks.append(
|
||||
{
|
||||
"language": lang,
|
||||
"code": match.group(2).strip(),
|
||||
"line_count": len(match.group(2).strip().split("\n")),
|
||||
"code": code_content.strip(),
|
||||
"line_count": len(code_content.strip().split("\n")),
|
||||
}
|
||||
)
|
||||
else:
|
||||
pattern = r"```(\w*)\n(.*?)```"
|
||||
matches = cast(list[tuple[str, str]], findall_safe(pattern, content, flags=re.DOTALL))
|
||||
for match in matches:
|
||||
if isinstance(match, tuple) and len(match) == 2:
|
||||
# Type narrowing for pyrefly - we've verified it's a 2-tuple
|
||||
lang_raw, code_content_raw = cast(tuple[str, str], match)
|
||||
lang_str = str(lang_raw) or "text"
|
||||
code_content_str = str(code_content_raw).strip()
|
||||
code_blocks.append(
|
||||
{
|
||||
"language": lang_str,
|
||||
"code": code_content_str,
|
||||
"line_count": len(code_content_str.split("\n")),
|
||||
}
|
||||
)
|
||||
|
||||
# Also extract inline code (but not from within code blocks)
|
||||
inline_code = []
|
||||
# First remove all code blocks from consideration
|
||||
content_without_blocks = re.sub(r"```.*?```", "", content, flags=re.DOTALL)
|
||||
content_without_blocks = sub_safe(r"```.*?```", "", content, flags=re.DOTALL)
|
||||
# Then find inline code in the remaining content
|
||||
inline_pattern = r"`([^`\n]+)`"
|
||||
for match in re.finditer(inline_pattern, content_without_blocks):
|
||||
inline_code.append(match.group(1))
|
||||
|
||||
inline_matches = findall_safe(inline_pattern, content_without_blocks)
|
||||
inline_code = list(inline_matches)
|
||||
return {
|
||||
"code_blocks": code_blocks,
|
||||
"inline_code": inline_code,
|
||||
@@ -288,18 +307,23 @@ def generate_table_of_contents(content: str, max_level: int = 6) -> dict[str, An
|
||||
headers = []
|
||||
header_pattern = r"^(#{1,6})\s+(.*?)$"
|
||||
|
||||
for match in re.finditer(header_pattern, content, re.MULTILINE):
|
||||
level = len(match.group(1))
|
||||
if level <= max_level:
|
||||
text = match.group(2).strip()
|
||||
slug = _slugify(text)
|
||||
headers.append(
|
||||
{
|
||||
"level": level,
|
||||
"text": text,
|
||||
"slug": slug,
|
||||
}
|
||||
)
|
||||
header_matches = cast(list[tuple[str, str]], findall_safe(header_pattern, content, flags=re.MULTILINE))
|
||||
for match in header_matches:
|
||||
if isinstance(match, tuple) and len(match) == 2:
|
||||
# Type narrowing for pyrefly - we've verified it's a 2-tuple
|
||||
level_markers_raw, text_raw = cast(tuple[str, str], match)
|
||||
level_markers = str(level_markers_raw)
|
||||
level = len(level_markers)
|
||||
if level <= max_level:
|
||||
text_str = str(text_raw).strip()
|
||||
slug = _slugify(text_str)
|
||||
headers.append(
|
||||
{
|
||||
"level": level,
|
||||
"text": text_str,
|
||||
"slug": slug,
|
||||
}
|
||||
)
|
||||
|
||||
# Generate TOC markdown
|
||||
toc_lines = []
|
||||
@@ -420,25 +444,25 @@ def _markdown_to_html(content: str) -> str:
|
||||
html = content
|
||||
|
||||
# Headers
|
||||
html = re.sub(r"^######\s+(.*?)$", r"<h6>\1</h6>", html, flags=re.MULTILINE)
|
||||
html = re.sub(r"^#####\s+(.*?)$", r"<h5>\1</h5>", html, flags=re.MULTILINE)
|
||||
html = re.sub(r"^####\s+(.*?)$", r"<h4>\1</h4>", html, flags=re.MULTILINE)
|
||||
html = re.sub(r"^###\s+(.*?)$", r"<h3>\1</h3>", html, flags=re.MULTILINE)
|
||||
html = re.sub(r"^##\s+(.*?)$", r"<h2>\1</h2>", html, flags=re.MULTILINE)
|
||||
html = re.sub(r"^#\s+(.*?)$", r"<h1>\1</h1>", html, flags=re.MULTILINE)
|
||||
html = sub_safe(r"^######\s+(.*?)$", r"<h6>\1</h6>", html, flags=re.MULTILINE)
|
||||
html = sub_safe(r"^#####\s+(.*?)$", r"<h5>\1</h5>", html, flags=re.MULTILINE)
|
||||
html = sub_safe(r"^####\s+(.*?)$", r"<h4>\1</h4>", html, flags=re.MULTILINE)
|
||||
html = sub_safe(r"^###\s+(.*?)$", r"<h3>\1</h3>", html, flags=re.MULTILINE)
|
||||
html = sub_safe(r"^##\s+(.*?)$", r"<h2>\1</h2>", html, flags=re.MULTILINE)
|
||||
html = sub_safe(r"^#\s+(.*?)$", r"<h1>\1</h1>", html, flags=re.MULTILINE)
|
||||
|
||||
# Bold and italic
|
||||
html = re.sub(r"\*\*(.*?)\*\*", r"<strong>\1</strong>", html)
|
||||
html = re.sub(r"\*(.*?)\*", r"<em>\1</em>", html)
|
||||
html = sub_safe(r"\*\*(.*?)\*\*", r"<strong>\1</strong>", html)
|
||||
html = sub_safe(r"\*(.*?)\*", r"<em>\1</em>", html)
|
||||
|
||||
# Links
|
||||
html = re.sub(r"\[([^\]]+)\]\(([^)]+)\)", r'<a href="\2">\1</a>', html)
|
||||
html = sub_safe(r"\[([^\]]+)\]\(([^)]+)\)", r'<a href="\2">\1</a>', html)
|
||||
|
||||
# Images
|
||||
html = re.sub(r"!\[([^\]]*)\]\(([^)]+)\)", r'<img src="\2" alt="\1" />', html)
|
||||
html = sub_safe(r"!\[([^\]]*)\]\(([^)]+)\)", r'<img src="\2" alt="\1" />', html)
|
||||
|
||||
# Code blocks
|
||||
html = re.sub(
|
||||
html = sub_safe(
|
||||
r"```(\w*)\n(.*?)```",
|
||||
r'<pre><code class="language-\1">\2</code></pre>',
|
||||
html,
|
||||
@@ -446,7 +470,7 @@ def _markdown_to_html(content: str) -> str:
|
||||
)
|
||||
|
||||
# Inline code
|
||||
html = re.sub(r"`([^`]+)`", r"<code>\1</code>", html)
|
||||
html = sub_safe(r"`([^`]+)`", r"<code>\1</code>", html)
|
||||
|
||||
# Paragraphs
|
||||
paragraphs = html.split("\n\n")
|
||||
@@ -463,8 +487,8 @@ def _markdown_to_html(content: str) -> str:
|
||||
def _slugify(text: str) -> str:
|
||||
"""Convert text to URL-friendly slug."""
|
||||
slug = text.lower()
|
||||
slug = re.sub(r"[^\w\s-]", "", slug)
|
||||
slug = re.sub(r"[-\s]+", "-", slug)
|
||||
slug = sub_safe(r"[^\w\s-]", "", slug)
|
||||
slug = sub_safe(r"[-\s]+", "-", slug)
|
||||
return slug.strip("-")
|
||||
|
||||
|
||||
|
||||
@@ -7,9 +7,8 @@ supporting integration with agent frameworks and LangGraph state management.
|
||||
import json
|
||||
from typing import Annotated, Any, ClassVar, cast
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.tools import BaseTool, tool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Removed broken core import
|
||||
|
||||
@@ -15,6 +15,7 @@ from typing import TYPE_CHECKING, TypedDict, cast
|
||||
|
||||
from biz_bud.core.errors import ValidationError
|
||||
from biz_bud.core.utils.json_extractor import extract_json_robust, extract_json_simple
|
||||
from biz_bud.core.utils.regex_security import findall_safe, search_safe, sub_safe
|
||||
from biz_bud.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -73,9 +74,10 @@ def extract_json_from_text(text: str, use_robust_extraction: bool = True) -> Jso
|
||||
|
||||
def extract_python_code(text: str) -> str | None:
|
||||
"""Extract Python code from markdown code blocks."""
|
||||
# Look specifically for python code blocks
|
||||
python_pattern = re.compile(r"```python\s*\n(.*?)```", re.DOTALL)
|
||||
return match[1].strip() if (match := python_pattern.search(text)) else None
|
||||
# Look specifically for python code blocks using safe regex
|
||||
if match := search_safe(r"```python\s*\n(.*?)```", text, flags=re.DOTALL):
|
||||
return match.group(1).strip()
|
||||
return None
|
||||
|
||||
|
||||
def safe_eval_python(
|
||||
@@ -115,13 +117,13 @@ def extract_list_from_text(text: str) -> list[str]:
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
|
||||
# Numbered list (1. item, 2. item, etc.)
|
||||
if re.match(r"^\d+[\.\)]\s+", line):
|
||||
item = re.sub(r"^\d+[\.\)]\s+", "", line)
|
||||
# Numbered list (1. item, 2. item, etc.) using safe regex
|
||||
if search_safe(r"^\d+[\.\)]\s+", line):
|
||||
item = sub_safe(r"^\d+[\.\)]\s+", "", line)
|
||||
items.append(item.strip())
|
||||
in_list = True
|
||||
elif re.match(r"^[-*•]\s+", line):
|
||||
item = re.sub(r"^[-*•]\s+", "", line)
|
||||
elif search_safe(r"^[-*]\s+", line):
|
||||
item = sub_safe(r"^[-*]\s+", "", line)
|
||||
items.append(item.strip())
|
||||
in_list = True
|
||||
elif not line and in_list:
|
||||
@@ -151,7 +153,7 @@ def extract_key_value_pairs(text: str) -> dict[str, str]:
|
||||
lines = text.strip().split("\n")
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if (match := re.match(pattern1, line)) or (match := re.match(pattern2, line)):
|
||||
if (match := search_safe(pattern1, line)) or (match := search_safe(pattern2, line)):
|
||||
key, value = match.groups()
|
||||
pairs[key.strip()] = value.strip()
|
||||
|
||||
@@ -194,7 +196,7 @@ def extract_code_blocks(text: str, language: str = "") -> list[str]:
|
||||
if not language:
|
||||
pattern = r"```[^\n]*\n(.*?)```"
|
||||
|
||||
return re.findall(pattern, text, re.DOTALL)
|
||||
return findall_safe(pattern, text, flags=re.DOTALL)
|
||||
|
||||
|
||||
def parse_action_args(text: str) -> ActionArgsDict:
|
||||
@@ -264,13 +266,18 @@ def extract_thought_action_pairs(text: str) -> list[tuple[str, str]]:
|
||||
"""
|
||||
pairs = []
|
||||
|
||||
# Pattern for thought-action pairs
|
||||
# Pattern for thought-action pairs using safe regex
|
||||
pattern = r"Thought:\s*(.+?)(?:\n|$).*?Action:\s*(.+?)(?:\n|$)"
|
||||
|
||||
for match in re.finditer(pattern, text, re.MULTILINE | re.DOTALL):
|
||||
thought = match.group(1).strip()
|
||||
action = match.group(2).strip()
|
||||
pairs.append((thought, action))
|
||||
matches = cast(list[tuple[str, str]], findall_safe(pattern, text, flags=re.MULTILINE | re.DOTALL))
|
||||
for match in matches:
|
||||
# Note: findall_safe returns list of strings for single group or tuples for multiple groups
|
||||
if isinstance(match, tuple) and len(match) == 2:
|
||||
# Type narrowing for pyrefly - we've verified it's a 2-tuple
|
||||
thought_raw, action_raw = cast(tuple[str, str], match)
|
||||
thought = str(thought_raw).strip()
|
||||
action = str(action_raw).strip()
|
||||
pairs.append((thought, action))
|
||||
|
||||
return pairs
|
||||
|
||||
@@ -330,11 +337,11 @@ def clean_extracted_text(text: str) -> str:
|
||||
"""Clean extracted text by removing extra whitespace and normalizing quotes.
|
||||
|
||||
Normalizes all quote variants to standard straight quotes:
|
||||
- Curly double quotes (", ") → straight double quotes (")
|
||||
- Curly single quotes (', ') → straight single quotes (')
|
||||
- Curly double quotes (\u201c, \u201d) -> straight double quotes (")
|
||||
- Curly single quotes (\u2018, \u2019) -> straight single quotes (')
|
||||
"""
|
||||
return (
|
||||
re.sub(r"\s+", " ", text)
|
||||
sub_safe(r"\s+", " ", text)
|
||||
.strip()
|
||||
.replace("\u201c", '"') # Left curly double quote → straight double quote
|
||||
.replace("\u201d", '"') # Right curly double quote → straight double quote
|
||||
@@ -350,13 +357,13 @@ def clean_text(text: str) -> str:
|
||||
|
||||
def normalize_whitespace(text: str) -> str:
|
||||
"""Normalize whitespace in text."""
|
||||
return re.sub(r"\s+", " ", text).strip()
|
||||
return sub_safe(r"\s+", " ", text).strip()
|
||||
|
||||
|
||||
def remove_html_tags(text: str) -> str:
|
||||
"""Remove HTML tags from text."""
|
||||
# Simple HTML tag removal
|
||||
text = re.sub(r"<[^>]+>", "", text)
|
||||
text = sub_safe(r"<[^>]+>", "", text)
|
||||
# Decode common HTML entities
|
||||
text = text.replace("&", "&")
|
||||
text = text.replace("<", "<")
|
||||
@@ -374,8 +381,10 @@ def truncate_text(text: str, max_length: int = 100, suffix: str = "...") -> str:
|
||||
|
||||
def extract_sentences(text: str) -> list[str]:
|
||||
"""Extract sentences from text."""
|
||||
# Simple sentence extraction
|
||||
sentences = re.split(r"[.!?]+", text)
|
||||
# Simple sentence extraction using safe approach
|
||||
# Replace sentence endings with a delimiter, then split
|
||||
temp_text = sub_safe(r"[.!?]+", "||SENTENCE_BREAK||", text)
|
||||
sentences = temp_text.split("||SENTENCE_BREAK||")
|
||||
return [s.strip() for s in sentences if s.strip()]
|
||||
|
||||
|
||||
|
||||
@@ -274,7 +274,7 @@ class DefaultIntrospectionProvider(IntrospectionProvider):
|
||||
base_confidence = 0.5
|
||||
|
||||
# Boost confidence based on clear indicators
|
||||
if len(capabilities) > 0:
|
||||
if capabilities:
|
||||
base_confidence += 0.2
|
||||
|
||||
if len(keywords) > 3:
|
||||
@@ -285,7 +285,7 @@ class DefaultIntrospectionProvider(IntrospectionProvider):
|
||||
'api', 'data', 'extract', 'scrape', 'search', 'analyze', 'process',
|
||||
'document', 'url', 'website', 'database', 'research'
|
||||
]
|
||||
technical_matches = sum(1 for term in technical_terms if term in query.lower())
|
||||
technical_matches = sum(term in query.lower() for term in technical_terms)
|
||||
base_confidence += min(technical_matches * 0.05, 0.2)
|
||||
|
||||
return min(base_confidence, 1.0)
|
||||
@@ -328,8 +328,9 @@ class DefaultIntrospectionProvider(IntrospectionProvider):
|
||||
if tool == workflow:
|
||||
covered_capabilities.add(capability)
|
||||
|
||||
coverage = len(covered_capabilities.intersection(set(capabilities))) / len(capabilities)
|
||||
return coverage
|
||||
return len(covered_capabilities.intersection(set(capabilities))) / len(
|
||||
capabilities
|
||||
)
|
||||
|
||||
def _generate_selection_reasoning(
|
||||
self,
|
||||
|
||||
@@ -258,7 +258,7 @@ async def get_capability_analysis(
|
||||
# Create comprehensive result
|
||||
result = IntrospectionResult(
|
||||
analysis=analysis,
|
||||
selection=selection if selection else await introspection_provider.select_tools([], include_workflows=include_workflows),
|
||||
selection=selection or await introspection_provider.select_tools([], include_workflows=include_workflows),
|
||||
provider=introspection_provider.provider_name,
|
||||
timestamp=datetime.now(timezone.utc).isoformat()
|
||||
)
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
"""BeautifulSoup scraping provider implementation."""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
from bs4 import BeautifulSoup, Tag
|
||||
|
||||
from biz_bud.core.errors.tool_exceptions import ScraperError
|
||||
from biz_bud.core.networking.async_utils import gather_with_concurrency
|
||||
from biz_bud.core.networking.http_client import HTTPClient
|
||||
from biz_bud.logging import get_logger
|
||||
from biz_bud.tools.models import ContentType, PageMetadata, ScrapedContent
|
||||
|
||||
from ..interface import ScrapeProvider
|
||||
from .jina import _create_error_result, _scrape_batch_impl
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -85,14 +84,7 @@ class BeautifulSoupScrapeProvider(ScrapeProvider):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"BeautifulSoup scraping failed for '{url}': {e}")
|
||||
return ScrapedContent(
|
||||
url=url,
|
||||
content="",
|
||||
title=None,
|
||||
content_type=ContentType.HTML,
|
||||
metadata=PageMetadata(source_url=url),
|
||||
error=str(e),
|
||||
)
|
||||
return _create_error_result(url, str(e))
|
||||
|
||||
async def scrape_batch(
|
||||
self,
|
||||
@@ -110,42 +102,7 @@ class BeautifulSoupScrapeProvider(ScrapeProvider):
|
||||
Returns:
|
||||
List of ScrapedContent objects
|
||||
"""
|
||||
if not urls:
|
||||
return []
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
unique_urls = list(dict.fromkeys(urls))
|
||||
|
||||
# Create semaphore for concurrency control
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
|
||||
async def scrape_with_semaphore(url: str) -> ScrapedContent:
|
||||
"""Scrape a URL with semaphore control."""
|
||||
async with semaphore:
|
||||
return await self.scrape(url, timeout)
|
||||
|
||||
# Execute scraping concurrently
|
||||
tasks = [scrape_with_semaphore(url) for url in unique_urls]
|
||||
results = await gather_with_concurrency(max_concurrent, *tasks, return_exceptions=True)
|
||||
|
||||
# Process results and handle exceptions
|
||||
processed_results = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, BaseException):
|
||||
processed_results.append(
|
||||
ScrapedContent(
|
||||
url=unique_urls[i],
|
||||
content="",
|
||||
title=None,
|
||||
content_type=ContentType.HTML,
|
||||
metadata=PageMetadata(source_url=unique_urls[i]),
|
||||
error=str(result),
|
||||
)
|
||||
)
|
||||
else:
|
||||
processed_results.append(result)
|
||||
|
||||
return processed_results
|
||||
return await _scrape_batch_impl(self, urls, max_concurrent, timeout)
|
||||
|
||||
def _extract_main_content(self, soup: BeautifulSoup) -> str:
|
||||
"""Extract main content from HTML soup.
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
"""Firecrawl scraping provider implementation."""
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from biz_bud.core.errors import ValidationError
|
||||
from biz_bud.core.networking.async_utils import gather_with_concurrency
|
||||
from biz_bud.logging import get_logger
|
||||
from biz_bud.tools.models import ContentType, PageMetadata, ScrapedContent
|
||||
|
||||
from ..interface import ScrapeProvider
|
||||
from .jina import _create_error_result, _scrape_batch_impl
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from biz_bud.services.factory.service_factory import ServiceFactory
|
||||
@@ -96,14 +95,7 @@ class FirecrawlScrapeProvider(ScrapeProvider):
|
||||
|
||||
# Otherwise, it was a real scraping error
|
||||
logger.error(f"Firecrawl scraping failed for '{url}': {e}")
|
||||
return ScrapedContent(
|
||||
url=url,
|
||||
content="",
|
||||
title=None,
|
||||
content_type=ContentType.HTML,
|
||||
metadata=PageMetadata(source_url=url),
|
||||
error=str(e),
|
||||
)
|
||||
return _create_error_result(url, str(e))
|
||||
|
||||
async def scrape_batch(
|
||||
self,
|
||||
@@ -121,39 +113,4 @@ class FirecrawlScrapeProvider(ScrapeProvider):
|
||||
Returns:
|
||||
List of ScrapedContent objects
|
||||
"""
|
||||
if not urls:
|
||||
return []
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
unique_urls = list(dict.fromkeys(urls))
|
||||
|
||||
# Create semaphore for concurrency control
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
|
||||
async def scrape_with_semaphore(url: str) -> ScrapedContent:
|
||||
"""Scrape a URL with semaphore control."""
|
||||
async with semaphore:
|
||||
return await self.scrape(url, timeout)
|
||||
|
||||
# Execute scraping concurrently
|
||||
tasks = [scrape_with_semaphore(url) for url in unique_urls]
|
||||
results = await gather_with_concurrency(max_concurrent, *tasks, return_exceptions=True)
|
||||
|
||||
# Process results and handle exceptions
|
||||
processed_results = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, BaseException):
|
||||
processed_results.append(
|
||||
ScrapedContent(
|
||||
url=unique_urls[i],
|
||||
content="",
|
||||
title=None,
|
||||
content_type=ContentType.HTML,
|
||||
metadata=PageMetadata(source_url=unique_urls[i]),
|
||||
error=str(result),
|
||||
)
|
||||
)
|
||||
else:
|
||||
processed_results.append(result)
|
||||
|
||||
return processed_results
|
||||
return await _scrape_batch_impl(self, urls, max_concurrent, timeout)
|
||||
|
||||
@@ -16,6 +16,74 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _create_error_result(url: str, error: str) -> ScrapedContent:
|
||||
"""Create a ScrapedContent object for an error.
|
||||
|
||||
Args:
|
||||
url: The URL that failed
|
||||
error: Error message
|
||||
|
||||
Returns:
|
||||
ScrapedContent with error information
|
||||
"""
|
||||
return ScrapedContent(
|
||||
url=url,
|
||||
content="",
|
||||
title=None,
|
||||
content_type=ContentType.HTML,
|
||||
metadata=PageMetadata(source_url=url),
|
||||
error=error,
|
||||
)
|
||||
|
||||
|
||||
async def _scrape_batch_impl(
|
||||
provider: ScrapeProvider,
|
||||
urls: list[str],
|
||||
max_concurrent: int = 5,
|
||||
timeout: int = 30,
|
||||
) -> list[ScrapedContent]:
|
||||
"""Common implementation for batch scraping.
|
||||
|
||||
Args:
|
||||
provider: The scrape provider instance
|
||||
urls: List of URLs to scrape
|
||||
max_concurrent: Maximum concurrent operations
|
||||
timeout: Request timeout per URL
|
||||
|
||||
Returns:
|
||||
List of ScrapedContent objects
|
||||
"""
|
||||
if not urls:
|
||||
return []
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
unique_urls = list(dict.fromkeys(urls))
|
||||
|
||||
# Create semaphore for concurrency control
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
|
||||
async def scrape_with_semaphore(url: str) -> ScrapedContent:
|
||||
"""Scrape a URL with semaphore control."""
|
||||
async with semaphore:
|
||||
return await provider.scrape(url, timeout)
|
||||
|
||||
# Execute scraping concurrently
|
||||
tasks = [scrape_with_semaphore(url) for url in unique_urls]
|
||||
results = await gather_with_concurrency(max_concurrent, *tasks, return_exceptions=True)
|
||||
|
||||
# Process results and handle exceptions
|
||||
processed_results = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, BaseException):
|
||||
processed_results.append(
|
||||
_create_error_result(unique_urls[i], str(result))
|
||||
)
|
||||
else:
|
||||
processed_results.append(result)
|
||||
|
||||
return processed_results
|
||||
|
||||
|
||||
class JinaScrapeProvider(ScrapeProvider):
|
||||
"""Scraping provider using Jina Reader API through ServiceFactory."""
|
||||
|
||||
@@ -84,14 +152,7 @@ class JinaScrapeProvider(ScrapeProvider):
|
||||
|
||||
# Otherwise, it was a real scraping error
|
||||
logger.error(f"Jina scraping failed for '{url}': {e}")
|
||||
return ScrapedContent(
|
||||
url=url,
|
||||
content="",
|
||||
title=None,
|
||||
content_type=ContentType.HTML,
|
||||
metadata=PageMetadata(source_url=url),
|
||||
error=str(e),
|
||||
)
|
||||
return _create_error_result(url, str(e))
|
||||
|
||||
async def scrape_batch(
|
||||
self,
|
||||
@@ -109,39 +170,4 @@ class JinaScrapeProvider(ScrapeProvider):
|
||||
Returns:
|
||||
List of ScrapedContent objects
|
||||
"""
|
||||
if not urls:
|
||||
return []
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
unique_urls = list(dict.fromkeys(urls))
|
||||
|
||||
# Create semaphore for concurrency control
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
|
||||
async def scrape_with_semaphore(url: str) -> ScrapedContent:
|
||||
"""Scrape a URL with semaphore control."""
|
||||
async with semaphore:
|
||||
return await self.scrape(url, timeout)
|
||||
|
||||
# Execute scraping concurrently
|
||||
tasks = [scrape_with_semaphore(url) for url in unique_urls]
|
||||
results = await gather_with_concurrency(max_concurrent, *tasks, return_exceptions=True)
|
||||
|
||||
# Process results and handle exceptions
|
||||
processed_results = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, BaseException):
|
||||
processed_results.append(
|
||||
ScrapedContent(
|
||||
url=unique_urls[i],
|
||||
content="",
|
||||
title=None,
|
||||
content_type=ContentType.HTML,
|
||||
metadata=PageMetadata(source_url=unique_urls[i]),
|
||||
error=str(result),
|
||||
)
|
||||
)
|
||||
else:
|
||||
processed_results.append(result)
|
||||
|
||||
return processed_results
|
||||
return await _scrape_batch_impl(self, urls, max_concurrent, timeout)
|
||||
|
||||
@@ -61,7 +61,7 @@ class ComprehensiveDiscoveryProvider(URLDiscoveryProvider):
|
||||
start_time = time.time()
|
||||
|
||||
# Use existing URLDiscoverer
|
||||
discovery_result = await self.discoverer.discover_urls(base_url)
|
||||
discovery_result = self.discoverer.discover_urls(base_url)
|
||||
|
||||
discovery_time = time.time() - start_time
|
||||
|
||||
@@ -152,7 +152,7 @@ class SitemapOnlyDiscoveryProvider(URLDiscoveryProvider):
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
discovery_result = await self.discoverer.discover_urls(base_url)
|
||||
discovery_result = self.discoverer.discover_urls(base_url)
|
||||
discovery_time = time.time() - start_time
|
||||
|
||||
result = DiscoveryResult(
|
||||
@@ -238,7 +238,7 @@ class HTMLParsingDiscoveryProvider(URLDiscoveryProvider):
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
discovery_result = await self.discoverer.discover_urls(base_url)
|
||||
discovery_result = self.discoverer.discover_urls(base_url)
|
||||
discovery_time = time.time() - start_time
|
||||
|
||||
result = DiscoveryResult(
|
||||
|
||||
@@ -17,7 +17,92 @@ from ..interface import URLNormalizationProvider
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class StandardNormalizationProvider(URLNormalizationProvider):
|
||||
class BaseNormalizationProvider(URLNormalizationProvider):
|
||||
"""Base class for URL normalization providers."""
|
||||
|
||||
def __init__(self, config: dict[str, Any] | None = None) -> None:
|
||||
"""Initialize base normalization provider.
|
||||
|
||||
Args:
|
||||
config: Optional configuration dictionary
|
||||
"""
|
||||
self.config = config or {}
|
||||
self.normalizer = self._create_normalizer()
|
||||
self._error_message = "Failed to normalize URL"
|
||||
|
||||
def _create_normalizer(self) -> URLNormalizer:
|
||||
"""Create URLNormalizer instance with provider-specific settings.
|
||||
|
||||
Subclasses should override this method to provide custom settings.
|
||||
|
||||
Returns:
|
||||
Configured URLNormalizer instance
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement _create_normalizer")
|
||||
|
||||
def normalize_url(self, url: str) -> str:
|
||||
"""Normalize URL using provider rules.
|
||||
|
||||
Args:
|
||||
url: URL to normalize
|
||||
|
||||
Returns:
|
||||
Normalized URL string
|
||||
|
||||
Raises:
|
||||
URLNormalizationError: If normalization fails
|
||||
"""
|
||||
try:
|
||||
return self._perform_normalization(url)
|
||||
except Exception as e:
|
||||
raise URLNormalizationError(
|
||||
message=self._error_message,
|
||||
url=url,
|
||||
normalization_rules=list(self.get_normalization_config().keys()),
|
||||
original_error=e,
|
||||
) from e
|
||||
|
||||
def _perform_normalization(self, url: str) -> str:
|
||||
"""Perform actual normalization. Override for custom logic.
|
||||
|
||||
Args:
|
||||
url: URL to normalize
|
||||
|
||||
Returns:
|
||||
Normalized URL
|
||||
"""
|
||||
return self.normalizer.normalize(url)
|
||||
|
||||
def get_normalization_config(self) -> dict[str, Any]:
|
||||
"""Get normalization configuration details.
|
||||
|
||||
Returns:
|
||||
Dictionary with normalization settings
|
||||
"""
|
||||
base_config = {
|
||||
"default_protocol": self.normalizer.default_protocol,
|
||||
"normalize_protocol": self.normalizer.normalize_protocol,
|
||||
"remove_fragments": self.normalizer.remove_fragments,
|
||||
"remove_www": self.normalizer.remove_www,
|
||||
"lowercase_domain": self.normalizer.lowercase_domain,
|
||||
"sort_query_params": self.normalizer.sort_query_params,
|
||||
"remove_trailing_slash": self.normalizer.remove_trailing_slash,
|
||||
}
|
||||
return self._extend_config(base_config)
|
||||
|
||||
def _extend_config(self, base_config: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extend base configuration with provider-specific settings.
|
||||
|
||||
Args:
|
||||
base_config: Base configuration dictionary
|
||||
|
||||
Returns:
|
||||
Extended configuration
|
||||
"""
|
||||
return base_config
|
||||
|
||||
|
||||
class StandardNormalizationProvider(BaseNormalizationProvider):
|
||||
"""Standard URL normalization using core URLNormalizer."""
|
||||
|
||||
def __init__(self, config: dict[str, Any] | None = None) -> None:
|
||||
@@ -26,10 +111,12 @@ class StandardNormalizationProvider(URLNormalizationProvider):
|
||||
Args:
|
||||
config: Optional configuration dictionary for normalization rules
|
||||
"""
|
||||
self.config = config or {}
|
||||
super().__init__(config)
|
||||
self._error_message = "Failed to normalize URL with standard rules"
|
||||
|
||||
# Create URLNormalizer with standard configuration
|
||||
self.normalizer = URLNormalizer(
|
||||
def _create_normalizer(self) -> URLNormalizer:
|
||||
"""Create URLNormalizer with standard configuration."""
|
||||
return URLNormalizer(
|
||||
default_protocol=self.config.get("default_protocol", "https"),
|
||||
normalize_protocol=self.config.get("normalize_protocol", True),
|
||||
remove_fragments=self.config.get("remove_fragments", True),
|
||||
@@ -39,46 +126,8 @@ class StandardNormalizationProvider(URLNormalizationProvider):
|
||||
remove_trailing_slash=self.config.get("remove_trailing_slash", True),
|
||||
)
|
||||
|
||||
def normalize_url(self, url: str) -> str:
|
||||
"""Normalize URL using standard rules.
|
||||
|
||||
Args:
|
||||
url: URL to normalize
|
||||
|
||||
Returns:
|
||||
Normalized URL string
|
||||
|
||||
Raises:
|
||||
URLNormalizationError: If normalization fails
|
||||
"""
|
||||
try:
|
||||
return self.normalizer.normalize(url)
|
||||
except Exception as e:
|
||||
raise URLNormalizationError(
|
||||
message="Failed to normalize URL with standard rules",
|
||||
url=url,
|
||||
normalization_rules=list(self.get_normalization_config().keys()),
|
||||
original_error=e,
|
||||
) from e
|
||||
|
||||
def get_normalization_config(self) -> dict[str, Any]:
|
||||
"""Get normalization configuration details.
|
||||
|
||||
Returns:
|
||||
Dictionary with normalization settings
|
||||
"""
|
||||
return {
|
||||
"default_protocol": self.normalizer.default_protocol,
|
||||
"normalize_protocol": self.normalizer.normalize_protocol,
|
||||
"remove_fragments": self.normalizer.remove_fragments,
|
||||
"remove_www": self.normalizer.remove_www,
|
||||
"lowercase_domain": self.normalizer.lowercase_domain,
|
||||
"sort_query_params": self.normalizer.sort_query_params,
|
||||
"remove_trailing_slash": self.normalizer.remove_trailing_slash,
|
||||
}
|
||||
|
||||
|
||||
class ConservativeNormalizationProvider(URLNormalizationProvider):
|
||||
class ConservativeNormalizationProvider(BaseNormalizationProvider):
|
||||
"""Conservative URL normalization with minimal changes."""
|
||||
|
||||
def __init__(self, config: dict[str, Any] | None = None) -> None:
|
||||
@@ -87,10 +136,12 @@ class ConservativeNormalizationProvider(URLNormalizationProvider):
|
||||
Args:
|
||||
config: Optional configuration dictionary
|
||||
"""
|
||||
self.config = config or {}
|
||||
super().__init__(config)
|
||||
self._error_message = "Failed to normalize URL with conservative rules"
|
||||
|
||||
# Create URLNormalizer with conservative settings
|
||||
self.normalizer = URLNormalizer(
|
||||
def _create_normalizer(self) -> URLNormalizer:
|
||||
"""Create URLNormalizer with conservative settings."""
|
||||
return URLNormalizer(
|
||||
default_protocol=self.config.get("default_protocol", "https"),
|
||||
normalize_protocol=False, # Don't change protocol
|
||||
remove_fragments=True, # Safe to remove fragments
|
||||
@@ -100,46 +151,8 @@ class ConservativeNormalizationProvider(URLNormalizationProvider):
|
||||
remove_trailing_slash=False, # Keep trailing slashes
|
||||
)
|
||||
|
||||
def normalize_url(self, url: str) -> str:
|
||||
"""Normalize URL using conservative rules.
|
||||
|
||||
Args:
|
||||
url: URL to normalize
|
||||
|
||||
Returns:
|
||||
Normalized URL string
|
||||
|
||||
Raises:
|
||||
URLNormalizationError: If normalization fails
|
||||
"""
|
||||
try:
|
||||
return self.normalizer.normalize(url)
|
||||
except Exception as e:
|
||||
raise URLNormalizationError(
|
||||
message="Failed to normalize URL with conservative rules",
|
||||
url=url,
|
||||
normalization_rules=list(self.get_normalization_config().keys()),
|
||||
original_error=e,
|
||||
) from e
|
||||
|
||||
def get_normalization_config(self) -> dict[str, Any]:
|
||||
"""Get normalization configuration details.
|
||||
|
||||
Returns:
|
||||
Dictionary with normalization settings
|
||||
"""
|
||||
return {
|
||||
"default_protocol": self.normalizer.default_protocol,
|
||||
"normalize_protocol": self.normalizer.normalize_protocol,
|
||||
"remove_fragments": self.normalizer.remove_fragments,
|
||||
"remove_www": self.normalizer.remove_www,
|
||||
"lowercase_domain": self.normalizer.lowercase_domain,
|
||||
"sort_query_params": self.normalizer.sort_query_params,
|
||||
"remove_trailing_slash": self.normalizer.remove_trailing_slash,
|
||||
}
|
||||
|
||||
|
||||
class AggressiveNormalizationProvider(URLNormalizationProvider):
|
||||
class AggressiveNormalizationProvider(BaseNormalizationProvider):
|
||||
"""Aggressive URL normalization with maximum canonicalization."""
|
||||
|
||||
def __init__(self, config: dict[str, Any] | None = None) -> None:
|
||||
@@ -148,10 +161,12 @@ class AggressiveNormalizationProvider(URLNormalizationProvider):
|
||||
Args:
|
||||
config: Optional configuration dictionary
|
||||
"""
|
||||
self.config = config or {}
|
||||
super().__init__(config)
|
||||
self._error_message = "Failed to normalize URL with aggressive rules"
|
||||
|
||||
# Create URLNormalizer with aggressive settings
|
||||
self.normalizer = URLNormalizer(
|
||||
def _create_normalizer(self) -> URLNormalizer:
|
||||
"""Create URLNormalizer with aggressive settings."""
|
||||
return URLNormalizer(
|
||||
default_protocol=self.config.get("default_protocol", "https"),
|
||||
normalize_protocol=True, # Always normalize to HTTPS
|
||||
remove_fragments=True, # Remove all fragments
|
||||
@@ -161,32 +176,11 @@ class AggressiveNormalizationProvider(URLNormalizationProvider):
|
||||
remove_trailing_slash=True, # Remove trailing slashes
|
||||
)
|
||||
|
||||
def normalize_url(self, url: str) -> str:
|
||||
"""Normalize URL using aggressive rules.
|
||||
|
||||
Args:
|
||||
url: URL to normalize
|
||||
|
||||
Returns:
|
||||
Normalized URL string
|
||||
|
||||
Raises:
|
||||
URLNormalizationError: If normalization fails
|
||||
"""
|
||||
try:
|
||||
normalized = self.normalizer.normalize(url)
|
||||
|
||||
# Additional aggressive normalization
|
||||
normalized = self._apply_aggressive_rules(normalized)
|
||||
|
||||
return normalized
|
||||
except Exception as e:
|
||||
raise URLNormalizationError(
|
||||
message="Failed to normalize URL with aggressive rules",
|
||||
url=url,
|
||||
normalization_rules=list(self.get_normalization_config().keys()),
|
||||
original_error=e,
|
||||
) from e
|
||||
def _perform_normalization(self, url: str) -> str:
|
||||
"""Perform normalization with additional aggressive rules."""
|
||||
normalized = self.normalizer.normalize(url)
|
||||
# Additional aggressive normalization
|
||||
return self._apply_aggressive_rules(normalized)
|
||||
|
||||
def _apply_aggressive_rules(self, url: str) -> str:
|
||||
"""Apply additional aggressive normalization rules.
|
||||
@@ -237,21 +231,9 @@ class AggressiveNormalizationProvider(URLNormalizationProvider):
|
||||
# Return original URL if aggressive rules fail
|
||||
return url
|
||||
|
||||
def get_normalization_config(self) -> dict[str, Any]:
|
||||
"""Get normalization configuration details.
|
||||
|
||||
Returns:
|
||||
Dictionary with normalization settings
|
||||
"""
|
||||
return {
|
||||
"default_protocol": self.normalizer.default_protocol,
|
||||
"normalize_protocol": self.normalizer.normalize_protocol,
|
||||
"remove_fragments": self.normalizer.remove_fragments,
|
||||
"remove_www": self.normalizer.remove_www,
|
||||
"lowercase_domain": self.normalizer.lowercase_domain,
|
||||
"sort_query_params": self.normalizer.sort_query_params,
|
||||
"remove_trailing_slash": self.normalizer.remove_trailing_slash,
|
||||
} | {
|
||||
def _extend_config(self, base_config: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extend base configuration with aggressive-specific settings."""
|
||||
return base_config | {
|
||||
"remove_tracking_params": True,
|
||||
"aggressive_query_cleaning": True,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,11 @@ from biz_bud.core.errors import ValidationError
|
||||
from biz_bud.logging import get_logger
|
||||
from biz_bud.states.buddy import ExecutionRecord
|
||||
from biz_bud.states.planner import QueryStep
|
||||
from biz_bud.tools.capabilities.workflow.validation_helpers import (
|
||||
create_key_points,
|
||||
create_summary,
|
||||
extract_content_from_result,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -261,10 +266,8 @@ class IntermediateResultsConverter:
|
||||
# Extract key information from result string
|
||||
extracted_info[step_id] = {
|
||||
"content": result,
|
||||
"summary": f"{result[:300]}..." if len(result) > 300 else result,
|
||||
"key_points": [f"{result[:200]}..."]
|
||||
if len(result) > 200
|
||||
else [result],
|
||||
"summary": create_summary(result),
|
||||
"key_points": create_key_points(result),
|
||||
"facts": [],
|
||||
}
|
||||
sources.append(
|
||||
@@ -279,45 +282,17 @@ class IntermediateResultsConverter:
|
||||
f"Dict result for step {step_id}, keys: {list(result.keys())}"
|
||||
)
|
||||
# Handle dictionary results - extract actual content
|
||||
content = None
|
||||
|
||||
# Try to extract meaningful content from various possible keys
|
||||
for content_key in [
|
||||
"synthesis",
|
||||
"final_response",
|
||||
"content",
|
||||
"response",
|
||||
"result",
|
||||
"output",
|
||||
]:
|
||||
if content_key in result and result[content_key]:
|
||||
content = str(result[content_key])
|
||||
logger.debug(
|
||||
f"Found content in key '{content_key}' for step {step_id}"
|
||||
)
|
||||
break
|
||||
|
||||
# If no content found, stringify the whole result
|
||||
if not content:
|
||||
content = str(result)
|
||||
logger.debug(
|
||||
f"No specific content key found, using stringified result for step {step_id}"
|
||||
)
|
||||
content = extract_content_from_result(result, step_id)
|
||||
|
||||
# Extract key points if available
|
||||
key_points = result.get("key_points", [])
|
||||
if not key_points and content:
|
||||
# Create key points from content
|
||||
key_points = (
|
||||
[f"{content[:200]}..."] if len(content) > 200 else [content]
|
||||
)
|
||||
key_points = create_key_points(
|
||||
content,
|
||||
result.get("key_points", []) or None
|
||||
)
|
||||
|
||||
extracted_info[step_id] = {
|
||||
"content": content,
|
||||
"summary": result.get(
|
||||
"summary",
|
||||
f"{content[:300]}..." if len(content) > 300 else content,
|
||||
),
|
||||
"summary": result.get("summary") or create_summary(content),
|
||||
"key_points": key_points,
|
||||
"facts": result.get("facts", []),
|
||||
}
|
||||
@@ -333,14 +308,11 @@ class IntermediateResultsConverter:
|
||||
f"Unexpected result type for step {step_id}: {type(result).__name__}"
|
||||
)
|
||||
# Handle other types by converting to string
|
||||
content_str: str = str(result)
|
||||
summary = (
|
||||
f"{content_str[:300]}..." if len(content_str) > 300 else content_str
|
||||
)
|
||||
content_str = str(result)
|
||||
extracted_info[step_id] = {
|
||||
"content": content_str,
|
||||
"summary": summary,
|
||||
"key_points": [content_str],
|
||||
"summary": create_summary(content_str),
|
||||
"key_points": create_key_points(content_str),
|
||||
"facts": [],
|
||||
}
|
||||
sources.append(
|
||||
@@ -359,6 +331,7 @@ class IntermediateResultsConverter:
|
||||
|
||||
# Tool functions for workflow execution utilities
|
||||
|
||||
|
||||
@tool
|
||||
def create_success_execution_record(
|
||||
step_id: str,
|
||||
|
||||
@@ -7,6 +7,14 @@ from langchain_core.tools import tool
|
||||
|
||||
from biz_bud.logging import get_logger
|
||||
from biz_bud.states.planner import ExecutionPlan, QueryStep
|
||||
from biz_bud.tools.capabilities.workflow.validation_helpers import (
|
||||
process_dependencies_field,
|
||||
validate_bool_field,
|
||||
validate_list_field,
|
||||
validate_literal_field,
|
||||
validate_optional_string_field,
|
||||
validate_string_field,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -64,63 +72,38 @@ class PlanParser:
|
||||
steps = []
|
||||
for step_data in execution_plan_data.get("steps", []):
|
||||
if isinstance(step_data, dict):
|
||||
# Ensure dependencies is a list of strings
|
||||
dependencies_raw = step_data.get("dependencies", [])
|
||||
if isinstance(dependencies_raw, str):
|
||||
dependencies = [dependencies_raw]
|
||||
elif isinstance(dependencies_raw, list):
|
||||
dependencies = [str(dep) for dep in dependencies_raw]
|
||||
else:
|
||||
dependencies = []
|
||||
# Validate dependencies
|
||||
dependencies = process_dependencies_field(
|
||||
step_data.get("dependencies", [])
|
||||
)
|
||||
|
||||
# Validate priority with proper type checking
|
||||
priority_raw = step_data.get("priority", "medium")
|
||||
if not isinstance(priority_raw, str):
|
||||
logger.warning(
|
||||
f"Invalid priority type: {type(priority_raw)}, using default 'medium'"
|
||||
)
|
||||
priority: Literal["high", "medium", "low"] = "medium"
|
||||
elif priority_raw not in ["high", "medium", "low"]:
|
||||
logger.warning(
|
||||
f"Invalid priority value: {priority_raw}, using default 'medium'"
|
||||
)
|
||||
priority = "medium"
|
||||
else:
|
||||
priority = priority_raw # type: ignore[assignment] # Validated above
|
||||
# Validate priority
|
||||
priority: Literal["high", "medium", "low"] = validate_literal_field( # type: ignore[assignment]
|
||||
step_data, "priority", ["high", "medium", "low"], "medium", "priority"
|
||||
)
|
||||
|
||||
# Validate status with proper type checking
|
||||
status_raw = step_data.get("status", "pending")
|
||||
valid_statuses: list[str] = [
|
||||
# Validate status
|
||||
valid_statuses = [
|
||||
"pending",
|
||||
"in_progress",
|
||||
"completed",
|
||||
"failed",
|
||||
"blocked",
|
||||
]
|
||||
if not isinstance(status_raw, str):
|
||||
logger.warning(
|
||||
f"Invalid status type: {type(status_raw)}, using default 'pending'"
|
||||
)
|
||||
status: Literal[
|
||||
"pending",
|
||||
"in_progress",
|
||||
"completed",
|
||||
"failed",
|
||||
"blocked",
|
||||
] = "pending"
|
||||
elif status_raw not in valid_statuses:
|
||||
logger.warning(
|
||||
f"Invalid status value: {status_raw}, using default 'pending'"
|
||||
)
|
||||
status = "pending"
|
||||
else:
|
||||
status = status_raw # type: ignore[assignment] # Validated above
|
||||
status_str = validate_literal_field(
|
||||
step_data, "status", valid_statuses, "pending", "status"
|
||||
)
|
||||
status: Literal[
|
||||
"pending",
|
||||
"in_progress",
|
||||
"completed",
|
||||
"failed",
|
||||
"blocked",
|
||||
] = status_str # type: ignore[assignment]
|
||||
|
||||
# Validate results field
|
||||
results_raw = step_data.get("results")
|
||||
if results_raw is not None and not isinstance(
|
||||
results_raw, dict
|
||||
):
|
||||
if results_raw is not None and not isinstance(results_raw, dict):
|
||||
logger.warning(
|
||||
f"Invalid results type: {type(results_raw)}, setting to None"
|
||||
)
|
||||
@@ -128,58 +111,21 @@ class PlanParser:
|
||||
else:
|
||||
results = results_raw
|
||||
|
||||
# Ensure all string fields are properly converted with validation
|
||||
step_id_raw = step_data.get("id", f"step_{len(steps) + 1}")
|
||||
if not isinstance(step_id_raw, (str, int)):
|
||||
logger.warning(
|
||||
f"Invalid step_id type: {type(step_id_raw)}, using default"
|
||||
)
|
||||
step_id = f"step_{len(steps) + 1}"
|
||||
else:
|
||||
step_id = str(step_id_raw)
|
||||
|
||||
description_raw = step_data.get("description", "")
|
||||
if not isinstance(description_raw, str):
|
||||
logger.warning(
|
||||
f"Invalid description type: {type(description_raw)}, converting to string"
|
||||
)
|
||||
description = str(description_raw)
|
||||
|
||||
agent_name_raw = step_data.get("agent_name", "main")
|
||||
if agent_name_raw is not None and not isinstance(
|
||||
agent_name_raw, str
|
||||
):
|
||||
logger.warning(
|
||||
f"Invalid agent_name type: {type(agent_name_raw)}, converting to string"
|
||||
)
|
||||
agent_name = (
|
||||
str(agent_name_raw) if agent_name_raw is not None else None
|
||||
# Validate string fields
|
||||
step_id = validate_string_field(
|
||||
step_data, "id", f"step_{len(steps) + 1}"
|
||||
)
|
||||
|
||||
agent_role_prompt_raw = step_data.get("agent_role_prompt")
|
||||
if agent_role_prompt_raw is not None and not isinstance(
|
||||
agent_role_prompt_raw, str
|
||||
):
|
||||
logger.warning(
|
||||
f"Invalid agent_role_prompt type: {type(agent_role_prompt_raw)}, converting to string"
|
||||
)
|
||||
agent_role_prompt = (
|
||||
str(agent_role_prompt_raw)
|
||||
if agent_role_prompt_raw is not None
|
||||
else None
|
||||
description = validate_string_field(
|
||||
step_data, "description", ""
|
||||
)
|
||||
|
||||
error_message_raw = step_data.get("error_message")
|
||||
if error_message_raw is not None and not isinstance(
|
||||
error_message_raw, str
|
||||
):
|
||||
logger.warning(
|
||||
f"Invalid error_message type: {type(error_message_raw)}, converting to string"
|
||||
)
|
||||
error_message = (
|
||||
str(error_message_raw)
|
||||
if error_message_raw is not None
|
||||
else None
|
||||
agent_name = validate_optional_string_field(
|
||||
step_data, "agent_name"
|
||||
) or "main"
|
||||
agent_role_prompt = validate_optional_string_field(
|
||||
step_data, "agent_role_prompt"
|
||||
)
|
||||
error_message = validate_optional_string_field(
|
||||
step_data, "error_message"
|
||||
)
|
||||
|
||||
step = QueryStep(
|
||||
@@ -199,88 +145,34 @@ class PlanParser:
|
||||
steps.append(step)
|
||||
|
||||
if steps:
|
||||
# Validate current_step_id with proper type checking
|
||||
current_step_id_raw = execution_plan_data.get("current_step_id")
|
||||
if current_step_id_raw is not None:
|
||||
if not isinstance(current_step_id_raw, (str, int)):
|
||||
logger.warning(
|
||||
f"Invalid current_step_id type: {type(current_step_id_raw)}, setting to None"
|
||||
)
|
||||
current_step_id = None
|
||||
else:
|
||||
current_step_id = str(current_step_id_raw)
|
||||
else:
|
||||
current_step_id = None
|
||||
|
||||
# Validate completed_steps with proper error handling
|
||||
completed_steps_raw = execution_plan_data.get("completed_steps", [])
|
||||
completed_steps = []
|
||||
if not isinstance(completed_steps_raw, list):
|
||||
logger.warning(
|
||||
f"Invalid completed_steps type: {type(completed_steps_raw)}, using empty list"
|
||||
)
|
||||
else:
|
||||
for step in completed_steps_raw:
|
||||
if isinstance(step, (str, int)):
|
||||
completed_steps.append(str(step))
|
||||
else:
|
||||
logger.warning(
|
||||
f"Invalid completed step type: {type(step)}, skipping"
|
||||
)
|
||||
|
||||
# Validate failed_steps with proper error handling
|
||||
failed_steps_raw = execution_plan_data.get("failed_steps", [])
|
||||
failed_steps = []
|
||||
if not isinstance(failed_steps_raw, list):
|
||||
logger.warning(
|
||||
f"Invalid failed_steps type: {type(failed_steps_raw)}, using empty list"
|
||||
)
|
||||
else:
|
||||
for step in failed_steps_raw:
|
||||
if isinstance(step, (str, int)):
|
||||
failed_steps.append(str(step))
|
||||
else:
|
||||
logger.warning(
|
||||
f"Invalid failed step type: {type(step)}, skipping"
|
||||
)
|
||||
|
||||
# Validate can_execute_parallel with proper type checking
|
||||
can_execute_parallel_raw = execution_plan_data.get(
|
||||
"can_execute_parallel", False
|
||||
# Validate current_step_id
|
||||
current_step_id = validate_optional_string_field(
|
||||
execution_plan_data, "current_step_id"
|
||||
)
|
||||
if not isinstance(can_execute_parallel_raw, (bool, int, str)):
|
||||
logger.warning(
|
||||
f"Invalid can_execute_parallel type: {type(can_execute_parallel_raw)}, using False"
|
||||
)
|
||||
can_execute_parallel = False
|
||||
else:
|
||||
try:
|
||||
can_execute_parallel = bool(can_execute_parallel_raw)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Cannot convert can_execute_parallel to bool: {can_execute_parallel_raw}, using False"
|
||||
)
|
||||
can_execute_parallel = False
|
||||
|
||||
# Validate execution_mode with proper type checking
|
||||
execution_mode_raw = execution_plan_data.get(
|
||||
"execution_mode", "sequential"
|
||||
# Validate completed_steps
|
||||
completed_steps = validate_list_field(
|
||||
execution_plan_data, "completed_steps", str
|
||||
)
|
||||
|
||||
# Validate failed_steps
|
||||
failed_steps = validate_list_field(
|
||||
execution_plan_data, "failed_steps", str
|
||||
)
|
||||
|
||||
# Validate can_execute_parallel
|
||||
can_execute_parallel = validate_bool_field(
|
||||
execution_plan_data, "can_execute_parallel", False
|
||||
)
|
||||
|
||||
# Validate execution_mode
|
||||
execution_mode: Literal["sequential", "parallel", "hybrid"] = validate_literal_field( # type: ignore[assignment]
|
||||
execution_plan_data,
|
||||
"execution_mode",
|
||||
["sequential", "parallel", "hybrid"],
|
||||
"sequential",
|
||||
"execution_mode"
|
||||
)
|
||||
valid_modes = ["sequential", "parallel", "hybrid"]
|
||||
if not isinstance(execution_mode_raw, str):
|
||||
logger.warning(
|
||||
f"Invalid execution_mode type: {type(execution_mode_raw)}, using 'sequential'"
|
||||
)
|
||||
execution_mode: Literal["sequential", "parallel", "hybrid"] = (
|
||||
"sequential"
|
||||
)
|
||||
elif execution_mode_raw not in valid_modes:
|
||||
logger.warning(
|
||||
f"Invalid execution_mode value: {execution_mode_raw}, using 'sequential'"
|
||||
)
|
||||
execution_mode = "sequential"
|
||||
else:
|
||||
execution_mode = execution_mode_raw # type: ignore[assignment] # Validated above
|
||||
|
||||
# Add fallback step if no valid steps found in structured plan
|
||||
if not steps:
|
||||
@@ -554,10 +446,11 @@ def validate_execution_plan(
|
||||
|
||||
# Check step required fields
|
||||
required_step_fields = ["id", "description"]
|
||||
for field in required_step_fields:
|
||||
if field not in step:
|
||||
errors.append(f"Step {i} missing required field '{field}'")
|
||||
|
||||
errors.extend(
|
||||
f"Step {i} missing required field '{field}'"
|
||||
for field in required_step_fields
|
||||
if field not in step
|
||||
)
|
||||
# Check step field types
|
||||
if "dependencies" in step and not isinstance(step["dependencies"], list):
|
||||
errors.append(f"Step {i} dependencies must be a list")
|
||||
@@ -568,11 +461,10 @@ def validate_execution_plan(
|
||||
# Check optional fields
|
||||
optional_fields = ["current_step_id", "completed_steps", "failed_steps"]
|
||||
for field in optional_fields:
|
||||
if field in plan_data:
|
||||
if field in ["completed_steps", "failed_steps"] and not isinstance(plan_data[field], list):
|
||||
errors.append(f"Field '{field}' must be a list")
|
||||
if field in plan_data and (field in ["completed_steps", "failed_steps"] and not isinstance(plan_data[field], list)):
|
||||
errors.append(f"Field '{field}' must be a list")
|
||||
|
||||
is_valid = len(errors) == 0
|
||||
is_valid = not errors
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
|
||||
339
src/biz_bud/tools/capabilities/workflow/validation_helpers.py
Normal file
339
src/biz_bud/tools/capabilities/workflow/validation_helpers.py
Normal file
@@ -0,0 +1,339 @@
|
||||
"""Validation helper functions for workflow utilities.
|
||||
|
||||
This module provides reusable validation functions to reduce code duplication
|
||||
in workflow planning and execution modules.
|
||||
"""
|
||||
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from biz_bud.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def validate_field(
|
||||
data: dict[str, Any],
|
||||
field_name: str,
|
||||
expected_type: type[T],
|
||||
default_value: T,
|
||||
field_display_name: str | None = None,
|
||||
) -> T:
|
||||
"""Validate a field in a dictionary and return the value or default.
|
||||
|
||||
Args:
|
||||
data: Dictionary containing the field
|
||||
field_name: Name of the field to validate
|
||||
expected_type: Expected type of the field
|
||||
default_value: Default value if field is missing or invalid
|
||||
field_display_name: Display name for logging (defaults to field_name)
|
||||
|
||||
Returns:
|
||||
The validated value or default
|
||||
"""
|
||||
display_name = field_display_name or field_name
|
||||
value = data.get(field_name, default_value)
|
||||
|
||||
if value is None:
|
||||
return default_value
|
||||
|
||||
if not isinstance(value, expected_type):
|
||||
logger.warning(
|
||||
f"Invalid {display_name} type: {type(value)}, using default {repr(default_value)}"
|
||||
)
|
||||
return default_value
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def validate_string_field(
|
||||
data: dict[str, Any],
|
||||
field_name: str,
|
||||
default_value: str = "",
|
||||
convert_to_string: bool = True,
|
||||
) -> str:
|
||||
"""Validate a string field with optional conversion.
|
||||
|
||||
Args:
|
||||
data: Dictionary containing the field
|
||||
field_name: Name of the field to validate
|
||||
default_value: Default value if field is missing or invalid
|
||||
convert_to_string: Whether to convert non-string values to string
|
||||
|
||||
Returns:
|
||||
The validated string value
|
||||
"""
|
||||
value = data.get(field_name, default_value)
|
||||
|
||||
if value is None:
|
||||
return default_value
|
||||
|
||||
if not isinstance(value, str):
|
||||
if convert_to_string:
|
||||
logger.warning(
|
||||
f"Invalid {field_name} type: {type(value)}, converting to string"
|
||||
)
|
||||
return str(value)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Invalid {field_name} type: {type(value)}, using default"
|
||||
)
|
||||
return default_value
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def validate_literal_field(
|
||||
data: dict[str, Any],
|
||||
field_name: str,
|
||||
valid_values: list[str],
|
||||
default_value: str,
|
||||
type_name: str | None = None,
|
||||
) -> str:
|
||||
"""Validate a field that must be one of a set of literal values.
|
||||
|
||||
Args:
|
||||
data: Dictionary containing the field
|
||||
field_name: Name of the field to validate
|
||||
valid_values: List of valid string values
|
||||
default_value: Default value (must be in valid_values)
|
||||
type_name: Type name for logging (e.g., "priority", "status")
|
||||
|
||||
Returns:
|
||||
The validated literal value
|
||||
"""
|
||||
display_name = type_name or field_name
|
||||
value = data.get(field_name, default_value)
|
||||
|
||||
if not isinstance(value, str):
|
||||
logger.warning(
|
||||
f"Invalid {display_name} type: {type(value)}, using default '{default_value}'"
|
||||
)
|
||||
return default_value
|
||||
|
||||
if value not in valid_values:
|
||||
logger.warning(
|
||||
f"Invalid {display_name} value: {value}, using default '{default_value}'"
|
||||
)
|
||||
return default_value
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def validate_list_field(
|
||||
data: dict[str, Any],
|
||||
field_name: str,
|
||||
item_type: type[T] | None = None,
|
||||
default_value: list[T] | None = None,
|
||||
) -> list[T]:
|
||||
"""Validate a list field with optional item type checking.
|
||||
|
||||
Args:
|
||||
data: Dictionary containing the field
|
||||
field_name: Name of the field to validate
|
||||
item_type: Expected type of list items (if None, accepts any)
|
||||
default_value: Default list value
|
||||
|
||||
Returns:
|
||||
The validated list
|
||||
"""
|
||||
if default_value is None:
|
||||
default_value = []
|
||||
|
||||
value = data.get(field_name, default_value)
|
||||
|
||||
if not isinstance(value, list):
|
||||
logger.warning(
|
||||
f"Invalid {field_name} type: {type(value)}, using empty list"
|
||||
)
|
||||
return list(default_value)
|
||||
|
||||
if item_type is None:
|
||||
return value
|
||||
|
||||
# Validate and filter items
|
||||
validated_items = []
|
||||
for item in value:
|
||||
if isinstance(item, item_type):
|
||||
validated_items.append(item)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Invalid {field_name} item type: {type(item)}, skipping"
|
||||
)
|
||||
|
||||
return validated_items
|
||||
|
||||
|
||||
def validate_optional_string_field(
|
||||
data: dict[str, Any],
|
||||
field_name: str,
|
||||
convert_to_string: bool = True,
|
||||
) -> str | None:
|
||||
"""Validate an optional string field.
|
||||
|
||||
Args:
|
||||
data: Dictionary containing the field
|
||||
field_name: Name of the field to validate
|
||||
convert_to_string: Whether to convert non-string values to string
|
||||
|
||||
Returns:
|
||||
The validated string value or None
|
||||
"""
|
||||
value = data.get(field_name)
|
||||
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if not isinstance(value, str):
|
||||
if convert_to_string:
|
||||
logger.warning(
|
||||
f"Invalid {field_name} type: {type(value)}, converting to string"
|
||||
)
|
||||
return str(value)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Invalid {field_name} type: {type(value)}, setting to None"
|
||||
)
|
||||
return None
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def validate_bool_field(
|
||||
data: dict[str, Any],
|
||||
field_name: str,
|
||||
default_value: bool = False,
|
||||
) -> bool:
|
||||
"""Validate a boolean field with type conversion.
|
||||
|
||||
Args:
|
||||
data: Dictionary containing the field
|
||||
field_name: Name of the field to validate
|
||||
default_value: Default boolean value
|
||||
|
||||
Returns:
|
||||
The validated boolean value
|
||||
"""
|
||||
value = data.get(field_name, default_value)
|
||||
|
||||
if not isinstance(value, (bool, int, str)):
|
||||
logger.warning(
|
||||
f"Invalid {field_name} type: {type(value)}, using {default_value}"
|
||||
)
|
||||
return default_value
|
||||
|
||||
try:
|
||||
# Handle string boolean values explicitly
|
||||
if isinstance(value, str):
|
||||
value_lower = value.lower().strip()
|
||||
if value_lower in ('false', '0', 'no', 'off', 'disabled'):
|
||||
return False
|
||||
elif value_lower in ('true', '1', 'yes', 'on', 'enabled'):
|
||||
return True
|
||||
# For other non-empty strings, default to True but warn
|
||||
elif value_lower:
|
||||
logger.warning(
|
||||
f"Ambiguous string value for {field_name}: '{value}', treating as True"
|
||||
)
|
||||
return True
|
||||
else:
|
||||
# Empty string
|
||||
return False
|
||||
|
||||
return bool(value)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Cannot convert {field_name} to bool: {value}, using {default_value}"
|
||||
)
|
||||
return default_value
|
||||
|
||||
|
||||
def process_dependencies_field(
|
||||
dependencies_raw: Any,
|
||||
) -> list[str]:
|
||||
"""Process and validate a dependencies field.
|
||||
|
||||
Args:
|
||||
dependencies_raw: Raw dependencies value (string, list, or other)
|
||||
|
||||
Returns:
|
||||
List of string dependencies
|
||||
"""
|
||||
if isinstance(dependencies_raw, str):
|
||||
return [dependencies_raw]
|
||||
elif isinstance(dependencies_raw, list):
|
||||
return [str(dep) for dep in dependencies_raw]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def extract_content_from_result(
|
||||
result: dict[str, Any],
|
||||
step_id: str,
|
||||
content_keys: list[str] | None = None,
|
||||
) -> str:
|
||||
"""Extract meaningful content from a result dictionary.
|
||||
|
||||
Args:
|
||||
result: Result dictionary to extract content from
|
||||
step_id: Step ID for logging
|
||||
content_keys: List of keys to try for content extraction
|
||||
|
||||
Returns:
|
||||
Extracted content as string
|
||||
"""
|
||||
if content_keys is None:
|
||||
content_keys = [
|
||||
"synthesis",
|
||||
"final_response",
|
||||
"content",
|
||||
"response",
|
||||
"result",
|
||||
"output",
|
||||
]
|
||||
|
||||
# Try to extract meaningful content from various possible keys
|
||||
for content_key in content_keys:
|
||||
if content_key in result and result[content_key]:
|
||||
content = str(result[content_key])
|
||||
logger.debug(
|
||||
f"Found content in key '{content_key}' for step {step_id}"
|
||||
)
|
||||
return content
|
||||
|
||||
# If no content found, stringify the whole result
|
||||
content = str(result)
|
||||
logger.debug(
|
||||
f"No specific content key found, using stringified result for step {step_id}"
|
||||
)
|
||||
return content
|
||||
|
||||
|
||||
def create_summary(content: str, max_length: int = 300) -> str:
|
||||
"""Create a summary from content.
|
||||
|
||||
Args:
|
||||
content: Content to summarize
|
||||
max_length: Maximum length for summary
|
||||
|
||||
Returns:
|
||||
Summary string
|
||||
"""
|
||||
return f"{content[:max_length]}..." if len(content) > max_length else content
|
||||
|
||||
|
||||
def create_key_points(content: str, existing_points: list[str] | None = None) -> list[str]:
|
||||
"""Create key points from content.
|
||||
|
||||
Args:
|
||||
content: Content to extract key points from
|
||||
existing_points: Existing key points to use if available
|
||||
|
||||
Returns:
|
||||
List of key points
|
||||
"""
|
||||
if existing_points:
|
||||
return existing_points
|
||||
|
||||
return [f"{content[:200]}..."] if len(content) > 200 else [content]
|
||||
@@ -757,16 +757,17 @@ class TestConcurrencyRaces:
|
||||
async def mixed_operation_task(task_id):
|
||||
"""Mixed factory operations."""
|
||||
try:
|
||||
operations = {
|
||||
0: lambda: get_global_factory(config),
|
||||
1: lambda: check_global_factory_health(),
|
||||
2: lambda: ensure_healthy_global_factory(config)
|
||||
}
|
||||
operation = operations[task_id % 3]
|
||||
result = await operation()
|
||||
# Direct async calls instead of lambda coroutines to avoid union type issues
|
||||
op_type = task_id % 3
|
||||
if op_type == 0:
|
||||
result = await get_global_factory(config)
|
||||
elif op_type == 1:
|
||||
result = await check_global_factory_health()
|
||||
else:
|
||||
result = await ensure_healthy_global_factory(config)
|
||||
|
||||
operation_names = {0: "get", 1: "health", 2: "ensure"}
|
||||
name = operation_names[task_id % 3]
|
||||
name = operation_names[op_type]
|
||||
|
||||
return f"{name}_{id(result) if hasattr(result, '__hash__') else result}"
|
||||
except Exception as e:
|
||||
|
||||
@@ -103,13 +103,35 @@ class WebExtractorNode:
|
||||
"""Mock web extractor node for testing HTML processing."""
|
||||
|
||||
def extract_text_from_html(self, html_content):
|
||||
"""Mock HTML text extractor."""
|
||||
"""Mock HTML text extractor with ReDoS protection."""
|
||||
try:
|
||||
import re
|
||||
# SECURITY FIX: Input validation to prevent ReDoS attacks
|
||||
if len(html_content) > 100000: # 100KB limit
|
||||
raise ValueError("HTML input too long")
|
||||
|
||||
# Simple HTML tag removal
|
||||
text = re.sub(r"<[^>]+>", "", html_content)
|
||||
return text.strip()
|
||||
# Check for potentially dangerous patterns
|
||||
if html_content.count('<') > 1000:
|
||||
raise ValueError("Too many HTML tags")
|
||||
|
||||
# Use safer HTML processing instead of vulnerable regex
|
||||
import html
|
||||
try:
|
||||
# Simple character-based tag removal (safer than regex)
|
||||
result = []
|
||||
in_tag = False
|
||||
for char in html_content:
|
||||
if char == '<':
|
||||
in_tag = True
|
||||
elif char == '>':
|
||||
in_tag = False
|
||||
elif not in_tag:
|
||||
result.append(char)
|
||||
|
||||
text = ''.join(result)
|
||||
return html.unescape(text).strip()
|
||||
except Exception:
|
||||
# Fallback: just remove common problematic characters
|
||||
return ''.join(c for c in html_content if c not in '<>').strip()
|
||||
except Exception:
|
||||
return "Extraction failed"
|
||||
|
||||
@@ -805,12 +827,34 @@ class TestMalformedInput:
|
||||
return sanitized
|
||||
|
||||
def extract_text_from_html_manually(self, html_content):
|
||||
"""Manual HTML text extraction for testing."""
|
||||
"""Manual HTML text extraction for testing with ReDoS protection."""
|
||||
try:
|
||||
import re
|
||||
# SECURITY FIX: Input validation to prevent ReDoS attacks
|
||||
if len(html_content) > 100000: # 100KB limit
|
||||
raise ValueError("HTML input too long")
|
||||
|
||||
# Simple HTML tag removal
|
||||
text = re.sub(r"<[^>]+>", "", html_content)
|
||||
return text.strip()
|
||||
# Check for potentially dangerous patterns
|
||||
if html_content.count('<') > 1000:
|
||||
raise ValueError("Too many HTML tags")
|
||||
|
||||
# Use safer HTML processing instead of vulnerable regex
|
||||
import html
|
||||
try:
|
||||
# Simple character-based tag removal (safer than regex)
|
||||
result = []
|
||||
in_tag = False
|
||||
for char in html_content:
|
||||
if char == '<':
|
||||
in_tag = True
|
||||
elif char == '>':
|
||||
in_tag = False
|
||||
elif not in_tag:
|
||||
result.append(char)
|
||||
|
||||
text = ''.join(result)
|
||||
return html.unescape(text).strip()
|
||||
except Exception:
|
||||
# Fallback: just remove common problematic characters
|
||||
return ''.join(c for c in html_content if c not in '<>').strip()
|
||||
except Exception:
|
||||
return "Extraction failed"
|
||||
|
||||
@@ -326,8 +326,8 @@ async def test_optimized_state_creation() -> None:
|
||||
|
||||
# Verify state structure
|
||||
assert state["raw_input"] == f'{{"query": "{query}"}}'
|
||||
assert state["parsed_input"]["user_query"] == query
|
||||
assert state["messages"][0]["content"] == query
|
||||
assert state["parsed_input"].get("user_query") == query
|
||||
assert state["messages"][0].content == query
|
||||
assert "config" in state
|
||||
assert "context" in state
|
||||
assert "errors" in state
|
||||
@@ -348,7 +348,7 @@ async def test_backward_compatibility() -> None:
|
||||
assert graph is not None
|
||||
|
||||
state_sync = create_initial_state_sync(query="sync test")
|
||||
assert state_sync["parsed_input"]["user_query"] == "sync test"
|
||||
assert state_sync["parsed_input"].get("user_query") == "sync test"
|
||||
|
||||
state_legacy = get_initial_state()
|
||||
assert "messages" in state_legacy
|
||||
|
||||
@@ -298,8 +298,8 @@ class TestGraphPerformance:
|
||||
|
||||
def test_configuration_hashing_performance(self) -> None:
|
||||
"""Test that configuration hashing is efficient."""
|
||||
from biz_bud.core.config.loader import generate_config_hash
|
||||
from biz_bud.core.config.schemas import AppConfig
|
||||
from biz_bud.graphs.graph import _generate_config_hash
|
||||
|
||||
# Create a complex configuration using AppConfig with defaults
|
||||
complex_config = AppConfig(
|
||||
@@ -318,7 +318,7 @@ class TestGraphPerformance:
|
||||
|
||||
hashes = []
|
||||
for _ in range(batch_size):
|
||||
hash_value = _generate_config_hash(complex_config)
|
||||
hash_value = generate_config_hash(complex_config)
|
||||
hashes.append(hash_value)
|
||||
|
||||
total_time = time.perf_counter() - start_time
|
||||
|
||||
@@ -1078,31 +1078,30 @@ class TestRefactoringRegressionTests:
|
||||
"""Test the new force_refresh parameter works in cache_async decorator."""
|
||||
from biz_bud.core.caching.decorators import cache_async
|
||||
|
||||
call_count = 0
|
||||
call_count = [0] # Use list to make it mutable
|
||||
|
||||
@cache_async(backend=mock_cache_backend, ttl=300)
|
||||
async def cached_func(x: int) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"result_{x}_{call_count}"
|
||||
call_count[0] += 1
|
||||
return f"result_{x}_{call_count[0]}"
|
||||
|
||||
# First call - cache miss
|
||||
mock_cache_backend.get.return_value = None
|
||||
result1 = await cached_func(42)
|
||||
assert result1 == "result_42_1"
|
||||
assert call_count == 1
|
||||
assert call_count[0] == 1
|
||||
|
||||
# Second call - should use cache
|
||||
mock_cache_backend.get.return_value = pickle.dumps("cached_result_42")
|
||||
result2 = await cached_func(42)
|
||||
assert result2 == "cached_result_42"
|
||||
assert call_count == 1 # Function not called again
|
||||
assert call_count[0] == 1 # Function not called again
|
||||
|
||||
# Third call with force_refresh - should bypass cache
|
||||
# Note: force_refresh is handled by the decorator, not the function signature
|
||||
result3 = await cached_func(42) # Remove force_refresh parameter
|
||||
assert result3 == "cached_result_42" # Should return cached result
|
||||
assert call_count == 1 # Function not called again
|
||||
assert call_count[0] == 1 # Function not called again
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_decorator_with_helper_functions_end_to_end(self, backend_requiring_ainit):
|
||||
@@ -1159,12 +1158,11 @@ class TestRefactoringRegressionTests:
|
||||
"""Test that serialization failures don't break caching functionality."""
|
||||
from biz_bud.core.caching.decorators import cache_async
|
||||
|
||||
call_count = 0
|
||||
call_count = [0] # Use list to make it mutable
|
||||
|
||||
@cache_async(backend=mock_cache_backend, ttl=300)
|
||||
async def func_returning_non_serializable() -> Any:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
call_count[0] += 1
|
||||
return non_serializable_object
|
||||
|
||||
mock_cache_backend.get.return_value = None
|
||||
@@ -1172,7 +1170,7 @@ class TestRefactoringRegressionTests:
|
||||
# Should succeed despite serialization failure
|
||||
result = await func_returning_non_serializable()
|
||||
assert result is non_serializable_object
|
||||
assert call_count == 1
|
||||
assert call_count[0] == 1
|
||||
|
||||
# Cache get should have been called (double-check pattern calls get twice for cache miss)
|
||||
assert mock_cache_backend.get.call_count == 2
|
||||
@@ -1181,20 +1179,19 @@ class TestRefactoringRegressionTests:
|
||||
# Second call should work normally (no caching due to serialization failure)
|
||||
result2 = await func_returning_non_serializable()
|
||||
assert result2 is non_serializable_object
|
||||
assert call_count == 2
|
||||
assert call_count[0] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deserialization_error_graceful_handling(self, mock_cache_backend, corrupted_pickle_data):
|
||||
"""Test that deserialization failures don't break caching functionality."""
|
||||
from biz_bud.core.caching.decorators import cache_async
|
||||
|
||||
call_count = 0
|
||||
call_count = [0] # Use list to make it mutable
|
||||
|
||||
@cache_async(backend=mock_cache_backend, ttl=300)
|
||||
async def normal_func() -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"result_{call_count}"
|
||||
call_count[0] += 1
|
||||
return f"result_{call_count[0]}"
|
||||
|
||||
# Return corrupted data from cache
|
||||
mock_cache_backend.get.return_value = corrupted_pickle_data
|
||||
@@ -1202,7 +1199,7 @@ class TestRefactoringRegressionTests:
|
||||
# Should fall back to computing result
|
||||
result = await normal_func()
|
||||
assert result == "result_1"
|
||||
assert call_count == 1
|
||||
assert call_count[0] == 1
|
||||
|
||||
# Cache operations should have been called (double-check pattern calls get twice for cache miss)
|
||||
assert mock_cache_backend.get.call_count == 2
|
||||
@@ -1260,16 +1257,15 @@ class TestRefactoringRegressionTests:
|
||||
"""Test that thread safety wasn't broken by refactoring."""
|
||||
from biz_bud.core.caching.decorators import cache_async
|
||||
|
||||
call_count = 0
|
||||
call_count = [0] # Use list to make it mutable
|
||||
call_lock = asyncio.Lock()
|
||||
|
||||
@cache_async(backend=mock_cache_backend, ttl=300)
|
||||
async def concurrent_func(x: int) -> str:
|
||||
nonlocal call_count
|
||||
async with call_lock:
|
||||
call_count += 1
|
||||
call_count[0] += 1
|
||||
await asyncio.sleep(0.01) # Simulate work
|
||||
return f"result_{x}_{call_count}"
|
||||
return f"result_{x}_{call_count[0]}"
|
||||
|
||||
# Mock cache miss for concurrent calls
|
||||
mock_cache_backend.get.return_value = None
|
||||
@@ -1287,7 +1283,7 @@ class TestRefactoringRegressionTests:
|
||||
assert all(result.startswith("result_") for result in results)
|
||||
|
||||
# Function should have been called for each unique argument
|
||||
assert call_count == 5
|
||||
assert call_count[0] == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_key_function_still_works(self, mock_cache_backend):
|
||||
|
||||
@@ -16,17 +16,17 @@ class TestInMemoryCacheLRU:
|
||||
"""Test LRU (Least Recently Used) behavior of InMemoryCache."""
|
||||
|
||||
@pytest.fixture
|
||||
def lru_cache(self) -> InMemoryCache:
|
||||
def lru_cache(self) -> InMemoryCache[bytes]:
|
||||
"""Provide LRU cache with small max size for testing."""
|
||||
return InMemoryCache(max_size=3)
|
||||
return InMemoryCache[bytes](max_size=3)
|
||||
|
||||
@pytest.fixture
|
||||
def unlimited_cache(self) -> InMemoryCache:
|
||||
def unlimited_cache(self) -> InMemoryCache[bytes]:
|
||||
"""Provide cache with no size limit."""
|
||||
return InMemoryCache(max_size=None)
|
||||
return InMemoryCache[bytes](max_size=None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lru_basic_eviction(self, lru_cache: InMemoryCache):
|
||||
async def test_lru_basic_eviction(self, lru_cache: InMemoryCache[bytes]):
|
||||
"""Test basic LRU eviction when cache exceeds max size."""
|
||||
# Fill cache to capacity
|
||||
await lru_cache.set("key1", b"value1")
|
||||
@@ -48,7 +48,7 @@ class TestInMemoryCacheLRU:
|
||||
assert await lru_cache.get("key4") == b"value4"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lru_access_updates_order(self, lru_cache: InMemoryCache):
|
||||
async def test_lru_access_updates_order(self, lru_cache: InMemoryCache[bytes]):
|
||||
"""Test that accessing a key updates its position in LRU order."""
|
||||
# Fill cache to capacity
|
||||
await lru_cache.set("key1", b"value1")
|
||||
@@ -68,7 +68,7 @@ class TestInMemoryCacheLRU:
|
||||
assert await lru_cache.get("key4") == b"value4"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lru_multiple_accesses(self, lru_cache: InMemoryCache):
|
||||
async def test_lru_multiple_accesses(self, lru_cache: InMemoryCache[bytes]):
|
||||
"""Test LRU behavior with multiple access patterns."""
|
||||
# Fill cache
|
||||
await lru_cache.set("key1", b"value1")
|
||||
@@ -90,7 +90,7 @@ class TestInMemoryCacheLRU:
|
||||
assert await lru_cache.get("key4") == b"value4"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lru_update_existing_key_preserves_order(self, lru_cache: InMemoryCache):
|
||||
async def test_lru_update_existing_key_preserves_order(self, lru_cache: InMemoryCache[bytes]):
|
||||
"""Test that updating existing key doesn't trigger eviction."""
|
||||
# Fill cache to capacity
|
||||
await lru_cache.set("key1", b"value1")
|
||||
@@ -114,7 +114,7 @@ class TestInMemoryCacheLRU:
|
||||
assert await lru_cache.get("key4") == b"value4"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lru_no_eviction_when_under_limit(self, lru_cache: InMemoryCache):
|
||||
async def test_lru_no_eviction_when_under_limit(self, lru_cache: InMemoryCache[bytes]):
|
||||
"""Test that no eviction occurs when under max size."""
|
||||
# Add keys under the limit
|
||||
await lru_cache.set("key1", b"value1")
|
||||
@@ -128,7 +128,7 @@ class TestInMemoryCacheLRU:
|
||||
assert await lru_cache.size() == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lru_unlimited_cache_no_eviction(self, unlimited_cache: InMemoryCache):
|
||||
async def test_lru_unlimited_cache_no_eviction(self, unlimited_cache: InMemoryCache[bytes]):
|
||||
"""Test that unlimited cache never evicts based on size."""
|
||||
# Add many keys
|
||||
for i in range(100):
|
||||
@@ -145,7 +145,7 @@ class TestInMemoryCacheLRU:
|
||||
@pytest.mark.parametrize("max_size", [1, 2, 5, 10])
|
||||
async def test_lru_various_cache_sizes(self, max_size: int):
|
||||
"""Test LRU behavior with various cache sizes."""
|
||||
cache = InMemoryCache(max_size=max_size)
|
||||
cache = InMemoryCache[bytes](max_size=max_size)
|
||||
|
||||
# Fill cache beyond capacity
|
||||
for i in range(max_size + 5):
|
||||
@@ -164,7 +164,7 @@ class TestInMemoryCacheLRU:
|
||||
assert await cache.get(f"key{i}") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lru_ordereddict_behavior(self, lru_cache: InMemoryCache):
|
||||
async def test_lru_ordereddict_behavior(self, lru_cache: InMemoryCache[bytes]):
|
||||
"""Test that OrderedDict maintains proper LRU order."""
|
||||
# Fill cache
|
||||
await lru_cache.set("key1", b"value1")
|
||||
@@ -182,7 +182,7 @@ class TestInMemoryCacheLRU:
|
||||
assert cache_keys[-1] == "key1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lru_exists_updates_order(self, lru_cache: InMemoryCache):
|
||||
async def test_lru_exists_updates_order(self, lru_cache: InMemoryCache[bytes]):
|
||||
"""Test that exists() check also updates LRU order."""
|
||||
# Fill cache
|
||||
await lru_cache.set("key1", b"value1")
|
||||
@@ -211,12 +211,12 @@ class TestInMemoryCacheTTLLRU:
|
||||
"""Test interaction between TTL expiration and LRU eviction."""
|
||||
|
||||
@pytest.fixture
|
||||
def ttl_lru_cache(self) -> InMemoryCache:
|
||||
def ttl_lru_cache(self) -> InMemoryCache[bytes]:
|
||||
"""Provide cache with small max size for TTL + LRU testing."""
|
||||
return InMemoryCache(max_size=3)
|
||||
return InMemoryCache[bytes](max_size=3)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_entries_dont_affect_lru_count(self, ttl_lru_cache: InMemoryCache):
|
||||
async def test_expired_entries_dont_affect_lru_count(self, ttl_lru_cache: InMemoryCache[bytes]):
|
||||
"""Test that expired entries don't count toward max size."""
|
||||
# Add entries with short TTL
|
||||
await ttl_lru_cache.set("key1", b"value1", ttl=1)
|
||||
@@ -243,7 +243,7 @@ class TestInMemoryCacheTTLLRU:
|
||||
assert await ttl_lru_cache.get("key6") == b"value6"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_removes_from_lru_order(self, ttl_lru_cache: InMemoryCache):
|
||||
async def test_cleanup_expired_removes_from_lru_order(self, ttl_lru_cache: InMemoryCache[bytes]):
|
||||
"""Test that cleanup_expired properly removes entries from LRU tracking."""
|
||||
# Add mix of expiring and non-expiring entries
|
||||
await ttl_lru_cache.set("expire1", b"value1", ttl=1)
|
||||
@@ -274,7 +274,7 @@ class TestInMemoryCacheTTLLRU:
|
||||
assert await ttl_lru_cache.get("new2") == b"newvalue2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lru_eviction_with_mixed_ttl(self, ttl_lru_cache: InMemoryCache):
|
||||
async def test_lru_eviction_with_mixed_ttl(self, ttl_lru_cache: InMemoryCache[bytes]):
|
||||
"""Test LRU eviction when entries have different TTL values."""
|
||||
# Add entries with different TTL
|
||||
await ttl_lru_cache.set("short_ttl", b"value1", ttl=1)
|
||||
@@ -300,12 +300,12 @@ class TestInMemoryCacheConcurrentLRU:
|
||||
"""Test LRU behavior under concurrent access."""
|
||||
|
||||
@pytest.fixture
|
||||
def concurrent_cache(self) -> InMemoryCache:
|
||||
def concurrent_cache(self) -> InMemoryCache[bytes]:
|
||||
"""Provide cache for concurrency testing."""
|
||||
return InMemoryCache(max_size=5)
|
||||
return InMemoryCache[bytes](max_size=5)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_lru_operations(self, concurrent_cache: InMemoryCache):
|
||||
async def test_concurrent_lru_operations(self, concurrent_cache: InMemoryCache[bytes]):
|
||||
"""Test LRU behavior with concurrent get/set operations."""
|
||||
# Pre-populate cache
|
||||
for i in range(5):
|
||||
@@ -349,7 +349,7 @@ class TestInMemoryCacheConcurrentLRU:
|
||||
assert accessed_keys_present > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lru_thread_safety(self, concurrent_cache: InMemoryCache):
|
||||
async def test_lru_thread_safety(self, concurrent_cache: InMemoryCache[bytes]):
|
||||
"""Test that LRU operations are thread-safe."""
|
||||
# Fill cache
|
||||
for i in range(5):
|
||||
@@ -389,7 +389,7 @@ class TestInMemoryCacheLRUEdgeCases:
|
||||
@pytest.mark.asyncio
|
||||
async def test_lru_cache_size_one(self):
|
||||
"""Test LRU behavior with cache size of 1."""
|
||||
cache = InMemoryCache(max_size=1)
|
||||
cache = InMemoryCache[bytes](max_size=1)
|
||||
|
||||
# Add first key
|
||||
await cache.set("key1", b"value1")
|
||||
@@ -408,7 +408,7 @@ class TestInMemoryCacheLRUEdgeCases:
|
||||
@pytest.mark.asyncio
|
||||
async def test_lru_cache_size_zero(self):
|
||||
"""Test behavior with cache size of 0."""
|
||||
cache = InMemoryCache(max_size=0)
|
||||
cache = InMemoryCache[bytes](max_size=0)
|
||||
|
||||
# Should not store anything
|
||||
await cache.set("key1", b"value1")
|
||||
@@ -418,7 +418,7 @@ class TestInMemoryCacheLRUEdgeCases:
|
||||
@pytest.mark.asyncio
|
||||
async def test_lru_large_cache_performance(self):
|
||||
"""Test LRU performance with larger cache size."""
|
||||
cache = InMemoryCache(max_size=1000)
|
||||
cache = InMemoryCache[bytes](max_size=1000)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
@@ -453,7 +453,7 @@ class TestInMemoryCacheLRUEdgeCases:
|
||||
@pytest.mark.asyncio
|
||||
async def test_lru_behavior_after_clear(self):
|
||||
"""Test that LRU behavior works correctly after cache clear."""
|
||||
cache = InMemoryCache(max_size=3)
|
||||
cache = InMemoryCache[bytes](max_size=3)
|
||||
|
||||
# Fill and clear cache
|
||||
for i in range(3):
|
||||
@@ -479,7 +479,7 @@ class TestInMemoryCacheLRUEdgeCases:
|
||||
@pytest.mark.asyncio
|
||||
async def test_lru_delete_behavior(self):
|
||||
"""Test LRU behavior when keys are deleted."""
|
||||
cache = InMemoryCache(max_size=3)
|
||||
cache = InMemoryCache[bytes](max_size=3)
|
||||
|
||||
# Fill cache
|
||||
await cache.set("key1", b"value1")
|
||||
|
||||
@@ -414,7 +414,7 @@ class TestContextProfileMap:
|
||||
present_contexts = [context for context in expected_contexts if context in profile_map]
|
||||
|
||||
# Validate all present contexts exist (type is guaranteed to be string)
|
||||
assert len(present_contexts) > 0, "At least one expected context should be present"
|
||||
assert present_contexts, "At least one expected context should be present"
|
||||
|
||||
def test_get_context_mappings(self):
|
||||
"""Test get_context_mappings function."""
|
||||
|
||||
@@ -1332,14 +1332,13 @@ class TestUserInteractionEdgeCases:
|
||||
|
||||
def test_memory_efficiency_with_large_states(self):
|
||||
"""Test user interaction functions with large state objects."""
|
||||
large_state = {f"field_{i}": f"value_{i}" for i in range(10000)}
|
||||
large_state.update({
|
||||
"human_interrupt": "stop",
|
||||
"status": "error",
|
||||
"user": {"permissions": ["basic_access"]},
|
||||
"pending_user_input": True,
|
||||
"input_type": "text",
|
||||
})
|
||||
from typing import Any
|
||||
large_state: dict[str, Any] = {f"field_{i}": f"value_{i}" for i in range(10000)}
|
||||
large_state["human_interrupt"] = "stop"
|
||||
large_state["status"] = "error"
|
||||
large_state["user"] = {"permissions": ["basic_access"]}
|
||||
large_state["pending_user_input"] = True
|
||||
large_state["input_type"] = "text"
|
||||
# Functions should efficiently access only needed fields
|
||||
interrupt_router = human_interrupt()
|
||||
result = interrupt_router(large_state)
|
||||
|
||||
@@ -770,7 +770,8 @@ class TestErrorAggregatorEdgeCases:
|
||||
for i in range(expected_limit + 2) # Try more than the limit
|
||||
]
|
||||
should_report_results = [aggregator.should_report_error(error_info) for error_info in error_infos]
|
||||
allowed_count = sum(1 for should_report, _ in should_report_results if should_report)
|
||||
allowed_count = sum(bool(should_report)
|
||||
for should_report, _ in should_report_results)
|
||||
return (severity, allowed_count, expected_limit)
|
||||
|
||||
# Test all severity types using comprehension
|
||||
|
||||
@@ -533,7 +533,7 @@ class TestCreateAndAddError:
|
||||
mock_add_error.return_value = {**basic_state, "errors": [created_error]}
|
||||
|
||||
# Merge contexts instead of unpacking extra_context as kwargs
|
||||
merged_context = {**context, **extra_context}
|
||||
merged_context = context | extra_context
|
||||
await create_and_add_error(
|
||||
basic_state,
|
||||
"Test message",
|
||||
|
||||
@@ -662,7 +662,7 @@ class TestErrorRouter:
|
||||
router = ErrorRouter()
|
||||
|
||||
with patch.object(router, '_log_error') as mock_log:
|
||||
action, error = await router._apply_action(RouteAction.LOG, basic_error_info, {})
|
||||
action, error = router._apply_action(RouteAction.LOG, basic_error_info, {})
|
||||
|
||||
assert action == RouteAction.LOG
|
||||
assert error == basic_error_info
|
||||
@@ -674,7 +674,7 @@ class TestErrorRouter:
|
||||
router = ErrorRouter()
|
||||
|
||||
with patch('biz_bud.core.errors.router.info_highlight') as mock_info:
|
||||
action, error = await router._apply_action(RouteAction.SUPPRESS, basic_error_info, {})
|
||||
action, error = router._apply_action(RouteAction.SUPPRESS, basic_error_info, {})
|
||||
|
||||
assert action == RouteAction.SUPPRESS
|
||||
assert error is None
|
||||
@@ -690,7 +690,7 @@ class TestErrorRouter:
|
||||
escalated_error["severity"] = "critical"
|
||||
mock_escalate.return_value = escalated_error
|
||||
|
||||
action, error = await router._apply_action(RouteAction.ESCALATE, basic_error_info, {})
|
||||
action, error = router._apply_action(RouteAction.ESCALATE, basic_error_info, {})
|
||||
|
||||
assert action == RouteAction.ESCALATE
|
||||
assert error == escalated_error
|
||||
@@ -701,7 +701,7 @@ class TestErrorRouter:
|
||||
"""Test applying AGGREGATE action."""
|
||||
router = ErrorRouter(aggregator=mock_aggregator)
|
||||
|
||||
action, error = await router._apply_action(RouteAction.AGGREGATE, basic_error_info, {})
|
||||
action, error = router._apply_action(RouteAction.AGGREGATE, basic_error_info, {})
|
||||
|
||||
assert action == RouteAction.AGGREGATE
|
||||
assert error == basic_error_info
|
||||
@@ -712,7 +712,7 @@ class TestErrorRouter:
|
||||
"""Test applying AGGREGATE action without aggregator."""
|
||||
router = ErrorRouter(aggregator=None)
|
||||
|
||||
action, error = await router._apply_action(RouteAction.AGGREGATE, basic_error_info, {})
|
||||
action, error = router._apply_action(RouteAction.AGGREGATE, basic_error_info, {})
|
||||
|
||||
assert action == RouteAction.AGGREGATE
|
||||
assert error == basic_error_info
|
||||
@@ -723,7 +723,7 @@ class TestErrorRouter:
|
||||
"""Test applying action with None error."""
|
||||
router = ErrorRouter()
|
||||
|
||||
action, error = await router._apply_action(RouteAction.LOG, None, {})
|
||||
action, error = router._apply_action(RouteAction.LOG, None, {})
|
||||
|
||||
assert action == RouteAction.LOG
|
||||
assert error is None
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
@@ -484,7 +485,7 @@ class TestCustomHandlers:
|
||||
context = {"retry_count": 1, "max_retries": 3}
|
||||
|
||||
with patch("biz_bud.core.errors.router_config.logger") as mock_logger:
|
||||
result = await retry_handler(sample_error_info, context)
|
||||
result = retry_handler(sample_error_info, context)
|
||||
|
||||
# Should return None to signal retry
|
||||
assert result is None
|
||||
@@ -498,7 +499,7 @@ class TestCustomHandlers:
|
||||
"""Test retry handler when at retry limit."""
|
||||
context = {"retry_count": 3, "max_retries": 3}
|
||||
|
||||
result = await retry_handler(sample_error_info, context)
|
||||
result = retry_handler(sample_error_info, context)
|
||||
|
||||
# Should return modified error
|
||||
assert result is not None
|
||||
@@ -511,7 +512,7 @@ class TestCustomHandlers:
|
||||
"""Test retry handler with default context values."""
|
||||
context = {}
|
||||
|
||||
result = await retry_handler(sample_error_info, context)
|
||||
result = retry_handler(sample_error_info, context)
|
||||
|
||||
# Should use defaults (0 retries, max 3)
|
||||
assert result is None # Under limit
|
||||
@@ -522,7 +523,7 @@ class TestCustomHandlers:
|
||||
context = {"test": "value"}
|
||||
|
||||
with patch("biz_bud.core.errors.router_config.logger") as mock_logger:
|
||||
result = await notification_handler(sample_error_info, context)
|
||||
result = notification_handler(sample_error_info, context)
|
||||
|
||||
# Should log critical message
|
||||
mock_logger.critical.assert_called_once()
|
||||
@@ -674,7 +675,7 @@ class TestRouterConfigIntegration:
|
||||
error_type=critical_error["error_type"],
|
||||
severity=critical_error["severity"],
|
||||
category=critical_error["category"],
|
||||
context=critical_error["context"],
|
||||
context=cast(dict[str, Any] | None, critical_error["context"]),
|
||||
traceback_str=critical_error["traceback"],
|
||||
)
|
||||
action, _ = await router.route_error(critical_error_info)
|
||||
|
||||
@@ -755,6 +755,7 @@ class TestConcurrencyAndThreadSafety:
|
||||
async def slow_extract():
|
||||
await asyncio.sleep(0.01) # Small delay
|
||||
await original_extract()
|
||||
|
||||
manager._extract_service_configs = slow_extract
|
||||
|
||||
# Simulate concurrent loading attempts
|
||||
@@ -772,10 +773,8 @@ class TestConcurrencyAndThreadSafety:
|
||||
results = await asyncio.gather(task1, task2, return_exceptions=True)
|
||||
|
||||
# One should succeed, one should fail with loading error
|
||||
success_count = sum(bool(not isinstance(r, Exception))
|
||||
for r in results)
|
||||
error_count = sum(bool(isinstance(r, ConfigurationLoadError))
|
||||
for r in results)
|
||||
success_count = sum(not isinstance(r, Exception) for r in results)
|
||||
error_count = sum(isinstance(r, ConfigurationLoadError) for r in results)
|
||||
|
||||
assert success_count >= 1 # At least one should succeed
|
||||
assert error_count >= 0 # May or may not have concurrent access error
|
||||
|
||||
@@ -544,8 +544,7 @@ class TestConcurrentOperations:
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Should only succeed once, others should be no-ops or return early
|
||||
success_count = sum(bool(not isinstance(r, Exception))
|
||||
for r in results)
|
||||
success_count = sum(not isinstance(r, Exception) for r in results)
|
||||
assert success_count >= 1
|
||||
assert lifecycle_manager._started is True
|
||||
|
||||
|
||||
@@ -114,7 +114,6 @@ class TestRestartServiceBasic:
|
||||
@asynccontextmanager
|
||||
async def failing_get_service(service_type):
|
||||
raise RuntimeError("Service initialization failed")
|
||||
yield # This line will never be reached but is needed for the generator
|
||||
|
||||
lifecycle_manager.registry.get_service = Mock(side_effect=failing_get_service)
|
||||
|
||||
@@ -298,10 +297,7 @@ class TestRestartServicePerformance:
|
||||
services = [MockTestService(f"service_{i}") for i in range(3)]
|
||||
|
||||
# Set up all services in registry using dictionary comprehension
|
||||
lifecycle_manager.registry._services.update({
|
||||
service_type: service
|
||||
for service_type, service in zip(service_types, services)
|
||||
})
|
||||
lifecycle_manager.registry._services.update(dict(zip(service_types, services)))
|
||||
|
||||
# Test restarting each service type sequentially
|
||||
async def restart_service_with_new_mock(service_type, service_id):
|
||||
|
||||
@@ -905,6 +905,5 @@ class TestConcurrentOperations:
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Should handle concurrent operations gracefully
|
||||
exception_count = sum(bool(isinstance(r, Exception))
|
||||
for r in results)
|
||||
exception_count = sum(isinstance(r, Exception) for r in results)
|
||||
assert exception_count <= len(results) # Some may fail due to concurrency, but not all
|
||||
|
||||
@@ -1004,7 +1004,6 @@ class TestErrorHandling:
|
||||
@asynccontextmanager
|
||||
async def failing_factory():
|
||||
raise RuntimeError("Factory failed")
|
||||
yield # This line will never be reached but is needed for the generator
|
||||
|
||||
service_registry.register_factory(service_type, failing_factory)
|
||||
|
||||
@@ -1037,7 +1036,6 @@ class TestErrorHandling:
|
||||
@asynccontextmanager
|
||||
async def bad_factory():
|
||||
raise RuntimeError("Init failed")
|
||||
yield # This line will never be reached but is needed for the generator
|
||||
|
||||
service_registry.register_factory(service_a, good_factory)
|
||||
service_registry.register_factory(service_b, bad_factory)
|
||||
@@ -1093,8 +1091,7 @@ class TestConcurrentOperations:
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Should handle concurrent cleanups gracefully
|
||||
exception_count = sum(bool(isinstance(r, Exception))
|
||||
for r in results)
|
||||
exception_count = sum(isinstance(r, Exception) for r in results)
|
||||
assert exception_count <= len(results)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -1114,8 +1111,7 @@ class TestConcurrentOperations:
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Should handle race conditions gracefully
|
||||
success_count = sum(bool(not isinstance(r, Exception))
|
||||
for r in results)
|
||||
success_count = sum(not isinstance(r, Exception) for r in results)
|
||||
assert success_count >= 1 # At least some should succeed
|
||||
|
||||
|
||||
|
||||
@@ -315,8 +315,8 @@ class TestServiceLifecycle:
|
||||
|
||||
results = {
|
||||
Mock: mock_service1,
|
||||
type("Service1", (), {}): mock_service2,
|
||||
type("Service2", (), {}): error,
|
||||
type("Service1", (), {}): mock_service2, # type: ignore[misc]
|
||||
type("Service2", (), {}): error, # type: ignore[misc]
|
||||
}
|
||||
|
||||
succeeded, failed = cleanup_registry.partition_results(results)
|
||||
@@ -407,8 +407,8 @@ class TestServiceCleanup:
|
||||
service2 = Mock()
|
||||
service2.cleanup = AsyncMock()
|
||||
|
||||
service_class1 = type("Service1", (), {})
|
||||
service_class2 = type("Service2", (), {})
|
||||
service_class1 = type("Service1", (), {}) # type: ignore[misc]
|
||||
service_class2 = type("Service2", (), {}) # type: ignore[misc]
|
||||
|
||||
services = {
|
||||
service_class1: service1,
|
||||
@@ -435,7 +435,7 @@ class TestServiceCleanup:
|
||||
await asyncio.sleep(15.0) # Longer than timeout
|
||||
|
||||
service.cleanup = slow_cleanup
|
||||
service_class = type("SlowService", (), {})
|
||||
service_class = type("SlowService", (), {}) # type: ignore[misc]
|
||||
|
||||
services = {service_class: service}
|
||||
|
||||
@@ -447,7 +447,7 @@ class TestServiceCleanup:
|
||||
"""Test cleanup with service errors."""
|
||||
service = Mock()
|
||||
service.cleanup = AsyncMock(side_effect=RuntimeError("Cleanup failed"))
|
||||
service_class = type("FailingService", (), {})
|
||||
service_class = type("FailingService", (), {}) # type: ignore[misc]
|
||||
|
||||
services = {service_class: service}
|
||||
|
||||
@@ -460,7 +460,7 @@ class TestServiceCleanup:
|
||||
# Setup initialized services
|
||||
service = Mock()
|
||||
service.cleanup = AsyncMock()
|
||||
service_class = type("TestService", (), {})
|
||||
service_class = type("TestService", (), {}) # type: ignore[misc]
|
||||
services = {service_class: service}
|
||||
|
||||
# Setup initializing tasks
|
||||
|
||||
@@ -501,7 +501,7 @@ class TestDeprecationWarnings:
|
||||
# Filter warnings for deprecation warnings using comprehension
|
||||
deprecation_warnings = [
|
||||
warning
|
||||
for warning in w
|
||||
for warning in (w or [])
|
||||
if issubclass(warning.category, DeprecationWarning)
|
||||
]
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user