Modernize LangGraph routing helpers for mapping-based state (#61)

* Modernize LangGraph routing helpers for mapping-based state

* Refine LangGraph v1 reducers and search node handling

* Harden search workflow state handling

* Refine LangGraph routing and search handling

* Harden LangGraph search routing and reducer safety

* Improve LangGraph search hygiene and scraping validation

* Refine LangGraph integration and search sanitization

* Refine state access protocol for routing helpers

* Refine edge helper state access and numeric coercion

* Harden search sanitization and routing utilities

* refactor: remove url processing compatibility layer

* Harden langgraph reducers and routing helpers

* Break circular import in legacy URL processing shim

* Fix pytest async hook interaction

* Modernize LangGraph helpers and sanitizers

* Address follow-up LangGraph modernization issues

* Harden LangGraph routing and URL helpers

* Refine RAG integration nodes for typed state
This commit is contained in:
2025-09-19 19:31:47 -04:00
committed by GitHub
parent 8ad47a7640
commit 34154c5dd6
77 changed files with 4489 additions and 4985 deletions

View File

@@ -43,6 +43,8 @@ from biz_bud.services.factory import ServiceFactory
from biz_bud.states.buddy import BuddyState
from biz_bud.tools.capabilities.workflow.execution import ResponseFormatter
CompiledGraph = CompiledStateGraph[BuddyState]
logger = get_logger(__name__)
@@ -57,7 +59,7 @@ __all__ = [
def create_buddy_orchestrator_graph(
config: AppConfig | None = None,
) -> "CompiledStateGraph":
) -> CompiledGraph:
"""Create the Buddy orchestrator graph with all components.
Args:
@@ -124,7 +126,7 @@ def create_buddy_orchestrator_graph(
def create_buddy_orchestrator_agent(
config: AppConfig | None = None,
service_factory: ServiceFactory | None = None,
) -> "CompiledStateGraph":
) -> CompiledGraph:
"""Create the Buddy orchestrator agent.
Args:
@@ -149,13 +151,13 @@ def create_buddy_orchestrator_agent(
# Direct singleton instance
_buddy_agent_instance: "CompiledStateGraph | None" = None
_buddy_agent_instance: CompiledGraph | None = None
def get_buddy_agent(
config: AppConfig | None = None,
service_factory: ServiceFactory | None = None,
) -> "CompiledStateGraph":
) -> CompiledGraph:
"""Get or create the Buddy agent instance.
Uses singleton pattern for default instance.
@@ -326,12 +328,12 @@ async def stream_buddy_agent(
# Export for LangGraph API
def buddy_agent_factory(config: RunnableConfig) -> "CompiledStateGraph":
def buddy_agent_factory(config: RunnableConfig) -> "CompiledGraph":
"""Create factory function for LangGraph API."""
return get_buddy_agent()
async def buddy_agent_factory_async(config: RunnableConfig) -> "CompiledStateGraph":
async def buddy_agent_factory_async(config: RunnableConfig) -> "CompiledGraph":
"""Async factory function for LangGraph API to avoid blocking calls."""
# Use asyncio.to_thread to run the synchronous initialization in a thread
# This prevents blocking the event loop

View File

@@ -8,8 +8,13 @@ mappings and condition evaluation.
from __future__ import annotations
from collections.abc import Callable
from typing import Any
from collections.abc import Callable, Hashable, Mapping
from typing import TypeVar, cast
from .core import StateLike, get_state_value
KeyT = TypeVar("KeyT", bound=Hashable)
class BasicRouters:
@@ -18,9 +23,9 @@ class BasicRouters:
@staticmethod
def route_on_key(
state_key: str,
mapping: dict[Any, str],
mapping: Mapping[KeyT, str],
default: str = "end",
) -> Callable[[dict[str, Any]], str]:
) -> Callable[[StateLike], str]:
"""Route based on a state key value.
Simple routing pattern that looks up a value in state and maps it
@@ -45,18 +50,23 @@ class BasicRouters:
```
"""
def router(state: dict[str, Any]) -> str:
value = state.get(state_key)
return mapping.get(value, default)
def router(state: StateLike) -> str:
value = get_state_value(state, state_key)
if isinstance(value, Hashable):
try:
return mapping.get(cast(KeyT, value), default)
except TypeError:
return default
return default
return router
@staticmethod
def route_on_condition(
condition: Callable[[dict[str, Any]], bool],
condition: Callable[[StateLike], bool],
true_target: str,
false_target: str,
) -> Callable[[dict[str, Any]], str]:
) -> Callable[[StateLike], str]:
"""Route based on a boolean condition.
Simple binary routing based on evaluating a condition function.
@@ -81,7 +91,7 @@ class BasicRouters:
```
"""
def router(state: dict[str, Any]) -> str:
def router(state: StateLike) -> str:
return true_target if condition(state) else false_target
return router
@@ -92,7 +102,7 @@ class BasicRouters:
threshold: float,
above_target: str,
below_target: str,
) -> Callable[[dict[str, Any]], str]:
) -> Callable[[StateLike], str]:
"""Route based on numeric threshold comparison.
Compares a numeric value in state against a threshold and routes
@@ -119,8 +129,8 @@ class BasicRouters:
```
"""
def router(state: dict[str, Any]) -> str:
value = state.get(state_key, 0)
def router(state: StateLike) -> str:
value = get_state_value(state, state_key, 0)
try:
numeric_value = float(value)
return above_target if numeric_value >= threshold else below_target

View File

@@ -4,17 +4,17 @@ This module provides the Buddy-specific routing logic that handles
orchestration phases and execution flow.
"""
import time
from collections.abc import Callable
from typing import Any
from langgraph.types import Command
from .router_factories import create_command_router
from .routing_rules import CommandRoutingRule
def create_buddy_command_router() -> Callable[[Any], Command[str]]:
import time
from collections.abc import Callable
from langgraph.types import Command
from .core import StateLike
from .router_factories import create_command_router
from .routing_rules import CommandRoutingRule
def create_buddy_command_router() -> Callable[[StateLike], Command[str]]:
"""Create a specialized Command router for Buddy orchestration patterns.
This provides the standard Buddy routing logic using Commands for enhanced

View File

@@ -1,14 +1,9 @@
"""Command and Send pattern implementations for LangGraph edge routing.
This module provides advanced routing patterns using LangGraph's Command
and Send objects for dynamic control flow and map-reduce patterns.
"""
"""Command and Send pattern implementations for LangGraph edge routing."""
from __future__ import annotations
from collections.abc import Callable
from typing import Any
from collections.abc import Callable, Mapping, Sequence
from langgraph.types import Command, Send
@@ -16,29 +11,140 @@ from biz_bud.logging import debug_highlight, get_logger
logger = get_logger(__name__)
StateMapping = Mapping[str, object]
def _is_json_compatible(value: object) -> bool:
if value is None or isinstance(value, (bool, int, float, str)):
return True
if isinstance(value, (list, tuple)):
return all(_is_json_compatible(item) for item in value)
if isinstance(value, Mapping):
return all(isinstance(key, str) and _is_json_compatible(val) for key, val in value.items())
return False
def _sanitize_value(value: object) -> object:
if value is None or isinstance(value, (bool, int, float, str)):
return value
if isinstance(value, (list, tuple)):
return [_sanitize_value(item) for item in value]
if isinstance(value, Mapping):
sanitized: dict[str, object] = {}
for key, val in value.items():
sanitized[str(key)] = _sanitize_value(val)
return sanitized
return repr(value)
def _sanitize_mapping(mapping: Mapping[str, object]) -> dict[str, object]:
sanitized: dict[str, object] = {}
for key, value in mapping.items():
sanitized_key = key if isinstance(key, str) else str(key)
sanitized[sanitized_key] = _sanitize_value(value)
return sanitized
def _estimate_size(obj: object, limit: int = 64_000) -> int:
try:
import json
serialized = json.dumps(obj, default=lambda _: "<non-serializable>")
return len(serialized.encode("utf-8"))
except Exception:
return limit + 1
# Private helper functions for DRY code
def _log_command(
target: str, updates: dict[str, Any] | None, category: str
target: str, updates: Mapping[str, object] | None, category: str
) -> Command[str]:
"""Create a Command object with consistent logging."""
msg = f"{category}: routing to '{target}'"
if updates:
msg += f" with updates: {list(updates.keys())}"
debug_highlight(msg, category=category)
return Command(goto=target, update=updates or {})
if updates is not None:
safe_updates = _sanitize_mapping(updates)
else:
safe_updates = {}
return Command(goto=target, update=safe_updates)
def _log_sends(targets: list[tuple[str, dict[str, Any]]], category: str) -> list[Send]:
def _log_sends(
targets: Sequence[tuple[str, Mapping[str, object]]] | object,
category: str,
) -> list[Send]:
"""Create Send objects with consistent logging."""
sends = [Send(t, state) for t, state in targets]
debug_highlight(f"{category}: dispatching {len(sends)} branches", category=category)
return sends
if not isinstance(targets, Sequence) or isinstance(targets, (str, bytes, bytearray)):
debug_highlight(
f"{category}: invalid targets container: {type(targets)!r}",
category=category,
)
raise TypeError(f"{category}: 'targets' must be a sequence of (str, Mapping) tuples")
MAX_BRANCHES = 1000
if len(targets) > MAX_BRANCHES:
debug_highlight(
f"{category}: too many target branches: {len(targets)} (limit {MAX_BRANCHES})",
category=category,
)
raise ValueError(f"{category}: too many branches ({len(targets)})")
MAX_PAYLOAD_BYTES = 64_000
safe_sends: list[Send] = []
invalid_count = 0
for idx, item in enumerate(targets):
if not isinstance(item, tuple) or len(item) != 2:
debug_highlight(
f"{category}: invalid target entry at index {idx}: {item!r}",
category=category,
)
invalid_count += 1
continue
target, state = item
if not isinstance(target, str) or not target:
debug_highlight(
f"{category}: invalid target at index {idx}: {target!r}",
category=category,
)
invalid_count += 1
continue
if not isinstance(state, Mapping):
debug_highlight(
f"{category}: invalid state mapping at index {idx}: {type(state)!r}",
category=category,
)
invalid_count += 1
continue
payload = _sanitize_mapping(state)
if _estimate_size(payload) > MAX_PAYLOAD_BYTES:
debug_highlight(
f"{category}: payload for '{target}' exceeds {MAX_PAYLOAD_BYTES} bytes; dropping entry",
category=category,
)
invalid_count += 1
continue
safe_sends.append(Send(target, payload))
if len(targets) > 0 and not safe_sends:
debug_highlight(
f"{category}: all {len(targets)} target entries were invalid; aborting dispatch",
category=category,
)
raise ValueError(f"{category}: no valid target entries to dispatch")
debug_highlight(
f"{category}: dispatching {len(safe_sends)} branches (invalid: {invalid_count})",
category=category,
)
return safe_sends
def create_command_router(
routing_logic: Callable[[dict[str, Any]], tuple[str, dict[str, Any] | None]],
) -> Callable[[dict[str, Any]], Command[str]]:
routing_logic: Callable[
[StateMapping], tuple[str, Mapping[str, object] | None]
],
) -> Callable[[StateMapping], Command[str]]:
"""Create a router that returns Command objects for combined state update and routing.
This factory creates routers that can both update state and control flow
@@ -63,7 +169,7 @@ def create_command_router(
```
"""
def router(state: dict[str, Any]) -> Command[str]:
def router(state: StateMapping) -> Command[str]:
target, updates = routing_logic(state)
return _log_command(target, updates, category="CommandRouter")
@@ -71,8 +177,10 @@ def create_command_router(
def create_dynamic_send_router(
target_generator: Callable[[dict[str, Any]], list[tuple[str, dict[str, Any]]]],
) -> Callable[[dict[str, Any]], list[Send]]:
target_generator: Callable[
[StateMapping], Sequence[tuple[str, Mapping[str, object]]]
],
) -> Callable[[StateMapping], list[Send]]:
"""Create a router that generates Send objects for dynamic fan-out patterns.
This factory creates routers for map-reduce patterns where you need to
@@ -97,7 +205,7 @@ def create_dynamic_send_router(
```
"""
def router(state: dict[str, Any]) -> list[Send]:
def router(state: StateMapping) -> list[Send]:
targets = target_generator(state)
return _log_sends(targets, category="SendRouter")
@@ -105,12 +213,12 @@ def create_dynamic_send_router(
def create_conditional_command_router(
conditions: list[
tuple[Callable[[dict[str, Any]], bool], str, dict[str, Any] | None]
conditions: Sequence[
tuple[Callable[[StateMapping], bool], str, Mapping[str, object] | None]
],
default_target: str = "end",
default_updates: dict[str, Any] | None = None,
) -> Callable[[dict[str, Any]], Command[str]]:
default_updates: Mapping[str, object] | None = None,
) -> Callable[[StateMapping], Command[str]]:
"""Create a router with multiple conditions that returns Command objects.
Args:
@@ -131,20 +239,22 @@ def create_conditional_command_router(
```
"""
def router(state: dict[str, Any]) -> Command[str]:
def router(state: StateMapping) -> Command[str]:
for condition_fn, target, updates in conditions:
if condition_fn(state):
debug_highlight(
f"Conditional command router: condition met, routing to '{target}'",
category="ConditionalCommand",
)
return Command(goto=target, update=updates or {})
update_payload = dict(updates) if updates is not None else {}
return Command(goto=target, update=update_payload)
debug_highlight(
f"Conditional command router: no conditions met, using default '{default_target}'",
category="ConditionalCommand",
)
return Command(goto=default_target, update=default_updates or {})
default_payload = dict(default_updates) if default_updates is not None else {}
return Command(goto=default_target, update=default_payload)
return router
@@ -154,7 +264,7 @@ def create_map_reduce_router(
processor_node: str = "process_item",
reducer_node: str = "reduce_results",
item_state_key: str = "current_item",
) -> Callable[[dict[str, Any]], list[Send] | Command[str]]:
) -> Callable[[StateMapping], list[Send] | Command[str]]:
"""Create a router for map-reduce patterns using Send.
This router dispatches items to parallel processors and then
@@ -180,9 +290,20 @@ def create_map_reduce_router(
```
"""
# Create reusable send router for item processing
def _gen(state):
items = state.get(items_key, [])
start = state.get("processed_count", 0)
def _gen(state: StateMapping) -> list[tuple[str, Mapping[str, object]]]:
raw_items = state.get(items_key, [])
if not isinstance(raw_items, Sequence) or isinstance(
raw_items, (str, bytes, bytearray)
):
return []
items = list(raw_items)
processed_raw = state.get("processed_count", 0)
try:
start = max(int(processed_raw), 0)
except (TypeError, ValueError):
start = 0
return [
(
processor_node,
@@ -197,14 +318,24 @@ def create_map_reduce_router(
send_router = create_dynamic_send_router(_gen)
def router(state: dict[str, Any]) -> list[Send] | Command[str]:
items = state.get(items_key, [])
processed_count = state.get("processed_count", 0)
def router(state: StateMapping) -> list[Send] | Command[str]:
raw_items = state.get(items_key, [])
if (
not isinstance(raw_items, Sequence)
or isinstance(raw_items, (str, bytes, bytearray))
):
raw_items = ()
if processed_count < len(items):
processed_raw = state.get("processed_count", 0)
try:
processed_count = max(int(processed_raw), 0)
except (TypeError, ValueError):
processed_count = 0
if processed_count < len(raw_items):
return send_router(state)
else:
return _log_command(reducer_node, None, category="MapReduce")
return _log_command(reducer_node, None, category="MapReduce")
return router
@@ -216,7 +347,7 @@ def create_retry_command_router(
failure_node: str = "failure",
attempt_key: str = "retry_attempts",
success_key: str = "is_successful",
) -> Callable[[dict[str, Any]], Command[str]]:
) -> Callable[[StateMapping], Command[str]]:
"""Create a router that handles retry logic with Command pattern.
Args:
@@ -241,9 +372,14 @@ def create_retry_command_router(
```
"""
def router(state: dict[str, Any]) -> Command[str]:
attempts = state.get(attempt_key, 0)
is_successful = state.get(success_key, False)
def router(state: StateMapping) -> Command[str]:
attempts_raw = state.get(attempt_key, 0)
try:
attempts = max(int(attempts_raw), 0)
except (TypeError, ValueError):
attempts = 0
is_successful = bool(state.get(success_key, False))
if is_successful:
debug_highlight(
@@ -260,7 +396,7 @@ def create_retry_command_router(
goto=retry_node,
update={
attempt_key: attempts + 1,
"last_retry_timestamp": None, # Would be actual timestamp
"last_retry_timestamp": None,
},
)
else:
@@ -280,10 +416,10 @@ def create_retry_command_router(
def create_subgraph_command_router(
subgraph_mapping: dict[str, tuple[str, dict[str, Any] | None]],
subgraph_mapping: Mapping[str, tuple[str, Mapping[str, object] | None]],
state_key: str = "task_type",
parent_return_node: str = "consolidate_results",
) -> Callable[[dict[str, Any]], Command[str]]:
) -> Callable[[StateMapping], Command[str]]:
"""Create a router for delegating to subgraphs with Command.PARENT support.
Args:
@@ -304,15 +440,15 @@ def create_subgraph_command_router(
```
"""
def router(state: dict[str, Any]) -> Command[str]:
def router(state: StateMapping) -> Command[str]:
task_type = state.get(state_key)
if task_type and isinstance(task_type, str) and task_type in subgraph_mapping:
subgraph_node, initial_state = subgraph_mapping[task_type]
updates = {"delegated_to": subgraph_node}
updates: dict[str, object] = {"delegated_to": subgraph_node}
if initial_state:
updates.update(initial_state)
updates.update(dict(initial_state))
debug_highlight(
f"Subgraph router: delegating '{task_type}' to '{subgraph_node}'",
@@ -343,10 +479,10 @@ def route_on_success(
success_node: str = "continue",
failure_node: str = "error_handler",
success_key: str = "success",
) -> Callable[[dict[str, Any]], Command[str]]:
) -> Callable[[StateMapping], Command[str]]:
"""Create simple success/failure router with Command pattern."""
return create_conditional_command_router(
[(lambda s: s.get(success_key, False), success_node, None)],
[(lambda s: bool(s.get(success_key, False)), success_node, None)],
default_target=failure_node,
default_updates={"status": "failed"},
)
@@ -355,14 +491,18 @@ def route_on_success(
def fan_out_tasks(
tasks_key: str = "tasks",
processor_node: str = "process_task",
) -> Callable[[dict[str, Any]], list[Send]]:
) -> Callable[[StateMapping], list[Send]]:
"""Create simple fan-out router for parallel task processing."""
def generator(state: dict[str, Any]) -> list[tuple[str, dict[str, Any]]]:
tasks = state.get(tasks_key, [])
def generator(state: StateMapping) -> list[tuple[str, Mapping[str, object]]]:
tasks_value = state.get(tasks_key, [])
if not isinstance(tasks_value, Sequence) or isinstance(
tasks_value, (str, bytes, bytearray)
):
return []
return [
(processor_node, {"task": task, "task_id": i})
for i, task in enumerate(tasks)
for i, task in enumerate(tasks_value)
]
return create_dynamic_send_router(generator)
@@ -373,9 +513,9 @@ class CommandRouters:
@staticmethod
def command_route_with_update(
routing_fn: Callable[[dict[str, Any]], str],
update_fn: Callable[[dict[str, Any]], dict[str, Any]],
) -> Callable[[dict[str, Any]], Command[str]]:
routing_fn: Callable[[StateMapping], str],
update_fn: Callable[[StateMapping], Mapping[str, object]],
) -> Callable[[StateMapping], Command[str]]:
"""Create Command router that updates state while routing.
Combines routing decision with state updates in a single operation.
@@ -401,7 +541,7 @@ class CommandRouters:
```
"""
def router(state: dict[str, Any]) -> Command[str]:
def router(state: StateMapping) -> Command[str]:
target = routing_fn(state)
updates = update_fn(state)
return _log_command(target, updates, category="CommandUpdate")
@@ -410,13 +550,13 @@ class CommandRouters:
@staticmethod
def command_route_with_retry(
success_check: Callable[[dict[str, Any]], bool],
success_check: Callable[[StateMapping], bool],
success_target: str,
retry_target: str,
failure_target: str,
max_attempts: int = 3,
attempt_key: str = "attempts",
) -> Callable[[dict[str, Any]], Command[str]]:
) -> Callable[[StateMapping], Command[str]]:
"""Create Command router with retry logic.
Implements retry pattern with attempt counting and eventual failure routing.
@@ -446,8 +586,12 @@ class CommandRouters:
```
"""
def router(state: dict[str, Any]) -> Command[str]:
attempts = state.get(attempt_key, 0)
def router(state: StateMapping) -> Command[str]:
attempts_raw = state.get(attempt_key, 0)
try:
attempts = max(int(attempts_raw), 0)
except (TypeError, ValueError):
attempts = 0
if success_check(state):
return _log_command(
@@ -475,7 +619,7 @@ class CommandRouters:
processor_node: str,
item_key: str = "item",
include_index: bool = True,
) -> Callable[[dict[str, Any]], list[Send]]:
) -> Callable[[StateMapping], list[Send]]:
"""Create Send router for parallel processing of items.
Distributes items from state to parallel processor nodes using Send objects.
@@ -502,14 +646,21 @@ class CommandRouters:
```
"""
def router(state: dict[str, Any]) -> list[Send]:
items = state.get(items_key, [])
targets = []
def router(state: StateMapping) -> list[Send]:
items_value = state.get(items_key, [])
if not isinstance(items_value, Sequence) or isinstance(
items_value, (str, bytes, bytearray)
):
return []
for i, item in enumerate(items):
processor_state = {item_key: item}
targets: list[tuple[str, Mapping[str, object]]] = []
for i, item in enumerate(items_value):
processor_state: dict[str, object] = {item_key: item}
if include_index:
processor_state.update({"item_index": i, "total_items": len(items)})
processor_state.update(
{"item_index": i, "total_items": len(items_value)}
)
targets.append((processor_node, processor_state))
return _log_sends(targets, category="ParallelProcessor")
@@ -518,12 +669,12 @@ class CommandRouters:
@staticmethod
def send_conditional(
condition_fn: Callable[[dict[str, Any], Any], bool],
condition_fn: Callable[[StateMapping, object], bool],
items_key: str,
true_node: str,
false_node: str,
item_key: str = "item",
) -> Callable[[dict[str, Any]], list[Send]]:
) -> Callable[[StateMapping], list[Send]]:
"""Create conditional Send router for item filtering.
Evaluates each item against a condition and sends to different nodes
@@ -555,11 +706,16 @@ class CommandRouters:
```
"""
def router(state: dict[str, Any]) -> list[Send]:
items = state.get(items_key, [])
targets = []
def router(state: StateMapping) -> list[Send]:
items_value = state.get(items_key, [])
if not isinstance(items_value, Sequence) or isinstance(
items_value, (str, bytes, bytearray)
):
return []
for item in items:
targets: list[tuple[str, Mapping[str, object]]] = []
for item in items_value:
target_node = true_node if condition_fn(state, item) else false_node
targets.append((target_node, {item_key: item}))

View File

@@ -11,14 +11,14 @@ maintainability and organization.
from __future__ import annotations
import warnings
from collections.abc import Callable
from typing import Any
from collections.abc import Callable, Mapping
from langgraph.types import Command
from .basic_routing import BasicRouters
from .command_patterns import CommandRouters
from .workflow_routing import WorkflowRouters
from .core import StateLike, get_state_value
# Issue deprecation warning when this module is imported
warnings.warn(
@@ -38,7 +38,7 @@ class EdgeHelpers(BasicRouters, CommandRouters, WorkflowRouters):
error_target: str = "error_handler",
success_target: str = "continue",
threshold: int = 1,
) -> Callable[[dict[str, Any]], str]:
) -> Callable[[StateLike], str]:
"""Route based on error presence - DEPRECATED.
Use error handling modules for new code.
@@ -49,15 +49,18 @@ class EdgeHelpers(BasicRouters, CommandRouters, WorkflowRouters):
stacklevel=2,
)
def router(state: dict[str, Any]) -> str:
errors = state.get(error_key, [])
def router(state: StateLike) -> str:
raw_errors = get_state_value(state, error_key, [])
# Normalize errors inline to avoid circular import
if not errors:
errors = []
elif not isinstance(errors, list):
errors = [errors]
error_count = len(errors)
if not raw_errors:
normalized: list[object] = []
elif isinstance(raw_errors, list):
normalized = raw_errors
else:
normalized = [raw_errors]
error_count = len(normalized)
return error_target if error_count >= threshold else success_target
@@ -66,10 +69,10 @@ class EdgeHelpers(BasicRouters, CommandRouters, WorkflowRouters):
@staticmethod
def command_route_on_error_with_recovery(
error_key: str = "errors",
recovery_strategies: dict[str, str] | None = None,
recovery_strategies: Mapping[str, str] | None = None,
max_recovery_attempts: int = 2,
final_failure_target: str = "human_intervention",
) -> Callable[[dict[str, Any]], Command[str]]:
) -> Callable[[StateLike], Command[str]]:
"""Create error-aware Command router with recovery - DEPRECATED.
Use error handling modules for new code.
@@ -87,22 +90,32 @@ class EdgeHelpers(BasicRouters, CommandRouters, WorkflowRouters):
"parsing": "alternative_parser",
}
def router(state: dict[str, Any]) -> Command[str]:
errors = state.get(error_key, [])
recovery_attempts = state.get("recovery_attempts", 0)
def router(state: StateLike) -> Command[str]:
raw_errors = get_state_value(state, error_key, [])
recovery_attempts = get_state_value(state, "recovery_attempts", 0)
if not errors:
try:
attempts = int(recovery_attempts)
except (TypeError, ValueError):
attempts = 0
if not raw_errors:
return Command(goto="continue", update={"status": "no_errors"})
if recovery_attempts >= max_recovery_attempts:
normalized_errors = raw_errors if isinstance(raw_errors, list) else [raw_errors]
if attempts >= max_recovery_attempts:
return Command(
goto=final_failure_target,
update={"status": "recovery_exhausted", "final_errors": errors},
update={
"status": "recovery_exhausted",
"final_errors": normalized_errors,
},
)
error_type = "unknown"
if isinstance(errors, list) and errors:
first_error = errors[0]
if normalized_errors:
first_error = normalized_errors[0]
if isinstance(first_error, dict):
error_type = first_error.get("type", "unknown")
@@ -111,9 +124,9 @@ class EdgeHelpers(BasicRouters, CommandRouters, WorkflowRouters):
return Command(
goto=recovery_target,
update={
"recovery_attempts": recovery_attempts + 1,
"recovery_attempts": attempts + 1,
"recovery_strategy": error_type,
"original_errors": errors,
"original_errors": normalized_errors,
},
)

View File

@@ -1,29 +1,57 @@
"""Core routing factory functions for edge helpers.
This module provides factory functions for creating flexible routing functions
that can be used as conditional edges in LangGraph workflows.
"""
from collections.abc import Callable
from typing import Any, Literal, Protocol, TypeVar
# Generic state type for edge functions
StateT = TypeVar("StateT")
"""Core routing factory functions for edge helpers.
This module provides factory functions for creating flexible routing functions
that can be used as conditional edges in LangGraph workflows.
"""
from __future__ import annotations
from collections.abc import Callable, Mapping, Sequence
from typing import Protocol, runtime_checkable
@runtime_checkable
class StateProtocol(Protocol):
"""Protocol for state objects that provide ``get`` access."""
def get(self, key: str, default: object | None = None) -> object | None:
"""Return the stored value for ``key`` when available."""
...
@runtime_checkable
class AttributeStateProtocol(Protocol):
"""Protocol for state objects that expose attributes for lookup."""
def __getattr__(self, key: str) -> object:
"""Return the value for ``key`` when exposed as an attribute."""
...
StateLike = Mapping[str, object] | StateProtocol | AttributeStateProtocol
def get_state_value(
state: StateLike,
key: str,
default: object | None = None,
) -> object | None:
"""Safely extract a value from mapping or mapping-like state objects."""
if isinstance(state, Mapping):
return state.get(key, default)
if isinstance(state, StateProtocol):
return state.get(key, default)
if isinstance(state, AttributeStateProtocol):
return getattr(state, key, default)
return getattr(state, key, default)
class StateProtocol(Protocol):
"""Protocol for state objects that can be used with edge helpers."""
def get(self, key: str, default: Any = None) -> Any:
"""Get a value from the state."""
...
def create_enum_router(
enum_to_target: dict[str, str],
state_key: str = "routing_decision",
default_target: str = "end",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
def create_enum_router(
enum_to_target: Mapping[str, str],
state_key: str = "routing_decision",
default_target: str = "end",
) -> Callable[[StateLike], str]:
"""Create a router that maps enum values to target nodes.
Args:
@@ -43,22 +71,18 @@ def create_enum_router(
graph.add_conditional_edges("source_node", router)
"""
def router(state: dict[str, Any] | StateProtocol) -> str:
if hasattr(state, "get") or isinstance(state, dict):
enum_value = state.get(state_key)
else:
enum_value = getattr(state, state_key, None)
return enum_to_target.get(str(enum_value), default_target)
def router(state: StateLike) -> str:
enum_value = get_state_value(state, state_key)
return enum_to_target.get(str(enum_value), default_target)
return router
def create_bool_router(
true_target: str,
false_target: str,
state_key: str = "condition",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
def create_bool_router(
true_target: str,
false_target: str,
state_key: str = "condition",
) -> Callable[[StateLike], str]:
"""Create a router based on a boolean condition in state.
Args:
@@ -74,22 +98,18 @@ def create_bool_router(
graph.add_conditional_edges("validator", router)
"""
def router(state: dict[str, Any] | StateProtocol) -> str:
if hasattr(state, "get") or isinstance(state, dict):
condition = state.get(state_key, False)
else:
condition = getattr(state, state_key, False)
return true_target if condition else false_target
def router(state: StateLike) -> str:
condition = get_state_value(state, state_key, False)
return true_target if bool(condition) else false_target
return router
def create_status_router(
status_mapping: dict[str, str],
state_key: str = "status",
default_target: str = "end",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
def create_status_router(
status_mapping: Mapping[str, str],
state_key: str = "status",
default_target: str = "end",
) -> Callable[[StateLike], str]:
"""Create a router based on operation status.
Args:
@@ -112,13 +132,13 @@ def create_status_router(
return create_enum_router(status_mapping, state_key, default_target)
def create_threshold_router(
threshold: float,
above_target: str,
below_target: str,
state_key: str = "score",
equal_target: str | None = None,
) -> Callable[[dict[str, Any] | StateProtocol], str]:
def create_threshold_router(
threshold: float,
above_target: str,
below_target: str,
state_key: str = "score",
equal_target: str | None = None,
) -> Callable[[StateLike], str]:
"""Create a router based on numeric threshold comparison.
Args:
@@ -139,15 +159,11 @@ def create_threshold_router(
graph.add_conditional_edges("scorer", router)
"""
def router(state: dict[str, Any] | StateProtocol) -> str:
if hasattr(state, "get") or isinstance(state, dict):
value = state.get(state_key, 0.0)
else:
value = getattr(state, state_key, 0.0)
try:
numeric_value = float(value)
if numeric_value > threshold:
def router(state: StateLike) -> str:
value = get_state_value(state, state_key, 0.0)
try:
numeric_value = float(value)
if numeric_value > threshold:
return above_target
elif numeric_value < threshold:
return below_target
@@ -159,11 +175,11 @@ def create_threshold_router(
return router
def create_field_presence_router(
required_fields: list[str],
complete_target: str,
incomplete_target: str,
) -> Callable[[dict[str, Any] | StateProtocol], str]:
def create_field_presence_router(
required_fields: Sequence[str],
complete_target: str,
incomplete_target: str,
) -> Callable[[StateLike], str]:
"""Create a router based on presence of required fields in state.
Args:
@@ -183,27 +199,24 @@ def create_field_presence_router(
graph.add_conditional_edges("data_checker", router)
"""
def router(state: dict[str, Any] | StateProtocol) -> str:
for field in required_fields:
if hasattr(state, "get") or isinstance(state, dict):
value = state.get(field)
else:
value = getattr(state, field, None)
if value is None or value == "":
return incomplete_target
return complete_target
def router(state: StateLike) -> str:
for field in required_fields:
value = get_state_value(state, field)
if value is None or value == "":
return incomplete_target
return complete_target
return router
def create_list_length_router(
min_length: int,
sufficient_target: str,
insufficient_target: str,
state_key: str = "items",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
def create_list_length_router(
min_length: int,
sufficient_target: str,
insufficient_target: str,
state_key: str = "items",
) -> Callable[[StateLike], str]:
"""Create a router based on list length in state.
Args:
@@ -222,27 +235,18 @@ def create_list_length_router(
graph.add_conditional_edges("result_checker", router)
"""
def router(state: dict[str, Any] | StateProtocol) -> str:
if hasattr(state, "get") or isinstance(state, dict):
items = state.get(state_key, [])
else:
items = getattr(state, state_key, [])
# Only process actual lists/sequences, not strings or other types
if not isinstance(items, (list, tuple)):
return insufficient_target
try:
return (
sufficient_target if len(items) >= min_length else insufficient_target
)
except TypeError:
return insufficient_target
return router
# Type aliases for common router patterns
BoolRouter = Callable[[StateT], Literal["true", "false"]]
StatusRouter = Callable[[StateT], Literal["pending", "running", "completed", "failed"]]
ContinueRouter = Callable[[StateT], Literal["continue", "end"]]
def router(state: StateLike) -> str:
items = get_state_value(state, state_key, [])
# Only process actual sequences (excluding strings/bytes)
if not isinstance(items, Sequence) or isinstance(items, (str, bytes)):
return insufficient_target
try:
return (
sufficient_target if len(items) >= min_length else insufficient_target
)
except TypeError:
return insufficient_target
return router

View File

@@ -1,395 +1,269 @@
"""Error handling edge helpers for robust workflow management.
This module provides edge helpers for detecting errors, implementing retry logic,
and routing to appropriate error handling or recovery nodes.
"""
from collections.abc import Callable
from typing import Any, Literal, TypeVar
from .core import StateProtocol
StateT = TypeVar("StateT", bound=StateProtocol)
def detect_errors_list(
error_target: str = "error_handler",
success_target: str = "continue",
errors_key: str = "errors",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that detects errors in a list format.
This is specifically designed for states that use an 'errors' list
to accumulate error information.
Args:
error_target: Target node when errors are detected
success_target: Target node when no errors present
errors_key: Key in state containing list of errors
Returns:
Router function that routes based on error presence
Example:
error_detector = detect_errors_list(
error_target="error_handler",
success_target="next_step"
)
graph.add_conditional_edges("some_node", error_detector)
"""
def router(state: dict[str, Any] | StateProtocol) -> str:
if isinstance(state, dict):
errors = state.get(errors_key, [])
else:
errors = getattr(state, errors_key, [])
# Check if we have any errors
return error_target if len(errors) > 0 else success_target
return router
def handle_error(
error_types: dict[str, str] | None = None,
error_key: str = "error",
default_target: str = "generic_error_handler",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that handles different types of errors.
Args:
error_types: Mapping of error type/class names to handler nodes
error_key: Key in state containing error information
default_target: Default target if error type not in mapping
Returns:
Router function that routes based on error type
Example:
error_router = handle_error({
"ValidationError": "validation_recovery",
"NetworkError": "network_retry",
"AuthenticationError": "auth_failure_handler"
})
graph.add_conditional_edges("error_detector", error_router)
"""
if error_types is None:
error_types = {
"ValidationError": "validation_error_handler",
"NetworkError": "network_error_handler",
"TimeoutError": "timeout_error_handler",
"AuthenticationError": "auth_error_handler",
}
def router(state: dict[str, Any] | StateProtocol) -> str:
if hasattr(state, "get") or isinstance(state, dict):
error = state.get(error_key)
else:
error = getattr(state, error_key, None)
if error is None:
return "no_error"
# Handle different error formats
error_type = None
if isinstance(error, dict):
error_type = error.get("type") or error.get("error_type")
elif isinstance(error, str):
error_type = error
elif hasattr(error, "__class__"):
error_type = error.__class__.__name__
if error_type:
return error_types.get(error_type, default_target)
return default_target
return router
def retry_on_failure(
max_retries: int = 3,
retry_count_key: str = "retry_count",
error_key: str = "error",
) -> Callable[
[dict[str, Any] | StateProtocol],
Literal["retry", "max_retries_exceeded", "success"],
]:
"""Create a router that implements retry logic for failures.
Args:
max_retries: Maximum number of retry attempts
retry_count_key: Key in state tracking retry attempts
error_key: Key in state containing error information
Returns:
Router function that returns "retry", "max_retries_exceeded", or "success"
Example:
retry_router = retry_on_failure(max_retries=5)
graph.add_conditional_edges("failure_handler", retry_router)
"""
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["retry", "max_retries_exceeded", "success"]:
# Check if there's an error
if hasattr(state, "get") or isinstance(state, dict):
error = state.get(error_key)
retry_count = state.get(retry_count_key, 0)
else:
error = getattr(state, error_key, None)
retry_count = getattr(state, retry_count_key, 0)
# No error means success
if error is None:
return "success"
try:
current_retries = int(retry_count)
return "max_retries_exceeded" if current_retries >= max_retries else "retry"
except (ValueError, TypeError):
return "retry"
return router
def fallback_to_default(
fallback_conditions: list[str] | None = None,
output_key: str = "output",
) -> Callable[[dict[str, Any] | StateProtocol], Literal["use_output", "fallback"]]:
"""Create a router that falls back to default when output is inadequate.
Args:
fallback_conditions: List of conditions that trigger fallback
output_key: Key in state containing output to check
Returns:
Router function that returns "use_output" or "fallback"
Example:
fallback_router = fallback_to_default([
"empty_output", "low_confidence", "invalid_format"
])
graph.add_conditional_edges("output_validator", fallback_router)
"""
if fallback_conditions is None:
fallback_conditions = ["empty_output", "low_confidence", "error"]
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["use_output", "fallback"]:
if hasattr(state, "get") or isinstance(state, dict):
output = state.get(output_key)
else:
output = getattr(state, output_key, None)
# Check for empty or null output
if output is None or output == "":
return "fallback"
# Check for other fallback conditions in state
for condition in fallback_conditions:
if hasattr(state, "get") or isinstance(state, dict):
condition_value = state.get(condition, False)
else:
condition_value = getattr(state, condition, False)
if condition_value:
return "fallback"
return "use_output"
return router
def check_critical_error(
critical_error_types: list[str] | None = None,
error_key: str = "error",
) -> Callable[[dict[str, Any] | StateProtocol], Literal["critical", "non_critical"]]:
"""Create a router that identifies critical errors requiring immediate attention.
Args:
critical_error_types: List of error types considered critical
error_key: Key in state containing error information
Returns:
Router function that returns "critical" or "non_critical"
Example:
critical_router = check_critical_error([
"SecurityError", "DataCorruptionError", "SystemFailure"
])
graph.add_conditional_edges("error_classifier", critical_router)
"""
if critical_error_types is None:
critical_error_types = [
"SecurityError",
"AuthenticationError",
"AuthorizationError",
"DataCorruptionError",
"SystemFailure",
"OutOfMemoryError",
]
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["critical", "non_critical"]:
if hasattr(state, "get") or isinstance(state, dict):
error = state.get(error_key)
else:
error = getattr(state, error_key, None)
if error is None:
return "non_critical"
# Determine error type
error_type = None
if isinstance(error, dict):
error_type = error.get("type") or error.get("error_type")
elif isinstance(error, str):
error_type = error
elif hasattr(error, "__class__"):
error_type = error.__class__.__name__
if error_type and error_type in critical_error_types:
return "critical"
return "non_critical"
return router
def escalation_policy(
escalation_threshold: int = 3,
failure_count_key: str = "consecutive_failures",
escalation_types: dict[str, str] | None = None,
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that implements error escalation policies.
Args:
escalation_threshold: Number of failures before escalation
failure_count_key: Key in state tracking consecutive failures
escalation_types: Mapping of failure types to escalation targets
Returns:
Router function that handles escalation logic
Example:
escalation_router = escalation_policy(
escalation_threshold=3,
escalation_types={
"timeout": "timeout_escalation",
"validation": "validation_escalation"
}
)
graph.add_conditional_edges("failure_monitor", escalation_router)
"""
if escalation_types is None:
escalation_types = {
"timeout": "timeout_escalation",
"network": "network_escalation",
"validation": "validation_escalation",
"auth": "security_escalation",
"security": "security_escalation",
}
def router(state: dict[str, Any] | StateProtocol) -> str:
if hasattr(state, "get") or isinstance(state, dict):
failure_count = state.get(failure_count_key, 0)
error = state.get("error")
else:
failure_count = getattr(state, failure_count_key, 0)
error = getattr(state, "error", None)
try:
failures = int(failure_count)
if failures < escalation_threshold:
return "continue_monitoring"
# Determine escalation type based on error
if error:
error_type = None
if isinstance(error, dict):
error_type = error.get("type") or error.get("error_type")
elif isinstance(error, str):
error_type = error.lower()
elif hasattr(error, "__class__"):
error_type = error.__class__.__name__
if error_type:
for escalation_key, target in escalation_types.items():
if escalation_key.lower() in error_type.lower():
return target
return "generic_escalation"
except (ValueError, TypeError):
return "continue_monitoring"
return router
def circuit_breaker(
failure_threshold: int = 5,
recovery_timeout: int = 60,
state_key: str = "circuit_breaker_state",
failure_count_key: str = "failure_count",
last_failure_key: str = "last_failure_time",
) -> Callable[[dict[str, Any] | StateProtocol], Literal["allow", "reject", "probe"]]:
"""Create a router that implements circuit breaker pattern.
Args:
failure_threshold: Number of failures before opening circuit
recovery_timeout: Seconds to wait before trying to close circuit
state_key: Key storing circuit state (closed/open/half_open)
failure_count_key: Key tracking failure count
last_failure_key: Key storing last failure timestamp
Returns:
Router function implementing circuit breaker logic
Example:
breaker_router = circuit_breaker(failure_threshold=3, recovery_timeout=30)
graph.add_conditional_edges("service_call", breaker_router)
"""
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["allow", "reject", "probe"]:
import time
if hasattr(state, "get") or isinstance(state, dict):
circuit_state = state.get(state_key, "closed")
failure_count = state.get(failure_count_key, 0)
last_failure = state.get(last_failure_key, 0)
else:
circuit_state = getattr(state, state_key, "closed")
failure_count = getattr(state, failure_count_key, 0)
last_failure = getattr(state, last_failure_key, 0)
current_time = time.time()
# Convert numeric values with error handling
try:
failure_count = int(failure_count)
except (ValueError, TypeError):
failure_count = 0
try:
last_failure = float(last_failure)
except (ValueError, TypeError):
last_failure = 0.0
if circuit_state == "open":
# Check if recovery timeout has passed
return (
"probe" if current_time - last_failure >= recovery_timeout else "reject"
)
elif circuit_state == "half_open":
return "probe" # Allow one request to test
else: # closed state
return "reject" if failure_count > failure_threshold else "allow"
return router
"""Error handling routing helpers for LangGraph workflows."""
from __future__ import annotations
import time
from collections.abc import Callable, Mapping, Sequence
from numbers import Real
from typing import Literal
from .core import StateLike, get_state_value
def _extract_error_type(error: object | None) -> str | None:
"""Return a normalized error type string when possible."""
if error is None:
return None
if isinstance(error, Mapping):
raw_type = error.get("type") or error.get("error_type")
return str(raw_type) if raw_type is not None else None
if isinstance(error, str):
return error
return error.__class__.__name__ if hasattr(error, "__class__") else None
def _to_int(value: object | None) -> int | None:
if value is None:
return None
if isinstance(value, bool): # bool is an int subclass; preserve explicit intent
return int(value)
if isinstance(value, Real):
return int(value)
if isinstance(value, str):
stripped = value.strip()
if not stripped:
return None
try:
return int(stripped)
except ValueError:
return None
return None
def _to_float(value: object | None) -> float | None:
if value is None:
return None
if isinstance(value, bool):
return float(value)
if isinstance(value, Real):
return float(value)
if isinstance(value, str):
stripped = value.strip()
if not stripped:
return None
try:
return float(stripped)
except ValueError:
return None
return None
def _has_content(value: object | None) -> bool:
if value is None:
return False
if isinstance(value, str):
return bool(value)
if isinstance(value, bool):
return True
if isinstance(value, Real):
return True
if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
return len(value) > 0
return bool(value)
def detect_errors_list(
error_target: str = "error_handler",
success_target: str = "continue",
errors_key: str = "errors",
) -> Callable[[StateLike], str]:
"""Create a router that detects errors stored in a list."""
def router(state: StateLike) -> str:
errors = get_state_value(state, errors_key, [])
if (
isinstance(errors, Sequence)
and not isinstance(errors, (str, bytes, bytearray))
and len(errors) > 0
):
return error_target
return success_target
return router
def handle_error(
error_types: Mapping[str, str] | None = None,
error_key: str = "error",
default_target: str = "generic_error_handler",
) -> Callable[[StateLike], str]:
"""Create a router that handles different types of errors."""
handlers = (
dict(error_types)
if error_types is not None
else {
"ValidationError": "validation_error_handler",
"NetworkError": "network_error_handler",
"TimeoutError": "timeout_error_handler",
"AuthenticationError": "auth_error_handler",
}
)
def router(state: StateLike) -> str:
error = get_state_value(state, error_key)
if error is None:
return "no_error"
error_type = _extract_error_type(error)
if error_type:
return handlers.get(error_type, default_target)
return default_target
return router
def retry_on_failure(
max_retries: int = 3,
retry_count_key: str = "retry_count",
error_key: str = "error",
) -> Callable[[StateLike], Literal["retry", "max_retries_exceeded", "success"]]:
"""Create a router that implements retry logic for failures."""
def router(state: StateLike) -> Literal["retry", "max_retries_exceeded", "success"]:
error = get_state_value(state, error_key)
if error is None:
return "success"
retry_count = _to_int(get_state_value(state, retry_count_key, 0))
if retry_count is None:
return "retry"
return "max_retries_exceeded" if retry_count >= max_retries else "retry"
return router
def fallback_to_default(
fallback_conditions: Sequence[str] | None = None,
output_key: str = "output",
) -> Callable[[StateLike], Literal["use_output", "fallback"]]:
"""Create a router that falls back to default when output is inadequate."""
conditions = list(fallback_conditions) if fallback_conditions is not None else [
"empty_output",
"low_confidence",
"error",
]
def router(state: StateLike) -> Literal["use_output", "fallback"]:
output = get_state_value(state, output_key)
if not _has_content(output):
return "fallback"
for condition in conditions:
if bool(get_state_value(state, condition, False)):
return "fallback"
return "use_output"
return router
def check_critical_error(
critical_error_types: Sequence[str] | None = None,
error_key: str = "error",
) -> Callable[[StateLike], Literal["critical", "non_critical"]]:
"""Create a router that identifies critical errors requiring attention."""
critical = set(critical_error_types or [
"SecurityError",
"AuthenticationError",
"AuthorizationError",
"DataCorruptionError",
"SystemFailure",
"OutOfMemoryError",
])
def router(state: StateLike) -> Literal["critical", "non_critical"]:
error = get_state_value(state, error_key)
if error is None:
return "non_critical"
error_type = _extract_error_type(error)
if error_type and error_type in critical:
return "critical"
return "non_critical"
return router
def escalation_policy(
escalation_threshold: int = 3,
failure_count_key: str = "consecutive_failures",
escalation_types: Mapping[str, str] | None = None,
) -> Callable[[StateLike], str]:
"""Create a router that implements error escalation policies."""
escalations = (
dict(escalation_types)
if escalation_types is not None
else {
"timeout": "timeout_escalation",
"network": "network_escalation",
"validation": "validation_escalation",
"auth": "security_escalation",
"security": "security_escalation",
}
)
def router(state: StateLike) -> str:
failure_count = _to_int(get_state_value(state, failure_count_key, 0))
if failure_count is None or failure_count < escalation_threshold:
return "continue_monitoring"
error = get_state_value(state, "error")
error_type = _extract_error_type(error)
if error_type:
lower_type = error_type.lower()
for escalation_key, target in escalations.items():
if escalation_key.lower() in lower_type:
return target
return "generic_escalation"
return router
def circuit_breaker(
failure_threshold: int = 5,
recovery_timeout: int = 60,
state_key: str = "circuit_breaker_state",
failure_count_key: str = "failure_count",
last_failure_key: str = "last_failure_time",
) -> Callable[[StateLike], Literal["allow", "reject", "probe"]]:
"""Create a router that implements the circuit breaker pattern."""
def router(state: StateLike) -> Literal["allow", "reject", "probe"]:
circuit_state = str(get_state_value(state, state_key, "closed"))
failure_count = _to_int(get_state_value(state, failure_count_key, 0)) or 0
last_failure = _to_float(get_state_value(state, last_failure_key, 0.0)) or 0.0
current_time = time.time()
if circuit_state == "open":
return "probe" if current_time - last_failure >= recovery_timeout else "reject"
if circuit_state == "half_open":
return "probe"
return "reject" if failure_count > failure_threshold else "allow"
return router
__all__ = [
"detect_errors_list",
"handle_error",
"retry_on_failure",
"fallback_to_default",
"check_critical_error",
"escalation_policy",
"circuit_breaker",
]

View File

@@ -1,262 +1,183 @@
"""Flow control edge helpers for managing workflow progression.
This module provides edge helpers for controlling the flow of execution
in LangGraph workflows, including continuation logic, timeout checks,
and multi-step progress tracking.
"""
import time
from collections.abc import Callable
from typing import Any, Literal, TypeVar
from .core import StateProtocol
StateT = TypeVar("StateT", bound=StateProtocol)
def should_continue(
state: dict[str, Any] | StateProtocol,
) -> Literal["continue", "end"]:
"""Decide whether to continue or end based on tool calls in the last AI message.
Checks for the presence of tool calls in the last message to determine
if the workflow should continue processing or end.
Args:
state: State containing messages with potential tool calls
Returns:
"continue" if tool calls are present, "end" otherwise
Example:
graph.add_conditional_edges("agent", should_continue)
"""
# Check for messages in state
if hasattr(state, "get") or isinstance(state, dict):
messages = state.get("messages", [])
else:
messages = getattr(state, "messages", [])
if not messages:
return "end"
last_message = messages[-1]
# Check various formats for tool calls
if isinstance(last_message, dict):
# Check for tool_calls in additional_kwargs (must be a list)
additional_kwargs = last_message.get("additional_kwargs", {})
tool_calls = additional_kwargs.get("tool_calls")
if tool_calls and isinstance(tool_calls, list) and len(tool_calls) > 0:
return "continue"
# Check for tool_calls directly (must be a list)
tool_calls = last_message.get("tool_calls")
if tool_calls and isinstance(tool_calls, list) and len(tool_calls) > 0:
return "continue"
# Check for function_call
if last_message.get("function_call"):
return "continue"
# Check if message object has tool_calls attribute (must be a list)
if hasattr(last_message, "tool_calls"):
tool_calls = getattr(last_message, "tool_calls", None)
if tool_calls and isinstance(tool_calls, list) and len(tool_calls) > 0:
return "continue"
# Check if message object has additional_kwargs with tool_calls (must be a list)
if hasattr(last_message, "additional_kwargs"):
additional_kwargs = getattr(last_message, "additional_kwargs", None) or {}
if isinstance(additional_kwargs, dict):
tool_calls = additional_kwargs.get("tool_calls")
if tool_calls and isinstance(tool_calls, list) and len(tool_calls) > 0:
return "continue"
return "end"
def timeout_check(
timeout_seconds: float = 300.0,
start_time_key: str = "start_time",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that checks for timeout conditions.
Args:
timeout_seconds: Maximum allowed execution time in seconds
start_time_key: Key in state containing start timestamp
Returns:
Router function that returns "timeout" or "continue"
Example:
timeout_router = timeout_check(timeout_seconds=60.0)
graph.add_conditional_edges("long_task", timeout_router)
"""
def router(state: dict[str, Any] | StateProtocol) -> Literal["timeout", "continue"]:
if hasattr(state, "get") or isinstance(state, dict):
start_time = state.get(start_time_key)
else:
start_time = getattr(state, start_time_key, None)
if start_time is None:
# No start time recorded, assume we should continue
return "continue"
current_time = time.time()
elapsed = current_time - start_time
return "timeout" if elapsed > timeout_seconds else "continue"
return router
def multi_step_progress(
total_steps: int,
step_key: str = "current_step",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router for multi-step workflow progress tracking.
Args:
total_steps: Total number of steps in the workflow
step_key: Key in state containing current step number
Returns:
Router function that returns "next_step", "complete", or "error"
Example:
progress_router = multi_step_progress(total_steps=5)
graph.add_conditional_edges("step_processor", progress_router)
"""
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["next_step", "complete", "error"]:
if hasattr(state, "get") or isinstance(state, dict):
current_step = state.get(step_key, 0)
else:
current_step = getattr(state, step_key, 0)
try:
step_num = int(current_step)
if step_num < 0:
return "error"
elif step_num >= total_steps:
return "complete"
else:
return "next_step"
except (ValueError, TypeError):
return "error"
return router
def check_iteration_limit(
max_iterations: int = 10,
iteration_key: str = "iteration_count",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that enforces iteration limits to prevent infinite loops.
Args:
max_iterations: Maximum allowed iterations
iteration_key: Key in state containing iteration counter
Returns:
Router function that returns "continue" or "limit_reached"
Example:
limit_router = check_iteration_limit(max_iterations=5)
graph.add_conditional_edges("retry_node", limit_router)
"""
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["continue", "limit_reached"]:
if hasattr(state, "get") or isinstance(state, dict):
iterations = state.get(iteration_key, 0)
else:
iterations = getattr(state, iteration_key, 0)
try:
iter_count = int(iterations)
return "limit_reached" if iter_count >= max_iterations else "continue"
except (ValueError, TypeError):
return "continue"
return router
def check_completion_criteria(
required_conditions: list[str],
condition_prefix: str = "completed_",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that checks if all completion criteria are met.
Args:
required_conditions: List of condition names that must be true
condition_prefix: Prefix for condition keys in state
Returns:
Router function that returns "complete" or "continue"
Example:
completion_router = check_completion_criteria([
"data_validated", "report_generated", "notifications_sent"
])
graph.add_conditional_edges("final_check", completion_router)
"""
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["complete", "continue"]:
for condition in required_conditions:
condition_key = f"{condition_prefix}{condition}"
if hasattr(state, "get") or isinstance(state, dict):
is_complete = state.get(condition_key, False)
else:
is_complete = getattr(state, condition_key, False)
if not is_complete:
return "continue"
return "complete"
return router
def check_workflow_state(
state_transitions: dict[str, str],
state_key: str = "workflow_state",
default_target: str = "error",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router based on workflow state transitions.
Args:
state_transitions: Mapping of current states to next targets
state_key: Key in state containing current workflow state
default_target: Default target if state not found in transitions
Returns:
Router function that returns target based on workflow state
Example:
workflow_router = check_workflow_state({
"initialized": "processing",
"processing": "validation",
"validation": "completion",
"completion": "end"
})
graph.add_conditional_edges("state_manager", workflow_router)
"""
def router(state: dict[str, Any] | StateProtocol) -> str:
if hasattr(state, "get") or isinstance(state, dict):
current_state = state.get(state_key)
else:
current_state = getattr(state, state_key, None)
return state_transitions.get(str(current_state), default_target)
return router
"""Flow control routing helpers for LangGraph workflows."""
from __future__ import annotations
import time
from collections.abc import Callable, Mapping, Sequence
from typing import Literal
from .core import StateLike, get_state_value
def _to_float(value: object | None) -> float | None:
"""Convert ``value`` to ``float`` when possible."""
if value is None:
return None
try:
return float(value)
except (TypeError, ValueError):
return None
def _to_int(value: object | None) -> int | None:
"""Convert ``value`` to ``int`` when possible."""
if value is None:
return None
try:
return int(value)
except (TypeError, ValueError):
return None
def should_continue(state: StateLike) -> Literal["continue", "end"]:
"""Decide whether to continue based on tool calls in the last assistant message."""
messages_value = get_state_value(state, "messages", [])
if not isinstance(messages_value, Sequence) or not messages_value:
return "end"
last_assistant: object | None = None
for message in reversed(messages_value):
role: object | None
if isinstance(message, Mapping):
role = message.get("role") or message.get("type")
else:
role = getattr(message, "role", None) or getattr(message, "type", None)
if role in ("assistant", "ai", "ai_message", "assistant_message"):
last_assistant = message
break
if last_assistant is None:
return "end"
def _has_tool_calls(msg: object) -> bool:
if isinstance(msg, Mapping):
additional_kwargs = msg.get("additional_kwargs")
if isinstance(additional_kwargs, Mapping):
tool_calls = additional_kwargs.get("tool_calls")
if isinstance(tool_calls, Sequence) and not isinstance(
tool_calls, (str, bytes, bytearray)
) and tool_calls:
return True
direct_tool_calls = msg.get("tool_calls")
if isinstance(direct_tool_calls, Sequence) and not isinstance(
direct_tool_calls, (str, bytes, bytearray)
) and direct_tool_calls:
return True
if msg.get("function_call"):
return True
else:
tool_calls_attr = getattr(msg, "tool_calls", None)
if isinstance(tool_calls_attr, Sequence) and not isinstance(
tool_calls_attr, (str, bytes, bytearray)
) and tool_calls_attr:
return True
additional_kwargs_attr = getattr(msg, "additional_kwargs", None)
if isinstance(additional_kwargs_attr, Mapping):
attr_tool_calls = additional_kwargs_attr.get("tool_calls")
if isinstance(attr_tool_calls, Sequence) and not isinstance(
attr_tool_calls, (str, bytes, bytearray)
) and attr_tool_calls:
return True
return False
return "continue" if _has_tool_calls(last_assistant) else "end"
def timeout_check(
timeout_seconds: float = 300.0,
start_time_key: str = "start_time",
) -> Callable[[StateLike], Literal["timeout", "continue"]]:
"""Create a router that checks for timeout conditions."""
def router(state: StateLike) -> Literal["timeout", "continue"]:
start_time_value = get_state_value(state, start_time_key)
start_time = _to_float(start_time_value)
if start_time is None:
return "continue"
elapsed = time.time() - start_time
return "timeout" if elapsed > timeout_seconds else "continue"
return router
def multi_step_progress(
total_steps: int,
step_key: str = "current_step",
) -> Callable[[StateLike], Literal["next_step", "complete", "error"]]:
"""Create a router for multi-step workflow progress tracking."""
def router(state: StateLike) -> Literal["next_step", "complete", "error"]:
current_step_value = get_state_value(state, step_key, 0)
step_num = _to_int(current_step_value)
if step_num is None:
return "error"
if step_num < 0:
return "error"
if step_num >= total_steps:
return "complete"
return "next_step"
return router
def check_iteration_limit(
max_iterations: int = 10,
iteration_key: str = "iteration_count",
) -> Callable[[StateLike], Literal["continue", "limit_reached"]]:
"""Create a router that enforces iteration limits to prevent infinite loops."""
def router(state: StateLike) -> Literal["continue", "limit_reached"]:
iteration_value = get_state_value(state, iteration_key, 0)
iter_count = _to_int(iteration_value)
if iter_count is None:
return "continue"
return "limit_reached" if iter_count >= max_iterations else "continue"
return router
def check_completion_criteria(
required_conditions: Sequence[str],
condition_prefix: str = "completed_",
) -> Callable[[StateLike], Literal["complete", "continue"]]:
"""Create a router that checks if all completion criteria are met."""
def router(state: StateLike) -> Literal["complete", "continue"]:
for condition in required_conditions:
condition_key = f"{condition_prefix}{condition}"
if not bool(get_state_value(state, condition_key, False)):
return "continue"
return "complete"
return router
def check_workflow_state(
state_transitions: Mapping[str, str],
state_key: str = "workflow_state",
default_target: str = "error",
) -> Callable[[StateLike], str]:
"""Create a router based on workflow state transitions."""
transitions = dict(state_transitions)
def router(state: StateLike) -> str:
current_state = get_state_value(state, state_key)
return transitions.get(str(current_state), default_target)
return router
__all__ = [
"should_continue",
"timeout_check",
"multi_step_progress",
"check_iteration_limit",
"check_completion_criteria",
"check_workflow_state",
]

View File

@@ -1,400 +1,300 @@
"""Monitoring and operational edge helpers for system management.
This module provides edge helpers for monitoring system health, checking
resource availability, triggering notifications, and managing operational
concerns like load balancing and rate limiting.
"""
import time
from collections.abc import Callable
from typing import Any, Literal, TypeVar
from .core import StateProtocol
StateT = TypeVar("StateT", bound=StateProtocol)
def log_and_monitor(
log_levels: dict[str, str] | None = None,
log_level_key: str = "log_level",
should_log_key: str = "should_log",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that determines logging and monitoring actions.
Args:
log_levels: Mapping of log levels to monitoring actions
log_level_key: Key in state containing log level
should_log_key: Key in state indicating if logging is enabled
Returns:
Router function that returns monitoring action or "no_logging"
Example:
monitor_router = log_and_monitor({
"debug": "debug_monitor",
"info": "info_monitor",
"warning": "alert_monitor",
"error": "urgent_monitor"
})
graph.add_conditional_edges("logger", monitor_router)
"""
if log_levels is None:
log_levels = {
"debug": "debug_log",
"info": "info_log",
"warning": "warning_alert",
"error": "error_alert",
"critical": "critical_alert",
}
def router(state: dict[str, Any] | StateProtocol) -> str:
if hasattr(state, "get") or isinstance(state, dict):
should_log = state.get(should_log_key, True)
log_level = state.get(log_level_key, "info")
else:
should_log = getattr(state, should_log_key, True)
log_level = getattr(state, log_level_key, "info")
if not should_log:
return "no_logging"
level_str = str(log_level).lower()
return log_levels.get(level_str, "info_log")
return router
def check_resource_availability(
resource_thresholds: dict[str, float] | None = None,
resources_key: str = "resource_usage",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that checks system resource availability.
Args:
resource_thresholds: Mapping of resource names to threshold percentages
resources_key: Key in state containing resource usage data
Returns:
Router function that returns "resources_available" or resource constraint
Example:
resource_router = check_resource_availability({
"cpu": 0.8, # 80% threshold
"memory": 0.9, # 90% threshold
"disk": 0.95 # 95% threshold
})
graph.add_conditional_edges("resource_check", resource_router)
"""
if resource_thresholds is None:
resource_thresholds = {
"cpu": 0.8,
"memory": 0.9,
"disk": 0.95,
"network": 0.8,
}
def router(state: dict[str, Any] | StateProtocol) -> str:
if hasattr(state, "get") or isinstance(state, dict):
resources = state.get(resources_key)
else:
resources = getattr(state, resources_key, None)
if resources is None or not isinstance(resources, dict):
return "resources_unknown"
# Check each resource against its threshold
for resource_name, threshold in resource_thresholds.items():
usage = resources.get(resource_name, 0.0)
try:
usage_percent = float(usage)
if usage_percent >= threshold:
return f"{resource_name}_constrained"
except (ValueError, TypeError):
continue
return "resources_available"
return router
def trigger_notifications(
notification_rules: dict[str, str] | None = None,
alert_level_key: str = "alert_level",
notification_enabled_key: str = "notifications_enabled",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that triggers appropriate notifications.
Args:
notification_rules: Mapping of alert levels to notification types
alert_level_key: Key in state containing alert level
notification_enabled_key: Key in state for notification toggle
Returns:
Router function that returns notification type or "no_notification"
Example:
notify_router = trigger_notifications({
"low": "email_notification",
"medium": "slack_notification",
"high": "sms_notification",
"critical": "phone_notification"
})
graph.add_conditional_edges("alert_handler", notify_router)
"""
if notification_rules is None:
notification_rules = {
"low": "email_notification",
"medium": "slack_notification",
"high": "sms_notification",
"critical": "phone_notification",
"emergency": "all_channels",
}
def router(state: dict[str, Any] | StateProtocol) -> str:
if hasattr(state, "get") or isinstance(state, dict):
notifications_enabled = state.get(notification_enabled_key, True)
alert_level = state.get(alert_level_key)
else:
notifications_enabled = getattr(state, notification_enabled_key, True)
alert_level = getattr(state, alert_level_key, None)
if not notifications_enabled or alert_level is None:
return "no_notification"
level_str = str(alert_level).lower()
return notification_rules.get(level_str, "no_notification")
return router
def check_rate_limiting(
rate_limits: dict[str, dict[str, float]] | None = None,
request_counts_key: str = "request_counts",
time_window_key: str = "time_window",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that enforces rate limiting policies.
Args:
rate_limits: Mapping of endpoints/operations to rate limit configs
request_counts_key: Key in state containing request count data
time_window_key: Key in state containing time window info
Returns:
Router function that returns "allowed" or "rate_limited"
Example:
rate_router = check_rate_limiting({
"api_calls": {"max_requests": 100, "window_seconds": 60},
"file_uploads": {"max_requests": 10, "window_seconds": 60}
})
graph.add_conditional_edges("rate_limiter", rate_router)
"""
if rate_limits is None:
rate_limits = {
"default": {"max_requests": 100, "window_seconds": 60},
"api_calls": {"max_requests": 100, "window_seconds": 60},
"heavy_operations": {"max_requests": 10, "window_seconds": 300},
}
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["allowed", "rate_limited"]:
if hasattr(state, "get") or isinstance(state, dict):
request_counts = state.get(request_counts_key, {})
time_window_data = state.get(time_window_key, {})
else:
request_counts = getattr(state, request_counts_key, {})
time_window_data = getattr(state, time_window_key, {})
# Use time from time_window_data if available, otherwise use system time
current_time = time_window_data.get("current_time", time.time())
# Check rate limits for each configured endpoint
for endpoint, limits in rate_limits.items():
if endpoint not in request_counts:
continue
max_requests = limits.get("max_requests", 100)
window_seconds = limits.get("window_seconds", 60)
# Get request history for this endpoint
endpoint_data = request_counts.get(endpoint, {})
request_count = endpoint_data.get("count", 0)
window_start = endpoint_data.get("window_start", current_time)
# Check if we're still in the same time window
if (
current_time - window_start < window_seconds
and request_count > max_requests
):
return "rate_limited"
return "allowed"
return router
def load_balance(
load_balancing_strategy: str = "round_robin",
available_nodes: list[str] | None = None,
node_status_key: str = "node_status",
current_node_key: str = "current_node_index",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that implements load balancing across nodes.
Args:
load_balancing_strategy: Strategy to use ("round_robin", "least_loaded")
available_nodes: List of available node names
node_status_key: Key in state containing node health status
current_node_key: Key in state tracking current node index
Returns:
Router function that returns selected node name
Example:
balance_router = load_balance(
strategy="round_robin",
available_nodes=["node1", "node2", "node3"]
)
graph.add_conditional_edges("load_balancer", balance_router)
"""
if available_nodes is None:
available_nodes = ["primary_node", "secondary_node"]
def router(state: dict[str, Any] | StateProtocol) -> str:
if hasattr(state, "get") or isinstance(state, dict):
node_status = state.get(node_status_key, {}) or {}
current_index = state.get(current_node_key, 0)
else:
node_status = getattr(state, node_status_key, {}) or {}
current_index = getattr(state, current_node_key, 0)
# Filter healthy nodes
healthy_nodes = []
for node in available_nodes:
node_health = node_status.get(node, {})
if node_health.get("healthy", True): # Default to healthy
healthy_nodes.append(node)
if not healthy_nodes:
return "no_available_nodes"
if load_balancing_strategy == "round_robin":
try:
index = int(current_index) % len(healthy_nodes)
return healthy_nodes[index]
except (ValueError, TypeError):
return healthy_nodes[0]
elif load_balancing_strategy == "least_loaded":
# Find node with lowest load
min_load = float("inf")
selected_node = healthy_nodes[0]
for node in healthy_nodes:
node_data = node_status.get(node, {})
load = node_data.get("load", 0.0)
try:
if float(load) < min_load:
min_load = float(load)
selected_node = node
except (ValueError, TypeError):
continue
return selected_node
else: # Default to first healthy node
return healthy_nodes[0]
return router
def health_check(
health_criteria: dict[str, Any] | None = None,
health_status_key: str = "health_status",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that performs system health checks.
Args:
health_criteria: Criteria for determining system health
health_status_key: Key in state containing health check results
Returns:
Router function that returns "healthy", "degraded", or "unhealthy"
Example:
health_router = health_check({
"response_time_ms": 1000,
"error_rate": 0.05,
"uptime_percent": 0.99
})
graph.add_conditional_edges("health_monitor", health_router)
"""
if health_criteria is None:
health_criteria = {
"response_time_ms": 1000,
"error_rate": 0.05,
"cpu_usage": 0.8,
"memory_usage": 0.9,
}
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["healthy", "degraded", "unhealthy"]:
if hasattr(state, "get") or isinstance(state, dict):
health_status = state.get(health_status_key, {})
else:
health_status = getattr(state, health_status_key, {})
if not isinstance(health_status, dict):
return "unhealthy"
# If health_status is empty, we have no health data to assess
if not health_status:
return "unhealthy"
unhealthy_count = 0
degraded_count = 0
for metric, threshold in health_criteria.items():
current_value = health_status.get(metric)
if current_value is None:
continue
try:
value = float(current_value)
threshold_val = float(threshold)
# Different metrics have different "good" directions
if metric in [
"uptime_percent",
"availability",
"success_rate",
]:
# Higher is better (like uptime_percent)
if value < threshold_val * 0.8: # 80% of threshold
unhealthy_count += 1
elif value < threshold_val:
degraded_count += 1
else:
# Lower is better (default for most metrics including unknown ones)
if abs(threshold_val) < 1e-9: # Use tolerance for float comparison
# Special case for zero threshold - any positive value is degraded
if value > 1e-9:
degraded_count += 1
elif value > threshold_val * 1.5: # 150% of threshold
unhealthy_count += 1
elif value > threshold_val:
degraded_count += 1
except (ValueError, TypeError):
continue
if unhealthy_count > 0:
return "unhealthy"
elif degraded_count > 0:
return "degraded"
else:
return "healthy"
return router
"""Monitoring and observability routing helpers for LangGraph graphs."""
from __future__ import annotations
import time
from collections.abc import Callable, Mapping
from typing import Literal
from .core import StateLike, get_state_value
def _as_mapping(value: object) -> Mapping[str, object]:
"""Return ``value`` when it behaves like a mapping, otherwise an empty dict."""
if isinstance(value, Mapping):
return value
return {}
def log_and_monitor(
log_levels: Mapping[str, str] | None = None,
log_level_key: str = "log_level",
should_log_key: str = "should_log",
) -> Callable[[StateLike], str]:
"""Create a router that determines logging and monitoring actions."""
levels = dict(log_levels) if log_levels is not None else {
"debug": "debug_log",
"info": "info_log",
"warning": "warning_alert",
"error": "error_alert",
"critical": "critical_alert",
}
def router(state: StateLike) -> str:
should_log = get_state_value(state, should_log_key, True)
if not bool(should_log):
return "no_logging"
log_level = get_state_value(state, log_level_key, "info")
return levels.get(str(log_level).lower(), "info_log")
return router
def check_resource_availability(
resource_thresholds: Mapping[str, float] | None = None,
resources_key: str = "resource_usage",
) -> Callable[[StateLike], str]:
"""Create a router that checks system resource availability."""
thresholds = dict(resource_thresholds) if resource_thresholds is not None else {
"cpu": 0.8,
"memory": 0.9,
"disk": 0.95,
"network": 0.8,
}
def router(state: StateLike) -> str:
resources_value = get_state_value(state, resources_key)
if not isinstance(resources_value, Mapping):
return "resources_unknown"
for resource_name, threshold in thresholds.items():
usage = resources_value.get(resource_name, 0.0)
try:
usage_percent = float(usage)
except (TypeError, ValueError):
continue
if usage_percent >= threshold:
return f"{resource_name}_constrained"
return "resources_available"
return router
def trigger_notifications(
notification_rules: Mapping[str, str] | None = None,
alert_level_key: str = "alert_level",
notification_enabled_key: str = "notifications_enabled",
) -> Callable[[StateLike], str]:
"""Create a router that triggers appropriate notifications."""
rules = dict(notification_rules) if notification_rules is not None else {
"low": "email_notification",
"medium": "slack_notification",
"high": "sms_notification",
"critical": "phone_notification",
"emergency": "all_channels",
}
def router(state: StateLike) -> str:
notifications_enabled = get_state_value(state, notification_enabled_key, True)
if not bool(notifications_enabled):
return "no_notification"
alert_level = get_state_value(state, alert_level_key)
if alert_level is None:
return "no_notification"
return rules.get(str(alert_level).lower(), "no_notification")
return router
def check_rate_limiting(
rate_limits: Mapping[str, Mapping[str, float]] | None = None,
request_counts_key: str = "request_counts",
time_window_key: str = "time_window",
) -> Callable[[StateLike], Literal["allowed", "rate_limited"]]:
"""Create a router that enforces rate limiting policies."""
limits = (
{endpoint: dict(config) for endpoint, config in rate_limits.items()}
if rate_limits is not None
else {
"default": {"max_requests": 100, "window_seconds": 60},
"api_calls": {"max_requests": 100, "window_seconds": 60},
"heavy_operations": {"max_requests": 10, "window_seconds": 300},
}
)
def router(state: StateLike) -> Literal["allowed", "rate_limited"]:
request_counts_value = get_state_value(state, request_counts_key, {})
request_counts = _as_mapping(request_counts_value)
time_window_value = get_state_value(state, time_window_key, {})
time_window_data = _as_mapping(time_window_value)
current_time = time_window_data.get("current_time", time.time())
try:
current_time = float(current_time)
except (TypeError, ValueError):
current_time = time.time()
for endpoint, config in limits.items():
if endpoint not in request_counts:
continue
try:
max_requests = int(config.get("max_requests", 100)) # type: ignore[arg-type]
except (TypeError, ValueError):
max_requests = 100
try:
window_seconds = float(config.get("window_seconds", 60)) # type: ignore[arg-type]
except (TypeError, ValueError):
window_seconds = 60.0
endpoint_value = request_counts.get(endpoint, {})
endpoint_data = _as_mapping(endpoint_value)
request_count = endpoint_data.get("count", 0)
window_start = endpoint_data.get("window_start", current_time)
try:
request_count = int(request_count) # type: ignore[arg-type]
except (TypeError, ValueError):
continue
try:
window_start = float(window_start) # type: ignore[arg-type]
except (TypeError, ValueError):
window_start = current_time
if (
current_time - window_start < window_seconds
and request_count > max_requests
):
return "rate_limited"
return "allowed"
return router
def load_balance(
load_balancing_strategy: str = "round_robin",
available_nodes: list[str] | None = None,
node_status_key: str = "node_status",
current_node_key: str = "current_node_index",
) -> Callable[[StateLike], str]:
"""Create a router that implements load balancing across nodes."""
nodes = list(available_nodes) if available_nodes is not None else [
"primary_node",
"secondary_node",
]
def router(state: StateLike) -> str:
status_value = get_state_value(state, node_status_key, {})
node_status = _as_mapping(status_value)
current_index = get_state_value(state, current_node_key, 0)
healthy_nodes: list[str] = []
for node in nodes:
node_data = _as_mapping(node_status.get(node, {}))
if bool(node_data.get("healthy", True)):
healthy_nodes.append(node)
if not healthy_nodes:
return "no_available_nodes"
if load_balancing_strategy == "round_robin":
try:
index = int(current_index) % len(healthy_nodes)
except (TypeError, ValueError):
index = 0
return healthy_nodes[index]
if load_balancing_strategy == "least_loaded":
min_load = float("inf")
selected_node = healthy_nodes[0]
for node in healthy_nodes:
node_data = _as_mapping(node_status.get(node, {}))
load = node_data.get("load", 0.0)
try:
load_value = float(load)
except (TypeError, ValueError):
continue
if load_value < min_load:
min_load = load_value
selected_node = node
return selected_node
return healthy_nodes[0]
return router
def health_check(
health_criteria: Mapping[str, float] | None = None,
health_status_key: str = "health_status",
) -> Callable[[StateLike], Literal["healthy", "degraded", "unhealthy"]]:
"""Create a router that performs system health checks."""
criteria = dict(health_criteria) if health_criteria is not None else {
"response_time_ms": 1000,
"error_rate": 0.05,
"cpu_usage": 0.8,
"memory_usage": 0.9,
}
def router(state: StateLike) -> Literal["healthy", "degraded", "unhealthy"]:
health_status_value = get_state_value(state, health_status_key, {})
health_status = _as_mapping(health_status_value)
if not health_status:
return "unhealthy"
unhealthy_count = 0
degraded_count = 0
for metric, threshold in criteria.items():
current_value = health_status.get(metric)
if current_value is None:
continue
try:
value = float(current_value)
threshold_val = float(threshold)
except (TypeError, ValueError):
continue
if metric in {"uptime_percent", "availability", "success_rate"}:
if value < threshold_val * 0.8:
unhealthy_count += 1
elif value < threshold_val:
degraded_count += 1
else:
if abs(threshold_val) < 1e-9:
if value > 1e-9:
degraded_count += 1
elif value > threshold_val * 1.5:
unhealthy_count += 1
elif value > threshold_val:
degraded_count += 1
if unhealthy_count > 0:
return "unhealthy"
if degraded_count > 0:
return "degraded"
return "healthy"
return router
__all__ = [
"log_and_monitor",
"check_resource_availability",
"trigger_notifications",
"check_rate_limiting",
"load_balance",
"health_check",
]

View File

@@ -1,27 +1,25 @@
"""Factory functions for creating different types of command routers.
This module provides factory functions for creating various types of
command-based routers that follow LangGraph best practices.
"""
from collections.abc import Callable
from typing import Any
from langgraph.types import Command
from biz_bud.logging import get_logger
from ..validation.condition_security import ConditionSecurityError, ConditionValidator
from .routing_rules import CommandRoutingRule
logger = get_logger(__name__)
"""Factory functions for creating different types of command routers."""
from __future__ import annotations
from collections.abc import Callable, Mapping
from langgraph.types import Command
from biz_bud.logging import get_logger
from ..validation.condition_security import ConditionSecurityError, ConditionValidator
from .core import StateLike
from .routing_rules import CommandRoutingRule
logger = get_logger(__name__)
def create_command_router(
rules: list[CommandRoutingRule],
default_target: str = "__end__",
allow_state_updates: bool = True,
) -> Callable[[Any], Command[str]]:
def create_command_router(
rules: list[CommandRoutingRule],
default_target: str = "__end__",
allow_state_updates: bool = True,
) -> Callable[[StateLike], Command[str]]:
"""Create a Command-based router from a list of routing rules.
Args:
@@ -45,7 +43,7 @@ def create_command_router(
# Sort rules by priority (highest first)
sorted_rules = sorted(rules, key=lambda r: r.priority, reverse=True)
def router_func(state: Any) -> Command[str]:
def router_func(state: StateLike) -> Command[str]:
"""Router function that evaluates rules and returns Command."""
# Evaluate rules in priority order
for rule in sorted_rules:
@@ -107,11 +105,11 @@ def _safe_construct_condition(field_name: str, operator: str, value: str) -> str
return condition
def create_status_command_router(
status_mappings: dict[str, str],
status_field: str = "status",
default_target: str = "__end__",
) -> Callable[[Any], Command[str]]:
def create_status_command_router(
status_mappings: Mapping[str, str],
status_field: str = "status",
default_target: str = "__end__",
) -> Callable[[StateLike], Command[str]]:
"""Create a Command router based on status field values.
Args:
@@ -153,10 +151,10 @@ def create_status_command_router(
return create_command_router(rules, default_target)
def create_conditional_command_router(
condition_mappings: dict[str, str],
default_target: str = "__end__",
) -> Callable[[Any], Command[str]]:
def create_conditional_command_router(
condition_mappings: Mapping[str, str],
default_target: str = "__end__",
) -> Callable[[StateLike], Command[str]]:
"""Create a Command router based on arbitrary conditions.
Args:

View File

@@ -1,17 +1,18 @@
"""Core routing rule classes for command-based routing.
This module provides the base classes and protocols for building
command routing rules with security validation.
"""
import time
from dataclasses import dataclass, field
from typing import Any, Callable, Protocol
from langgraph.types import Command
from biz_bud.core.errors import ValidationError
from biz_bud.logging import get_logger
"""Core routing rule classes for command-based routing."""
from __future__ import annotations
import ast
import contextlib
import time
from collections.abc import Callable, Mapping
from dataclasses import dataclass, field
from langgraph.types import Command
from biz_bud.core.errors import ValidationError
from biz_bud.logging import get_logger
from .core import StateLike, StateProtocol
from ..validation.condition_security import (
ConditionSecurityError,
@@ -19,15 +20,11 @@ from ..validation.condition_security import (
validate_condition_for_security,
)
logger = get_logger(__name__)
class StateProtocol(Protocol):
"""Protocol for state objects that can be used with Command routers."""
def get(self, key: str, default: Any = None) -> Any:
"""Get a value from the state."""
...
logger = get_logger(__name__)
SupportsGet = StateProtocol
ConditionCallable = Callable[[StateLike], bool]
@dataclass
@@ -47,19 +44,19 @@ class CommandRoutingRule:
description: Human-readable description of the rule
"""
condition: str | Callable[[Any], bool]
target: str
state_updates: dict[str, Any] = field(default_factory=dict)
condition: str | ConditionCallable
target: str
state_updates: dict[str, object] = field(default_factory=dict)
priority: int = 0
description: str = ""
def evaluate(self, state: Any) -> bool:
def evaluate(self, state: StateLike) -> bool:
"""Evaluate if this rule applies to the given state."""
if callable(self.condition):
return self.condition(state)
return self._evaluate_string_condition(self.condition, state)
def _evaluate_string_condition(self, condition: str, state: Any) -> bool:
def _evaluate_string_condition(self, condition: str, state: StateLike) -> bool:
"""Evaluate a string-based condition with comprehensive security validation.
This method provides robust security against injection attacks and ensures
@@ -104,14 +101,14 @@ class CommandRoutingRule:
expected_value = validated_condition[operator_position + len(operator):].strip()
# Define supported operators with their functions
operators = {
">=": lambda a, b: a >= b,
"<=": lambda a, b: a <= b,
"==": lambda a, b: a == b,
"!=": lambda a, b: a != b,
">": lambda a, b: a > b,
"<": lambda a, b: a < b,
}
operators: dict[str, Callable[[object, object], bool]] = {
">=": lambda a, b: bool(a >= b),
"<=": lambda a, b: bool(a <= b),
"==": lambda a, b: bool(a == b),
"!=": lambda a, b: bool(a != b),
">": lambda a, b: bool(a > b),
"<": lambda a, b: bool(a < b),
}
if operator not in operators:
raise ValidationError(f"Unsupported operator: {operator}")
@@ -147,7 +144,7 @@ class CommandRoutingRule:
)
raise ValidationError(f"Condition evaluation failed: {type(e).__name__}") from e
def _parse_condition_value(self, value_str: str) -> Any:
def _parse_condition_value(self, value_str: str) -> object:
"""Parse a condition value string into appropriate Python type.
Args:
@@ -156,33 +153,31 @@ class CommandRoutingRule:
Returns:
Parsed value with appropriate type
"""
value_str = value_str.strip()
value_str = value_str.strip()
if value_str.lower() in ("true", "false"):
return value_str.lower() == "true"
if value_str.lower() in {"null", "none"}:
return None
if (value_str.startswith("'") and value_str.endswith("'")) or (
value_str.startswith('"') and value_str.endswith('"')
):
try:
return ast.literal_eval(value_str)
except (ValueError, SyntaxError):
return value_str[1:-1]
with contextlib.suppress(ValueError):
return int(value_str)
with contextlib.suppress(ValueError):
return float(value_str)
return value_str
# Handle quoted strings (remove quotes but preserve content)
if (value_str.startswith('"') and value_str.endswith('"')) or (
value_str.startswith("'") and value_str.endswith("'")
):
return value_str[1:-1] # Return string without quotes
# Handle boolean values
if value_str.lower() == "true":
return True
elif value_str.lower() == "false":
return False
elif value_str.lower() in {"null", "none"}:
return None
# Handle numeric values
try:
# Try integer first (handles both positive and negative integers)
return (
int(value_str) if value_str.lstrip("-").isdigit() else float(value_str)
)
except ValueError:
# If all else fails, treat as string
return value_str
def _safe_get_state_value(self, state: Any, field_name: str) -> Any:
def _safe_get_state_value(
self, state: StateLike, field_name: str
) -> object | None:
"""Safely get value from state object.
Args:
@@ -193,24 +188,27 @@ class CommandRoutingRule:
Field value or None if not found
"""
# For simple field names (no dots), handle directly
if "." not in field_name:
if hasattr(state, "get") or isinstance(state, dict):
return state.get(field_name)
else:
return getattr(state, field_name, None)
if "." not in field_name:
if isinstance(state, Mapping):
return state.get(field_name)
if isinstance(state, StateProtocol):
return state.get(field_name)
return getattr(state, field_name, None)
# For nested access (with dots), traverse safely
current = state
for part in field_name.split("."):
if current is None:
return None
if hasattr(current, "get") or isinstance(current, dict):
current = current.get(part)
else:
current = getattr(current, part, None)
return current
def create_command(self, state: Any) -> Command[str]:
if isinstance(current, Mapping):
current = current.get(part)
elif isinstance(current, StateProtocol):
current = current.get(part)
else:
current = getattr(current, part, None)
return current
def create_command(self, state: StateLike) -> Command[str]:
"""Create a Command object for this rule.
Args:
@@ -231,4 +229,9 @@ class CommandRoutingRule:
return Command(goto=self.target, update=updates)
__all__ = ["StateProtocol", "CommandRoutingRule"]
__all__ = [
"CommandRoutingRule",
"SupportsGet",
"StateLike",
"StateProtocol",
]

View File

@@ -7,7 +7,8 @@ resource monitoring, and safe execution contexts for LangGraph workflows.
from __future__ import annotations
import uuid
from typing import Any, Literal
from collections.abc import Callable, Mapping, MutableMapping
from typing import Literal, cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import Command
@@ -47,12 +48,12 @@ class SecureGraphRouter:
async def secure_graph_execution(
self,
graph_name: str,
graph_info: dict[str, Any],
execution_state: dict[str, Any],
graph_info: Mapping[str, object],
execution_state: Mapping[str, object],
config: RunnableConfig | None = None,
step_id: str | None = None,
# noqa: ARG001
) -> dict[str, Any]:
) -> dict[str, object]:
"""Execute a graph securely with comprehensive validation and monitoring.
Args:
@@ -83,16 +84,19 @@ class SecureGraphRouter:
self.validator.check_concurrent_limit()
# SECURITY: Validate state data
validated_state = self.validator.validate_state_data(execution_state.copy())
validated_state = self.validator.validate_state_data(
dict(execution_state)
)
# Get and validate factory function
factory_function = graph_info.get("factory_function")
if not factory_function:
factory_obj = graph_info.get("factory_function")
if not callable(factory_obj):
raise SecurityValidationError(
f"No factory function for graph: {validated_graph_name}",
validated_graph_name,
"missing_factory",
)
factory_function = cast(Callable[[], object], factory_obj)
# SECURITY: Validate factory function
self.execution_manager.validate_factory_function(
@@ -133,7 +137,7 @@ class SecureGraphRouter:
def create_security_failure_command(
self,
error: SecurityValidationError | ResourceLimitExceededError,
execution_plan: dict[str, Any],
execution_plan: MutableMapping[str, object],
step_id: str | None = None,
) -> Command[Literal["__end__"]]:
"""Create a command for handling security failures.
@@ -169,7 +173,7 @@ class SecureGraphRouter:
},
)
def get_execution_statistics(self) -> dict[str, Any]:
def get_execution_statistics(self) -> dict[str, object]:
"""Get current execution statistics from the security manager.
Returns:
@@ -196,12 +200,12 @@ def get_secure_router() -> SecureGraphRouter:
async def execute_graph_securely(
graph_name: str,
graph_info: dict[str, Any],
execution_state: dict[str, Any],
graph_info: Mapping[str, object],
execution_state: Mapping[str, object],
config: RunnableConfig | None = None,
step_id: str | None = None,
# noqa: ARG001
) -> dict[str, Any]:
) -> dict[str, object]:
"""Execute graph securely with validation and error handling.
Args:

View File

@@ -1,326 +1,211 @@
"""User interaction edge helpers for human-in-the-loop workflows.
This module provides edge helpers for managing user interactions, interrupts,
feedback loops, and escalation to human operators.
"""
from collections.abc import Callable
from typing import Any, Literal, TypeVar
from .core import StateProtocol
StateT = TypeVar("StateT", bound=StateProtocol)
def human_interrupt(
interrupt_signals: list[str] | None = None,
interrupt_key: str = "human_interrupt",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that detects human interruption requests.
Args:
interrupt_signals: List of signals that indicate interrupt request
interrupt_key: Key in state containing interrupt signal
Returns:
Router function that returns "interrupt" or "continue"
Example:
interrupt_router = human_interrupt([
"stop", "pause", "cancel", "abort"
])
graph.add_conditional_edges("user_input_check", interrupt_router)
"""
if interrupt_signals is None:
interrupt_signals = ["stop", "pause", "cancel", "abort", "interrupt", "halt"]
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["interrupt", "continue"]:
if hasattr(state, "get") or isinstance(state, dict):
signal = state.get(interrupt_key)
else:
signal = getattr(state, interrupt_key, None)
if signal is None:
return "continue"
signal_str = str(signal).lower().strip()
normalized_interrupt_signals = {s.lower() for s in interrupt_signals}
return "interrupt" if signal_str in normalized_interrupt_signals else "continue"
return router
def pass_status_to_user(
status_levels: dict[str, str] | None = None,
status_key: str = "status",
notify_key: str = "notify_user",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that determines when to notify users of status.
Args:
status_levels: Mapping of status values to notification urgency
status_key: Key in state containing current status
notify_key: Key in state indicating if user should be notified
Returns:
Router function that returns notification priority or "no_notification"
Example:
status_router = pass_status_to_user({
"error": "urgent",
"warning": "medium",
"completed": "low"
})
graph.add_conditional_edges("status_monitor", status_router)
"""
if status_levels is None:
status_levels = {
"error": "urgent",
"failed": "urgent",
"warning": "medium",
"completed": "low",
"success": "low",
"in_progress": "info",
}
def router(state: dict[str, Any] | StateProtocol) -> str:
if hasattr(state, "get") or isinstance(state, dict):
status = state.get(status_key)
should_notify = state.get(notify_key, True)
else:
status = getattr(state, status_key, None)
should_notify = getattr(state, notify_key, True)
if not should_notify or status is None:
return "no_notification"
status_str = str(status).lower()
return status_levels.get(status_str, "no_notification")
return router
def user_feedback_loop(
feedback_required_conditions: list[str] | None = None,
feedback_key: str = "requires_feedback",
feedback_type_key: str = "feedback_type",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that manages user feedback collection.
Args:
feedback_required_conditions: Conditions that trigger feedback request
feedback_key: Key in state indicating if feedback is required
feedback_type_key: Key in state specifying type of feedback needed
Returns:
Router function that returns feedback type or "no_feedback"
Example:
feedback_router = user_feedback_loop([
"low_confidence", "ambiguous_input", "multiple_options"
])
graph.add_conditional_edges("feedback_check", feedback_router)
"""
if feedback_required_conditions is None:
feedback_required_conditions = [
"low_confidence",
"ambiguous_input",
"multiple_options",
"validation_failed",
]
def router(state: dict[str, Any] | StateProtocol) -> str:
# Check if feedback is explicitly required
if hasattr(state, "get") or isinstance(state, dict):
requires_feedback = state.get(feedback_key, False)
feedback_type = state.get(feedback_type_key, "general")
else:
requires_feedback = getattr(state, feedback_key, False)
feedback_type = getattr(state, feedback_type_key, "general")
if requires_feedback:
return f"feedback_{feedback_type}"
# Check for conditions that trigger feedback
for condition in feedback_required_conditions:
if hasattr(state, "get") or isinstance(state, dict):
condition_met = state.get(condition, False)
else:
condition_met = getattr(state, condition, False)
if condition_met:
return f"feedback_{condition}"
return "no_feedback"
return router
def escalate_to_human(
escalation_triggers: dict[str, str] | None = None,
auto_escalate_key: str = "auto_escalate",
escalation_reason_key: str = "escalation_reason",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that escalates to human operators when needed.
Args:
escalation_triggers: Mapping of trigger conditions to escalation types
auto_escalate_key: Key in state for automatic escalation flag
escalation_reason_key: Key in state containing escalation reason
Returns:
Router function that returns escalation type or "continue_automated"
Example:
escalation_router = escalate_to_human({
"critical_error": "immediate",
"multiple_failures": "urgent",
"user_request": "standard"
})
graph.add_conditional_edges("escalation_check", escalation_router)
"""
if escalation_triggers is None:
escalation_triggers = {
"critical_error": "immediate",
"security_issue": "immediate",
"multiple_failures": "urgent",
"user_request": "standard",
"manual_review": "standard",
"complex_case": "expert",
}
def router(state: dict[str, Any] | StateProtocol) -> str:
# Check for automatic escalation flag
if hasattr(state, "get") or isinstance(state, dict):
auto_escalate = state.get(auto_escalate_key, False)
escalation_reason = state.get(escalation_reason_key)
else:
auto_escalate = getattr(state, auto_escalate_key, False)
escalation_reason = getattr(state, escalation_reason_key, None)
if auto_escalate and escalation_reason:
reason_str = str(escalation_reason).lower()
for trigger, escalation_type in escalation_triggers.items():
if trigger in reason_str:
return f"escalate_{escalation_type}"
return "escalate_standard"
# Check for specific escalation triggers in state
for trigger, escalation_type in escalation_triggers.items():
if hasattr(state, "get") or isinstance(state, dict):
trigger_present = state.get(trigger, False)
else:
trigger_present = getattr(state, trigger, False)
if trigger_present:
return f"escalate_{escalation_type}"
return "continue_automated"
return router
def check_user_authorization(
required_permissions: list[str] | None = None,
user_key: str = "user",
permissions_key: str = "permissions",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that checks user authorization levels.
Args:
required_permissions: List of permissions required for access
user_key: Key in state containing user information
permissions_key: Key in user data containing permissions list
Returns:
Router function that returns "authorized" or "unauthorized"
Example:
auth_router = check_user_authorization([
"read_data", "write_reports", "admin_access"
])
graph.add_conditional_edges("permission_check", auth_router)
"""
if required_permissions is None:
required_permissions = ["basic_access"]
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["authorized", "unauthorized"]:
if hasattr(state, "get") or isinstance(state, dict):
user = state.get(user_key)
else:
user = getattr(state, user_key, None)
if user is None:
return "unauthorized"
# Extract user permissions
user_permissions = []
if isinstance(user, dict):
user_permissions = user.get(permissions_key, [])
elif hasattr(user, permissions_key):
user_permissions = getattr(user, permissions_key, [])
if not isinstance(user_permissions, list):
return "unauthorized"
# Check if user has all required permissions
for permission in required_permissions:
if permission not in user_permissions:
return "unauthorized"
return "authorized"
return router
def collect_user_input(
input_types: dict[str, str] | None = None,
pending_input_key: str = "pending_user_input",
input_type_key: str = "input_type",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that manages user input collection.
Args:
input_types: Mapping of input types to collection methods
pending_input_key: Key in state indicating pending input
input_type_key: Key in state specifying type of input needed
Returns:
Router function that returns input collection method or "no_input_needed"
Example:
input_router = collect_user_input({
"text": "text_input_form",
"choice": "multiple_choice",
"file": "file_upload"
})
graph.add_conditional_edges("input_manager", input_router)
"""
if input_types is None:
input_types = {
"text": "text_input",
"choice": "choice_input",
"confirmation": "confirm_input",
"file": "file_input",
"numeric": "number_input",
}
def router(state: dict[str, Any] | StateProtocol) -> str:
if hasattr(state, "get") or isinstance(state, dict):
pending = state.get(pending_input_key, False)
input_type = state.get(input_type_key, "text")
else:
pending = getattr(state, pending_input_key, False)
input_type = getattr(state, input_type_key, "text")
if not pending:
return "no_input_needed"
input_type_str = str(input_type).lower()
return input_types.get(input_type_str, "text_input")
return router
"""User interaction routing helpers for LangGraph workflows."""
from __future__ import annotations
from collections.abc import Callable, Mapping, Sequence
from typing import Literal
from .core import StateLike, get_state_value
def _normalize_strings(values: Sequence[str]) -> set[str]:
"""Return normalized lowercase strings for comparison."""
return {value.lower().strip() for value in values}
def human_interrupt(
interrupt_signals: Sequence[str] | None = None,
interrupt_key: str = "human_interrupt",
) -> Callable[[StateLike], Literal["interrupt", "continue"]]:
"""Create a router that detects human interruption requests."""
signals = _normalize_strings(interrupt_signals or [
"stop",
"pause",
"cancel",
"abort",
"interrupt",
"halt",
])
def router(state: StateLike) -> Literal["interrupt", "continue"]:
signal = get_state_value(state, interrupt_key)
if signal is None:
return "continue"
signal_str = str(signal).lower().strip()
return "interrupt" if signal_str in signals else "continue"
return router
def pass_status_to_user(
status_levels: Mapping[str, str] | None = None,
status_key: str = "status",
notify_key: str = "notify_user",
) -> Callable[[StateLike], str]:
"""Create a router that determines when to notify users of status."""
levels = (
{key.lower(): value for key, value in status_levels.items()}
if status_levels is not None
else {
"error": "urgent",
"failed": "urgent",
"warning": "medium",
"completed": "low",
"success": "low",
"in_progress": "info",
}
)
def router(state: StateLike) -> str:
should_notify = bool(get_state_value(state, notify_key, True))
status = get_state_value(state, status_key)
if not should_notify or status is None:
return "no_notification"
return levels.get(str(status).lower(), "no_notification")
return router
def user_feedback_loop(
feedback_required_conditions: Sequence[str] | None = None,
feedback_key: str = "requires_feedback",
feedback_type_key: str = "feedback_type",
) -> Callable[[StateLike], str]:
"""Create a router that manages user feedback collection."""
conditions = list(feedback_required_conditions) if feedback_required_conditions is not None else [
"low_confidence",
"ambiguous_input",
"multiple_options",
"validation_failed",
]
def router(state: StateLike) -> str:
if bool(get_state_value(state, feedback_key, False)):
feedback_type = get_state_value(state, feedback_type_key, "general")
return f"feedback_{feedback_type}"
for condition in conditions:
if bool(get_state_value(state, condition, False)):
return f"feedback_{condition}"
return "no_feedback"
return router
def escalate_to_human(
escalation_triggers: Mapping[str, str] | None = None,
auto_escalate_key: str = "auto_escalate",
escalation_reason_key: str = "escalation_reason",
) -> Callable[[StateLike], str]:
"""Create a router that escalates to human operators when needed."""
triggers = (
dict(escalation_triggers)
if escalation_triggers is not None
else {
"critical_error": "immediate",
"security_issue": "immediate",
"multiple_failures": "urgent",
"user_request": "standard",
"manual_review": "standard",
"complex_case": "expert",
}
)
def router(state: StateLike) -> str:
auto_escalate = bool(get_state_value(state, auto_escalate_key, False))
escalation_reason = get_state_value(state, escalation_reason_key)
if auto_escalate and escalation_reason:
reason_str = str(escalation_reason).lower()
for trigger, escalation_type in triggers.items():
if trigger in reason_str:
return f"escalate_{escalation_type}"
return "escalate_standard"
for trigger, escalation_type in triggers.items():
if bool(get_state_value(state, trigger, False)):
return f"escalate_{escalation_type}"
return "continue_automated"
return router
def check_user_authorization(
required_permissions: Sequence[str] | None = None,
user_key: str = "user",
permissions_key: str = "permissions",
) -> Callable[[StateLike], Literal["authorized", "unauthorized"]]:
"""Create a router that checks user authorization levels."""
required = set(required_permissions or ["basic_access"])
def router(state: StateLike) -> Literal["authorized", "unauthorized"]:
user_value = get_state_value(state, user_key)
if user_value is None:
return "unauthorized"
permissions_value = None
if isinstance(user_value, dict):
permissions_value = user_value.get(permissions_key)
else:
permissions_value = getattr(user_value, permissions_key, None)
if not isinstance(permissions_value, Sequence) or isinstance(
permissions_value, (str, bytes, bytearray)
):
return "unauthorized"
permissions = {str(permission) for permission in permissions_value}
if required.issubset(permissions):
return "authorized"
return "unauthorized"
return router
def collect_user_input(
input_types: Mapping[str, str] | None = None,
pending_input_key: str = "pending_user_input",
input_type_key: str = "input_type",
) -> Callable[[StateLike], str]:
"""Create a router that manages user input collection."""
inputs = (
dict(input_types)
if input_types is not None
else {
"text": "text_input",
"choice": "choice_input",
"confirmation": "confirm_input",
"file": "file_input",
"numeric": "number_input",
}
)
def router(state: StateLike) -> str:
if not bool(get_state_value(state, pending_input_key, False)):
return "no_input_needed"
input_type = str(get_state_value(state, input_type_key, "text")).lower()
return inputs.get(input_type, "text_input")
return router
__all__ = [
"human_interrupt",
"pass_status_to_user",
"user_feedback_loop",
"escalate_to_human",
"check_user_authorization",
"collect_user_input",
]

View File

@@ -1,406 +1,270 @@
"""Validation edge helpers for quality control and data integrity.
This module provides edge helpers for validating outputs, checking accuracy,
confidence levels, format compliance, and data privacy requirements.
"""
import json
import re
import time
from collections.abc import Callable
from typing import Any, Literal, TypeVar
from .core import StateProtocol
StateT = TypeVar("StateT", bound=StateProtocol)
def check_accuracy(
threshold: float = 0.8,
accuracy_key: str = "accuracy_score",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that checks if accuracy meets threshold.
Args:
threshold: Minimum accuracy threshold (0.0 to 1.0)
accuracy_key: Key in state containing accuracy score
Returns:
Router function that returns "high_accuracy" or "low_accuracy"
Example:
accuracy_router = check_accuracy(threshold=0.85)
graph.add_conditional_edges("quality_check", accuracy_router)
"""
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["high_accuracy", "low_accuracy"]:
if isinstance(state, dict):
accuracy = state.get(accuracy_key, 0.0)
else:
accuracy = getattr(state, accuracy_key, 0.0)
try:
score = float(accuracy)
return "high_accuracy" if score >= threshold else "low_accuracy"
except (ValueError, TypeError):
return "low_accuracy"
return router
def check_confidence_level(
threshold: float = 0.7,
confidence_key: str = "confidence",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router based on confidence score threshold.
Args:
threshold: Minimum confidence threshold (0.0 to 1.0)
confidence_key: Key in state containing confidence score
Returns:
Router function that returns "high_confidence" or "low_confidence"
Example:
confidence_router = check_confidence_level(threshold=0.75)
graph.add_conditional_edges("llm_output", confidence_router)
"""
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["high_confidence", "low_confidence"]:
if isinstance(state, dict):
confidence = state.get(confidence_key, 0.0)
else:
confidence = getattr(state, confidence_key, 0.0)
try:
score = float(confidence)
return "high_confidence" if score >= threshold else "low_confidence"
except (ValueError, TypeError):
return "low_confidence"
return router
def validate_output_format(
expected_format: str = "json",
output_key: str = "output",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that validates output format.
Args:
expected_format: Expected format ("json", "xml", "csv", "text")
output_key: Key in state containing output to validate
Returns:
Router function that returns "valid_format" or "invalid_format"
Example:
format_router = validate_output_format(expected_format="json")
graph.add_conditional_edges("formatter", format_router)
"""
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["valid_format", "invalid_format"]:
if isinstance(state, dict):
output = state.get(output_key, "")
else:
output = getattr(state, output_key, "")
if not output:
return "invalid_format"
try:
if expected_format.lower() == "json":
json.loads(str(output))
return "valid_format"
elif expected_format.lower() == "xml":
# Basic XML validation - check for proper opening/closing tags
output_str = str(output).strip()
if output_str.startswith("<") and output_str.endswith(">") and (output_str.endswith("/>") or (
"</" in output_str and output_str.count("<") >= 2
)):
return "valid_format"
return "invalid_format"
elif expected_format.lower() == "csv":
# Basic CSV validation - check for comma separation
lines = str(output).strip().split("\n")
return "valid_format" if lines and "," in lines[0] else "invalid_format"
else: # Default to text validation
# Convert to string and check if non-empty
text_output = str(output)
if text_output.strip() != "":
return "valid_format"
return "invalid_format"
except Exception:
return "invalid_format"
return router
def check_privacy_compliance(
sensitive_patterns: list[str] | None = None,
content_key: str = "content",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
r"""Create a router that checks for privacy compliance.
Args:
sensitive_patterns: List of regex patterns for sensitive data
content_key: Key in state containing content to check
Returns:
Router function that returns "compliant" or "privacy_violation"
Example:
privacy_router = check_privacy_compliance([
r'\b\d{3}-\d{2}-\d{4}\b', # SSN
r'\b\d{4}[- ]?\d{4}[- ]?\d{4}[- ]?\d{4}\b' # Credit card
])
graph.add_conditional_edges("privacy_check", privacy_router)
"""
if sensitive_patterns is None:
sensitive_patterns = [
r"\b\d{3}-\d{2}-\d{4}\b", # SSN pattern
r"\b\d{4}[- ]?\d{4}[- ]?\d{4}[- ]?\d{4}\b", # Credit card pattern
r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", # Email
r"\b\d{3}[- ]\d{3}[- ]\d{4}\b", # Phone number (require separators)
]
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["compliant", "privacy_violation"]:
if isinstance(state, dict):
content = state.get(content_key, "")
else:
content = getattr(state, content_key, "")
content_str = str(content)
for pattern in sensitive_patterns:
if re.search(pattern, content_str, re.IGNORECASE):
return "privacy_violation"
return "compliant"
return router
def check_data_freshness(
max_age_seconds: int = 3600, # 1 hour default
timestamp_key: str = "timestamp",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that checks if data is fresh enough.
Args:
max_age_seconds: Maximum age in seconds before data is stale
timestamp_key: Key in state containing data timestamp
Returns:
Router function that returns "fresh" or "stale"
Example:
freshness_router = check_data_freshness(max_age_seconds=1800) # 30 minutes
graph.add_conditional_edges("data_check", freshness_router)
"""
def router(state: dict[str, Any] | StateProtocol) -> Literal["fresh", "stale"]:
if isinstance(state, dict):
timestamp = state.get(timestamp_key)
else:
timestamp = getattr(state, timestamp_key, None)
if timestamp is None:
return "stale"
try:
# Handle different timestamp formats
if isinstance(timestamp, str):
# Try parsing as Unix timestamp first, then ISO format
try:
timestamp_value = float(timestamp)
except (ValueError, TypeError):
# Try parsing ISO format
from datetime import datetime
dt = datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
timestamp_value = dt.timestamp()
else:
timestamp_value = float(timestamp)
current_time = time.time()
age = current_time - timestamp_value
return "fresh" if age <= max_age_seconds else "stale"
except (ValueError, TypeError):
return "stale"
return router
def validate_required_fields(
required_fields: list[str],
strict_mode: bool = True,
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that validates presence of required fields.
Args:
required_fields: List of field names that must be present
strict_mode: If True, fields must be non-empty; if False, just present
Returns:
Router function that returns "valid" or "missing_fields"
Example:
field_router = validate_required_fields([
"user_id", "request_data", "timestamp"
])
graph.add_conditional_edges("input_validator", field_router)
"""
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["valid", "missing_fields"]:
for field in required_fields:
if isinstance(state, dict):
value = state.get(field)
else:
value = getattr(state, field, None)
if value is None:
return "missing_fields"
if strict_mode and (
value == "" or (isinstance(value, list) and len(value) == 0)
):
return "missing_fields"
return "valid"
return router
def check_output_length(
min_length: int = 1,
max_length: int | None = None,
content_key: str = "output",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that validates output length constraints.
Args:
min_length: Minimum required length
max_length: Maximum allowed length (None for no limit)
content_key: Key in state containing content to check
Returns:
Router function that returns "valid_length", "too_short", or "too_long"
Example:
length_router = check_output_length(min_length=10, max_length=1000)
graph.add_conditional_edges("length_validator", length_router)
"""
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["valid_length", "too_short", "too_long"]:
if isinstance(state, dict):
content = state.get(content_key, "")
else:
content = getattr(state, content_key, "")
content_length = len(str(content))
if content_length < min_length:
return "too_short"
elif max_length is not None and content_length > max_length:
return "too_long"
else:
return "valid_length"
return router
def create_content_availability_router(
content_keys: list[str] | None = None,
success_target: str = "analyze_content",
failure_target: str = "status_summary",
error_key: str = "error",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that checks for content availability and success conditions.
This router is designed for workflows that need to verify if processing
was successful and content is available for further processing.
Args:
content_keys: List of state keys to check for content availability
success_target: Target node when content is available and no errors
failure_target: Target node when content is missing or errors present
error_key: Key in state containing error information
Returns:
Router function that returns success_target or failure_target
Example:
content_router = create_content_availability_router(
content_keys=["scraped_content", "repomix_output"],
success_target="analyze_content",
failure_target="status_summary"
)
graph.add_conditional_edges("processing_check", content_router)
"""
if content_keys is None:
content_keys = ["scraped_content", "repomix_output"]
def router(state: dict[str, Any] | StateProtocol) -> str:
# Check if there's an error
if isinstance(state, dict):
has_error = bool(state.get(error_key))
# Check if any content is available
has_content = False
for key in content_keys:
content = state.get(key)
if content:
# For lists, check if they have items
if isinstance(content, list):
has_content = len(content) > 0
# For strings, check if they're non-empty
elif isinstance(content, str):
has_content = len(content.strip()) > 0
# For other types, check if they're truthy
else:
has_content = bool(content)
# If we found content, break early
if has_content:
break
else:
has_error = bool(getattr(state, error_key, None))
# Check if any content is available
has_content = False
for key in content_keys:
content = getattr(state, key, None)
if content:
# For lists, check if they have items
if isinstance(content, list):
has_content = len(content) > 0
# For strings, check if they're non-empty
elif isinstance(content, str):
has_content = len(content.strip()) > 0
# For other types, check if they're truthy
else:
has_content = bool(content)
# If we found content, break early
if has_content:
break
# Route based on content availability and error status
return success_target if has_content and not has_error else failure_target
return router
"""Validation and compliance routing helpers for LangGraph graphs."""
from __future__ import annotations
import json
import re
import time
from collections.abc import Callable, Sequence
from numbers import Real
from datetime import datetime, timezone
from typing import Literal
from .core import StateLike, get_state_value
def _to_float(value: object, default: float = 0.0) -> float:
"""Convert ``value`` to ``float`` when possible, otherwise ``default``."""
if value is None:
return default
if isinstance(value, bool):
return float(value)
if isinstance(value, Real):
return float(value)
if isinstance(value, str):
stripped = value.strip()
if not stripped:
return default
try:
return float(stripped)
except ValueError:
return default
return default
def _to_str(value: object, default: str = "") -> str:
"""Return a string representation of ``value``."""
if value is None:
return default
return str(value)
def _has_content(value: object) -> bool:
"""Determine whether ``value`` carries usable content."""
if value is None:
return False
if isinstance(value, str):
return bool(value)
if isinstance(value, bool):
return True
if isinstance(value, Real):
return True
if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
return len(value) > 0
return bool(value)
def _parse_timestamp(value: object) -> float | None:
"""Parse various timestamp representations into a Unix epoch (UTC)."""
if value is None:
return None
if isinstance(value, (int, float)):
return float(value)
if isinstance(value, str):
stripped = value.strip()
try:
return float(stripped)
except (TypeError, ValueError):
pass
try:
dt = datetime.fromisoformat(stripped.replace("Z", "+00:00"))
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt.timestamp()
except Exception:
return None
return None
def check_accuracy(
threshold: float = 0.8,
accuracy_key: str = "accuracy_score",
) -> Callable[[StateLike], Literal["high_accuracy", "low_accuracy"]]:
"""Create a router that checks if accuracy meets the provided threshold."""
def router(state: StateLike) -> Literal["high_accuracy", "low_accuracy"]:
score = _to_float(get_state_value(state, accuracy_key, 0.0))
return "high_accuracy" if score >= threshold else "low_accuracy"
return router
def check_confidence_level(
threshold: float = 0.7,
confidence_key: str = "confidence",
) -> Callable[[StateLike], Literal["high_confidence", "low_confidence"]]:
"""Create a router that evaluates whether confidence exceeds ``threshold``."""
def router(state: StateLike) -> Literal["high_confidence", "low_confidence"]:
score = _to_float(get_state_value(state, confidence_key, 0.0))
return "high_confidence" if score >= threshold else "low_confidence"
return router
def validate_output_format(
expected_format: str = "json",
output_key: str = "output",
) -> Callable[[StateLike], Literal["valid_format", "invalid_format"]]:
"""Create a router that validates the serialized format of ``output_key``."""
expected = expected_format.lower()
def router(state: StateLike) -> Literal["valid_format", "invalid_format"]:
output = get_state_value(state, output_key, "")
output_str = _to_str(output)
if not output_str:
return "invalid_format"
try:
if expected == "json":
json.loads(output_str)
return "valid_format"
if expected == "xml":
stripped = output_str.strip()
if stripped.startswith("<") and stripped.endswith(">"):
if stripped.endswith("/>") or ("</" in stripped and stripped.count("<") >= 2):
return "valid_format"
return "invalid_format"
if expected == "csv":
lines = output_str.strip().split("\n")
return "valid_format" if lines and "," in lines[0] else "invalid_format"
return "valid_format" if output_str.strip() else "invalid_format"
except Exception: # noqa: BLE001
return "invalid_format"
return router
def check_privacy_compliance(
sensitive_patterns: Sequence[str] | None = None,
content_key: str = "content",
) -> Callable[[StateLike], Literal["compliant", "privacy_violation"]]:
r"""Create a router that checks content for sensitive data patterns."""
patterns = list(sensitive_patterns) if sensitive_patterns is not None else [
r"\b\d{3}-\d{2}-\d{4}\b", # SSN
r"\b\d{4}[- ]?\d{4}[- ]?\d{4}[- ]?\d{4}\b", # Credit card
r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", # Email
r"\b\d{3}[- ]\d{3}[- ]\d{4}\b", # Phone number with separators
]
def router(state: StateLike) -> Literal["compliant", "privacy_violation"]:
content_str = _to_str(get_state_value(state, content_key, ""))
for pattern in patterns:
if re.search(pattern, content_str, re.IGNORECASE):
return "privacy_violation"
return "compliant"
return router
def check_data_freshness(
max_age_seconds: int = 3600,
timestamp_key: str = "timestamp",
) -> Callable[[StateLike], Literal["fresh", "stale"]]:
"""Create a router that validates whether stored data is still fresh."""
def router(state: StateLike) -> Literal["fresh", "stale"]:
timestamp_value = get_state_value(state, timestamp_key)
parsed = _parse_timestamp(timestamp_value)
if parsed is None:
return "stale"
age = time.time() - parsed
return "fresh" if age <= max_age_seconds else "stale"
return router
def validate_required_fields(
required_fields: Sequence[str],
strict_mode: bool = True,
) -> Callable[[StateLike], Literal["valid", "missing_fields"]]:
"""Create a router that checks the presence of required state fields."""
def router(state: StateLike) -> Literal["valid", "missing_fields"]:
for field in required_fields:
value = get_state_value(state, field)
if value is None:
return "missing_fields"
if strict_mode and not _has_content(value):
return "missing_fields"
return "valid"
return router
def check_output_length(
min_length: int = 1,
max_length: int | None = None,
content_key: str = "output",
) -> Callable[[StateLike], Literal["valid_length", "too_short", "too_long"]]:
"""Create a router that validates content length constraints."""
def router(state: StateLike) -> Literal["valid_length", "too_short", "too_long"]:
content_str = _to_str(get_state_value(state, content_key, ""))
content_length = len(content_str)
if content_length < min_length:
return "too_short"
if max_length is not None and content_length > max_length:
return "too_long"
return "valid_length"
return router
def create_content_availability_router(
content_keys: Sequence[str] | None = None,
success_target: str = "analyze_content",
failure_target: str = "status_summary",
error_key: str = "error",
) -> Callable[[StateLike], str]:
"""Create a router that checks for content availability and error flags."""
keys = list(content_keys) if content_keys is not None else [
"scraped_content",
"repomix_output",
]
def _content_available(value: object | None) -> bool:
if value is None:
return False
if isinstance(value, str):
return bool(value.strip())
if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
return len(value) > 0
return bool(value)
def router(state: StateLike) -> str:
has_error = bool(get_state_value(state, error_key))
has_content = any(
_content_available(get_state_value(state, key))
for key in keys
)
return success_target if has_content and not has_error else failure_target
return router
__all__ = [
"check_accuracy",
"check_confidence_level",
"validate_output_format",
"check_privacy_compliance",
"check_data_freshness",
"validate_required_fields",
"check_output_length",
"create_content_availability_router",
]

View File

@@ -8,12 +8,13 @@ routing and provide utilities for building composite routing logic.
from __future__ import annotations
from collections.abc import Callable, Sequence
from typing import Any
from collections.abc import Callable, Mapping, Sequence
from .core import StateLike, get_state_value
from langgraph.types import Command
from biz_bud.logging import debug_highlight
from biz_bud.logging import debug_highlight, error_highlight
class WorkflowRouters:
@@ -24,7 +25,7 @@ class WorkflowRouters:
stage_key: str = "workflow_stage",
stage_mapping: dict[str, str] | None = None,
default_stage: str = "end",
) -> Callable[[dict[str, Any]], str]:
) -> Callable[[StateLike], str]:
"""Route based on workflow stage progression.
Routes based on current workflow stage, useful for sequential
@@ -55,9 +56,11 @@ class WorkflowRouters:
if stage_mapping is None:
stage_mapping = {}
def router(state: dict[str, Any]) -> str:
current_stage = state.get(stage_key, "unknown")
return stage_mapping.get(current_stage, default_stage)
def router(state: StateLike) -> str:
current_stage = get_state_value(state, stage_key, "unknown")
if isinstance(current_stage, str):
return stage_mapping.get(current_stage, default_stage)
return default_stage
return router
@@ -67,7 +70,7 @@ class WorkflowRouters:
subgraph_mapping: dict[str, str],
default_subgraph: str = "main",
parent_return: str = "consolidate",
) -> Callable[[dict[str, Any]], Command[str]]:
) -> Callable[[StateLike], Command[str]]:
"""Route to subgraphs using Command pattern.
Routes to different subgraphs based on state values, useful for
@@ -97,8 +100,9 @@ class WorkflowRouters:
```
"""
def router(state: dict[str, Any]) -> Command[str]:
task_type = state.get(subgraph_key, "default")
def router(state: StateLike) -> Command[str]:
raw_task_type = get_state_value(state, subgraph_key, "default")
task_type = raw_task_type if isinstance(raw_task_type, str) else str(raw_task_type)
target_subgraph = subgraph_mapping.get(task_type, default_subgraph)
debug_highlight(
@@ -119,9 +123,9 @@ class WorkflowRouters:
@staticmethod
def combine_routers(
routers: Sequence[tuple[Callable[[dict[str, Any]], Any], str]],
routers: Sequence[tuple[Callable[[StateLike], object], str]],
combination_logic: str = "first_match",
) -> Callable[[dict[str, Any]], Any]:
) -> Callable[[StateLike], object]:
"""Combine multiple routers with specified logic.
Allows composition of multiple routing functions with different
@@ -145,8 +149,8 @@ class WorkflowRouters:
```
"""
def combined_router(state: dict[str, Any]) -> Any:
results = []
def combined_router(state: StateLike) -> object:
results: list[tuple[str, object]] = []
for router_func, router_name in routers:
try:
@@ -174,41 +178,37 @@ class WorkflowRouters:
@staticmethod
def create_debug_router(
inner_router: Callable[[dict[str, Any]], Any],
inner_router: Callable[[StateLike], object],
name: str = "unnamed_router",
) -> Callable[[dict[str, Any]], Any]:
"""Wrap a router with debug logging.
fallback: object | None = "end",
rethrow: bool = False,
) -> Callable[[StateLike], object]:
"""Wrap a router with debug logging."""
Adds comprehensive debug logging around router execution,
useful for troubleshooting routing decisions in complex workflows.
Args:
inner_router: The router function to wrap
name: Name for logging identification
Returns:
Debug-wrapped router function
Example:
```python
basic_router = BasicRouters.route_on_key("status", {"ok": "continue"})
debug_router = WorkflowRouters.create_debug_router(
basic_router,
"status_router"
def debug_router(state: StateLike) -> object:
state_keys: list[str] = (
list(state.keys()) if isinstance(state, Mapping) else []
)
graph.add_conditional_edges("processor", debug_router)
```
"""
def debug_router(state: dict[str, Any]) -> Any:
debug_highlight(
f"Router '{name}' evaluating state keys: {list(state.keys())}",
f"Router '{name}' evaluating state keys: {state_keys}",
category="EdgeRouter",
)
result = inner_router(state)
debug_highlight(f"Router '{name}' result: {result}", category="EdgeRouter")
return result
try:
result = inner_router(state)
debug_highlight(
f"Router '{name}' produced route: {result!r}",
category="EdgeRouter",
)
return result
except Exception as exc: # pragma: no cover - defensive logging path
error_highlight(
f"Router '{name}' failed with error: {exc}",
category="EdgeRouter",
)
if rethrow:
raise
return fallback
return debug_router

View File

@@ -2,8 +2,9 @@
from __future__ import annotations
from collections.abc import Callable, Mapping
from functools import wraps
from typing import Any, Callable, TypeVar, cast
from typing import TypeVar, cast
from .cross_cutting import (
handle_errors,
@@ -38,28 +39,30 @@ from .state_immutability import (
update_state_immutably,
)
StateT = TypeVar("StateT")
StateT = TypeVar("StateT", bound=Mapping[str, object])
ReturnT = TypeVar("ReturnT")
def create_type_safe_wrapper(
func: Callable[[StateT], ReturnT]
) -> Callable[[dict[str, Any]], ReturnT]:
) -> Callable[[dict[str, object]], ReturnT]:
"""Wrap a router or helper to satisfy LangGraph's ``dict``-based typing.
LangGraph nodes and routers receive ``dict[str, Any]`` state objects at
LangGraph nodes and routers receive ``dict[str, object]`` state objects at
runtime, but most of our helpers are annotated with TypedDict subclasses or
custom mapping types. This utility provides a lightweight adapter that
custom mapping types. This utility provides a lightweight adapter that
casts the dynamic runtime state into the helper's expected type while
preserving the original return value and function metadata.
The wrapper intentionally performs a shallow cast rather than copying the
state to avoid the overhead of deep cloning large graphs. Callers that need
state to avoid the overhead of deep cloning large graphs. Callers that need
defensive copies should do so within their helper implementation.
"""
@wraps(func)
def wrapper(state: dict[str, Any], *args: Any, **kwargs: Any) -> ReturnT:
def wrapper(
state: dict[str, object], *args: object, **kwargs: object
) -> ReturnT:
return func(cast(StateT, state), *args, **kwargs)
return wrapper

View File

@@ -8,15 +8,19 @@ nodes and tools in the Business Buddy framework.
import asyncio
import functools
import time
from collections.abc import Callable
from collections.abc import Awaitable, Callable, Mapping, MutableMapping, Sequence
from datetime import UTC, datetime
from typing import Any, TypedDict, cast
from typing import ParamSpec, TypeVar, TypedDict, cast
from biz_bud.logging import get_logger
logger = get_logger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
class NodeMetric(TypedDict):
"""Type definition for node metrics."""
@@ -29,7 +33,7 @@ class NodeMetric(TypedDict):
last_error: str | None
def _log_execution_start(node_name: str, context: dict[str, Any]) -> float:
def _log_execution_start(node_name: str, context: Mapping[str, object]) -> float:
"""Log the start of node execution and return start time.
Args:
@@ -51,7 +55,9 @@ def _log_execution_start(node_name: str, context: dict[str, Any]) -> float:
return start_time
def _log_execution_success(node_name: str, start_time: float, context: dict[str, Any]) -> None:
def _log_execution_success(
node_name: str, start_time: float, context: Mapping[str, object]
) -> None:
"""Log successful node execution.
Args:
@@ -71,7 +77,10 @@ def _log_execution_success(node_name: str, start_time: float, context: dict[str,
def _log_execution_error(
node_name: str, start_time: float, error: Exception, context: dict[str, Any]
node_name: str,
start_time: float,
error: Exception,
context: Mapping[str, object],
) -> None:
"""Log node execution error.
@@ -95,7 +104,7 @@ def _log_execution_error(
def log_node_execution(
node_name: str | None = None,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
) -> Callable[[Callable[P, Awaitable[R] | R]], Callable[P, Awaitable[R] | R]]:
"""Log node execution with timing and context.
This decorator automatically logs entry, exit, and timing information
@@ -108,16 +117,18 @@ def log_node_execution(
Decorated function with logging
"""
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
def decorator(func: Callable[P, Awaitable[R] | R]) -> Callable[P, Awaitable[R] | R]:
actual_node_name = node_name or func.__name__
@functools.wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
context = _extract_context_from_args(args, kwargs)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
context = _extract_context_from_args(
cast(tuple[object, ...], args), cast(Mapping[str, object], kwargs)
)
start_time = _log_execution_start(actual_node_name, context)
try:
result = await func(*args, **kwargs)
result = await cast(Callable[P, Awaitable[R]], func)(*args, **kwargs)
_log_execution_success(actual_node_name, start_time, context)
return result
except Exception as e:
@@ -125,12 +136,14 @@ def log_node_execution(
raise
@functools.wraps(func)
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
context = _extract_context_from_args(args, kwargs)
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
context = _extract_context_from_args(
cast(tuple[object, ...], args), cast(Mapping[str, object], kwargs)
)
start_time = _log_execution_start(actual_node_name, context)
try:
result = func(*args, **kwargs)
result = cast(Callable[P, R], func)(*args, **kwargs)
_log_execution_success(actual_node_name, start_time, context)
return result
except Exception as e:
@@ -138,12 +151,16 @@ def log_node_execution(
raise
# Return appropriate wrapper based on function type
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
if asyncio.iscoroutinefunction(func):
return cast(Callable[P, Awaitable[R] | R], async_wrapper)
return cast(Callable[P, Awaitable[R] | R], sync_wrapper)
return decorator
def _initialize_metric(state: dict[str, Any] | None, metric_name: str) -> NodeMetric | None:
def _initialize_metric(
state: MutableMapping[str, object] | None, metric_name: str
) -> NodeMetric | None:
"""Initialize or retrieve a metric from state.
Args:
@@ -156,24 +173,51 @@ def _initialize_metric(state: dict[str, Any] | None, metric_name: str) -> NodeMe
if state is None:
return None
if "metrics" not in state:
state["metrics"] = {}
metrics_value = state.get("metrics")
if isinstance(metrics_value, MutableMapping):
metrics: MutableMapping[str, object] = metrics_value
elif isinstance(metrics_value, Mapping):
metrics = dict(metrics_value)
state["metrics"] = metrics
else:
metrics = {}
state["metrics"] = metrics
metrics = state["metrics"]
raw_entry = metrics.get(metric_name)
if isinstance(raw_entry, Mapping):
entry_dict: dict[str, object] = dict(raw_entry)
else:
try:
entry_dict = dict(raw_entry) if raw_entry is not None else {}
except Exception:
entry_dict = {}
if metric_name not in metrics:
metrics[metric_name] = NodeMetric(
count=0,
success_count=0,
failure_count=0,
total_duration_ms=0.0,
avg_duration_ms=0.0,
last_execution=None,
last_error=None,
)
def _as_int(value: object) -> int:
if isinstance(value, bool):
return int(value)
if isinstance(value, (int, float)):
return int(value)
return 0
metric = cast("NodeMetric", metrics[metric_name])
metric["count"] = (metric["count"] or 0) + 1
def _as_float(value: object) -> float:
if isinstance(value, (int, float)):
return float(value)
return 0.0
metric: NodeMetric = NodeMetric(
count=_as_int(entry_dict.get("count")),
success_count=_as_int(entry_dict.get("success_count")),
failure_count=_as_int(entry_dict.get("failure_count")),
total_duration_ms=_as_float(entry_dict.get("total_duration_ms")),
avg_duration_ms=0.0,
last_execution=str(entry_dict.get("last_execution"))
if isinstance(entry_dict.get("last_execution"), str)
else None,
last_error=str(entry_dict.get("last_error"))
if entry_dict.get("last_error") is not None
else None,
)
metrics[metric_name] = metric
return metric
@@ -187,11 +231,13 @@ def _update_metric_success(metric: NodeMetric | None, elapsed_ms: float) -> None
if metric is None:
return
metric["count"] = (metric["count"] or 0) + 1
metric["success_count"] = (metric["success_count"] or 0) + 1
metric["total_duration_ms"] = (metric["total_duration_ms"] or 0.0) + elapsed_ms
count = metric["count"] or 1
metric["avg_duration_ms"] = (metric["total_duration_ms"] or 0.0) / count
metric["last_execution"] = datetime.now(UTC).isoformat()
metric["last_error"] = None
def _update_metric_failure(metric: NodeMetric | None, elapsed_ms: float, error: Exception) -> None:
@@ -205,6 +251,7 @@ def _update_metric_failure(metric: NodeMetric | None, elapsed_ms: float, error:
if metric is None:
return
metric["count"] = (metric["count"] or 0) + 1
metric["failure_count"] = (metric["failure_count"] or 0) + 1
metric["total_duration_ms"] = (metric["total_duration_ms"] or 0.0) + elapsed_ms
count = metric["count"] or 1
@@ -215,7 +262,7 @@ def _update_metric_failure(metric: NodeMetric | None, elapsed_ms: float, error:
def track_metrics(
metric_name: str,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
) -> Callable[[Callable[P, Awaitable[R] | R]], Callable[P, Awaitable[R] | R]]:
"""Track metrics for node execution.
This decorator updates state with performance metrics including
@@ -228,15 +275,19 @@ def track_metrics(
Decorated function with metric tracking
"""
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
def decorator(func: Callable[P, Awaitable[R] | R]) -> Callable[P, Awaitable[R] | R]:
@functools.wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
start_time = time.time()
state = args[0] if args and isinstance(args[0], dict) else None
state = (
cast(MutableMapping[str, object], args[0])
if args and isinstance(args[0], MutableMapping)
else None
)
metric = _initialize_metric(state, metric_name)
try:
result = await func(*args, **kwargs)
result = await cast(Callable[P, Awaitable[R]], func)(*args, **kwargs)
elapsed_ms = (time.time() - start_time) * 1000
_update_metric_success(metric, elapsed_ms)
return result
@@ -246,13 +297,17 @@ def track_metrics(
raise
@functools.wraps(func)
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
start_time = time.time()
state = args[0] if args and isinstance(args[0], dict) else None
state = (
cast(MutableMapping[str, object], args[0])
if args and isinstance(args[0], MutableMapping)
else None
)
metric = _initialize_metric(state, metric_name)
try:
result = func(*args, **kwargs)
result = cast(Callable[P, R], func)(*args, **kwargs)
elapsed_ms = (time.time() - start_time) * 1000
_update_metric_success(metric, elapsed_ms)
return result
@@ -261,7 +316,9 @@ def track_metrics(
_update_metric_failure(metric, elapsed_ms, e)
raise
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
if asyncio.iscoroutinefunction(func):
return cast(Callable[P, Awaitable[R] | R], async_wrapper)
return cast(Callable[P, Awaitable[R] | R], sync_wrapper)
return decorator
@@ -269,10 +326,10 @@ def track_metrics(
def _handle_error(
error: Exception,
func_name: str,
args: tuple[Any, ...],
error_handler: Callable[[Exception], Any] | None,
fallback_value: Any
) -> Any:
args: tuple[object, ...],
error_handler: Callable[[Exception], None] | None,
fallback_value: object | None,
) -> object:
"""Common error handling logic for both sync and async functions.
Args:
@@ -297,12 +354,20 @@ def _handle_error(
error_handler(error)
# Update state with error if available
state = args[0] if args and isinstance(args[0], dict) else None
if state and "errors" in state:
if not isinstance(state["errors"], list):
state["errors"] = []
state = (
cast(MutableMapping[str, object], args[0])
if args and isinstance(args[0], MutableMapping)
else None
)
if state is not None:
errors_value = state.get("errors")
if isinstance(errors_value, list):
errors_list = errors_value
else:
errors_list = []
state["errors"] = errors_list
state["errors"].append(
errors_list.append(
{
"node": func_name,
"error": str(error),
@@ -314,13 +379,14 @@ def _handle_error(
# Return fallback value or re-raise
if fallback_value is not None:
return fallback_value
else:
raise
raise
def handle_errors(
error_handler: Callable[[Exception], Any] | None = None, fallback_value: Any = None
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
error_handler: Callable[[Exception], None] | None = None,
fallback_value: R | None = None,
) -> Callable[[Callable[P, Awaitable[R] | R]], Callable[P, Awaitable[R] | R]]:
"""Handle errors with standardized error handling in nodes.
This decorator provides consistent error handling with optional
@@ -334,22 +400,38 @@ def handle_errors(
Decorated function with error handling
"""
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
def decorator(func: Callable[P, Awaitable[R] | R]) -> Callable[P, Awaitable[R] | R]:
@functools.wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
try:
return await func(*args, **kwargs)
return await cast(Callable[P, Awaitable[R]], func)(*args, **kwargs)
except Exception as e:
return _handle_error(e, func.__name__, args, error_handler, fallback_value)
result = _handle_error(
e,
func.__name__,
cast(tuple[object, ...], args),
error_handler,
fallback_value,
)
return cast(R, result)
@functools.wraps(func)
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
try:
return func(*args, **kwargs)
return cast(Callable[P, R], func)(*args, **kwargs)
except Exception as e:
return _handle_error(e, func.__name__, args, error_handler, fallback_value)
result = _handle_error(
e,
func.__name__,
cast(tuple[object, ...], args),
error_handler,
fallback_value,
)
return cast(R, result)
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
if asyncio.iscoroutinefunction(func):
return cast(Callable[P, Awaitable[R] | R], async_wrapper)
return cast(Callable[P, Awaitable[R] | R], sync_wrapper)
return decorator
@@ -358,7 +440,7 @@ def retry_on_failure(
max_attempts: int = 3,
backoff_seconds: float = 1.0,
exponential_backoff: bool = True,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
) -> Callable[[Callable[P, Awaitable[R] | R]], Callable[P, Awaitable[R] | R]]:
"""Retry node execution on failure.
Args:
@@ -370,14 +452,14 @@ def retry_on_failure(
Decorated function with retry logic
"""
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
def decorator(func: Callable[P, Awaitable[R] | R]) -> Callable[P, Awaitable[R] | R]:
@functools.wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
last_exception = None
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
last_exception: Exception | None = None
for attempt in range(max_attempts):
try:
return await func(*args, **kwargs)
return await cast(Callable[P, Awaitable[R]], func)(*args, **kwargs)
except Exception as e:
last_exception = e
@@ -406,12 +488,12 @@ def retry_on_failure(
)
@functools.wraps(func)
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
last_exception = None
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
last_exception: Exception | None = None
for attempt in range(max_attempts):
try:
return func(*args, **kwargs)
return cast(Callable[P, R], func)(*args, **kwargs)
except Exception as e:
last_exception = e
@@ -432,21 +514,22 @@ def retry_on_failure(
f"{func.__name__}: {str(e)}"
)
if last_exception:
if last_exception is not None:
raise last_exception
else:
raise RuntimeError(
f"Unexpected error in retry logic for {func.__name__}"
)
raise RuntimeError(
f"Unexpected error in retry logic for {func.__name__}"
)
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
if asyncio.iscoroutinefunction(func):
return cast(Callable[P, Awaitable[R] | R], async_wrapper)
return cast(Callable[P, Awaitable[R] | R], sync_wrapper)
return decorator
def _extract_context_from_args(
args: tuple[Any, ...], kwargs: dict[str, Any]
) -> dict[str, Any]:
args: tuple[object, ...], kwargs: Mapping[str, object]
) -> dict[str, object]:
"""Extract context information from function arguments.
Looks for RunnableConfig in args/kwargs and extracts relevant context.
@@ -458,44 +541,35 @@ def _extract_context_from_args(
Returns:
Dictionary with extracted context (run_id, user_id, etc.)
"""
context = {}
context: dict[str, object] = {}
# Check if we can extract RunnableConfig-like data
try:
# Check kwargs for config
if "config" in kwargs:
config = kwargs["config"]
# Check if config looks like a RunnableConfig (duck typing for TypedDict)
if (
isinstance(config, dict)
and any(
key in config
for key in [
"tags",
"metadata",
"callbacks",
"run_name",
"configurable",
]
)
and "metadata" in config
and isinstance(config["metadata"], dict)
):
context |= config["metadata"]
def _sanitize_mapping(meta: Mapping[str, object]) -> dict[str, object]:
cleaned: dict[str, object] = {}
for key, value in meta.items():
if isinstance(value, (str, int, float, bool)) or value is None:
cleaned[key] = value
else:
cleaned[key] = str(value)
return cleaned
# Check args for RunnableConfig
for arg in args:
# Check if arg looks like a RunnableConfig (duck typing for TypedDict)
if isinstance(arg, dict) and any(
key in arg
for key in ["tags", "metadata", "callbacks", "run_name", "configurable"]
):
if "metadata" in arg and isinstance(arg["metadata"], dict):
context |= arg["metadata"]
break
except ImportError:
# LangChain not available, skip RunnableConfig extraction
pass
# Prefer explicit config kwarg when available
if "config" in kwargs:
config = kwargs["config"]
if isinstance(config, Mapping):
metadata = config.get("metadata")
if isinstance(metadata, Mapping):
context.update(_sanitize_mapping(metadata))
# Fallback: inspect args for RunnableConfig-like payloads
for arg in args:
if not isinstance(arg, Mapping):
continue
keys = set(arg.keys())
if {"metadata"}.issubset(keys) and keys & {"tags", "callbacks", "run_name", "configurable"}:
metadata = arg.get("metadata")
if isinstance(metadata, Mapping):
context.update(_sanitize_mapping(metadata))
break
return context
@@ -505,7 +579,7 @@ def standard_node(
node_name: str | None = None,
metric_name: str | None = None,
retry_attempts: int = 0,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
) -> Callable[[Callable[P, Awaitable[R] | R]], Callable[P, Awaitable[R] | R]]:
"""Composite decorator applying standard cross-cutting concerns.
This decorator combines logging, metrics, error handling, and retries
@@ -520,9 +594,9 @@ def standard_node(
Decorated function with all standard concerns applied
"""
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
def decorator(func: Callable[P, Awaitable[R] | R]) -> Callable[P, Awaitable[R] | R]:
# Apply decorators in order (innermost to outermost)
decorated = func
decorated: Callable[P, Awaitable[R] | R] = func
# Add retry if requested
if retry_attempts > 0:
@@ -545,7 +619,7 @@ def standard_node(
return decorator
def route_error_severity(state: dict[str, Any]) -> str:
def route_error_severity(state: Mapping[str, object]) -> str:
"""Route based on error severity in state.
This function examines the state for error information and routes
@@ -560,7 +634,7 @@ def route_error_severity(state: dict[str, Any]) -> str:
# Check for errors in state
if "errors" in state and state["errors"]:
errors = state["errors"]
if isinstance(errors, list) and errors:
if isinstance(errors, Sequence) and errors:
# Count critical errors
critical_count = 0
for error in errors:
@@ -606,7 +680,7 @@ def route_error_severity(state: dict[str, Any]) -> str:
return "END"
def route_llm_output(state: dict[str, Any]) -> str:
def route_llm_output(state: Mapping[str, object]) -> str:
"""Route based on LLM output analysis.
This function examines LLM output in the state and determines

View File

@@ -1,30 +1,29 @@
"""Centralized graph builder to reduce duplication in graph creation."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable, Generic, Sequence, TypeVar, Union, cast
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.graph import END, START, StateGraph
from langgraph.graph.state import CachePolicy, CompiledStateGraph, RetryPolicy
from langgraph.cache.base import BaseCache
from langgraph.store.base import BaseStore
from langgraph.types import All
from biz_bud.logging import get_logger
logger = get_logger(__name__)
# Type variable for state types - using Any for maximum compatibility with LangGraph
StateT = TypeVar("StateT", bound=Any)
# Type aliases for clarity - LangGraph nodes can be sync or async with various signatures
NodeFunction = Any # Maximum flexibility for LangGraph node functions
RouterFunction = Callable[[Any], str] # Routers should return strings for LangGraph compatibility
ConditionalMapping = dict[str, str]
"""Centralized graph builder to reduce duplication in graph creation."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Generic, TypeVar, cast
from collections.abc import Awaitable, Callable, Mapping, Sequence
from langgraph.cache.base import BaseCache
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.graph import END, START, StateGraph
from langgraph.graph.state import CachePolicy, CompiledStateGraph as CompiledGraph, RetryPolicy
from langgraph.store.base import BaseStore
from langgraph.types import All
from biz_bud.logging import get_logger
logger = get_logger(__name__)
# Type aliases for clarity - LangGraph nodes can be sync or async with various signatures
RuntimeState = dict[str, object]
StateT = TypeVar("StateT", bound=RuntimeState)
NodeFunction = Callable[..., object | Awaitable[object]]
RouterFunction = Callable[[Mapping[str, object]], str]
ConditionalMapping = dict[str, str]
@dataclass
@@ -33,10 +32,10 @@ class NodeConfig:
func: NodeFunction
defer: bool = False
metadata: dict[str, Any] | None = None
input_schema: type[Any] | None = None
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None
cache_policy: CachePolicy | None = None
metadata: dict[str, object] | None = None
input_schema: type[object] | None = None
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None
cache_policy: CachePolicy | None = None
@dataclass
@@ -55,29 +54,29 @@ class ConditionalEdgeConfig:
@dataclass
class GraphBuilderConfig(Generic[StateT]):
class GraphBuilderConfig(Generic[StateT]):
"""Configuration for building a StateGraph."""
state_class: type[StateT]
context_class: type[Any] | None = None
input_schema: type[Any] | None = None
output_schema: type[Any] | None = None
nodes: dict[str, NodeFunction | NodeConfig] = field(default_factory=dict)
edges: list[EdgeConfig] = field(default_factory=list)
conditional_edges: list[ConditionalEdgeConfig] = field(default_factory=list)
checkpointer: BaseCheckpointSaver[Any] | None = None
cache: BaseCache[Any] | None = None
store: BaseStore[Any] | None = None
interrupt_before: All | list[str] | None = None
interrupt_after: All | list[str] | None = None
debug: bool = False
name: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
context_class: type[object] | None = None
input_schema: type[object] | None = None
output_schema: type[object] | None = None
nodes: dict[str, NodeFunction | NodeConfig] = field(default_factory=dict)
edges: list[EdgeConfig] = field(default_factory=list)
conditional_edges: list[ConditionalEdgeConfig] = field(default_factory=list)
checkpointer: BaseCheckpointSaver[RuntimeState] | None = None
cache: BaseCache[object] | None = None
store: BaseStore[object] | None = None
interrupt_before: All | list[str] | None = None
interrupt_after: All | list[str] | None = None
debug: bool = False
name: str | None = None
metadata: dict[str, object] = field(default_factory=dict)
def build_graph_from_config(
config: GraphBuilderConfig[StateT],
) -> CompiledStateGraph[StateT]:
def build_graph_from_config(
config: GraphBuilderConfig[StateT],
) -> CompiledGraph[StateT]:
"""Build a StateGraph from configuration.
This function reduces duplication by centralizing the graph building logic.
@@ -154,7 +153,7 @@ def build_graph_from_config(
logger.debug(f"Added conditional edge from: {cond_edge.source}")
# Compile with optional runtime configuration
compiled = builder.compile(
compiled = builder.compile(
checkpointer=config.checkpointer,
cache=config.cache,
store=config.store,
@@ -169,38 +168,45 @@ def build_graph_from_config(
setattr(compiled.builder, "config", config)
logger.info("Graph compilation completed successfully")
return compiled
return compiled
class GraphBuilder(Generic[StateT]):
"""Fluent API for building graphs."""
def __init__(
self,
state_class: type[StateT],
*,
context_class: type[Any] | None = None,
input_schema: type[Any] | None = None,
output_schema: type[Any] | None = None,
):
"""Initialize the builder with a state class."""
self.config = GraphBuilderConfig(
state_class=state_class,
context_class=context_class,
input_schema=input_schema,
output_schema=output_schema,
)
def __init__(
self,
state_class: type[StateT],
*,
context_class: type[object] | None = None,
input_schema: type[object] | None = None,
output_schema: type[object] | None = None,
):
"""Initialize the builder with a state class."""
if not isinstance(state_class, type): # pragma: no cover - defensive guard
raise TypeError("GraphBuilder requires a state class type.")
if not issubclass(state_class, Mapping) and not hasattr(state_class, "get"):
raise TypeError(
"GraphBuilder requires a mapping-compatible state class that exposes a 'get' method."
)
self.config = GraphBuilderConfig(
state_class=state_class,
context_class=context_class,
input_schema=input_schema,
output_schema=output_schema,
)
def add_node(
self,
name: str,
func: NodeFunction,
*,
defer: bool = False,
metadata: dict[str, Any] | None = None,
input_schema: type[Any] | None = None,
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None,
cache_policy: CachePolicy | None = None,
name: str,
func: NodeFunction,
*,
defer: bool = False,
metadata: dict[str, object] | None = None,
input_schema: type[object] | None = None,
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None,
cache_policy: CachePolicy | None = None,
) -> "GraphBuilder[StateT]":
"""Add a node to the graph with modern LangGraph options."""
@@ -231,40 +237,40 @@ class GraphBuilder(Generic[StateT]):
)
return self
def with_checkpointer(
self, checkpointer: BaseCheckpointSaver[Any]
) -> "GraphBuilder[StateT]":
def with_checkpointer(
self, checkpointer: BaseCheckpointSaver[RuntimeState]
) -> "GraphBuilder[StateT]":
"""Set the checkpointer for the graph."""
self.config.checkpointer = checkpointer
return self
def with_context(
self, context_class: type[Any]
) -> "GraphBuilder[StateT]":
def with_context(
self, context_class: type[object]
) -> "GraphBuilder[StateT]":
"""Define the context schema for runtime access."""
self.config.context_class = context_class
return self
def with_input_schema(
self, input_schema: type[Any]
) -> "GraphBuilder[StateT]":
def with_input_schema(
self, input_schema: type[object]
) -> "GraphBuilder[StateT]":
"""Define the input schema for the graph."""
self.config.input_schema = input_schema
return self
def with_output_schema(
self, output_schema: type[Any]
) -> "GraphBuilder[StateT]":
def with_output_schema(
self, output_schema: type[object]
) -> "GraphBuilder[StateT]":
"""Define the output schema for the graph."""
self.config.output_schema = output_schema
return self
def with_cache(self, cache: BaseCache[Any]) -> "GraphBuilder[StateT]":
def with_cache(self, cache: BaseCache[object]) -> "GraphBuilder[StateT]":
"""Attach a LangGraph cache implementation."""
self.config.cache = cache
return self
def with_store(self, store: BaseStore[Any]) -> "GraphBuilder[StateT]":
def with_store(self, store: BaseStore[object]) -> "GraphBuilder[StateT]":
"""Attach a LangGraph store implementation."""
self.config.store = store
return self
@@ -290,31 +296,31 @@ class GraphBuilder(Generic[StateT]):
self.config.debug = enabled
return self
def with_metadata(self, **kwargs: Any) -> "GraphBuilder[StateT]":
"""Add metadata to the graph."""
self.config.metadata.update(kwargs)
return self
def with_metadata(self, **kwargs: object) -> "GraphBuilder[StateT]":
"""Add metadata to the graph."""
self.config.metadata.update(kwargs)
return self
def build(self) -> CompiledStateGraph[StateT]:
def build(self) -> CompiledGraph[StateT]:
"""Build and compile the graph."""
return build_graph_from_config(self.config)
def create_simple_linear_graph(
state_class: type[StateT],
nodes: list[tuple[str, NodeFunction] | tuple[str, NodeFunction, dict[str, Any]] | tuple[str, NodeFunction, NodeConfig]],
checkpointer: BaseCheckpointSaver[Any] | None = None,
*,
context_class: type[Any] | None = None,
input_schema: type[Any] | None = None,
output_schema: type[Any] | None = None,
cache: BaseCache[Any] | None = None,
store: BaseStore[Any] | None = None,
def create_simple_linear_graph(
state_class: type[StateT],
nodes: list[tuple[str, NodeFunction] | tuple[str, NodeFunction, dict[str, object]] | tuple[str, NodeFunction, NodeConfig]],
checkpointer: BaseCheckpointSaver[RuntimeState] | None = None,
*,
context_class: type[object] | None = None,
input_schema: type[object] | None = None,
output_schema: type[object] | None = None,
cache: BaseCache[object] | None = None,
store: BaseStore[object] | None = None,
name: str | None = None,
interrupt_before: All | list[str] | None = None,
interrupt_after: All | list[str] | None = None,
debug: bool = False,
) -> CompiledStateGraph[StateT]:
) -> CompiledGraph[StateT]:
"""Create a simple linear graph where nodes execute in sequence.
Args:
@@ -348,21 +354,21 @@ def create_simple_linear_graph(
# Add all nodes
for entry in nodes:
if len(entry) == 2:
name, func = entry
node_kwargs: dict[str, Any] = {}
else:
name, func, extras = entry
if isinstance(extras, NodeConfig):
node_kwargs = {
"defer": extras.defer,
if len(entry) == 2:
name, func = entry
node_kwargs: dict[str, object] = {}
else:
name, func, extras = entry
if isinstance(extras, NodeConfig):
node_kwargs = {
"defer": extras.defer,
"metadata": extras.metadata,
"input_schema": extras.input_schema,
"retry_policy": extras.retry_policy,
"cache_policy": extras.cache_policy,
}
else:
node_kwargs = cast(dict[str, Any], extras)
node_kwargs = cast(dict[str, object], extras)
builder.add_node(name, func, **node_kwargs)
# Connect nodes linearly
@@ -388,30 +394,30 @@ def create_simple_linear_graph(
return builder.build()
def create_branching_graph(
state_class: type[StateT],
initial_node: tuple[str, NodeFunction],
router: RouterFunction,
branches: dict[
str,
list[
tuple[str, NodeFunction]
| tuple[str, NodeFunction, dict[str, Any]]
| tuple[str, NodeFunction, NodeConfig]
],
],
checkpointer: BaseCheckpointSaver[Any] | None = None,
*,
context_class: type[Any] | None = None,
input_schema: type[Any] | None = None,
output_schema: type[Any] | None = None,
cache: BaseCache[Any] | None = None,
store: BaseStore[Any] | None = None,
def create_branching_graph(
state_class: type[StateT],
initial_node: tuple[str, NodeFunction],
router: RouterFunction,
branches: dict[
str,
list[
tuple[str, NodeFunction]
| tuple[str, NodeFunction, dict[str, object]]
| tuple[str, NodeFunction, NodeConfig]
],
],
checkpointer: BaseCheckpointSaver[RuntimeState] | None = None,
*,
context_class: type[object] | None = None,
input_schema: type[object] | None = None,
output_schema: type[object] | None = None,
cache: BaseCache[object] | None = None,
store: BaseStore[object] | None = None,
name: str | None = None,
interrupt_before: All | list[str] | None = None,
interrupt_after: All | list[str] | None = None,
debug: bool = False,
) -> CompiledStateGraph[StateT]:
) -> CompiledGraph[StateT]:
"""Create a graph with branching logic.
Args:
@@ -445,21 +451,21 @@ def create_branching_graph(
# Add nodes in this branch
for entry in branch_nodes:
if len(entry) == 2:
name, func = entry
node_kwargs: dict[str, Any] = {}
else:
name, func, extras = entry
if isinstance(extras, NodeConfig):
node_kwargs = {
"defer": extras.defer,
if len(entry) == 2:
name, func = entry
node_kwargs: dict[str, object] = {}
else:
name, func, extras = entry
if isinstance(extras, NodeConfig):
node_kwargs = {
"defer": extras.defer,
"metadata": extras.metadata,
"input_schema": extras.input_schema,
"retry_policy": extras.retry_policy,
"cache_policy": extras.cache_policy,
}
else:
node_kwargs = cast(dict[str, Any], extras)
node_kwargs = cast(dict[str, object], extras)
builder.add_node(name, func, **node_kwargs)
# Connect nodes within branch

View File

@@ -1,27 +1,33 @@
"""Graph configuration utilities for LangGraph integration.
This module provides utilities for configuring graphs with RunnableConfig,
enabling consistent configuration injection across all nodes and tools.
This module centralises helpers that wrap LangGraph nodes with
``RunnableConfig`` metadata so graph execution receives application
configuration, service factories, and runtime overrides consistently.
"""
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
from typing import Any, cast
from collections.abc import Awaitable, Callable, Mapping
from copy import copy
from typing import TypeVar, cast
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langgraph.graph import StateGraph
from .runnable_config import ConfigurationProvider, create_runnable_config
StateT = TypeVar("StateT", bound=dict[str, object])
NodeCallable = Callable[..., object | Awaitable[object]]
def configure_graph_with_injection(
graph_builder: StateGraph[Any],
app_config: Any,
service_factory: Any | None = None,
**config_overrides: Any,
) -> StateGraph[Any]:
graph_builder: StateGraph[StateT],
app_config: object,
service_factory: object | None = None,
**config_overrides: object,
) -> StateGraph[StateT]:
"""Configure a graph builder with dependency injection.
This function wraps all nodes in the graph to automatically inject
@@ -44,25 +50,51 @@ def configure_graph_with_injection(
# Get all nodes that have been added
nodes = graph_builder.nodes
# Wrap each node to inject configuration
for node_name, node_func in nodes.items():
# Skip if already configured
if hasattr(node_func, "_runnable_config_injected"):
# Plan node replacements without mutating while iterating
replacements: list[tuple[str, object]] = []
for node_name, node_entry in list(nodes.items()):
target = _resolve_node_callable(node_entry)
if target is None or hasattr(target, "_runnable_config_injected"):
continue
# Create wrapper that injects config
if callable(node_func):
callable_node = cast(Callable[..., Any] | Callable[..., Awaitable[Any]], node_func)
wrapped_node = create_config_injected_node(callable_node, base_config)
# Replace the node
graph_builder.nodes[node_name] = wrapped_node
wrapped_node = create_config_injected_node(target, base_config)
if hasattr(node_entry, "func"):
new_entry = copy(node_entry)
setattr(new_entry, "func", wrapped_node)
replacements.append((node_name, new_entry))
else:
replacements.append((node_name, wrapped_node))
for node_name, new_entry in replacements:
try:
graph_builder.add_node(node_name, new_entry) # type: ignore[arg-type]
except Exception as exc:
raise RuntimeError(
f"Failed to inject configuration for node '{node_name}'. "
"Ensure injection occurs before graph compilation."
) from exc
return graph_builder
def _resolve_node_callable(node_entry: object) -> NodeCallable | None:
"""Extract the callable associated with a node entry."""
if hasattr(node_entry, "func"):
func = getattr(node_entry, "func")
if callable(func):
return cast(NodeCallable, func)
if callable(node_entry):
return cast(NodeCallable, node_entry)
return None
def create_config_injected_node(
node_func: Callable[..., Any] | Callable[..., Awaitable[Any]], base_config: RunnableConfig
) -> Any:
node_func: NodeCallable, base_config: RunnableConfig
) -> object:
"""Create a node wrapper that injects RunnableConfig.
Args:
@@ -76,42 +108,64 @@ def create_config_injected_node(
import inspect
from functools import wraps
from langchain_core.runnables import RunnableLambda
sig = inspect.signature(node_func)
expects_config = "config" in sig.parameters
accepts_config = "config" in sig.parameters
is_coroutine = asyncio.iscoroutinefunction(node_func)
def _merge_config(runtime: RunnableConfig | Mapping[str, object] | None) -> RunnableConfig:
provider = ConfigurationProvider(base_config)
if runtime is None:
return provider.to_runnable_config()
if isinstance(runtime, RunnableConfig):
return provider.merge_with(runtime).to_runnable_config()
if isinstance(runtime, Mapping):
runtime_config = RunnableConfig()
for key, value in runtime.items():
runtime_config[key] = value
return provider.merge_with(runtime_config).to_runnable_config()
return provider.to_runnable_config()
def _inject_kwargs(
args: tuple[object, ...],
kwargs: dict[str, object | None],
cfg: RunnableConfig,
) -> tuple[tuple[object, ...], dict[str, object]]:
if accepts_config and "config" not in kwargs:
updated_kwargs: dict[str, object] = dict(kwargs)
updated_kwargs["config"] = cfg
return args, updated_kwargs
return args, cast(dict[str, object], kwargs)
if is_coroutine:
if expects_config:
# Node already expects config, wrap to provide it
@wraps(node_func)
async def config_aware_wrapper(
state: dict[str, Any], config: RunnableConfig | None = None
# noqa: ARG001
) -> Any:
# Merge base config with runtime config
if config:
provider = ConfigurationProvider(base_config)
merged_provider = provider.merge_with(config)
merged_config = merged_provider.to_runnable_config()
else:
merged_config = base_config
async def wrapper(*args: object, **kwargs: object) -> object:
cfg_arg = kwargs.get("config")
merged = _merge_config(
cast(RunnableConfig | Mapping[str, object] | None, cfg_arg)
)
args2, kwargs2 = _inject_kwargs(args, kwargs, merged)
return await cast(Awaitable[object], node_func(*args2, **kwargs2))
# Call original node with merged config
if asyncio.iscoroutinefunction(node_func):
coro_result = node_func(state, config=merged_config)
return await cast(Awaitable[Any], coro_result)
else:
return node_func(state, config=merged_config)
wrapped: Any = RunnableLambda(config_aware_wrapper).with_config(base_config)
else:
# Node doesn't expect config, just wrap with config
wrapped = RunnableLambda(node_func).with_config(base_config)
# Mark as injected to avoid double wrapping
wrapped._runnable_config_injected = True
@wraps(node_func)
def wrapper(*args: object, **kwargs: object) -> object:
cfg_arg = kwargs.get("config")
merged = _merge_config(
cast(RunnableConfig | Mapping[str, object] | None, cfg_arg)
)
args2, kwargs2 = _inject_kwargs(args, kwargs, merged)
return node_func(*args2, **kwargs2)
return wrapped
try:
wrapper.__dict__.update(copy(getattr(node_func, "__dict__", {}))) # type: ignore[attr-defined]
except Exception:
pass
setattr(wrapper, "_runnable_config_injected", True)
runnable = RunnableLambda(wrapper).with_config(base_config)
setattr(runnable, "_runnable_config_injected", True)
return runnable
def update_node_to_use_config(
@@ -137,42 +191,25 @@ def update_node_to_use_config(
sig = inspect.signature(node_func)
# Check if already accepts config
# If the node already consumes a config argument we return it untouched so
# callers keep the original implementation.
if "config" in sig.parameters:
return node_func
@wraps(node_func)
async def wrapper(
state: dict[str, object], config: RunnableConfig | None = None
# noqa: ARG001
) -> object:
# Call original without config (for backward compatibility)
async def wrapper(state: dict[str, object], config: RunnableConfig) -> object:
if asyncio.iscoroutinefunction(node_func):
coro_result = node_func(state)
return await cast(Awaitable[Any], coro_result)
else:
return node_func(state)
return await cast(Awaitable[object], coro_result)
return node_func(state)
return wrapper
def extract_config_from_state(state: dict[str, object]) -> object | None:
"""Extract AppConfig from state for backward compatibility.
"""Return the runnable configuration stored on the state mapping."""
Args:
state: The graph state dictionary.
Returns:
The AppConfig if found in state, None otherwise.
"""
# Check common locations
if "config" in state and isinstance(state["config"], dict):
# Try to reconstruct AppConfig from dict
try:
# Import would need to be dynamic to avoid circular dependencies
# This is a placeholder implementation
return state["config"]
except Exception:
pass
return state.get("app_config")
cfg = state.get("config")
if isinstance(cfg, Mapping):
return dict(cfg)
return cfg

View File

@@ -1,10 +1,9 @@
"""RunnableConfig utilities for LangGraph integration.
"""Utilities for working with LangChain's :class:`RunnableConfig`."""
This module provides utilities for working with LangChain's RunnableConfig
to enable configuration injection and dependency management in LangGraph workflows.
"""
from __future__ import annotations
from typing import Any, TypeVar, cast
from collections.abc import Mapping
from typing import TypeVar, cast
from langchain_core.runnables import RunnableConfig
@@ -26,28 +25,35 @@ class ConfigurationProvider:
"""
self._config = config or RunnableConfig()
def _get_configurable(self) -> dict[str, Any]:
def _get_configurable(self) -> Mapping[str, object] | object:
"""Get the configurable section from config.
Handles both dict-like and attribute-like access to RunnableConfig.
Returns:
The configurable dictionary.
The configurable mapping or the original value when not mapping.
"""
return self._config.get("configurable", {})
configurable = self._config.get("configurable", {})
if isinstance(configurable, Mapping):
return dict(configurable)
return configurable
def get_metadata(self, key: str, default: Any = None) -> Any:
"""Get metadata value from the configuration.
def _get_mapping(
self, key: str, config: RunnableConfig | None = None
) -> dict[str, object]:
"""Safely extract a mapping from the runnable configuration."""
Args:
key: The metadata key to retrieve.
default: Default value if key not found.
source = (config or self._config).get(key, {})
if isinstance(source, Mapping):
return dict(source)
return {}
Returns:
The metadata value or default.
"""
def get_metadata(self, key: str, default: object | None = None) -> object | None:
"""Get metadata value from the configuration."""
metadata = self._config.get("metadata", {})
return metadata.get(key, default)
if isinstance(metadata, Mapping):
return dict(metadata).get(key, default)
return default
def get_run_id(self) -> str | None:
"""Get the current run ID from configuration.
@@ -76,23 +82,27 @@ class ConfigurationProvider:
session_id = self.get_metadata("session_id")
return session_id if isinstance(session_id, str) else None
def get_app_config(self) -> Any | None:
def get_app_config(self) -> object | None:
"""Get the application configuration object.
Returns:
The app configuration if available, None otherwise.
"""
configurable = self._get_configurable()
return configurable.get("app_config")
if isinstance(configurable, Mapping):
return configurable.get("app_config")
return None
def get_service_factory(self) -> Any | None:
def get_service_factory(self) -> object | None:
"""Get the service factory instance.
Returns:
The service factory if available, None otherwise.
"""
configurable = self._get_configurable()
return configurable.get("service_factory")
if isinstance(configurable, Mapping):
return configurable.get("service_factory")
return None
def get_llm_profile(self) -> str:
"""Get the LLM profile to use.
@@ -101,7 +111,10 @@ class ConfigurationProvider:
The LLM profile name, defaults to "large".
"""
configurable = self._get_configurable()
profile = configurable.get("llm_profile_override", "large")
if isinstance(configurable, Mapping):
profile = configurable.get("llm_profile_override", "large")
else:
profile = "large"
return profile if isinstance(profile, str) else "large"
def get_temperature_override(self) -> float | None:
@@ -111,8 +124,10 @@ class ConfigurationProvider:
The temperature override if set, None otherwise.
"""
configurable = self._get_configurable()
value = configurable.get("temperature_override")
return float(value) if isinstance(value, int | float) else None
if isinstance(configurable, Mapping):
value = configurable.get("temperature_override")
return float(value) if isinstance(value, int | float) else None
return None
def get_max_tokens_override(self) -> int | None:
"""Get max tokens override for LLM calls.
@@ -121,6 +136,8 @@ class ConfigurationProvider:
The max tokens override if set, None otherwise.
"""
configurable = self._get_configurable()
if not isinstance(configurable, Mapping):
return None
value = configurable.get("max_tokens_override")
try:
return int(value) if isinstance(value, int | float) else None
@@ -134,7 +151,10 @@ class ConfigurationProvider:
True if streaming is enabled, False otherwise.
"""
configurable = self._get_configurable()
value = configurable.get("streaming_enabled", False)
if isinstance(configurable, Mapping):
value = configurable.get("streaming_enabled", False)
else:
value = False
return bool(value)
def is_metrics_enabled(self) -> bool:
@@ -144,7 +164,10 @@ class ConfigurationProvider:
True if metrics are enabled, False otherwise.
"""
configurable = self._get_configurable()
value = configurable.get("metrics_enabled", True)
if isinstance(configurable, Mapping):
value = configurable.get("metrics_enabled", True)
else:
value = True
return bool(value)
def get_custom_value(self, key: str, default: T | None = None) -> T | None:
@@ -158,8 +181,10 @@ class ConfigurationProvider:
The configuration value or default.
"""
configurable = self._get_configurable()
result = configurable.get(key, default)
return cast("T", result) if result is not None else default
if isinstance(configurable, Mapping):
result = configurable.get(key, default)
return cast(T, result) if result is not None else default
return default
def merge_with(self, other: RunnableConfig) -> "ConfigurationProvider":
"""Create a new provider by merging with another config.
@@ -174,13 +199,13 @@ class ConfigurationProvider:
merged = RunnableConfig()
# Copy metadata
self_metadata = self._config.get("metadata", {})
other_metadata = other.get("metadata", {})
self_metadata = self._get_mapping("metadata")
other_metadata = self._get_mapping("metadata", other)
merged["metadata"] = {**self_metadata, **other_metadata}
# Copy configurable
self_configurable = self._config.get("configurable", {})
other_configurable = other.get("configurable", {})
self_configurable = self._get_mapping("configurable")
other_configurable = self._get_mapping("configurable", other)
merged["configurable"] = {**self_configurable, **other_configurable}
# Copy other attributes
@@ -202,7 +227,10 @@ class ConfigurationProvider:
@classmethod
def from_app_config(
cls, app_config: Any, service_factory: Any | None = None, **metadata: Any
cls,
app_config: object,
service_factory: object | None = None,
**metadata: object,
) -> "ConfigurationProvider":
"""Create a provider from app configuration.
@@ -217,21 +245,21 @@ class ConfigurationProvider:
config = RunnableConfig()
# Set configurable values
configurable_dict = {
configurable_dict: dict[str, object | None] = {
"app_config": app_config,
"service_factory": service_factory,
}
config["configurable"] = configurable_dict
# Set metadata
config["metadata"] = metadata
config["metadata"] = dict(metadata)
return cls(config)
def create_runnable_config(
app_config: Any | None = None,
service_factory: Any | None = None,
app_config: object | None = None,
service_factory: object | None = None,
llm_profile_override: str | None = None,
temperature_override: float | None = None,
max_tokens_override: int | None = None,
@@ -240,7 +268,7 @@ def create_runnable_config(
run_id: str | None = None,
user_id: str | None = None,
session_id: str | None = None,
**custom_values: Any,
**custom_values: object,
) -> RunnableConfig:
"""Create a RunnableConfig with common settings.
@@ -263,7 +291,7 @@ def create_runnable_config(
config = RunnableConfig()
# Set configurable values
configurable = {
configurable: dict[str, object] = {
"metrics_enabled": metrics_enabled,
"streaming_enabled": streaming_enabled,
**custom_values,

View File

@@ -15,23 +15,27 @@ used during development and testing. For production, consider:
from __future__ import annotations
import copy
from collections.abc import Callable
from typing import Any, TypeVar, cast
from collections.abc import (
Callable,
Iterator,
ItemsView,
KeysView,
Mapping,
MutableMapping,
Sequence,
ValuesView,
)
from types import MappingProxyType
from typing import cast
try: # pragma: no cover - pandas is optional in lightweight test environments
import pandas as pd # type: ignore
except ModuleNotFoundError: # pragma: no cover - executed when pandas isn't installed
pd = None # type: ignore[assignment]
from typing_extensions import ParamSpec
from biz_bud.core.errors import ImmutableStateError, StateValidationError
P = ParamSpec("P")
T = TypeVar("T")
CallableT = TypeVar("CallableT", bound=Callable[..., object])
def _states_equal(state1: Any, state2: Any) -> bool:
def _states_equal(state1: object, state2: object) -> bool:
"""Compare two states safely, handling DataFrames and other complex objects.
Args:
@@ -44,7 +48,7 @@ def _states_equal(state1: Any, state2: Any) -> bool:
if type(state1) is not type(state2):
return False
if isinstance(state1, dict) and isinstance(state2, dict):
if isinstance(state1, Mapping) and isinstance(state2, Mapping):
if set(state1.keys()) != set(state2.keys()):
return False
@@ -53,7 +57,9 @@ def _states_equal(state1: Any, state2: Any) -> bool:
return False
return True
elif isinstance(state1, (list, tuple)):
elif isinstance(state1, Sequence) and isinstance(state2, Sequence) and not isinstance(
state1, (str, bytes, bytearray)
):
if len(state1) != len(state2):
return False
return all(_states_equal(a, b) for a, b in zip(state1, state2))
@@ -90,221 +96,92 @@ def _states_equal(state1: Any, state2: Any) -> bool:
return False
class ImmutableDict:
"""An immutable dictionary that prevents modifications.
class ImmutableDict(Mapping[str, object]):
"""An immutable mapping that prevents modifications."""
This class wraps a regular dict but prevents any modifications,
raising ImmutableStateError on mutation attempts.
"""
def __init__(
self, initial: Mapping[str, object] | None = None, **kwargs: object
) -> None:
data: dict[str, object] = dict(initial or {})
if kwargs:
data.update(kwargs)
self._data = copy.deepcopy(data)
self._proxy = MappingProxyType(self._data)
_data: dict[str, Any]
_frozen: bool
def _raise_mutation_error(self) -> None:
raise ImmutableStateError(
"Cannot modify immutable state. Create a new state object instead."
)
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize ImmutableDict.
def __getitem__(self, key: str) -> object:
return self._proxy[key]
Args:
*args: Positional arguments passed to dict constructor.
**kwargs: Keyword arguments passed to dict constructor.
"""
self._data = dict(*args, **kwargs)
self._frozen = True
def _check_frozen(self) -> None:
if self._frozen:
raise ImmutableStateError(
"Cannot modify immutable state. Create a new state object instead."
)
# Read-only dict interface
def __getitem__(self, key: str) -> Any:
"""Get item by key.
Args:
key: The key to retrieve.
Returns:
The value associated with the key.
"""
return self._data[key]
def __contains__(self, key: object) -> bool:
"""Check if key is in the dictionary.
Args:
key: The key to check for.
Returns:
True if key exists, False otherwise.
"""
return key in self._data
def __iter__(self) -> Any:
"""Return an iterator over the dictionary keys.
Returns:
An iterator over the keys.
"""
return iter(self._data)
def __iter__(self) -> Iterator[str]:
return iter(self._proxy)
def __len__(self) -> int:
"""Return the number of items in the dictionary.
return len(self._proxy)
Returns:
The number of key-value pairs.
"""
return len(self._data)
def __contains__(self, key: object) -> bool:
return key in self._proxy
def __repr__(self) -> str:
"""Return a string representation of the dictionary.
return f"ImmutableDict({dict(self._proxy)!r})"
Returns:
A string representation of the ImmutableDict.
"""
return f"ImmutableDict({self._data!r})"
def get(self, key: str, default: object | None = None) -> object | None:
return self._proxy.get(key, default)
def get(self, key: str, default: Any = None) -> Any:
"""Get value by key with optional default.
def keys(self) -> KeysView[str]:
return self._proxy.keys()
Args:
key: The key to retrieve.
default: Default value if key is not found.
def values(self) -> ValuesView[object]:
return self._proxy.values()
Returns:
The value associated with the key or default.
"""
return self._data.get(key, default)
def items(self) -> ItemsView[str, object]:
return self._proxy.items()
def keys(self) -> Any:
"""Return a view of the dictionary keys.
def copy(self) -> Mapping[str, object]:
return MappingProxyType(copy.deepcopy(dict(self._proxy)))
Returns:
A view object of the dictionary keys.
"""
return self._data.keys()
# Mutation methods intentionally raise errors
def __setitem__(self, key: str, value: object) -> None: # pragma: no cover
self._raise_mutation_error()
def values(self) -> Any:
"""Return a view of the dictionary values.
def __delitem__(self, key: str) -> None: # pragma: no cover
self._raise_mutation_error()
Returns:
A view object of the dictionary values.
"""
return self._data.values()
def pop(self, key: str, default: object | None = None) -> object: # pragma: no cover
self._raise_mutation_error()
raise AssertionError("pop should never return")
def items(self) -> Any:
"""Return a view of the dictionary key-value pairs.
def popitem(self) -> tuple[str, object]: # pragma: no cover
self._raise_mutation_error()
raise AssertionError("popitem should never return")
Returns:
A view object of the dictionary items.
"""
return self._data.items()
def clear(self) -> None: # pragma: no cover
self._raise_mutation_error()
def copy(self) -> dict[str, Any]:
"""Create a shallow copy of the dictionary.
def update(self, *args: object, **kwargs: object) -> None: # pragma: no cover
self._raise_mutation_error()
Returns:
A new dictionary with the same key-value pairs.
"""
return self._data.copy()
# Mutation methods that raise errors
def __setitem__(self, key: str, value: Any) -> None:
"""Prevent item assignment (raises ImmutableStateError).
Args:
key: The key to set (will raise error).
value: The value to set (will raise error).
Raises:
ImmutableStateError: Always, as state is immutable.
"""
self._check_frozen()
def __delitem__(self, key: str) -> None:
"""Prevent item deletion (raises ImmutableStateError).
Args:
key: The key to delete (will raise error).
Raises:
ImmutableStateError: Always, as state is immutable.
"""
self._check_frozen()
def pop(self, key: str, default: Any = None) -> Any:
"""Prevent popping items (raises ImmutableStateError).
Args:
key: The key to pop (will raise error).
default: Default value (will raise error before using).
Returns:
Never returns as it always raises an error.
Raises:
ImmutableStateError: Always, as state is immutable.
"""
self._check_frozen()
return None # Never reached
def popitem(self) -> tuple[str, Any]:
"""Prevent popping items (raises ImmutableStateError).
Returns:
Never returns as it always raises an error.
Raises:
ImmutableStateError: Always, as state is immutable.
"""
self._check_frozen()
return ("", None) # Never reached
def clear(self) -> None:
"""Prevent clearing the dictionary (raises ImmutableStateError).
Raises:
ImmutableStateError: Always, as state is immutable.
"""
self._check_frozen()
def update(self, *args: Any, **kwargs: Any) -> None:
"""Prevent updating the dictionary (raises ImmutableStateError).
Args:
*args: Arguments to update with (will raise error).
**kwargs: Keyword arguments to update with (will raise error).
Raises:
ImmutableStateError: Always, as state is immutable.
"""
self._check_frozen()
def setdefault(self, key: str, default: Any = None) -> Any:
"""Prevent setting default values (raises ImmutableStateError).
Args:
key: The key to set default for (will raise error).
default: The default value (will raise error before using).
Returns:
Never returns as it always raises an error.
Raises:
ImmutableStateError: Always, as state is immutable.
"""
self._check_frozen()
return None # Never reached
def setdefault(self, key: str, default: object | None = None) -> object: # pragma: no cover
self._raise_mutation_error()
raise AssertionError("setdefault should never return")
def __deepcopy__(self, memo: dict[int, object]) -> ImmutableDict:
"""Support deepcopy by creating a new immutable dict with copied data."""
# Deep copy the internal data
new_data = {key: copy.deepcopy(value, memo) for key, value in self.items()}
obj_id = id(self)
if obj_id in memo:
return memo[obj_id] # type: ignore[return-value]
# Create a new ImmutableDict with the copied data
return ImmutableDict(new_data)
placeholder = ImmutableDict({})
memo[obj_id] = placeholder
new_data = {key: copy.deepcopy(value, memo) for key, value in self._proxy.items()}
result = ImmutableDict(new_data)
memo[obj_id] = result
return result
def create_immutable_state(state: dict[str, Any]) -> ImmutableDict:
def create_immutable_state(state: Mapping[str, object]) -> ImmutableDict:
"""Create an immutable version of a state dictionary.
Args:
@@ -314,12 +191,12 @@ def create_immutable_state(state: dict[str, Any]) -> ImmutableDict:
An ImmutableDict that prevents modifications.
"""
# Deep copy to prevent references to mutable objects
return ImmutableDict(copy.deepcopy(state))
return ImmutableDict(copy.deepcopy(dict(state)))
def update_state_immutably(
current_state: dict[str, Any], updates: dict[str, Any]
) -> dict[str, Any]:
current_state: Mapping[str, object], updates: Mapping[str, object]
) -> dict[str, object]:
"""Create a new state with updates applied immutably.
This function creates a new state object by merging the current state
@@ -334,35 +211,60 @@ def update_state_immutably(
"""
# Deep copy the current state into a regular dict
# If it's an ImmutableDict, convert to regular dict first
new_state: dict[str, Any]
if isinstance(current_state, ImmutableDict):
new_state = {key: copy.deepcopy(value) for key, value in current_state.items()}
else:
new_state = copy.deepcopy(current_state)
base_state = dict(current_state)
new_state: dict[str, object] = copy.deepcopy(base_state)
def _as_dict(obj: object) -> dict[str, object] | None:
if isinstance(obj, dict):
return obj
if isinstance(obj, Mapping):
return {key: obj[key] for key in obj.keys()}
return None
def _deep_merge(
target: dict[str, object], source: Mapping[str, object]
) -> dict[str, object]:
for merge_key, merge_value in source.items():
existing_sub = _as_dict(target.get(merge_key))
incoming_sub = _as_dict(merge_value)
if existing_sub is not None and incoming_sub is not None:
target[merge_key] = _deep_merge(
copy.deepcopy(existing_sub), incoming_sub
)
else:
target[merge_key] = copy.deepcopy(merge_value)
return target
# Apply updates
for key, value in updates.items():
# Check for replacement marker
if isinstance(value, tuple) and len(value) == 2 and value[0] == "__REPLACE__":
# Force replacement, ignore existing value
new_state[key] = value[1]
elif (
replace_tuple = (
isinstance(value, tuple)
and len(value) == 2
and value[0] == "__REPLACE__"
)
if replace_tuple:
new_state[key] = copy.deepcopy(value[1])
continue
existing_dict = _as_dict(new_state.get(key))
incoming_dict = _as_dict(value)
if existing_dict is not None and incoming_dict is not None:
new_state[key] = _deep_merge(
copy.deepcopy(existing_dict), incoming_dict
)
continue
if (
key in new_state
and isinstance(new_state[key], list)
and isinstance(value, list)
):
# For lists, create a new list (don't extend in place)
new_state[key] = new_state[key] + value
elif (
key in new_state
and isinstance(new_state[key], dict)
and isinstance(value, dict)
):
# For dicts, merge recursively
new_state[key] = {**new_state[key], **value}
else:
# Direct assignment for other types
new_state[key] = value
merged_list = copy.deepcopy(new_state[key])
merged_list.extend(copy.deepcopy(value))
new_state[key] = merged_list
continue
new_state[key] = copy.deepcopy(value)
return new_state
@@ -400,7 +302,7 @@ def ensure_immutable_node(
return await result if inspect.iscoroutine(result) else result
state = args[0]
if not isinstance(state, dict):
if not isinstance(state, MutableMapping):
result = node_func(*args, **kwargs)
return await result if inspect.iscoroutine(result) else result
@@ -436,7 +338,7 @@ def ensure_immutable_node(
return node_func(*args, **kwargs)
state = args[0]
if not isinstance(state, dict):
if not isinstance(state, MutableMapping):
return node_func(*args, **kwargs)
# PERFORMANCE WARNING: Deep copying state before/after execution
@@ -487,24 +389,17 @@ class StateUpdater:
```
"""
def __init__(self, base_state: dict[str, Any] | Any):
def __init__(self, base_state: Mapping[str, object]):
"""Initialize with a base state.
Args:
base_state: The starting state.
"""
# Handle both regular dicts and ImmutableDict
if isinstance(base_state, ImmutableDict):
# Convert ImmutableDict to regular dict for internal use
self._state: dict[str, Any] = {}
for key, value in base_state.items():
self._state[key] = copy.deepcopy(value)
else:
# Regular dict can be deepcopied normally
self._state = copy.deepcopy(base_state)
self._updates: dict[str, Any] = {}
self._state: dict[str, object] = dict(base_state)
self._updates: dict[str, object] = {}
def set(self, key: str, value: Any) -> StateUpdater:
def set(self, key: str, value: object) -> StateUpdater:
"""Set a value in the state.
Args:
@@ -517,7 +412,7 @@ class StateUpdater:
self._updates[key] = value
return self
def replace(self, key: str, value: Any) -> StateUpdater:
def replace(self, key: str, value: object) -> StateUpdater:
"""Replace a value in the state, ignoring existing value.
This forces replacement even for dicts and lists, unlike set()
@@ -533,7 +428,7 @@ class StateUpdater:
self._updates[key] = ("__REPLACE__", value)
return self
def append(self, key: str, value: Any) -> StateUpdater:
def append(self, key: str, value: object) -> StateUpdater:
"""Append to a list in the state.
Args:
@@ -560,7 +455,7 @@ class StateUpdater:
return self
def extend(self, key: str, values: list[Any]) -> StateUpdater:
def extend(self, key: str, values: Sequence[object]) -> StateUpdater:
"""Extend a list in the state.
Args:
@@ -583,11 +478,11 @@ class StateUpdater:
self._updates[key] = list(current_list)
if isinstance(self._updates[key], list):
self._updates[key].extend(values)
self._updates[key].extend(list(values))
return self
def merge(self, key: str, updates: dict[str, Any]) -> StateUpdater:
def merge(self, key: str, updates: Mapping[str, object]) -> StateUpdater:
"""Merge updates into a dict in the state.
Args:
@@ -610,7 +505,7 @@ class StateUpdater:
self._updates[key] = dict(current_dict)
if isinstance(self._updates[key], dict):
self._updates[key].update(updates)
self._updates[key].update(dict(updates))
return self
@@ -636,7 +531,7 @@ class StateUpdater:
self._updates[key] = current_value + amount
return self
def build(self) -> dict[str, Any]:
def build(self) -> dict[str, object]:
"""Build the final state with all updates applied.
Returns:
@@ -645,7 +540,7 @@ class StateUpdater:
return update_state_immutably(self._state, self._updates)
def validate_state_schema(state: dict[str, Any], schema: type) -> None:
def validate_state_schema(state: Mapping[str, object], schema: type) -> None:
"""Validate that a state conforms to a schema.
Args:

View File

@@ -347,7 +347,7 @@ CleanupFunctionWithArgs = Callable[..., Awaitable[None]]
# Error handling types to break circular imports with core.errors
class ErrorDetails(TypedDict):
class ErrorDetails(TypedDict):
"""Detailed error information."""
type: str
@@ -355,7 +355,7 @@ class ErrorDetails(TypedDict):
severity: str
category: str
timestamp: str
context: dict[str, Any]
context: dict[str, object]
traceback: str | None
retry_exhausted: NotRequired[bool] # Added for retry handling
notification_sent: NotRequired[bool] # Added for notification tracking
@@ -365,7 +365,7 @@ class ErrorDetails(TypedDict):
escalated: NotRequired[bool] # Whether error has been escalated
class ErrorInfo(TypedDict):
class ErrorInfo(TypedDict):
"""Error information structure for state management."""
message: str
@@ -374,7 +374,8 @@ class ErrorInfo(TypedDict):
severity: str
category: str
timestamp: str
context: dict[str, Any]
context: dict[str, object]
details: NotRequired[dict[str, object]]
traceback: str | None

View File

@@ -1,333 +1,201 @@
"""Backward compatibility layer for URL processing.
"""Deprecated bridge to the modern URL processing capability layer.
⚠️ DEPRECATED: This module is deprecated and will be removed in a future version.
Please use biz_bud.tools.capabilities.url_processing instead.
This module used to provide a large backward-compatibility surface for the
legacy URLProcessor implementation. As part of the LangGraph v1 migration we
no longer ship stub fallbacks the new tool-backed service is always required.
This module provides backward compatibility imports for the URL processing
system that has been migrated to the tools layer. Existing imports will
continue to work while displaying deprecation warnings.
Migration Guide:
Old import:
from biz_bud.core.url_processing import URLProcessor
New import:
from biz_bud.tools.capabilities.url_processing import (
validate_url, normalize_url, discover_urls, process_urls_batch
)
Instead of:
processor = URLProcessor()
result = await processor.validate_url("https://example.com")
Use:
result = await validate_url("https://example.com")
Only a thin wrapper remains so existing imports continue to work while emitting
clear deprecation guidance. Callers should migrate to
``biz_bud.tools.capabilities.url_processing`` as soon as possible.
"""
from __future__ import annotations
import asyncio
import warnings
from collections.abc import Coroutine, Mapping
from dataclasses import asdict, is_dataclass
from typing import TYPE_CHECKING, TypeVar
# Issue deprecation warning
warnings.warn(
"biz_bud.core.url_processing is deprecated. "
"Please use biz_bud.tools.capabilities.url_processing instead.",
DeprecationWarning,
stacklevel=2
)
from biz_bud.services.factory import get_global_factory
# Try to import from new location first, fall back to old location
try:
# Import new implementations for backward compatibility
if TYPE_CHECKING: # pragma: no cover - used for static analysis only
from biz_bud.tools.capabilities.url_processing.service import URLProcessingService
# Service is imported above for direct usage
# Legacy URLProcessor wrapper class
class URLProcessor:
"""Legacy URLProcessor for backward compatibility."""
def __init__(self, config=None):
warnings.warn(
"URLProcessor is deprecated. Use URL processing tools instead.",
DeprecationWarning,
stacklevel=2
)
self.config = config
self._service: URLProcessingService | None = None
async def _get_service(self) -> URLProcessingService:
"""Get or create URL processing service."""
if self._service is None:
from biz_bud.services.factory import get_global_factory
factory = await get_global_factory()
self._service = await factory.get_service(URLProcessingService)
assert self._service is not None # For type checker
return self._service
async def validate_url(self, url: str):
"""Validate URL using new service."""
service = await self._get_service()
result = await service.validate_url(url)
# Convert dataclass to dict for backward compatibility
from dataclasses import asdict, is_dataclass
return asdict(result) if is_dataclass(result) else result
async def normalize_url(self, url: str) -> str:
"""Normalize URL using new service."""
service = await self._get_service()
return service.normalize_url(url)
async def discover_urls(self, base_url: str):
"""Discover URLs using new service."""
service = await self._get_service()
result = await service.discover_urls(base_url)
return result.discovered_urls if hasattr(result, 'discovered_urls') else [base_url]
async def deduplicate_urls(self, urls: list[str]):
"""Deduplicate URLs using new service."""
service = await self._get_service()
return await service.deduplicate_urls(urls)
async def process_urls(self, urls: list[str]):
"""Process URLs using new service."""
service = await self._get_service()
return await service.process_urls_batch(urls)
new_tools_available = True
except ImportError:
# Fall back to old implementation if new tools not available
new_tools_available = False
# Re-export key utilities for backward compatibility
from ..utils.url_analyzer import (
URLAnalysisResult,
URLType,
analyze_url_type,
get_url_type,
is_git_repo_url,
is_pdf_url,
is_valid_url,
normalize_url,
)
from ..utils.url_normalizer import URLNormalizer
# Legacy classes for backward compatibility
class URLProcessorConfig:
"""Legacy config class."""
class ValidationLevel:
"""Legacy validation level enum."""
BASIC = "basic"
STANDARD = "standard"
STRICT = "strict"
class URLDiscoverer:
"""Legacy discoverer class."""
class URLFilter:
"""Legacy filter class."""
class URLValidator:
"""Legacy validator class."""
# Legacy exception classes
class URLProcessingError(Exception):
"""Base URL processing error."""
class URLValidationError(URLProcessingError):
"""URL validation error."""
class URLNormalizationError(URLProcessingError):
"""URL normalization error."""
class URLDiscoveryError(URLProcessingError):
"""URL discovery error."""
class URLFilterError(URLProcessingError):
"""URL filter error."""
class URLDeduplicationError(URLProcessingError):
"""URL deduplication error."""
class URLCacheError(URLProcessingError):
"""URL cache error."""
# Legacy data model classes
class ProcessedURL:
"""Legacy processed URL model."""
class URLAnalysis:
"""Legacy URL analysis model."""
class ValidationResult:
"""Legacy validation result model."""
class DiscoveryResult:
"""Legacy discovery result model."""
class FilterResult:
"""Legacy filter result model."""
class DeduplicationResult:
"""Legacy deduplication result model."""
# URLProcessor is defined above in the try block
# Version information
__version__ = "1.0.0"
__author__ = "Business Buddy Team"
# Public API - what gets imported with "from url_processing import *"
if new_tools_available:
__all__ = [
# Main processor class (deprecated)
"URLProcessor",
# New service class
"URLProcessingService",
# Module metadata
"__version__",
"__author__",
# Convenience functions
"validate_url",
"normalize_url_simple",
]
else:
__all__ = [
# Main processor class
"URLProcessor",
# Component classes
"URLValidator",
"URLFilter",
"URLDiscoverer",
# Configuration
"URLProcessorConfig",
"ValidationLevel",
# Data models
"ProcessedURL",
"URLAnalysis",
"ValidationResult",
"DiscoveryResult",
"FilterResult",
"DeduplicationResult",
# Exceptions
"URLProcessingError",
"URLValidationError",
"URLNormalizationError",
"URLDiscoveryError",
"URLFilterError",
"URLDeduplicationError",
"URLCacheError",
# Backward compatibility exports
"analyze_url_type",
"normalize_url",
"is_valid_url",
"is_git_repo_url",
"is_pdf_url",
"get_url_type",
"URLType",
"URLAnalysisResult",
"URLNormalizer",
# Module metadata
"__version__",
"__author__",
]
_T = TypeVar("_T")
warnings.warn(
"biz_bud.core.url_processing is deprecated. "
"Use biz_bud.tools.capabilities.url_processing instead.",
DeprecationWarning,
stacklevel=2,
)
# Module-level convenience functions for backward compatibility
if new_tools_available:
def validate_url(url: str, strict: bool = False) -> bool:
"""Quick URL validation using new service (deprecated).
def _run_sync(coro: Coroutine[object, object, _T]) -> _T:
"""Execute an async operation from a synchronous helper safely."""
⚠️ DEPRECATED: Use biz_bud.tools.capabilities.url_processing.validate_url instead.
"""
warnings.warn(
"validate_url from core.url_processing is deprecated. "
"Use biz_bud.tools.capabilities.url_processing.validate_url instead.",
DeprecationWarning,
stacklevel=2
try:
loop = asyncio.get_running_loop()
except RuntimeError:
# No loop running in this thread; safe to create one
return asyncio.run(coro)
if loop.is_running():
raise RuntimeError(
"Synchronous URL processing helpers cannot run inside an active "
"event loop. Call the async variants from "
"biz_bud.tools.capabilities.url_processing instead."
)
import asyncio
async def _validate():
from biz_bud.services.factory import get_global_factory
from biz_bud.tools.capabilities.url_processing.service import (
URLProcessingService,
)
factory = await get_global_factory()
service = await factory.get_service(URLProcessingService)
result = await service.validate_url(url, "strict" if strict else "standard")
return result.is_valid if hasattr(result, 'is_valid') else False
# A loop exists but isn't running in this thread submit the coroutine
future = asyncio.run_coroutine_threadsafe(coro, loop)
return future.result()
return asyncio.run(_validate())
def normalize_url_simple(url: str) -> str:
"""Quick URL normalization using new service (deprecated).
def _convert_dataclass(result: object) -> object:
"""Convert dataclass results into plain dictionaries recursively."""
⚠️ DEPRECATED: Use biz_bud.tools.capabilities.url_processing.normalize_url instead.
"""
if is_dataclass(result):
return asdict(result)
return result
class URLProcessor:
"""Deprecated wrapper that now forwards to URLProcessingService."""
def __init__(self, config: Mapping[str, object] | None = None) -> None:
warnings.warn(
"normalize_url_simple from core.url_processing is deprecated. "
"Use biz_bud.tools.capabilities.url_processing.normalize_url instead.",
"URLProcessor is deprecated. Use URL processing tools instead.",
DeprecationWarning,
stacklevel=2
stacklevel=2,
)
import asyncio
self._config = config or {}
self._service: "URLProcessingService | None" = None
async def _normalize():
from biz_bud.services.factory import get_global_factory
async def _get_service(self) -> "URLProcessingService":
if self._service is None:
from biz_bud.tools.capabilities.url_processing.service import (
URLProcessingService,
URLProcessingService as _Service,
)
factory = await get_global_factory()
service = await factory.get_service(URLProcessingService)
return service.normalize_url(url)
self._service = await factory.get_service(_Service)
return self._service
return asyncio.run(_normalize())
async def validate_url(
self, url: str, *, level: str | None = None, provider: str | None = None
) -> dict[str, object]:
service = await self._get_service()
provider_name = provider or level or self._config.get("validation_level") or "standard"
result = await service.validate_url(url, provider_name)
converted = _convert_dataclass(result)
return converted if isinstance(converted, dict) else {"result": converted}
else:
def validate_url(url: str, strict: bool = False) -> bool:
"""Quick URL validation using default settings.
async def normalize_url(self, url: str, *, provider: str | None = None) -> str:
service = await self._get_service()
provider_name = provider or self._config.get("normalization_provider")
return service.normalize_url(url, provider_name)
Args:
url: URL to validate
strict: Whether to use strict validation
async def discover_urls(
self, base_url: str, *, provider: str | None = None, max_results: int | None = None
) -> list[str]:
service = await self._get_service()
provider_name = provider or self._config.get("discovery_provider")
result = await service.discover_urls(base_url, provider_name)
urls = list(result.discovered_urls)
if max_results is not None and len(urls) > max_results:
return urls[:max_results]
return urls
Returns:
True if URL is valid
async def deduplicate_urls(
self, urls: list[str], *, provider: str | None = None
) -> list[str]:
service = await self._get_service()
provider_name = provider or self._config.get("deduplication_provider")
return await service.deduplicate_urls(urls, provider_name)
Example:
>>> validate_url("https://example.com")
True
>>> validate_url("not-a-url")
False
"""
from ..utils.url_analyzer import is_valid_url
return is_valid_url(url, strict=strict)
def normalize_url_simple(url: str) -> str:
"""Quick URL normalization using default settings.
Args:
url: URL to normalize
Returns:
Normalized URL string
Example:
>>> normalize_url_simple("HTTP://Example.COM/Path?param=1#fragment")
"https://example.com/Path?param=1"
"""
from ..utils.url_analyzer import normalize_url
return normalize_url(url)
async def process_urls(
self, urls: list[str], **options: object
) -> dict[str, object]:
service = await self._get_service()
result = await service.process_urls_batch(
urls,
validation_provider=options.get("validation_provider"),
normalization_provider=options.get("normalization_provider"),
enable_deduplication=bool(options.get("enable_deduplication", True)),
deduplication_provider=options.get("deduplication_provider"),
max_concurrent=options.get("max_concurrent"),
timeout=options.get("timeout"),
)
converted = _convert_dataclass(result)
return converted if isinstance(converted, dict) else {"result": converted}
# Module initialization
def _initialize_module() -> None:
"""Initialize the URL processing module."""
# Module initialization logic can go here
# For now, just validate that required dependencies are available
async def _fetch_service() -> "URLProcessingService":
from biz_bud.tools.capabilities.url_processing.service import (
URLProcessingService as _Service,
)
factory = await get_global_factory()
return await factory.get_service(_Service)
# Run initialization when module is imported
_initialize_module()
def validate_url(url: str, strict: bool = False) -> bool:
"""Synchronous helper that proxies to the async validation tool."""
warnings.warn(
"validate_url from biz_bud.core.url_processing is deprecated. "
"Use biz_bud.tools.capabilities.url_processing.validate_url instead.",
DeprecationWarning,
stacklevel=2,
)
async def _validate() -> bool:
level = "strict" if strict else "standard"
service = await _fetch_service()
result = await service.validate_url(url, level)
return bool(getattr(result, "is_valid", False))
return _run_sync(_validate())
def normalize_url_simple(url: str) -> str:
"""Synchronous helper for quick URL normalization."""
warnings.warn(
"normalize_url_simple from biz_bud.core.url_processing is deprecated. "
"Use biz_bud.tools.capabilities.url_processing.normalize_url instead.",
DeprecationWarning,
stacklevel=2,
)
async def _normalize() -> str:
service = await _fetch_service()
return await service.normalize_url(url, None)
return _run_sync(_normalize())
__all__ = [
"URLProcessor",
"URLProcessingService",
"__version__",
"__author__",
"validate_url",
"normalize_url_simple",
]
def __getattr__(name: str) -> object:
if name == "URLProcessingService":
from biz_bud.tools.capabilities.url_processing.service import (
URLProcessingService as _Service,
)
globals()[name] = _Service
return _Service
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@@ -1,160 +1,201 @@
"""Helper functions for graph creation and state initialization.
This module consolidates common patterns used in graph creation to reduce
code duplication and improve maintainability.
"""
import json
from typing import Any, TypeVar
from biz_bud.logging import get_logger
logger = get_logger(__name__)
T = TypeVar("T")
def process_state_query(
query: str | None,
messages: list[dict[str, Any]] | None,
state_update: dict[str, Any] | None,
default_query: str
) -> str:
"""Process and extract query from various sources.
Args:
query: Direct query string
messages: Message history to extract query from
state_update: State update that may contain query
default_query: Default query to use if none found
Returns:
Processed query string
"""
# Extract query from messages if not provided
if query is None and messages:
# Look for the last human/user message
for msg in reversed(messages):
# Handle different message formats
role = msg.get("role")
msg_type = msg.get("type")
if role in ("user", "human") or msg_type == "human":
query = msg.get("content", "")
break
return query or default_query
def format_raw_input(
raw_input: str | dict[str, Any] | None,
user_query: str
) -> tuple[str, str]:
"""Format raw input into a consistent string format.
Args:
raw_input: Raw input data (string, dict, or None).
If a dict, must be JSON serializable.
user_query: User query to use as default
Returns:
Tuple of (raw_input_str, extracted_query)
Raises:
TypeError: If raw_input is a dict but not JSON serializable.
"""
if raw_input is None:
return f'{{"query": "{user_query}"}}', user_query
if isinstance(raw_input, dict):
# If raw_input has a query field, use it
extracted_query = raw_input.get("query", user_query)
# Avoid json.dumps for simple cases
if len(raw_input) == 1 and "query" in raw_input:
return f'{{"query": "{raw_input["query"]}"}}', extracted_query
else:
try:
return json.dumps(raw_input), extracted_query
except (TypeError, ValueError) as e:
# Fallback: show error message in string
raw_input_str = f"<non-serializable dict: {e}>"
return raw_input_str, extracted_query
# For unsupported types (already checked that raw_input is str above)
raw_input_str = f"<unsupported type: {type(raw_input).__name__}>"
return raw_input_str, user_query
def extract_state_update_data(
state_update: dict[str, Any] | None,
messages: list[dict[str, Any]] | None,
raw_input: str | dict[str, Any] | None,
thread_id: str | None
) -> tuple[list[dict[str, Any]] | None, str | dict[str, Any] | None, str | None]:
"""Extract data from state update if provided.
Args:
state_update: State update dictionary from LangGraph API
messages: Existing messages
raw_input: Existing raw input
thread_id: Existing thread ID
Returns:
Tuple of (messages, raw_input, thread_id)
"""
if not state_update:
return messages, raw_input, thread_id
# Extract messages if present
if "messages" in state_update and not messages:
messages = state_update["messages"]
# Extract other fields if present
if "raw_input" in state_update and not raw_input:
raw_input = state_update["raw_input"]
if "thread_id" in state_update and not thread_id:
thread_id = state_update["thread_id"]
return messages, raw_input, thread_id
def create_initial_state_dict(
raw_input_str: str,
user_query: str,
messages: list[dict[str, Any]],
thread_id: str,
config_dict: dict[str, Any]
) -> dict[str, Any]:
"""Create the initial state dictionary.
Args:
raw_input_str: Formatted raw input string
user_query: User query
messages: Message history
thread_id: Thread identifier
config_dict: Configuration dictionary
Returns:
Initial state dictionary
"""
return {
"raw_input": raw_input_str,
"parsed_input": {
"raw_payload": {
"query": user_query,
},
"user_query": user_query,
},
"messages": messages,
"initial_input": {
"query": user_query,
},
"thread_id": thread_id,
"config": config_dict,
"input_metadata": {},
"context": {},
"status": "pending",
"errors": [],
"run_metadata": {},
"is_last_step": False,
"final_result": None,
}
"""Helper functions for graph creation and state initialization.
This module consolidates common patterns used in graph creation to reduce
code duplication and improve maintainability.
"""
import json
from collections.abc import Mapping, Sequence
from biz_bud.logging import get_logger
logger = get_logger(__name__)
def process_state_query(
query: str | None,
messages: Sequence[Mapping[str, object]] | None,
state_update: Mapping[str, object] | None,
default_query: str,
) -> str:
"""Process and extract query from various sources.
Args:
query: Direct query string
messages: Message history to extract query from
state_update: State update that may contain query
default_query: Default query to use if none found
Returns:
Processed query string
"""
if query is None and messages:
for msg in reversed(messages):
if not isinstance(msg, Mapping):
continue
role = msg.get("role")
msg_type = msg.get("type")
if role not in ("user", "human") and msg_type != "human":
continue
content = msg.get("content", "")
extracted: str | None = None
if isinstance(content, str):
extracted = content
elif isinstance(content, list):
parts: list[str] = []
for part in content:
if isinstance(part, str):
parts.append(part)
elif isinstance(part, Mapping):
text_value = part.get("text") or part.get("content")
if isinstance(text_value, str):
parts.append(text_value)
if parts:
extracted = " ".join(parts)
elif isinstance(content, Mapping):
text_value = content.get("text") or content.get("content")
if isinstance(text_value, str):
extracted = text_value
if extracted:
query = extracted
break
return query or default_query
def format_raw_input(
raw_input: str | Mapping[str, object] | None,
user_query: str,
) -> tuple[str, str]:
"""Format raw input into a consistent string format.
Args:
raw_input: Raw input data (string, dict, or None).
If a dict, must be JSON serializable.
user_query: User query to use as default
Returns:
Tuple of (raw_input_str, extracted_query)
Raises:
TypeError: If raw_input is a dict but not JSON serializable.
"""
if raw_input is None:
return json.dumps({"query": user_query}), user_query
if isinstance(raw_input, str):
return json.dumps({"query": raw_input}), raw_input
if isinstance(raw_input, Mapping):
raw_input_dict = dict(raw_input)
# If raw_input has a query field, use it
extracted_query = str(raw_input_dict.get("query", user_query))
# Avoid json.dumps for simple cases
if len(raw_input_dict) == 1 and "query" in raw_input_dict:
return json.dumps({"query": raw_input_dict["query"]}), extracted_query
else:
try:
return json.dumps(raw_input_dict), extracted_query
except (TypeError, ValueError) as e:
fallback = {
"error": f"non-serializable dict: {e}",
"query": extracted_query,
}
return json.dumps(fallback), extracted_query
fallback = {
"error": f"unsupported type: {type(raw_input).__name__}",
"query": user_query,
}
return json.dumps(fallback), user_query
def extract_state_update_data(
state_update: Mapping[str, object] | None,
messages: list[dict[str, object]] | None,
raw_input: str | Mapping[str, object] | None,
thread_id: str | None,
) -> tuple[
list[dict[str, object]] | None,
str | Mapping[str, object] | None,
str | None,
]:
"""Extract data from state update if provided.
Args:
state_update: State update dictionary from LangGraph API
messages: Existing messages
raw_input: Existing raw input
thread_id: Existing thread ID
Returns:
Tuple of (messages, raw_input, thread_id)
"""
if not state_update:
return messages, raw_input, thread_id
# Extract messages if present
if "messages" in state_update and not messages:
messages_value = state_update["messages"]
if isinstance(messages_value, list):
coerced_messages: list[dict[str, object]] = []
for item in messages_value:
if isinstance(item, Mapping):
coerced_messages.append(dict(item))
if coerced_messages:
messages = coerced_messages
# Extract other fields if present
if "raw_input" in state_update and not raw_input:
raw_input = state_update["raw_input"]
if "thread_id" in state_update and not thread_id:
thread_id = state_update["thread_id"]
return messages, raw_input, thread_id
def create_initial_state_dict(
raw_input_str: str,
user_query: str,
messages: list[dict[str, object]],
thread_id: str,
config_dict: Mapping[str, object],
) -> dict[str, object]:
"""Create the initial state dictionary.
Args:
raw_input_str: Formatted raw input string
user_query: User query
messages: Message history
thread_id: Thread identifier
config_dict: Configuration dictionary
Returns:
Initial state dictionary
"""
return {
"raw_input": raw_input_str,
"parsed_input": {
"raw_payload": {
"query": user_query,
},
"user_query": user_query,
},
"messages": messages,
"initial_input": {
"query": user_query,
},
"thread_id": thread_id,
"config": dict(config_dict),
"input_metadata": {},
"context": {},
"status": "pending",
"errors": [],
"run_metadata": {},
"is_last_step": False,
"final_result": None,
}

View File

@@ -1,8 +1,8 @@
"""Validation decorators for functions."""
import functools
from collections.abc import Callable
from typing import Any, ParamSpec, TypeVar, cast
import functools
from collections.abc import Callable
from typing import ParamSpec, TypeVar, cast
from ..errors import ValidationError
@@ -52,7 +52,7 @@ def validate_args(
def decorator(func: Callable[..., T]) -> Callable[..., T]:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> T:
def wrapper(*args: object, **kwargs: object) -> T:
# Get function signature
import inspect
@@ -98,7 +98,7 @@ def validate_return(
def decorator(func: Callable[..., T]) -> Callable[..., T]:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> T:
def wrapper(*args: object, **kwargs: object) -> T:
result = func(*args, **kwargs)
is_valid, error_msg = validator(result)
if not is_valid:
@@ -126,7 +126,7 @@ def validate_not_none(
def decorator(func: Callable[..., T]) -> Callable[..., T]:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> T:
def wrapper(*args: object, **kwargs: object) -> T:
import inspect
sig = inspect.signature(func)
@@ -168,7 +168,7 @@ def validate_types(
def decorator(func: Callable[..., T]) -> Callable[..., T]:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> T:
def wrapper(*args: object, **kwargs: object) -> T:
import inspect
sig = inspect.signature(func)

View File

@@ -15,8 +15,8 @@ import asyncio
import contextlib
import functools
import inspect
from collections.abc import Awaitable, Callable
from typing import Any, ParamSpec, TypeVar, cast
from collections.abc import Awaitable, Callable, Mapping
from typing import ParamSpec, TypeVar, cast
from pydantic import BaseModel, ValidationError, create_model
@@ -28,7 +28,8 @@ F = TypeVar("F", bound=Callable[..., object])
# Type alias for state dictionaries that can contain any type
# (needed for validation framework)
StateDict = dict[str, Any]
StateDict = dict[str, object]
GraphFactory = Callable[..., Awaitable[object]]
# Registry to track validated functions
_VALIDATED_FUNCTIONS: set[str] = set()
@@ -54,10 +55,13 @@ RECOMMENDED_GRAPH_METHODS: list[str] = [
]
FieldTypes = Mapping[str, type[object]]
def create_validation_model(
name: str,
required_fields: dict[str, type[Any]],
optional_fields: dict[str, type[Any]] | None = None,
required_fields: FieldTypes,
optional_fields: FieldTypes | None = None,
) -> type[BaseModel]:
"""Create a Pydantic model for node input/output validation.
@@ -69,14 +73,14 @@ def create_validation_model(
Returns:
A Pydantic model class with the specified fields and validation rules
"""
field_definitions: dict[str, tuple[Any, Any]] = {
field_definitions: dict[str, tuple[type[object], object]] = {
field_name: (field_type, ...)
for field_name, field_type in required_fields.items()
}
if optional_fields:
for field_name, field_type in optional_fields.items():
field_definitions[field_name] = (field_type, None)
model: type[BaseModel] = create_model(name, **field_definitions) # type: ignore
model = cast(type[BaseModel], create_model(name, **field_definitions))
return model
@@ -96,13 +100,13 @@ class PydanticValidator:
raise NodeValidationError(str(e), validation_type=self.phase) from e
def is_validated(func: Callable[..., Any]) -> bool:
def is_validated(func: Callable[..., object]) -> bool:
"""Check if a function has been validated."""
func_id: str = f"{func.__module__}.{func.__name__}"
return func_id in _VALIDATED_FUNCTIONS
def mark_validated(func: Callable[..., Any]) -> None:
def mark_validated(func: Callable[..., object]) -> None:
"""Mark a function as validated."""
func_id: str = f"{func.__module__}.{func.__name__}"
_VALIDATED_FUNCTIONS.add(func_id)
@@ -209,11 +213,8 @@ def validate_node_output(output_model: type[BaseModel]) -> Callable[[F], F]:
result = await async_func(state, *args, **kwargs)
result_dict: StateDict = result
try:
pydantic_validator = cast(
"Callable[[dict[str, Any]], BaseModel]", validator
)
validated_result: BaseModel = pydantic_validator(result_dict)
return validated_result.model_dump()
validated_result = validator(result_dict)
return cast(StateDict, validated_result.model_dump())
except NodeValidationError as e:
source: str = f"{func.__module__}.{func.__name__}"
return add_exception_error(result_dict, e, source)
@@ -230,11 +231,8 @@ def validate_node_output(output_model: type[BaseModel]) -> Callable[[F], F]:
result = sync_func(state, *args, **kwargs)
result_dict: StateDict = result
try:
pydantic_validator = cast(
"Callable[[dict[str, Any]], BaseModel]", validator
)
validated_result: BaseModel = pydantic_validator(result_dict)
return validated_result.model_dump()
validated_result = validator(result_dict)
return cast(StateDict, validated_result.model_dump())
except NodeValidationError as e:
source: str = f"{func.__module__}.{func.__name__}"
return add_exception_error(result_dict, e, source)
@@ -346,7 +344,7 @@ async def ensure_graph_compatibility(
async def validate_all_graphs(
graph_functions: dict[str, Callable[[], Any]],
graph_functions: Mapping[str, GraphFactory],
) -> bool:
"""Validate all graph creation functions.
@@ -362,9 +360,7 @@ async def validate_all_graphs(
# Use inspect to determine if function takes arguments
sig = inspect.signature(func)
if len(sig.parameters) == 0:
# Cast to the expected coroutine function type
coro_func = cast("Callable[[], Awaitable[object]]", func)
graph = await coro_func()
graph = await cast("Callable[[], Awaitable[object]]", func)()
else:
# Skip functions that require arguments for now
all_valid = False

View File

@@ -6,15 +6,23 @@ including interrupt-based validation, state initialization helpers, and proper e
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar, cast
import copy
from collections import deque
from collections.abc import Callable, Mapping, Sequence
from typing import Literal, Protocol, TypeVar, cast
try: # pragma: no cover - defensive import guard
import json
except Exception: # pragma: no cover - fallback when json is unavailable
json = None # type: ignore[assignment]
from biz_bud.core.langgraph.state_immutability import StateUpdater
if TYPE_CHECKING:
from collections.abc import Callable, Mapping
# Remove TypedDict bound constraint as it's not allowed
T = TypeVar("T")
# Remove TypedDict bound constraint as it's not allowed
T = TypeVar("T")
StateMapping = Mapping[str, object]
StateDict = dict[str, object]
class ValidationError(Exception):
@@ -25,7 +33,7 @@ class ValidationError(Exception):
message: str,
field: str | None = None,
expected_type: type | None = None,
actual_value: Any = None,
actual_value: object | None = None,
) -> None:
super().__init__(message)
self.field = field
@@ -33,24 +41,24 @@ class ValidationError(Exception):
self.actual_value = actual_value
class StateValidator(Protocol):
class StateValidator(Protocol):
"""Protocol for state validators."""
def __call__(self, state: Mapping[str, Any]) -> bool:
def __call__(self, state: StateMapping) -> bool:
"""Validate state and return True if valid."""
...
class StateInitializer(Protocol):
class StateInitializer(Protocol):
"""Protocol for state initializers."""
def __call__(self, **kwargs: Any) -> dict[str, Any]:
def __call__(self, **kwargs: object) -> StateDict:
"""Initialize state with proper defaults."""
...
def validate_required_fields(
state: Mapping[str, Any],
def validate_required_fields(
state: StateMapping,
required_fields: set[str],
) -> None:
"""Validate that all required fields are present in state.
@@ -68,8 +76,8 @@ def validate_required_fields(
)
def validate_field_type(
state: Mapping[str, Any],
def validate_field_type(
state: StateMapping,
field_name: str,
expected_type: type,
) -> None:
@@ -97,11 +105,11 @@ def validate_field_type(
)
def create_state_validator(
required_fields: set[str] | None = None,
type_checks: dict[str, type] | None = None,
custom_validators: list[StateValidator] | None = None,
) -> StateValidator:
def create_state_validator(
required_fields: set[str] | None = None,
type_checks: Mapping[str, type[object]] | None = None,
custom_validators: Sequence[StateValidator] | None = None,
) -> StateValidator:
"""Create a composite state validator.
Args:
@@ -112,7 +120,7 @@ def create_state_validator(
Returns:
A validator function that can be used in LangGraph nodes
"""
def validator(state: Mapping[str, Any]) -> bool:
def validator(state: StateMapping) -> bool:
try:
# Check required fields
if required_fields:
@@ -137,10 +145,10 @@ def create_state_validator(
return validator
def create_validation_node(
validator: StateValidator,
error_message: str = "State validation failed",
) -> Callable[[Mapping[str, Any]], dict[str, Any]]:
def create_validation_node(
validator: StateValidator,
error_message: str = "State validation failed",
) -> Callable[[StateMapping], StateDict]:
"""Create a LangGraph node that validates state and handles errors.
Args:
@@ -150,7 +158,7 @@ def create_validation_node(
Returns:
A LangGraph node function that validates state
"""
def validation_node(state: Mapping[str, Any]) -> dict[str, Any]:
def validation_node(state: StateMapping) -> StateDict:
"""Validate state and update error status if validation fails."""
is_valid = validator(state)
@@ -163,25 +171,25 @@ def create_validation_node(
"timestamp": None, # Could add timestamp here
})
return {
"errors": errors,
"status": "error",
"is_output_valid": False,
"validation_issues": [error_message],
}
return {
"is_output_valid": True,
"validation_issues": [],
}
return {
"errors": errors,
"status": "error",
"is_output_valid": False,
"validation_issues": [error_message],
}
return {
"is_output_valid": True,
"validation_issues": [],
}
return validation_node
def create_interrupt_validation_node(
validator: StateValidator,
interrupt_message: str = "Human validation required",
) -> Callable[[Mapping[str, Any]], dict[str, Any]]:
def create_interrupt_validation_node(
validator: StateValidator,
interrupt_message: str = "Human validation required",
) -> Callable[[StateMapping], StateDict]:
"""Create a LangGraph node that validates and interrupts for human feedback.
Args:
@@ -191,7 +199,7 @@ def create_interrupt_validation_node(
Returns:
A LangGraph node function that validates and may interrupt
"""
def interrupt_validation_node(state: Mapping[str, Any]) -> dict[str, Any]:
def interrupt_validation_node(state: StateMapping) -> StateDict:
"""Validate state and interrupt if validation fails."""
is_valid = validator(state)
@@ -211,11 +219,11 @@ def create_interrupt_validation_node(
return interrupt_validation_node
def initialize_base_state(
thread_id: str,
config: dict[str, Any] | None = None,
status: Literal["pending", "running", "success", "error", "interrupted"] = "pending",
) -> dict[str, Any]:
def initialize_base_state(
thread_id: str,
config: Mapping[str, object] | None = None,
status: Literal["pending", "running", "success", "error", "interrupted"] = "pending",
) -> StateDict:
"""Initialize base state with proper defaults for LangGraph workflows.
Args:
@@ -229,7 +237,7 @@ def initialize_base_state(
return {
"messages": [],
"errors": [],
"config": config or {},
"config": dict(config) if config is not None else {},
"thread_id": thread_id,
"status": status,
"is_output_valid": None,
@@ -237,9 +245,9 @@ def initialize_base_state(
}
def initialize_search_state(
base_state: dict[str, Any] | None = None,
) -> dict[str, Any]:
def initialize_search_state(
base_state: StateDict | None = None,
) -> StateDict:
"""Initialize search-related state fields with proper defaults.
Args:
@@ -248,7 +256,7 @@ def initialize_search_state(
Returns:
State dictionary with search fields initialized
"""
state = base_state.copy() if base_state else {}
state: StateDict = base_state.copy() if base_state else {}
search_defaults = {
"search_results": [],
@@ -263,9 +271,9 @@ def initialize_search_state(
return updater.build()
def initialize_analysis_state(
base_state: dict[str, Any] | None = None,
) -> dict[str, Any]:
def initialize_analysis_state(
base_state: StateDict | None = None,
) -> StateDict:
"""Initialize analysis-related state fields with proper defaults.
Args:
@@ -274,7 +282,7 @@ def initialize_analysis_state(
Returns:
State dictionary with analysis fields initialized
"""
state = base_state.copy() if base_state else {}
state: StateDict = base_state.copy() if base_state else {}
analysis_defaults = {
"visualizations": [],
@@ -286,11 +294,11 @@ def initialize_analysis_state(
return updater.build()
def create_safe_state_initializer(
state_type: type[T],
default_values: dict[str, Any] | None = None,
required_fields: set[str] | None = None,
) -> Callable[..., T]:
def create_safe_state_initializer(
state_type: type[T],
default_values: Mapping[str, object] | None = None,
required_fields: set[str] | None = None,
) -> Callable[..., T]:
"""Create a safe state initializer that prevents KeyError crashes.
Args:
@@ -301,13 +309,15 @@ def create_safe_state_initializer(
Returns:
A function that safely initializes the state type
"""
def initializer(**kwargs: Any) -> T:
"""Initialize state with safe defaults and validation."""
# Start with provided defaults
state_dict = (default_values or {}).copy()
# Update with provided kwargs
state_dict.update(kwargs)
def initializer(**kwargs: object) -> T:
"""Initialize state with safe defaults and validation."""
base_defaults: dict[str, object] = {}
if default_values is not None:
base_defaults = copy.deepcopy(dict(default_values))
state_dict: dict[str, object] = base_defaults
state_dict.update(kwargs)
# Validate required fields if specified
if required_fields:
@@ -319,14 +329,14 @@ def create_safe_state_initializer(
return initializer
def create_business_buddy_state_initializer() -> Callable[..., dict[str, Any]]:
def create_business_buddy_state_initializer() -> Callable[..., StateDict]:
"""Create a safe initializer for BusinessBuddyState with all defaults.
Returns:
A function that initializes BusinessBuddyState with safe defaults
"""
# Define comprehensive defaults for all list fields to prevent KeyErrors
defaults = {
defaults: StateDict = {
# Core required fields (should be provided by caller)
"messages": [],
"errors": [],
@@ -412,10 +422,10 @@ def create_business_buddy_state_initializer() -> Callable[..., dict[str, Any]]:
required_fields = {"thread_id", "status", "config"}
def initializer(**kwargs: Any) -> dict[str, Any]:
def initializer(**kwargs: object) -> StateDict:
"""Initialize BusinessBuddyState with comprehensive defaults."""
# Start with defaults
state_dict = defaults.copy()
state_dict: StateDict = defaults.copy()
# Update with provided values
state_dict.update(kwargs)
@@ -428,7 +438,9 @@ def create_business_buddy_state_initializer() -> Callable[..., dict[str, Any]]:
return initializer
def create_state_access_helper(state: dict[str, Any]) -> Callable[..., Any]:
def create_state_access_helper(
state: StateMapping,
) -> Callable[[str, object | None], object | None]:
"""Create a safe state accessor that prevents KeyError crashes.
Args:
@@ -437,17 +449,17 @@ def create_state_access_helper(state: dict[str, Any]) -> Callable[..., Any]:
Returns:
A function that safely gets values with defaults
"""
def safe_get(key: str, default: Any = None) -> Any:
def safe_get(key: str, default: object | None = None) -> object | None:
"""Safely get a value from state with default fallback."""
return state.get(key, default)
return safe_get
def validate_state_structure(
state: dict[str, Any],
expected_structure: dict[str, type],
) -> list[str]:
def validate_state_structure(
state: StateMapping,
expected_structure: Mapping[str, type[object]],
) -> list[str]:
"""Validate that state has the expected structure and types.
Args:
@@ -473,10 +485,10 @@ def validate_state_structure(
return errors
def create_state_transition_validator(
pre_transition_checks: list[StateValidator] | None = None,
post_transition_checks: list[StateValidator] | None = None,
) -> Callable[[dict[str, Any], dict[str, Any]], tuple[bool, list[str]]]:
def create_state_transition_validator(
pre_transition_checks: Sequence[StateValidator] | None = None,
post_transition_checks: Sequence[StateValidator] | None = None,
) -> Callable[[StateMapping, StateMapping], tuple[bool, list[str]]]:
"""Create a validator for state transitions.
Args:
@@ -486,10 +498,10 @@ def create_state_transition_validator(
Returns:
A function that validates state transitions
"""
def validate_transition(
old_state: dict[str, Any],
new_state: dict[str, Any]
) -> tuple[bool, list[str]]:
def validate_transition(
old_state: StateMapping,
new_state: StateMapping,
) -> tuple[bool, list[str]]:
"""Validate a state transition.
Returns:
@@ -514,7 +526,7 @@ def create_state_transition_validator(
return validate_transition
def validate_message_list(messages: Any) -> bool:
def validate_message_list(messages: object) -> bool:
"""Validate that messages field contains proper LangChain message objects.
Args:
@@ -542,7 +554,7 @@ def create_message_validator() -> StateValidator:
Returns:
A validator that checks message field integrity
"""
def message_validator(state: Mapping[str, Any]) -> bool:
def message_validator(state: StateMapping) -> bool:
"""Validate message fields in state."""
if "messages" in state:
return validate_message_list(state["messages"])
@@ -551,80 +563,235 @@ def create_message_validator() -> StateValidator:
return message_validator
# Custom reducer functions for LangGraph state management
def unique_add(left: list[Any] | None, right: list[Any] | Any | None) -> list[Any]:
"""Add items to list while avoiding duplicates.
Args:
left: Existing list
right: Items to add (can be single item or list)
Returns:
Combined list without duplicates
"""
if left is None:
left = []
# Convert single item to list
if not isinstance(right, list):
right = [right] if right is not None else []
# Combine and deduplicate while preserving order
seen = set(left)
result = list(left)
for item in right:
if item not in seen:
result.append(item)
seen.add(item)
return result
# Custom reducer functions for LangGraph state management
def _as_object_list(value: Sequence[object] | object | None) -> list[object]:
"""Normalize a value into a list while treating strings as scalars."""
if value is None:
return []
if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
return list(value)
return [value]
def unique_add(
left: Sequence[object] | None, right: Sequence[object] | object | None
) -> list[object]:
"""Add items to list while avoiding duplicates."""
existing = list(left) if left is not None else []
additions = _as_object_list(right)
def _dedup_key(
item: object,
*,
_seen_ids: set[int] | None = None,
_depth: int = 0,
) -> object:
if _seen_ids is None:
_seen_ids = set()
if _depth > 10:
return ("max_depth", type(item))
obj_id = id(item)
if obj_id in _seen_ids:
return ("cycle", type(item))
_seen_ids.add(obj_id)
try:
hash(item)
except TypeError:
if isinstance(item, Mapping):
key = (
"mapping",
type(item),
tuple(
sorted(
(
key,
_dedup_key(value, _seen_ids=_seen_ids, _depth=_depth + 1),
)
for key, value in item.items()
)
),
)
elif isinstance(item, Sequence) and not isinstance(item, (str, bytes, bytearray)):
key = (
"sequence",
type(item),
tuple(
_dedup_key(value, _seen_ids=_seen_ids, _depth=_depth + 1)
for value in item
),
)
else:
attrs: dict[str, str] = {}
for attr in ("message", "code", "category"):
try:
value = getattr(item, attr)
except Exception:
continue
else:
if value is not None:
attrs[attr] = repr(value)
if attrs:
key = (
"object_attrs",
type(item),
tuple(sorted(attrs.items())),
)
else:
key = ("identity", type(item), id(item))
else:
key = ("hashable", type(item), item)
finally:
_seen_ids.discard(obj_id)
return key
seen = {_dedup_key(item) for item in existing}
for item in additions:
key = _dedup_key(item)
if key not in seen:
existing.append(item)
seen.add(key)
return existing
def error_merge(left: Any, right: Any) -> list[Any]:
"""Merge error lists while avoiding duplicates based on message.
This function handles ErrorInfo TypedDict objects and provides deduplication
based on error message content.
Args:
left: Existing error list (can be None)
right: New errors to add (can be None, single item, or list)
Returns:
Combined error list without duplicate messages
"""
if left is None:
left = []
elif not isinstance(left, list):
left = [left]
if right is None:
right = []
elif not isinstance(right, list):
right = [right]
# Create set of existing error messages for deduplication
existing_messages = set()
for error in left:
if isinstance(error, dict):
existing_messages.add(error.get("message", ""))
else:
existing_messages.add(str(error))
result = list(left)
for error in right:
if isinstance(error, dict):
error_message = error.get("message", "")
else:
error_message = str(error)
if error_message not in existing_messages:
result.append(error)
existing_messages.add(error_message)
return result
def error_merge(
left: Sequence[object] | object | None,
right: Sequence[object] | object | None,
*,
max_errors: int = 1000,
_max_depth: int = 5,
_max_items: int = 50,
) -> list[object]:
"""Merge error lists while avoiding duplicates using robust keys with predictable performance."""
capacity = max(1, int(max_errors)) if isinstance(max_errors, int) and max_errors > 0 else 1000
existing = _as_object_list(left)
if len(existing) > capacity:
existing = existing[-capacity:]
additions = _as_object_list(right)
def _safe_json(obj: object) -> str:
if json is None: # pragma: no cover - json import failed
return repr(obj)
try:
class SafeEncoder(json.JSONEncoder): # type: ignore[name-defined]
def default(self, o: object) -> object: # type: ignore[override]
try:
return str(o)
except Exception:
return repr(o)
return json.dumps(obj, sort_keys=True, ensure_ascii=False, cls=SafeEncoder) # type: ignore[name-defined]
except Exception:
return repr(obj)
def _stable_repr(value: object, depth: int = 0) -> str:
if depth >= _max_depth:
return f"<max_depth:{type(value).__name__}>"
try:
if isinstance(value, Mapping):
items = list(value.items())[:_max_items]
try:
normalized = {str(k): _stable_repr(v, depth + 1) for k, v in items}
except Exception:
normalized = {str(k): repr(v) for k, v in items}
return _safe_json(dict(sorted(normalized.items(), key=lambda kv: kv[0])))
if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
seq = list(value)[:_max_items]
try:
normalized_list = [_stable_repr(v, depth + 1) for v in seq]
except Exception:
normalized_list = [repr(v) for v in seq]
return _safe_json(normalized_list)
return repr(value)
except Exception:
return repr(value)
def _exception_key(exc: BaseException) -> tuple[object, ...]:
try:
args_repr = tuple(str(arg) for arg in getattr(exc, "args", ()))
except Exception:
args_repr = ()
code = getattr(exc, "code", None)
message = getattr(exc, "message", None)
return (
"exception",
exc.__class__.__name__,
args_repr,
str(code) if code is not None else None,
str(message) if message is not None else None,
)
def _message_key(item: object, depth: int = 0) -> tuple[object, ...]:
if depth >= _max_depth:
return ("truncated", type(item).__name__)
if isinstance(item, Mapping):
code = item.get("code")
message = item.get("message")
detail = item.get("detail") or item.get("details")
category = item.get("category")
if code is None and message is None and detail is None and category is None:
truncated_items = tuple(
sorted(
(str(k), _stable_repr(v, depth + 1))
for k, v in list(item.items())[:_max_items]
)
)
return ("mapping_fallback", type(item).__name__, _safe_json(truncated_items))
return (
"mapping",
type(item).__name__,
str(code) if code is not None else None,
str(message) if message is not None else None,
str(detail) if detail is not None else None,
str(category) if category is not None else None,
)
if isinstance(item, BaseException):
return _exception_key(item)
return ("other", item.__class__.__name__, _stable_repr(item, depth + 1))
result: deque[object] = deque()
key_order: deque[tuple[object, ...]] = deque()
key_counts: dict[tuple[object, ...], int] = {}
for existing_item in existing:
key = _message_key(existing_item)
if key in key_counts:
continue
result.append(existing_item)
key_order.append(key)
key_counts[key] = key_counts.get(key, 0) + 1
for item in additions:
key = _message_key(item)
if key in key_counts:
continue
result.append(item)
key_order.append(key)
key_counts[key] = 1
if len(result) > capacity:
evict_count = len(result) - capacity
for _ in range(evict_count):
evicted_key = key_order.popleft()
result.popleft()
count = key_counts.get(evicted_key)
if count is None:
continue
if count <= 1:
del key_counts[evicted_key]
else:
key_counts[evicted_key] = count - 1
return list(result)
def status_override(left: str, right: str | None) -> str:
@@ -640,7 +807,7 @@ def status_override(left: str, right: str | None) -> str:
return right if right is not None else left
def config_merge(left: dict[str, Any], right: dict[str, Any]) -> dict[str, Any]:
def config_merge(left: StateDict, right: Mapping[str, object]) -> StateDict:
"""Deep merge configuration dictionaries.
Args:
@@ -650,22 +817,26 @@ def config_merge(left: dict[str, Any], right: dict[str, Any]) -> dict[str, Any]:
Returns:
Merged configuration dictionary
"""
# left and right are guaranteed to be dict[str, Any] by type annotations
# ``left`` and ``right`` are guaranteed to use ``StateDict``-compatible payloads.
result = left.copy()
result: StateDict = left.copy()
for key, value in right.items():
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
# Recursively merge nested dictionaries
result[key] = config_merge(result[key], value)
else:
# Override with new value
result[key] = value
return result
def safe_list_add(left: list[Any] | None, right: list[Any] | Any | None) -> list[Any]:
merged_value = config_merge(cast(StateDict, result[key]), value)
result[key] = merged_value
else:
# Override with new value
result[key] = value
return result
def safe_list_add(
left: Sequence[object] | None,
right: Sequence[object] | object | None,
) -> list[object]:
"""Safely add items to a list, handling None values gracefully.
Args:
@@ -675,19 +846,16 @@ def safe_list_add(left: list[Any] | None, right: list[Any] | Any | None) -> list
Returns:
Combined list with proper None handling
"""
if left is None:
left = []
if right is None:
return list(left)
if not isinstance(right, list):
right = [right]
return list(left) + list(right)
def dict_list_merge(left: list[dict[str, Any]] | None, right: list[dict[str, Any]] | dict[str, Any] | None) -> list[dict[str, Any]]:
existing = list(left) if left is not None else []
additions = _as_object_list(right)
return existing + additions
def dict_list_merge(
left: Sequence[Mapping[str, object]] | None,
right: Sequence[Mapping[str, object]] | Mapping[str, object] | None,
) -> list[Mapping[str, object]]:
"""Merge lists of dictionaries, useful for complex state objects.
Args:
@@ -697,13 +865,10 @@ def dict_list_merge(left: list[dict[str, Any]] | None, right: list[dict[str, Any
Returns:
Combined list of dictionaries
"""
if left is None:
left = []
if right is None:
return list(left)
if isinstance(right, dict):
right = [right]
return list(left) + list(right)
existing = list(left) if left is not None else []
if right is None:
return existing
additions = [right] if isinstance(right, Mapping) else list(right)
return existing + additions

View File

@@ -25,8 +25,8 @@ The Business Buddy framework provides extensive utilities through the core packa
- **`normalize_errors_to_list`**: Ensure errors are always list format
### Type Compatibility
- **`create_type_safe_wrapper`**: Wrap functions for LangGraph type safety
- **`wrap_for_langgraph`**: Decorator for type-safe conditional edges
- **Mapping-based routers**: `route_error_severity` and `route_llm_output` accept LangGraph state directly
- **`StateProtocol` utilities**: Build custom routers that interoperate with both dict and TypedDict states
### Caching Infrastructure
- **`LLMCache`**: Specialized cache for LLM responses
@@ -126,38 +126,34 @@ def create_state_from_input(user_input: str | dict) -> dict:
return state
```
### Type-Safe Wrappers for LangGraph
### Typed Routers for LangGraph
```python
from biz_bud.core.langgraph import (
create_type_safe_wrapper,
wrap_for_langgraph,
route_error_severity,
route_llm_output
)
from biz_bud.core.langgraph import route_error_severity, route_llm_output
from langgraph.graph import StateGraph
# Create type-safe wrappers for routing functions
safe_error_router = create_type_safe_wrapper(route_error_severity)
safe_llm_router = create_type_safe_wrapper(route_llm_output)
# Or use decorator pattern
@wrap_for_langgraph(dict)
def custom_router(state: MyTypedState) -> str:
"""Custom routing logic with automatic type casting."""
return route_error_severity(state)
# Use in graph construction
# Router functions operate on ``Mapping[str, object]`` so TypedDict states work
# with LangGraph without additional adapters.
builder = StateGraph(MyTypedState)
builder.add_conditional_edges(
"process_node",
safe_error_router, # No type errors!
route_error_severity,
{
"retry": "process_node",
"error": "error_handler",
"continue": "next_node"
}
)
builder.add_conditional_edges(
"call_model",
route_llm_output,
{
"tool_executor": "tools",
"output": "END",
"error_handling": "error_handler"
}
)
```
### Async/Sync Factory Pattern
@@ -592,7 +588,6 @@ from biz_bud.core.langgraph import (
standard_node,
handle_errors,
log_node_execution,
create_type_safe_wrapper,
route_error_severity,
route_llm_output
)
@@ -609,10 +604,6 @@ class ProductionState(InputState):
results: Annotated[list[dict], operator.add] = []
confidence_score: float = 0.0
# Create type-safe routers
safe_error_router = create_type_safe_wrapper(route_error_severity)
safe_llm_router = create_type_safe_wrapper(route_llm_output)
@standard_node
@handle_errors
@log_node_execution
@@ -670,7 +661,7 @@ async def create_production_graph_with_caching():
# Use type-safe routing
builder.add_conditional_edges(
"process",
safe_error_router,
route_error_severity,
{
"retry": "process",
"error": "error_handler",
@@ -900,7 +891,7 @@ async def process_large_dataset(items: list[dict]) -> list[dict]:
7. **Reuse Edge Helpers**: Use project's routing utilities to prevent duplication
8. **Leverage Core Utilities**:
- `generate_config_hash` for cache keys
- `create_type_safe_wrapper` for LangGraph compatibility
- Mapping-based routers (`route_error_severity`, `route_llm_output`) for LangGraph compatibility
- `normalize_errors_to_list` for consistent error handling
- `GraphCache` for thread-safe graph caching
- `handle_sync_async_context` for context-aware execution
@@ -942,8 +933,6 @@ from biz_bud.core.utils import normalize_errors_to_list
# Type compatibility
from biz_bud.core.langgraph import (
create_type_safe_wrapper,
wrap_for_langgraph,
standard_node,
handle_errors,
log_node_execution,

View File

@@ -7,10 +7,11 @@ data visualization, trend analysis, and business insights generation.
from __future__ import annotations
from typing import TYPE_CHECKING, Any, TypedDict, cast
from collections.abc import Mapping
from typing import TypedDict, cast
from langchain_core.runnables import RunnableConfig
from langgraph.graph.state import CachePolicy, RetryPolicy
from langgraph.graph.state import CachePolicy, CompiledStateGraph, RetryPolicy
from pydantic import BaseModel, Field
from biz_bud.core.edge_helpers.error_handling import handle_error
@@ -23,9 +24,6 @@ from biz_bud.core.langgraph.graph_builder import (
)
from biz_bud.logging import get_logger
if TYPE_CHECKING:
from langgraph.graph.state import CompiledStateGraph
# Import consolidated nodes
# Import analysis-specific nodes from local module
from biz_bud.graphs.analysis.nodes import (
@@ -49,8 +47,8 @@ class AnalysisGraphInput(BaseModel):
"""Input schema for the analysis graph."""
task: str = Field(description="The analysis task or request")
data: dict[str, Any] | None = Field(
default=None, description="Data to analyze (DataFrames, dicts, etc.)"
data: object | None = Field(
default=None, description="Data to analyze (DataFrames, mappings, etc.)"
)
include_visualizations: bool = Field(
default=True, description="Whether to generate visualizations"
@@ -69,11 +67,11 @@ class AnalysisGraphContext(TypedDict, total=False):
class AnalysisGraphOutput(TypedDict, total=False):
"""Output schema describing the terminal payload from the analysis graph."""
report: dict[str, Any] | str | None
analysis_results: dict[str, Any] | None
visualizations: list[dict[str, Any]]
report: dict[str, object] | str | None
analysis_results: dict[str, object] | None
visualizations: list[dict[str, object]]
status: str
errors: list[dict[str, Any]]
errors: list[dict[str, object]]
# Graph metadata for dynamic discovery
@@ -127,7 +125,7 @@ _handle_analysis_errors = handle_error(
)
def create_analysis_graph() -> "CompiledStateGraph":
def create_analysis_graph() -> CompiledStateGraph[AnalysisState]:
"""Create the data analysis workflow graph.
This graph implements a comprehensive analysis workflow:
@@ -263,7 +261,9 @@ def create_analysis_graph() -> "CompiledStateGraph":
return build_graph_from_config(config)
def analysis_graph_factory(config: RunnableConfig) -> "CompiledStateGraph":
def analysis_graph_factory(
config: RunnableConfig,
) -> CompiledStateGraph[AnalysisState]:
"""Create analysis graph for graph-as-tool pattern.
This factory function follows the standard pattern for graphs
@@ -288,7 +288,9 @@ def analysis_graph_factory(config: RunnableConfig) -> "CompiledStateGraph":
# Async factory function for LangGraph API
async def analysis_graph_factory_async(config: RunnableConfig) -> Any: # noqa: ANN401
async def analysis_graph_factory_async(
config: RunnableConfig,
) -> CompiledStateGraph[AnalysisState]:
"""Async wrapper for analysis_graph_factory to avoid blocking calls."""
import asyncio
return await asyncio.to_thread(analysis_graph_factory, config)
@@ -301,9 +303,9 @@ analysis_graph = create_analysis_graph()
# Helper function for easy invocation
async def analyze_data(
task: str,
data: dict[str, Any] | None = None,
data: object | None = None,
include_visualizations: bool = True,
config: dict[str, Any] | None = None,
config: Mapping[str, object] | None = None,
) -> AnalysisState:
"""Analyze data using the analysis workflow.

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from collections.abc import Mapping
from langchain_core.runnables import RunnableConfig
try: # pragma: no cover - optional checkpoint backend
@@ -10,6 +10,7 @@ try: # pragma: no cover - optional checkpoint backend
except ModuleNotFoundError: # pragma: no cover - fallback when postgres extra missing
PostgresSaver = None # type: ignore[assignment]
from langgraph.graph import END, StateGraph
from langgraph.graph.state import CompiledStateGraph
from biz_bud.core.edge_helpers.core import create_bool_router, create_enum_router
from biz_bud.core.edge_helpers.error_handling import handle_error
@@ -20,9 +21,6 @@ from biz_bud.core.langgraph.graph_builder import (
build_graph_from_config,
)
if TYPE_CHECKING:
from langgraph.graph.state import CompiledStateGraph
from biz_bud.nodes.error_handling import (
error_analyzer_node,
error_interceptor_node,
@@ -64,7 +62,7 @@ GRAPH_METADATA = {
def create_error_handling_graph(
checkpointer: PostgresSaver | None = None,
) -> "CompiledStateGraph":
) -> CompiledStateGraph[ErrorHandlingState]:
"""Create the error handling agent graph.
This graph can be used as a subgraph in any BizBud workflow
@@ -141,7 +139,7 @@ def check_recovery_success(state: ErrorHandlingState) -> bool:
return bool(state.get("recovery_success", False))
def check_for_errors(state: dict[str, Any]) -> str:
def check_for_errors(state: Mapping[str, object]) -> str:
"""Compatibility function that checks for errors in state."""
errors = state.get("errors", [])
status = state.get("status", "")
@@ -160,8 +158,8 @@ def check_error_recovery(state: ErrorHandlingState) -> str:
def add_error_handling_to_graph(
main_graph: "StateGraph",
error_handler: "CompiledStateGraph",
main_graph: StateGraph,
error_handler: CompiledStateGraph[ErrorHandlingState],
nodes_to_protect: list[str],
error_node_name: str = "handle_error",
next_node_mapping: dict[str, str] | None = None,
@@ -230,7 +228,7 @@ def create_error_handling_config(
retry_max_delay: int = 60,
enable_llm_analysis: bool = True,
recovery_timeout: int = 300,
) -> dict[str, Any]:
) -> dict[str, object]:
"""Create error handling configuration.
This is a public helper function that creates standardized error handling
@@ -300,7 +298,9 @@ def create_error_handling_config(
}
def error_handling_graph_factory(config: RunnableConfig) -> "CompiledStateGraph":
def error_handling_graph_factory(
config: RunnableConfig,
) -> CompiledStateGraph[ErrorHandlingState]:
"""Create error handling graph for LangGraph API.
Args:
@@ -314,7 +314,9 @@ def error_handling_graph_factory(config: RunnableConfig) -> "CompiledStateGraph"
# Async factory function for LangGraph API
async def error_handling_graph_factory_async(config: RunnableConfig) -> Any: # noqa: ANN401
async def error_handling_graph_factory_async(
config: RunnableConfig,
) -> CompiledStateGraph[ErrorHandlingState]:
"""Async wrapper for error_handling_graph_factory to avoid blocking calls."""
import asyncio
return await asyncio.to_thread(error_handling_graph_factory, config)

View File

@@ -17,7 +17,8 @@ The main graph represents a sophisticated agent workflow that can:
import asyncio
import os
from typing import TYPE_CHECKING, Any, TypeVar, cast
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, cast
if TYPE_CHECKING:
from biz_bud.services.factory import ServiceFactory
@@ -25,7 +26,11 @@ if TYPE_CHECKING:
# Import existing lazy loading infrastructure
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph.state import CachePolicy, CompiledStateGraph, RetryPolicy
from langgraph.graph.state import (
CachePolicy,
CompiledStateGraph as CompiledGraph,
RetryPolicy,
)
from biz_bud.core.caching import InMemoryCache
from biz_bud.core.cleanup_registry import get_cleanup_registry
@@ -38,10 +43,8 @@ from biz_bud.core.config.schemas import AppConfig
from biz_bud.core.edge_helpers import create_enum_router, detect_errors_list
from biz_bud.core.langgraph import (
NodeConfig,
create_type_safe_wrapper,
handle_errors,
log_node_execution,
route_error_severity,
route_llm_output,
standard_node,
)
@@ -65,8 +68,8 @@ from biz_bud.nodes import call_model_node, parse_and_validate_initial_payload
from biz_bud.services.factory import get_global_factory
from biz_bud.states.base import InputState
# Type variable for state types
StateType = TypeVar('StateType')
StateMapping = Mapping[str, object] | InputState
JsonDict = dict[str, object]
# Get logger instance
logger = get_logger(__name__)
@@ -90,7 +93,7 @@ def _create_async_factory_wrapper(sync_resolver_func, async_resolver_func):
"""
def create_sync_factory():
def sync_factory(config: RunnableConfig) -> CompiledStateGraph:
def sync_factory(config: RunnableConfig) -> CompiledGraph:
"""Create graph for LangGraph API with RunnableConfig (optimized)."""
from langchain_core.runnables import RunnableConfig
@@ -109,7 +112,7 @@ def _create_async_factory_wrapper(sync_resolver_func, async_resolver_func):
return sync_factory
def create_async_factory():
async def async_factory(config: RunnableConfig) -> CompiledStateGraph:
async def async_factory(config: RunnableConfig) -> CompiledGraph:
"""Create graph for LangGraph API with RunnableConfig (async optimized)."""
from langchain_core.runnables import RunnableConfig
@@ -130,7 +133,7 @@ def _create_async_factory_wrapper(sync_resolver_func, async_resolver_func):
return create_sync_factory(), create_async_factory()
def _handle_sync_async_context(app_config: AppConfig, service_factory: "ServiceFactory") -> CompiledStateGraph:
def _handle_sync_async_context(app_config: AppConfig, service_factory: "ServiceFactory") -> CompiledGraph:
"""Handle sync/async context detection for graph creation.
This function uses centralized context detection from core.networking
@@ -156,7 +159,7 @@ def _handle_sync_async_context(app_config: AppConfig, service_factory: "ServiceF
# Enhanced input validation using edge helpers
def _validate_parsed_input(state: dict[str, Any] | InputState) -> bool:
def _validate_parsed_input(state: StateMapping) -> bool:
"""Enhanced validation for parsed input content.
This helper checks both the presence and validity of parsed input data
@@ -171,7 +174,7 @@ def _validate_parsed_input(state: dict[str, Any] | InputState) -> bool:
return bool(user_query and user_query.strip())
# Create comprehensive input validation router using edge helpers
def _enhanced_input_check(state: dict[str, Any] | InputState) -> str:
def _enhanced_input_check(state: StateMapping) -> str:
"""Enhanced input validation that combines error detection with content validation.
Uses the Business Buddy pattern of checking both errors and input validity
@@ -195,13 +198,8 @@ check_initial_input_simple = detect_errors_list(
)
# Use type-safe wrappers from core.langgraph
route_error_severity_wrapper = create_type_safe_wrapper(route_error_severity)
route_llm_output_wrapper = create_type_safe_wrapper(route_llm_output)
# Replace manual routing with edge helper-based recovery routing
def _determine_recovery_action(state: dict[str, Any]) -> str:
def _determine_recovery_action(state: Mapping[str, object]) -> str:
"""Determine recovery action based on error handler decisions.
This function consolidates the recovery logic and provides a clear
@@ -230,7 +228,7 @@ route_error_recovery = create_enum_router(
)
# Enhanced router that sets the recovery action in state
def _enhanced_error_recovery(state: dict[str, Any]) -> str:
def _enhanced_error_recovery(state: Mapping[str, object]) -> str:
"""Enhanced error recovery that uses edge helpers pattern."""
recovery_action = _determine_recovery_action(state)
@@ -273,83 +271,205 @@ GRAPH_METADATA = {
@standard_node()
@handle_errors()
@log_node_execution("search")
async def search(state: Any) -> Any: # noqa: ANN401
async def search(state: StateMapping) -> JsonDict:
"""Execute web search using the unified search tool with proper error handling."""
from collections.abc import Mapping as _Mapping
from biz_bud.tools.capabilities.search.tool import web_search
# Extract search query from state
search_query = None
messages = state.get("messages", [])
state_mapping = state if isinstance(state, Mapping) else {}
search_query: str | None = None
messages_value = state_mapping.get("messages", [])
messages: list[object] = []
if isinstance(messages_value, Sequence) and not isinstance(
messages_value, (str, bytes, bytearray)
):
messages = list(messages_value)
# Look for the most recent tool call or human message to extract query
for message in reversed(messages):
if hasattr(message, "tool_calls") and message.tool_calls:
# Extract query from tool call arguments
for tool_call in message.tool_calls:
if tool_call.get("name") == "web_search" or "search" in tool_call.get("name", ""):
args = tool_call.get("args", {})
search_query = args.get("query") or args.get("search_query")
break
if search_query:
tool_calls: object | None = None
if isinstance(message, Mapping):
tool_calls = message.get("tool_calls")
if tool_calls is None:
additional_kwargs = message.get("additional_kwargs")
if isinstance(additional_kwargs, Mapping):
tool_calls = additional_kwargs.get("tool_calls")
if tool_calls is None:
try:
tool_calls = getattr(message, "tool_calls")
except Exception:
tool_calls = None
tool_query_found = False
if isinstance(tool_calls, Sequence) and not isinstance(
tool_calls, (str, bytes, bytearray)
) and tool_calls:
for tool_call in tool_calls:
if not isinstance(tool_call, _Mapping):
continue
name = str(tool_call.get("name", ""))
if name == "web_search" or "search" in name:
args = tool_call.get("args", {}) or {}
if not isinstance(args, _Mapping):
args = {}
candidate = (
args.get("query")
or args.get("search_query")
or args.get("q")
or args.get("keywords")
)
if isinstance(candidate, str) and candidate.strip():
search_query = candidate.strip()
tool_query_found = True
break
if tool_query_found:
break
elif hasattr(message, "content") and message.content:
# Fall back to using message content as query
search_query = str(message.content)[:200] # Limit query length
# Do not fall back to message content when tool calls were present
continue
content: object | None = None
if isinstance(message, Mapping):
content = message.get("content")
if content is None:
try:
content = getattr(message, "content")
except Exception:
content = None
if isinstance(content, str) and content.strip():
search_query = content.strip()
break
# Default fallback query if none found
if not search_query:
search_query = state.get("user_query", "business intelligence search")
if not isinstance(search_query, str) or not search_query.strip():
user_query_value = state_mapping.get("user_query")
if isinstance(user_query_value, str) and user_query_value.strip():
search_query = user_query_value.strip()
else:
search_query = "business intelligence search"
logger.info(f"Executing web search for query: {search_query}")
normalized_query = "".join(ch for ch in search_query if ch.isprintable())
normalized_query = normalized_query.replace("\n", " ").replace("\r", " ").strip()
safe_query = normalized_query[:200] if normalized_query else "business intelligence search"
def _mask_for_log(query: str) -> str:
cleaned = "".join(ch for ch in query if ch.isprintable()).replace("\n", " ").replace("\r", " ")
trimmed = cleaned.strip()
return trimmed if len(trimmed) <= 32 else f"{trimmed[:29]}..."
logger.info("Executing web search for query: %s", _mask_for_log(safe_query))
status = state_mapping.get("status", "running")
validation_issues_value = state_mapping.get("validation_issues")
if isinstance(validation_issues_value, Sequence) and not isinstance(
validation_issues_value, (str, bytes, bytearray)
):
validation_issues = list(validation_issues_value)
else:
validation_issues = []
try:
# Use the web_search tool with proper tool invocation
# Since web_search is decorated with @tool, we need to invoke it properly
search_results = await web_search.ainvoke({
"query": search_query,
"provider": None, # Auto-select best available provider
"max_results": 5 # Reasonable number of results for context
raw_results = await web_search.ainvoke({
"query": safe_query,
"provider": None,
"max_results": 5,
})
# Format search results for the state
if search_results:
# Create a summary of search results
result_summary = f"Found {len(search_results)} search results for '{search_query}':\n\n"
for i, result in enumerate(search_results[:3], 1): # Show top 3 results
result_summary += f"{i}. {result.get('title', 'No title')}\n"
result_summary += f" URL: {result.get('url', 'No URL')}\n"
if result.get('snippet'):
result_summary += f" Summary: {result.get('snippet')[:150]}...\n"
result_summary += "\n"
results_list: list[dict[str, object]] = []
MAX_FIELD_LENGTH = 500
# Update state with search results
state_update = {
**state,
"search_results": search_results,
"final_response": result_summary,
"is_last_step": True
def _normalize_text(value: object, default: str) -> str:
text = value if isinstance(value, str) else (str(value) if value is not None else default)
cleaned = "".join(ch for ch in text if ch.isprintable()).strip()
if not cleaned:
cleaned = default
return cleaned[:MAX_FIELD_LENGTH]
def _normalize_url(value: object) -> str:
raw = value if isinstance(value, str) else (str(value) if value is not None else "")
cleaned = "".join(ch for ch in raw if ch.isprintable()).strip()
return cleaned[:MAX_FIELD_LENGTH] if cleaned else ""
if isinstance(raw_results, Sequence) and not isinstance(
raw_results, (str, bytes, bytearray)
):
for item in raw_results:
if isinstance(item, _Mapping):
title = _normalize_text(item.get("title"), "No title")
url = _normalize_url(item.get("url")) or "No URL"
snippet_source = item.get("snippet") or item.get("description")
else:
title = _normalize_text(getattr(item, "title", None), "No title")
url = _normalize_url(getattr(item, "url", None)) or "No URL"
snippet_source = getattr(item, "snippet", None) or getattr(
item, "description", None
)
snippet = _normalize_text(snippet_source, "")
if title or url or snippet:
results_list.append({"title": title, "url": url, "snippet": snippet})
if results_list:
result_summary_lines: list[str] = []
for index, result in enumerate(results_list[:3], start=1):
title = result.get("title", "No title")
url = result.get("url", "No URL")
snippet = result.get("snippet") or result.get("description")
result_summary_lines.append(f"{index}. {title}")
result_summary_lines.append(f" URL: {url}")
if isinstance(snippet, str) and snippet:
safe_snippet = "".join(ch for ch in snippet if ch.isprintable()).strip()
result_summary_lines.append(f" Summary: {safe_snippet[:150]}...")
result_summary_lines.append("")
state_update: JsonDict = {
"search_query": safe_query,
"search_results": list(results_list),
"final_response": "\n".join(result_summary_lines).strip(),
"is_last_step": True,
"status": "success" if status != "error" else status,
"validation_issues": validation_issues,
}
logger.info(f"Web search completed successfully with {len(search_results)} results")
logger.info(
"Web search completed successfully with %s results",
len(results_list),
)
else:
# No results found
masked_query = _mask_for_log(safe_query)
state_update = {
**state,
"final_response": f"No search results found for query: '{search_query}'. Please try a different search term.",
"is_last_step": True
"search_query": safe_query,
"final_response": (
"No search results found for query: "
f"'{masked_query}'. Please try a different search term."
),
"is_last_step": True,
"status": "success" if status != "error" else status,
"validation_issues": validation_issues,
}
logger.warning(f"No search results found for query: {search_query}")
logger.warning("No search results found for query: %s", masked_query)
except Exception as e:
# Handle search errors gracefully
logger.error(f"Web search failed for query '{search_query}': {e}")
except Exception as exc: # pragma: no cover - defensive logging path
masked_query = _mask_for_log(safe_query)
logger.error("Web search failed for query '%s': %s", masked_query, exc)
errors_value = state_mapping.get("errors")
if isinstance(errors_value, Sequence) and not isinstance(
errors_value, (str, bytes, bytearray)
):
errors_list = list(errors_value)
elif errors_value is None:
errors_list = []
else:
errors_list = [errors_value]
errors_list.append(f"Search error: {exc}")
if len(errors_list) > 1000:
errors_list = errors_list[-1000:]
validation_issues.append("search_failed")
state_update = {
**state,
"final_response": f"Search service temporarily unavailable. Error: {str(e)}",
"search_query": safe_query,
"final_response": (
"Search service temporarily unavailable. Error: %s" % (exc,)
),
"is_last_step": True,
"errors": state.get("errors", []) + [f"Search error: {str(e)}"]
"errors": errors_list,
"status": "error",
"validation_issues": validation_issues,
}
return state_update
@@ -357,7 +477,7 @@ async def search(state: Any) -> Any: # noqa: ANN401
def create_graph() -> CompiledStateGraph:
def create_graph() -> CompiledGraph:
"""Build and compile the complete StateGraph for Business Buddy agent execution.
This function constructs the main workflow graph that serves as the execution
@@ -443,7 +563,7 @@ def create_graph() -> CompiledStateGraph:
),
ConditionalEdgeConfig(
"call_model_node",
route_llm_output_wrapper,
route_llm_output,
{
"tool_executor": "tools",
"output": "END",
@@ -470,7 +590,7 @@ def create_graph() -> CompiledStateGraph:
async def create_graph_with_services(
app_config: AppConfig, service_factory: "ServiceFactory"
) -> CompiledStateGraph:
) -> CompiledGraph:
"""Create graph with service factory injection using caching.
Args:
@@ -562,8 +682,8 @@ async def _get_or_create_service_factory_async(config_hash: str, app_config: App
async def create_graph_with_overrides_async(
config: dict[str, Any],
) -> CompiledStateGraph:
config: Mapping[str, object],
) -> CompiledGraph:
"""Async version of create_graph_with_overrides.
Same functionality as create_graph_with_overrides but uses async config loading
@@ -634,7 +754,7 @@ def _load_config_with_logging() -> AppConfig:
_graph_cache_manager = InMemoryCache[CompiledStateGraph](max_size=100)
_graph_cache_manager = InMemoryCache[CompiledGraph](max_size=100)
_graph_creation_locks: dict[str, asyncio.Lock] = {}
_locks_lock = asyncio.Lock() # Lock for managing the locks dict
_graph_cache_lock = asyncio.Lock() # Keep for backward compatibility
@@ -646,7 +766,7 @@ cleanup_registry.register_cleanup("graph_cache", _graph_cache_manager.clear)
# Configuration cache using lazy loader
_config_loader = create_lazy_loader(_load_config_with_logging)
_state_template_cache: dict[str, Any] | None = None
_state_template_cache: JsonDict | None = None
_state_template_lock = asyncio.Lock()
@@ -670,7 +790,7 @@ async def get_cached_graph(
config_hash: str = "default",
service_factory: "ServiceFactory | None" = None,
use_caching: bool = True
) -> CompiledStateGraph:
) -> CompiledGraph:
"""Get cached compiled graph with optional service injection using GraphCache.
Args:
@@ -682,7 +802,7 @@ async def get_cached_graph(
Compiled and cached graph instance
"""
# Use GraphCache for thread-safe caching
async def build_graph_for_cache() -> CompiledStateGraph:
async def build_graph_for_cache() -> CompiledGraph:
logger.info(f"Creating new graph instance for config: {config_hash}")
# Get or create service factory
@@ -728,7 +848,7 @@ async def get_cached_graph(
),
ConditionalEdgeConfig(
"call_model_node",
route_llm_output_wrapper,
route_llm_output,
{
"tool_executor": "tools",
"output": "END",
@@ -798,7 +918,7 @@ async def get_cached_graph(
return graph
def get_graph() -> CompiledStateGraph:
def get_graph() -> CompiledGraph:
"""Get the singleton graph instance (backward compatibility)."""
try:
asyncio.get_running_loop()
@@ -816,9 +936,9 @@ def get_graph() -> CompiledStateGraph:
# For backward compatibility - direct access (lazy initialization)
_module_graph: CompiledStateGraph | None = None
_module_graph: CompiledGraph | None = None
def get_module_graph() -> CompiledStateGraph:
def get_module_graph() -> CompiledGraph:
"""Get module-level graph instance (lazy initialization)."""
global _module_graph
if _module_graph is None:
@@ -882,7 +1002,7 @@ def reset_caches() -> None:
logger.info("Reset graph and state caches using CleanupRegistry")
def get_cache_stats() -> dict[str, Any]:
def get_cache_stats() -> JsonDict:
"""Get cache statistics for monitoring using InMemoryCache.
Returns:
@@ -916,12 +1036,12 @@ debug_highlight("Cache infrastructure: graphs, service factories, state template
# Create lazy config access for backward compatibility
async def get_config_dict() -> dict[str, Any]:
async def get_config_dict() -> JsonDict:
"""Get the config dictionary lazily (async version)."""
app_config = await get_app_config()
debug_highlight(f"[GRAPH_INIT] app_config: {app_config}", category="GRAPH_INIT")
# Ensure the config has the required structure
config_dict: dict[str, Any] = {
config_dict: JsonDict = {
**app_config.model_dump(), # Include all existing app_config as dict
"llm_config": app_config.llm_config.model_dump()
if app_config.llm_config
@@ -942,12 +1062,12 @@ async def get_config_dict() -> dict[str, Any]:
return config_dict
def get_config_dict_sync() -> dict[str, Any]:
def get_config_dict_sync() -> JsonDict:
"""Get the config dictionary synchronously (backward compatibility)."""
app_config = get_app_config_sync()
debug_highlight(f"[GRAPH_INIT] app_config: {app_config}", category="GRAPH_INIT")
# Ensure the config has the required structure
config_dict: dict[str, Any] = {
config_dict: JsonDict = {
**app_config.model_dump(), # Include all existing app_config as dict
"llm_config": app_config.llm_config.model_dump()
if app_config.llm_config
@@ -968,7 +1088,7 @@ def get_config_dict_sync() -> dict[str, Any]:
return config_dict
async def get_state_template() -> dict[str, Any]:
async def get_state_template() -> JsonDict:
"""Get cached state template for fast state creation.
Returns:
@@ -997,10 +1117,10 @@ async def get_state_template() -> dict[str, Any]:
async def create_initial_state(
query: str | None = None,
raw_input: str | dict[str, Any] | None = None,
messages: list[dict[str, Any]] | None = None,
raw_input: str | Mapping[str, object] | None = None,
messages: list[dict[str, object]] | None = None,
thread_id: str | None = None,
state_update: dict[str, Any] | None = None,
state_update: Mapping[str, object] | None = None,
) -> InputState:
"""Create initial state with dynamic user input (optimized async version).
@@ -1052,10 +1172,10 @@ async def create_initial_state(
def create_initial_state_sync(
query: str | None = None,
raw_input: str | dict[str, Any] | None = None,
messages: list[dict[str, Any]] | None = None,
raw_input: str | Mapping[str, object] | None = None,
messages: list[dict[str, object]] | None = None,
thread_id: str | None = None,
state_update: dict[str, Any] | None = None,
state_update: Mapping[str, object] | None = None,
) -> InputState:
"""Create initial state synchronously (backward compatibility).
@@ -1101,8 +1221,8 @@ async def get_initial_state_async() -> InputState:
def run_graph(
query: str | None = None,
raw_input: str | dict[str, Any] | None = None,
messages: list[dict[str, Any]] | None = None,
raw_input: str | Mapping[str, object] | None = None,
messages: list[dict[str, object]] | None = None,
thread_id: str | None = None,
) -> InputState:
"""Execute the Business Buddy agent graph with optional custom input.
@@ -1136,8 +1256,8 @@ def run_graph(
async def run_graph_async(
query: str | None = None,
raw_input: str | dict[str, Any] | None = None,
messages: list[dict[str, Any]] | None = None,
raw_input: str | Mapping[str, object] | None = None,
messages: list[dict[str, object]] | None = None,
thread_id: str | None = None,
config_hash: str = "default",
) -> InputState:

View File

@@ -60,14 +60,14 @@ from biz_bud.tools.capabilities.external.paperless.tool import (
from biz_bud.tools.capabilities.search.tool import list_search_providers, web_search
if TYPE_CHECKING:
from langgraph.graph.state import CompiledStateGraph
from langgraph.graph.state import CompiledStateGraph as CompiledGraph
from biz_bud.services.factory import ServiceFactory
logger = get_logger(__name__)
# Module-level caches for performance optimization
_compiled_graph_cache: dict[str, "CompiledStateGraph"] = {}
_compiled_graph_cache: dict[str, "CompiledGraph"] = {}
_global_factory: ServiceFactory | None = None
_llm_cache: LLMCache[str] | None = None
_cache_ttl = 300 # 5 minutes default TTL for LLM responses
@@ -676,7 +676,7 @@ def should_continue(state: dict[str, Any]) -> str:
def create_paperless_agent(
config: dict[str, Any] | str | None = None,
) -> "CompiledStateGraph":
) -> "CompiledGraph":
"""Create a Paperless agent using Business Buddy patterns with caching.
This creates an agent that can bind tools to the LLM and execute them

View File

@@ -44,7 +44,7 @@ from biz_bud.states.base import BaseState
from biz_bud.states.receipt import ReceiptState
if TYPE_CHECKING:
from langgraph.graph.state import CompiledStateGraph
from langgraph.graph.state import CompiledStateGraph as CompiledGraph
logger = get_logger(__name__)
@@ -162,7 +162,7 @@ def _create_receipt_processing_graph_internal(
config: dict[str, Any] | None = None,
app_config: object | None = None,
service_factory: object | None = None,
) -> CompiledStateGraph:
) -> CompiledGraph:
"""Internal function to create a focused receipt processing graph using ReceiptState."""
# Define nodes
nodes = {
@@ -202,7 +202,7 @@ def _create_receipt_processing_graph_internal(
return compiled_graph
def create_receipt_processing_graph(config: RunnableConfig) -> CompiledStateGraph:
def create_receipt_processing_graph(config: RunnableConfig) -> CompiledGraph:
"""Create a focused receipt processing graph for LangGraph API.
Args:
@@ -218,7 +218,7 @@ def create_receipt_processing_graph_direct(
config: dict[str, Any] | None = None,
app_config: object | None = None,
service_factory: object | None = None,
) -> CompiledStateGraph:
) -> CompiledGraph:
"""Create a focused receipt processing graph for direct usage.
Args:
@@ -236,7 +236,7 @@ def create_paperless_graph(
config: dict[str, Any] | None = None,
app_config: object | None = None,
service_factory: object | None = None,
) -> CompiledStateGraph:
) -> CompiledGraph:
"""Create the standardized Paperless NGX document management graph.
This graph provides intelligent document management workflows with:
@@ -402,7 +402,7 @@ def create_paperless_graph(
return compiled_graph
def paperless_graph_factory(config: RunnableConfig) -> CompiledStateGraph:
def paperless_graph_factory(config: RunnableConfig) -> CompiledGraph:
"""Create Paperless graph for LangGraph API.
Args:
@@ -421,7 +421,7 @@ async def paperless_graph_factory_async(config: RunnableConfig) -> Any: # noqa:
return await asyncio.to_thread(paperless_graph_factory, config)
def receipt_processing_graph_factory(config: RunnableConfig) -> CompiledStateGraph:
def receipt_processing_graph_factory(config: RunnableConfig) -> CompiledGraph:
"""Create receipt processing graph for LangGraph API.
Args:

View File

@@ -2,7 +2,8 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Literal, cast
from collections.abc import Callable
from typing import TYPE_CHECKING, AsyncGenerator, Literal, cast
from langchain_core.runnables import RunnableConfig
from langgraph.graph.state import CachePolicy, RetryPolicy
@@ -43,10 +44,10 @@ from biz_bud.graphs.rag.nodes.scraping import (
scrape_status_summary_node,
)
from biz_bud.nodes import finalize_status_node, preserve_url_fields_node
from biz_bud.states.url_to_rag import URLToRAGState
from biz_bud.states.url_to_rag import StatePayload, URLToRAGState
if TYPE_CHECKING:
from langgraph.graph.state import CompiledStateGraph
from langgraph.graph.state import CompiledStateGraph as CompiledGraph
from biz_bud.services.factory import ServiceFactory
@@ -55,7 +56,7 @@ class URLToRAGGraphInput(TypedDict, total=False):
url: str
input_url: str
config: dict[str, Any]
config: StatePayload
collection_name: str | None
force_refresh: bool
@@ -65,10 +66,10 @@ class URLToRAGGraphOutput(TypedDict, total=False):
status: Literal["pending", "running", "success", "error"]
error: str | None
r2r_info: dict[str, Any] | None
scraped_content: list[dict[str, Any]]
r2r_info: StatePayload | None
scraped_content: list[StatePayload]
repomix_output: str | None
upload_tracker: dict[str, Any] | None
upload_tracker: StatePayload | None
class URLToRAGGraphContext(TypedDict, total=False):
@@ -76,7 +77,7 @@ class URLToRAGGraphContext(TypedDict, total=False):
service_factory: "ServiceFactory" | None
request_id: str | None
metadata: dict[str, Any]
metadata: StatePayload
logger = get_logger(__name__)
@@ -105,7 +106,7 @@ GRAPH_METADATA = {
def _create_initial_state(
url: str,
config: dict[str, Any],
config: StatePayload,
collection_name: str | None = None,
force_refresh: bool = False,
) -> URLToRAGState:
@@ -190,7 +191,7 @@ _should_scrape_or_skip = create_list_length_router(
_should_process_next_url = create_bool_router("finalize", "check_duplicate", "batch_complete")
def create_url_to_r2r_graph(config: dict[str, Any] | None = None) -> "CompiledStateGraph":
def create_url_to_r2r_graph(config: StatePayload | None = None) -> "CompiledGraph":
"""Create the URL to R2R processing graph with iterative URL processing.
This graph processes URLs one at a time through the complete pipeline,
@@ -464,7 +465,7 @@ def create_url_to_r2r_graph(config: dict[str, Any] | None = None) -> "CompiledSt
# Factory function for LangGraph API
def url_to_r2r_graph_factory(config: RunnableConfig) -> Any: # noqa: ANN401
def url_to_r2r_graph_factory(config: RunnableConfig) -> "CompiledGraph":
"""Create URL to R2R graph for LangGraph API with RunnableConfig."""
# Use centralized config resolution to handle all overrides at entry point
# Resolve configuration with any RunnableConfig overrides (sync version)
@@ -481,7 +482,7 @@ def url_to_r2r_graph_factory(config: RunnableConfig) -> Any: # noqa: ANN401
service_factory = ServiceFactory(app_config)
# Convert AppConfig to dict format expected by the graph
config_dict = app_config.model_dump()
config_dict: StatePayload = app_config.model_dump()
# Inject service factory into config for nodes to access
config_dict["service_factory"] = service_factory
@@ -490,7 +491,7 @@ def url_to_r2r_graph_factory(config: RunnableConfig) -> Any: # noqa: ANN401
# Async factory function for LangGraph API
async def url_to_r2r_graph_factory_async(config: RunnableConfig) -> Any: # noqa: ANN401
async def url_to_r2r_graph_factory_async(config: RunnableConfig) -> "CompiledGraph":
"""Async wrapper for url_to_r2r_graph_factory to avoid blocking calls."""
import asyncio
return await asyncio.to_thread(url_to_r2r_graph_factory, config)
@@ -503,7 +504,7 @@ url_to_r2r_graph = create_url_to_r2r_graph
# Usage example
async def _process_url_to_r2r(
url: str,
config: dict[str, Any],
config: StatePayload,
collection_name: str | None = None,
force_refresh: bool = False,
) -> URLToRAGState:
@@ -555,10 +556,10 @@ async def _process_url_to_r2r(
async def _stream_url_to_r2r(
url: str,
config: dict[str, Any],
config: StatePayload,
collection_name: str | None = None,
force_refresh: bool = False,
) -> AsyncGenerator[dict[str, Any], None]:
) -> AsyncGenerator[StatePayload, None]:
"""Process a URL and upload to R2R, yielding streaming updates.
Args:
@@ -603,8 +604,8 @@ async def _stream_url_to_r2r(
async def _process_url_to_r2r_with_streaming(
url: str,
config: dict[str, Any],
on_update: Callable[[dict[str, Any]], None] | None = None,
config: StatePayload,
on_update: Callable[[StatePayload], None] | None = None,
collection_name: str | None = None,
force_refresh: bool = False,
) -> URLToRAGState:
@@ -665,7 +666,7 @@ stream_url_to_r2r = _stream_url_to_r2r
process_url_to_r2r_with_streaming = _process_url_to_r2r_with_streaming
def url_to_rag_graph_factory(config: RunnableConfig) -> "CompiledStateGraph":
def url_to_rag_graph_factory(config: RunnableConfig) -> "CompiledGraph":
"""Create URL to RAG graph for graph-as-tool pattern.
This factory function follows the standard pattern for graphs

View File

@@ -7,7 +7,7 @@ and various vector store integrations.
from __future__ import annotations
from typing import Any
from collections.abc import Mapping, Sequence
from langchain_core.runnables import RunnableConfig
@@ -34,10 +34,23 @@ firecrawl_scrape_node = None
firecrawl_batch_scrape_node = None
StatePayload = dict[str, object]
def _coerce_sequence(value: object) -> Sequence[Mapping[str, object]]:
if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
items: list[Mapping[str, object]] = []
for element in value:
if isinstance(element, Mapping):
items.append(element)
return items
return []
@standard_node(node_name="vector_store_upload", metric_name="vector_upload")
async def vector_store_upload_node(
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
state: Mapping[str, object], config: RunnableConfig
) -> StatePayload:
"""Upload prepared content to vector store.
This node handles the upload of prepared content to various vector stores
@@ -52,17 +65,17 @@ async def vector_store_upload_node(
"""
info_highlight("Uploading content to vector store...", category="VectorUpload")
prepared_content = state.get("rag_prepared_content", [])
prepared_content = _coerce_sequence(state.get("rag_prepared_content"))
if not prepared_content:
warning_highlight("No prepared content to upload", category="VectorUpload")
return {"upload_results": {"documents_uploaded": 0}}
collection_name = state.get("collection_name", "default")
collection_name = str(state.get("collection_name", "default"))
try:
# Simulate upload (real implementation would use actual vector store service)
uploaded_count = 0
failed_uploads = []
failed_uploads: list[Mapping[str, object]] = []
for item in prepared_content:
if not item.get("ready_for_upload"):
@@ -77,7 +90,7 @@ async def vector_store_upload_node(
category="VectorUpload",
)
upload_results = {
upload_results: StatePayload = {
"documents_uploaded": uploaded_count,
"failed_uploads": failed_uploads,
"collection_name": collection_name,
@@ -109,8 +122,8 @@ async def vector_store_upload_node(
@standard_node(node_name="git_repo_processor", metric_name="git_processing")
async def process_git_repository_node(
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
state: Mapping[str, object], config: RunnableConfig
) -> StatePayload:
"""Process Git repository for RAG ingestion.
This node handles the special case of processing Git repositories,
@@ -125,10 +138,10 @@ async def process_git_repository_node(
"""
info_highlight("Processing Git repository...", category="GitProcessor")
if not state.get("is_git_repo"):
return state
if not bool(state.get("is_git_repo")):
return {}
input_url = state.get("input_url", "")
input_url = str(state.get("input_url", ""))
try:
# Use repomix if available

View File

@@ -8,7 +8,8 @@ from biz_bud.core for optimal performance and maintainability.
from __future__ import annotations
import uuid
from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast
from collections.abc import Mapping
from typing import TYPE_CHECKING, Literal, TypedDict, cast
from langchain_core.runnables import RunnableConfig
@@ -32,15 +33,16 @@ from biz_bud.core.langgraph.graph_builder import (
NodeConfig,
build_graph_from_config,
)
from langgraph.graph.state import CachePolicy, RetryPolicy
from langgraph.graph.state import CachePolicy, CompiledStateGraph, RetryPolicy
from biz_bud.core.utils import create_lazy_loader
from biz_bud.logging import get_logger
if TYPE_CHECKING:
from langgraph.graph.state import CompiledStateGraph
from langgraph.pregel import Pregel
from biz_bud.services.factory import ServiceFactory
from biz_bud.core.config.loader import load_config, load_config_async
from biz_bud.core.config.schemas import AppConfig
@@ -48,15 +50,11 @@ from biz_bud.core.config.schemas import AppConfig
from biz_bud.core.edge_helpers import create_conditional_langgraph_command_router
# Import research-specific nodes from local module
from biz_bud.graphs.research.nodes import (
derive_research_query_node,
)
from biz_bud.graphs.research.nodes import derive_research_query_node
from biz_bud.graphs.research.nodes import (
synthesize_search_results as synthesize_research_results_node,
)
from biz_bud.graphs.research.nodes import (
validate_research_synthesis_node,
)
from biz_bud.graphs.research.nodes import validate_research_synthesis_node
# Import consolidated core nodes
from biz_bud.nodes import (
@@ -65,16 +63,12 @@ from biz_bud.nodes import (
research_web_search_node,
semantic_extract_node,
)
from biz_bud.states.base import StateMetadata
from biz_bud.states.research import ResearchState
# Import nodes that haven't been consolidated yet
try:
from biz_bud.graphs.rag.nodes.rag_enhance import rag_enhance_node
from biz_bud.nodes.validation.human_feedback import human_feedback_node
except ImportError:
# These will be moved to appropriate graph modules later
rag_enhance_node = None
human_feedback_node = None
from biz_bud.graphs.rag.nodes.rag_enhance import rag_enhance_node
from biz_bud.nodes.validation.human_feedback import human_feedback_node
logger = get_logger(__name__)
@@ -93,17 +87,17 @@ class ResearchGraphOutput(TypedDict, total=False):
status: Literal["pending", "running", "complete", "error"]
synthesis: str
sources: list[dict[str, Any]]
validation_summary: dict[str, Any]
human_feedback: dict[str, Any]
sources: list[StateMetadata]
validation_summary: StateMetadata
human_feedback: StateMetadata
class ResearchGraphContext(TypedDict, total=False):
"""Optional runtime context injected into research graph executions."""
service_factory: Any | None
service_factory: "ServiceFactory" | None
request_id: str | None
metadata: dict[str, Any]
metadata: StateMetadata
# Graph metadata for dynamic discovery
GRAPH_METADATA = {
@@ -172,17 +166,6 @@ def _get_cached_config() -> AppConfig:
return load_config()
# Placeholder node for missing imports
async def _placeholder_node(
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Provide placeholder node for features not yet migrated."""
logger.warning("Placeholder node called - feature not yet migrated")
return state
# Create string-based router for conditional edges compatibility
def _search_retry_router(state: ResearchState) -> str:
"""Route search results based on retry logic.
@@ -321,7 +304,7 @@ _handle_synthesis_errors = handle_error(
)
def _prepare_synthesis_routing(state: ResearchState) -> dict[str, Any]:
def _prepare_synthesis_routing(state: ResearchState) -> dict[str, object]:
"""Prepare state for synthesis quality routing.
This helper adds synthesis_length to state for threshold routing.
@@ -338,7 +321,7 @@ def _prepare_synthesis_routing(state: ResearchState) -> dict[str, Any]:
return {"synthesis_length": synthesis_length}
def _prepare_search_results(state: ResearchState) -> dict[str, Any]:
def _prepare_search_results(state: ResearchState) -> dict[str, object]:
"""Prepare search results state for Command-based routing.
Args:
@@ -358,7 +341,7 @@ def _prepare_search_results(state: ResearchState) -> dict[str, Any]:
def create_research_graph(
checkpointer: PostgresSaver | None = None,
) -> "CompiledStateGraph":
) -> CompiledStateGraph[ResearchState]:
"""Create the consolidated research workflow graph.
This graph uses consolidated nodes, edge helper factories, and global
@@ -389,7 +372,7 @@ def create_research_graph(
cache_policy=CachePolicy(ttl=900),
),
"rag_enhance": NodeConfig(
func=rag_enhance_node or _placeholder_node,
func=rag_enhance_node,
metadata={
"category": "augmentation",
"description": "Enhance the research scope with retrieval augmented prompts",
@@ -448,7 +431,7 @@ def create_research_graph(
retry_policy=RetryPolicy(max_attempts=2),
),
"human_feedback": NodeConfig(
func=human_feedback_node or _placeholder_node,
func=human_feedback_node,
metadata={
"category": "feedback",
"description": "Escalate to human feedback when automated validation fails",
@@ -494,7 +477,7 @@ def create_research_graph(
]
# Wrapper function to extract route from Command object for LangGraph compatibility
def _synthesis_quality_route_extractor(state: dict[str, Any]) -> str:
def _synthesis_quality_route_extractor(state: Mapping[str, object]) -> str:
"""Extract route string from Command object for LangGraph compatibility."""
command = _synthesis_quality_router(state)
# Handle the case where goto might be a Send object or sequence
@@ -567,7 +550,9 @@ def create_research_graph(
# Factory function for LangGraph API
def research_graph_factory(config: RunnableConfig) -> "CompiledStateGraph":
def research_graph_factory(
config: RunnableConfig,
) -> CompiledStateGraph[ResearchState]:
"""Create research graph for LangGraph API with RunnableConfig.
This factory extracts configuration from RunnableConfig and sets up
@@ -595,14 +580,18 @@ def research_graph_factory(config: RunnableConfig) -> "CompiledStateGraph":
# Async factory function for LangGraph API
async def research_graph_factory_async(config: RunnableConfig) -> Any: # noqa: ANN401
async def research_graph_factory_async(
config: RunnableConfig,
) -> CompiledStateGraph[ResearchState]:
"""Async wrapper for research_graph_factory to avoid blocking calls."""
import asyncio
return await asyncio.to_thread(research_graph_factory, config)
# Async factory function that follows Business Buddy patterns
async def create_research_graph_async(config: Any = None) -> "CompiledStateGraph":
async def create_research_graph_async(
config: RunnableConfig | None = None,
) -> CompiledStateGraph[ResearchState]:
"""Create research graph using async patterns with service factory integration.
This function follows Business Buddy architectural patterns for async
@@ -688,7 +677,7 @@ def get_research_graph(
# Usage example for testing
async def process_research_query(
query: str, config: dict[str, Any] | None = None, derive_query: bool = True
query: str, config: dict[str, object] | None = None, derive_query: bool = True
) -> ResearchState:
"""Process a research query using the consolidated graph.

View File

@@ -23,7 +23,7 @@ from biz_bud.nodes import parse_and_validate_initial_payload
from biz_bud.states.base import BaseState
if TYPE_CHECKING:
from langgraph.graph.state import CompiledStateGraph
from langgraph.graph.state import CompiledStateGraph as CompiledGraph
logger = get_logger(__name__)
@@ -282,7 +282,7 @@ async def finalize_scraping(
# --- Graph Construction ---
def create_scraping_graph() -> "CompiledStateGraph":
def create_scraping_graph() -> "CompiledGraph":
"""Create the web scraping workflow graph.
This graph demonstrates parallel URL processing using the Send API,
@@ -361,7 +361,7 @@ GRAPH_METADATA = {
# Factory function
def scraping_graph_factory(config: RunnableConfig) -> "CompiledStateGraph":
def scraping_graph_factory(config: RunnableConfig) -> "CompiledGraph":
"""Create scraping graph for LangGraph API."""
return create_scraping_graph()

View File

@@ -7,73 +7,69 @@ suggesting recovery actions, and updating workflow state to reflect error status
from __future__ import annotations
from typing import TYPE_CHECKING, Any, cast
from biz_bud.core.langgraph import StateUpdater, ensure_immutable_node, standard_node
from biz_bud.logging import error_highlight, get_logger, info_highlight
if TYPE_CHECKING:
from collections.abc import Sequence
from biz_bud.core import ErrorInfo, ErrorRecoveryTypedDict
from collections.abc import Mapping, MutableMapping, Sequence
from typing import TypedDict, cast
from langchain_core.runnables import RunnableConfig
logger = get_logger(__name__)
from biz_bud.core import ErrorInfo, ErrorRecoveryTypedDict
from biz_bud.core.langgraph import StateUpdater, ensure_immutable_node, standard_node
from biz_bud.logging import error_highlight, info_highlight
WorkflowState = MutableMapping[str, object]
ErrorMapping = Mapping[str, object]
class ValidationErrorSummary(TypedDict):
"""Structured summary returned when validation fails."""
validation_status: str
failure_reason: list[str]
error_count: int
error_distribution: dict[str, int]
recommendation: str
@standard_node(node_name="handle_graph_error", metric_name="error_handling")
@ensure_immutable_node
async def handle_graph_error(
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Central error handler for the workflow graph.
state: WorkflowState, config: RunnableConfig
) -> WorkflowState:
"""Central error handler for the workflow graph."""
Args:
state (BusinessBuddyState): The current workflow state containing error information.
errors_raw = state.get("errors")
errors: list[ErrorInfo] = []
if isinstance(errors_raw, Sequence) and not isinstance(
errors_raw, (str, bytes, bytearray)
):
for error in errors_raw:
if isinstance(error, Mapping):
errors.append(cast(ErrorInfo, error))
Returns:
BusinessBuddyState: The updated state with error status, recovery actions, and workflow status.
This node function:
- Logs all errors found in the state.
- Determines if any errors are critical (preventing workflow continuation).
- Suggests recovery actions for non-critical errors based on error phase.
- Updates the state with error recovery actions and workflow status.
- Handles both critical and non-critical errors, marking the workflow as failed or recovered.
"""
# Get errors from state with proper typing
errors_raw = state.get("errors", []) or []
errors: Sequence[ErrorInfo] = cast(
"Sequence[ErrorInfo]", errors_raw if isinstance(errors_raw, list) else []
)
if not errors:
info_highlight("No errors found in state.")
return state # No errors to handle
return state
info_highlight(f"Handling {len(errors)} error(s) found in state.")
# Initialize variables to prevent potential UnboundLocalError
error_detail: ErrorInfo = cast("ErrorInfo", {})
phase = ""
error_msg = ""
for error_detail in cast("list[ErrorInfo]", errors):
for error_detail in errors:
phase = error_detail.get("phase", "unknown")
error_msg = error_detail.get("error", "unknown error")
error_highlight(f"Error in phase '{phase}': {error_msg}")
# Define phases considered critical where failure prevents continuation
critical_phases = ["input", "initialization", "config_load", "authentication"]
critical_phases: tuple[str, ...] = (
"input",
"initialization",
"config_load",
"authentication",
)
has_critical_error: bool = any(
error_detail.get("phase", "") in critical_phases
for error_detail in cast("list[ErrorInfo]", errors)
has_critical_error = any(
error_detail.get("phase", "") in critical_phases for error_detail in errors
)
if has_critical_error:
error_highlight("Critical error detected, cannot continue workflow.")
# Use StateUpdater for immutable updates
updater = StateUpdater(state)
return (
updater.set("critical_error", True)
@@ -82,9 +78,8 @@ async def handle_graph_error(
.build()
)
# Attempt to define recovery actions for non-critical errors
recovery_actions: list[ErrorRecoveryTypedDict] = []
for error_detail in cast("list[ErrorInfo]", errors):
for error_detail in errors:
phase = error_detail.get("phase", "unknown")
error_msg = error_detail.get("error", "")
action: ErrorRecoveryTypedDict = {
@@ -92,7 +87,6 @@ async def handle_graph_error(
"description": f"Logged error: {error_msg}",
}
# Simple phase-based recovery suggestions (expand as needed)
if phase in ["search_query", "web_search"]:
action = {
"action": "retry_with_different_query",
@@ -111,7 +105,10 @@ async def handle_graph_error(
elif phase in ["synthesis", "validation", "interpret", "report"]:
action = {
"action": "simplify_or_request_review",
"description": f"Consider simplifying scope or requesting human review due to error in {phase}.",
"description": (
"Consider simplifying scope or requesting human review due to error in "
f"{phase}."
),
}
elif phase == "human_feedback":
action = {
@@ -126,16 +123,16 @@ async def handle_graph_error(
recovery_actions.append(action)
# Use StateUpdater for immutable updates
updater = StateUpdater(state)
current_status = state.get("workflow_status", "recovered")
# Track retry attempts
retry_count = state.get("retry_count", 0)
if isinstance(retry_count, int):
retry_count += 1
else:
retry_count = 1
raw_retry = state.get("retry_count", 0)
try:
retry_count_int = int(raw_retry) # type: ignore[arg-type]
if retry_count_int < 0:
retry_count_int = 0
except (TypeError, ValueError):
retry_count_int = 0
retry_count = retry_count_int + 1
new_state = (
updater.set("error_recovery", recovery_actions)
@@ -151,145 +148,98 @@ async def handle_graph_error(
return new_state
@standard_node(
node_name="handle_validation_failure", metric_name="validation_error_handling"
)
@standard_node(node_name="handle_validation_failure", metric_name="validation_error_handling")
@ensure_immutable_node
async def handle_validation_failure(
state: dict[str, Any], config: RunnableConfig | None
) -> dict[str, Any]:
"""Handle validation failures. Updates 'validation_error_summary' and 'validation_passed' in the workflow state.
state: WorkflowState, config: RunnableConfig | None
) -> WorkflowState:
"""Handle validation failures."""
Args:
state (BusinessBuddyState): The current workflow state containing validation results and errors.
errors_value = state.get("errors")
errors: Sequence[ErrorMapping]
if isinstance(errors_value, Sequence):
errors = tuple(
cast(ErrorMapping, error)
for error in errors_value
if isinstance(error, Mapping)
)
else:
errors = ()
Returns:
BusinessBuddyState: The updated state with validation error summary, pass/fail status, and workflow status.
This node function:
- Extracts validation errors and results from the state.
- Checks for failures in fact checking, logic validation, consistency, and human feedback.
- Aggregates issues and determines the overall validation status.
- Updates the state with a summary of validation failures and recommendations.
- Handles and logs any errors encountered during validation.
"""
from typing import cast
# Cast state to dict for dynamic access
state_dict = cast("dict[str, object]", state)
errors_raw = state_dict.get("errors", []) or []
errors: Sequence[dict[str, object]] = (
errors_raw if isinstance(errors_raw, list) else []
)
validation_phases: list[str] = [
validation_phases: tuple[str, ...] = (
"criteria",
"fact_check",
"logic_check",
"consistency_check",
"human_feedback",
"validate",
]
)
validation_errors: list[dict[str, object]] = []
for err in errors or []:
if isinstance(err, dict) and err.get("phase", "") in validation_phases:
# Convert to dict[str, object] by ensuring all values are objects
obj_err: dict[str, object] = {k: str(v) for k, v in err.items()}
for err in errors:
if err.get("phase", "") in validation_phases:
obj_err: dict[str, object] = {str(k): str(v) for k, v in err.items()}
validation_errors.append(obj_err)
# Get values from state with proper typing
fact_check_raw = state_dict.get("fact_check_results")
fact_check_results: dict[str, object] | None = (
fact_check_raw if isinstance(fact_check_raw, dict) else None
fact_check_raw = state.get("fact_check_results")
fact_check_results: Mapping[str, object] | None = (
fact_check_raw if isinstance(fact_check_raw, Mapping) else None
)
logic_raw = state_dict.get("logic_validation")
logic_validation: dict[str, object] | None = (
logic_raw if isinstance(logic_raw, dict) else None
logic_raw = state.get("logic_validation")
logic_validation: Mapping[str, object] | None = (
logic_raw if isinstance(logic_raw, Mapping) else None
)
consistency_raw = state_dict.get("consistency_check")
consistency_check: dict[str, object] | None = (
consistency_raw if isinstance(consistency_raw, dict) else None
consistency_raw = state.get("consistency_check")
consistency_check: Mapping[str, object] | None = (
consistency_raw if isinstance(consistency_raw, Mapping) else None
)
feedback_raw = state_dict.get("human_feedback")
human_feedback: dict[str, object] | None = (
feedback_raw if isinstance(feedback_raw, dict) else None
feedback_raw = state.get("human_feedback")
human_feedback: Mapping[str, object] | None = (
feedback_raw if isinstance(feedback_raw, Mapping) else None
)
valid_raw = state_dict.get("is_output_valid")
valid_raw = state.get("is_output_valid")
is_output_valid: bool | None = valid_raw if isinstance(valid_raw, bool) else None
failed: bool = False
failed = False
issues: list[str] = [
str(err.get("error", "Unknown validation issue")) for err in validation_errors
]
error_counts: dict[str, int] = {}
for err in validation_errors:
phase = str(err.get("phase", "unknown"))
error_counts[phase] = cast("dict[str, int]", error_counts).get(phase, 0) + 1
error_counts[phase] = error_counts.get(phase, 0) + 1
def check_score(
result: dict[str, object] | None, label: str, threshold: float
) -> None:
"""Check if a validation result's score is below a threshold and update failure state.
Args:
result: The validation result, or None if not available.
label: The label for the validation type (e.g., 'Fact check').
threshold: The minimum acceptable score.
Side Effects:
Modifies the enclosing function's 'failed' and 'issues' variables if the score is too low.
"""
def check_score(result: Mapping[str, object] | None, label: str, threshold: float) -> None:
nonlocal failed, issues
if result is not None:
score_raw = result.get("score", 1.0)
score = float(score_raw) if isinstance(score_raw, (int, float)) else 1.0
if score < threshold:
failed = True
issues.append(f"{label} score too low: {score}")
issues_raw = result.get("issues")
if isinstance(issues_raw, list):
issues.extend(
str(issue) for issue in issues_raw if issue is not None
)
if result is None:
return
score_raw = result.get("score", 1.0)
score = float(score_raw) if isinstance(score_raw, (int, float)) else 1.0
if score < threshold:
failed = True
issues.append(f"{label} score too low: {score}")
issues_raw = result.get("issues")
if isinstance(issues_raw, Sequence) and not isinstance(
issues_raw, (str, bytes, bytearray)
):
issues.extend(str(issue) for issue in issues_raw if issue is not None)
if is_output_valid is False:
failed = True
# Safely handle validation_issues which might be None or not a list
validation_issues = state_dict.get("validation_issues")
if isinstance(validation_issues, list):
issues.extend(
str(issue) for issue in validation_issues if issue is not None
)
validation_issues = state.get("validation_issues")
if isinstance(validation_issues, Sequence):
issues.extend(str(issue) for issue in validation_issues if issue is not None)
def safe_check_score(
result: dict[str, object] | None, label: str, threshold: float
) -> None:
"""Safely check score with proper type handling.
check_score(fact_check_results, "Fact check", 0.6)
check_score(logic_validation, "Logic validation", 0.6)
check_score(consistency_check, "Consistency", 0.6)
check_score(human_feedback, "Human feedback", 0.5)
Args:
result: The validation result to check
label: Label for the validation type
threshold: Minimum acceptable score
"""
if result is not None:
check_score(result, label, threshold)
else:
check_score(None, label, threshold)
# Check all validation results
safe_check_score(fact_check_results, "Fact check", 0.6)
safe_check_score(logic_validation, "Logic validation", 0.6)
safe_check_score(consistency_check, "Consistency", 0.6)
safe_check_score(human_feedback, "Human feedback", 0.5)
def get_safe_score(result: dict[str, object] | None) -> float:
"""Safely extract score from validation result."""
def get_safe_score(result: Mapping[str, object] | None) -> float:
if result is None:
return 1.0
score_raw = result.get("score", 1.0)
@@ -304,7 +254,7 @@ async def handle_validation_failure(
.build()
)
status: str = "failed_validation"
status = "failed_validation"
if error_counts.get("fact_check", 0) > 0 or (
fact_check_results and get_safe_score(fact_check_results) < 0.6
):
@@ -322,26 +272,15 @@ async def handle_validation_failure(
):
status = "failed_human_feedback"
class ValidationErrorSummaryTypedDict(dict[str, object]):
"""TypedDict for validation error summary structure."""
validation_status: str
failure_reason: list[str]
error_count: int
error_distribution: dict[str, int]
recommendation: str
summary_dict = {
summary: ValidationErrorSummary = {
"validation_status": status,
"failure_reason": issues,
"error_count": len(validation_errors),
"error_distribution": error_counts,
"recommendation": "Revise content based on validation feedback or request human review.",
}
summary = cast("ValidationErrorSummaryTypedDict", summary_dict)
error_highlight(f"Validation failed. Summary: {summary}")
# Use StateUpdater for immutable updates
updater = StateUpdater(state)
return (
updater.set("validation_error_summary", summary)

View File

@@ -1,6 +1,7 @@
"""Error analyzer node for classifying errors and determining recovery strategies."""
import re
from collections.abc import Mapping
from typing import Any
from langchain_core.runnables import RunnableConfig
@@ -112,6 +113,9 @@ def _rule_based_analysis(error: ErrorInfo, context: ErrorContext) -> ErrorAnalys
error_type = error.get("error_type", "unknown")
error_message = error.get("message", "")
error_category = error.get("category", ErrorCategory.UNKNOWN)
raw_details = error.get("details")
if isinstance(raw_details, Mapping):
error_category = raw_details.get("category", error_category)
# Analyze based on error category
if error_category == ErrorCategory.LLM:

View File

@@ -11,6 +11,7 @@ from biz_bud.core.langgraph import ConfigurationProvider, standard_node
from biz_bud.logging import get_logger
from biz_bud.prompts.error_handling import ERROR_SUMMARY_PROMPT, USER_GUIDANCE_PROMPT
from biz_bud.states.error_handling import ErrorHandlingState
from biz_bud.services.llm.client import LangchainLLMClient
logger = get_logger(__name__)
@@ -128,7 +129,9 @@ async def _generate_resolution_steps(
if service_factory is None:
raise ConfigurationError("ServiceFactory not found in RunnableConfig")
llm_client = await service_factory.get_llm_client()
llm_client = cast(
LangchainLLMClient, await service_factory.get_llm_client()
)
# Prepare context for guidance generation
context = {

View File

@@ -11,7 +11,7 @@ from __future__ import annotations
import re
import warnings
from typing import Any
from biz_bud.nodes.url_processing._typing import StateMapping
from urllib.parse import urljoin, urlparse
from langchain_core.runnables import RunnableConfig
@@ -30,8 +30,8 @@ from .scrape_url import scrape_url_node
@standard_node(node_name="discover_urls", metric_name="url_discovery")
async def discover_urls_node(
state: dict[str, Any], config: RunnableConfig | None
) -> dict[str, Any]:
state: StateMapping, config: RunnableConfig | None
) -> dict[str, object]:
"""Discover URLs from a website through sitemaps and crawling.
⚠️ DEPRECATED: This function is deprecated. Use biz_bud.nodes.url_processing.discover_urls_node instead.

View File

@@ -0,0 +1,103 @@
"""Shared typing helpers for URL processing nodes."""
from __future__ import annotations
from collections.abc import Iterable, Mapping
StateMapping = Mapping[str, object]
def coerce_str(value: object | None) -> str | None:
"""Return ``value`` if it is a string, otherwise ``None``."""
if isinstance(value, str):
return value
return None
def coerce_bool(value: object | None, default: bool = False) -> bool:
"""Coerce arbitrary objects into booleans with a default."""
if isinstance(value, bool):
return value
if isinstance(value, str):
lowered = value.strip().lower()
if lowered in {"true", "1", "yes", "on"}:
return True
if lowered in {"false", "0", "no", "off"}:
return False
return default
def coerce_int(value: object | None, default: int) -> int:
"""Return an integer when possible, otherwise the provided default."""
if isinstance(value, bool):
return default
if isinstance(value, int):
return value
if isinstance(value, float):
return int(value)
if isinstance(value, str):
try:
return int(value)
except ValueError:
return default
return default
def coerce_float(value: object | None, default: float = 0.0) -> float:
"""Return a floating-point number when possible."""
if isinstance(value, bool):
return default
if isinstance(value, (int, float)):
return float(value)
if isinstance(value, str):
try:
return float(value)
except ValueError:
return default
return default
def coerce_str_list(value: object | None) -> list[str]:
"""Create a list of strings from an arbitrary iterable value."""
if isinstance(value, str):
return [value]
if isinstance(value, Iterable):
return [item for item in value if isinstance(item, str)]
return []
def coerce_object_dict(value: object | None) -> dict[str, object]:
"""Convert arbitrary mapping-like objects into ``dict[str, object]``."""
if isinstance(value, Mapping):
return {str(key): item for key, item in value.items()}
return {}
def coerce_object_list(value: object | None) -> list[dict[str, object]]:
"""Convert an iterable of mappings into concrete dictionaries."""
if isinstance(value, Iterable) and not isinstance(value, (str, bytes, bytearray)):
result: list[dict[str, object]] = []
for item in value:
if isinstance(item, Mapping):
result.append({str(key): item_value for key, item_value in item.items()})
return result
return []
__all__ = [
"StateMapping",
"coerce_bool",
"coerce_float",
"coerce_int",
"coerce_object_dict",
"coerce_object_list",
"coerce_str",
"coerce_str_list",
]

View File

@@ -6,7 +6,15 @@ using the Business Buddy URL processing tools.
from __future__ import annotations
from typing import Any
from biz_bud.nodes.url_processing._typing import (
StateMapping,
coerce_bool,
coerce_float,
coerce_int,
coerce_object_dict,
coerce_str,
coerce_str_list,
)
from langchain_core.runnables import RunnableConfig
@@ -20,9 +28,9 @@ from biz_bud.tools.capabilities.url_processing import (
@standard_node(node_name="discover_urls", metric_name="url_discovery")
async def discover_urls_node(
state: dict[str, Any], config: RunnableConfig | None
) -> dict[str, Any]:
async def discover_urls_node(
state: StateMapping, config: RunnableConfig | None
) -> dict[str, object]:
"""Discover URLs from a website using URL processing tools.
This node discovers URLs from a base URL using various discovery methods
@@ -48,13 +56,17 @@ async def discover_urls_node(
"""
info_highlight("Starting URL discovery...", category="URLDiscovery")
base_url = state.get("base_url") or state.get("input_url") or state.get("url", "")
if not base_url:
error_highlight("No base URL provided for discovery", category="URLDiscovery")
return {
"discovered_urls": [],
"discovery_summary": {"total_discovered": 0, "base_url": None, "error": "No base URL provided"},
"errors": [
base_url_value = (
coerce_str(state.get("base_url"))
or coerce_str(state.get("input_url"))
or coerce_str(state.get("url"))
)
if not base_url_value:
error_highlight("No base URL provided for discovery", category="URLDiscovery")
return {
"discovered_urls": [],
"discovery_summary": {"total_discovered": 0, "base_url": None, "error": "No base URL provided"},
"errors": [
create_error_info(
message="No base URL provided for discovery",
node="discover_urls",
@@ -64,9 +76,10 @@ async def discover_urls_node(
],
}
discovery_provider = state.get("discovery_provider", "comprehensive")
max_results = state.get("max_results", 1000)
detailed_results = state.get("detailed_results", False)
base_url = base_url_value
discovery_provider = coerce_str(state.get("discovery_provider")) or "comprehensive"
max_results = coerce_int(state.get("max_results"), 1000)
detailed_results = coerce_bool(state.get("detailed_results"))
debug_highlight(
f"Discovering URLs from {base_url} using {discovery_provider} method",
@@ -74,43 +87,46 @@ async def discover_urls_node(
)
try:
if detailed_results:
# Get detailed discovery information
result = await discover_urls_detailed.ainvoke({
"base_url": base_url,
"provider": discovery_provider
})
discovered_urls = result.get("discovered_urls", [])
# Apply max results limit
if max_results and len(discovered_urls) > max_results:
discovered_urls = discovered_urls[:max_results]
discovery_details = {
"base_url": result.get("base_url"),
"total_discovered": result.get("total_discovered", 0),
"discovery_method": result.get("discovery_method"),
"sitemap_urls": result.get("sitemap_urls", []),
"robots_txt_urls": result.get("robots_txt_urls", []),
"html_links": result.get("html_links", []),
"is_successful": result.get("is_successful", False),
"discovery_time": result.get("discovery_time", 0.0),
"metadata": result.get("metadata", {}),
}
discovery_summary = {
"total_discovered": len(discovered_urls),
"base_url": base_url,
"discovery_method": result.get("discovery_method"),
"discovery_time": result.get("discovery_time", 0.0),
"limited_to": len(discovered_urls) if max_results and len(discovered_urls) >= max_results else None,
}
info_highlight(
f"Detailed URL discovery completed: {len(discovered_urls)} URLs found using {result.get('discovery_method')} method",
category="URLDiscovery"
)
if detailed_results:
# Get detailed discovery information
result = await discover_urls_detailed.ainvoke(
{"base_url": base_url, "provider": discovery_provider}
)
result_data = coerce_object_dict(result)
discovered_urls = coerce_str_list(result_data.get("discovered_urls"))
# Apply max results limit
if max_results and len(discovered_urls) > max_results:
discovered_urls = discovered_urls[:max_results]
discovery_method = coerce_str(result_data.get("discovery_method"))
discovery_details = {
"base_url": coerce_str(result_data.get("base_url")) or base_url,
"total_discovered": coerce_int(result_data.get("total_discovered"), len(discovered_urls)),
"discovery_method": discovery_method,
"sitemap_urls": coerce_str_list(result_data.get("sitemap_urls")),
"robots_txt_urls": coerce_str_list(result_data.get("robots_txt_urls")),
"html_links": coerce_str_list(result_data.get("html_links")),
"is_successful": coerce_bool(result_data.get("is_successful")),
"discovery_time": coerce_float(result_data.get("discovery_time")),
"metadata": coerce_object_dict(result_data.get("metadata")),
}
discovery_summary = {
"total_discovered": len(discovered_urls),
"base_url": base_url,
"discovery_method": discovery_method,
"discovery_time": coerce_float(result_data.get("discovery_time")),
"limited_to": len(discovered_urls) if max_results and len(discovered_urls) >= max_results else None,
}
info_highlight(
"Detailed URL discovery completed: "
f"{len(discovered_urls)} URLs found using {discovery_method or 'unknown'} method",
category="URLDiscovery"
)
return {
"discovered_urls": discovered_urls,
@@ -118,13 +134,17 @@ async def discover_urls_node(
"discovery_summary": discovery_summary,
}
else:
# Simple discovery
discovered_urls = await discover_urls.ainvoke({
"base_url": base_url,
"provider": discovery_provider,
"max_results": max_results
})
else:
# Simple discovery
discovered_urls = coerce_str_list(
await discover_urls.ainvoke(
{
"base_url": base_url,
"provider": discovery_provider,
"max_results": max_results,
}
)
)
discovery_summary = {
"total_discovered": len(discovered_urls),
@@ -143,22 +163,33 @@ async def discover_urls_node(
"discovery_summary": discovery_summary,
}
except Exception as e:
error_highlight(f"URL discovery failed: {e}", category="URLDiscovery")
return {
"discovered_urls": [base_url], # Fallback to base URL
"discovery_summary": {
"total_discovered": 1,
"base_url": base_url,
"error": str(e),
"fallback_used": True,
},
"errors": [
create_error_info(
message=f"URL discovery failed: {str(e)}",
node="discover_urls",
severity="warning",
category="discovery_error",
)
],
except Exception as e:
error_highlight(f"URL discovery failed: {e}", category="URLDiscovery")
safe_base_url = coerce_str(base_url) or None
fallback_urls = [url for url in [safe_base_url] if isinstance(url, str) and url]
error_type = type(e).__name__
return {
"discovered_urls": fallback_urls,
"discovery_summary": {
"total_discovered": len(fallback_urls),
"base_url": safe_base_url,
"error": f"{error_type}: {str(e)}",
"fallback_used": True,
"status": "error",
"provider": discovery_provider,
"requested_max_results": max_results,
},
"errors": [
create_error_info(
message=f"{error_type}: {str(e)}",
node="discover_urls",
severity="error",
category="discovery_error",
context={
"base_url": safe_base_url,
"provider": discovery_provider,
"max_results": max_results,
},
)
],
}

View File

@@ -6,7 +6,15 @@ using the Business Buddy URL processing tools.
from __future__ import annotations
from typing import Any
from biz_bud.nodes.url_processing._typing import (
StateMapping,
coerce_bool,
coerce_int,
coerce_object_dict,
coerce_object_list,
coerce_str,
coerce_str_list,
)
from langchain_core.runnables import RunnableConfig
@@ -16,9 +24,9 @@ from biz_bud.logging import debug_highlight, error_highlight, info_highlight
@standard_node(node_name="process_urls", metric_name="url_processing")
async def process_urls_node(
state: dict[str, Any], config: RunnableConfig | None
) -> dict[str, Any]:
async def process_urls_node(
state: StateMapping, config: RunnableConfig | None
) -> dict[str, object]:
"""Process multiple URLs using URL processing tools.
This node processes a batch of URLs using various URL processing
@@ -46,55 +54,64 @@ async def process_urls_node(
info_highlight("Starting batch URL processing...", category="URLProcessing")
urls = state.get("urls", [])
if not urls:
info_highlight("No URLs to process", category="URLProcessing")
return {
"processed_urls": [],
"processing_summary": {"total": 0, "successful": 0, "failed": 0},
}
urls = coerce_str_list(state.get("urls"))
if not urls:
info_highlight("No URLs to process", category="URLProcessing")
return {
"processed_urls": [],
"processing_summary": {"total": 0, "successful": 0, "failed": 0},
}
processing_operations = coerce_str_list(state.get("processing_operations"))
if not processing_operations:
processing_operations = ["validate", "normalize"]
validation_level = coerce_str(state.get("validation_level")) or "standard"
max_concurrent = coerce_int(state.get("max_concurrent"), 5)
processing_operations = state.get("processing_operations", ["validate", "normalize"])
validation_level = state.get("validation_level", "standard")
max_concurrent = state.get("max_concurrent", 5)
debug_highlight(
f"Processing {len(urls)} URLs with operations: {processing_operations}",
category="URLProcessing"
)
total_urls = len(urls)
debug_highlight(
f"Processing {total_urls} URLs with operations: {processing_operations}",
category="URLProcessing"
)
try:
# Use the batch processing tool
# Call the LangChain tool using invoke
input_dict = {
"urls": urls,
"validation_level": validation_level,
"normalization_provider": None,
"enable_deduplication": True,
"deduplication_provider": None,
"max_concurrent": max_concurrent,
"timeout": 30.0
}
result = await process_urls_batch.ainvoke(input_dict)
input_dict: dict[str, object] = {
"urls": urls,
"validation_level": validation_level,
"normalization_provider": None,
"enable_deduplication": True,
"deduplication_provider": None,
"max_concurrent": max_concurrent,
"timeout": 30.0,
}
result = await process_urls_batch.ainvoke(input_dict)
result_data = coerce_object_dict(result)
processed_urls = coerce_object_list(result_data.get("results"))
successful_count = sum(
1 for url_result in processed_urls if coerce_bool(url_result.get("success"))
)
failed_count = len(processed_urls) - successful_count
success_rate = (successful_count / total_urls) * 100 if total_urls else 0.0
processing_summary = {
"total": total_urls,
"successful": successful_count,
"failed": failed_count,
"success_rate": success_rate,
"operations_performed": processing_operations,
}
processed_urls = result.get("results", [])
successful_count = sum(bool(url_result.get("success", False))
for url_result in processed_urls)
failed_count = len(processed_urls) - successful_count
processing_summary = {
"total": len(urls),
"successful": successful_count,
"failed": failed_count,
"success_rate": (successful_count / len(urls)) * 100 if urls else 0,
"operations_performed": processing_operations,
}
info_highlight(
f"URL processing completed: {successful_count}/{len(urls)} successful "
f"({processing_summary['success_rate']:.1f}% success rate)",
category="URLProcessing"
)
info_highlight(
f"URL processing completed: {successful_count}/{len(urls)} successful "
f"({success_rate:.1f}% success rate)",
category="URLProcessing"
)
return {
"processed_urls": processed_urls,
@@ -103,14 +120,14 @@ async def process_urls_node(
except Exception as e:
error_highlight(f"URL processing node failed: {e}", category="URLProcessing")
return {
"processed_urls": [],
"processing_summary": {"total": len(urls), "successful": 0, "failed": len(urls)},
"errors": [
create_error_info(
message=f"URL processing failed: {str(e)}",
node="process_urls",
severity="error",
return {
"processed_urls": [],
"processing_summary": {"total": total_urls, "successful": 0, "failed": total_urls},
"errors": [
create_error_info(
message=f"URL processing failed: {str(e)}",
node="process_urls",
severity="error",
category="processing_error",
)
],

View File

@@ -6,7 +6,13 @@ using the Business Buddy URL processing tools.
from __future__ import annotations
from typing import Any
from biz_bud.nodes.url_processing._typing import (
StateMapping,
coerce_bool,
coerce_object_dict,
coerce_str,
coerce_str_list,
)
from langchain_core.runnables import RunnableConfig
@@ -17,9 +23,9 @@ from biz_bud.tools.capabilities.url_processing import validate_url
@standard_node(node_name="validate_urls", metric_name="url_validation")
async def validate_urls_node(
state: dict[str, Any], config: RunnableConfig | None
) -> dict[str, Any]:
async def validate_urls_node(
state: StateMapping, config: RunnableConfig | None
) -> dict[str, object]:
"""Validate URLs using URL processing tools.
This node validates a list of URLs using the URL processing validation
@@ -45,46 +51,45 @@ async def validate_urls_node(
"""
info_highlight("Starting URL validation...", category="URLValidation")
urls = state.get("urls", [])
if not urls:
info_highlight("No URLs to validate", category="URLValidation")
return {
"validated_urls": [],
"valid_urls": [],
"invalid_urls": [],
"validation_summary": {"total": 0, "valid": 0, "invalid": 0},
}
urls = coerce_str_list(state.get("urls"))
if not urls:
info_highlight("No URLs to validate", category="URLValidation")
return {
"validated_urls": [],
"valid_urls": [],
"invalid_urls": [],
"validation_summary": {"total": 0, "valid": 0, "invalid": 0},
}
validation_level = coerce_str(state.get("validation_level")) or "standard"
validation_provider = coerce_str(state.get("validation_provider"))
validation_level = state.get("validation_level", "standard")
validation_provider = state.get("validation_provider")
total_urls = len(urls)
debug_highlight(
f"Validating {total_urls} URLs with level: {validation_level}",
category="URLValidation"
)
debug_highlight(
f"Validating {len(urls)} URLs with level: {validation_level}",
category="URLValidation"
)
try:
validated_urls: list[dict[str, object]] = []
valid_urls: list[str] = []
invalid_urls: list[str] = []
try:
validated_urls = []
valid_urls = []
invalid_urls = []
for url in urls:
try:
result = await validate_url.ainvoke({
"url": url,
"level": validation_level,
"provider": validation_provider
})
validated_urls.append({
"url": url,
**result
})
if result.get("is_valid", False):
valid_urls.append(url)
else:
invalid_urls.append(url)
for url in urls:
try:
result = await validate_url.ainvoke(
{"url": url, "level": validation_level, "provider": validation_provider}
)
result_data = coerce_object_dict(result)
validated_urls.append({"url": url, **result_data})
if coerce_bool(result_data.get("is_valid")):
valid_urls.append(url)
else:
invalid_urls.append(url)
except Exception as e:
error_highlight(
@@ -99,18 +104,22 @@ async def validate_urls_node(
})
invalid_urls.append(url)
validation_summary = {
"total": len(urls),
"valid": len(valid_urls),
"invalid": len(invalid_urls),
"success_rate": (len(valid_urls) / len(urls)) * 100 if urls else 0,
}
info_highlight(
f"URL validation completed: {validation_summary['valid']}/{validation_summary['total']} valid "
f"({validation_summary['success_rate']:.1f}% success rate)",
category="URLValidation"
)
valid_count = len(valid_urls)
invalid_count = len(invalid_urls)
success_rate = (valid_count / total_urls) * 100 if total_urls else 0.0
validation_summary = {
"total": total_urls,
"valid": valid_count,
"invalid": invalid_count,
"success_rate": success_rate,
}
info_highlight(
f"URL validation completed: {valid_count}/{total_urls} valid "
f"({success_rate:.1f}% success rate)",
category="URLValidation"
)
return {
"validated_urls": validated_urls,
@@ -121,11 +130,11 @@ async def validate_urls_node(
except Exception as e:
error_highlight(f"URL validation node failed: {e}", category="URLValidation")
return {
"validated_urls": [],
"valid_urls": [],
"invalid_urls": urls, # Mark all as invalid on error
"validation_summary": {"total": len(urls), "valid": 0, "invalid": len(urls)},
return {
"validated_urls": [],
"valid_urls": [],
"invalid_urls": list(urls), # Mark all as invalid on error
"validation_summary": {"total": total_urls, "valid": 0, "invalid": total_urls},
"errors": [
create_error_info(
message=f"URL validation failed: {str(e)}",

View File

@@ -3,44 +3,22 @@
from __future__ import annotations
from operator import add
from typing import TYPE_CHECKING, Annotated, Any, Literal, TypedDict, TypeVar
from typing import Annotated, Literal, TypedDict
from langchain_core.messages import AIMessage as LangchainAIMessage
from langchain_core.messages import AnyMessage
from langgraph.graph.message import add_messages
from langgraph.graph import add_messages
from langgraph.managed import IsLastStep
if TYPE_CHECKING:
from langchain_core.messages import AIMessage as LangchainAIMessage
from langgraph.managed import IsLastStep
from biz_bud.core import (
ApiResponseTypedDict,
ErrorInfo,
InputMetadataTypedDict,
ParsedInputTypedDict,
ToolCallTypedDict,
)
from biz_bud.core import (
ApiResponseTypedDict,
ErrorInfo,
InputMetadataTypedDict,
ParsedInputTypedDict,
ToolCallTypedDict,
)
else:
# Create dummy types for runtime
Configuration = dict
ApiResponseTypedDict = dict
ErrorInfo = dict
InputMetadataTypedDict = dict
ParsedInputTypedDict = dict
ToolCallTypedDict = dict
IsLastStep = Any
LangchainAIMessage = Any # Will be set after TYPE_CHECKING
# Alias for compatibility
Message = Any # Will be set after TYPE_CHECKING
# The types previously here (AIMessage, IsLastStep, Configuration, utils)
# were moved into TYPE_CHECKING to satisfy TCH linting rules.
# If LangGraph has issues resolving these at runtime, specific noqa directives
# for TCH rules (TCH001-TCH004) might be needed on those specific imports
# if they are kept outside the TYPE_CHECKING block, or LangGraph's documentation
# should be consulted for the preferred pattern.
T = TypeVar("T")
StateMetadata = dict[str, object]
class ContextTypedDict(TypedDict, total=False):
@@ -50,13 +28,13 @@ class ContextTypedDict(TypedDict, total=False):
task: str
"""Task description or identifier."""
session_data: dict[str, Any]
session_data: StateMetadata
"""Session-specific data."""
user_preferences: dict[str, Any]
user_preferences: StateMetadata
"""User preferences and settings."""
workflow_metadata: dict[str, Any]
workflow_metadata: StateMetadata
"""Metadata about the workflow execution."""
unique_marker: str
@@ -76,7 +54,7 @@ class InitialInputTypedDict(TypedDict, total=False):
session_id: str
"""Session identifier."""
metadata: dict[str, Any]
metadata: StateMetadata
"""Additional metadata for the input."""
@@ -103,7 +81,7 @@ class RunMetadataTypedDict(TypedDict, total=False):
tags: list[str]
"""Tags associated with the run."""
custom_metadata: dict[str, Any]
custom_metadata: StateMetadata
"""Custom metadata for the run."""
@@ -127,7 +105,7 @@ class BaseStateRequired(TypedDict):
initial_input: InitialInputTypedDict
"""The original input dictionary that initiated the graph run. Preserved for context."""
config: dict[str, Any]
config: StateMetadata
"""Immutable application configuration (API keys, model settings, agent params, etc.).
Passed during graph invocation. Provides context to all nodes."""
@@ -156,11 +134,11 @@ class BaseStateRequired(TypedDict):
LangGraph managed field indicating the final step before recursion limit.
Note:
The type `IsLastStep` is an opaque marker imported from `langgraph.managed` and is defined as `Any`.
This is intentional, as LangGraph uses it as a special sentinel value rather than a concrete type.
See: src/langgraph/managed.pyi
LangGraph annotates this field with ``Annotated[bool, _IsLastStepSentinel()]`` to indicate when a
graph execution is about to reach its recursion limit. The concrete value is a boolean flag but the
annotation allows LangGraph to manage the lifecycle automatically.
[Source: LangGraph Docs, User Example]
[Source: LangGraph Docs]
"""
@@ -190,10 +168,10 @@ class BaseStateOptional(TypedDict, total=False):
assistant_message_for_history: LangchainAIMessage
"""Assistant message to be appended to message history."""
claims_to_check: list[dict[str, Any]]
claims_to_check: list[StateMetadata]
"""Claims extracted from content for fact-checking."""
fact_check_results: dict[str, Any] | None
fact_check_results: StateMetadata | None
"""Results from fact-checking claims."""
is_output_valid: bool | None

View File

@@ -4,9 +4,9 @@ This module defines the state structure for Buddy, the intelligent graph
orchestrator that coordinates complex workflows across the Business Buddy system.
"""
from typing import Any, Literal, TypedDict
from typing import Literal, TypedDict
from biz_bud.states.base import BaseState
from biz_bud.states.base import BaseState, StateMetadata
from biz_bud.states.planner import ExecutionPlan, QueryStep
@@ -18,7 +18,7 @@ class ExecutionRecord(TypedDict):
start_time: float
end_time: float
status: Literal["running", "completed", "failed", "skipped"]
result: Any | None
result: object | None
error: str | None
@@ -50,7 +50,7 @@ class BuddyState(BaseState):
# Execution tracking
execution_history: list[ExecutionRecord]
intermediate_results: dict[str, Any] # step_id -> result mapping
intermediate_results: dict[str, object] # step_id -> result mapping
adaptation_count: int
parallel_execution_enabled: bool
completed_step_ids: list[str]
@@ -75,8 +75,8 @@ class BuddyState(BaseState):
complexity_reasoning: str # Why this complexity was chosen
# Synthesis data (for capability introspection and other direct synthesis)
extracted_info: dict[str, Any] | None # Information extracted for synthesis
sources: list[dict[str, Any]] | None # Source metadata for synthesis
extracted_info: StateMetadata | None # Information extracted for synthesis
sources: list[StateMetadata] | None # Source metadata for synthesis
is_capability_introspection: (
bool # Whether this is a capability introspection query
)

View File

@@ -10,7 +10,7 @@ This module contains TypedDict definitions organized by business domains:
- Catalog: Catalog research and component analysis types
"""
from typing import Any, Literal, NotRequired, TypedDict
from typing import Literal, NotRequired, TypedDict
# =============================================================================
# Core Domain Types
@@ -45,7 +45,7 @@ class DataDict(TypedDict):
id: str
type: str
content: str | dict[str, Any] | list[Any]
content: str | dict[str, object] | list[object]
metadata: NotRequired[MetadataDict]
@@ -247,7 +247,7 @@ class CatalogComponentResearchResult(TypedDict):
item_id: str
item_name: str
search_query: NotRequired[str]
component_research: NotRequired[dict[str, Any]]
component_research: NotRequired[dict[str, object]]
from_cache: NotRequired[bool]
cache_age_days: NotRequired[int]
error: NotRequired[str]
@@ -263,7 +263,7 @@ class CatalogComponentResearchDict(TypedDict):
cached_items: NotRequired[int]
searched_items: NotRequired[int]
research_results: NotRequired[list[CatalogComponentResearchResult]]
metadata: NotRequired[dict[str, Any]]
metadata: NotRequired[dict[str, object]]
message: NotRequired[str]

View File

@@ -1,10 +1,10 @@
"""Error handling state definitions for the error recovery agent."""
from typing import Any, Literal, TypedDict
from typing import Literal, TypedDict
from biz_bud.core import ErrorInfo
from .base import BaseState
from .base import BaseState, StateMetadata
# Note: These TypedDict definitions are specific to the error handling state
# and include additional fields required for state management.
@@ -16,7 +16,7 @@ class ErrorContext(TypedDict):
node_name: str
graph_name: str
timestamp: str
input_state: dict[str, Any]
input_state: StateMetadata
execution_count: int
@@ -34,7 +34,7 @@ class RecoveryAction(TypedDict):
"""A recovery action to attempt."""
action_type: Literal["retry", "modify_input", "fallback", "skip", "abort"]
parameters: dict[str, Any]
parameters: StateMetadata
priority: int
expected_success_rate: float
@@ -44,7 +44,7 @@ class RecoveryResult(TypedDict, total=False):
success: bool
message: str
new_state: dict[str, Any] # Optional
new_state: StateMetadata # Optional
duration_seconds: float # Optional

View File

@@ -8,8 +8,8 @@ from __future__ import annotations
from typing import Annotated, Literal, TypedDict, cast
from langchain_core.messages import AnyMessage
from langgraph.graph.message import add_messages
from langchain_core.messages import AnyMessage
from langgraph.graph import add_messages
from biz_bud.core import SearchResultTypedDict as SearchResult
from biz_bud.core.types import ErrorInfo
@@ -174,30 +174,72 @@ def create_message_only_state() -> MessageOnlyState:
# Example validation functions that work with focused states
def validate_search_state(state: SearchWorkflowState) -> bool:
"""Validate that search state is properly formed.
Args:
state: The search state to validate
Returns:
True if state is valid, False otherwise
"""
# Required fields check
return bool(state.get("status")) if state.get("thread_id") else False
def validate_search_state(state: SearchWorkflowState) -> bool:
"""Validate that search state is properly formed.
The modern LangGraph runtime is stricter about reducer inputs, so we make
sure list-based fields are actually lists (and contain dictionaries for
search results) before allowing the state to proceed. This mirrors
LangGraph's runtime validation and keeps our focused states resilient
across v0.x and v1.x releases.
Args:
state: The search state to validate
Returns:
True if state is valid, False otherwise
"""
thread_id = state.get("thread_id")
status = state.get("status")
messages = state.get("messages")
search_results = state.get("search_results")
if not isinstance(thread_id, str) or not thread_id.strip():
return False
if status not in {"pending", "running", "success", "error", "interrupted"}:
return False
if not isinstance(messages, list):
return False
if not isinstance(search_results, list):
return False
if any(not isinstance(result, dict) for result in search_results):
return False
return True
def validate_validation_state(state: ValidationWorkflowState) -> bool:
"""Validate that validation state is properly formed.
Args:
state: The validation state to validate
Returns:
True if state is valid, False otherwise
"""
# Must have content to validate
return bool(state.get("thread_id")) if state.get("content") else False
def validate_validation_state(state: ValidationWorkflowState) -> bool:
"""Validate that validation state is properly formed.
Args:
state: The validation state to validate
Returns:
True if state is valid, False otherwise
"""
thread_id = state.get("thread_id")
content = state.get("content")
issues = state.get("validation_issues")
if not isinstance(thread_id, str) or not thread_id.strip():
return False
if not isinstance(content, str) or not content:
return False
if issues is None:
issues = []
if not isinstance(issues, list):
return False
return True
# Export focused states and utilities

View File

@@ -3,9 +3,9 @@
from __future__ import annotations
from typing import Any, Literal, TypedDict
from typing import Literal, TypedDict
from biz_bud.states.base import BaseState
from biz_bud.states.base import BaseState, StateMetadata
class RAGAgentStateRequired(TypedDict):
@@ -25,7 +25,7 @@ class RAGAgentStateRequired(TypedDict):
url_hash: str | None
"""SHA256 hash of the URL for deduplication (first 16 chars)."""
existing_content: dict[str, Any] | None
existing_content: StateMetadata | None
"""Existing content metadata from knowledge stores."""
content_age_days: int | None
@@ -38,14 +38,14 @@ class RAGAgentStateRequired(TypedDict):
"""Human-readable reason for processing decision."""
# Processing parameters
scrape_params: dict[str, Any]
scrape_params: StateMetadata
"""Parameters for web scraping (depth, patterns, selectors)."""
r2r_params: dict[str, Any]
r2r_params: StateMetadata
"""Parameters for R2R upload (chunking, metadata)."""
# Results
processing_result: dict[str, Any] | None
processing_result: StateMetadata | None
"""Result from url_to_rag graph execution."""
rag_status: Literal["checking", "decided", "processing", "completed", "error"]
@@ -62,7 +62,7 @@ class RAGAgentStateOptional(TypedDict, total=False):
"""Optional fields for RAG agent workflow."""
# ReAct agent fields
intermediate_steps: list[tuple[Any, str]]
intermediate_steps: list[tuple[object, str]]
"""Intermediate steps for ReAct agent execution."""
final_answer: str | None

View File

@@ -2,20 +2,14 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Literal, TypedDict
from typing import Literal, TypedDict
from biz_bud.states.base import BaseState
from biz_bud.states.base import BaseState, StateMetadata
from biz_bud.tools.clients.r2r import R2RSearchResult
if TYPE_CHECKING:
from biz_bud.tools.clients.r2r import R2RSearchResult
else:
# Runtime placeholders for type checking
R2RSearchResult = Any
# Legacy types from deleted rag modules - using Any for now
FilteredChunk = Any
GenerationResult = Any
RetrievalResult = Any
FilteredChunk = StateMetadata
GenerationResult = StateMetadata
RetrievalResult = StateMetadata
class RAGOrchestratorStateRequired(TypedDict):
@@ -56,7 +50,7 @@ class RAGOrchestratorStateRequired(TypedDict):
urls_to_ingest: list[str]
"""URLs that need to be ingested."""
ingestion_results: dict[str, Any]
ingestion_results: StateMetadata
"""Results from the ingestion component."""
ingestion_status: Literal["pending", "processing", "completed", "failed", "skipped"]
@@ -120,14 +114,14 @@ class RAGOrchestratorStateOptional(TypedDict, total=False):
"""Timing information for each component."""
# Context and metadata
user_context: dict[str, Any]
user_context: StateMetadata
"""Additional context provided by the user."""
previous_interactions: list[dict[str, Any]]
previous_interactions: list[StateMetadata]
"""History of previous interactions in this session."""
# Advanced retrieval options
retrieval_filters: dict[str, Any]
retrieval_filters: StateMetadata
"""Filters to apply during retrieval."""
max_chunks: int
@@ -147,10 +141,10 @@ class RAGOrchestratorStateOptional(TypedDict, total=False):
"""Style for citations in the response."""
# Error handling and monitoring
error_history: list[dict[str, Any]]
error_history: list[StateMetadata]
"""History of errors encountered during workflow."""
error_analysis: dict[str, Any]
error_analysis: StateMetadata
"""Analysis results from the error handling graph."""
should_retry_node: bool
@@ -165,10 +159,10 @@ class RAGOrchestratorStateOptional(TypedDict, total=False):
recovery_successful: bool
"""Whether error recovery was successful."""
performance_metrics: dict[str, Any]
performance_metrics: StateMetadata
"""Performance metrics for monitoring."""
debug_info: dict[str, Any]
debug_info: StateMetadata
"""Debug information for troubleshooting."""
# Integration with legacy fields

View File

@@ -7,7 +7,7 @@ replacing the monolithic BusinessBuddyState with focused, type-safe alternatives
from __future__ import annotations
from operator import add
from typing import TYPE_CHECKING, Annotated, Any, Literal, TypedDict
from typing import TYPE_CHECKING, Annotated, Literal, TypedDict
# Import from biz_bud.core for search results
from biz_bud.core import SearchResultTypedDict as _SearchResult
@@ -126,7 +126,7 @@ class ResearchStateOptional(TypedDict, total=False):
scraped_results: dict[str, dict[str, str | None]]
"""Scraped content from URLs, keyed by URL."""
semantic_extraction_results: dict[str, Any]
semantic_extraction_results: dict[str, object]
"""Results from semantic extraction including vector IDs."""
vector_ids: Annotated[list[str], add]

View File

@@ -9,7 +9,7 @@ Pydantic models in biz_bud.states.validation_models which provide:
"""
from typing import Any, Literal, TypedDict, TypeGuard
from typing import Literal, TypedDict, TypeGuard
from langchain_core.messages import AIMessage, ToolMessage
@@ -27,10 +27,10 @@ class ToolInputData(TypedDict, total=False):
source_context: str
class ToolOutputData(TypedDict, total=False):
class ToolOutputData(TypedDict, total=False):
"""Structured output data from tool execution."""
result: str | dict[str, Any] | list[Any]
result: str | dict[str, object] | list[object]
status: Literal["success", "error", "partial"]
metadata: dict[str, str | int | bool]
processing_time_ms: int
@@ -260,7 +260,7 @@ class WebToolsStateTypedDict(SourceMetadataBase, ProcessingStatusBase, Configura
# Type guard functions for runtime type checking
def is_tool_input_data(obj: dict[str, Any]) -> TypeGuard[ToolInputData]:
def is_tool_input_data(obj: dict[str, object]) -> TypeGuard[ToolInputData]:
"""Type guard to check if object is valid ToolInputData."""
return (
isinstance(obj.get("operation", ""), str) and
@@ -269,7 +269,7 @@ def is_tool_input_data(obj: dict[str, Any]) -> TypeGuard[ToolInputData]:
)
def is_tool_output_data(obj: dict[str, Any]) -> TypeGuard[ToolOutputData]:
def is_tool_output_data(obj: dict[str, object]) -> TypeGuard[ToolOutputData]:
"""Type guard to check if object is valid ToolOutputData."""
return (
obj.get("status") in ("success", "error", "partial") and
@@ -278,7 +278,7 @@ def is_tool_output_data(obj: dict[str, Any]) -> TypeGuard[ToolOutputData]:
)
def is_search_result_dict(obj: dict[str, Any]) -> TypeGuard[SearchResultDict]:
def is_search_result_dict(obj: dict[str, object]) -> TypeGuard[SearchResultDict]:
"""Type guard to check if object is valid SearchResultDict."""
return (
isinstance(obj.get("title", ""), str) and
@@ -290,7 +290,7 @@ isinstance(obj.get("title", ""), str) and
)
def is_scraped_content_dict(obj: dict[str, Any]) -> TypeGuard[ScrapedContentDict]:
def is_scraped_content_dict(obj: dict[str, object]) -> TypeGuard[ScrapedContentDict]:
"""Type guard to check if object is valid ScrapedContentDict."""
return (
isinstance(obj.get("url", ""), str) and
@@ -301,7 +301,7 @@ isinstance(obj.get("url", ""), str) and
)
def is_extracted_statistic(obj: dict[str, Any]) -> TypeGuard[ExtractedStatistic]:
def is_extracted_statistic(obj: dict[str, object]) -> TypeGuard[ExtractedStatistic]:
"""Type guard to check if object is valid ExtractedStatistic."""
return (
isinstance(obj.get("name", ""), str) and
@@ -313,7 +313,7 @@ isinstance(obj.get("name", ""), str) and
)
def is_extracted_fact(obj: dict[str, Any]) -> TypeGuard[ExtractedFact]:
def is_extracted_fact(obj: dict[str, object]) -> TypeGuard[ExtractedFact]:
"""Type guard to check if object is valid ExtractedFact."""
return (
isinstance(obj.get("claim", ""), str) and
@@ -328,7 +328,7 @@ isinstance(obj.get("claim", ""), str) and
# Higher-level type guards for state classes
def is_tool_state_data(obj: dict[str, Any]) -> TypeGuard[ToolStateTypedDict]:
def is_tool_state_data(obj: dict[str, object]) -> TypeGuard[ToolStateTypedDict]:
"""Type guard to check if object has valid ToolState structure."""
return (
(obj.get("tool_name") is None or isinstance(obj.get("tool_name"), str)) and
@@ -337,7 +337,7 @@ def is_tool_state_data(obj: dict[str, Any]) -> TypeGuard[ToolStateTypedDict]:
)
def is_search_state_data(obj: dict[str, Any]) -> TypeGuard[SearchStateTypedDict]:
def is_search_state_data(obj: dict[str, object]) -> TypeGuard[SearchStateTypedDict]:
"""Type guard to check if object has valid SearchState structure."""
return (
(obj.get("query") is None or isinstance(obj.get("query"), str)) and
@@ -349,7 +349,7 @@ def is_search_state_data(obj: dict[str, Any]) -> TypeGuard[SearchStateTypedDict]
)
def is_scraping_state_data(obj: dict[str, Any]) -> TypeGuard[ScrapingStateTypedDict]:
def is_scraping_state_data(obj: dict[str, object]) -> TypeGuard[ScrapingStateTypedDict]:
"""Type guard to check if object has valid ScrapingState structure."""
return (
(obj.get("scraper_name") is None or
@@ -362,7 +362,7 @@ def is_scraping_state_data(obj: dict[str, Any]) -> TypeGuard[ScrapingStateTypedD
)
def is_extraction_state_data(obj: dict[str, Any]) -> TypeGuard[ExtractionStateTypedDict]:
def is_extraction_state_data(obj: dict[str, object]) -> TypeGuard[ExtractionStateTypedDict]:
"""Type guard to check if object has valid ExtractionState structure."""
return (
(obj.get("extracted_statistics") is None or
@@ -376,32 +376,32 @@ def is_extraction_state_data(obj: dict[str, Any]) -> TypeGuard[ExtractionStateTy
# Utility functions for runtime validation with type guards
def validate_and_cast_tool_input(obj: dict[str, Any]) -> ToolInputData | None:
def validate_and_cast_tool_input(obj: dict[str, object]) -> ToolInputData | None:
"""Validate and cast object to ToolInputData if valid, otherwise return None."""
return obj if is_tool_input_data(obj) else None
def validate_and_cast_search_result(obj: dict[str, Any]) -> SearchResultDict | None:
def validate_and_cast_search_result(obj: dict[str, object]) -> SearchResultDict | None:
"""Validate and cast object to SearchResultDict if valid, otherwise return None."""
return obj if is_search_result_dict(obj) else None
def validate_and_cast_scraped_content(obj: dict[str, Any]) -> ScrapedContentDict | None:
def validate_and_cast_scraped_content(obj: dict[str, object]) -> ScrapedContentDict | None:
"""Validate and cast object to ScrapedContentDict if valid, otherwise return None."""
return obj if is_scraped_content_dict(obj) else None
def filter_valid_search_results(results: list[dict[str, Any]]) -> list[SearchResultDict]:
def filter_valid_search_results(results: list[dict[str, object]]) -> list[SearchResultDict]:
"""Filter list to only include valid SearchResultDict objects."""
return [result for result in results if is_search_result_dict(result)]
def filter_valid_extracted_statistics(stats: list[dict[str, Any]]) -> list[ExtractedStatistic]:
def filter_valid_extracted_statistics(stats: list[dict[str, object]]) -> list[ExtractedStatistic]:
"""Filter list to only include valid ExtractedStatistic objects."""
return [stat for stat in stats if is_extracted_statistic(stat)]
def filter_valid_extracted_facts(facts: list[dict[str, Any]]) -> list[ExtractedFact]:
def filter_valid_extracted_facts(facts: list[dict[str, object]]) -> list[ExtractedFact]:
"""Filter list to only include valid ExtractedFact objects."""
return [fact for fact in facts if is_extracted_fact(fact)]
@@ -449,7 +449,7 @@ def create_tool_input_data(
def create_tool_output_data(
status: Literal["success", "error", "partial"],
*,
result: str | dict[str, Any] | list[Any] | None = None,
result: str | dict[str, object] | list[object] | None = None,
metadata: dict[str, str | int | bool] | None = None,
processing_time_ms: int = 0,
warnings: list[str] | None = None,
@@ -539,15 +539,15 @@ def create_search_result_dict(
return search_result
def create_scraped_content_dict(
url: str,
content: str,
*,
title: str | None = None,
html: str | None = None,
metadata: dict[str, str] | None = None,
scrape_timestamp: str | None = None,
content_length: int | None = None,
def create_scraped_content_dict(
url: str,
content: str | None,
*,
title: str | None = None,
html: str | None = None,
metadata: dict[str, str] | None = None,
scrape_timestamp: str | None = None,
content_length: int | None = None,
success: bool = True,
error_message: str | None = None,
) -> ScrapedContentDict:
@@ -567,26 +567,30 @@ def create_scraped_content_dict(
Returns:
Validated ScrapedContentDict instance
Raises:
ValueError: If url/content are empty, content_length is negative,
or error_message is present when success is True
"""
if not url or url.isspace():
raise ValidationError("URL cannot be empty or whitespace")
if not content or content.isspace():
raise ValidationError("Content cannot be empty or whitespace")
if content_length is not None and content_length < 0:
raise ValidationError("Content length cannot be negative")
if not success and not error_message:
error_message = "An unknown error occurred during scraping."
if success and error_message:
raise ValidationError("Error message should not be present when success is True")
scraped_content: ScrapedContentDict = {
"url": url.strip(),
"content": content,
"success": success,
}
Raises:
ValidationError: If url/content are empty, content_length is negative,
or error_message is present when success is True
"""
if not url or url.isspace():
raise ValidationError("URL cannot be empty or whitespace")
if success:
if not content or content.isspace():
raise ValidationError("Content cannot be empty or whitespace when success is True")
else:
if not content or content.isspace():
content = ""
if content_length is not None and content_length < 0:
raise ValidationError("Content length cannot be negative")
if not success and not error_message:
error_message = "An unknown error occurred during scraping."
if success and error_message:
raise ValidationError("Error message should not be present when success is True")
scraped_content: ScrapedContentDict = {
"url": url.strip(),
"content": content or "",
"success": success,
}
if title is not None:
scraped_content["title"] = title

View File

@@ -7,12 +7,10 @@ to handle accumulating fields correctly.
from __future__ import annotations
from operator import add
from typing import Annotated, Any, Literal, TypedDict
from typing import Annotated, Literal, TypedDict
from langchain_core.messages import AnyMessage
# Import LangGraph functionality
from langgraph.graph.message import add_messages
from langgraph.graph import add_messages
# Import runtime types needed for forward references
# These must be available at runtime for LangGraph's get_type_hints() calls
@@ -61,6 +59,8 @@ from .catalog import CatalogIntelState
# Types are now imported directly at runtime for LangGraph compatibility
StatePayload = dict[str, object]
class Organization(TypedDict):
"""Organization entity."""
@@ -126,10 +126,10 @@ class BaseStateOptional(TypedDict, total=False):
persistence_error: str
"""Error message if result persistence fails."""
claims_to_check: list[dict[str, Any]]
claims_to_check: list[StatePayload]
"""Claims extracted from content for fact-checking."""
fact_check_results: dict[str, Any]
fact_check_results: StatePayload
"""Results from fact-checking claims."""
is_output_valid: bool | None
@@ -305,7 +305,7 @@ class ResearchStateOptional(TypedDict, total=False):
scraped_results: dict[str, dict[str, str | None]]
"""Scraped content from URLs, keyed by URL."""
semantic_extraction_results: dict[str, Any]
semantic_extraction_results: StatePayload
"""Results from semantic extraction including vector IDs."""
vector_ids: Annotated[list[str], add]
@@ -325,7 +325,7 @@ class ResearchState(
Combines base state with search and validation capabilities.
"""
extracted_info: dict[str, Any]
extracted_info: StatePayload
"""Information extracted from search results in source_0, source_1, etc. format."""
synthesis: str
@@ -382,7 +382,7 @@ class BusinessBuddyState(
extracted_info: ExtractedInfoDict
"""Information extracted from search results."""
extracted_content: dict[str, Any]
extracted_content: StatePayload
"""Content extracted from various sources (e.g., catalog data)."""
synthesis: str
@@ -491,7 +491,7 @@ class BusinessBuddyState(
catalog_component_research: CatalogComponentResearchDict
"""Results from catalog component research."""
extracted_components: dict[str, Any]
extracted_components: StatePayload
"""Extracted components from catalog items."""
# Catalog intelligence fields
@@ -501,22 +501,22 @@ class BusinessBuddyState(
batch_component_queries: list[str]
"""List of components for batch analysis."""
component_news_impact_reports: list[dict[str, Any]]
component_news_impact_reports: list[StatePayload]
"""Impact reports for components based on news/market conditions."""
catalog_optimization_suggestions: list[dict[str, Any]]
catalog_optimization_suggestions: list[StatePayload]
"""Optimization suggestions for catalog items."""
optimization_summary: dict[str, Any]
optimization_summary: StatePayload
"""Summary of optimization recommendations."""
catalog_items_linked_to_component: list[dict[str, Any]]
catalog_items_linked_to_component: list[StatePayload]
"""Catalog items affected by specific component."""
data_source_used: str | None
"""Source of catalog data (database, yaml, default)."""
component_analytics: dict[str, Any]
component_analytics: StatePayload
"""Analytics and insights from component analysis."""

View File

@@ -1,21 +1,13 @@
"""Minimal state for URL to R2R processing workflow."""
from __future__ import annotations
from typing import TYPE_CHECKING, Annotated, Any, Literal, TypedDict
# Import AnyMessage and add_messages for proper LangGraph support
try:
from langchain_core.messages import AnyMessage
from langgraph.graph.message import add_messages
except ImportError:
if TYPE_CHECKING:
from langchain_core.messages import AnyMessage
from langgraph.graph.message import add_messages
else:
# Fallback for runtime if langchain is not installed
AnyMessage = Any
add_messages = list
"""Minimal state for URL to R2R processing workflow."""
from __future__ import annotations
from typing import Annotated, Literal, TypedDict
from langchain_core.messages import AnyMessage
from langgraph.graph import add_messages
StatePayload = dict[str, object]
class URLToRAGState(TypedDict, total=False):
@@ -33,7 +25,7 @@ class URLToRAGState(TypedDict, total=False):
url: str
"""Alternative URL field for compatibility."""
config: dict[str, Any]
config: StatePayload
"""Configuration containing API keys."""
# Processing fields
@@ -46,17 +38,17 @@ class URLToRAGState(TypedDict, total=False):
discovered_urls: list[str]
"""List of URLs discovered during processing."""
scraped_content: list[dict[str, Any]]
scraped_content: list[StatePayload]
"""Scraped content from Firecrawl."""
processed_content: dict[str, Any]
processed_content: StatePayload
"""Processed content ready for upload."""
repomix_output: str | None
"""Repomix output for git repositories."""
# R2R configuration
r2r_info: dict[str, Any] | None
r2r_info: StatePayload | None
"""R2R upload information and metadata."""
# R2R fields
@@ -78,7 +70,7 @@ class URLToRAGState(TypedDict, total=False):
"""Status messages during processing."""
# For tests
errors: list[dict[str, Any]]
errors: list[StatePayload]
"""Error list for tests."""
# Track processing state
@@ -86,10 +78,10 @@ class URLToRAGState(TypedDict, total=False):
"""Number of pages that have been processed by analyzer."""
# Upload tracking (optional)
upload_tracker: dict[str, Any] | None
upload_tracker: StatePayload | None
"""Upload tracking summary with success/failure counts."""
upload_details: dict[str, Any] | None
upload_details: StatePayload | None
"""Detailed upload information per URL."""
# Additional fields that may be returned
@@ -119,7 +111,7 @@ class URLToRAGState(TypedDict, total=False):
existing_document_id: str | None
"""Document ID if URL was already processed."""
existing_document_metadata: dict[str, Any] | None
existing_document_metadata: StatePayload | None
"""Metadata of existing document if found."""
skipped_urls_count: int
@@ -157,7 +149,7 @@ class URLToRAGState(TypedDict, total=False):
url_hash: str | None
"""SHA256 hash of URL for efficient lookup."""
existing_content: dict[str, Any] | None
existing_content: StatePayload | None
"""Found content metadata for deduplication."""
content_age_days: int | None
@@ -169,8 +161,8 @@ class URLToRAGState(TypedDict, total=False):
processing_reason: str | None
"""Human-readable explanation of processing decision."""
scrape_params: dict[str, Any]
scrape_params: StatePayload
"""Parameters for scraping operations."""
r2r_params: dict[str, Any]
r2r_params: StatePayload
"""Parameters for R2R processing."""

View File

@@ -1,6 +1,8 @@
"""State for URL to R2R workflow."""
from typing import Any, TypedDict
"""State for URL to R2R workflow."""
from typing import TypedDict
from langchain_core.messages import AnyMessage
class R2RInfo(TypedDict, total=False):
@@ -13,12 +15,12 @@ class R2RInfo(TypedDict, total=False):
upload_timestamp: str
class URLToRAGState(TypedDict, total=False):
"""State for URL to R2R processing."""
url: str
processed_content: dict[str, str | dict[str, str]]
r2r_info: R2RInfo
upload_complete: bool
messages: list[Any]
errors: list[dict[str, str]]
class URLToRAGState(TypedDict, total=False):
"""State for URL to R2R processing."""
url: str
processed_content: dict[str, str | dict[str, str]]
r2r_info: R2RInfo
upload_complete: bool
messages: list[AnyMessage]
errors: list[dict[str, str]]

View File

@@ -6,7 +6,7 @@ in tools.py, enabling runtime validation, type coercion, and error reporting.
from __future__ import annotations
from typing import Any, Literal
from typing import Literal
from urllib.parse import urlparse
from pydantic import BaseModel, Field, field_validator, model_validator
@@ -38,9 +38,9 @@ class ToolInputDataModel(BaseModel):
class ToolOutputDataModel(BaseModel):
"""Pydantic model for ToolOutputData validation."""
result: str | dict[str, Any] | list[Any] | None = Field(
None, description="Tool execution result"
)
result: str | dict[str, object] | list[object] | None = Field(
None, description="Tool execution result"
)
status: Literal["success", "error", "partial"] = Field(
..., description="Execution status"
)
@@ -428,56 +428,56 @@ class WebToolsStateModel(SourceMetadataModel, ProcessingStatusModel, Configurati
# Validation helper functions
def validate_tool_input_data(data: dict[str, Any]) -> ToolInputDataModel:
def validate_tool_input_data(data: dict[str, object]) -> ToolInputDataModel:
"""Validate tool input data using Pydantic model."""
return ToolInputDataModel.model_validate(data)
def validate_tool_output_data(data: dict[str, Any]) -> ToolOutputDataModel:
def validate_tool_output_data(data: dict[str, object]) -> ToolOutputDataModel:
"""Validate tool output data using Pydantic model."""
return ToolOutputDataModel.model_validate(data)
def validate_search_result(data: dict[str, Any]) -> SearchResultDictModel:
def validate_search_result(data: dict[str, object]) -> SearchResultDictModel:
"""Validate search result data using Pydantic model."""
return SearchResultDictModel.model_validate(data)
def validate_scraped_content(data: dict[str, Any]) -> ScrapedContentDictModel:
def validate_scraped_content(data: dict[str, object]) -> ScrapedContentDictModel:
"""Validate scraped content data using Pydantic model."""
return ScrapedContentDictModel.model_validate(data)
def validate_extracted_statistic(data: dict[str, Any]) -> ExtractedStatisticModel:
def validate_extracted_statistic(data: dict[str, object]) -> ExtractedStatisticModel:
"""Validate extracted statistic data using Pydantic model."""
return ExtractedStatisticModel.model_validate(data)
def validate_extracted_fact(data: dict[str, Any]) -> ExtractedFactModel:
def validate_extracted_fact(data: dict[str, object]) -> ExtractedFactModel:
"""Validate extracted fact data using Pydantic model."""
return ExtractedFactModel.model_validate(data)
def validate_tool_state(data: dict[str, Any]) -> ToolStateModel:
def validate_tool_state(data: dict[str, object]) -> ToolStateModel:
"""Validate tool state data using Pydantic model."""
return ToolStateModel.model_validate(data)
def validate_search_state(data: dict[str, Any]) -> SearchStateModel:
def validate_search_state(data: dict[str, object]) -> SearchStateModel:
"""Validate search state data using Pydantic model."""
return SearchStateModel.model_validate(data)
def validate_scraping_state(data: dict[str, Any]) -> ScrapingStateModel:
def validate_scraping_state(data: dict[str, object]) -> ScrapingStateModel:
"""Validate scraping state data using Pydantic model."""
return ScrapingStateModel.model_validate(data)
def validate_extraction_state(data: dict[str, Any]) -> ExtractionStateModel:
def validate_extraction_state(data: dict[str, object]) -> ExtractionStateModel:
"""Validate extraction state data using Pydantic model."""
return ExtractionStateModel.model_validate(data)
def validate_web_tools_state(data: dict[str, Any]) -> WebToolsStateModel:
def validate_web_tools_state(data: dict[str, object]) -> WebToolsStateModel:
"""Validate web tools state data using Pydantic model."""
return WebToolsStateModel.model_validate(data)

View File

@@ -1,7 +1,7 @@
"""Root pytest configuration with hierarchical fixtures."""
import asyncio
import asyncio
import inspect
import importlib
import importlib.util
import os
@@ -170,9 +170,27 @@ def _ensure_optional_dependency(package: str) -> None:
_install_stub(package)
_REQUIRED_DEPENDENCIES: tuple[str, ...] = ("langgraph", "langchain_core")
def _require_dependency(package: str) -> None:
"""Ensure a required dependency is importable for the test suite."""
try:
importlib.import_module(package)
except ModuleNotFoundError as exc: # pragma: no cover - defensive guard
message = (
f"The '{package}' package is required to run the test suite. "
"Install the project dependencies with `pip install -r requirements.txt`."
)
raise RuntimeError(message) from exc
for required_package in _REQUIRED_DEPENDENCIES:
_require_dependency(required_package)
for optional_package in (
"langgraph",
"langchain_core",
"langchain_anthropic",
"langchain_openai",
"pydantic",
@@ -278,10 +296,19 @@ def pytest_pyfunc_call(pyfuncitem: pytest.Function) -> bool | None:
if marker is None:
return None
if pyfuncitem.config.pluginmanager.hasplugin("asyncio"):
return None
loop = asyncio.new_event_loop()
try:
kwargs = {name: pyfuncitem.funcargs[name] for name in pyfuncitem._fixtureinfo.argnames} # type: ignore[attr-defined]
loop.run_until_complete(pyfuncitem.obj(**kwargs))
kwargs = {
name: pyfuncitem.funcargs[name]
for name in pyfuncitem._fixtureinfo.argnames # type: ignore[attr-defined]
}
result = pyfuncitem.obj(**kwargs)
if not inspect.isawaitable(result):
return None
loop.run_until_complete(result)
finally:
loop.close()
return True

View File

@@ -2,7 +2,8 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, cast
from collections.abc import MutableMapping
from typing import TYPE_CHECKING, cast
from unittest.mock import AsyncMock, patch
import pytest
@@ -12,7 +13,7 @@ from langchain_core.runnables import RunnableConfig
from tests.helpers.mocks.mock_builders import MockLLMBuilder
if TYPE_CHECKING:
from langgraph.graph.state import CompiledStateGraph
from langgraph.graph.state import CompiledStateGraph as CompiledGraph
@pytest.mark.e2e
@@ -48,7 +49,7 @@ class TestAnalysisWorkflowE2E:
return mock_service_factory
@pytest.fixture
def sample_data(self) -> dict[str, Any]:
def sample_data(self) -> dict[str, object]:
"""Sample data for analysis."""
return {
"sales_data": [
@@ -146,7 +147,7 @@ class TestAnalysisWorkflowE2E:
mock_call_model_node: AsyncMock,
mock_get_global_factory: AsyncMock,
mock_analyze_data: AsyncMock,
sample_data: dict[str, Any],
sample_data: dict[str, object],
mock_data_analyzer: AsyncMock,
mock_llm_analyzer: AsyncMock,
) -> None:
@@ -204,7 +205,7 @@ class TestAnalysisWorkflowE2E:
# Create analysis graph (simplified for test) - using optimized version
from biz_bud.graphs.graph import get_graph
analysis_graph: CompiledStateGraph[Any] = get_graph()
analysis_graph: CompiledGraph[MutableMapping[str, object]] = get_graph()
# Initial state with data
initial_state = {
@@ -246,8 +247,8 @@ class TestAnalysisWorkflowE2E:
# Mock the LLM to handle the error
async def mock_llm_response(
state: dict[str, Any], config=None
) -> dict[str, Any]:
state: MutableMapping[str, object], config=None
) -> MutableMapping[str, object]:
# Return a simple error acknowledgment
from langchain_core.messages import AIMessage
@@ -268,7 +269,7 @@ class TestAnalysisWorkflowE2E:
from biz_bud.graphs.graph import get_graph
analysis_graph: CompiledStateGraph[Any] = get_graph()
analysis_graph: CompiledGraph[MutableMapping[str, object]] = get_graph()
# Invalid data
invalid_data = {
@@ -313,7 +314,7 @@ class TestAnalysisWorkflowE2E:
mock_get_global_factory: AsyncMock,
mock_create_viz: AsyncMock,
mock_analyze_data: AsyncMock,
sample_data: dict[str, Any],
sample_data: dict[str, object],
mock_data_analyzer: AsyncMock,
mock_llm_analyzer: AsyncMock,
) -> None:
@@ -373,7 +374,7 @@ class TestAnalysisWorkflowE2E:
from biz_bud.graphs.graph import get_graph
analysis_graph: CompiledStateGraph[Any] = get_graph()
analysis_graph: CompiledGraph[MutableMapping[str, object]] = get_graph()
initial_state = {
"messages": [
@@ -409,7 +410,7 @@ class TestAnalysisWorkflowE2E:
mock_call_model_node: AsyncMock,
mock_get_global_factory: AsyncMock,
mock_analyze_data: AsyncMock,
sample_data: dict[str, Any],
sample_data: dict[str, object],
mock_data_analyzer: AsyncMock,
) -> None:
"""Test analysis workflow that includes strategic planning."""
@@ -484,7 +485,7 @@ class TestAnalysisWorkflowE2E:
from biz_bud.graphs.graph import get_graph
analysis_graph: CompiledStateGraph[Any] = get_graph()
analysis_graph: CompiledGraph[MutableMapping[str, object]] = get_graph()
initial_state = {
"messages": [
@@ -621,7 +622,7 @@ class TestAnalysisWorkflowE2E:
from biz_bud.graphs.graph import get_graph
analysis_graph: CompiledStateGraph[Any] = get_graph()
analysis_graph: CompiledGraph[MutableMapping[str, object]] = get_graph()
initial_state = {
"messages": [HumanMessage(content="Analyze large transaction dataset")],
@@ -723,7 +724,7 @@ class TestAnalysisWorkflowE2E:
from biz_bud.graphs.graph import get_graph
analysis_graph: CompiledStateGraph[Any] = get_graph()
analysis_graph: CompiledGraph[MutableMapping[str, object]] = get_graph()
initial_state = {
"messages": [HumanMessage(content="Compare 2023 vs 2024 performance")],

View File

@@ -4,7 +4,7 @@ from unittest.mock import patch
import pytest
from langchain_core.runnables import RunnableConfig
from langgraph.graph.state import CompiledStateGraph
from langgraph.graph.state import CompiledStateGraph as CompiledGraph
from biz_bud.states.error_handling import ErrorHandlingState
from biz_bud.states.research import ResearchState
@@ -18,7 +18,7 @@ class TestResearchGraph:
from biz_bud.graphs.research.graph import create_research_graph
graph = create_research_graph()
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
assert hasattr(graph, 'ainvoke')
def test_research_graph_factory(self):
@@ -27,7 +27,7 @@ class TestResearchGraph:
config: RunnableConfig = {"configurable": {"test": "config"}}
graph = research_graph_factory(config)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
@pytest.mark.asyncio
async def test_research_graph_factory_async(self):
@@ -36,7 +36,7 @@ class TestResearchGraph:
config: RunnableConfig = {"configurable": {"test": "config"}}
graph = await research_graph_factory_async(config)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
def test_research_graph_has_expected_nodes(self):
"""Test that research graph contains expected nodes."""
@@ -58,7 +58,7 @@ class TestCatalogGraph:
from biz_bud.graphs.catalog.graph import create_catalog_graph
graph = create_catalog_graph()
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
def test_catalog_graph_factory(self):
"""Test catalog graph factory function."""
@@ -66,7 +66,7 @@ class TestCatalogGraph:
config: RunnableConfig = {"configurable": {"test": "config"}}
graph = catalog_graph_factory(config)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
@pytest.mark.asyncio
async def test_catalog_graph_factory_async(self):
@@ -76,7 +76,7 @@ class TestCatalogGraph:
config: RunnableConfig = {"configurable": {"test": "config"}}
import asyncio
graph = await asyncio.to_thread(catalog_graph_factory, config)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
class TestAnalysisGraph:
@@ -87,7 +87,7 @@ class TestAnalysisGraph:
from biz_bud.graphs.analysis.graph import create_analysis_graph
graph = create_analysis_graph()
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
def test_analysis_graph_factory(self):
"""Test analysis graph factory function."""
@@ -95,7 +95,7 @@ class TestAnalysisGraph:
config: RunnableConfig = {"configurable": {"test": "config"}}
graph = analysis_graph_factory(config)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
@pytest.mark.asyncio
async def test_analysis_graph_factory_async(self):
@@ -104,7 +104,7 @@ class TestAnalysisGraph:
config: RunnableConfig = {"configurable": {"test": "config"}}
graph = await analysis_graph_factory_async(config)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
class TestScrapingGraph:
@@ -115,7 +115,7 @@ class TestScrapingGraph:
from biz_bud.graphs.scraping.graph import create_scraping_graph
graph = create_scraping_graph()
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
def test_scraping_graph_factory(self):
"""Test scraping graph factory function."""
@@ -123,7 +123,7 @@ class TestScrapingGraph:
config: RunnableConfig = {"configurable": {"test": "config"}}
graph = scraping_graph_factory(config)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
@pytest.mark.asyncio
async def test_scraping_graph_factory_async(self):
@@ -132,7 +132,7 @@ class TestScrapingGraph:
config: RunnableConfig = {"configurable": {"test": "config"}}
graph = await scraping_graph_factory_async(config)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
class TestRAGGraph:
@@ -143,7 +143,7 @@ class TestRAGGraph:
from biz_bud.graphs.rag.graph import create_url_to_r2r_graph
graph = create_url_to_r2r_graph()
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
def test_rag_graph_with_config(self):
"""Test RAG graph creation with config."""
@@ -151,7 +151,7 @@ class TestRAGGraph:
config = {"test": "config"}
graph = create_url_to_r2r_graph(config)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
class TestPlannerGraph:
@@ -162,7 +162,7 @@ class TestPlannerGraph:
from biz_bud.graphs.planner import create_planner_graph
graph = create_planner_graph()
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
def test_planner_graph_factory(self):
"""Test planner graph factory function."""
@@ -170,7 +170,7 @@ class TestPlannerGraph:
config: RunnableConfig = {"configurable": {"test": "config"}}
graph = planner_graph_factory(config)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
@pytest.mark.asyncio
async def test_planner_graph_factory_async(self):
@@ -179,7 +179,7 @@ class TestPlannerGraph:
config: RunnableConfig = {"configurable": {"test": "config"}}
graph = await planner_graph_factory_async(config)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
class TestErrorHandlingGraph:
@@ -190,7 +190,7 @@ class TestErrorHandlingGraph:
from biz_bud.graphs.error_handling import create_error_handling_graph
graph = create_error_handling_graph()
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
def test_error_handling_graph_factory(self):
"""Test error handling graph factory function."""
@@ -198,7 +198,7 @@ class TestErrorHandlingGraph:
config: RunnableConfig = {"configurable": {"test": "config"}}
graph = error_handling_graph_factory(config)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
@pytest.mark.asyncio
async def test_error_handling_graph_factory_async(self):
@@ -207,7 +207,7 @@ class TestErrorHandlingGraph:
config: RunnableConfig = {"configurable": {"test": "config"}}
graph = await error_handling_graph_factory_async(config)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
class TestGraphStateCompatibility:
@@ -316,7 +316,7 @@ class TestGraphBuilderPatternConsistency:
]
for graph in graphs:
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
assert hasattr(graph, 'ainvoke')
assert hasattr(graph, 'astream')
@@ -344,7 +344,7 @@ class TestGraphBuilderPatternConsistency:
for factory in factories:
graph = factory(config)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
class TestGraphExecutionCompatibility:
@@ -493,7 +493,7 @@ class TestEndToEndIntegration:
for name, creator in graph_creators:
try:
graph = creator()
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
successful_graphs.append(name)
except Exception as e:
failed_graphs.append((name, str(e)))

View File

@@ -2,7 +2,7 @@
import pytest
from langchain_core.runnables import RunnableConfig
from langgraph.graph.state import CompiledStateGraph
from langgraph.graph.state import CompiledStateGraph as CompiledGraph
class TestGraphBuilderIntegration:
@@ -13,7 +13,7 @@ class TestGraphBuilderIntegration:
from biz_bud.graphs.research.graph import create_research_graph
graph = create_research_graph()
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
assert hasattr(graph, 'ainvoke')
def test_catalog_graph_creation(self):
@@ -21,7 +21,7 @@ class TestGraphBuilderIntegration:
from biz_bud.graphs.catalog.graph import create_catalog_graph
graph = create_catalog_graph()
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
assert hasattr(graph, 'ainvoke')
def test_analysis_graph_creation(self):
@@ -29,7 +29,7 @@ class TestGraphBuilderIntegration:
from biz_bud.graphs.analysis.graph import create_analysis_graph
graph = create_analysis_graph()
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
assert hasattr(graph, 'ainvoke')
def test_scraping_graph_creation(self):
@@ -37,7 +37,7 @@ class TestGraphBuilderIntegration:
from biz_bud.graphs.scraping.graph import create_scraping_graph
graph = create_scraping_graph()
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
assert hasattr(graph, 'ainvoke')
def test_rag_graph_creation(self):
@@ -45,7 +45,7 @@ class TestGraphBuilderIntegration:
from biz_bud.graphs.rag.graph import create_url_to_r2r_graph
graph = create_url_to_r2r_graph()
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
assert hasattr(graph, 'ainvoke')
def test_planner_graph_creation(self):
@@ -53,7 +53,7 @@ class TestGraphBuilderIntegration:
from biz_bud.graphs.planner import create_planner_graph
graph = create_planner_graph()
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
assert hasattr(graph, 'ainvoke')
def test_error_handling_graph_creation(self):
@@ -61,7 +61,7 @@ class TestGraphBuilderIntegration:
from biz_bud.graphs.error_handling import create_error_handling_graph
graph = create_error_handling_graph()
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
assert hasattr(graph, 'ainvoke')
@pytest.mark.integration
@@ -91,7 +91,7 @@ class TestGraphBuilderIntegration:
for name, creator in graph_creators:
try:
graph = creator()
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
assert hasattr(graph, 'ainvoke')
successful_graphs.append(name)
except Exception as e:
@@ -132,7 +132,7 @@ class TestGraphBuilderIntegration:
for factory in factories:
graph = factory(config)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
assert hasattr(graph, 'ainvoke')
def test_builder_pattern_consistency(self):

View File

@@ -1,5 +0,0 @@
"""Lightweight test-oriented stub of the LangGraph package."""
from .graph import END, START, StateGraph
__all__ = ["StateGraph", "START", "END"]

View File

@@ -1,6 +0,0 @@
"""Cache subpackage for LangGraph test stubs."""
from .base import BaseCache
from .memory import InMemoryCache
__all__ = ["BaseCache", "InMemoryCache"]

View File

@@ -1,22 +0,0 @@
"""Cache protocol used by the LangGraph test stub."""
from __future__ import annotations
from typing import Any, Protocol, runtime_checkable
@runtime_checkable
class BaseCache(Protocol):
"""Protocol capturing the minimal cache interface exercised in tests."""
async def aget(self, key: str) -> Any: # pragma: no cover - interface only
...
async def aset(self, key: str, value: Any) -> None: # pragma: no cover - interface only
...
async def adelete(self, key: str) -> None: # pragma: no cover - interface only
...
__all__ = ["BaseCache"]

View File

@@ -1,29 +0,0 @@
"""Minimal in-memory cache implementation for tests."""
from __future__ import annotations
from typing import Any, Dict, Optional
from .base import BaseCache
class InMemoryCache(BaseCache):
"""Dictionary-backed async cache used in unit tests."""
def __init__(self) -> None:
self._store: Dict[str, Any] = {}
async def aget(self, key: str) -> Optional[Any]:
return self._store.get(key)
async def aset(self, key: str, value: Any) -> None:
self._store[key] = value
async def adelete(self, key: str) -> None:
self._store.pop(key, None)
def clear(self) -> None:
self._store.clear()
__all__ = ["InMemoryCache"]

View File

@@ -1,19 +0,0 @@
"""Checkpoint base classes used by the LangGraph stub."""
from __future__ import annotations
from typing import Any, Protocol, runtime_checkable
@runtime_checkable
class BaseCheckpointSaver(Protocol):
"""Protocol capturing the limited API exercised by the tests."""
def save(self, state: Any) -> None: # pragma: no cover - interface only
...
def load(self) -> Any: # pragma: no cover - interface only
...
__all__ = ["BaseCheckpointSaver"]

View File

@@ -1,23 +0,0 @@
"""In-memory checkpoint saver used by LangGraph tests."""
from __future__ import annotations
from typing import Any
from .base import BaseCheckpointSaver
class InMemorySaver(BaseCheckpointSaver):
"""Trivial checkpoint saver that stores the latest state in memory."""
def __init__(self) -> None:
self._state: Any | None = None
def save(self, state: Any) -> None:
self._state = state
def load(self) -> Any:
return self._state
__all__ = ["InMemorySaver"]

View File

@@ -1,264 +0,0 @@
"""Minimal subset of LangGraph graph primitives for unit testing."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Generic, List, Mapping, Sequence, Tuple, TypeVar, TYPE_CHECKING
from langgraph.cache.base import BaseCache
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.store.base import BaseStore
from langgraph.types import All
if TYPE_CHECKING: # pragma: no cover - hinting only
from langgraph.graph.state import CachePolicy, RetryPolicy
else: # pragma: no cover - fallback for runtime without circular imports
CachePolicy = Any # type: ignore[assignment]
RetryPolicy = Any # type: ignore[assignment]
StateT = TypeVar("StateT")
NodeCallable = Callable[[Any], Any]
RouterCallable = Callable[[Any], str]
START: str = "__start__"
END: str = "__end__"
@dataclass(slots=True)
class NodeSpec:
"""Lightweight representation of a configured LangGraph node."""
func: NodeCallable
metadata: Dict[str, Any]
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None
cache_policy: CachePolicy | None
defer: bool
input_schema: type[Any] | None
@dataclass
class _ConditionalEdges:
"""Internal representation of conditional routing information."""
source: str | object
router: RouterCallable
mapping: Mapping[str, str] | Sequence[str]
class StateGraph(Generic[StateT]):
"""Simplified builder that mimics the LangGraph public API."""
def __init__(
self,
state_schema: type[StateT],
*,
context_schema: type[Any] | None = None,
input_schema: type[Any] | None = None,
output_schema: type[Any] | None = None,
) -> None:
self.state_schema = state_schema
self.context_schema = context_schema
self.input_schema = input_schema
self.output_schema = output_schema
self._nodes: Dict[str, NodeSpec] = {}
self._edges: List[Tuple[str | object, str | object]] = []
self._conditional_edges: List[_ConditionalEdges] = []
self._entry_point: str | None = None
# ------------------------------------------------------------------
# Graph mutation helpers
# ------------------------------------------------------------------
def add_node(
self,
name: str,
func: NodeCallable,
*,
metadata: Dict[str, Any] | None = None,
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None,
cache_policy: CachePolicy | None = None,
defer: bool = False,
input_schema: type[Any] | None = None,
**_: Any,
) -> None:
self._nodes[name] = NodeSpec(
func=func,
metadata=dict(metadata or {}),
retry_policy=retry_policy,
cache_policy=cache_policy,
defer=defer,
input_schema=input_schema,
)
def add_edge(self, source: str | object, target: str | object) -> None:
self._edges.append((source, target))
def add_conditional_edges(
self,
source: str | object,
router: RouterCallable,
mapping: Mapping[str, str] | Sequence[str],
) -> None:
self._conditional_edges.append(
_ConditionalEdges(source=source, router=router, mapping=mapping)
)
def set_entry_point(self, node: str) -> None:
self._entry_point = node
# ------------------------------------------------------------------
# Compilation
# ------------------------------------------------------------------
def compile(
self,
*,
checkpointer: BaseCheckpointSaver[Any] | None = None,
cache: BaseCache[Any] | None = None,
store: BaseStore[Any] | None = None,
interrupt_before: All | Sequence[str] | None = None,
interrupt_after: All | Sequence[str] | None = None,
debug: bool = False,
name: str | None = None,
) -> "CompiledStateGraph[StateT]":
self._validate()
alias = name
context_schema = self.context_schema
input_schema = self.input_schema
output_schema = self.output_schema
if alias:
if context_schema is not None:
context_schema = _alias_schema(context_schema, f"{alias}_context")
if input_schema is not None:
input_schema = _alias_schema(input_schema, f"{alias}_input")
if output_schema is not None:
output_schema = _alias_schema(output_schema, f"{alias}_output")
compiled = CompiledStateGraph(
builder=self,
checkpointer=checkpointer,
cache=cache,
store=store,
interrupt_before_nodes=_normalise_interrupts(interrupt_before),
interrupt_after_nodes=_normalise_interrupts(interrupt_after),
debug=debug,
name=name,
context_schema=context_schema,
input_schema=input_schema,
output_schema=output_schema,
)
return compiled
# ------------------------------------------------------------------
# Introspection helpers used in tests
# ------------------------------------------------------------------
@property
def nodes(self) -> Mapping[str, NodeSpec]:
return dict(self._nodes)
@property
def edges(self) -> Sequence[Tuple[str | object, str | object]]:
return list(self._edges)
@property
def conditional_edges(self) -> Sequence[_ConditionalEdges]:
return list(self._conditional_edges)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _validate(self) -> None:
if not self._nodes:
raise ValueError("Graph must have an entrypoint")
# Determine the effective entry point.
entry_point = self._entry_point
if entry_point is None:
for source, target in self._edges:
if source == START and isinstance(target, str):
entry_point = target
break
if entry_point is None:
raise ValueError("Graph must have an entrypoint")
# Validate that all referenced nodes exist (ignoring START/END sentinels).
for source, target in self._edges:
if source not in {START, END} and source not in self._nodes:
raise ValueError("Found edge starting at unknown node: {0}".format(source))
if target not in {START, END} and target not in self._nodes:
raise ValueError("Found edge ending at unknown node: {0}".format(target))
for cond in self._conditional_edges:
if cond.source not in {START, END} and cond.source not in self._nodes:
raise ValueError(
"Found conditional edge starting at unknown node: {0}".format(
cond.source
)
)
if isinstance(cond.mapping, Mapping):
missing = [
dest
for dest in cond.mapping.values()
if dest not in {END} and dest not in self._nodes
]
if missing:
raise ValueError(
"Found conditional edge ending at unknown node: {0}".format(
", ".join(missing)
)
)
def _normalise_interrupts(value: All | Sequence[str] | None) -> List[str] | All | None:
if value is None:
return None
if isinstance(value, str):
return [value]
if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
return list(value)
return value
def _alias_schema(schema: type[Any], alias: str) -> type[Any]:
try:
setattr(schema, "__name__", alias)
except Exception: # pragma: no cover - defensive fallback
pass
return schema
@dataclass
class CompiledStateGraph(Generic[StateT]):
"""Runtime representation returned by :meth:`StateGraph.compile`."""
builder: StateGraph[StateT]
checkpointer: BaseCheckpointSaver[Any] | None = None
cache: BaseCache[Any] | None = None
store: BaseStore[Any] | None = None
interrupt_before_nodes: List[str] | All | None = field(default_factory=list)
interrupt_after_nodes: List[str] | All | None = field(default_factory=list)
debug: bool = False
name: str | None = None
context_schema: type[Any] | None = None
input_schema: type[Any] | None = None
output_schema: type[Any] | None = None
@property
def InputType(self) -> type[Any] | None:
"""Compatibility alias for LangGraph 1.x compiled graphs."""
return self.input_schema
@property
def OutputType(self) -> type[Any] | None:
"""Compatibility alias for LangGraph 1.x compiled graphs."""
return self.output_schema
@property
def ContextType(self) -> type[Any] | None:
"""Compatibility alias for LangGraph 1.x compiled graphs."""
return self.context_schema
__all__ = ["StateGraph", "START", "END", "CompiledStateGraph", "NodeSpec"]

View File

@@ -1,17 +0,0 @@
"""Stubbed message utilities for LangGraph integration."""
from __future__ import annotations
from typing import Any, Iterable
def add_messages(state: dict[str, Any], messages: Iterable[Any]) -> dict[str, Any]:
"""Append messages to the state's ``messages`` list."""
existing = state.setdefault("messages", [])
if isinstance(existing, list):
existing.extend(list(messages))
return state
__all__ = ["add_messages"]

View File

@@ -1,47 +0,0 @@
"""LangGraph graph state helpers used across the Biz Bud tests."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Sequence
from . import CompiledStateGraph
@dataclass(slots=True)
class RetryPolicy:
"""Stubbed retry policy with configurable attempt limits."""
max_attempts: int = 1
min_backoff: float | None = None
max_backoff: float | None = None
backoff_multiplier: float | None = None
jitter: float | None = None
def as_sequence(self) -> Sequence["RetryPolicy"]:
"""Return the policy as a singleton sequence for convenience."""
return (self,)
@dataclass(slots=True)
class CachePolicy:
"""Minimal cache policy matching the attributes used in tests."""
namespace: str | None = None
key: str | None = None
ttl: float | None = None
populate: bool = True
def describe(self) -> dict[str, Any]:
"""Expose a serialisable representation for debugging."""
return {
"namespace": self.namespace,
"key": self.key,
"ttl": self.ttl,
"populate": self.populate,
}
__all__ = ["CachePolicy", "RetryPolicy", "CompiledStateGraph"]

View File

@@ -1,19 +0,0 @@
"""Store base classes used by the LangGraph stub."""
from __future__ import annotations
from typing import Any, Protocol, runtime_checkable
@runtime_checkable
class BaseStore(Protocol):
"""Protocol capturing the minimal store interface used in tests."""
def put(self, key: str, value: Any) -> None: # pragma: no cover - interface only
...
def get(self, key: str) -> Any: # pragma: no cover - interface only
...
__all__ = ["BaseStore"]

View File

@@ -1,26 +0,0 @@
"""Simple in-memory store implementation for tests."""
from __future__ import annotations
from typing import Any, Dict, Optional
from .base import BaseStore
class InMemoryStore(BaseStore):
"""Minimal dictionary-backed store."""
def __init__(self) -> None:
self._store: Dict[str, Any] = {}
def put(self, key: str, value: Any) -> None:
self._store[key] = value
def get(self, key: str) -> Optional[Any]:
return self._store.get(key)
def clear(self) -> None:
self._store.clear()
__all__ = ["InMemoryStore"]

View File

@@ -1,43 +0,0 @@
"""Common type hints and helper payloads exposed by the LangGraph stub."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, Generic, Literal, Mapping, TypeVar
All = Literal["*"]
TTarget = TypeVar("TTarget")
@dataclass(slots=True)
class Command(Generic[TTarget]):
"""Simplified representation of LangGraph's command-based routing payload."""
goto: TTarget | None = None
update: Dict[str, Any] = field(default_factory=dict)
graph: str | None = None
# Class-level sentinel used to signal returning to the parent graph.
PARENT: str = "__parent__"
def as_dict(self) -> dict[str, Any]:
"""Return a serialisable view of the command for debugging."""
return {"goto": self.goto, "update": dict(self.update), "graph": self.graph}
@dataclass(slots=True)
class Send(Generic[TTarget]):
"""Parallel dispatch payload used by LangGraph for fan-out patterns."""
target: TTarget
state: Mapping[str, Any]
def with_updates(self, updates: Mapping[str, Any]) -> "Send[TTarget]":
merged: Dict[str, Any] = dict(self.state)
merged.update(dict(updates))
return Send(target=self.target, state=merged)
__all__ = ["All", "Command", "Send"]

View File

@@ -4,7 +4,7 @@ from typing import Any, TypedDict
from unittest.mock import MagicMock
import pytest
from langgraph.graph.state import CachePolicy, CompiledStateGraph, RetryPolicy
from langgraph.graph.state import CachePolicy, CompiledStateGraph as CompiledGraph, RetryPolicy
from langgraph.cache.memory import InMemoryCache
from langgraph.store.memory import InMemoryStore
@@ -160,7 +160,7 @@ class TestBuildGraphFromConfig:
)
graph = build_graph_from_config(config)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
def test_graph_with_conditional_edges(self, sample_node, sample_router):
"""Test building graph with conditional edges."""
@@ -189,7 +189,7 @@ class TestBuildGraphFromConfig:
)
graph = build_graph_from_config(config)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
def test_graph_with_entry_point(self, sample_node):
"""Test building graph with custom entry point."""
@@ -202,7 +202,7 @@ class TestBuildGraphFromConfig:
)
graph = build_graph_from_config(config)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
def test_graph_with_checkpointer(self, sample_node):
"""Test building graph with checkpointer."""
@@ -217,7 +217,7 @@ class TestBuildGraphFromConfig:
)
graph = build_graph_from_config(config)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
def test_graph_with_node_config_options(self, sample_node):
"""GraphBuilderConfig should accept modern node configuration options."""
@@ -238,7 +238,7 @@ class TestBuildGraphFromConfig:
)
graph = build_graph_from_config(config)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
node = graph.builder.nodes["process"]
assert node.metadata == {"role": "primary"}
@@ -288,7 +288,7 @@ class TestGraphBuilder:
.build()
)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
def test_fluent_api_with_conditional_edge(self, sample_node, sample_router):
"""Test fluent API with conditional edges."""
@@ -311,7 +311,7 @@ class TestGraphBuilder:
.build()
)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
def test_fluent_api_with_metadata(self, sample_node):
"""Test fluent API with metadata."""
@@ -324,7 +324,7 @@ class TestGraphBuilder:
.build()
)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
# Note: LangGraph may not expose config on compiled graphs
# The metadata is used internally during graph building
@@ -368,7 +368,7 @@ class TestHelperFunctions:
nodes = [("step1", node1), ("step2", node2), ("step3", node3)]
graph = create_simple_linear_graph(TestState, nodes)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
def test_create_simple_linear_graph_empty_nodes(self):
"""Test creating linear graph with empty nodes raises error."""
@@ -382,7 +382,7 @@ class TestHelperFunctions:
nodes = [("step1", node1)]
graph = create_simple_linear_graph(TestState, nodes, mock_checkpointer)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
def test_create_simple_linear_graph_with_node_options(self):
"""Linear helper should accept per-node LangGraph options."""
@@ -428,7 +428,7 @@ class TestHelperFunctions:
sample_router,
branches
)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
def test_create_branching_graph_with_checkpointer(self, sample_router):
"""Test creating branching graph with checkpointer."""
@@ -443,7 +443,7 @@ class TestHelperFunctions:
branches,
mock_checkpointer
)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
def test_create_branching_graph_empty_branches(self, sample_router):
"""Test creating branching graph handles empty branches."""
@@ -456,7 +456,7 @@ class TestHelperFunctions:
sample_router,
branches
)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
class TestErrorHandling:
@@ -500,7 +500,7 @@ class TestErrorHandling:
)
graph = build_graph_from_config(config)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
class TestIntegrationWithExistingGraphs:
@@ -535,7 +535,7 @@ class TestIntegrationWithExistingGraphs:
.build()
)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
def test_research_pattern_simulation(self):
"""Test pattern similar to research graph refactoring."""
@@ -577,7 +577,7 @@ class TestIntegrationWithExistingGraphs:
)
graph = build_graph_from_config(config)
assert isinstance(graph, CompiledStateGraph)
assert isinstance(graph, CompiledGraph)
def test_fluent_api_with_runtime_options(self, sample_node):
"""Fluent builder should expose new LangGraph runtime options."""

View File

@@ -55,16 +55,17 @@ def base_error_state() -> ErrorHandlingState:
input_state={},
execution_count=0,
),
current_error={
"message": "Test error",
"error_type": "TestError",
"node": "test_node",
"severity": "error",
"category": ErrorCategory.UNKNOWN.value,
"timestamp": "2024-01-01T00:00:00",
"context": {},
"traceback": None,
},
current_error={
"message": "Test error",
"error_type": "TestError",
"node": "test_node",
"severity": "error",
"category": ErrorCategory.UNKNOWN.value,
"timestamp": "2024-01-01T00:00:00",
"context": {},
"details": {},
"traceback": None,
},
attempted_actions=[],
)
@@ -830,7 +831,7 @@ class TestConfigurationFunctions:
config: RunnableConfig = {"configurable": {"test": "config"}}
graph = error_handling_graph_factory(config)
assert graph is not None
# CompiledStateGraph has ainvoke method
# CompiledGraph has ainvoke method
assert hasattr(graph, "ainvoke")
# NOTE: get_next_node_function was removed - functionality moved to edge helpers

View File

@@ -8,7 +8,8 @@ from langchain_core.messages import AIMessage, ToolMessage
if TYPE_CHECKING:
from biz_bud.states.tools import ToolState
from biz_bud.states.tools import create_scraped_content_dict
from biz_bud.core.errors import ValidationError
from biz_bud.states.tools import create_scraped_content_dict
class TestToolState:
@@ -499,41 +500,53 @@ class TestScrapedContentDictCreator:
assert result.get("error_message") == special_error
def test_create_scraped_content_dict_validates_empty_url(self):
"""Test that empty URL raises ValueError."""
with pytest.raises(ValueError, match="URL cannot be empty or whitespace"):
create_scraped_content_dict(url="", content="Content")
def test_create_scraped_content_dict_validates_whitespace_url(self):
"""Test that whitespace-only URL raises ValueError."""
with pytest.raises(ValueError, match="URL cannot be empty or whitespace"):
create_scraped_content_dict(url=" ", content="Content")
def test_create_scraped_content_dict_validates_empty_content(self):
"""Test that empty content raises ValueError."""
with pytest.raises(ValueError, match="Content cannot be empty or whitespace"):
create_scraped_content_dict(url="https://example.com", content="")
def test_create_scraped_content_dict_validates_whitespace_content(self):
"""Test that whitespace-only content raises ValueError."""
with pytest.raises(ValueError, match="Content cannot be empty or whitespace"):
create_scraped_content_dict(url="https://example.com", content=" ")
def test_create_scraped_content_dict_validates_negative_content_length(self):
"""Test that negative content length raises ValueError."""
with pytest.raises(ValueError, match="Content length cannot be negative"):
create_scraped_content_dict(
url="https://example.com",
content="Content",
content_length=-1
)
def test_create_scraped_content_dict_validates_success_with_error_message(self):
"""Test that success=True with error_message raises ValueError."""
with pytest.raises(ValueError, match="Error message should not be present when success is True"):
create_scraped_content_dict(
url="https://example.com",
content="Content",
success=True,
error_message="This should not be here"
)
def test_create_scraped_content_dict_validates_empty_url(self):
"""Test that empty URL raises ValidationError."""
with pytest.raises(ValidationError, match="URL cannot be empty or whitespace"):
create_scraped_content_dict(url="", content="Content")
def test_create_scraped_content_dict_validates_whitespace_url(self):
"""Test that whitespace-only URL raises ValidationError."""
with pytest.raises(ValidationError, match="URL cannot be empty or whitespace"):
create_scraped_content_dict(url=" ", content="Content")
def test_create_scraped_content_dict_validates_empty_content(self):
"""Test that empty content raises ValidationError when success is True."""
with pytest.raises(ValidationError, match="Content cannot be empty or whitespace"):
create_scraped_content_dict(url="https://example.com", content="")
def test_create_scraped_content_dict_validates_whitespace_content(self):
"""Test that whitespace-only content raises ValidationError when success is True."""
with pytest.raises(ValidationError, match="Content cannot be empty or whitespace"):
create_scraped_content_dict(url="https://example.com", content=" ")
def test_create_scraped_content_dict_validates_negative_content_length(self):
"""Test that negative content length raises ValidationError."""
with pytest.raises(ValidationError, match="Content length cannot be negative"):
create_scraped_content_dict(
url="https://example.com",
content="Content",
content_length=-1
)
def test_create_scraped_content_dict_validates_success_with_error_message(self):
"""Test that success=True with error_message raises ValidationError."""
with pytest.raises(ValidationError, match="Error message should not be present when success is True"):
create_scraped_content_dict(
url="https://example.com",
content="Content",
success=True,
error_message="This should not be here"
)
def test_create_scraped_content_dict_failure_normalizes_empty_content(self):
"""Ensure failed scraping normalizes missing content to an empty string."""
result = create_scraped_content_dict(
url="https://example.com",
content=" ",
success=False,
error_message="failure",
)
assert result["content"] == ""
assert result["success"] is False