Modernize LangGraph routing helpers for mapping-based state (#61)
* Modernize LangGraph routing helpers for mapping-based state * Refine LangGraph v1 reducers and search node handling * Harden search workflow state handling * Refine LangGraph routing and search handling * Harden LangGraph search routing and reducer safety * Improve LangGraph search hygiene and scraping validation * Refine LangGraph integration and search sanitization * Refine state access protocol for routing helpers * Refine edge helper state access and numeric coercion * Harden search sanitization and routing utilities * refactor: remove url processing compatibility layer * Harden langgraph reducers and routing helpers * Break circular import in legacy URL processing shim * Fix pytest async hook interaction * Modernize LangGraph helpers and sanitizers * Address follow-up LangGraph modernization issues * Harden LangGraph routing and URL helpers * Refine RAG integration nodes for typed state
This commit is contained in:
@@ -43,6 +43,8 @@ from biz_bud.services.factory import ServiceFactory
|
||||
from biz_bud.states.buddy import BuddyState
|
||||
from biz_bud.tools.capabilities.workflow.execution import ResponseFormatter
|
||||
|
||||
CompiledGraph = CompiledStateGraph[BuddyState]
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +59,7 @@ __all__ = [
|
||||
|
||||
def create_buddy_orchestrator_graph(
|
||||
config: AppConfig | None = None,
|
||||
) -> "CompiledStateGraph":
|
||||
) -> CompiledGraph:
|
||||
"""Create the Buddy orchestrator graph with all components.
|
||||
|
||||
Args:
|
||||
@@ -124,7 +126,7 @@ def create_buddy_orchestrator_graph(
|
||||
def create_buddy_orchestrator_agent(
|
||||
config: AppConfig | None = None,
|
||||
service_factory: ServiceFactory | None = None,
|
||||
) -> "CompiledStateGraph":
|
||||
) -> CompiledGraph:
|
||||
"""Create the Buddy orchestrator agent.
|
||||
|
||||
Args:
|
||||
@@ -149,13 +151,13 @@ def create_buddy_orchestrator_agent(
|
||||
|
||||
|
||||
# Direct singleton instance
|
||||
_buddy_agent_instance: "CompiledStateGraph | None" = None
|
||||
_buddy_agent_instance: CompiledGraph | None = None
|
||||
|
||||
|
||||
def get_buddy_agent(
|
||||
config: AppConfig | None = None,
|
||||
service_factory: ServiceFactory | None = None,
|
||||
) -> "CompiledStateGraph":
|
||||
) -> CompiledGraph:
|
||||
"""Get or create the Buddy agent instance.
|
||||
|
||||
Uses singleton pattern for default instance.
|
||||
@@ -326,12 +328,12 @@ async def stream_buddy_agent(
|
||||
|
||||
|
||||
# Export for LangGraph API
|
||||
def buddy_agent_factory(config: RunnableConfig) -> "CompiledStateGraph":
|
||||
def buddy_agent_factory(config: RunnableConfig) -> "CompiledGraph":
|
||||
"""Create factory function for LangGraph API."""
|
||||
return get_buddy_agent()
|
||||
|
||||
|
||||
async def buddy_agent_factory_async(config: RunnableConfig) -> "CompiledStateGraph":
|
||||
async def buddy_agent_factory_async(config: RunnableConfig) -> "CompiledGraph":
|
||||
"""Async factory function for LangGraph API to avoid blocking calls."""
|
||||
# Use asyncio.to_thread to run the synchronous initialization in a thread
|
||||
# This prevents blocking the event loop
|
||||
|
||||
@@ -8,8 +8,13 @@ mappings and condition evaluation.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from collections.abc import Callable, Hashable, Mapping
|
||||
from typing import TypeVar, cast
|
||||
|
||||
from .core import StateLike, get_state_value
|
||||
|
||||
|
||||
KeyT = TypeVar("KeyT", bound=Hashable)
|
||||
|
||||
|
||||
class BasicRouters:
|
||||
@@ -18,9 +23,9 @@ class BasicRouters:
|
||||
@staticmethod
|
||||
def route_on_key(
|
||||
state_key: str,
|
||||
mapping: dict[Any, str],
|
||||
mapping: Mapping[KeyT, str],
|
||||
default: str = "end",
|
||||
) -> Callable[[dict[str, Any]], str]:
|
||||
) -> Callable[[StateLike], str]:
|
||||
"""Route based on a state key value.
|
||||
|
||||
Simple routing pattern that looks up a value in state and maps it
|
||||
@@ -45,18 +50,23 @@ class BasicRouters:
|
||||
```
|
||||
"""
|
||||
|
||||
def router(state: dict[str, Any]) -> str:
|
||||
value = state.get(state_key)
|
||||
return mapping.get(value, default)
|
||||
def router(state: StateLike) -> str:
|
||||
value = get_state_value(state, state_key)
|
||||
if isinstance(value, Hashable):
|
||||
try:
|
||||
return mapping.get(cast(KeyT, value), default)
|
||||
except TypeError:
|
||||
return default
|
||||
return default
|
||||
|
||||
return router
|
||||
|
||||
@staticmethod
|
||||
def route_on_condition(
|
||||
condition: Callable[[dict[str, Any]], bool],
|
||||
condition: Callable[[StateLike], bool],
|
||||
true_target: str,
|
||||
false_target: str,
|
||||
) -> Callable[[dict[str, Any]], str]:
|
||||
) -> Callable[[StateLike], str]:
|
||||
"""Route based on a boolean condition.
|
||||
|
||||
Simple binary routing based on evaluating a condition function.
|
||||
@@ -81,7 +91,7 @@ class BasicRouters:
|
||||
```
|
||||
"""
|
||||
|
||||
def router(state: dict[str, Any]) -> str:
|
||||
def router(state: StateLike) -> str:
|
||||
return true_target if condition(state) else false_target
|
||||
|
||||
return router
|
||||
@@ -92,7 +102,7 @@ class BasicRouters:
|
||||
threshold: float,
|
||||
above_target: str,
|
||||
below_target: str,
|
||||
) -> Callable[[dict[str, Any]], str]:
|
||||
) -> Callable[[StateLike], str]:
|
||||
"""Route based on numeric threshold comparison.
|
||||
|
||||
Compares a numeric value in state against a threshold and routes
|
||||
@@ -119,8 +129,8 @@ class BasicRouters:
|
||||
```
|
||||
"""
|
||||
|
||||
def router(state: dict[str, Any]) -> str:
|
||||
value = state.get(state_key, 0)
|
||||
def router(state: StateLike) -> str:
|
||||
value = get_state_value(state, state_key, 0)
|
||||
try:
|
||||
numeric_value = float(value)
|
||||
return above_target if numeric_value >= threshold else below_target
|
||||
|
||||
@@ -4,17 +4,17 @@ This module provides the Buddy-specific routing logic that handles
|
||||
orchestration phases and execution flow.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from langgraph.types import Command
|
||||
|
||||
from .router_factories import create_command_router
|
||||
from .routing_rules import CommandRoutingRule
|
||||
|
||||
|
||||
def create_buddy_command_router() -> Callable[[Any], Command[str]]:
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
|
||||
from langgraph.types import Command
|
||||
|
||||
from .core import StateLike
|
||||
from .router_factories import create_command_router
|
||||
from .routing_rules import CommandRoutingRule
|
||||
|
||||
|
||||
def create_buddy_command_router() -> Callable[[StateLike], Command[str]]:
|
||||
"""Create a specialized Command router for Buddy orchestration patterns.
|
||||
|
||||
This provides the standard Buddy routing logic using Commands for enhanced
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
"""Command and Send pattern implementations for LangGraph edge routing.
|
||||
|
||||
This module provides advanced routing patterns using LangGraph's Command
|
||||
and Send objects for dynamic control flow and map-reduce patterns.
|
||||
"""
|
||||
"""Command and Send pattern implementations for LangGraph edge routing."""
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
|
||||
from langgraph.types import Command, Send
|
||||
|
||||
@@ -16,29 +11,140 @@ from biz_bud.logging import debug_highlight, get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
StateMapping = Mapping[str, object]
|
||||
|
||||
|
||||
def _is_json_compatible(value: object) -> bool:
|
||||
if value is None or isinstance(value, (bool, int, float, str)):
|
||||
return True
|
||||
if isinstance(value, (list, tuple)):
|
||||
return all(_is_json_compatible(item) for item in value)
|
||||
if isinstance(value, Mapping):
|
||||
return all(isinstance(key, str) and _is_json_compatible(val) for key, val in value.items())
|
||||
return False
|
||||
|
||||
|
||||
def _sanitize_value(value: object) -> object:
|
||||
if value is None or isinstance(value, (bool, int, float, str)):
|
||||
return value
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [_sanitize_value(item) for item in value]
|
||||
if isinstance(value, Mapping):
|
||||
sanitized: dict[str, object] = {}
|
||||
for key, val in value.items():
|
||||
sanitized[str(key)] = _sanitize_value(val)
|
||||
return sanitized
|
||||
return repr(value)
|
||||
|
||||
|
||||
def _sanitize_mapping(mapping: Mapping[str, object]) -> dict[str, object]:
|
||||
sanitized: dict[str, object] = {}
|
||||
for key, value in mapping.items():
|
||||
sanitized_key = key if isinstance(key, str) else str(key)
|
||||
sanitized[sanitized_key] = _sanitize_value(value)
|
||||
return sanitized
|
||||
|
||||
|
||||
def _estimate_size(obj: object, limit: int = 64_000) -> int:
|
||||
try:
|
||||
import json
|
||||
|
||||
serialized = json.dumps(obj, default=lambda _: "<non-serializable>")
|
||||
return len(serialized.encode("utf-8"))
|
||||
except Exception:
|
||||
return limit + 1
|
||||
|
||||
|
||||
# Private helper functions for DRY code
|
||||
def _log_command(
|
||||
target: str, updates: dict[str, Any] | None, category: str
|
||||
target: str, updates: Mapping[str, object] | None, category: str
|
||||
) -> Command[str]:
|
||||
"""Create a Command object with consistent logging."""
|
||||
msg = f"{category}: routing to '{target}'"
|
||||
if updates:
|
||||
msg += f" with updates: {list(updates.keys())}"
|
||||
debug_highlight(msg, category=category)
|
||||
return Command(goto=target, update=updates or {})
|
||||
if updates is not None:
|
||||
safe_updates = _sanitize_mapping(updates)
|
||||
else:
|
||||
safe_updates = {}
|
||||
return Command(goto=target, update=safe_updates)
|
||||
|
||||
|
||||
def _log_sends(targets: list[tuple[str, dict[str, Any]]], category: str) -> list[Send]:
|
||||
def _log_sends(
|
||||
targets: Sequence[tuple[str, Mapping[str, object]]] | object,
|
||||
category: str,
|
||||
) -> list[Send]:
|
||||
"""Create Send objects with consistent logging."""
|
||||
sends = [Send(t, state) for t, state in targets]
|
||||
debug_highlight(f"{category}: dispatching {len(sends)} branches", category=category)
|
||||
return sends
|
||||
if not isinstance(targets, Sequence) or isinstance(targets, (str, bytes, bytearray)):
|
||||
debug_highlight(
|
||||
f"{category}: invalid targets container: {type(targets)!r}",
|
||||
category=category,
|
||||
)
|
||||
raise TypeError(f"{category}: 'targets' must be a sequence of (str, Mapping) tuples")
|
||||
|
||||
MAX_BRANCHES = 1000
|
||||
if len(targets) > MAX_BRANCHES:
|
||||
debug_highlight(
|
||||
f"{category}: too many target branches: {len(targets)} (limit {MAX_BRANCHES})",
|
||||
category=category,
|
||||
)
|
||||
raise ValueError(f"{category}: too many branches ({len(targets)})")
|
||||
|
||||
MAX_PAYLOAD_BYTES = 64_000
|
||||
safe_sends: list[Send] = []
|
||||
invalid_count = 0
|
||||
for idx, item in enumerate(targets):
|
||||
if not isinstance(item, tuple) or len(item) != 2:
|
||||
debug_highlight(
|
||||
f"{category}: invalid target entry at index {idx}: {item!r}",
|
||||
category=category,
|
||||
)
|
||||
invalid_count += 1
|
||||
continue
|
||||
target, state = item
|
||||
if not isinstance(target, str) or not target:
|
||||
debug_highlight(
|
||||
f"{category}: invalid target at index {idx}: {target!r}",
|
||||
category=category,
|
||||
)
|
||||
invalid_count += 1
|
||||
continue
|
||||
if not isinstance(state, Mapping):
|
||||
debug_highlight(
|
||||
f"{category}: invalid state mapping at index {idx}: {type(state)!r}",
|
||||
category=category,
|
||||
)
|
||||
invalid_count += 1
|
||||
continue
|
||||
payload = _sanitize_mapping(state)
|
||||
if _estimate_size(payload) > MAX_PAYLOAD_BYTES:
|
||||
debug_highlight(
|
||||
f"{category}: payload for '{target}' exceeds {MAX_PAYLOAD_BYTES} bytes; dropping entry",
|
||||
category=category,
|
||||
)
|
||||
invalid_count += 1
|
||||
continue
|
||||
safe_sends.append(Send(target, payload))
|
||||
|
||||
if len(targets) > 0 and not safe_sends:
|
||||
debug_highlight(
|
||||
f"{category}: all {len(targets)} target entries were invalid; aborting dispatch",
|
||||
category=category,
|
||||
)
|
||||
raise ValueError(f"{category}: no valid target entries to dispatch")
|
||||
debug_highlight(
|
||||
f"{category}: dispatching {len(safe_sends)} branches (invalid: {invalid_count})",
|
||||
category=category,
|
||||
)
|
||||
return safe_sends
|
||||
|
||||
|
||||
def create_command_router(
|
||||
routing_logic: Callable[[dict[str, Any]], tuple[str, dict[str, Any] | None]],
|
||||
) -> Callable[[dict[str, Any]], Command[str]]:
|
||||
routing_logic: Callable[
|
||||
[StateMapping], tuple[str, Mapping[str, object] | None]
|
||||
],
|
||||
) -> Callable[[StateMapping], Command[str]]:
|
||||
"""Create a router that returns Command objects for combined state update and routing.
|
||||
|
||||
This factory creates routers that can both update state and control flow
|
||||
@@ -63,7 +169,7 @@ def create_command_router(
|
||||
```
|
||||
"""
|
||||
|
||||
def router(state: dict[str, Any]) -> Command[str]:
|
||||
def router(state: StateMapping) -> Command[str]:
|
||||
target, updates = routing_logic(state)
|
||||
return _log_command(target, updates, category="CommandRouter")
|
||||
|
||||
@@ -71,8 +177,10 @@ def create_command_router(
|
||||
|
||||
|
||||
def create_dynamic_send_router(
|
||||
target_generator: Callable[[dict[str, Any]], list[tuple[str, dict[str, Any]]]],
|
||||
) -> Callable[[dict[str, Any]], list[Send]]:
|
||||
target_generator: Callable[
|
||||
[StateMapping], Sequence[tuple[str, Mapping[str, object]]]
|
||||
],
|
||||
) -> Callable[[StateMapping], list[Send]]:
|
||||
"""Create a router that generates Send objects for dynamic fan-out patterns.
|
||||
|
||||
This factory creates routers for map-reduce patterns where you need to
|
||||
@@ -97,7 +205,7 @@ def create_dynamic_send_router(
|
||||
```
|
||||
"""
|
||||
|
||||
def router(state: dict[str, Any]) -> list[Send]:
|
||||
def router(state: StateMapping) -> list[Send]:
|
||||
targets = target_generator(state)
|
||||
return _log_sends(targets, category="SendRouter")
|
||||
|
||||
@@ -105,12 +213,12 @@ def create_dynamic_send_router(
|
||||
|
||||
|
||||
def create_conditional_command_router(
|
||||
conditions: list[
|
||||
tuple[Callable[[dict[str, Any]], bool], str, dict[str, Any] | None]
|
||||
conditions: Sequence[
|
||||
tuple[Callable[[StateMapping], bool], str, Mapping[str, object] | None]
|
||||
],
|
||||
default_target: str = "end",
|
||||
default_updates: dict[str, Any] | None = None,
|
||||
) -> Callable[[dict[str, Any]], Command[str]]:
|
||||
default_updates: Mapping[str, object] | None = None,
|
||||
) -> Callable[[StateMapping], Command[str]]:
|
||||
"""Create a router with multiple conditions that returns Command objects.
|
||||
|
||||
Args:
|
||||
@@ -131,20 +239,22 @@ def create_conditional_command_router(
|
||||
```
|
||||
"""
|
||||
|
||||
def router(state: dict[str, Any]) -> Command[str]:
|
||||
def router(state: StateMapping) -> Command[str]:
|
||||
for condition_fn, target, updates in conditions:
|
||||
if condition_fn(state):
|
||||
debug_highlight(
|
||||
f"Conditional command router: condition met, routing to '{target}'",
|
||||
category="ConditionalCommand",
|
||||
)
|
||||
return Command(goto=target, update=updates or {})
|
||||
update_payload = dict(updates) if updates is not None else {}
|
||||
return Command(goto=target, update=update_payload)
|
||||
|
||||
debug_highlight(
|
||||
f"Conditional command router: no conditions met, using default '{default_target}'",
|
||||
category="ConditionalCommand",
|
||||
)
|
||||
return Command(goto=default_target, update=default_updates or {})
|
||||
default_payload = dict(default_updates) if default_updates is not None else {}
|
||||
return Command(goto=default_target, update=default_payload)
|
||||
|
||||
return router
|
||||
|
||||
@@ -154,7 +264,7 @@ def create_map_reduce_router(
|
||||
processor_node: str = "process_item",
|
||||
reducer_node: str = "reduce_results",
|
||||
item_state_key: str = "current_item",
|
||||
) -> Callable[[dict[str, Any]], list[Send] | Command[str]]:
|
||||
) -> Callable[[StateMapping], list[Send] | Command[str]]:
|
||||
"""Create a router for map-reduce patterns using Send.
|
||||
|
||||
This router dispatches items to parallel processors and then
|
||||
@@ -180,9 +290,20 @@ def create_map_reduce_router(
|
||||
```
|
||||
"""
|
||||
# Create reusable send router for item processing
|
||||
def _gen(state):
|
||||
items = state.get(items_key, [])
|
||||
start = state.get("processed_count", 0)
|
||||
def _gen(state: StateMapping) -> list[tuple[str, Mapping[str, object]]]:
|
||||
raw_items = state.get(items_key, [])
|
||||
if not isinstance(raw_items, Sequence) or isinstance(
|
||||
raw_items, (str, bytes, bytearray)
|
||||
):
|
||||
return []
|
||||
|
||||
items = list(raw_items)
|
||||
processed_raw = state.get("processed_count", 0)
|
||||
try:
|
||||
start = max(int(processed_raw), 0)
|
||||
except (TypeError, ValueError):
|
||||
start = 0
|
||||
|
||||
return [
|
||||
(
|
||||
processor_node,
|
||||
@@ -197,14 +318,24 @@ def create_map_reduce_router(
|
||||
|
||||
send_router = create_dynamic_send_router(_gen)
|
||||
|
||||
def router(state: dict[str, Any]) -> list[Send] | Command[str]:
|
||||
items = state.get(items_key, [])
|
||||
processed_count = state.get("processed_count", 0)
|
||||
def router(state: StateMapping) -> list[Send] | Command[str]:
|
||||
raw_items = state.get(items_key, [])
|
||||
if (
|
||||
not isinstance(raw_items, Sequence)
|
||||
or isinstance(raw_items, (str, bytes, bytearray))
|
||||
):
|
||||
raw_items = ()
|
||||
|
||||
if processed_count < len(items):
|
||||
processed_raw = state.get("processed_count", 0)
|
||||
try:
|
||||
processed_count = max(int(processed_raw), 0)
|
||||
except (TypeError, ValueError):
|
||||
processed_count = 0
|
||||
|
||||
if processed_count < len(raw_items):
|
||||
return send_router(state)
|
||||
else:
|
||||
return _log_command(reducer_node, None, category="MapReduce")
|
||||
|
||||
return _log_command(reducer_node, None, category="MapReduce")
|
||||
|
||||
return router
|
||||
|
||||
@@ -216,7 +347,7 @@ def create_retry_command_router(
|
||||
failure_node: str = "failure",
|
||||
attempt_key: str = "retry_attempts",
|
||||
success_key: str = "is_successful",
|
||||
) -> Callable[[dict[str, Any]], Command[str]]:
|
||||
) -> Callable[[StateMapping], Command[str]]:
|
||||
"""Create a router that handles retry logic with Command pattern.
|
||||
|
||||
Args:
|
||||
@@ -241,9 +372,14 @@ def create_retry_command_router(
|
||||
```
|
||||
"""
|
||||
|
||||
def router(state: dict[str, Any]) -> Command[str]:
|
||||
attempts = state.get(attempt_key, 0)
|
||||
is_successful = state.get(success_key, False)
|
||||
def router(state: StateMapping) -> Command[str]:
|
||||
attempts_raw = state.get(attempt_key, 0)
|
||||
try:
|
||||
attempts = max(int(attempts_raw), 0)
|
||||
except (TypeError, ValueError):
|
||||
attempts = 0
|
||||
|
||||
is_successful = bool(state.get(success_key, False))
|
||||
|
||||
if is_successful:
|
||||
debug_highlight(
|
||||
@@ -260,7 +396,7 @@ def create_retry_command_router(
|
||||
goto=retry_node,
|
||||
update={
|
||||
attempt_key: attempts + 1,
|
||||
"last_retry_timestamp": None, # Would be actual timestamp
|
||||
"last_retry_timestamp": None,
|
||||
},
|
||||
)
|
||||
else:
|
||||
@@ -280,10 +416,10 @@ def create_retry_command_router(
|
||||
|
||||
|
||||
def create_subgraph_command_router(
|
||||
subgraph_mapping: dict[str, tuple[str, dict[str, Any] | None]],
|
||||
subgraph_mapping: Mapping[str, tuple[str, Mapping[str, object] | None]],
|
||||
state_key: str = "task_type",
|
||||
parent_return_node: str = "consolidate_results",
|
||||
) -> Callable[[dict[str, Any]], Command[str]]:
|
||||
) -> Callable[[StateMapping], Command[str]]:
|
||||
"""Create a router for delegating to subgraphs with Command.PARENT support.
|
||||
|
||||
Args:
|
||||
@@ -304,15 +440,15 @@ def create_subgraph_command_router(
|
||||
```
|
||||
"""
|
||||
|
||||
def router(state: dict[str, Any]) -> Command[str]:
|
||||
def router(state: StateMapping) -> Command[str]:
|
||||
task_type = state.get(state_key)
|
||||
|
||||
if task_type and isinstance(task_type, str) and task_type in subgraph_mapping:
|
||||
subgraph_node, initial_state = subgraph_mapping[task_type]
|
||||
|
||||
updates = {"delegated_to": subgraph_node}
|
||||
updates: dict[str, object] = {"delegated_to": subgraph_node}
|
||||
if initial_state:
|
||||
updates.update(initial_state)
|
||||
updates.update(dict(initial_state))
|
||||
|
||||
debug_highlight(
|
||||
f"Subgraph router: delegating '{task_type}' to '{subgraph_node}'",
|
||||
@@ -343,10 +479,10 @@ def route_on_success(
|
||||
success_node: str = "continue",
|
||||
failure_node: str = "error_handler",
|
||||
success_key: str = "success",
|
||||
) -> Callable[[dict[str, Any]], Command[str]]:
|
||||
) -> Callable[[StateMapping], Command[str]]:
|
||||
"""Create simple success/failure router with Command pattern."""
|
||||
return create_conditional_command_router(
|
||||
[(lambda s: s.get(success_key, False), success_node, None)],
|
||||
[(lambda s: bool(s.get(success_key, False)), success_node, None)],
|
||||
default_target=failure_node,
|
||||
default_updates={"status": "failed"},
|
||||
)
|
||||
@@ -355,14 +491,18 @@ def route_on_success(
|
||||
def fan_out_tasks(
|
||||
tasks_key: str = "tasks",
|
||||
processor_node: str = "process_task",
|
||||
) -> Callable[[dict[str, Any]], list[Send]]:
|
||||
) -> Callable[[StateMapping], list[Send]]:
|
||||
"""Create simple fan-out router for parallel task processing."""
|
||||
|
||||
def generator(state: dict[str, Any]) -> list[tuple[str, dict[str, Any]]]:
|
||||
tasks = state.get(tasks_key, [])
|
||||
def generator(state: StateMapping) -> list[tuple[str, Mapping[str, object]]]:
|
||||
tasks_value = state.get(tasks_key, [])
|
||||
if not isinstance(tasks_value, Sequence) or isinstance(
|
||||
tasks_value, (str, bytes, bytearray)
|
||||
):
|
||||
return []
|
||||
return [
|
||||
(processor_node, {"task": task, "task_id": i})
|
||||
for i, task in enumerate(tasks)
|
||||
for i, task in enumerate(tasks_value)
|
||||
]
|
||||
|
||||
return create_dynamic_send_router(generator)
|
||||
@@ -373,9 +513,9 @@ class CommandRouters:
|
||||
|
||||
@staticmethod
|
||||
def command_route_with_update(
|
||||
routing_fn: Callable[[dict[str, Any]], str],
|
||||
update_fn: Callable[[dict[str, Any]], dict[str, Any]],
|
||||
) -> Callable[[dict[str, Any]], Command[str]]:
|
||||
routing_fn: Callable[[StateMapping], str],
|
||||
update_fn: Callable[[StateMapping], Mapping[str, object]],
|
||||
) -> Callable[[StateMapping], Command[str]]:
|
||||
"""Create Command router that updates state while routing.
|
||||
|
||||
Combines routing decision with state updates in a single operation.
|
||||
@@ -401,7 +541,7 @@ class CommandRouters:
|
||||
```
|
||||
"""
|
||||
|
||||
def router(state: dict[str, Any]) -> Command[str]:
|
||||
def router(state: StateMapping) -> Command[str]:
|
||||
target = routing_fn(state)
|
||||
updates = update_fn(state)
|
||||
return _log_command(target, updates, category="CommandUpdate")
|
||||
@@ -410,13 +550,13 @@ class CommandRouters:
|
||||
|
||||
@staticmethod
|
||||
def command_route_with_retry(
|
||||
success_check: Callable[[dict[str, Any]], bool],
|
||||
success_check: Callable[[StateMapping], bool],
|
||||
success_target: str,
|
||||
retry_target: str,
|
||||
failure_target: str,
|
||||
max_attempts: int = 3,
|
||||
attempt_key: str = "attempts",
|
||||
) -> Callable[[dict[str, Any]], Command[str]]:
|
||||
) -> Callable[[StateMapping], Command[str]]:
|
||||
"""Create Command router with retry logic.
|
||||
|
||||
Implements retry pattern with attempt counting and eventual failure routing.
|
||||
@@ -446,8 +586,12 @@ class CommandRouters:
|
||||
```
|
||||
"""
|
||||
|
||||
def router(state: dict[str, Any]) -> Command[str]:
|
||||
attempts = state.get(attempt_key, 0)
|
||||
def router(state: StateMapping) -> Command[str]:
|
||||
attempts_raw = state.get(attempt_key, 0)
|
||||
try:
|
||||
attempts = max(int(attempts_raw), 0)
|
||||
except (TypeError, ValueError):
|
||||
attempts = 0
|
||||
|
||||
if success_check(state):
|
||||
return _log_command(
|
||||
@@ -475,7 +619,7 @@ class CommandRouters:
|
||||
processor_node: str,
|
||||
item_key: str = "item",
|
||||
include_index: bool = True,
|
||||
) -> Callable[[dict[str, Any]], list[Send]]:
|
||||
) -> Callable[[StateMapping], list[Send]]:
|
||||
"""Create Send router for parallel processing of items.
|
||||
|
||||
Distributes items from state to parallel processor nodes using Send objects.
|
||||
@@ -502,14 +646,21 @@ class CommandRouters:
|
||||
```
|
||||
"""
|
||||
|
||||
def router(state: dict[str, Any]) -> list[Send]:
|
||||
items = state.get(items_key, [])
|
||||
targets = []
|
||||
def router(state: StateMapping) -> list[Send]:
|
||||
items_value = state.get(items_key, [])
|
||||
if not isinstance(items_value, Sequence) or isinstance(
|
||||
items_value, (str, bytes, bytearray)
|
||||
):
|
||||
return []
|
||||
|
||||
for i, item in enumerate(items):
|
||||
processor_state = {item_key: item}
|
||||
targets: list[tuple[str, Mapping[str, object]]] = []
|
||||
|
||||
for i, item in enumerate(items_value):
|
||||
processor_state: dict[str, object] = {item_key: item}
|
||||
if include_index:
|
||||
processor_state.update({"item_index": i, "total_items": len(items)})
|
||||
processor_state.update(
|
||||
{"item_index": i, "total_items": len(items_value)}
|
||||
)
|
||||
targets.append((processor_node, processor_state))
|
||||
|
||||
return _log_sends(targets, category="ParallelProcessor")
|
||||
@@ -518,12 +669,12 @@ class CommandRouters:
|
||||
|
||||
@staticmethod
|
||||
def send_conditional(
|
||||
condition_fn: Callable[[dict[str, Any], Any], bool],
|
||||
condition_fn: Callable[[StateMapping, object], bool],
|
||||
items_key: str,
|
||||
true_node: str,
|
||||
false_node: str,
|
||||
item_key: str = "item",
|
||||
) -> Callable[[dict[str, Any]], list[Send]]:
|
||||
) -> Callable[[StateMapping], list[Send]]:
|
||||
"""Create conditional Send router for item filtering.
|
||||
|
||||
Evaluates each item against a condition and sends to different nodes
|
||||
@@ -555,11 +706,16 @@ class CommandRouters:
|
||||
```
|
||||
"""
|
||||
|
||||
def router(state: dict[str, Any]) -> list[Send]:
|
||||
items = state.get(items_key, [])
|
||||
targets = []
|
||||
def router(state: StateMapping) -> list[Send]:
|
||||
items_value = state.get(items_key, [])
|
||||
if not isinstance(items_value, Sequence) or isinstance(
|
||||
items_value, (str, bytes, bytearray)
|
||||
):
|
||||
return []
|
||||
|
||||
for item in items:
|
||||
targets: list[tuple[str, Mapping[str, object]]] = []
|
||||
|
||||
for item in items_value:
|
||||
target_node = true_node if condition_fn(state, item) else false_node
|
||||
targets.append((target_node, {item_key: item}))
|
||||
|
||||
|
||||
@@ -11,14 +11,14 @@ maintainability and organization.
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from collections.abc import Callable, Mapping
|
||||
|
||||
from langgraph.types import Command
|
||||
|
||||
from .basic_routing import BasicRouters
|
||||
from .command_patterns import CommandRouters
|
||||
from .workflow_routing import WorkflowRouters
|
||||
from .core import StateLike, get_state_value
|
||||
|
||||
# Issue deprecation warning when this module is imported
|
||||
warnings.warn(
|
||||
@@ -38,7 +38,7 @@ class EdgeHelpers(BasicRouters, CommandRouters, WorkflowRouters):
|
||||
error_target: str = "error_handler",
|
||||
success_target: str = "continue",
|
||||
threshold: int = 1,
|
||||
) -> Callable[[dict[str, Any]], str]:
|
||||
) -> Callable[[StateLike], str]:
|
||||
"""Route based on error presence - DEPRECATED.
|
||||
|
||||
Use error handling modules for new code.
|
||||
@@ -49,15 +49,18 @@ class EdgeHelpers(BasicRouters, CommandRouters, WorkflowRouters):
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def router(state: dict[str, Any]) -> str:
|
||||
errors = state.get(error_key, [])
|
||||
def router(state: StateLike) -> str:
|
||||
raw_errors = get_state_value(state, error_key, [])
|
||||
|
||||
# Normalize errors inline to avoid circular import
|
||||
if not errors:
|
||||
errors = []
|
||||
elif not isinstance(errors, list):
|
||||
errors = [errors]
|
||||
error_count = len(errors)
|
||||
if not raw_errors:
|
||||
normalized: list[object] = []
|
||||
elif isinstance(raw_errors, list):
|
||||
normalized = raw_errors
|
||||
else:
|
||||
normalized = [raw_errors]
|
||||
|
||||
error_count = len(normalized)
|
||||
|
||||
return error_target if error_count >= threshold else success_target
|
||||
|
||||
@@ -66,10 +69,10 @@ class EdgeHelpers(BasicRouters, CommandRouters, WorkflowRouters):
|
||||
@staticmethod
|
||||
def command_route_on_error_with_recovery(
|
||||
error_key: str = "errors",
|
||||
recovery_strategies: dict[str, str] | None = None,
|
||||
recovery_strategies: Mapping[str, str] | None = None,
|
||||
max_recovery_attempts: int = 2,
|
||||
final_failure_target: str = "human_intervention",
|
||||
) -> Callable[[dict[str, Any]], Command[str]]:
|
||||
) -> Callable[[StateLike], Command[str]]:
|
||||
"""Create error-aware Command router with recovery - DEPRECATED.
|
||||
|
||||
Use error handling modules for new code.
|
||||
@@ -87,22 +90,32 @@ class EdgeHelpers(BasicRouters, CommandRouters, WorkflowRouters):
|
||||
"parsing": "alternative_parser",
|
||||
}
|
||||
|
||||
def router(state: dict[str, Any]) -> Command[str]:
|
||||
errors = state.get(error_key, [])
|
||||
recovery_attempts = state.get("recovery_attempts", 0)
|
||||
def router(state: StateLike) -> Command[str]:
|
||||
raw_errors = get_state_value(state, error_key, [])
|
||||
recovery_attempts = get_state_value(state, "recovery_attempts", 0)
|
||||
|
||||
if not errors:
|
||||
try:
|
||||
attempts = int(recovery_attempts)
|
||||
except (TypeError, ValueError):
|
||||
attempts = 0
|
||||
|
||||
if not raw_errors:
|
||||
return Command(goto="continue", update={"status": "no_errors"})
|
||||
|
||||
if recovery_attempts >= max_recovery_attempts:
|
||||
normalized_errors = raw_errors if isinstance(raw_errors, list) else [raw_errors]
|
||||
|
||||
if attempts >= max_recovery_attempts:
|
||||
return Command(
|
||||
goto=final_failure_target,
|
||||
update={"status": "recovery_exhausted", "final_errors": errors},
|
||||
update={
|
||||
"status": "recovery_exhausted",
|
||||
"final_errors": normalized_errors,
|
||||
},
|
||||
)
|
||||
|
||||
error_type = "unknown"
|
||||
if isinstance(errors, list) and errors:
|
||||
first_error = errors[0]
|
||||
if normalized_errors:
|
||||
first_error = normalized_errors[0]
|
||||
if isinstance(first_error, dict):
|
||||
error_type = first_error.get("type", "unknown")
|
||||
|
||||
@@ -111,9 +124,9 @@ class EdgeHelpers(BasicRouters, CommandRouters, WorkflowRouters):
|
||||
return Command(
|
||||
goto=recovery_target,
|
||||
update={
|
||||
"recovery_attempts": recovery_attempts + 1,
|
||||
"recovery_attempts": attempts + 1,
|
||||
"recovery_strategy": error_type,
|
||||
"original_errors": errors,
|
||||
"original_errors": normalized_errors,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -1,29 +1,57 @@
|
||||
"""Core routing factory functions for edge helpers.
|
||||
|
||||
This module provides factory functions for creating flexible routing functions
|
||||
that can be used as conditional edges in LangGraph workflows.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal, Protocol, TypeVar
|
||||
|
||||
# Generic state type for edge functions
|
||||
StateT = TypeVar("StateT")
|
||||
"""Core routing factory functions for edge helpers.
|
||||
|
||||
This module provides factory functions for creating flexible routing functions
|
||||
that can be used as conditional edges in LangGraph workflows.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class StateProtocol(Protocol):
|
||||
"""Protocol for state objects that provide ``get`` access."""
|
||||
|
||||
def get(self, key: str, default: object | None = None) -> object | None:
|
||||
"""Return the stored value for ``key`` when available."""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AttributeStateProtocol(Protocol):
|
||||
"""Protocol for state objects that expose attributes for lookup."""
|
||||
|
||||
def __getattr__(self, key: str) -> object:
|
||||
"""Return the value for ``key`` when exposed as an attribute."""
|
||||
...
|
||||
|
||||
|
||||
StateLike = Mapping[str, object] | StateProtocol | AttributeStateProtocol
|
||||
|
||||
|
||||
def get_state_value(
|
||||
state: StateLike,
|
||||
key: str,
|
||||
default: object | None = None,
|
||||
) -> object | None:
|
||||
"""Safely extract a value from mapping or mapping-like state objects."""
|
||||
|
||||
if isinstance(state, Mapping):
|
||||
return state.get(key, default)
|
||||
if isinstance(state, StateProtocol):
|
||||
return state.get(key, default)
|
||||
if isinstance(state, AttributeStateProtocol):
|
||||
return getattr(state, key, default)
|
||||
return getattr(state, key, default)
|
||||
|
||||
|
||||
class StateProtocol(Protocol):
|
||||
"""Protocol for state objects that can be used with edge helpers."""
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""Get a value from the state."""
|
||||
...
|
||||
|
||||
|
||||
def create_enum_router(
|
||||
enum_to_target: dict[str, str],
|
||||
state_key: str = "routing_decision",
|
||||
default_target: str = "end",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
def create_enum_router(
|
||||
enum_to_target: Mapping[str, str],
|
||||
state_key: str = "routing_decision",
|
||||
default_target: str = "end",
|
||||
) -> Callable[[StateLike], str]:
|
||||
"""Create a router that maps enum values to target nodes.
|
||||
|
||||
Args:
|
||||
@@ -43,22 +71,18 @@ def create_enum_router(
|
||||
graph.add_conditional_edges("source_node", router)
|
||||
"""
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
enum_value = state.get(state_key)
|
||||
else:
|
||||
enum_value = getattr(state, state_key, None)
|
||||
|
||||
return enum_to_target.get(str(enum_value), default_target)
|
||||
def router(state: StateLike) -> str:
|
||||
enum_value = get_state_value(state, state_key)
|
||||
return enum_to_target.get(str(enum_value), default_target)
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def create_bool_router(
|
||||
true_target: str,
|
||||
false_target: str,
|
||||
state_key: str = "condition",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
def create_bool_router(
|
||||
true_target: str,
|
||||
false_target: str,
|
||||
state_key: str = "condition",
|
||||
) -> Callable[[StateLike], str]:
|
||||
"""Create a router based on a boolean condition in state.
|
||||
|
||||
Args:
|
||||
@@ -74,22 +98,18 @@ def create_bool_router(
|
||||
graph.add_conditional_edges("validator", router)
|
||||
"""
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
condition = state.get(state_key, False)
|
||||
else:
|
||||
condition = getattr(state, state_key, False)
|
||||
|
||||
return true_target if condition else false_target
|
||||
def router(state: StateLike) -> str:
|
||||
condition = get_state_value(state, state_key, False)
|
||||
return true_target if bool(condition) else false_target
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def create_status_router(
|
||||
status_mapping: dict[str, str],
|
||||
state_key: str = "status",
|
||||
default_target: str = "end",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
def create_status_router(
|
||||
status_mapping: Mapping[str, str],
|
||||
state_key: str = "status",
|
||||
default_target: str = "end",
|
||||
) -> Callable[[StateLike], str]:
|
||||
"""Create a router based on operation status.
|
||||
|
||||
Args:
|
||||
@@ -112,13 +132,13 @@ def create_status_router(
|
||||
return create_enum_router(status_mapping, state_key, default_target)
|
||||
|
||||
|
||||
def create_threshold_router(
|
||||
threshold: float,
|
||||
above_target: str,
|
||||
below_target: str,
|
||||
state_key: str = "score",
|
||||
equal_target: str | None = None,
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
def create_threshold_router(
|
||||
threshold: float,
|
||||
above_target: str,
|
||||
below_target: str,
|
||||
state_key: str = "score",
|
||||
equal_target: str | None = None,
|
||||
) -> Callable[[StateLike], str]:
|
||||
"""Create a router based on numeric threshold comparison.
|
||||
|
||||
Args:
|
||||
@@ -139,15 +159,11 @@ def create_threshold_router(
|
||||
graph.add_conditional_edges("scorer", router)
|
||||
"""
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
value = state.get(state_key, 0.0)
|
||||
else:
|
||||
value = getattr(state, state_key, 0.0)
|
||||
|
||||
try:
|
||||
numeric_value = float(value)
|
||||
if numeric_value > threshold:
|
||||
def router(state: StateLike) -> str:
|
||||
value = get_state_value(state, state_key, 0.0)
|
||||
try:
|
||||
numeric_value = float(value)
|
||||
if numeric_value > threshold:
|
||||
return above_target
|
||||
elif numeric_value < threshold:
|
||||
return below_target
|
||||
@@ -159,11 +175,11 @@ def create_threshold_router(
|
||||
return router
|
||||
|
||||
|
||||
def create_field_presence_router(
|
||||
required_fields: list[str],
|
||||
complete_target: str,
|
||||
incomplete_target: str,
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
def create_field_presence_router(
|
||||
required_fields: Sequence[str],
|
||||
complete_target: str,
|
||||
incomplete_target: str,
|
||||
) -> Callable[[StateLike], str]:
|
||||
"""Create a router based on presence of required fields in state.
|
||||
|
||||
Args:
|
||||
@@ -183,27 +199,24 @@ def create_field_presence_router(
|
||||
graph.add_conditional_edges("data_checker", router)
|
||||
"""
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
for field in required_fields:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
value = state.get(field)
|
||||
else:
|
||||
value = getattr(state, field, None)
|
||||
|
||||
if value is None or value == "":
|
||||
return incomplete_target
|
||||
|
||||
return complete_target
|
||||
def router(state: StateLike) -> str:
|
||||
for field in required_fields:
|
||||
value = get_state_value(state, field)
|
||||
|
||||
if value is None or value == "":
|
||||
return incomplete_target
|
||||
|
||||
return complete_target
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def create_list_length_router(
|
||||
min_length: int,
|
||||
sufficient_target: str,
|
||||
insufficient_target: str,
|
||||
state_key: str = "items",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
def create_list_length_router(
|
||||
min_length: int,
|
||||
sufficient_target: str,
|
||||
insufficient_target: str,
|
||||
state_key: str = "items",
|
||||
) -> Callable[[StateLike], str]:
|
||||
"""Create a router based on list length in state.
|
||||
|
||||
Args:
|
||||
@@ -222,27 +235,18 @@ def create_list_length_router(
|
||||
graph.add_conditional_edges("result_checker", router)
|
||||
"""
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
items = state.get(state_key, [])
|
||||
else:
|
||||
items = getattr(state, state_key, [])
|
||||
|
||||
# Only process actual lists/sequences, not strings or other types
|
||||
if not isinstance(items, (list, tuple)):
|
||||
return insufficient_target
|
||||
|
||||
try:
|
||||
return (
|
||||
sufficient_target if len(items) >= min_length else insufficient_target
|
||||
)
|
||||
except TypeError:
|
||||
return insufficient_target
|
||||
|
||||
return router
|
||||
|
||||
|
||||
# Type aliases for common router patterns
|
||||
BoolRouter = Callable[[StateT], Literal["true", "false"]]
|
||||
StatusRouter = Callable[[StateT], Literal["pending", "running", "completed", "failed"]]
|
||||
ContinueRouter = Callable[[StateT], Literal["continue", "end"]]
|
||||
def router(state: StateLike) -> str:
|
||||
items = get_state_value(state, state_key, [])
|
||||
|
||||
# Only process actual sequences (excluding strings/bytes)
|
||||
if not isinstance(items, Sequence) or isinstance(items, (str, bytes)):
|
||||
return insufficient_target
|
||||
|
||||
try:
|
||||
return (
|
||||
sufficient_target if len(items) >= min_length else insufficient_target
|
||||
)
|
||||
except TypeError:
|
||||
return insufficient_target
|
||||
|
||||
return router
|
||||
|
||||
@@ -1,395 +1,269 @@
|
||||
"""Error handling edge helpers for robust workflow management.
|
||||
|
||||
This module provides edge helpers for detecting errors, implementing retry logic,
|
||||
and routing to appropriate error handling or recovery nodes.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal, TypeVar
|
||||
|
||||
from .core import StateProtocol
|
||||
|
||||
StateT = TypeVar("StateT", bound=StateProtocol)
|
||||
|
||||
|
||||
def detect_errors_list(
|
||||
error_target: str = "error_handler",
|
||||
success_target: str = "continue",
|
||||
errors_key: str = "errors",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that detects errors in a list format.
|
||||
|
||||
This is specifically designed for states that use an 'errors' list
|
||||
to accumulate error information.
|
||||
|
||||
Args:
|
||||
error_target: Target node when errors are detected
|
||||
success_target: Target node when no errors present
|
||||
errors_key: Key in state containing list of errors
|
||||
|
||||
Returns:
|
||||
Router function that routes based on error presence
|
||||
|
||||
Example:
|
||||
error_detector = detect_errors_list(
|
||||
error_target="error_handler",
|
||||
success_target="next_step"
|
||||
)
|
||||
graph.add_conditional_edges("some_node", error_detector)
|
||||
"""
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
if isinstance(state, dict):
|
||||
errors = state.get(errors_key, [])
|
||||
else:
|
||||
errors = getattr(state, errors_key, [])
|
||||
|
||||
# Check if we have any errors
|
||||
return error_target if len(errors) > 0 else success_target
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def handle_error(
|
||||
error_types: dict[str, str] | None = None,
|
||||
error_key: str = "error",
|
||||
default_target: str = "generic_error_handler",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that handles different types of errors.
|
||||
|
||||
Args:
|
||||
error_types: Mapping of error type/class names to handler nodes
|
||||
error_key: Key in state containing error information
|
||||
default_target: Default target if error type not in mapping
|
||||
|
||||
Returns:
|
||||
Router function that routes based on error type
|
||||
|
||||
Example:
|
||||
error_router = handle_error({
|
||||
"ValidationError": "validation_recovery",
|
||||
"NetworkError": "network_retry",
|
||||
"AuthenticationError": "auth_failure_handler"
|
||||
})
|
||||
graph.add_conditional_edges("error_detector", error_router)
|
||||
"""
|
||||
if error_types is None:
|
||||
error_types = {
|
||||
"ValidationError": "validation_error_handler",
|
||||
"NetworkError": "network_error_handler",
|
||||
"TimeoutError": "timeout_error_handler",
|
||||
"AuthenticationError": "auth_error_handler",
|
||||
}
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
error = state.get(error_key)
|
||||
else:
|
||||
error = getattr(state, error_key, None)
|
||||
|
||||
if error is None:
|
||||
return "no_error"
|
||||
|
||||
# Handle different error formats
|
||||
error_type = None
|
||||
if isinstance(error, dict):
|
||||
error_type = error.get("type") or error.get("error_type")
|
||||
elif isinstance(error, str):
|
||||
error_type = error
|
||||
elif hasattr(error, "__class__"):
|
||||
error_type = error.__class__.__name__
|
||||
|
||||
if error_type:
|
||||
return error_types.get(error_type, default_target)
|
||||
|
||||
return default_target
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def retry_on_failure(
|
||||
max_retries: int = 3,
|
||||
retry_count_key: str = "retry_count",
|
||||
error_key: str = "error",
|
||||
) -> Callable[
|
||||
[dict[str, Any] | StateProtocol],
|
||||
Literal["retry", "max_retries_exceeded", "success"],
|
||||
]:
|
||||
"""Create a router that implements retry logic for failures.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of retry attempts
|
||||
retry_count_key: Key in state tracking retry attempts
|
||||
error_key: Key in state containing error information
|
||||
|
||||
Returns:
|
||||
Router function that returns "retry", "max_retries_exceeded", or "success"
|
||||
|
||||
Example:
|
||||
retry_router = retry_on_failure(max_retries=5)
|
||||
graph.add_conditional_edges("failure_handler", retry_router)
|
||||
"""
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["retry", "max_retries_exceeded", "success"]:
|
||||
# Check if there's an error
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
error = state.get(error_key)
|
||||
retry_count = state.get(retry_count_key, 0)
|
||||
else:
|
||||
error = getattr(state, error_key, None)
|
||||
retry_count = getattr(state, retry_count_key, 0)
|
||||
|
||||
# No error means success
|
||||
if error is None:
|
||||
return "success"
|
||||
|
||||
try:
|
||||
current_retries = int(retry_count)
|
||||
return "max_retries_exceeded" if current_retries >= max_retries else "retry"
|
||||
except (ValueError, TypeError):
|
||||
return "retry"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def fallback_to_default(
|
||||
fallback_conditions: list[str] | None = None,
|
||||
output_key: str = "output",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], Literal["use_output", "fallback"]]:
|
||||
"""Create a router that falls back to default when output is inadequate.
|
||||
|
||||
Args:
|
||||
fallback_conditions: List of conditions that trigger fallback
|
||||
output_key: Key in state containing output to check
|
||||
|
||||
Returns:
|
||||
Router function that returns "use_output" or "fallback"
|
||||
|
||||
Example:
|
||||
fallback_router = fallback_to_default([
|
||||
"empty_output", "low_confidence", "invalid_format"
|
||||
])
|
||||
graph.add_conditional_edges("output_validator", fallback_router)
|
||||
"""
|
||||
if fallback_conditions is None:
|
||||
fallback_conditions = ["empty_output", "low_confidence", "error"]
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["use_output", "fallback"]:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
output = state.get(output_key)
|
||||
else:
|
||||
output = getattr(state, output_key, None)
|
||||
|
||||
# Check for empty or null output
|
||||
if output is None or output == "":
|
||||
return "fallback"
|
||||
|
||||
# Check for other fallback conditions in state
|
||||
for condition in fallback_conditions:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
condition_value = state.get(condition, False)
|
||||
else:
|
||||
condition_value = getattr(state, condition, False)
|
||||
|
||||
if condition_value:
|
||||
return "fallback"
|
||||
|
||||
return "use_output"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_critical_error(
|
||||
critical_error_types: list[str] | None = None,
|
||||
error_key: str = "error",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], Literal["critical", "non_critical"]]:
|
||||
"""Create a router that identifies critical errors requiring immediate attention.
|
||||
|
||||
Args:
|
||||
critical_error_types: List of error types considered critical
|
||||
error_key: Key in state containing error information
|
||||
|
||||
Returns:
|
||||
Router function that returns "critical" or "non_critical"
|
||||
|
||||
Example:
|
||||
critical_router = check_critical_error([
|
||||
"SecurityError", "DataCorruptionError", "SystemFailure"
|
||||
])
|
||||
graph.add_conditional_edges("error_classifier", critical_router)
|
||||
"""
|
||||
if critical_error_types is None:
|
||||
critical_error_types = [
|
||||
"SecurityError",
|
||||
"AuthenticationError",
|
||||
"AuthorizationError",
|
||||
"DataCorruptionError",
|
||||
"SystemFailure",
|
||||
"OutOfMemoryError",
|
||||
]
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["critical", "non_critical"]:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
error = state.get(error_key)
|
||||
else:
|
||||
error = getattr(state, error_key, None)
|
||||
|
||||
if error is None:
|
||||
return "non_critical"
|
||||
|
||||
# Determine error type
|
||||
error_type = None
|
||||
if isinstance(error, dict):
|
||||
error_type = error.get("type") or error.get("error_type")
|
||||
elif isinstance(error, str):
|
||||
error_type = error
|
||||
elif hasattr(error, "__class__"):
|
||||
error_type = error.__class__.__name__
|
||||
|
||||
if error_type and error_type in critical_error_types:
|
||||
return "critical"
|
||||
|
||||
return "non_critical"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def escalation_policy(
|
||||
escalation_threshold: int = 3,
|
||||
failure_count_key: str = "consecutive_failures",
|
||||
escalation_types: dict[str, str] | None = None,
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that implements error escalation policies.
|
||||
|
||||
Args:
|
||||
escalation_threshold: Number of failures before escalation
|
||||
failure_count_key: Key in state tracking consecutive failures
|
||||
escalation_types: Mapping of failure types to escalation targets
|
||||
|
||||
Returns:
|
||||
Router function that handles escalation logic
|
||||
|
||||
Example:
|
||||
escalation_router = escalation_policy(
|
||||
escalation_threshold=3,
|
||||
escalation_types={
|
||||
"timeout": "timeout_escalation",
|
||||
"validation": "validation_escalation"
|
||||
}
|
||||
)
|
||||
graph.add_conditional_edges("failure_monitor", escalation_router)
|
||||
"""
|
||||
if escalation_types is None:
|
||||
escalation_types = {
|
||||
"timeout": "timeout_escalation",
|
||||
"network": "network_escalation",
|
||||
"validation": "validation_escalation",
|
||||
"auth": "security_escalation",
|
||||
"security": "security_escalation",
|
||||
}
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
failure_count = state.get(failure_count_key, 0)
|
||||
error = state.get("error")
|
||||
else:
|
||||
failure_count = getattr(state, failure_count_key, 0)
|
||||
error = getattr(state, "error", None)
|
||||
|
||||
try:
|
||||
failures = int(failure_count)
|
||||
if failures < escalation_threshold:
|
||||
return "continue_monitoring"
|
||||
|
||||
# Determine escalation type based on error
|
||||
if error:
|
||||
error_type = None
|
||||
if isinstance(error, dict):
|
||||
error_type = error.get("type") or error.get("error_type")
|
||||
elif isinstance(error, str):
|
||||
error_type = error.lower()
|
||||
elif hasattr(error, "__class__"):
|
||||
error_type = error.__class__.__name__
|
||||
|
||||
if error_type:
|
||||
for escalation_key, target in escalation_types.items():
|
||||
if escalation_key.lower() in error_type.lower():
|
||||
return target
|
||||
|
||||
return "generic_escalation"
|
||||
|
||||
except (ValueError, TypeError):
|
||||
return "continue_monitoring"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def circuit_breaker(
|
||||
failure_threshold: int = 5,
|
||||
recovery_timeout: int = 60,
|
||||
state_key: str = "circuit_breaker_state",
|
||||
failure_count_key: str = "failure_count",
|
||||
last_failure_key: str = "last_failure_time",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], Literal["allow", "reject", "probe"]]:
|
||||
"""Create a router that implements circuit breaker pattern.
|
||||
|
||||
Args:
|
||||
failure_threshold: Number of failures before opening circuit
|
||||
recovery_timeout: Seconds to wait before trying to close circuit
|
||||
state_key: Key storing circuit state (closed/open/half_open)
|
||||
failure_count_key: Key tracking failure count
|
||||
last_failure_key: Key storing last failure timestamp
|
||||
|
||||
Returns:
|
||||
Router function implementing circuit breaker logic
|
||||
|
||||
Example:
|
||||
breaker_router = circuit_breaker(failure_threshold=3, recovery_timeout=30)
|
||||
graph.add_conditional_edges("service_call", breaker_router)
|
||||
"""
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["allow", "reject", "probe"]:
|
||||
import time
|
||||
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
circuit_state = state.get(state_key, "closed")
|
||||
failure_count = state.get(failure_count_key, 0)
|
||||
last_failure = state.get(last_failure_key, 0)
|
||||
else:
|
||||
circuit_state = getattr(state, state_key, "closed")
|
||||
failure_count = getattr(state, failure_count_key, 0)
|
||||
last_failure = getattr(state, last_failure_key, 0)
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# Convert numeric values with error handling
|
||||
try:
|
||||
failure_count = int(failure_count)
|
||||
except (ValueError, TypeError):
|
||||
failure_count = 0
|
||||
|
||||
try:
|
||||
last_failure = float(last_failure)
|
||||
except (ValueError, TypeError):
|
||||
last_failure = 0.0
|
||||
|
||||
if circuit_state == "open":
|
||||
# Check if recovery timeout has passed
|
||||
return (
|
||||
"probe" if current_time - last_failure >= recovery_timeout else "reject"
|
||||
)
|
||||
|
||||
elif circuit_state == "half_open":
|
||||
return "probe" # Allow one request to test
|
||||
|
||||
else: # closed state
|
||||
return "reject" if failure_count > failure_threshold else "allow"
|
||||
|
||||
return router
|
||||
"""Error handling routing helpers for LangGraph workflows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from numbers import Real
|
||||
from typing import Literal
|
||||
|
||||
from .core import StateLike, get_state_value
|
||||
|
||||
|
||||
def _extract_error_type(error: object | None) -> str | None:
|
||||
"""Return a normalized error type string when possible."""
|
||||
|
||||
if error is None:
|
||||
return None
|
||||
if isinstance(error, Mapping):
|
||||
raw_type = error.get("type") or error.get("error_type")
|
||||
return str(raw_type) if raw_type is not None else None
|
||||
if isinstance(error, str):
|
||||
return error
|
||||
return error.__class__.__name__ if hasattr(error, "__class__") else None
|
||||
|
||||
|
||||
def _to_int(value: object | None) -> int | None:
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, bool): # bool is an int subclass; preserve explicit intent
|
||||
return int(value)
|
||||
if isinstance(value, Real):
|
||||
return int(value)
|
||||
if isinstance(value, str):
|
||||
stripped = value.strip()
|
||||
if not stripped:
|
||||
return None
|
||||
try:
|
||||
return int(stripped)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _to_float(value: object | None) -> float | None:
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, bool):
|
||||
return float(value)
|
||||
if isinstance(value, Real):
|
||||
return float(value)
|
||||
if isinstance(value, str):
|
||||
stripped = value.strip()
|
||||
if not stripped:
|
||||
return None
|
||||
try:
|
||||
return float(stripped)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _has_content(value: object | None) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
if isinstance(value, str):
|
||||
return bool(value)
|
||||
if isinstance(value, bool):
|
||||
return True
|
||||
if isinstance(value, Real):
|
||||
return True
|
||||
if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
|
||||
return len(value) > 0
|
||||
return bool(value)
|
||||
|
||||
|
||||
def detect_errors_list(
|
||||
error_target: str = "error_handler",
|
||||
success_target: str = "continue",
|
||||
errors_key: str = "errors",
|
||||
) -> Callable[[StateLike], str]:
|
||||
"""Create a router that detects errors stored in a list."""
|
||||
|
||||
def router(state: StateLike) -> str:
|
||||
errors = get_state_value(state, errors_key, [])
|
||||
if (
|
||||
isinstance(errors, Sequence)
|
||||
and not isinstance(errors, (str, bytes, bytearray))
|
||||
and len(errors) > 0
|
||||
):
|
||||
return error_target
|
||||
return success_target
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def handle_error(
|
||||
error_types: Mapping[str, str] | None = None,
|
||||
error_key: str = "error",
|
||||
default_target: str = "generic_error_handler",
|
||||
) -> Callable[[StateLike], str]:
|
||||
"""Create a router that handles different types of errors."""
|
||||
|
||||
handlers = (
|
||||
dict(error_types)
|
||||
if error_types is not None
|
||||
else {
|
||||
"ValidationError": "validation_error_handler",
|
||||
"NetworkError": "network_error_handler",
|
||||
"TimeoutError": "timeout_error_handler",
|
||||
"AuthenticationError": "auth_error_handler",
|
||||
}
|
||||
)
|
||||
|
||||
def router(state: StateLike) -> str:
|
||||
error = get_state_value(state, error_key)
|
||||
if error is None:
|
||||
return "no_error"
|
||||
error_type = _extract_error_type(error)
|
||||
if error_type:
|
||||
return handlers.get(error_type, default_target)
|
||||
return default_target
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def retry_on_failure(
|
||||
max_retries: int = 3,
|
||||
retry_count_key: str = "retry_count",
|
||||
error_key: str = "error",
|
||||
) -> Callable[[StateLike], Literal["retry", "max_retries_exceeded", "success"]]:
|
||||
"""Create a router that implements retry logic for failures."""
|
||||
|
||||
def router(state: StateLike) -> Literal["retry", "max_retries_exceeded", "success"]:
|
||||
error = get_state_value(state, error_key)
|
||||
if error is None:
|
||||
return "success"
|
||||
|
||||
retry_count = _to_int(get_state_value(state, retry_count_key, 0))
|
||||
if retry_count is None:
|
||||
return "retry"
|
||||
return "max_retries_exceeded" if retry_count >= max_retries else "retry"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def fallback_to_default(
|
||||
fallback_conditions: Sequence[str] | None = None,
|
||||
output_key: str = "output",
|
||||
) -> Callable[[StateLike], Literal["use_output", "fallback"]]:
|
||||
"""Create a router that falls back to default when output is inadequate."""
|
||||
|
||||
conditions = list(fallback_conditions) if fallback_conditions is not None else [
|
||||
"empty_output",
|
||||
"low_confidence",
|
||||
"error",
|
||||
]
|
||||
|
||||
def router(state: StateLike) -> Literal["use_output", "fallback"]:
|
||||
output = get_state_value(state, output_key)
|
||||
if not _has_content(output):
|
||||
return "fallback"
|
||||
|
||||
for condition in conditions:
|
||||
if bool(get_state_value(state, condition, False)):
|
||||
return "fallback"
|
||||
return "use_output"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_critical_error(
|
||||
critical_error_types: Sequence[str] | None = None,
|
||||
error_key: str = "error",
|
||||
) -> Callable[[StateLike], Literal["critical", "non_critical"]]:
|
||||
"""Create a router that identifies critical errors requiring attention."""
|
||||
|
||||
critical = set(critical_error_types or [
|
||||
"SecurityError",
|
||||
"AuthenticationError",
|
||||
"AuthorizationError",
|
||||
"DataCorruptionError",
|
||||
"SystemFailure",
|
||||
"OutOfMemoryError",
|
||||
])
|
||||
|
||||
def router(state: StateLike) -> Literal["critical", "non_critical"]:
|
||||
error = get_state_value(state, error_key)
|
||||
if error is None:
|
||||
return "non_critical"
|
||||
error_type = _extract_error_type(error)
|
||||
if error_type and error_type in critical:
|
||||
return "critical"
|
||||
return "non_critical"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def escalation_policy(
|
||||
escalation_threshold: int = 3,
|
||||
failure_count_key: str = "consecutive_failures",
|
||||
escalation_types: Mapping[str, str] | None = None,
|
||||
) -> Callable[[StateLike], str]:
|
||||
"""Create a router that implements error escalation policies."""
|
||||
|
||||
escalations = (
|
||||
dict(escalation_types)
|
||||
if escalation_types is not None
|
||||
else {
|
||||
"timeout": "timeout_escalation",
|
||||
"network": "network_escalation",
|
||||
"validation": "validation_escalation",
|
||||
"auth": "security_escalation",
|
||||
"security": "security_escalation",
|
||||
}
|
||||
)
|
||||
|
||||
def router(state: StateLike) -> str:
|
||||
failure_count = _to_int(get_state_value(state, failure_count_key, 0))
|
||||
if failure_count is None or failure_count < escalation_threshold:
|
||||
return "continue_monitoring"
|
||||
|
||||
error = get_state_value(state, "error")
|
||||
error_type = _extract_error_type(error)
|
||||
if error_type:
|
||||
lower_type = error_type.lower()
|
||||
for escalation_key, target in escalations.items():
|
||||
if escalation_key.lower() in lower_type:
|
||||
return target
|
||||
return "generic_escalation"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def circuit_breaker(
|
||||
failure_threshold: int = 5,
|
||||
recovery_timeout: int = 60,
|
||||
state_key: str = "circuit_breaker_state",
|
||||
failure_count_key: str = "failure_count",
|
||||
last_failure_key: str = "last_failure_time",
|
||||
) -> Callable[[StateLike], Literal["allow", "reject", "probe"]]:
|
||||
"""Create a router that implements the circuit breaker pattern."""
|
||||
|
||||
def router(state: StateLike) -> Literal["allow", "reject", "probe"]:
|
||||
circuit_state = str(get_state_value(state, state_key, "closed"))
|
||||
failure_count = _to_int(get_state_value(state, failure_count_key, 0)) or 0
|
||||
last_failure = _to_float(get_state_value(state, last_failure_key, 0.0)) or 0.0
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
if circuit_state == "open":
|
||||
return "probe" if current_time - last_failure >= recovery_timeout else "reject"
|
||||
if circuit_state == "half_open":
|
||||
return "probe"
|
||||
return "reject" if failure_count > failure_threshold else "allow"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
__all__ = [
|
||||
"detect_errors_list",
|
||||
"handle_error",
|
||||
"retry_on_failure",
|
||||
"fallback_to_default",
|
||||
"check_critical_error",
|
||||
"escalation_policy",
|
||||
"circuit_breaker",
|
||||
]
|
||||
|
||||
@@ -1,262 +1,183 @@
|
||||
"""Flow control edge helpers for managing workflow progression.
|
||||
|
||||
This module provides edge helpers for controlling the flow of execution
|
||||
in LangGraph workflows, including continuation logic, timeout checks,
|
||||
and multi-step progress tracking.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal, TypeVar
|
||||
|
||||
from .core import StateProtocol
|
||||
|
||||
StateT = TypeVar("StateT", bound=StateProtocol)
|
||||
|
||||
|
||||
def should_continue(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["continue", "end"]:
|
||||
"""Decide whether to continue or end based on tool calls in the last AI message.
|
||||
|
||||
Checks for the presence of tool calls in the last message to determine
|
||||
if the workflow should continue processing or end.
|
||||
|
||||
Args:
|
||||
state: State containing messages with potential tool calls
|
||||
|
||||
Returns:
|
||||
"continue" if tool calls are present, "end" otherwise
|
||||
|
||||
Example:
|
||||
graph.add_conditional_edges("agent", should_continue)
|
||||
"""
|
||||
# Check for messages in state
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
messages = state.get("messages", [])
|
||||
else:
|
||||
messages = getattr(state, "messages", [])
|
||||
|
||||
if not messages:
|
||||
return "end"
|
||||
|
||||
last_message = messages[-1]
|
||||
|
||||
# Check various formats for tool calls
|
||||
if isinstance(last_message, dict):
|
||||
# Check for tool_calls in additional_kwargs (must be a list)
|
||||
additional_kwargs = last_message.get("additional_kwargs", {})
|
||||
tool_calls = additional_kwargs.get("tool_calls")
|
||||
if tool_calls and isinstance(tool_calls, list) and len(tool_calls) > 0:
|
||||
return "continue"
|
||||
|
||||
# Check for tool_calls directly (must be a list)
|
||||
tool_calls = last_message.get("tool_calls")
|
||||
if tool_calls and isinstance(tool_calls, list) and len(tool_calls) > 0:
|
||||
return "continue"
|
||||
|
||||
# Check for function_call
|
||||
if last_message.get("function_call"):
|
||||
return "continue"
|
||||
|
||||
# Check if message object has tool_calls attribute (must be a list)
|
||||
if hasattr(last_message, "tool_calls"):
|
||||
tool_calls = getattr(last_message, "tool_calls", None)
|
||||
if tool_calls and isinstance(tool_calls, list) and len(tool_calls) > 0:
|
||||
return "continue"
|
||||
|
||||
# Check if message object has additional_kwargs with tool_calls (must be a list)
|
||||
if hasattr(last_message, "additional_kwargs"):
|
||||
additional_kwargs = getattr(last_message, "additional_kwargs", None) or {}
|
||||
if isinstance(additional_kwargs, dict):
|
||||
tool_calls = additional_kwargs.get("tool_calls")
|
||||
if tool_calls and isinstance(tool_calls, list) and len(tool_calls) > 0:
|
||||
return "continue"
|
||||
|
||||
return "end"
|
||||
|
||||
|
||||
def timeout_check(
|
||||
timeout_seconds: float = 300.0,
|
||||
start_time_key: str = "start_time",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that checks for timeout conditions.
|
||||
|
||||
Args:
|
||||
timeout_seconds: Maximum allowed execution time in seconds
|
||||
start_time_key: Key in state containing start timestamp
|
||||
|
||||
Returns:
|
||||
Router function that returns "timeout" or "continue"
|
||||
|
||||
Example:
|
||||
timeout_router = timeout_check(timeout_seconds=60.0)
|
||||
graph.add_conditional_edges("long_task", timeout_router)
|
||||
"""
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> Literal["timeout", "continue"]:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
start_time = state.get(start_time_key)
|
||||
else:
|
||||
start_time = getattr(state, start_time_key, None)
|
||||
|
||||
if start_time is None:
|
||||
# No start time recorded, assume we should continue
|
||||
return "continue"
|
||||
|
||||
current_time = time.time()
|
||||
elapsed = current_time - start_time
|
||||
|
||||
return "timeout" if elapsed > timeout_seconds else "continue"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def multi_step_progress(
|
||||
total_steps: int,
|
||||
step_key: str = "current_step",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router for multi-step workflow progress tracking.
|
||||
|
||||
Args:
|
||||
total_steps: Total number of steps in the workflow
|
||||
step_key: Key in state containing current step number
|
||||
|
||||
Returns:
|
||||
Router function that returns "next_step", "complete", or "error"
|
||||
|
||||
Example:
|
||||
progress_router = multi_step_progress(total_steps=5)
|
||||
graph.add_conditional_edges("step_processor", progress_router)
|
||||
"""
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["next_step", "complete", "error"]:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
current_step = state.get(step_key, 0)
|
||||
else:
|
||||
current_step = getattr(state, step_key, 0)
|
||||
|
||||
try:
|
||||
step_num = int(current_step)
|
||||
if step_num < 0:
|
||||
return "error"
|
||||
elif step_num >= total_steps:
|
||||
return "complete"
|
||||
else:
|
||||
return "next_step"
|
||||
except (ValueError, TypeError):
|
||||
return "error"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_iteration_limit(
|
||||
max_iterations: int = 10,
|
||||
iteration_key: str = "iteration_count",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that enforces iteration limits to prevent infinite loops.
|
||||
|
||||
Args:
|
||||
max_iterations: Maximum allowed iterations
|
||||
iteration_key: Key in state containing iteration counter
|
||||
|
||||
Returns:
|
||||
Router function that returns "continue" or "limit_reached"
|
||||
|
||||
Example:
|
||||
limit_router = check_iteration_limit(max_iterations=5)
|
||||
graph.add_conditional_edges("retry_node", limit_router)
|
||||
"""
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["continue", "limit_reached"]:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
iterations = state.get(iteration_key, 0)
|
||||
else:
|
||||
iterations = getattr(state, iteration_key, 0)
|
||||
|
||||
try:
|
||||
iter_count = int(iterations)
|
||||
return "limit_reached" if iter_count >= max_iterations else "continue"
|
||||
except (ValueError, TypeError):
|
||||
return "continue"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_completion_criteria(
|
||||
required_conditions: list[str],
|
||||
condition_prefix: str = "completed_",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that checks if all completion criteria are met.
|
||||
|
||||
Args:
|
||||
required_conditions: List of condition names that must be true
|
||||
condition_prefix: Prefix for condition keys in state
|
||||
|
||||
Returns:
|
||||
Router function that returns "complete" or "continue"
|
||||
|
||||
Example:
|
||||
completion_router = check_completion_criteria([
|
||||
"data_validated", "report_generated", "notifications_sent"
|
||||
])
|
||||
graph.add_conditional_edges("final_check", completion_router)
|
||||
"""
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["complete", "continue"]:
|
||||
for condition in required_conditions:
|
||||
condition_key = f"{condition_prefix}{condition}"
|
||||
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
is_complete = state.get(condition_key, False)
|
||||
else:
|
||||
is_complete = getattr(state, condition_key, False)
|
||||
|
||||
if not is_complete:
|
||||
return "continue"
|
||||
|
||||
return "complete"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_workflow_state(
|
||||
state_transitions: dict[str, str],
|
||||
state_key: str = "workflow_state",
|
||||
default_target: str = "error",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router based on workflow state transitions.
|
||||
|
||||
Args:
|
||||
state_transitions: Mapping of current states to next targets
|
||||
state_key: Key in state containing current workflow state
|
||||
default_target: Default target if state not found in transitions
|
||||
|
||||
Returns:
|
||||
Router function that returns target based on workflow state
|
||||
|
||||
Example:
|
||||
workflow_router = check_workflow_state({
|
||||
"initialized": "processing",
|
||||
"processing": "validation",
|
||||
"validation": "completion",
|
||||
"completion": "end"
|
||||
})
|
||||
graph.add_conditional_edges("state_manager", workflow_router)
|
||||
"""
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
current_state = state.get(state_key)
|
||||
else:
|
||||
current_state = getattr(state, state_key, None)
|
||||
|
||||
return state_transitions.get(str(current_state), default_target)
|
||||
|
||||
return router
|
||||
"""Flow control routing helpers for LangGraph workflows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Literal
|
||||
|
||||
from .core import StateLike, get_state_value
|
||||
|
||||
|
||||
def _to_float(value: object | None) -> float | None:
|
||||
"""Convert ``value`` to ``float`` when possible."""
|
||||
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def _to_int(value: object | None) -> int | None:
|
||||
"""Convert ``value`` to ``int`` when possible."""
|
||||
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def should_continue(state: StateLike) -> Literal["continue", "end"]:
|
||||
"""Decide whether to continue based on tool calls in the last assistant message."""
|
||||
|
||||
messages_value = get_state_value(state, "messages", [])
|
||||
if not isinstance(messages_value, Sequence) or not messages_value:
|
||||
return "end"
|
||||
|
||||
last_assistant: object | None = None
|
||||
for message in reversed(messages_value):
|
||||
role: object | None
|
||||
if isinstance(message, Mapping):
|
||||
role = message.get("role") or message.get("type")
|
||||
else:
|
||||
role = getattr(message, "role", None) or getattr(message, "type", None)
|
||||
|
||||
if role in ("assistant", "ai", "ai_message", "assistant_message"):
|
||||
last_assistant = message
|
||||
break
|
||||
|
||||
if last_assistant is None:
|
||||
return "end"
|
||||
|
||||
def _has_tool_calls(msg: object) -> bool:
|
||||
if isinstance(msg, Mapping):
|
||||
additional_kwargs = msg.get("additional_kwargs")
|
||||
if isinstance(additional_kwargs, Mapping):
|
||||
tool_calls = additional_kwargs.get("tool_calls")
|
||||
if isinstance(tool_calls, Sequence) and not isinstance(
|
||||
tool_calls, (str, bytes, bytearray)
|
||||
) and tool_calls:
|
||||
return True
|
||||
direct_tool_calls = msg.get("tool_calls")
|
||||
if isinstance(direct_tool_calls, Sequence) and not isinstance(
|
||||
direct_tool_calls, (str, bytes, bytearray)
|
||||
) and direct_tool_calls:
|
||||
return True
|
||||
if msg.get("function_call"):
|
||||
return True
|
||||
else:
|
||||
tool_calls_attr = getattr(msg, "tool_calls", None)
|
||||
if isinstance(tool_calls_attr, Sequence) and not isinstance(
|
||||
tool_calls_attr, (str, bytes, bytearray)
|
||||
) and tool_calls_attr:
|
||||
return True
|
||||
additional_kwargs_attr = getattr(msg, "additional_kwargs", None)
|
||||
if isinstance(additional_kwargs_attr, Mapping):
|
||||
attr_tool_calls = additional_kwargs_attr.get("tool_calls")
|
||||
if isinstance(attr_tool_calls, Sequence) and not isinstance(
|
||||
attr_tool_calls, (str, bytes, bytearray)
|
||||
) and attr_tool_calls:
|
||||
return True
|
||||
return False
|
||||
|
||||
return "continue" if _has_tool_calls(last_assistant) else "end"
|
||||
|
||||
|
||||
def timeout_check(
|
||||
timeout_seconds: float = 300.0,
|
||||
start_time_key: str = "start_time",
|
||||
) -> Callable[[StateLike], Literal["timeout", "continue"]]:
|
||||
"""Create a router that checks for timeout conditions."""
|
||||
|
||||
def router(state: StateLike) -> Literal["timeout", "continue"]:
|
||||
start_time_value = get_state_value(state, start_time_key)
|
||||
start_time = _to_float(start_time_value)
|
||||
if start_time is None:
|
||||
return "continue"
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
return "timeout" if elapsed > timeout_seconds else "continue"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def multi_step_progress(
|
||||
total_steps: int,
|
||||
step_key: str = "current_step",
|
||||
) -> Callable[[StateLike], Literal["next_step", "complete", "error"]]:
|
||||
"""Create a router for multi-step workflow progress tracking."""
|
||||
|
||||
def router(state: StateLike) -> Literal["next_step", "complete", "error"]:
|
||||
current_step_value = get_state_value(state, step_key, 0)
|
||||
step_num = _to_int(current_step_value)
|
||||
if step_num is None:
|
||||
return "error"
|
||||
if step_num < 0:
|
||||
return "error"
|
||||
if step_num >= total_steps:
|
||||
return "complete"
|
||||
return "next_step"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_iteration_limit(
|
||||
max_iterations: int = 10,
|
||||
iteration_key: str = "iteration_count",
|
||||
) -> Callable[[StateLike], Literal["continue", "limit_reached"]]:
|
||||
"""Create a router that enforces iteration limits to prevent infinite loops."""
|
||||
|
||||
def router(state: StateLike) -> Literal["continue", "limit_reached"]:
|
||||
iteration_value = get_state_value(state, iteration_key, 0)
|
||||
iter_count = _to_int(iteration_value)
|
||||
if iter_count is None:
|
||||
return "continue"
|
||||
return "limit_reached" if iter_count >= max_iterations else "continue"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_completion_criteria(
|
||||
required_conditions: Sequence[str],
|
||||
condition_prefix: str = "completed_",
|
||||
) -> Callable[[StateLike], Literal["complete", "continue"]]:
|
||||
"""Create a router that checks if all completion criteria are met."""
|
||||
|
||||
def router(state: StateLike) -> Literal["complete", "continue"]:
|
||||
for condition in required_conditions:
|
||||
condition_key = f"{condition_prefix}{condition}"
|
||||
if not bool(get_state_value(state, condition_key, False)):
|
||||
return "continue"
|
||||
return "complete"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_workflow_state(
|
||||
state_transitions: Mapping[str, str],
|
||||
state_key: str = "workflow_state",
|
||||
default_target: str = "error",
|
||||
) -> Callable[[StateLike], str]:
|
||||
"""Create a router based on workflow state transitions."""
|
||||
|
||||
transitions = dict(state_transitions)
|
||||
|
||||
def router(state: StateLike) -> str:
|
||||
current_state = get_state_value(state, state_key)
|
||||
return transitions.get(str(current_state), default_target)
|
||||
|
||||
return router
|
||||
|
||||
|
||||
__all__ = [
|
||||
"should_continue",
|
||||
"timeout_check",
|
||||
"multi_step_progress",
|
||||
"check_iteration_limit",
|
||||
"check_completion_criteria",
|
||||
"check_workflow_state",
|
||||
]
|
||||
|
||||
@@ -1,400 +1,300 @@
|
||||
"""Monitoring and operational edge helpers for system management.
|
||||
|
||||
This module provides edge helpers for monitoring system health, checking
|
||||
resource availability, triggering notifications, and managing operational
|
||||
concerns like load balancing and rate limiting.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal, TypeVar
|
||||
|
||||
from .core import StateProtocol
|
||||
|
||||
StateT = TypeVar("StateT", bound=StateProtocol)
|
||||
|
||||
|
||||
def log_and_monitor(
|
||||
log_levels: dict[str, str] | None = None,
|
||||
log_level_key: str = "log_level",
|
||||
should_log_key: str = "should_log",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that determines logging and monitoring actions.
|
||||
|
||||
Args:
|
||||
log_levels: Mapping of log levels to monitoring actions
|
||||
log_level_key: Key in state containing log level
|
||||
should_log_key: Key in state indicating if logging is enabled
|
||||
|
||||
Returns:
|
||||
Router function that returns monitoring action or "no_logging"
|
||||
|
||||
Example:
|
||||
monitor_router = log_and_monitor({
|
||||
"debug": "debug_monitor",
|
||||
"info": "info_monitor",
|
||||
"warning": "alert_monitor",
|
||||
"error": "urgent_monitor"
|
||||
})
|
||||
graph.add_conditional_edges("logger", monitor_router)
|
||||
"""
|
||||
if log_levels is None:
|
||||
log_levels = {
|
||||
"debug": "debug_log",
|
||||
"info": "info_log",
|
||||
"warning": "warning_alert",
|
||||
"error": "error_alert",
|
||||
"critical": "critical_alert",
|
||||
}
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
should_log = state.get(should_log_key, True)
|
||||
log_level = state.get(log_level_key, "info")
|
||||
else:
|
||||
should_log = getattr(state, should_log_key, True)
|
||||
log_level = getattr(state, log_level_key, "info")
|
||||
|
||||
if not should_log:
|
||||
return "no_logging"
|
||||
|
||||
level_str = str(log_level).lower()
|
||||
return log_levels.get(level_str, "info_log")
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_resource_availability(
|
||||
resource_thresholds: dict[str, float] | None = None,
|
||||
resources_key: str = "resource_usage",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that checks system resource availability.
|
||||
|
||||
Args:
|
||||
resource_thresholds: Mapping of resource names to threshold percentages
|
||||
resources_key: Key in state containing resource usage data
|
||||
|
||||
Returns:
|
||||
Router function that returns "resources_available" or resource constraint
|
||||
|
||||
Example:
|
||||
resource_router = check_resource_availability({
|
||||
"cpu": 0.8, # 80% threshold
|
||||
"memory": 0.9, # 90% threshold
|
||||
"disk": 0.95 # 95% threshold
|
||||
})
|
||||
graph.add_conditional_edges("resource_check", resource_router)
|
||||
"""
|
||||
if resource_thresholds is None:
|
||||
resource_thresholds = {
|
||||
"cpu": 0.8,
|
||||
"memory": 0.9,
|
||||
"disk": 0.95,
|
||||
"network": 0.8,
|
||||
}
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
resources = state.get(resources_key)
|
||||
else:
|
||||
resources = getattr(state, resources_key, None)
|
||||
|
||||
if resources is None or not isinstance(resources, dict):
|
||||
return "resources_unknown"
|
||||
|
||||
# Check each resource against its threshold
|
||||
for resource_name, threshold in resource_thresholds.items():
|
||||
usage = resources.get(resource_name, 0.0)
|
||||
try:
|
||||
usage_percent = float(usage)
|
||||
if usage_percent >= threshold:
|
||||
return f"{resource_name}_constrained"
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
|
||||
return "resources_available"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def trigger_notifications(
|
||||
notification_rules: dict[str, str] | None = None,
|
||||
alert_level_key: str = "alert_level",
|
||||
notification_enabled_key: str = "notifications_enabled",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that triggers appropriate notifications.
|
||||
|
||||
Args:
|
||||
notification_rules: Mapping of alert levels to notification types
|
||||
alert_level_key: Key in state containing alert level
|
||||
notification_enabled_key: Key in state for notification toggle
|
||||
|
||||
Returns:
|
||||
Router function that returns notification type or "no_notification"
|
||||
|
||||
Example:
|
||||
notify_router = trigger_notifications({
|
||||
"low": "email_notification",
|
||||
"medium": "slack_notification",
|
||||
"high": "sms_notification",
|
||||
"critical": "phone_notification"
|
||||
})
|
||||
graph.add_conditional_edges("alert_handler", notify_router)
|
||||
"""
|
||||
if notification_rules is None:
|
||||
notification_rules = {
|
||||
"low": "email_notification",
|
||||
"medium": "slack_notification",
|
||||
"high": "sms_notification",
|
||||
"critical": "phone_notification",
|
||||
"emergency": "all_channels",
|
||||
}
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
notifications_enabled = state.get(notification_enabled_key, True)
|
||||
alert_level = state.get(alert_level_key)
|
||||
else:
|
||||
notifications_enabled = getattr(state, notification_enabled_key, True)
|
||||
alert_level = getattr(state, alert_level_key, None)
|
||||
|
||||
if not notifications_enabled or alert_level is None:
|
||||
return "no_notification"
|
||||
|
||||
level_str = str(alert_level).lower()
|
||||
return notification_rules.get(level_str, "no_notification")
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_rate_limiting(
|
||||
rate_limits: dict[str, dict[str, float]] | None = None,
|
||||
request_counts_key: str = "request_counts",
|
||||
time_window_key: str = "time_window",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that enforces rate limiting policies.
|
||||
|
||||
Args:
|
||||
rate_limits: Mapping of endpoints/operations to rate limit configs
|
||||
request_counts_key: Key in state containing request count data
|
||||
time_window_key: Key in state containing time window info
|
||||
|
||||
Returns:
|
||||
Router function that returns "allowed" or "rate_limited"
|
||||
|
||||
Example:
|
||||
rate_router = check_rate_limiting({
|
||||
"api_calls": {"max_requests": 100, "window_seconds": 60},
|
||||
"file_uploads": {"max_requests": 10, "window_seconds": 60}
|
||||
})
|
||||
graph.add_conditional_edges("rate_limiter", rate_router)
|
||||
"""
|
||||
if rate_limits is None:
|
||||
rate_limits = {
|
||||
"default": {"max_requests": 100, "window_seconds": 60},
|
||||
"api_calls": {"max_requests": 100, "window_seconds": 60},
|
||||
"heavy_operations": {"max_requests": 10, "window_seconds": 300},
|
||||
}
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["allowed", "rate_limited"]:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
request_counts = state.get(request_counts_key, {})
|
||||
time_window_data = state.get(time_window_key, {})
|
||||
else:
|
||||
request_counts = getattr(state, request_counts_key, {})
|
||||
time_window_data = getattr(state, time_window_key, {})
|
||||
|
||||
# Use time from time_window_data if available, otherwise use system time
|
||||
current_time = time_window_data.get("current_time", time.time())
|
||||
|
||||
# Check rate limits for each configured endpoint
|
||||
for endpoint, limits in rate_limits.items():
|
||||
if endpoint not in request_counts:
|
||||
continue
|
||||
|
||||
max_requests = limits.get("max_requests", 100)
|
||||
window_seconds = limits.get("window_seconds", 60)
|
||||
|
||||
# Get request history for this endpoint
|
||||
endpoint_data = request_counts.get(endpoint, {})
|
||||
request_count = endpoint_data.get("count", 0)
|
||||
window_start = endpoint_data.get("window_start", current_time)
|
||||
|
||||
# Check if we're still in the same time window
|
||||
if (
|
||||
current_time - window_start < window_seconds
|
||||
and request_count > max_requests
|
||||
):
|
||||
return "rate_limited"
|
||||
|
||||
return "allowed"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def load_balance(
|
||||
load_balancing_strategy: str = "round_robin",
|
||||
available_nodes: list[str] | None = None,
|
||||
node_status_key: str = "node_status",
|
||||
current_node_key: str = "current_node_index",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that implements load balancing across nodes.
|
||||
|
||||
Args:
|
||||
load_balancing_strategy: Strategy to use ("round_robin", "least_loaded")
|
||||
available_nodes: List of available node names
|
||||
node_status_key: Key in state containing node health status
|
||||
current_node_key: Key in state tracking current node index
|
||||
|
||||
Returns:
|
||||
Router function that returns selected node name
|
||||
|
||||
Example:
|
||||
balance_router = load_balance(
|
||||
strategy="round_robin",
|
||||
available_nodes=["node1", "node2", "node3"]
|
||||
)
|
||||
graph.add_conditional_edges("load_balancer", balance_router)
|
||||
"""
|
||||
if available_nodes is None:
|
||||
available_nodes = ["primary_node", "secondary_node"]
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
node_status = state.get(node_status_key, {}) or {}
|
||||
current_index = state.get(current_node_key, 0)
|
||||
else:
|
||||
node_status = getattr(state, node_status_key, {}) or {}
|
||||
current_index = getattr(state, current_node_key, 0)
|
||||
|
||||
# Filter healthy nodes
|
||||
healthy_nodes = []
|
||||
for node in available_nodes:
|
||||
node_health = node_status.get(node, {})
|
||||
if node_health.get("healthy", True): # Default to healthy
|
||||
healthy_nodes.append(node)
|
||||
|
||||
if not healthy_nodes:
|
||||
return "no_available_nodes"
|
||||
|
||||
if load_balancing_strategy == "round_robin":
|
||||
try:
|
||||
index = int(current_index) % len(healthy_nodes)
|
||||
return healthy_nodes[index]
|
||||
except (ValueError, TypeError):
|
||||
return healthy_nodes[0]
|
||||
|
||||
elif load_balancing_strategy == "least_loaded":
|
||||
# Find node with lowest load
|
||||
min_load = float("inf")
|
||||
selected_node = healthy_nodes[0]
|
||||
|
||||
for node in healthy_nodes:
|
||||
node_data = node_status.get(node, {})
|
||||
load = node_data.get("load", 0.0)
|
||||
try:
|
||||
if float(load) < min_load:
|
||||
min_load = float(load)
|
||||
selected_node = node
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
|
||||
return selected_node
|
||||
|
||||
else: # Default to first healthy node
|
||||
return healthy_nodes[0]
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def health_check(
|
||||
health_criteria: dict[str, Any] | None = None,
|
||||
health_status_key: str = "health_status",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that performs system health checks.
|
||||
|
||||
Args:
|
||||
health_criteria: Criteria for determining system health
|
||||
health_status_key: Key in state containing health check results
|
||||
|
||||
Returns:
|
||||
Router function that returns "healthy", "degraded", or "unhealthy"
|
||||
|
||||
Example:
|
||||
health_router = health_check({
|
||||
"response_time_ms": 1000,
|
||||
"error_rate": 0.05,
|
||||
"uptime_percent": 0.99
|
||||
})
|
||||
graph.add_conditional_edges("health_monitor", health_router)
|
||||
"""
|
||||
if health_criteria is None:
|
||||
health_criteria = {
|
||||
"response_time_ms": 1000,
|
||||
"error_rate": 0.05,
|
||||
"cpu_usage": 0.8,
|
||||
"memory_usage": 0.9,
|
||||
}
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["healthy", "degraded", "unhealthy"]:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
health_status = state.get(health_status_key, {})
|
||||
else:
|
||||
health_status = getattr(state, health_status_key, {})
|
||||
|
||||
if not isinstance(health_status, dict):
|
||||
return "unhealthy"
|
||||
|
||||
# If health_status is empty, we have no health data to assess
|
||||
if not health_status:
|
||||
return "unhealthy"
|
||||
|
||||
unhealthy_count = 0
|
||||
degraded_count = 0
|
||||
|
||||
for metric, threshold in health_criteria.items():
|
||||
current_value = health_status.get(metric)
|
||||
if current_value is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
value = float(current_value)
|
||||
threshold_val = float(threshold)
|
||||
|
||||
# Different metrics have different "good" directions
|
||||
if metric in [
|
||||
"uptime_percent",
|
||||
"availability",
|
||||
"success_rate",
|
||||
]:
|
||||
# Higher is better (like uptime_percent)
|
||||
if value < threshold_val * 0.8: # 80% of threshold
|
||||
unhealthy_count += 1
|
||||
elif value < threshold_val:
|
||||
degraded_count += 1
|
||||
else:
|
||||
# Lower is better (default for most metrics including unknown ones)
|
||||
if abs(threshold_val) < 1e-9: # Use tolerance for float comparison
|
||||
# Special case for zero threshold - any positive value is degraded
|
||||
if value > 1e-9:
|
||||
degraded_count += 1
|
||||
elif value > threshold_val * 1.5: # 150% of threshold
|
||||
unhealthy_count += 1
|
||||
elif value > threshold_val:
|
||||
degraded_count += 1
|
||||
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
|
||||
if unhealthy_count > 0:
|
||||
return "unhealthy"
|
||||
elif degraded_count > 0:
|
||||
return "degraded"
|
||||
else:
|
||||
return "healthy"
|
||||
|
||||
return router
|
||||
"""Monitoring and observability routing helpers for LangGraph graphs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import Literal
|
||||
|
||||
from .core import StateLike, get_state_value
|
||||
|
||||
|
||||
def _as_mapping(value: object) -> Mapping[str, object]:
|
||||
"""Return ``value`` when it behaves like a mapping, otherwise an empty dict."""
|
||||
|
||||
if isinstance(value, Mapping):
|
||||
return value
|
||||
return {}
|
||||
|
||||
|
||||
def log_and_monitor(
|
||||
log_levels: Mapping[str, str] | None = None,
|
||||
log_level_key: str = "log_level",
|
||||
should_log_key: str = "should_log",
|
||||
) -> Callable[[StateLike], str]:
|
||||
"""Create a router that determines logging and monitoring actions."""
|
||||
|
||||
levels = dict(log_levels) if log_levels is not None else {
|
||||
"debug": "debug_log",
|
||||
"info": "info_log",
|
||||
"warning": "warning_alert",
|
||||
"error": "error_alert",
|
||||
"critical": "critical_alert",
|
||||
}
|
||||
|
||||
def router(state: StateLike) -> str:
|
||||
should_log = get_state_value(state, should_log_key, True)
|
||||
if not bool(should_log):
|
||||
return "no_logging"
|
||||
|
||||
log_level = get_state_value(state, log_level_key, "info")
|
||||
return levels.get(str(log_level).lower(), "info_log")
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_resource_availability(
|
||||
resource_thresholds: Mapping[str, float] | None = None,
|
||||
resources_key: str = "resource_usage",
|
||||
) -> Callable[[StateLike], str]:
|
||||
"""Create a router that checks system resource availability."""
|
||||
|
||||
thresholds = dict(resource_thresholds) if resource_thresholds is not None else {
|
||||
"cpu": 0.8,
|
||||
"memory": 0.9,
|
||||
"disk": 0.95,
|
||||
"network": 0.8,
|
||||
}
|
||||
|
||||
def router(state: StateLike) -> str:
|
||||
resources_value = get_state_value(state, resources_key)
|
||||
if not isinstance(resources_value, Mapping):
|
||||
return "resources_unknown"
|
||||
|
||||
for resource_name, threshold in thresholds.items():
|
||||
usage = resources_value.get(resource_name, 0.0)
|
||||
try:
|
||||
usage_percent = float(usage)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
|
||||
if usage_percent >= threshold:
|
||||
return f"{resource_name}_constrained"
|
||||
|
||||
return "resources_available"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def trigger_notifications(
|
||||
notification_rules: Mapping[str, str] | None = None,
|
||||
alert_level_key: str = "alert_level",
|
||||
notification_enabled_key: str = "notifications_enabled",
|
||||
) -> Callable[[StateLike], str]:
|
||||
"""Create a router that triggers appropriate notifications."""
|
||||
|
||||
rules = dict(notification_rules) if notification_rules is not None else {
|
||||
"low": "email_notification",
|
||||
"medium": "slack_notification",
|
||||
"high": "sms_notification",
|
||||
"critical": "phone_notification",
|
||||
"emergency": "all_channels",
|
||||
}
|
||||
|
||||
def router(state: StateLike) -> str:
|
||||
notifications_enabled = get_state_value(state, notification_enabled_key, True)
|
||||
if not bool(notifications_enabled):
|
||||
return "no_notification"
|
||||
|
||||
alert_level = get_state_value(state, alert_level_key)
|
||||
if alert_level is None:
|
||||
return "no_notification"
|
||||
|
||||
return rules.get(str(alert_level).lower(), "no_notification")
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_rate_limiting(
|
||||
rate_limits: Mapping[str, Mapping[str, float]] | None = None,
|
||||
request_counts_key: str = "request_counts",
|
||||
time_window_key: str = "time_window",
|
||||
) -> Callable[[StateLike], Literal["allowed", "rate_limited"]]:
|
||||
"""Create a router that enforces rate limiting policies."""
|
||||
|
||||
limits = (
|
||||
{endpoint: dict(config) for endpoint, config in rate_limits.items()}
|
||||
if rate_limits is not None
|
||||
else {
|
||||
"default": {"max_requests": 100, "window_seconds": 60},
|
||||
"api_calls": {"max_requests": 100, "window_seconds": 60},
|
||||
"heavy_operations": {"max_requests": 10, "window_seconds": 300},
|
||||
}
|
||||
)
|
||||
|
||||
def router(state: StateLike) -> Literal["allowed", "rate_limited"]:
|
||||
request_counts_value = get_state_value(state, request_counts_key, {})
|
||||
request_counts = _as_mapping(request_counts_value)
|
||||
|
||||
time_window_value = get_state_value(state, time_window_key, {})
|
||||
time_window_data = _as_mapping(time_window_value)
|
||||
|
||||
current_time = time_window_data.get("current_time", time.time())
|
||||
try:
|
||||
current_time = float(current_time)
|
||||
except (TypeError, ValueError):
|
||||
current_time = time.time()
|
||||
|
||||
for endpoint, config in limits.items():
|
||||
if endpoint not in request_counts:
|
||||
continue
|
||||
|
||||
try:
|
||||
max_requests = int(config.get("max_requests", 100)) # type: ignore[arg-type]
|
||||
except (TypeError, ValueError):
|
||||
max_requests = 100
|
||||
try:
|
||||
window_seconds = float(config.get("window_seconds", 60)) # type: ignore[arg-type]
|
||||
except (TypeError, ValueError):
|
||||
window_seconds = 60.0
|
||||
|
||||
endpoint_value = request_counts.get(endpoint, {})
|
||||
endpoint_data = _as_mapping(endpoint_value)
|
||||
|
||||
request_count = endpoint_data.get("count", 0)
|
||||
window_start = endpoint_data.get("window_start", current_time)
|
||||
|
||||
try:
|
||||
request_count = int(request_count) # type: ignore[arg-type]
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
|
||||
try:
|
||||
window_start = float(window_start) # type: ignore[arg-type]
|
||||
except (TypeError, ValueError):
|
||||
window_start = current_time
|
||||
|
||||
if (
|
||||
current_time - window_start < window_seconds
|
||||
and request_count > max_requests
|
||||
):
|
||||
return "rate_limited"
|
||||
|
||||
return "allowed"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def load_balance(
|
||||
load_balancing_strategy: str = "round_robin",
|
||||
available_nodes: list[str] | None = None,
|
||||
node_status_key: str = "node_status",
|
||||
current_node_key: str = "current_node_index",
|
||||
) -> Callable[[StateLike], str]:
|
||||
"""Create a router that implements load balancing across nodes."""
|
||||
|
||||
nodes = list(available_nodes) if available_nodes is not None else [
|
||||
"primary_node",
|
||||
"secondary_node",
|
||||
]
|
||||
|
||||
def router(state: StateLike) -> str:
|
||||
status_value = get_state_value(state, node_status_key, {})
|
||||
node_status = _as_mapping(status_value)
|
||||
|
||||
current_index = get_state_value(state, current_node_key, 0)
|
||||
|
||||
healthy_nodes: list[str] = []
|
||||
for node in nodes:
|
||||
node_data = _as_mapping(node_status.get(node, {}))
|
||||
if bool(node_data.get("healthy", True)):
|
||||
healthy_nodes.append(node)
|
||||
|
||||
if not healthy_nodes:
|
||||
return "no_available_nodes"
|
||||
|
||||
if load_balancing_strategy == "round_robin":
|
||||
try:
|
||||
index = int(current_index) % len(healthy_nodes)
|
||||
except (TypeError, ValueError):
|
||||
index = 0
|
||||
return healthy_nodes[index]
|
||||
|
||||
if load_balancing_strategy == "least_loaded":
|
||||
min_load = float("inf")
|
||||
selected_node = healthy_nodes[0]
|
||||
|
||||
for node in healthy_nodes:
|
||||
node_data = _as_mapping(node_status.get(node, {}))
|
||||
load = node_data.get("load", 0.0)
|
||||
try:
|
||||
load_value = float(load)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
|
||||
if load_value < min_load:
|
||||
min_load = load_value
|
||||
selected_node = node
|
||||
|
||||
return selected_node
|
||||
|
||||
return healthy_nodes[0]
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def health_check(
|
||||
health_criteria: Mapping[str, float] | None = None,
|
||||
health_status_key: str = "health_status",
|
||||
) -> Callable[[StateLike], Literal["healthy", "degraded", "unhealthy"]]:
|
||||
"""Create a router that performs system health checks."""
|
||||
|
||||
criteria = dict(health_criteria) if health_criteria is not None else {
|
||||
"response_time_ms": 1000,
|
||||
"error_rate": 0.05,
|
||||
"cpu_usage": 0.8,
|
||||
"memory_usage": 0.9,
|
||||
}
|
||||
|
||||
def router(state: StateLike) -> Literal["healthy", "degraded", "unhealthy"]:
|
||||
health_status_value = get_state_value(state, health_status_key, {})
|
||||
health_status = _as_mapping(health_status_value)
|
||||
|
||||
if not health_status:
|
||||
return "unhealthy"
|
||||
|
||||
unhealthy_count = 0
|
||||
degraded_count = 0
|
||||
|
||||
for metric, threshold in criteria.items():
|
||||
current_value = health_status.get(metric)
|
||||
if current_value is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
value = float(current_value)
|
||||
threshold_val = float(threshold)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
|
||||
if metric in {"uptime_percent", "availability", "success_rate"}:
|
||||
if value < threshold_val * 0.8:
|
||||
unhealthy_count += 1
|
||||
elif value < threshold_val:
|
||||
degraded_count += 1
|
||||
else:
|
||||
if abs(threshold_val) < 1e-9:
|
||||
if value > 1e-9:
|
||||
degraded_count += 1
|
||||
elif value > threshold_val * 1.5:
|
||||
unhealthy_count += 1
|
||||
elif value > threshold_val:
|
||||
degraded_count += 1
|
||||
|
||||
if unhealthy_count > 0:
|
||||
return "unhealthy"
|
||||
if degraded_count > 0:
|
||||
return "degraded"
|
||||
return "healthy"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
__all__ = [
|
||||
"log_and_monitor",
|
||||
"check_resource_availability",
|
||||
"trigger_notifications",
|
||||
"check_rate_limiting",
|
||||
"load_balance",
|
||||
"health_check",
|
||||
]
|
||||
|
||||
@@ -1,27 +1,25 @@
|
||||
"""Factory functions for creating different types of command routers.
|
||||
|
||||
This module provides factory functions for creating various types of
|
||||
command-based routers that follow LangGraph best practices.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from langgraph.types import Command
|
||||
|
||||
from biz_bud.logging import get_logger
|
||||
|
||||
from ..validation.condition_security import ConditionSecurityError, ConditionValidator
|
||||
from .routing_rules import CommandRoutingRule
|
||||
|
||||
logger = get_logger(__name__)
|
||||
"""Factory functions for creating different types of command routers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Mapping
|
||||
|
||||
from langgraph.types import Command
|
||||
|
||||
from biz_bud.logging import get_logger
|
||||
|
||||
from ..validation.condition_security import ConditionSecurityError, ConditionValidator
|
||||
from .core import StateLike
|
||||
from .routing_rules import CommandRoutingRule
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def create_command_router(
|
||||
rules: list[CommandRoutingRule],
|
||||
default_target: str = "__end__",
|
||||
allow_state_updates: bool = True,
|
||||
) -> Callable[[Any], Command[str]]:
|
||||
def create_command_router(
|
||||
rules: list[CommandRoutingRule],
|
||||
default_target: str = "__end__",
|
||||
allow_state_updates: bool = True,
|
||||
) -> Callable[[StateLike], Command[str]]:
|
||||
"""Create a Command-based router from a list of routing rules.
|
||||
|
||||
Args:
|
||||
@@ -45,7 +43,7 @@ def create_command_router(
|
||||
# Sort rules by priority (highest first)
|
||||
sorted_rules = sorted(rules, key=lambda r: r.priority, reverse=True)
|
||||
|
||||
def router_func(state: Any) -> Command[str]:
|
||||
def router_func(state: StateLike) -> Command[str]:
|
||||
"""Router function that evaluates rules and returns Command."""
|
||||
# Evaluate rules in priority order
|
||||
for rule in sorted_rules:
|
||||
@@ -107,11 +105,11 @@ def _safe_construct_condition(field_name: str, operator: str, value: str) -> str
|
||||
return condition
|
||||
|
||||
|
||||
def create_status_command_router(
|
||||
status_mappings: dict[str, str],
|
||||
status_field: str = "status",
|
||||
default_target: str = "__end__",
|
||||
) -> Callable[[Any], Command[str]]:
|
||||
def create_status_command_router(
|
||||
status_mappings: Mapping[str, str],
|
||||
status_field: str = "status",
|
||||
default_target: str = "__end__",
|
||||
) -> Callable[[StateLike], Command[str]]:
|
||||
"""Create a Command router based on status field values.
|
||||
|
||||
Args:
|
||||
@@ -153,10 +151,10 @@ def create_status_command_router(
|
||||
return create_command_router(rules, default_target)
|
||||
|
||||
|
||||
def create_conditional_command_router(
|
||||
condition_mappings: dict[str, str],
|
||||
default_target: str = "__end__",
|
||||
) -> Callable[[Any], Command[str]]:
|
||||
def create_conditional_command_router(
|
||||
condition_mappings: Mapping[str, str],
|
||||
default_target: str = "__end__",
|
||||
) -> Callable[[StateLike], Command[str]]:
|
||||
"""Create a Command router based on arbitrary conditions.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
"""Core routing rule classes for command-based routing.
|
||||
|
||||
This module provides the base classes and protocols for building
|
||||
command routing rules with security validation.
|
||||
"""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Protocol
|
||||
|
||||
from langgraph.types import Command
|
||||
|
||||
from biz_bud.core.errors import ValidationError
|
||||
from biz_bud.logging import get_logger
|
||||
"""Core routing rule classes for command-based routing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import contextlib
|
||||
import time
|
||||
from collections.abc import Callable, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from langgraph.types import Command
|
||||
|
||||
from biz_bud.core.errors import ValidationError
|
||||
from biz_bud.logging import get_logger
|
||||
from .core import StateLike, StateProtocol
|
||||
|
||||
from ..validation.condition_security import (
|
||||
ConditionSecurityError,
|
||||
@@ -19,15 +20,11 @@ from ..validation.condition_security import (
|
||||
validate_condition_for_security,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class StateProtocol(Protocol):
|
||||
"""Protocol for state objects that can be used with Command routers."""
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""Get a value from the state."""
|
||||
...
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
SupportsGet = StateProtocol
|
||||
ConditionCallable = Callable[[StateLike], bool]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -47,19 +44,19 @@ class CommandRoutingRule:
|
||||
description: Human-readable description of the rule
|
||||
"""
|
||||
|
||||
condition: str | Callable[[Any], bool]
|
||||
target: str
|
||||
state_updates: dict[str, Any] = field(default_factory=dict)
|
||||
condition: str | ConditionCallable
|
||||
target: str
|
||||
state_updates: dict[str, object] = field(default_factory=dict)
|
||||
priority: int = 0
|
||||
description: str = ""
|
||||
|
||||
def evaluate(self, state: Any) -> bool:
|
||||
def evaluate(self, state: StateLike) -> bool:
|
||||
"""Evaluate if this rule applies to the given state."""
|
||||
if callable(self.condition):
|
||||
return self.condition(state)
|
||||
return self._evaluate_string_condition(self.condition, state)
|
||||
|
||||
def _evaluate_string_condition(self, condition: str, state: Any) -> bool:
|
||||
def _evaluate_string_condition(self, condition: str, state: StateLike) -> bool:
|
||||
"""Evaluate a string-based condition with comprehensive security validation.
|
||||
|
||||
This method provides robust security against injection attacks and ensures
|
||||
@@ -104,14 +101,14 @@ class CommandRoutingRule:
|
||||
expected_value = validated_condition[operator_position + len(operator):].strip()
|
||||
|
||||
# Define supported operators with their functions
|
||||
operators = {
|
||||
">=": lambda a, b: a >= b,
|
||||
"<=": lambda a, b: a <= b,
|
||||
"==": lambda a, b: a == b,
|
||||
"!=": lambda a, b: a != b,
|
||||
">": lambda a, b: a > b,
|
||||
"<": lambda a, b: a < b,
|
||||
}
|
||||
operators: dict[str, Callable[[object, object], bool]] = {
|
||||
">=": lambda a, b: bool(a >= b),
|
||||
"<=": lambda a, b: bool(a <= b),
|
||||
"==": lambda a, b: bool(a == b),
|
||||
"!=": lambda a, b: bool(a != b),
|
||||
">": lambda a, b: bool(a > b),
|
||||
"<": lambda a, b: bool(a < b),
|
||||
}
|
||||
|
||||
if operator not in operators:
|
||||
raise ValidationError(f"Unsupported operator: {operator}")
|
||||
@@ -147,7 +144,7 @@ class CommandRoutingRule:
|
||||
)
|
||||
raise ValidationError(f"Condition evaluation failed: {type(e).__name__}") from e
|
||||
|
||||
def _parse_condition_value(self, value_str: str) -> Any:
|
||||
def _parse_condition_value(self, value_str: str) -> object:
|
||||
"""Parse a condition value string into appropriate Python type.
|
||||
|
||||
Args:
|
||||
@@ -156,33 +153,31 @@ class CommandRoutingRule:
|
||||
Returns:
|
||||
Parsed value with appropriate type
|
||||
"""
|
||||
value_str = value_str.strip()
|
||||
value_str = value_str.strip()
|
||||
|
||||
if value_str.lower() in ("true", "false"):
|
||||
return value_str.lower() == "true"
|
||||
if value_str.lower() in {"null", "none"}:
|
||||
return None
|
||||
|
||||
if (value_str.startswith("'") and value_str.endswith("'")) or (
|
||||
value_str.startswith('"') and value_str.endswith('"')
|
||||
):
|
||||
try:
|
||||
return ast.literal_eval(value_str)
|
||||
except (ValueError, SyntaxError):
|
||||
return value_str[1:-1]
|
||||
|
||||
with contextlib.suppress(ValueError):
|
||||
return int(value_str)
|
||||
with contextlib.suppress(ValueError):
|
||||
return float(value_str)
|
||||
|
||||
return value_str
|
||||
|
||||
# Handle quoted strings (remove quotes but preserve content)
|
||||
if (value_str.startswith('"') and value_str.endswith('"')) or (
|
||||
value_str.startswith("'") and value_str.endswith("'")
|
||||
):
|
||||
return value_str[1:-1] # Return string without quotes
|
||||
|
||||
# Handle boolean values
|
||||
if value_str.lower() == "true":
|
||||
return True
|
||||
elif value_str.lower() == "false":
|
||||
return False
|
||||
elif value_str.lower() in {"null", "none"}:
|
||||
return None
|
||||
|
||||
# Handle numeric values
|
||||
try:
|
||||
# Try integer first (handles both positive and negative integers)
|
||||
return (
|
||||
int(value_str) if value_str.lstrip("-").isdigit() else float(value_str)
|
||||
)
|
||||
except ValueError:
|
||||
# If all else fails, treat as string
|
||||
return value_str
|
||||
|
||||
def _safe_get_state_value(self, state: Any, field_name: str) -> Any:
|
||||
def _safe_get_state_value(
|
||||
self, state: StateLike, field_name: str
|
||||
) -> object | None:
|
||||
"""Safely get value from state object.
|
||||
|
||||
Args:
|
||||
@@ -193,24 +188,27 @@ class CommandRoutingRule:
|
||||
Field value or None if not found
|
||||
"""
|
||||
# For simple field names (no dots), handle directly
|
||||
if "." not in field_name:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
return state.get(field_name)
|
||||
else:
|
||||
return getattr(state, field_name, None)
|
||||
if "." not in field_name:
|
||||
if isinstance(state, Mapping):
|
||||
return state.get(field_name)
|
||||
if isinstance(state, StateProtocol):
|
||||
return state.get(field_name)
|
||||
return getattr(state, field_name, None)
|
||||
|
||||
# For nested access (with dots), traverse safely
|
||||
current = state
|
||||
for part in field_name.split("."):
|
||||
if current is None:
|
||||
return None
|
||||
if hasattr(current, "get") or isinstance(current, dict):
|
||||
current = current.get(part)
|
||||
else:
|
||||
current = getattr(current, part, None)
|
||||
return current
|
||||
|
||||
def create_command(self, state: Any) -> Command[str]:
|
||||
if isinstance(current, Mapping):
|
||||
current = current.get(part)
|
||||
elif isinstance(current, StateProtocol):
|
||||
current = current.get(part)
|
||||
else:
|
||||
current = getattr(current, part, None)
|
||||
return current
|
||||
|
||||
def create_command(self, state: StateLike) -> Command[str]:
|
||||
"""Create a Command object for this rule.
|
||||
|
||||
Args:
|
||||
@@ -231,4 +229,9 @@ class CommandRoutingRule:
|
||||
return Command(goto=self.target, update=updates)
|
||||
|
||||
|
||||
__all__ = ["StateProtocol", "CommandRoutingRule"]
|
||||
__all__ = [
|
||||
"CommandRoutingRule",
|
||||
"SupportsGet",
|
||||
"StateLike",
|
||||
"StateProtocol",
|
||||
]
|
||||
|
||||
@@ -7,7 +7,8 @@ resource monitoring, and safe execution contexts for LangGraph workflows.
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Any, Literal
|
||||
from collections.abc import Callable, Mapping, MutableMapping
|
||||
from typing import Literal, cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import Command
|
||||
@@ -47,12 +48,12 @@ class SecureGraphRouter:
|
||||
async def secure_graph_execution(
|
||||
self,
|
||||
graph_name: str,
|
||||
graph_info: dict[str, Any],
|
||||
execution_state: dict[str, Any],
|
||||
graph_info: Mapping[str, object],
|
||||
execution_state: Mapping[str, object],
|
||||
config: RunnableConfig | None = None,
|
||||
step_id: str | None = None,
|
||||
# noqa: ARG001
|
||||
) -> dict[str, Any]:
|
||||
) -> dict[str, object]:
|
||||
"""Execute a graph securely with comprehensive validation and monitoring.
|
||||
|
||||
Args:
|
||||
@@ -83,16 +84,19 @@ class SecureGraphRouter:
|
||||
self.validator.check_concurrent_limit()
|
||||
|
||||
# SECURITY: Validate state data
|
||||
validated_state = self.validator.validate_state_data(execution_state.copy())
|
||||
validated_state = self.validator.validate_state_data(
|
||||
dict(execution_state)
|
||||
)
|
||||
|
||||
# Get and validate factory function
|
||||
factory_function = graph_info.get("factory_function")
|
||||
if not factory_function:
|
||||
factory_obj = graph_info.get("factory_function")
|
||||
if not callable(factory_obj):
|
||||
raise SecurityValidationError(
|
||||
f"No factory function for graph: {validated_graph_name}",
|
||||
validated_graph_name,
|
||||
"missing_factory",
|
||||
)
|
||||
factory_function = cast(Callable[[], object], factory_obj)
|
||||
|
||||
# SECURITY: Validate factory function
|
||||
self.execution_manager.validate_factory_function(
|
||||
@@ -133,7 +137,7 @@ class SecureGraphRouter:
|
||||
def create_security_failure_command(
|
||||
self,
|
||||
error: SecurityValidationError | ResourceLimitExceededError,
|
||||
execution_plan: dict[str, Any],
|
||||
execution_plan: MutableMapping[str, object],
|
||||
step_id: str | None = None,
|
||||
) -> Command[Literal["__end__"]]:
|
||||
"""Create a command for handling security failures.
|
||||
@@ -169,7 +173,7 @@ class SecureGraphRouter:
|
||||
},
|
||||
)
|
||||
|
||||
def get_execution_statistics(self) -> dict[str, Any]:
|
||||
def get_execution_statistics(self) -> dict[str, object]:
|
||||
"""Get current execution statistics from the security manager.
|
||||
|
||||
Returns:
|
||||
@@ -196,12 +200,12 @@ def get_secure_router() -> SecureGraphRouter:
|
||||
|
||||
async def execute_graph_securely(
|
||||
graph_name: str,
|
||||
graph_info: dict[str, Any],
|
||||
execution_state: dict[str, Any],
|
||||
graph_info: Mapping[str, object],
|
||||
execution_state: Mapping[str, object],
|
||||
config: RunnableConfig | None = None,
|
||||
step_id: str | None = None,
|
||||
# noqa: ARG001
|
||||
) -> dict[str, Any]:
|
||||
) -> dict[str, object]:
|
||||
"""Execute graph securely with validation and error handling.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,326 +1,211 @@
|
||||
"""User interaction edge helpers for human-in-the-loop workflows.
|
||||
|
||||
This module provides edge helpers for managing user interactions, interrupts,
|
||||
feedback loops, and escalation to human operators.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal, TypeVar
|
||||
|
||||
from .core import StateProtocol
|
||||
|
||||
StateT = TypeVar("StateT", bound=StateProtocol)
|
||||
|
||||
|
||||
def human_interrupt(
|
||||
interrupt_signals: list[str] | None = None,
|
||||
interrupt_key: str = "human_interrupt",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that detects human interruption requests.
|
||||
|
||||
Args:
|
||||
interrupt_signals: List of signals that indicate interrupt request
|
||||
interrupt_key: Key in state containing interrupt signal
|
||||
|
||||
Returns:
|
||||
Router function that returns "interrupt" or "continue"
|
||||
|
||||
Example:
|
||||
interrupt_router = human_interrupt([
|
||||
"stop", "pause", "cancel", "abort"
|
||||
])
|
||||
graph.add_conditional_edges("user_input_check", interrupt_router)
|
||||
"""
|
||||
if interrupt_signals is None:
|
||||
interrupt_signals = ["stop", "pause", "cancel", "abort", "interrupt", "halt"]
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["interrupt", "continue"]:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
signal = state.get(interrupt_key)
|
||||
else:
|
||||
signal = getattr(state, interrupt_key, None)
|
||||
|
||||
if signal is None:
|
||||
return "continue"
|
||||
|
||||
signal_str = str(signal).lower().strip()
|
||||
normalized_interrupt_signals = {s.lower() for s in interrupt_signals}
|
||||
return "interrupt" if signal_str in normalized_interrupt_signals else "continue"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def pass_status_to_user(
|
||||
status_levels: dict[str, str] | None = None,
|
||||
status_key: str = "status",
|
||||
notify_key: str = "notify_user",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that determines when to notify users of status.
|
||||
|
||||
Args:
|
||||
status_levels: Mapping of status values to notification urgency
|
||||
status_key: Key in state containing current status
|
||||
notify_key: Key in state indicating if user should be notified
|
||||
|
||||
Returns:
|
||||
Router function that returns notification priority or "no_notification"
|
||||
|
||||
Example:
|
||||
status_router = pass_status_to_user({
|
||||
"error": "urgent",
|
||||
"warning": "medium",
|
||||
"completed": "low"
|
||||
})
|
||||
graph.add_conditional_edges("status_monitor", status_router)
|
||||
"""
|
||||
if status_levels is None:
|
||||
status_levels = {
|
||||
"error": "urgent",
|
||||
"failed": "urgent",
|
||||
"warning": "medium",
|
||||
"completed": "low",
|
||||
"success": "low",
|
||||
"in_progress": "info",
|
||||
}
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
status = state.get(status_key)
|
||||
should_notify = state.get(notify_key, True)
|
||||
else:
|
||||
status = getattr(state, status_key, None)
|
||||
should_notify = getattr(state, notify_key, True)
|
||||
|
||||
if not should_notify or status is None:
|
||||
return "no_notification"
|
||||
|
||||
status_str = str(status).lower()
|
||||
return status_levels.get(status_str, "no_notification")
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def user_feedback_loop(
|
||||
feedback_required_conditions: list[str] | None = None,
|
||||
feedback_key: str = "requires_feedback",
|
||||
feedback_type_key: str = "feedback_type",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that manages user feedback collection.
|
||||
|
||||
Args:
|
||||
feedback_required_conditions: Conditions that trigger feedback request
|
||||
feedback_key: Key in state indicating if feedback is required
|
||||
feedback_type_key: Key in state specifying type of feedback needed
|
||||
|
||||
Returns:
|
||||
Router function that returns feedback type or "no_feedback"
|
||||
|
||||
Example:
|
||||
feedback_router = user_feedback_loop([
|
||||
"low_confidence", "ambiguous_input", "multiple_options"
|
||||
])
|
||||
graph.add_conditional_edges("feedback_check", feedback_router)
|
||||
"""
|
||||
if feedback_required_conditions is None:
|
||||
feedback_required_conditions = [
|
||||
"low_confidence",
|
||||
"ambiguous_input",
|
||||
"multiple_options",
|
||||
"validation_failed",
|
||||
]
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
# Check if feedback is explicitly required
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
requires_feedback = state.get(feedback_key, False)
|
||||
feedback_type = state.get(feedback_type_key, "general")
|
||||
else:
|
||||
requires_feedback = getattr(state, feedback_key, False)
|
||||
feedback_type = getattr(state, feedback_type_key, "general")
|
||||
|
||||
if requires_feedback:
|
||||
return f"feedback_{feedback_type}"
|
||||
|
||||
# Check for conditions that trigger feedback
|
||||
for condition in feedback_required_conditions:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
condition_met = state.get(condition, False)
|
||||
else:
|
||||
condition_met = getattr(state, condition, False)
|
||||
|
||||
if condition_met:
|
||||
return f"feedback_{condition}"
|
||||
|
||||
return "no_feedback"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def escalate_to_human(
|
||||
escalation_triggers: dict[str, str] | None = None,
|
||||
auto_escalate_key: str = "auto_escalate",
|
||||
escalation_reason_key: str = "escalation_reason",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that escalates to human operators when needed.
|
||||
|
||||
Args:
|
||||
escalation_triggers: Mapping of trigger conditions to escalation types
|
||||
auto_escalate_key: Key in state for automatic escalation flag
|
||||
escalation_reason_key: Key in state containing escalation reason
|
||||
|
||||
Returns:
|
||||
Router function that returns escalation type or "continue_automated"
|
||||
|
||||
Example:
|
||||
escalation_router = escalate_to_human({
|
||||
"critical_error": "immediate",
|
||||
"multiple_failures": "urgent",
|
||||
"user_request": "standard"
|
||||
})
|
||||
graph.add_conditional_edges("escalation_check", escalation_router)
|
||||
"""
|
||||
if escalation_triggers is None:
|
||||
escalation_triggers = {
|
||||
"critical_error": "immediate",
|
||||
"security_issue": "immediate",
|
||||
"multiple_failures": "urgent",
|
||||
"user_request": "standard",
|
||||
"manual_review": "standard",
|
||||
"complex_case": "expert",
|
||||
}
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
# Check for automatic escalation flag
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
auto_escalate = state.get(auto_escalate_key, False)
|
||||
escalation_reason = state.get(escalation_reason_key)
|
||||
else:
|
||||
auto_escalate = getattr(state, auto_escalate_key, False)
|
||||
escalation_reason = getattr(state, escalation_reason_key, None)
|
||||
|
||||
if auto_escalate and escalation_reason:
|
||||
reason_str = str(escalation_reason).lower()
|
||||
for trigger, escalation_type in escalation_triggers.items():
|
||||
if trigger in reason_str:
|
||||
return f"escalate_{escalation_type}"
|
||||
return "escalate_standard"
|
||||
|
||||
# Check for specific escalation triggers in state
|
||||
for trigger, escalation_type in escalation_triggers.items():
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
trigger_present = state.get(trigger, False)
|
||||
else:
|
||||
trigger_present = getattr(state, trigger, False)
|
||||
|
||||
if trigger_present:
|
||||
return f"escalate_{escalation_type}"
|
||||
|
||||
return "continue_automated"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_user_authorization(
|
||||
required_permissions: list[str] | None = None,
|
||||
user_key: str = "user",
|
||||
permissions_key: str = "permissions",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that checks user authorization levels.
|
||||
|
||||
Args:
|
||||
required_permissions: List of permissions required for access
|
||||
user_key: Key in state containing user information
|
||||
permissions_key: Key in user data containing permissions list
|
||||
|
||||
Returns:
|
||||
Router function that returns "authorized" or "unauthorized"
|
||||
|
||||
Example:
|
||||
auth_router = check_user_authorization([
|
||||
"read_data", "write_reports", "admin_access"
|
||||
])
|
||||
graph.add_conditional_edges("permission_check", auth_router)
|
||||
"""
|
||||
if required_permissions is None:
|
||||
required_permissions = ["basic_access"]
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["authorized", "unauthorized"]:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
user = state.get(user_key)
|
||||
else:
|
||||
user = getattr(state, user_key, None)
|
||||
|
||||
if user is None:
|
||||
return "unauthorized"
|
||||
|
||||
# Extract user permissions
|
||||
user_permissions = []
|
||||
if isinstance(user, dict):
|
||||
user_permissions = user.get(permissions_key, [])
|
||||
elif hasattr(user, permissions_key):
|
||||
user_permissions = getattr(user, permissions_key, [])
|
||||
|
||||
if not isinstance(user_permissions, list):
|
||||
return "unauthorized"
|
||||
|
||||
# Check if user has all required permissions
|
||||
for permission in required_permissions:
|
||||
if permission not in user_permissions:
|
||||
return "unauthorized"
|
||||
|
||||
return "authorized"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def collect_user_input(
|
||||
input_types: dict[str, str] | None = None,
|
||||
pending_input_key: str = "pending_user_input",
|
||||
input_type_key: str = "input_type",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that manages user input collection.
|
||||
|
||||
Args:
|
||||
input_types: Mapping of input types to collection methods
|
||||
pending_input_key: Key in state indicating pending input
|
||||
input_type_key: Key in state specifying type of input needed
|
||||
|
||||
Returns:
|
||||
Router function that returns input collection method or "no_input_needed"
|
||||
|
||||
Example:
|
||||
input_router = collect_user_input({
|
||||
"text": "text_input_form",
|
||||
"choice": "multiple_choice",
|
||||
"file": "file_upload"
|
||||
})
|
||||
graph.add_conditional_edges("input_manager", input_router)
|
||||
"""
|
||||
if input_types is None:
|
||||
input_types = {
|
||||
"text": "text_input",
|
||||
"choice": "choice_input",
|
||||
"confirmation": "confirm_input",
|
||||
"file": "file_input",
|
||||
"numeric": "number_input",
|
||||
}
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
pending = state.get(pending_input_key, False)
|
||||
input_type = state.get(input_type_key, "text")
|
||||
else:
|
||||
pending = getattr(state, pending_input_key, False)
|
||||
input_type = getattr(state, input_type_key, "text")
|
||||
|
||||
if not pending:
|
||||
return "no_input_needed"
|
||||
|
||||
input_type_str = str(input_type).lower()
|
||||
return input_types.get(input_type_str, "text_input")
|
||||
|
||||
return router
|
||||
"""User interaction routing helpers for LangGraph workflows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Literal
|
||||
|
||||
from .core import StateLike, get_state_value
|
||||
|
||||
|
||||
def _normalize_strings(values: Sequence[str]) -> set[str]:
|
||||
"""Return normalized lowercase strings for comparison."""
|
||||
|
||||
return {value.lower().strip() for value in values}
|
||||
|
||||
|
||||
def human_interrupt(
|
||||
interrupt_signals: Sequence[str] | None = None,
|
||||
interrupt_key: str = "human_interrupt",
|
||||
) -> Callable[[StateLike], Literal["interrupt", "continue"]]:
|
||||
"""Create a router that detects human interruption requests."""
|
||||
|
||||
signals = _normalize_strings(interrupt_signals or [
|
||||
"stop",
|
||||
"pause",
|
||||
"cancel",
|
||||
"abort",
|
||||
"interrupt",
|
||||
"halt",
|
||||
])
|
||||
|
||||
def router(state: StateLike) -> Literal["interrupt", "continue"]:
|
||||
signal = get_state_value(state, interrupt_key)
|
||||
if signal is None:
|
||||
return "continue"
|
||||
signal_str = str(signal).lower().strip()
|
||||
return "interrupt" if signal_str in signals else "continue"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def pass_status_to_user(
|
||||
status_levels: Mapping[str, str] | None = None,
|
||||
status_key: str = "status",
|
||||
notify_key: str = "notify_user",
|
||||
) -> Callable[[StateLike], str]:
|
||||
"""Create a router that determines when to notify users of status."""
|
||||
|
||||
levels = (
|
||||
{key.lower(): value for key, value in status_levels.items()}
|
||||
if status_levels is not None
|
||||
else {
|
||||
"error": "urgent",
|
||||
"failed": "urgent",
|
||||
"warning": "medium",
|
||||
"completed": "low",
|
||||
"success": "low",
|
||||
"in_progress": "info",
|
||||
}
|
||||
)
|
||||
|
||||
def router(state: StateLike) -> str:
|
||||
should_notify = bool(get_state_value(state, notify_key, True))
|
||||
status = get_state_value(state, status_key)
|
||||
|
||||
if not should_notify or status is None:
|
||||
return "no_notification"
|
||||
|
||||
return levels.get(str(status).lower(), "no_notification")
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def user_feedback_loop(
|
||||
feedback_required_conditions: Sequence[str] | None = None,
|
||||
feedback_key: str = "requires_feedback",
|
||||
feedback_type_key: str = "feedback_type",
|
||||
) -> Callable[[StateLike], str]:
|
||||
"""Create a router that manages user feedback collection."""
|
||||
|
||||
conditions = list(feedback_required_conditions) if feedback_required_conditions is not None else [
|
||||
"low_confidence",
|
||||
"ambiguous_input",
|
||||
"multiple_options",
|
||||
"validation_failed",
|
||||
]
|
||||
|
||||
def router(state: StateLike) -> str:
|
||||
if bool(get_state_value(state, feedback_key, False)):
|
||||
feedback_type = get_state_value(state, feedback_type_key, "general")
|
||||
return f"feedback_{feedback_type}"
|
||||
|
||||
for condition in conditions:
|
||||
if bool(get_state_value(state, condition, False)):
|
||||
return f"feedback_{condition}"
|
||||
|
||||
return "no_feedback"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def escalate_to_human(
|
||||
escalation_triggers: Mapping[str, str] | None = None,
|
||||
auto_escalate_key: str = "auto_escalate",
|
||||
escalation_reason_key: str = "escalation_reason",
|
||||
) -> Callable[[StateLike], str]:
|
||||
"""Create a router that escalates to human operators when needed."""
|
||||
|
||||
triggers = (
|
||||
dict(escalation_triggers)
|
||||
if escalation_triggers is not None
|
||||
else {
|
||||
"critical_error": "immediate",
|
||||
"security_issue": "immediate",
|
||||
"multiple_failures": "urgent",
|
||||
"user_request": "standard",
|
||||
"manual_review": "standard",
|
||||
"complex_case": "expert",
|
||||
}
|
||||
)
|
||||
|
||||
def router(state: StateLike) -> str:
|
||||
auto_escalate = bool(get_state_value(state, auto_escalate_key, False))
|
||||
escalation_reason = get_state_value(state, escalation_reason_key)
|
||||
|
||||
if auto_escalate and escalation_reason:
|
||||
reason_str = str(escalation_reason).lower()
|
||||
for trigger, escalation_type in triggers.items():
|
||||
if trigger in reason_str:
|
||||
return f"escalate_{escalation_type}"
|
||||
return "escalate_standard"
|
||||
|
||||
for trigger, escalation_type in triggers.items():
|
||||
if bool(get_state_value(state, trigger, False)):
|
||||
return f"escalate_{escalation_type}"
|
||||
|
||||
return "continue_automated"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_user_authorization(
|
||||
required_permissions: Sequence[str] | None = None,
|
||||
user_key: str = "user",
|
||||
permissions_key: str = "permissions",
|
||||
) -> Callable[[StateLike], Literal["authorized", "unauthorized"]]:
|
||||
"""Create a router that checks user authorization levels."""
|
||||
|
||||
required = set(required_permissions or ["basic_access"])
|
||||
|
||||
def router(state: StateLike) -> Literal["authorized", "unauthorized"]:
|
||||
user_value = get_state_value(state, user_key)
|
||||
if user_value is None:
|
||||
return "unauthorized"
|
||||
|
||||
permissions_value = None
|
||||
if isinstance(user_value, dict):
|
||||
permissions_value = user_value.get(permissions_key)
|
||||
else:
|
||||
permissions_value = getattr(user_value, permissions_key, None)
|
||||
|
||||
if not isinstance(permissions_value, Sequence) or isinstance(
|
||||
permissions_value, (str, bytes, bytearray)
|
||||
):
|
||||
return "unauthorized"
|
||||
|
||||
permissions = {str(permission) for permission in permissions_value}
|
||||
if required.issubset(permissions):
|
||||
return "authorized"
|
||||
return "unauthorized"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def collect_user_input(
|
||||
input_types: Mapping[str, str] | None = None,
|
||||
pending_input_key: str = "pending_user_input",
|
||||
input_type_key: str = "input_type",
|
||||
) -> Callable[[StateLike], str]:
|
||||
"""Create a router that manages user input collection."""
|
||||
|
||||
inputs = (
|
||||
dict(input_types)
|
||||
if input_types is not None
|
||||
else {
|
||||
"text": "text_input",
|
||||
"choice": "choice_input",
|
||||
"confirmation": "confirm_input",
|
||||
"file": "file_input",
|
||||
"numeric": "number_input",
|
||||
}
|
||||
)
|
||||
|
||||
def router(state: StateLike) -> str:
|
||||
if not bool(get_state_value(state, pending_input_key, False)):
|
||||
return "no_input_needed"
|
||||
|
||||
input_type = str(get_state_value(state, input_type_key, "text")).lower()
|
||||
return inputs.get(input_type, "text_input")
|
||||
|
||||
return router
|
||||
|
||||
|
||||
__all__ = [
|
||||
"human_interrupt",
|
||||
"pass_status_to_user",
|
||||
"user_feedback_loop",
|
||||
"escalate_to_human",
|
||||
"check_user_authorization",
|
||||
"collect_user_input",
|
||||
]
|
||||
|
||||
@@ -1,406 +1,270 @@
|
||||
"""Validation edge helpers for quality control and data integrity.
|
||||
|
||||
This module provides edge helpers for validating outputs, checking accuracy,
|
||||
confidence levels, format compliance, and data privacy requirements.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal, TypeVar
|
||||
|
||||
from .core import StateProtocol
|
||||
|
||||
StateT = TypeVar("StateT", bound=StateProtocol)
|
||||
|
||||
|
||||
def check_accuracy(
|
||||
threshold: float = 0.8,
|
||||
accuracy_key: str = "accuracy_score",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that checks if accuracy meets threshold.
|
||||
|
||||
Args:
|
||||
threshold: Minimum accuracy threshold (0.0 to 1.0)
|
||||
accuracy_key: Key in state containing accuracy score
|
||||
|
||||
Returns:
|
||||
Router function that returns "high_accuracy" or "low_accuracy"
|
||||
|
||||
Example:
|
||||
accuracy_router = check_accuracy(threshold=0.85)
|
||||
graph.add_conditional_edges("quality_check", accuracy_router)
|
||||
"""
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["high_accuracy", "low_accuracy"]:
|
||||
if isinstance(state, dict):
|
||||
accuracy = state.get(accuracy_key, 0.0)
|
||||
else:
|
||||
accuracy = getattr(state, accuracy_key, 0.0)
|
||||
|
||||
try:
|
||||
score = float(accuracy)
|
||||
return "high_accuracy" if score >= threshold else "low_accuracy"
|
||||
except (ValueError, TypeError):
|
||||
return "low_accuracy"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_confidence_level(
|
||||
threshold: float = 0.7,
|
||||
confidence_key: str = "confidence",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router based on confidence score threshold.
|
||||
|
||||
Args:
|
||||
threshold: Minimum confidence threshold (0.0 to 1.0)
|
||||
confidence_key: Key in state containing confidence score
|
||||
|
||||
Returns:
|
||||
Router function that returns "high_confidence" or "low_confidence"
|
||||
|
||||
Example:
|
||||
confidence_router = check_confidence_level(threshold=0.75)
|
||||
graph.add_conditional_edges("llm_output", confidence_router)
|
||||
"""
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["high_confidence", "low_confidence"]:
|
||||
if isinstance(state, dict):
|
||||
confidence = state.get(confidence_key, 0.0)
|
||||
else:
|
||||
confidence = getattr(state, confidence_key, 0.0)
|
||||
|
||||
try:
|
||||
score = float(confidence)
|
||||
return "high_confidence" if score >= threshold else "low_confidence"
|
||||
except (ValueError, TypeError):
|
||||
return "low_confidence"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def validate_output_format(
|
||||
expected_format: str = "json",
|
||||
output_key: str = "output",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that validates output format.
|
||||
|
||||
Args:
|
||||
expected_format: Expected format ("json", "xml", "csv", "text")
|
||||
output_key: Key in state containing output to validate
|
||||
|
||||
Returns:
|
||||
Router function that returns "valid_format" or "invalid_format"
|
||||
|
||||
Example:
|
||||
format_router = validate_output_format(expected_format="json")
|
||||
graph.add_conditional_edges("formatter", format_router)
|
||||
"""
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["valid_format", "invalid_format"]:
|
||||
if isinstance(state, dict):
|
||||
output = state.get(output_key, "")
|
||||
else:
|
||||
output = getattr(state, output_key, "")
|
||||
|
||||
if not output:
|
||||
return "invalid_format"
|
||||
|
||||
try:
|
||||
if expected_format.lower() == "json":
|
||||
json.loads(str(output))
|
||||
return "valid_format"
|
||||
elif expected_format.lower() == "xml":
|
||||
# Basic XML validation - check for proper opening/closing tags
|
||||
output_str = str(output).strip()
|
||||
if output_str.startswith("<") and output_str.endswith(">") and (output_str.endswith("/>") or (
|
||||
"</" in output_str and output_str.count("<") >= 2
|
||||
)):
|
||||
return "valid_format"
|
||||
return "invalid_format"
|
||||
elif expected_format.lower() == "csv":
|
||||
# Basic CSV validation - check for comma separation
|
||||
lines = str(output).strip().split("\n")
|
||||
return "valid_format" if lines and "," in lines[0] else "invalid_format"
|
||||
else: # Default to text validation
|
||||
# Convert to string and check if non-empty
|
||||
text_output = str(output)
|
||||
if text_output.strip() != "":
|
||||
return "valid_format"
|
||||
return "invalid_format"
|
||||
|
||||
except Exception:
|
||||
return "invalid_format"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_privacy_compliance(
|
||||
sensitive_patterns: list[str] | None = None,
|
||||
content_key: str = "content",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
r"""Create a router that checks for privacy compliance.
|
||||
|
||||
Args:
|
||||
sensitive_patterns: List of regex patterns for sensitive data
|
||||
content_key: Key in state containing content to check
|
||||
|
||||
Returns:
|
||||
Router function that returns "compliant" or "privacy_violation"
|
||||
|
||||
Example:
|
||||
privacy_router = check_privacy_compliance([
|
||||
r'\b\d{3}-\d{2}-\d{4}\b', # SSN
|
||||
r'\b\d{4}[- ]?\d{4}[- ]?\d{4}[- ]?\d{4}\b' # Credit card
|
||||
])
|
||||
graph.add_conditional_edges("privacy_check", privacy_router)
|
||||
"""
|
||||
if sensitive_patterns is None:
|
||||
sensitive_patterns = [
|
||||
r"\b\d{3}-\d{2}-\d{4}\b", # SSN pattern
|
||||
r"\b\d{4}[- ]?\d{4}[- ]?\d{4}[- ]?\d{4}\b", # Credit card pattern
|
||||
r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", # Email
|
||||
r"\b\d{3}[- ]\d{3}[- ]\d{4}\b", # Phone number (require separators)
|
||||
]
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["compliant", "privacy_violation"]:
|
||||
if isinstance(state, dict):
|
||||
content = state.get(content_key, "")
|
||||
else:
|
||||
content = getattr(state, content_key, "")
|
||||
|
||||
content_str = str(content)
|
||||
|
||||
for pattern in sensitive_patterns:
|
||||
if re.search(pattern, content_str, re.IGNORECASE):
|
||||
return "privacy_violation"
|
||||
|
||||
return "compliant"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_data_freshness(
|
||||
max_age_seconds: int = 3600, # 1 hour default
|
||||
timestamp_key: str = "timestamp",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that checks if data is fresh enough.
|
||||
|
||||
Args:
|
||||
max_age_seconds: Maximum age in seconds before data is stale
|
||||
timestamp_key: Key in state containing data timestamp
|
||||
|
||||
Returns:
|
||||
Router function that returns "fresh" or "stale"
|
||||
|
||||
Example:
|
||||
freshness_router = check_data_freshness(max_age_seconds=1800) # 30 minutes
|
||||
graph.add_conditional_edges("data_check", freshness_router)
|
||||
"""
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> Literal["fresh", "stale"]:
|
||||
if isinstance(state, dict):
|
||||
timestamp = state.get(timestamp_key)
|
||||
else:
|
||||
timestamp = getattr(state, timestamp_key, None)
|
||||
|
||||
if timestamp is None:
|
||||
return "stale"
|
||||
|
||||
try:
|
||||
# Handle different timestamp formats
|
||||
if isinstance(timestamp, str):
|
||||
# Try parsing as Unix timestamp first, then ISO format
|
||||
try:
|
||||
timestamp_value = float(timestamp)
|
||||
except (ValueError, TypeError):
|
||||
# Try parsing ISO format
|
||||
from datetime import datetime
|
||||
dt = datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
|
||||
timestamp_value = dt.timestamp()
|
||||
else:
|
||||
timestamp_value = float(timestamp)
|
||||
|
||||
current_time = time.time()
|
||||
age = current_time - timestamp_value
|
||||
|
||||
return "fresh" if age <= max_age_seconds else "stale"
|
||||
|
||||
except (ValueError, TypeError):
|
||||
return "stale"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def validate_required_fields(
|
||||
required_fields: list[str],
|
||||
strict_mode: bool = True,
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that validates presence of required fields.
|
||||
|
||||
Args:
|
||||
required_fields: List of field names that must be present
|
||||
strict_mode: If True, fields must be non-empty; if False, just present
|
||||
|
||||
Returns:
|
||||
Router function that returns "valid" or "missing_fields"
|
||||
|
||||
Example:
|
||||
field_router = validate_required_fields([
|
||||
"user_id", "request_data", "timestamp"
|
||||
])
|
||||
graph.add_conditional_edges("input_validator", field_router)
|
||||
"""
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["valid", "missing_fields"]:
|
||||
for field in required_fields:
|
||||
if isinstance(state, dict):
|
||||
value = state.get(field)
|
||||
else:
|
||||
value = getattr(state, field, None)
|
||||
|
||||
if value is None:
|
||||
return "missing_fields"
|
||||
|
||||
if strict_mode and (
|
||||
value == "" or (isinstance(value, list) and len(value) == 0)
|
||||
):
|
||||
return "missing_fields"
|
||||
|
||||
return "valid"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_output_length(
|
||||
min_length: int = 1,
|
||||
max_length: int | None = None,
|
||||
content_key: str = "output",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that validates output length constraints.
|
||||
|
||||
Args:
|
||||
min_length: Minimum required length
|
||||
max_length: Maximum allowed length (None for no limit)
|
||||
content_key: Key in state containing content to check
|
||||
|
||||
Returns:
|
||||
Router function that returns "valid_length", "too_short", or "too_long"
|
||||
|
||||
Example:
|
||||
length_router = check_output_length(min_length=10, max_length=1000)
|
||||
graph.add_conditional_edges("length_validator", length_router)
|
||||
"""
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["valid_length", "too_short", "too_long"]:
|
||||
if isinstance(state, dict):
|
||||
content = state.get(content_key, "")
|
||||
else:
|
||||
content = getattr(state, content_key, "")
|
||||
|
||||
content_length = len(str(content))
|
||||
|
||||
if content_length < min_length:
|
||||
return "too_short"
|
||||
elif max_length is not None and content_length > max_length:
|
||||
return "too_long"
|
||||
else:
|
||||
return "valid_length"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def create_content_availability_router(
|
||||
content_keys: list[str] | None = None,
|
||||
success_target: str = "analyze_content",
|
||||
failure_target: str = "status_summary",
|
||||
error_key: str = "error",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that checks for content availability and success conditions.
|
||||
|
||||
This router is designed for workflows that need to verify if processing
|
||||
was successful and content is available for further processing.
|
||||
|
||||
Args:
|
||||
content_keys: List of state keys to check for content availability
|
||||
success_target: Target node when content is available and no errors
|
||||
failure_target: Target node when content is missing or errors present
|
||||
error_key: Key in state containing error information
|
||||
|
||||
Returns:
|
||||
Router function that returns success_target or failure_target
|
||||
|
||||
Example:
|
||||
content_router = create_content_availability_router(
|
||||
content_keys=["scraped_content", "repomix_output"],
|
||||
success_target="analyze_content",
|
||||
failure_target="status_summary"
|
||||
)
|
||||
graph.add_conditional_edges("processing_check", content_router)
|
||||
"""
|
||||
if content_keys is None:
|
||||
content_keys = ["scraped_content", "repomix_output"]
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
# Check if there's an error
|
||||
if isinstance(state, dict):
|
||||
has_error = bool(state.get(error_key))
|
||||
|
||||
# Check if any content is available
|
||||
has_content = False
|
||||
for key in content_keys:
|
||||
content = state.get(key)
|
||||
if content:
|
||||
# For lists, check if they have items
|
||||
if isinstance(content, list):
|
||||
has_content = len(content) > 0
|
||||
# For strings, check if they're non-empty
|
||||
elif isinstance(content, str):
|
||||
has_content = len(content.strip()) > 0
|
||||
# For other types, check if they're truthy
|
||||
else:
|
||||
has_content = bool(content)
|
||||
|
||||
# If we found content, break early
|
||||
if has_content:
|
||||
break
|
||||
else:
|
||||
has_error = bool(getattr(state, error_key, None))
|
||||
|
||||
# Check if any content is available
|
||||
has_content = False
|
||||
for key in content_keys:
|
||||
content = getattr(state, key, None)
|
||||
if content:
|
||||
# For lists, check if they have items
|
||||
if isinstance(content, list):
|
||||
has_content = len(content) > 0
|
||||
# For strings, check if they're non-empty
|
||||
elif isinstance(content, str):
|
||||
has_content = len(content.strip()) > 0
|
||||
# For other types, check if they're truthy
|
||||
else:
|
||||
has_content = bool(content)
|
||||
|
||||
# If we found content, break early
|
||||
if has_content:
|
||||
break
|
||||
|
||||
# Route based on content availability and error status
|
||||
return success_target if has_content and not has_error else failure_target
|
||||
|
||||
return router
|
||||
"""Validation and compliance routing helpers for LangGraph graphs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Callable, Sequence
|
||||
from numbers import Real
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal
|
||||
|
||||
from .core import StateLike, get_state_value
|
||||
|
||||
|
||||
def _to_float(value: object, default: float = 0.0) -> float:
|
||||
"""Convert ``value`` to ``float`` when possible, otherwise ``default``."""
|
||||
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, bool):
|
||||
return float(value)
|
||||
if isinstance(value, Real):
|
||||
return float(value)
|
||||
if isinstance(value, str):
|
||||
stripped = value.strip()
|
||||
if not stripped:
|
||||
return default
|
||||
try:
|
||||
return float(stripped)
|
||||
except ValueError:
|
||||
return default
|
||||
return default
|
||||
|
||||
|
||||
def _to_str(value: object, default: str = "") -> str:
|
||||
"""Return a string representation of ``value``."""
|
||||
|
||||
if value is None:
|
||||
return default
|
||||
return str(value)
|
||||
|
||||
|
||||
def _has_content(value: object) -> bool:
|
||||
"""Determine whether ``value`` carries usable content."""
|
||||
|
||||
if value is None:
|
||||
return False
|
||||
if isinstance(value, str):
|
||||
return bool(value)
|
||||
if isinstance(value, bool):
|
||||
return True
|
||||
if isinstance(value, Real):
|
||||
return True
|
||||
if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
|
||||
return len(value) > 0
|
||||
return bool(value)
|
||||
|
||||
|
||||
def _parse_timestamp(value: object) -> float | None:
|
||||
"""Parse various timestamp representations into a Unix epoch (UTC)."""
|
||||
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
|
||||
if isinstance(value, str):
|
||||
stripped = value.strip()
|
||||
try:
|
||||
return float(stripped)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
try:
|
||||
dt = datetime.fromisoformat(stripped.replace("Z", "+00:00"))
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return dt.timestamp()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def check_accuracy(
|
||||
threshold: float = 0.8,
|
||||
accuracy_key: str = "accuracy_score",
|
||||
) -> Callable[[StateLike], Literal["high_accuracy", "low_accuracy"]]:
|
||||
"""Create a router that checks if accuracy meets the provided threshold."""
|
||||
|
||||
def router(state: StateLike) -> Literal["high_accuracy", "low_accuracy"]:
|
||||
score = _to_float(get_state_value(state, accuracy_key, 0.0))
|
||||
return "high_accuracy" if score >= threshold else "low_accuracy"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_confidence_level(
|
||||
threshold: float = 0.7,
|
||||
confidence_key: str = "confidence",
|
||||
) -> Callable[[StateLike], Literal["high_confidence", "low_confidence"]]:
|
||||
"""Create a router that evaluates whether confidence exceeds ``threshold``."""
|
||||
|
||||
def router(state: StateLike) -> Literal["high_confidence", "low_confidence"]:
|
||||
score = _to_float(get_state_value(state, confidence_key, 0.0))
|
||||
return "high_confidence" if score >= threshold else "low_confidence"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def validate_output_format(
|
||||
expected_format: str = "json",
|
||||
output_key: str = "output",
|
||||
) -> Callable[[StateLike], Literal["valid_format", "invalid_format"]]:
|
||||
"""Create a router that validates the serialized format of ``output_key``."""
|
||||
|
||||
expected = expected_format.lower()
|
||||
|
||||
def router(state: StateLike) -> Literal["valid_format", "invalid_format"]:
|
||||
output = get_state_value(state, output_key, "")
|
||||
output_str = _to_str(output)
|
||||
if not output_str:
|
||||
return "invalid_format"
|
||||
|
||||
try:
|
||||
if expected == "json":
|
||||
json.loads(output_str)
|
||||
return "valid_format"
|
||||
if expected == "xml":
|
||||
stripped = output_str.strip()
|
||||
if stripped.startswith("<") and stripped.endswith(">"):
|
||||
if stripped.endswith("/>") or ("</" in stripped and stripped.count("<") >= 2):
|
||||
return "valid_format"
|
||||
return "invalid_format"
|
||||
if expected == "csv":
|
||||
lines = output_str.strip().split("\n")
|
||||
return "valid_format" if lines and "," in lines[0] else "invalid_format"
|
||||
|
||||
return "valid_format" if output_str.strip() else "invalid_format"
|
||||
except Exception: # noqa: BLE001
|
||||
return "invalid_format"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_privacy_compliance(
|
||||
sensitive_patterns: Sequence[str] | None = None,
|
||||
content_key: str = "content",
|
||||
) -> Callable[[StateLike], Literal["compliant", "privacy_violation"]]:
|
||||
r"""Create a router that checks content for sensitive data patterns."""
|
||||
|
||||
patterns = list(sensitive_patterns) if sensitive_patterns is not None else [
|
||||
r"\b\d{3}-\d{2}-\d{4}\b", # SSN
|
||||
r"\b\d{4}[- ]?\d{4}[- ]?\d{4}[- ]?\d{4}\b", # Credit card
|
||||
r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", # Email
|
||||
r"\b\d{3}[- ]\d{3}[- ]\d{4}\b", # Phone number with separators
|
||||
]
|
||||
|
||||
def router(state: StateLike) -> Literal["compliant", "privacy_violation"]:
|
||||
content_str = _to_str(get_state_value(state, content_key, ""))
|
||||
for pattern in patterns:
|
||||
if re.search(pattern, content_str, re.IGNORECASE):
|
||||
return "privacy_violation"
|
||||
return "compliant"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_data_freshness(
|
||||
max_age_seconds: int = 3600,
|
||||
timestamp_key: str = "timestamp",
|
||||
) -> Callable[[StateLike], Literal["fresh", "stale"]]:
|
||||
"""Create a router that validates whether stored data is still fresh."""
|
||||
|
||||
def router(state: StateLike) -> Literal["fresh", "stale"]:
|
||||
timestamp_value = get_state_value(state, timestamp_key)
|
||||
parsed = _parse_timestamp(timestamp_value)
|
||||
if parsed is None:
|
||||
return "stale"
|
||||
|
||||
age = time.time() - parsed
|
||||
return "fresh" if age <= max_age_seconds else "stale"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def validate_required_fields(
|
||||
required_fields: Sequence[str],
|
||||
strict_mode: bool = True,
|
||||
) -> Callable[[StateLike], Literal["valid", "missing_fields"]]:
|
||||
"""Create a router that checks the presence of required state fields."""
|
||||
|
||||
def router(state: StateLike) -> Literal["valid", "missing_fields"]:
|
||||
for field in required_fields:
|
||||
value = get_state_value(state, field)
|
||||
if value is None:
|
||||
return "missing_fields"
|
||||
if strict_mode and not _has_content(value):
|
||||
return "missing_fields"
|
||||
|
||||
return "valid"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_output_length(
|
||||
min_length: int = 1,
|
||||
max_length: int | None = None,
|
||||
content_key: str = "output",
|
||||
) -> Callable[[StateLike], Literal["valid_length", "too_short", "too_long"]]:
|
||||
"""Create a router that validates content length constraints."""
|
||||
|
||||
def router(state: StateLike) -> Literal["valid_length", "too_short", "too_long"]:
|
||||
content_str = _to_str(get_state_value(state, content_key, ""))
|
||||
content_length = len(content_str)
|
||||
|
||||
if content_length < min_length:
|
||||
return "too_short"
|
||||
if max_length is not None and content_length > max_length:
|
||||
return "too_long"
|
||||
return "valid_length"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def create_content_availability_router(
|
||||
content_keys: Sequence[str] | None = None,
|
||||
success_target: str = "analyze_content",
|
||||
failure_target: str = "status_summary",
|
||||
error_key: str = "error",
|
||||
) -> Callable[[StateLike], str]:
|
||||
"""Create a router that checks for content availability and error flags."""
|
||||
|
||||
keys = list(content_keys) if content_keys is not None else [
|
||||
"scraped_content",
|
||||
"repomix_output",
|
||||
]
|
||||
|
||||
def _content_available(value: object | None) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
if isinstance(value, str):
|
||||
return bool(value.strip())
|
||||
if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
|
||||
return len(value) > 0
|
||||
return bool(value)
|
||||
|
||||
def router(state: StateLike) -> str:
|
||||
has_error = bool(get_state_value(state, error_key))
|
||||
has_content = any(
|
||||
_content_available(get_state_value(state, key))
|
||||
for key in keys
|
||||
)
|
||||
return success_target if has_content and not has_error else failure_target
|
||||
|
||||
return router
|
||||
|
||||
|
||||
__all__ = [
|
||||
"check_accuracy",
|
||||
"check_confidence_level",
|
||||
"validate_output_format",
|
||||
"check_privacy_compliance",
|
||||
"check_data_freshness",
|
||||
"validate_required_fields",
|
||||
"check_output_length",
|
||||
"create_content_availability_router",
|
||||
]
|
||||
|
||||
@@ -8,12 +8,13 @@ routing and provide utilities for building composite routing logic.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
|
||||
from .core import StateLike, get_state_value
|
||||
|
||||
from langgraph.types import Command
|
||||
|
||||
from biz_bud.logging import debug_highlight
|
||||
from biz_bud.logging import debug_highlight, error_highlight
|
||||
|
||||
|
||||
class WorkflowRouters:
|
||||
@@ -24,7 +25,7 @@ class WorkflowRouters:
|
||||
stage_key: str = "workflow_stage",
|
||||
stage_mapping: dict[str, str] | None = None,
|
||||
default_stage: str = "end",
|
||||
) -> Callable[[dict[str, Any]], str]:
|
||||
) -> Callable[[StateLike], str]:
|
||||
"""Route based on workflow stage progression.
|
||||
|
||||
Routes based on current workflow stage, useful for sequential
|
||||
@@ -55,9 +56,11 @@ class WorkflowRouters:
|
||||
if stage_mapping is None:
|
||||
stage_mapping = {}
|
||||
|
||||
def router(state: dict[str, Any]) -> str:
|
||||
current_stage = state.get(stage_key, "unknown")
|
||||
return stage_mapping.get(current_stage, default_stage)
|
||||
def router(state: StateLike) -> str:
|
||||
current_stage = get_state_value(state, stage_key, "unknown")
|
||||
if isinstance(current_stage, str):
|
||||
return stage_mapping.get(current_stage, default_stage)
|
||||
return default_stage
|
||||
|
||||
return router
|
||||
|
||||
@@ -67,7 +70,7 @@ class WorkflowRouters:
|
||||
subgraph_mapping: dict[str, str],
|
||||
default_subgraph: str = "main",
|
||||
parent_return: str = "consolidate",
|
||||
) -> Callable[[dict[str, Any]], Command[str]]:
|
||||
) -> Callable[[StateLike], Command[str]]:
|
||||
"""Route to subgraphs using Command pattern.
|
||||
|
||||
Routes to different subgraphs based on state values, useful for
|
||||
@@ -97,8 +100,9 @@ class WorkflowRouters:
|
||||
```
|
||||
"""
|
||||
|
||||
def router(state: dict[str, Any]) -> Command[str]:
|
||||
task_type = state.get(subgraph_key, "default")
|
||||
def router(state: StateLike) -> Command[str]:
|
||||
raw_task_type = get_state_value(state, subgraph_key, "default")
|
||||
task_type = raw_task_type if isinstance(raw_task_type, str) else str(raw_task_type)
|
||||
target_subgraph = subgraph_mapping.get(task_type, default_subgraph)
|
||||
|
||||
debug_highlight(
|
||||
@@ -119,9 +123,9 @@ class WorkflowRouters:
|
||||
|
||||
@staticmethod
|
||||
def combine_routers(
|
||||
routers: Sequence[tuple[Callable[[dict[str, Any]], Any], str]],
|
||||
routers: Sequence[tuple[Callable[[StateLike], object], str]],
|
||||
combination_logic: str = "first_match",
|
||||
) -> Callable[[dict[str, Any]], Any]:
|
||||
) -> Callable[[StateLike], object]:
|
||||
"""Combine multiple routers with specified logic.
|
||||
|
||||
Allows composition of multiple routing functions with different
|
||||
@@ -145,8 +149,8 @@ class WorkflowRouters:
|
||||
```
|
||||
"""
|
||||
|
||||
def combined_router(state: dict[str, Any]) -> Any:
|
||||
results = []
|
||||
def combined_router(state: StateLike) -> object:
|
||||
results: list[tuple[str, object]] = []
|
||||
|
||||
for router_func, router_name in routers:
|
||||
try:
|
||||
@@ -174,41 +178,37 @@ class WorkflowRouters:
|
||||
|
||||
@staticmethod
|
||||
def create_debug_router(
|
||||
inner_router: Callable[[dict[str, Any]], Any],
|
||||
inner_router: Callable[[StateLike], object],
|
||||
name: str = "unnamed_router",
|
||||
) -> Callable[[dict[str, Any]], Any]:
|
||||
"""Wrap a router with debug logging.
|
||||
fallback: object | None = "end",
|
||||
rethrow: bool = False,
|
||||
) -> Callable[[StateLike], object]:
|
||||
"""Wrap a router with debug logging."""
|
||||
|
||||
Adds comprehensive debug logging around router execution,
|
||||
useful for troubleshooting routing decisions in complex workflows.
|
||||
|
||||
Args:
|
||||
inner_router: The router function to wrap
|
||||
name: Name for logging identification
|
||||
|
||||
Returns:
|
||||
Debug-wrapped router function
|
||||
|
||||
Example:
|
||||
```python
|
||||
basic_router = BasicRouters.route_on_key("status", {"ok": "continue"})
|
||||
debug_router = WorkflowRouters.create_debug_router(
|
||||
basic_router,
|
||||
"status_router"
|
||||
def debug_router(state: StateLike) -> object:
|
||||
state_keys: list[str] = (
|
||||
list(state.keys()) if isinstance(state, Mapping) else []
|
||||
)
|
||||
graph.add_conditional_edges("processor", debug_router)
|
||||
```
|
||||
"""
|
||||
|
||||
def debug_router(state: dict[str, Any]) -> Any:
|
||||
debug_highlight(
|
||||
f"Router '{name}' evaluating state keys: {list(state.keys())}",
|
||||
f"Router '{name}' evaluating state keys: {state_keys}",
|
||||
category="EdgeRouter",
|
||||
)
|
||||
|
||||
result = inner_router(state)
|
||||
debug_highlight(f"Router '{name}' result: {result}", category="EdgeRouter")
|
||||
return result
|
||||
try:
|
||||
result = inner_router(state)
|
||||
debug_highlight(
|
||||
f"Router '{name}' produced route: {result!r}",
|
||||
category="EdgeRouter",
|
||||
)
|
||||
return result
|
||||
except Exception as exc: # pragma: no cover - defensive logging path
|
||||
error_highlight(
|
||||
f"Router '{name}' failed with error: {exc}",
|
||||
category="EdgeRouter",
|
||||
)
|
||||
if rethrow:
|
||||
raise
|
||||
return fallback
|
||||
|
||||
return debug_router
|
||||
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Mapping
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, TypeVar, cast
|
||||
from typing import TypeVar, cast
|
||||
|
||||
from .cross_cutting import (
|
||||
handle_errors,
|
||||
@@ -38,28 +39,30 @@ from .state_immutability import (
|
||||
update_state_immutably,
|
||||
)
|
||||
|
||||
StateT = TypeVar("StateT")
|
||||
StateT = TypeVar("StateT", bound=Mapping[str, object])
|
||||
ReturnT = TypeVar("ReturnT")
|
||||
|
||||
|
||||
def create_type_safe_wrapper(
|
||||
func: Callable[[StateT], ReturnT]
|
||||
) -> Callable[[dict[str, Any]], ReturnT]:
|
||||
) -> Callable[[dict[str, object]], ReturnT]:
|
||||
"""Wrap a router or helper to satisfy LangGraph's ``dict``-based typing.
|
||||
|
||||
LangGraph nodes and routers receive ``dict[str, Any]`` state objects at
|
||||
LangGraph nodes and routers receive ``dict[str, object]`` state objects at
|
||||
runtime, but most of our helpers are annotated with TypedDict subclasses or
|
||||
custom mapping types. This utility provides a lightweight adapter that
|
||||
custom mapping types. This utility provides a lightweight adapter that
|
||||
casts the dynamic runtime state into the helper's expected type while
|
||||
preserving the original return value and function metadata.
|
||||
|
||||
The wrapper intentionally performs a shallow cast rather than copying the
|
||||
state to avoid the overhead of deep cloning large graphs. Callers that need
|
||||
state to avoid the overhead of deep cloning large graphs. Callers that need
|
||||
defensive copies should do so within their helper implementation.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(state: dict[str, Any], *args: Any, **kwargs: Any) -> ReturnT:
|
||||
def wrapper(
|
||||
state: dict[str, object], *args: object, **kwargs: object
|
||||
) -> ReturnT:
|
||||
return func(cast(StateT, state), *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -8,15 +8,19 @@ nodes and tools in the Business Buddy framework.
|
||||
import asyncio
|
||||
import functools
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Awaitable, Callable, Mapping, MutableMapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, TypedDict, cast
|
||||
from typing import ParamSpec, TypeVar, TypedDict, cast
|
||||
|
||||
from biz_bud.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class NodeMetric(TypedDict):
|
||||
"""Type definition for node metrics."""
|
||||
|
||||
@@ -29,7 +33,7 @@ class NodeMetric(TypedDict):
|
||||
last_error: str | None
|
||||
|
||||
|
||||
def _log_execution_start(node_name: str, context: dict[str, Any]) -> float:
|
||||
def _log_execution_start(node_name: str, context: Mapping[str, object]) -> float:
|
||||
"""Log the start of node execution and return start time.
|
||||
|
||||
Args:
|
||||
@@ -51,7 +55,9 @@ def _log_execution_start(node_name: str, context: dict[str, Any]) -> float:
|
||||
return start_time
|
||||
|
||||
|
||||
def _log_execution_success(node_name: str, start_time: float, context: dict[str, Any]) -> None:
|
||||
def _log_execution_success(
|
||||
node_name: str, start_time: float, context: Mapping[str, object]
|
||||
) -> None:
|
||||
"""Log successful node execution.
|
||||
|
||||
Args:
|
||||
@@ -71,7 +77,10 @@ def _log_execution_success(node_name: str, start_time: float, context: dict[str,
|
||||
|
||||
|
||||
def _log_execution_error(
|
||||
node_name: str, start_time: float, error: Exception, context: dict[str, Any]
|
||||
node_name: str,
|
||||
start_time: float,
|
||||
error: Exception,
|
||||
context: Mapping[str, object],
|
||||
) -> None:
|
||||
"""Log node execution error.
|
||||
|
||||
@@ -95,7 +104,7 @@ def _log_execution_error(
|
||||
|
||||
def log_node_execution(
|
||||
node_name: str | None = None,
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
) -> Callable[[Callable[P, Awaitable[R] | R]], Callable[P, Awaitable[R] | R]]:
|
||||
"""Log node execution with timing and context.
|
||||
|
||||
This decorator automatically logs entry, exit, and timing information
|
||||
@@ -108,16 +117,18 @@ def log_node_execution(
|
||||
Decorated function with logging
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def decorator(func: Callable[P, Awaitable[R] | R]) -> Callable[P, Awaitable[R] | R]:
|
||||
actual_node_name = node_name or func.__name__
|
||||
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
context = _extract_context_from_args(args, kwargs)
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
context = _extract_context_from_args(
|
||||
cast(tuple[object, ...], args), cast(Mapping[str, object], kwargs)
|
||||
)
|
||||
start_time = _log_execution_start(actual_node_name, context)
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
result = await cast(Callable[P, Awaitable[R]], func)(*args, **kwargs)
|
||||
_log_execution_success(actual_node_name, start_time, context)
|
||||
return result
|
||||
except Exception as e:
|
||||
@@ -125,12 +136,14 @@ def log_node_execution(
|
||||
raise
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
context = _extract_context_from_args(args, kwargs)
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
context = _extract_context_from_args(
|
||||
cast(tuple[object, ...], args), cast(Mapping[str, object], kwargs)
|
||||
)
|
||||
start_time = _log_execution_start(actual_node_name, context)
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
result = cast(Callable[P, R], func)(*args, **kwargs)
|
||||
_log_execution_success(actual_node_name, start_time, context)
|
||||
return result
|
||||
except Exception as e:
|
||||
@@ -138,12 +151,16 @@ def log_node_execution(
|
||||
raise
|
||||
|
||||
# Return appropriate wrapper based on function type
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return cast(Callable[P, Awaitable[R] | R], async_wrapper)
|
||||
return cast(Callable[P, Awaitable[R] | R], sync_wrapper)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _initialize_metric(state: dict[str, Any] | None, metric_name: str) -> NodeMetric | None:
|
||||
def _initialize_metric(
|
||||
state: MutableMapping[str, object] | None, metric_name: str
|
||||
) -> NodeMetric | None:
|
||||
"""Initialize or retrieve a metric from state.
|
||||
|
||||
Args:
|
||||
@@ -156,24 +173,51 @@ def _initialize_metric(state: dict[str, Any] | None, metric_name: str) -> NodeMe
|
||||
if state is None:
|
||||
return None
|
||||
|
||||
if "metrics" not in state:
|
||||
state["metrics"] = {}
|
||||
metrics_value = state.get("metrics")
|
||||
if isinstance(metrics_value, MutableMapping):
|
||||
metrics: MutableMapping[str, object] = metrics_value
|
||||
elif isinstance(metrics_value, Mapping):
|
||||
metrics = dict(metrics_value)
|
||||
state["metrics"] = metrics
|
||||
else:
|
||||
metrics = {}
|
||||
state["metrics"] = metrics
|
||||
|
||||
metrics = state["metrics"]
|
||||
raw_entry = metrics.get(metric_name)
|
||||
if isinstance(raw_entry, Mapping):
|
||||
entry_dict: dict[str, object] = dict(raw_entry)
|
||||
else:
|
||||
try:
|
||||
entry_dict = dict(raw_entry) if raw_entry is not None else {}
|
||||
except Exception:
|
||||
entry_dict = {}
|
||||
|
||||
if metric_name not in metrics:
|
||||
metrics[metric_name] = NodeMetric(
|
||||
count=0,
|
||||
success_count=0,
|
||||
failure_count=0,
|
||||
total_duration_ms=0.0,
|
||||
avg_duration_ms=0.0,
|
||||
last_execution=None,
|
||||
last_error=None,
|
||||
)
|
||||
def _as_int(value: object) -> int:
|
||||
if isinstance(value, bool):
|
||||
return int(value)
|
||||
if isinstance(value, (int, float)):
|
||||
return int(value)
|
||||
return 0
|
||||
|
||||
metric = cast("NodeMetric", metrics[metric_name])
|
||||
metric["count"] = (metric["count"] or 0) + 1
|
||||
def _as_float(value: object) -> float:
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
return 0.0
|
||||
|
||||
metric: NodeMetric = NodeMetric(
|
||||
count=_as_int(entry_dict.get("count")),
|
||||
success_count=_as_int(entry_dict.get("success_count")),
|
||||
failure_count=_as_int(entry_dict.get("failure_count")),
|
||||
total_duration_ms=_as_float(entry_dict.get("total_duration_ms")),
|
||||
avg_duration_ms=0.0,
|
||||
last_execution=str(entry_dict.get("last_execution"))
|
||||
if isinstance(entry_dict.get("last_execution"), str)
|
||||
else None,
|
||||
last_error=str(entry_dict.get("last_error"))
|
||||
if entry_dict.get("last_error") is not None
|
||||
else None,
|
||||
)
|
||||
metrics[metric_name] = metric
|
||||
return metric
|
||||
|
||||
|
||||
@@ -187,11 +231,13 @@ def _update_metric_success(metric: NodeMetric | None, elapsed_ms: float) -> None
|
||||
if metric is None:
|
||||
return
|
||||
|
||||
metric["count"] = (metric["count"] or 0) + 1
|
||||
metric["success_count"] = (metric["success_count"] or 0) + 1
|
||||
metric["total_duration_ms"] = (metric["total_duration_ms"] or 0.0) + elapsed_ms
|
||||
count = metric["count"] or 1
|
||||
metric["avg_duration_ms"] = (metric["total_duration_ms"] or 0.0) / count
|
||||
metric["last_execution"] = datetime.now(UTC).isoformat()
|
||||
metric["last_error"] = None
|
||||
|
||||
|
||||
def _update_metric_failure(metric: NodeMetric | None, elapsed_ms: float, error: Exception) -> None:
|
||||
@@ -205,6 +251,7 @@ def _update_metric_failure(metric: NodeMetric | None, elapsed_ms: float, error:
|
||||
if metric is None:
|
||||
return
|
||||
|
||||
metric["count"] = (metric["count"] or 0) + 1
|
||||
metric["failure_count"] = (metric["failure_count"] or 0) + 1
|
||||
metric["total_duration_ms"] = (metric["total_duration_ms"] or 0.0) + elapsed_ms
|
||||
count = metric["count"] or 1
|
||||
@@ -215,7 +262,7 @@ def _update_metric_failure(metric: NodeMetric | None, elapsed_ms: float, error:
|
||||
|
||||
def track_metrics(
|
||||
metric_name: str,
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
) -> Callable[[Callable[P, Awaitable[R] | R]], Callable[P, Awaitable[R] | R]]:
|
||||
"""Track metrics for node execution.
|
||||
|
||||
This decorator updates state with performance metrics including
|
||||
@@ -228,15 +275,19 @@ def track_metrics(
|
||||
Decorated function with metric tracking
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def decorator(func: Callable[P, Awaitable[R] | R]) -> Callable[P, Awaitable[R] | R]:
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
start_time = time.time()
|
||||
state = args[0] if args and isinstance(args[0], dict) else None
|
||||
state = (
|
||||
cast(MutableMapping[str, object], args[0])
|
||||
if args and isinstance(args[0], MutableMapping)
|
||||
else None
|
||||
)
|
||||
metric = _initialize_metric(state, metric_name)
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
result = await cast(Callable[P, Awaitable[R]], func)(*args, **kwargs)
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
_update_metric_success(metric, elapsed_ms)
|
||||
return result
|
||||
@@ -246,13 +297,17 @@ def track_metrics(
|
||||
raise
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
start_time = time.time()
|
||||
state = args[0] if args and isinstance(args[0], dict) else None
|
||||
state = (
|
||||
cast(MutableMapping[str, object], args[0])
|
||||
if args and isinstance(args[0], MutableMapping)
|
||||
else None
|
||||
)
|
||||
metric = _initialize_metric(state, metric_name)
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
result = cast(Callable[P, R], func)(*args, **kwargs)
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
_update_metric_success(metric, elapsed_ms)
|
||||
return result
|
||||
@@ -261,7 +316,9 @@ def track_metrics(
|
||||
_update_metric_failure(metric, elapsed_ms, e)
|
||||
raise
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return cast(Callable[P, Awaitable[R] | R], async_wrapper)
|
||||
return cast(Callable[P, Awaitable[R] | R], sync_wrapper)
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -269,10 +326,10 @@ def track_metrics(
|
||||
def _handle_error(
|
||||
error: Exception,
|
||||
func_name: str,
|
||||
args: tuple[Any, ...],
|
||||
error_handler: Callable[[Exception], Any] | None,
|
||||
fallback_value: Any
|
||||
) -> Any:
|
||||
args: tuple[object, ...],
|
||||
error_handler: Callable[[Exception], None] | None,
|
||||
fallback_value: object | None,
|
||||
) -> object:
|
||||
"""Common error handling logic for both sync and async functions.
|
||||
|
||||
Args:
|
||||
@@ -297,12 +354,20 @@ def _handle_error(
|
||||
error_handler(error)
|
||||
|
||||
# Update state with error if available
|
||||
state = args[0] if args and isinstance(args[0], dict) else None
|
||||
if state and "errors" in state:
|
||||
if not isinstance(state["errors"], list):
|
||||
state["errors"] = []
|
||||
state = (
|
||||
cast(MutableMapping[str, object], args[0])
|
||||
if args and isinstance(args[0], MutableMapping)
|
||||
else None
|
||||
)
|
||||
if state is not None:
|
||||
errors_value = state.get("errors")
|
||||
if isinstance(errors_value, list):
|
||||
errors_list = errors_value
|
||||
else:
|
||||
errors_list = []
|
||||
state["errors"] = errors_list
|
||||
|
||||
state["errors"].append(
|
||||
errors_list.append(
|
||||
{
|
||||
"node": func_name,
|
||||
"error": str(error),
|
||||
@@ -314,13 +379,14 @@ def _handle_error(
|
||||
# Return fallback value or re-raise
|
||||
if fallback_value is not None:
|
||||
return fallback_value
|
||||
else:
|
||||
raise
|
||||
|
||||
raise
|
||||
|
||||
|
||||
def handle_errors(
|
||||
error_handler: Callable[[Exception], Any] | None = None, fallback_value: Any = None
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
error_handler: Callable[[Exception], None] | None = None,
|
||||
fallback_value: R | None = None,
|
||||
) -> Callable[[Callable[P, Awaitable[R] | R]], Callable[P, Awaitable[R] | R]]:
|
||||
"""Handle errors with standardized error handling in nodes.
|
||||
|
||||
This decorator provides consistent error handling with optional
|
||||
@@ -334,22 +400,38 @@ def handle_errors(
|
||||
Decorated function with error handling
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def decorator(func: Callable[P, Awaitable[R] | R]) -> Callable[P, Awaitable[R] | R]:
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
return await cast(Callable[P, Awaitable[R]], func)(*args, **kwargs)
|
||||
except Exception as e:
|
||||
return _handle_error(e, func.__name__, args, error_handler, fallback_value)
|
||||
result = _handle_error(
|
||||
e,
|
||||
func.__name__,
|
||||
cast(tuple[object, ...], args),
|
||||
error_handler,
|
||||
fallback_value,
|
||||
)
|
||||
return cast(R, result)
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
return cast(Callable[P, R], func)(*args, **kwargs)
|
||||
except Exception as e:
|
||||
return _handle_error(e, func.__name__, args, error_handler, fallback_value)
|
||||
result = _handle_error(
|
||||
e,
|
||||
func.__name__,
|
||||
cast(tuple[object, ...], args),
|
||||
error_handler,
|
||||
fallback_value,
|
||||
)
|
||||
return cast(R, result)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return cast(Callable[P, Awaitable[R] | R], async_wrapper)
|
||||
return cast(Callable[P, Awaitable[R] | R], sync_wrapper)
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -358,7 +440,7 @@ def retry_on_failure(
|
||||
max_attempts: int = 3,
|
||||
backoff_seconds: float = 1.0,
|
||||
exponential_backoff: bool = True,
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
) -> Callable[[Callable[P, Awaitable[R] | R]], Callable[P, Awaitable[R] | R]]:
|
||||
"""Retry node execution on failure.
|
||||
|
||||
Args:
|
||||
@@ -370,14 +452,14 @@ def retry_on_failure(
|
||||
Decorated function with retry logic
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def decorator(func: Callable[P, Awaitable[R] | R]) -> Callable[P, Awaitable[R] | R]:
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
last_exception = None
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
last_exception: Exception | None = None
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
return await cast(Callable[P, Awaitable[R]], func)(*args, **kwargs)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
|
||||
@@ -406,12 +488,12 @@ def retry_on_failure(
|
||||
)
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
last_exception = None
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
last_exception: Exception | None = None
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
return cast(Callable[P, R], func)(*args, **kwargs)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
|
||||
@@ -432,21 +514,22 @@ def retry_on_failure(
|
||||
f"{func.__name__}: {str(e)}"
|
||||
)
|
||||
|
||||
if last_exception:
|
||||
if last_exception is not None:
|
||||
raise last_exception
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Unexpected error in retry logic for {func.__name__}"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Unexpected error in retry logic for {func.__name__}"
|
||||
)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return cast(Callable[P, Awaitable[R] | R], async_wrapper)
|
||||
return cast(Callable[P, Awaitable[R] | R], sync_wrapper)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _extract_context_from_args(
|
||||
args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
args: tuple[object, ...], kwargs: Mapping[str, object]
|
||||
) -> dict[str, object]:
|
||||
"""Extract context information from function arguments.
|
||||
|
||||
Looks for RunnableConfig in args/kwargs and extracts relevant context.
|
||||
@@ -458,44 +541,35 @@ def _extract_context_from_args(
|
||||
Returns:
|
||||
Dictionary with extracted context (run_id, user_id, etc.)
|
||||
"""
|
||||
context = {}
|
||||
context: dict[str, object] = {}
|
||||
|
||||
# Check if we can extract RunnableConfig-like data
|
||||
try:
|
||||
# Check kwargs for config
|
||||
if "config" in kwargs:
|
||||
config = kwargs["config"]
|
||||
# Check if config looks like a RunnableConfig (duck typing for TypedDict)
|
||||
if (
|
||||
isinstance(config, dict)
|
||||
and any(
|
||||
key in config
|
||||
for key in [
|
||||
"tags",
|
||||
"metadata",
|
||||
"callbacks",
|
||||
"run_name",
|
||||
"configurable",
|
||||
]
|
||||
)
|
||||
and "metadata" in config
|
||||
and isinstance(config["metadata"], dict)
|
||||
):
|
||||
context |= config["metadata"]
|
||||
def _sanitize_mapping(meta: Mapping[str, object]) -> dict[str, object]:
|
||||
cleaned: dict[str, object] = {}
|
||||
for key, value in meta.items():
|
||||
if isinstance(value, (str, int, float, bool)) or value is None:
|
||||
cleaned[key] = value
|
||||
else:
|
||||
cleaned[key] = str(value)
|
||||
return cleaned
|
||||
|
||||
# Check args for RunnableConfig
|
||||
for arg in args:
|
||||
# Check if arg looks like a RunnableConfig (duck typing for TypedDict)
|
||||
if isinstance(arg, dict) and any(
|
||||
key in arg
|
||||
for key in ["tags", "metadata", "callbacks", "run_name", "configurable"]
|
||||
):
|
||||
if "metadata" in arg and isinstance(arg["metadata"], dict):
|
||||
context |= arg["metadata"]
|
||||
break
|
||||
except ImportError:
|
||||
# LangChain not available, skip RunnableConfig extraction
|
||||
pass
|
||||
# Prefer explicit config kwarg when available
|
||||
if "config" in kwargs:
|
||||
config = kwargs["config"]
|
||||
if isinstance(config, Mapping):
|
||||
metadata = config.get("metadata")
|
||||
if isinstance(metadata, Mapping):
|
||||
context.update(_sanitize_mapping(metadata))
|
||||
|
||||
# Fallback: inspect args for RunnableConfig-like payloads
|
||||
for arg in args:
|
||||
if not isinstance(arg, Mapping):
|
||||
continue
|
||||
keys = set(arg.keys())
|
||||
if {"metadata"}.issubset(keys) and keys & {"tags", "callbacks", "run_name", "configurable"}:
|
||||
metadata = arg.get("metadata")
|
||||
if isinstance(metadata, Mapping):
|
||||
context.update(_sanitize_mapping(metadata))
|
||||
break
|
||||
|
||||
return context
|
||||
|
||||
@@ -505,7 +579,7 @@ def standard_node(
|
||||
node_name: str | None = None,
|
||||
metric_name: str | None = None,
|
||||
retry_attempts: int = 0,
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
) -> Callable[[Callable[P, Awaitable[R] | R]], Callable[P, Awaitable[R] | R]]:
|
||||
"""Composite decorator applying standard cross-cutting concerns.
|
||||
|
||||
This decorator combines logging, metrics, error handling, and retries
|
||||
@@ -520,9 +594,9 @@ def standard_node(
|
||||
Decorated function with all standard concerns applied
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def decorator(func: Callable[P, Awaitable[R] | R]) -> Callable[P, Awaitable[R] | R]:
|
||||
# Apply decorators in order (innermost to outermost)
|
||||
decorated = func
|
||||
decorated: Callable[P, Awaitable[R] | R] = func
|
||||
|
||||
# Add retry if requested
|
||||
if retry_attempts > 0:
|
||||
@@ -545,7 +619,7 @@ def standard_node(
|
||||
return decorator
|
||||
|
||||
|
||||
def route_error_severity(state: dict[str, Any]) -> str:
|
||||
def route_error_severity(state: Mapping[str, object]) -> str:
|
||||
"""Route based on error severity in state.
|
||||
|
||||
This function examines the state for error information and routes
|
||||
@@ -560,7 +634,7 @@ def route_error_severity(state: dict[str, Any]) -> str:
|
||||
# Check for errors in state
|
||||
if "errors" in state and state["errors"]:
|
||||
errors = state["errors"]
|
||||
if isinstance(errors, list) and errors:
|
||||
if isinstance(errors, Sequence) and errors:
|
||||
# Count critical errors
|
||||
critical_count = 0
|
||||
for error in errors:
|
||||
@@ -606,7 +680,7 @@ def route_error_severity(state: dict[str, Any]) -> str:
|
||||
return "END"
|
||||
|
||||
|
||||
def route_llm_output(state: dict[str, Any]) -> str:
|
||||
def route_llm_output(state: Mapping[str, object]) -> str:
|
||||
"""Route based on LLM output analysis.
|
||||
|
||||
This function examines LLM output in the state and determines
|
||||
|
||||
@@ -1,30 +1,29 @@
|
||||
"""Centralized graph builder to reduce duplication in graph creation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Awaitable, Callable, Generic, Sequence, TypeVar, Union, cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.state import CachePolicy, CompiledStateGraph, RetryPolicy
|
||||
from langgraph.cache.base import BaseCache
|
||||
from langgraph.store.base import BaseStore
|
||||
from langgraph.types import All
|
||||
|
||||
from biz_bud.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Type variable for state types - using Any for maximum compatibility with LangGraph
|
||||
StateT = TypeVar("StateT", bound=Any)
|
||||
|
||||
# Type aliases for clarity - LangGraph nodes can be sync or async with various signatures
|
||||
|
||||
NodeFunction = Any # Maximum flexibility for LangGraph node functions
|
||||
RouterFunction = Callable[[Any], str] # Routers should return strings for LangGraph compatibility
|
||||
ConditionalMapping = dict[str, str]
|
||||
"""Centralized graph builder to reduce duplication in graph creation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Generic, TypeVar, cast
|
||||
|
||||
from collections.abc import Awaitable, Callable, Mapping, Sequence
|
||||
|
||||
from langgraph.cache.base import BaseCache
|
||||
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.state import CachePolicy, CompiledStateGraph as CompiledGraph, RetryPolicy
|
||||
from langgraph.store.base import BaseStore
|
||||
from langgraph.types import All
|
||||
|
||||
from biz_bud.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Type aliases for clarity - LangGraph nodes can be sync or async with various signatures
|
||||
RuntimeState = dict[str, object]
|
||||
StateT = TypeVar("StateT", bound=RuntimeState)
|
||||
NodeFunction = Callable[..., object | Awaitable[object]]
|
||||
RouterFunction = Callable[[Mapping[str, object]], str]
|
||||
ConditionalMapping = dict[str, str]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -33,10 +32,10 @@ class NodeConfig:
|
||||
|
||||
func: NodeFunction
|
||||
defer: bool = False
|
||||
metadata: dict[str, Any] | None = None
|
||||
input_schema: type[Any] | None = None
|
||||
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None
|
||||
cache_policy: CachePolicy | None = None
|
||||
metadata: dict[str, object] | None = None
|
||||
input_schema: type[object] | None = None
|
||||
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None
|
||||
cache_policy: CachePolicy | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -55,29 +54,29 @@ class ConditionalEdgeConfig:
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphBuilderConfig(Generic[StateT]):
|
||||
class GraphBuilderConfig(Generic[StateT]):
|
||||
"""Configuration for building a StateGraph."""
|
||||
|
||||
state_class: type[StateT]
|
||||
context_class: type[Any] | None = None
|
||||
input_schema: type[Any] | None = None
|
||||
output_schema: type[Any] | None = None
|
||||
nodes: dict[str, NodeFunction | NodeConfig] = field(default_factory=dict)
|
||||
edges: list[EdgeConfig] = field(default_factory=list)
|
||||
conditional_edges: list[ConditionalEdgeConfig] = field(default_factory=list)
|
||||
checkpointer: BaseCheckpointSaver[Any] | None = None
|
||||
cache: BaseCache[Any] | None = None
|
||||
store: BaseStore[Any] | None = None
|
||||
interrupt_before: All | list[str] | None = None
|
||||
interrupt_after: All | list[str] | None = None
|
||||
debug: bool = False
|
||||
name: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
context_class: type[object] | None = None
|
||||
input_schema: type[object] | None = None
|
||||
output_schema: type[object] | None = None
|
||||
nodes: dict[str, NodeFunction | NodeConfig] = field(default_factory=dict)
|
||||
edges: list[EdgeConfig] = field(default_factory=list)
|
||||
conditional_edges: list[ConditionalEdgeConfig] = field(default_factory=list)
|
||||
checkpointer: BaseCheckpointSaver[RuntimeState] | None = None
|
||||
cache: BaseCache[object] | None = None
|
||||
store: BaseStore[object] | None = None
|
||||
interrupt_before: All | list[str] | None = None
|
||||
interrupt_after: All | list[str] | None = None
|
||||
debug: bool = False
|
||||
name: str | None = None
|
||||
metadata: dict[str, object] = field(default_factory=dict)
|
||||
|
||||
|
||||
def build_graph_from_config(
|
||||
config: GraphBuilderConfig[StateT],
|
||||
) -> CompiledStateGraph[StateT]:
|
||||
def build_graph_from_config(
|
||||
config: GraphBuilderConfig[StateT],
|
||||
) -> CompiledGraph[StateT]:
|
||||
"""Build a StateGraph from configuration.
|
||||
|
||||
This function reduces duplication by centralizing the graph building logic.
|
||||
@@ -154,7 +153,7 @@ def build_graph_from_config(
|
||||
logger.debug(f"Added conditional edge from: {cond_edge.source}")
|
||||
|
||||
# Compile with optional runtime configuration
|
||||
compiled = builder.compile(
|
||||
compiled = builder.compile(
|
||||
checkpointer=config.checkpointer,
|
||||
cache=config.cache,
|
||||
store=config.store,
|
||||
@@ -169,38 +168,45 @@ def build_graph_from_config(
|
||||
setattr(compiled.builder, "config", config)
|
||||
|
||||
logger.info("Graph compilation completed successfully")
|
||||
return compiled
|
||||
return compiled
|
||||
|
||||
|
||||
class GraphBuilder(Generic[StateT]):
|
||||
"""Fluent API for building graphs."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state_class: type[StateT],
|
||||
*,
|
||||
context_class: type[Any] | None = None,
|
||||
input_schema: type[Any] | None = None,
|
||||
output_schema: type[Any] | None = None,
|
||||
):
|
||||
"""Initialize the builder with a state class."""
|
||||
self.config = GraphBuilderConfig(
|
||||
state_class=state_class,
|
||||
context_class=context_class,
|
||||
input_schema=input_schema,
|
||||
output_schema=output_schema,
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
state_class: type[StateT],
|
||||
*,
|
||||
context_class: type[object] | None = None,
|
||||
input_schema: type[object] | None = None,
|
||||
output_schema: type[object] | None = None,
|
||||
):
|
||||
"""Initialize the builder with a state class."""
|
||||
if not isinstance(state_class, type): # pragma: no cover - defensive guard
|
||||
raise TypeError("GraphBuilder requires a state class type.")
|
||||
if not issubclass(state_class, Mapping) and not hasattr(state_class, "get"):
|
||||
raise TypeError(
|
||||
"GraphBuilder requires a mapping-compatible state class that exposes a 'get' method."
|
||||
)
|
||||
|
||||
self.config = GraphBuilderConfig(
|
||||
state_class=state_class,
|
||||
context_class=context_class,
|
||||
input_schema=input_schema,
|
||||
output_schema=output_schema,
|
||||
)
|
||||
|
||||
def add_node(
|
||||
self,
|
||||
name: str,
|
||||
func: NodeFunction,
|
||||
*,
|
||||
defer: bool = False,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
input_schema: type[Any] | None = None,
|
||||
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None,
|
||||
cache_policy: CachePolicy | None = None,
|
||||
name: str,
|
||||
func: NodeFunction,
|
||||
*,
|
||||
defer: bool = False,
|
||||
metadata: dict[str, object] | None = None,
|
||||
input_schema: type[object] | None = None,
|
||||
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None,
|
||||
cache_policy: CachePolicy | None = None,
|
||||
) -> "GraphBuilder[StateT]":
|
||||
"""Add a node to the graph with modern LangGraph options."""
|
||||
|
||||
@@ -231,40 +237,40 @@ class GraphBuilder(Generic[StateT]):
|
||||
)
|
||||
return self
|
||||
|
||||
def with_checkpointer(
|
||||
self, checkpointer: BaseCheckpointSaver[Any]
|
||||
) -> "GraphBuilder[StateT]":
|
||||
def with_checkpointer(
|
||||
self, checkpointer: BaseCheckpointSaver[RuntimeState]
|
||||
) -> "GraphBuilder[StateT]":
|
||||
"""Set the checkpointer for the graph."""
|
||||
self.config.checkpointer = checkpointer
|
||||
return self
|
||||
|
||||
def with_context(
|
||||
self, context_class: type[Any]
|
||||
) -> "GraphBuilder[StateT]":
|
||||
def with_context(
|
||||
self, context_class: type[object]
|
||||
) -> "GraphBuilder[StateT]":
|
||||
"""Define the context schema for runtime access."""
|
||||
self.config.context_class = context_class
|
||||
return self
|
||||
|
||||
def with_input_schema(
|
||||
self, input_schema: type[Any]
|
||||
) -> "GraphBuilder[StateT]":
|
||||
def with_input_schema(
|
||||
self, input_schema: type[object]
|
||||
) -> "GraphBuilder[StateT]":
|
||||
"""Define the input schema for the graph."""
|
||||
self.config.input_schema = input_schema
|
||||
return self
|
||||
|
||||
def with_output_schema(
|
||||
self, output_schema: type[Any]
|
||||
) -> "GraphBuilder[StateT]":
|
||||
def with_output_schema(
|
||||
self, output_schema: type[object]
|
||||
) -> "GraphBuilder[StateT]":
|
||||
"""Define the output schema for the graph."""
|
||||
self.config.output_schema = output_schema
|
||||
return self
|
||||
|
||||
def with_cache(self, cache: BaseCache[Any]) -> "GraphBuilder[StateT]":
|
||||
def with_cache(self, cache: BaseCache[object]) -> "GraphBuilder[StateT]":
|
||||
"""Attach a LangGraph cache implementation."""
|
||||
self.config.cache = cache
|
||||
return self
|
||||
|
||||
def with_store(self, store: BaseStore[Any]) -> "GraphBuilder[StateT]":
|
||||
def with_store(self, store: BaseStore[object]) -> "GraphBuilder[StateT]":
|
||||
"""Attach a LangGraph store implementation."""
|
||||
self.config.store = store
|
||||
return self
|
||||
@@ -290,31 +296,31 @@ class GraphBuilder(Generic[StateT]):
|
||||
self.config.debug = enabled
|
||||
return self
|
||||
|
||||
def with_metadata(self, **kwargs: Any) -> "GraphBuilder[StateT]":
|
||||
"""Add metadata to the graph."""
|
||||
self.config.metadata.update(kwargs)
|
||||
return self
|
||||
def with_metadata(self, **kwargs: object) -> "GraphBuilder[StateT]":
|
||||
"""Add metadata to the graph."""
|
||||
self.config.metadata.update(kwargs)
|
||||
return self
|
||||
|
||||
def build(self) -> CompiledStateGraph[StateT]:
|
||||
def build(self) -> CompiledGraph[StateT]:
|
||||
"""Build and compile the graph."""
|
||||
return build_graph_from_config(self.config)
|
||||
|
||||
|
||||
def create_simple_linear_graph(
|
||||
state_class: type[StateT],
|
||||
nodes: list[tuple[str, NodeFunction] | tuple[str, NodeFunction, dict[str, Any]] | tuple[str, NodeFunction, NodeConfig]],
|
||||
checkpointer: BaseCheckpointSaver[Any] | None = None,
|
||||
*,
|
||||
context_class: type[Any] | None = None,
|
||||
input_schema: type[Any] | None = None,
|
||||
output_schema: type[Any] | None = None,
|
||||
cache: BaseCache[Any] | None = None,
|
||||
store: BaseStore[Any] | None = None,
|
||||
def create_simple_linear_graph(
|
||||
state_class: type[StateT],
|
||||
nodes: list[tuple[str, NodeFunction] | tuple[str, NodeFunction, dict[str, object]] | tuple[str, NodeFunction, NodeConfig]],
|
||||
checkpointer: BaseCheckpointSaver[RuntimeState] | None = None,
|
||||
*,
|
||||
context_class: type[object] | None = None,
|
||||
input_schema: type[object] | None = None,
|
||||
output_schema: type[object] | None = None,
|
||||
cache: BaseCache[object] | None = None,
|
||||
store: BaseStore[object] | None = None,
|
||||
name: str | None = None,
|
||||
interrupt_before: All | list[str] | None = None,
|
||||
interrupt_after: All | list[str] | None = None,
|
||||
debug: bool = False,
|
||||
) -> CompiledStateGraph[StateT]:
|
||||
) -> CompiledGraph[StateT]:
|
||||
"""Create a simple linear graph where nodes execute in sequence.
|
||||
|
||||
Args:
|
||||
@@ -348,21 +354,21 @@ def create_simple_linear_graph(
|
||||
|
||||
# Add all nodes
|
||||
for entry in nodes:
|
||||
if len(entry) == 2:
|
||||
name, func = entry
|
||||
node_kwargs: dict[str, Any] = {}
|
||||
else:
|
||||
name, func, extras = entry
|
||||
if isinstance(extras, NodeConfig):
|
||||
node_kwargs = {
|
||||
"defer": extras.defer,
|
||||
if len(entry) == 2:
|
||||
name, func = entry
|
||||
node_kwargs: dict[str, object] = {}
|
||||
else:
|
||||
name, func, extras = entry
|
||||
if isinstance(extras, NodeConfig):
|
||||
node_kwargs = {
|
||||
"defer": extras.defer,
|
||||
"metadata": extras.metadata,
|
||||
"input_schema": extras.input_schema,
|
||||
"retry_policy": extras.retry_policy,
|
||||
"cache_policy": extras.cache_policy,
|
||||
}
|
||||
else:
|
||||
node_kwargs = cast(dict[str, Any], extras)
|
||||
node_kwargs = cast(dict[str, object], extras)
|
||||
builder.add_node(name, func, **node_kwargs)
|
||||
|
||||
# Connect nodes linearly
|
||||
@@ -388,30 +394,30 @@ def create_simple_linear_graph(
|
||||
return builder.build()
|
||||
|
||||
|
||||
def create_branching_graph(
|
||||
state_class: type[StateT],
|
||||
initial_node: tuple[str, NodeFunction],
|
||||
router: RouterFunction,
|
||||
branches: dict[
|
||||
str,
|
||||
list[
|
||||
tuple[str, NodeFunction]
|
||||
| tuple[str, NodeFunction, dict[str, Any]]
|
||||
| tuple[str, NodeFunction, NodeConfig]
|
||||
],
|
||||
],
|
||||
checkpointer: BaseCheckpointSaver[Any] | None = None,
|
||||
*,
|
||||
context_class: type[Any] | None = None,
|
||||
input_schema: type[Any] | None = None,
|
||||
output_schema: type[Any] | None = None,
|
||||
cache: BaseCache[Any] | None = None,
|
||||
store: BaseStore[Any] | None = None,
|
||||
def create_branching_graph(
|
||||
state_class: type[StateT],
|
||||
initial_node: tuple[str, NodeFunction],
|
||||
router: RouterFunction,
|
||||
branches: dict[
|
||||
str,
|
||||
list[
|
||||
tuple[str, NodeFunction]
|
||||
| tuple[str, NodeFunction, dict[str, object]]
|
||||
| tuple[str, NodeFunction, NodeConfig]
|
||||
],
|
||||
],
|
||||
checkpointer: BaseCheckpointSaver[RuntimeState] | None = None,
|
||||
*,
|
||||
context_class: type[object] | None = None,
|
||||
input_schema: type[object] | None = None,
|
||||
output_schema: type[object] | None = None,
|
||||
cache: BaseCache[object] | None = None,
|
||||
store: BaseStore[object] | None = None,
|
||||
name: str | None = None,
|
||||
interrupt_before: All | list[str] | None = None,
|
||||
interrupt_after: All | list[str] | None = None,
|
||||
debug: bool = False,
|
||||
) -> CompiledStateGraph[StateT]:
|
||||
) -> CompiledGraph[StateT]:
|
||||
"""Create a graph with branching logic.
|
||||
|
||||
Args:
|
||||
@@ -445,21 +451,21 @@ def create_branching_graph(
|
||||
|
||||
# Add nodes in this branch
|
||||
for entry in branch_nodes:
|
||||
if len(entry) == 2:
|
||||
name, func = entry
|
||||
node_kwargs: dict[str, Any] = {}
|
||||
else:
|
||||
name, func, extras = entry
|
||||
if isinstance(extras, NodeConfig):
|
||||
node_kwargs = {
|
||||
"defer": extras.defer,
|
||||
if len(entry) == 2:
|
||||
name, func = entry
|
||||
node_kwargs: dict[str, object] = {}
|
||||
else:
|
||||
name, func, extras = entry
|
||||
if isinstance(extras, NodeConfig):
|
||||
node_kwargs = {
|
||||
"defer": extras.defer,
|
||||
"metadata": extras.metadata,
|
||||
"input_schema": extras.input_schema,
|
||||
"retry_policy": extras.retry_policy,
|
||||
"cache_policy": extras.cache_policy,
|
||||
}
|
||||
else:
|
||||
node_kwargs = cast(dict[str, Any], extras)
|
||||
node_kwargs = cast(dict[str, object], extras)
|
||||
builder.add_node(name, func, **node_kwargs)
|
||||
|
||||
# Connect nodes within branch
|
||||
|
||||
@@ -1,27 +1,33 @@
|
||||
"""Graph configuration utilities for LangGraph integration.
|
||||
|
||||
This module provides utilities for configuring graphs with RunnableConfig,
|
||||
enabling consistent configuration injection across all nodes and tools.
|
||||
This module centralises helpers that wrap LangGraph nodes with
|
||||
``RunnableConfig`` metadata so graph execution receives application
|
||||
configuration, service factories, and runtime overrides consistently.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, cast
|
||||
from collections.abc import Awaitable, Callable, Mapping
|
||||
from copy import copy
|
||||
from typing import TypeVar, cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.runnables import RunnableConfig, RunnableLambda
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from .runnable_config import ConfigurationProvider, create_runnable_config
|
||||
|
||||
|
||||
StateT = TypeVar("StateT", bound=dict[str, object])
|
||||
NodeCallable = Callable[..., object | Awaitable[object]]
|
||||
|
||||
|
||||
def configure_graph_with_injection(
|
||||
graph_builder: StateGraph[Any],
|
||||
app_config: Any,
|
||||
service_factory: Any | None = None,
|
||||
**config_overrides: Any,
|
||||
) -> StateGraph[Any]:
|
||||
graph_builder: StateGraph[StateT],
|
||||
app_config: object,
|
||||
service_factory: object | None = None,
|
||||
**config_overrides: object,
|
||||
) -> StateGraph[StateT]:
|
||||
"""Configure a graph builder with dependency injection.
|
||||
|
||||
This function wraps all nodes in the graph to automatically inject
|
||||
@@ -44,25 +50,51 @@ def configure_graph_with_injection(
|
||||
# Get all nodes that have been added
|
||||
nodes = graph_builder.nodes
|
||||
|
||||
# Wrap each node to inject configuration
|
||||
for node_name, node_func in nodes.items():
|
||||
# Skip if already configured
|
||||
if hasattr(node_func, "_runnable_config_injected"):
|
||||
# Plan node replacements without mutating while iterating
|
||||
replacements: list[tuple[str, object]] = []
|
||||
for node_name, node_entry in list(nodes.items()):
|
||||
target = _resolve_node_callable(node_entry)
|
||||
if target is None or hasattr(target, "_runnable_config_injected"):
|
||||
continue
|
||||
|
||||
# Create wrapper that injects config
|
||||
if callable(node_func):
|
||||
callable_node = cast(Callable[..., Any] | Callable[..., Awaitable[Any]], node_func)
|
||||
wrapped_node = create_config_injected_node(callable_node, base_config)
|
||||
# Replace the node
|
||||
graph_builder.nodes[node_name] = wrapped_node
|
||||
wrapped_node = create_config_injected_node(target, base_config)
|
||||
|
||||
if hasattr(node_entry, "func"):
|
||||
new_entry = copy(node_entry)
|
||||
setattr(new_entry, "func", wrapped_node)
|
||||
replacements.append((node_name, new_entry))
|
||||
else:
|
||||
replacements.append((node_name, wrapped_node))
|
||||
|
||||
for node_name, new_entry in replacements:
|
||||
try:
|
||||
graph_builder.add_node(node_name, new_entry) # type: ignore[arg-type]
|
||||
except Exception as exc:
|
||||
raise RuntimeError(
|
||||
f"Failed to inject configuration for node '{node_name}'. "
|
||||
"Ensure injection occurs before graph compilation."
|
||||
) from exc
|
||||
|
||||
return graph_builder
|
||||
|
||||
|
||||
def _resolve_node_callable(node_entry: object) -> NodeCallable | None:
|
||||
"""Extract the callable associated with a node entry."""
|
||||
|
||||
if hasattr(node_entry, "func"):
|
||||
func = getattr(node_entry, "func")
|
||||
if callable(func):
|
||||
return cast(NodeCallable, func)
|
||||
|
||||
if callable(node_entry):
|
||||
return cast(NodeCallable, node_entry)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def create_config_injected_node(
|
||||
node_func: Callable[..., Any] | Callable[..., Awaitable[Any]], base_config: RunnableConfig
|
||||
) -> Any:
|
||||
node_func: NodeCallable, base_config: RunnableConfig
|
||||
) -> object:
|
||||
"""Create a node wrapper that injects RunnableConfig.
|
||||
|
||||
Args:
|
||||
@@ -76,42 +108,64 @@ def create_config_injected_node(
|
||||
import inspect
|
||||
from functools import wraps
|
||||
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
|
||||
sig = inspect.signature(node_func)
|
||||
expects_config = "config" in sig.parameters
|
||||
accepts_config = "config" in sig.parameters
|
||||
is_coroutine = asyncio.iscoroutinefunction(node_func)
|
||||
|
||||
def _merge_config(runtime: RunnableConfig | Mapping[str, object] | None) -> RunnableConfig:
|
||||
provider = ConfigurationProvider(base_config)
|
||||
if runtime is None:
|
||||
return provider.to_runnable_config()
|
||||
if isinstance(runtime, RunnableConfig):
|
||||
return provider.merge_with(runtime).to_runnable_config()
|
||||
if isinstance(runtime, Mapping):
|
||||
runtime_config = RunnableConfig()
|
||||
for key, value in runtime.items():
|
||||
runtime_config[key] = value
|
||||
return provider.merge_with(runtime_config).to_runnable_config()
|
||||
return provider.to_runnable_config()
|
||||
|
||||
def _inject_kwargs(
|
||||
args: tuple[object, ...],
|
||||
kwargs: dict[str, object | None],
|
||||
cfg: RunnableConfig,
|
||||
) -> tuple[tuple[object, ...], dict[str, object]]:
|
||||
if accepts_config and "config" not in kwargs:
|
||||
updated_kwargs: dict[str, object] = dict(kwargs)
|
||||
updated_kwargs["config"] = cfg
|
||||
return args, updated_kwargs
|
||||
return args, cast(dict[str, object], kwargs)
|
||||
|
||||
if is_coroutine:
|
||||
|
||||
if expects_config:
|
||||
# Node already expects config, wrap to provide it
|
||||
@wraps(node_func)
|
||||
async def config_aware_wrapper(
|
||||
state: dict[str, Any], config: RunnableConfig | None = None
|
||||
# noqa: ARG001
|
||||
) -> Any:
|
||||
# Merge base config with runtime config
|
||||
if config:
|
||||
provider = ConfigurationProvider(base_config)
|
||||
merged_provider = provider.merge_with(config)
|
||||
merged_config = merged_provider.to_runnable_config()
|
||||
else:
|
||||
merged_config = base_config
|
||||
async def wrapper(*args: object, **kwargs: object) -> object:
|
||||
cfg_arg = kwargs.get("config")
|
||||
merged = _merge_config(
|
||||
cast(RunnableConfig | Mapping[str, object] | None, cfg_arg)
|
||||
)
|
||||
args2, kwargs2 = _inject_kwargs(args, kwargs, merged)
|
||||
return await cast(Awaitable[object], node_func(*args2, **kwargs2))
|
||||
|
||||
# Call original node with merged config
|
||||
if asyncio.iscoroutinefunction(node_func):
|
||||
coro_result = node_func(state, config=merged_config)
|
||||
return await cast(Awaitable[Any], coro_result)
|
||||
else:
|
||||
return node_func(state, config=merged_config)
|
||||
|
||||
wrapped: Any = RunnableLambda(config_aware_wrapper).with_config(base_config)
|
||||
else:
|
||||
# Node doesn't expect config, just wrap with config
|
||||
wrapped = RunnableLambda(node_func).with_config(base_config)
|
||||
|
||||
# Mark as injected to avoid double wrapping
|
||||
wrapped._runnable_config_injected = True
|
||||
@wraps(node_func)
|
||||
def wrapper(*args: object, **kwargs: object) -> object:
|
||||
cfg_arg = kwargs.get("config")
|
||||
merged = _merge_config(
|
||||
cast(RunnableConfig | Mapping[str, object] | None, cfg_arg)
|
||||
)
|
||||
args2, kwargs2 = _inject_kwargs(args, kwargs, merged)
|
||||
return node_func(*args2, **kwargs2)
|
||||
|
||||
return wrapped
|
||||
try:
|
||||
wrapper.__dict__.update(copy(getattr(node_func, "__dict__", {}))) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
pass
|
||||
setattr(wrapper, "_runnable_config_injected", True)
|
||||
runnable = RunnableLambda(wrapper).with_config(base_config)
|
||||
setattr(runnable, "_runnable_config_injected", True)
|
||||
return runnable
|
||||
|
||||
|
||||
def update_node_to_use_config(
|
||||
@@ -137,42 +191,25 @@ def update_node_to_use_config(
|
||||
|
||||
sig = inspect.signature(node_func)
|
||||
|
||||
# Check if already accepts config
|
||||
# If the node already consumes a config argument we return it untouched so
|
||||
# callers keep the original implementation.
|
||||
if "config" in sig.parameters:
|
||||
return node_func
|
||||
|
||||
@wraps(node_func)
|
||||
async def wrapper(
|
||||
state: dict[str, object], config: RunnableConfig | None = None
|
||||
# noqa: ARG001
|
||||
) -> object:
|
||||
# Call original without config (for backward compatibility)
|
||||
async def wrapper(state: dict[str, object], config: RunnableConfig) -> object:
|
||||
if asyncio.iscoroutinefunction(node_func):
|
||||
coro_result = node_func(state)
|
||||
return await cast(Awaitable[Any], coro_result)
|
||||
else:
|
||||
return node_func(state)
|
||||
return await cast(Awaitable[object], coro_result)
|
||||
return node_func(state)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def extract_config_from_state(state: dict[str, object]) -> object | None:
|
||||
"""Extract AppConfig from state for backward compatibility.
|
||||
"""Return the runnable configuration stored on the state mapping."""
|
||||
|
||||
Args:
|
||||
state: The graph state dictionary.
|
||||
|
||||
Returns:
|
||||
The AppConfig if found in state, None otherwise.
|
||||
"""
|
||||
# Check common locations
|
||||
if "config" in state and isinstance(state["config"], dict):
|
||||
# Try to reconstruct AppConfig from dict
|
||||
try:
|
||||
# Import would need to be dynamic to avoid circular dependencies
|
||||
# This is a placeholder implementation
|
||||
return state["config"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return state.get("app_config")
|
||||
cfg = state.get("config")
|
||||
if isinstance(cfg, Mapping):
|
||||
return dict(cfg)
|
||||
return cfg
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
"""RunnableConfig utilities for LangGraph integration.
|
||||
"""Utilities for working with LangChain's :class:`RunnableConfig`."""
|
||||
|
||||
This module provides utilities for working with LangChain's RunnableConfig
|
||||
to enable configuration injection and dependency management in LangGraph workflows.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, TypeVar, cast
|
||||
from collections.abc import Mapping
|
||||
from typing import TypeVar, cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
@@ -26,28 +25,35 @@ class ConfigurationProvider:
|
||||
"""
|
||||
self._config = config or RunnableConfig()
|
||||
|
||||
def _get_configurable(self) -> dict[str, Any]:
|
||||
def _get_configurable(self) -> Mapping[str, object] | object:
|
||||
"""Get the configurable section from config.
|
||||
|
||||
Handles both dict-like and attribute-like access to RunnableConfig.
|
||||
|
||||
Returns:
|
||||
The configurable dictionary.
|
||||
The configurable mapping or the original value when not mapping.
|
||||
"""
|
||||
return self._config.get("configurable", {})
|
||||
configurable = self._config.get("configurable", {})
|
||||
if isinstance(configurable, Mapping):
|
||||
return dict(configurable)
|
||||
return configurable
|
||||
|
||||
def get_metadata(self, key: str, default: Any = None) -> Any:
|
||||
"""Get metadata value from the configuration.
|
||||
def _get_mapping(
|
||||
self, key: str, config: RunnableConfig | None = None
|
||||
) -> dict[str, object]:
|
||||
"""Safely extract a mapping from the runnable configuration."""
|
||||
|
||||
Args:
|
||||
key: The metadata key to retrieve.
|
||||
default: Default value if key not found.
|
||||
source = (config or self._config).get(key, {})
|
||||
if isinstance(source, Mapping):
|
||||
return dict(source)
|
||||
return {}
|
||||
|
||||
Returns:
|
||||
The metadata value or default.
|
||||
"""
|
||||
def get_metadata(self, key: str, default: object | None = None) -> object | None:
|
||||
"""Get metadata value from the configuration."""
|
||||
metadata = self._config.get("metadata", {})
|
||||
return metadata.get(key, default)
|
||||
if isinstance(metadata, Mapping):
|
||||
return dict(metadata).get(key, default)
|
||||
return default
|
||||
|
||||
def get_run_id(self) -> str | None:
|
||||
"""Get the current run ID from configuration.
|
||||
@@ -76,23 +82,27 @@ class ConfigurationProvider:
|
||||
session_id = self.get_metadata("session_id")
|
||||
return session_id if isinstance(session_id, str) else None
|
||||
|
||||
def get_app_config(self) -> Any | None:
|
||||
def get_app_config(self) -> object | None:
|
||||
"""Get the application configuration object.
|
||||
|
||||
Returns:
|
||||
The app configuration if available, None otherwise.
|
||||
"""
|
||||
configurable = self._get_configurable()
|
||||
return configurable.get("app_config")
|
||||
if isinstance(configurable, Mapping):
|
||||
return configurable.get("app_config")
|
||||
return None
|
||||
|
||||
def get_service_factory(self) -> Any | None:
|
||||
def get_service_factory(self) -> object | None:
|
||||
"""Get the service factory instance.
|
||||
|
||||
Returns:
|
||||
The service factory if available, None otherwise.
|
||||
"""
|
||||
configurable = self._get_configurable()
|
||||
return configurable.get("service_factory")
|
||||
if isinstance(configurable, Mapping):
|
||||
return configurable.get("service_factory")
|
||||
return None
|
||||
|
||||
def get_llm_profile(self) -> str:
|
||||
"""Get the LLM profile to use.
|
||||
@@ -101,7 +111,10 @@ class ConfigurationProvider:
|
||||
The LLM profile name, defaults to "large".
|
||||
"""
|
||||
configurable = self._get_configurable()
|
||||
profile = configurable.get("llm_profile_override", "large")
|
||||
if isinstance(configurable, Mapping):
|
||||
profile = configurable.get("llm_profile_override", "large")
|
||||
else:
|
||||
profile = "large"
|
||||
return profile if isinstance(profile, str) else "large"
|
||||
|
||||
def get_temperature_override(self) -> float | None:
|
||||
@@ -111,8 +124,10 @@ class ConfigurationProvider:
|
||||
The temperature override if set, None otherwise.
|
||||
"""
|
||||
configurable = self._get_configurable()
|
||||
value = configurable.get("temperature_override")
|
||||
return float(value) if isinstance(value, int | float) else None
|
||||
if isinstance(configurable, Mapping):
|
||||
value = configurable.get("temperature_override")
|
||||
return float(value) if isinstance(value, int | float) else None
|
||||
return None
|
||||
|
||||
def get_max_tokens_override(self) -> int | None:
|
||||
"""Get max tokens override for LLM calls.
|
||||
@@ -121,6 +136,8 @@ class ConfigurationProvider:
|
||||
The max tokens override if set, None otherwise.
|
||||
"""
|
||||
configurable = self._get_configurable()
|
||||
if not isinstance(configurable, Mapping):
|
||||
return None
|
||||
value = configurable.get("max_tokens_override")
|
||||
try:
|
||||
return int(value) if isinstance(value, int | float) else None
|
||||
@@ -134,7 +151,10 @@ class ConfigurationProvider:
|
||||
True if streaming is enabled, False otherwise.
|
||||
"""
|
||||
configurable = self._get_configurable()
|
||||
value = configurable.get("streaming_enabled", False)
|
||||
if isinstance(configurable, Mapping):
|
||||
value = configurable.get("streaming_enabled", False)
|
||||
else:
|
||||
value = False
|
||||
return bool(value)
|
||||
|
||||
def is_metrics_enabled(self) -> bool:
|
||||
@@ -144,7 +164,10 @@ class ConfigurationProvider:
|
||||
True if metrics are enabled, False otherwise.
|
||||
"""
|
||||
configurable = self._get_configurable()
|
||||
value = configurable.get("metrics_enabled", True)
|
||||
if isinstance(configurable, Mapping):
|
||||
value = configurable.get("metrics_enabled", True)
|
||||
else:
|
||||
value = True
|
||||
return bool(value)
|
||||
|
||||
def get_custom_value(self, key: str, default: T | None = None) -> T | None:
|
||||
@@ -158,8 +181,10 @@ class ConfigurationProvider:
|
||||
The configuration value or default.
|
||||
"""
|
||||
configurable = self._get_configurable()
|
||||
result = configurable.get(key, default)
|
||||
return cast("T", result) if result is not None else default
|
||||
if isinstance(configurable, Mapping):
|
||||
result = configurable.get(key, default)
|
||||
return cast(T, result) if result is not None else default
|
||||
return default
|
||||
|
||||
def merge_with(self, other: RunnableConfig) -> "ConfigurationProvider":
|
||||
"""Create a new provider by merging with another config.
|
||||
@@ -174,13 +199,13 @@ class ConfigurationProvider:
|
||||
merged = RunnableConfig()
|
||||
|
||||
# Copy metadata
|
||||
self_metadata = self._config.get("metadata", {})
|
||||
other_metadata = other.get("metadata", {})
|
||||
self_metadata = self._get_mapping("metadata")
|
||||
other_metadata = self._get_mapping("metadata", other)
|
||||
merged["metadata"] = {**self_metadata, **other_metadata}
|
||||
|
||||
# Copy configurable
|
||||
self_configurable = self._config.get("configurable", {})
|
||||
other_configurable = other.get("configurable", {})
|
||||
self_configurable = self._get_mapping("configurable")
|
||||
other_configurable = self._get_mapping("configurable", other)
|
||||
merged["configurable"] = {**self_configurable, **other_configurable}
|
||||
|
||||
# Copy other attributes
|
||||
@@ -202,7 +227,10 @@ class ConfigurationProvider:
|
||||
|
||||
@classmethod
|
||||
def from_app_config(
|
||||
cls, app_config: Any, service_factory: Any | None = None, **metadata: Any
|
||||
cls,
|
||||
app_config: object,
|
||||
service_factory: object | None = None,
|
||||
**metadata: object,
|
||||
) -> "ConfigurationProvider":
|
||||
"""Create a provider from app configuration.
|
||||
|
||||
@@ -217,21 +245,21 @@ class ConfigurationProvider:
|
||||
config = RunnableConfig()
|
||||
|
||||
# Set configurable values
|
||||
configurable_dict = {
|
||||
configurable_dict: dict[str, object | None] = {
|
||||
"app_config": app_config,
|
||||
"service_factory": service_factory,
|
||||
}
|
||||
config["configurable"] = configurable_dict
|
||||
|
||||
# Set metadata
|
||||
config["metadata"] = metadata
|
||||
config["metadata"] = dict(metadata)
|
||||
|
||||
return cls(config)
|
||||
|
||||
|
||||
def create_runnable_config(
|
||||
app_config: Any | None = None,
|
||||
service_factory: Any | None = None,
|
||||
app_config: object | None = None,
|
||||
service_factory: object | None = None,
|
||||
llm_profile_override: str | None = None,
|
||||
temperature_override: float | None = None,
|
||||
max_tokens_override: int | None = None,
|
||||
@@ -240,7 +268,7 @@ def create_runnable_config(
|
||||
run_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
**custom_values: Any,
|
||||
**custom_values: object,
|
||||
) -> RunnableConfig:
|
||||
"""Create a RunnableConfig with common settings.
|
||||
|
||||
@@ -263,7 +291,7 @@ def create_runnable_config(
|
||||
config = RunnableConfig()
|
||||
|
||||
# Set configurable values
|
||||
configurable = {
|
||||
configurable: dict[str, object] = {
|
||||
"metrics_enabled": metrics_enabled,
|
||||
"streaming_enabled": streaming_enabled,
|
||||
**custom_values,
|
||||
|
||||
@@ -15,23 +15,27 @@ used during development and testing. For production, consider:
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from collections.abc import Callable
|
||||
from typing import Any, TypeVar, cast
|
||||
from collections.abc import (
|
||||
Callable,
|
||||
Iterator,
|
||||
ItemsView,
|
||||
KeysView,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Sequence,
|
||||
ValuesView,
|
||||
)
|
||||
from types import MappingProxyType
|
||||
from typing import cast
|
||||
|
||||
try: # pragma: no cover - pandas is optional in lightweight test environments
|
||||
import pandas as pd # type: ignore
|
||||
except ModuleNotFoundError: # pragma: no cover - executed when pandas isn't installed
|
||||
pd = None # type: ignore[assignment]
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from biz_bud.core.errors import ImmutableStateError, StateValidationError
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
CallableT = TypeVar("CallableT", bound=Callable[..., object])
|
||||
|
||||
|
||||
def _states_equal(state1: Any, state2: Any) -> bool:
|
||||
def _states_equal(state1: object, state2: object) -> bool:
|
||||
"""Compare two states safely, handling DataFrames and other complex objects.
|
||||
|
||||
Args:
|
||||
@@ -44,7 +48,7 @@ def _states_equal(state1: Any, state2: Any) -> bool:
|
||||
if type(state1) is not type(state2):
|
||||
return False
|
||||
|
||||
if isinstance(state1, dict) and isinstance(state2, dict):
|
||||
if isinstance(state1, Mapping) and isinstance(state2, Mapping):
|
||||
if set(state1.keys()) != set(state2.keys()):
|
||||
return False
|
||||
|
||||
@@ -53,7 +57,9 @@ def _states_equal(state1: Any, state2: Any) -> bool:
|
||||
return False
|
||||
return True
|
||||
|
||||
elif isinstance(state1, (list, tuple)):
|
||||
elif isinstance(state1, Sequence) and isinstance(state2, Sequence) and not isinstance(
|
||||
state1, (str, bytes, bytearray)
|
||||
):
|
||||
if len(state1) != len(state2):
|
||||
return False
|
||||
return all(_states_equal(a, b) for a, b in zip(state1, state2))
|
||||
@@ -90,221 +96,92 @@ def _states_equal(state1: Any, state2: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class ImmutableDict:
|
||||
"""An immutable dictionary that prevents modifications.
|
||||
class ImmutableDict(Mapping[str, object]):
|
||||
"""An immutable mapping that prevents modifications."""
|
||||
|
||||
This class wraps a regular dict but prevents any modifications,
|
||||
raising ImmutableStateError on mutation attempts.
|
||||
"""
|
||||
def __init__(
|
||||
self, initial: Mapping[str, object] | None = None, **kwargs: object
|
||||
) -> None:
|
||||
data: dict[str, object] = dict(initial or {})
|
||||
if kwargs:
|
||||
data.update(kwargs)
|
||||
self._data = copy.deepcopy(data)
|
||||
self._proxy = MappingProxyType(self._data)
|
||||
|
||||
_data: dict[str, Any]
|
||||
_frozen: bool
|
||||
def _raise_mutation_error(self) -> None:
|
||||
raise ImmutableStateError(
|
||||
"Cannot modify immutable state. Create a new state object instead."
|
||||
)
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Initialize ImmutableDict.
|
||||
def __getitem__(self, key: str) -> object:
|
||||
return self._proxy[key]
|
||||
|
||||
Args:
|
||||
*args: Positional arguments passed to dict constructor.
|
||||
**kwargs: Keyword arguments passed to dict constructor.
|
||||
"""
|
||||
self._data = dict(*args, **kwargs)
|
||||
self._frozen = True
|
||||
|
||||
def _check_frozen(self) -> None:
|
||||
if self._frozen:
|
||||
raise ImmutableStateError(
|
||||
"Cannot modify immutable state. Create a new state object instead."
|
||||
)
|
||||
|
||||
# Read-only dict interface
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
"""Get item by key.
|
||||
|
||||
Args:
|
||||
key: The key to retrieve.
|
||||
|
||||
Returns:
|
||||
The value associated with the key.
|
||||
"""
|
||||
return self._data[key]
|
||||
|
||||
def __contains__(self, key: object) -> bool:
|
||||
"""Check if key is in the dictionary.
|
||||
|
||||
Args:
|
||||
key: The key to check for.
|
||||
|
||||
Returns:
|
||||
True if key exists, False otherwise.
|
||||
"""
|
||||
return key in self._data
|
||||
|
||||
def __iter__(self) -> Any:
|
||||
"""Return an iterator over the dictionary keys.
|
||||
|
||||
Returns:
|
||||
An iterator over the keys.
|
||||
"""
|
||||
return iter(self._data)
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(self._proxy)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of items in the dictionary.
|
||||
return len(self._proxy)
|
||||
|
||||
Returns:
|
||||
The number of key-value pairs.
|
||||
"""
|
||||
return len(self._data)
|
||||
def __contains__(self, key: object) -> bool:
|
||||
return key in self._proxy
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return a string representation of the dictionary.
|
||||
return f"ImmutableDict({dict(self._proxy)!r})"
|
||||
|
||||
Returns:
|
||||
A string representation of the ImmutableDict.
|
||||
"""
|
||||
return f"ImmutableDict({self._data!r})"
|
||||
def get(self, key: str, default: object | None = None) -> object | None:
|
||||
return self._proxy.get(key, default)
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""Get value by key with optional default.
|
||||
def keys(self) -> KeysView[str]:
|
||||
return self._proxy.keys()
|
||||
|
||||
Args:
|
||||
key: The key to retrieve.
|
||||
default: Default value if key is not found.
|
||||
def values(self) -> ValuesView[object]:
|
||||
return self._proxy.values()
|
||||
|
||||
Returns:
|
||||
The value associated with the key or default.
|
||||
"""
|
||||
return self._data.get(key, default)
|
||||
def items(self) -> ItemsView[str, object]:
|
||||
return self._proxy.items()
|
||||
|
||||
def keys(self) -> Any:
|
||||
"""Return a view of the dictionary keys.
|
||||
def copy(self) -> Mapping[str, object]:
|
||||
return MappingProxyType(copy.deepcopy(dict(self._proxy)))
|
||||
|
||||
Returns:
|
||||
A view object of the dictionary keys.
|
||||
"""
|
||||
return self._data.keys()
|
||||
# Mutation methods intentionally raise errors
|
||||
def __setitem__(self, key: str, value: object) -> None: # pragma: no cover
|
||||
self._raise_mutation_error()
|
||||
|
||||
def values(self) -> Any:
|
||||
"""Return a view of the dictionary values.
|
||||
def __delitem__(self, key: str) -> None: # pragma: no cover
|
||||
self._raise_mutation_error()
|
||||
|
||||
Returns:
|
||||
A view object of the dictionary values.
|
||||
"""
|
||||
return self._data.values()
|
||||
def pop(self, key: str, default: object | None = None) -> object: # pragma: no cover
|
||||
self._raise_mutation_error()
|
||||
raise AssertionError("pop should never return")
|
||||
|
||||
def items(self) -> Any:
|
||||
"""Return a view of the dictionary key-value pairs.
|
||||
def popitem(self) -> tuple[str, object]: # pragma: no cover
|
||||
self._raise_mutation_error()
|
||||
raise AssertionError("popitem should never return")
|
||||
|
||||
Returns:
|
||||
A view object of the dictionary items.
|
||||
"""
|
||||
return self._data.items()
|
||||
def clear(self) -> None: # pragma: no cover
|
||||
self._raise_mutation_error()
|
||||
|
||||
def copy(self) -> dict[str, Any]:
|
||||
"""Create a shallow copy of the dictionary.
|
||||
def update(self, *args: object, **kwargs: object) -> None: # pragma: no cover
|
||||
self._raise_mutation_error()
|
||||
|
||||
Returns:
|
||||
A new dictionary with the same key-value pairs.
|
||||
"""
|
||||
return self._data.copy()
|
||||
|
||||
# Mutation methods that raise errors
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
"""Prevent item assignment (raises ImmutableStateError).
|
||||
|
||||
Args:
|
||||
key: The key to set (will raise error).
|
||||
value: The value to set (will raise error).
|
||||
|
||||
Raises:
|
||||
ImmutableStateError: Always, as state is immutable.
|
||||
"""
|
||||
self._check_frozen()
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
"""Prevent item deletion (raises ImmutableStateError).
|
||||
|
||||
Args:
|
||||
key: The key to delete (will raise error).
|
||||
|
||||
Raises:
|
||||
ImmutableStateError: Always, as state is immutable.
|
||||
"""
|
||||
self._check_frozen()
|
||||
|
||||
def pop(self, key: str, default: Any = None) -> Any:
|
||||
"""Prevent popping items (raises ImmutableStateError).
|
||||
|
||||
Args:
|
||||
key: The key to pop (will raise error).
|
||||
default: Default value (will raise error before using).
|
||||
|
||||
Returns:
|
||||
Never returns as it always raises an error.
|
||||
|
||||
Raises:
|
||||
ImmutableStateError: Always, as state is immutable.
|
||||
"""
|
||||
self._check_frozen()
|
||||
return None # Never reached
|
||||
|
||||
def popitem(self) -> tuple[str, Any]:
|
||||
"""Prevent popping items (raises ImmutableStateError).
|
||||
|
||||
Returns:
|
||||
Never returns as it always raises an error.
|
||||
|
||||
Raises:
|
||||
ImmutableStateError: Always, as state is immutable.
|
||||
"""
|
||||
self._check_frozen()
|
||||
return ("", None) # Never reached
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Prevent clearing the dictionary (raises ImmutableStateError).
|
||||
|
||||
Raises:
|
||||
ImmutableStateError: Always, as state is immutable.
|
||||
"""
|
||||
self._check_frozen()
|
||||
|
||||
def update(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Prevent updating the dictionary (raises ImmutableStateError).
|
||||
|
||||
Args:
|
||||
*args: Arguments to update with (will raise error).
|
||||
**kwargs: Keyword arguments to update with (will raise error).
|
||||
|
||||
Raises:
|
||||
ImmutableStateError: Always, as state is immutable.
|
||||
"""
|
||||
self._check_frozen()
|
||||
|
||||
def setdefault(self, key: str, default: Any = None) -> Any:
|
||||
"""Prevent setting default values (raises ImmutableStateError).
|
||||
|
||||
Args:
|
||||
key: The key to set default for (will raise error).
|
||||
default: The default value (will raise error before using).
|
||||
|
||||
Returns:
|
||||
Never returns as it always raises an error.
|
||||
|
||||
Raises:
|
||||
ImmutableStateError: Always, as state is immutable.
|
||||
"""
|
||||
self._check_frozen()
|
||||
return None # Never reached
|
||||
def setdefault(self, key: str, default: object | None = None) -> object: # pragma: no cover
|
||||
self._raise_mutation_error()
|
||||
raise AssertionError("setdefault should never return")
|
||||
|
||||
def __deepcopy__(self, memo: dict[int, object]) -> ImmutableDict:
|
||||
"""Support deepcopy by creating a new immutable dict with copied data."""
|
||||
# Deep copy the internal data
|
||||
new_data = {key: copy.deepcopy(value, memo) for key, value in self.items()}
|
||||
obj_id = id(self)
|
||||
if obj_id in memo:
|
||||
return memo[obj_id] # type: ignore[return-value]
|
||||
|
||||
# Create a new ImmutableDict with the copied data
|
||||
return ImmutableDict(new_data)
|
||||
placeholder = ImmutableDict({})
|
||||
memo[obj_id] = placeholder
|
||||
new_data = {key: copy.deepcopy(value, memo) for key, value in self._proxy.items()}
|
||||
result = ImmutableDict(new_data)
|
||||
memo[obj_id] = result
|
||||
return result
|
||||
|
||||
|
||||
def create_immutable_state(state: dict[str, Any]) -> ImmutableDict:
|
||||
def create_immutable_state(state: Mapping[str, object]) -> ImmutableDict:
|
||||
"""Create an immutable version of a state dictionary.
|
||||
|
||||
Args:
|
||||
@@ -314,12 +191,12 @@ def create_immutable_state(state: dict[str, Any]) -> ImmutableDict:
|
||||
An ImmutableDict that prevents modifications.
|
||||
"""
|
||||
# Deep copy to prevent references to mutable objects
|
||||
return ImmutableDict(copy.deepcopy(state))
|
||||
return ImmutableDict(copy.deepcopy(dict(state)))
|
||||
|
||||
|
||||
def update_state_immutably(
|
||||
current_state: dict[str, Any], updates: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
current_state: Mapping[str, object], updates: Mapping[str, object]
|
||||
) -> dict[str, object]:
|
||||
"""Create a new state with updates applied immutably.
|
||||
|
||||
This function creates a new state object by merging the current state
|
||||
@@ -334,35 +211,60 @@ def update_state_immutably(
|
||||
"""
|
||||
# Deep copy the current state into a regular dict
|
||||
# If it's an ImmutableDict, convert to regular dict first
|
||||
new_state: dict[str, Any]
|
||||
if isinstance(current_state, ImmutableDict):
|
||||
new_state = {key: copy.deepcopy(value) for key, value in current_state.items()}
|
||||
else:
|
||||
new_state = copy.deepcopy(current_state)
|
||||
base_state = dict(current_state)
|
||||
new_state: dict[str, object] = copy.deepcopy(base_state)
|
||||
|
||||
def _as_dict(obj: object) -> dict[str, object] | None:
|
||||
if isinstance(obj, dict):
|
||||
return obj
|
||||
if isinstance(obj, Mapping):
|
||||
return {key: obj[key] for key in obj.keys()}
|
||||
return None
|
||||
|
||||
def _deep_merge(
|
||||
target: dict[str, object], source: Mapping[str, object]
|
||||
) -> dict[str, object]:
|
||||
for merge_key, merge_value in source.items():
|
||||
existing_sub = _as_dict(target.get(merge_key))
|
||||
incoming_sub = _as_dict(merge_value)
|
||||
if existing_sub is not None and incoming_sub is not None:
|
||||
target[merge_key] = _deep_merge(
|
||||
copy.deepcopy(existing_sub), incoming_sub
|
||||
)
|
||||
else:
|
||||
target[merge_key] = copy.deepcopy(merge_value)
|
||||
return target
|
||||
|
||||
# Apply updates
|
||||
for key, value in updates.items():
|
||||
# Check for replacement marker
|
||||
if isinstance(value, tuple) and len(value) == 2 and value[0] == "__REPLACE__":
|
||||
# Force replacement, ignore existing value
|
||||
new_state[key] = value[1]
|
||||
elif (
|
||||
replace_tuple = (
|
||||
isinstance(value, tuple)
|
||||
and len(value) == 2
|
||||
and value[0] == "__REPLACE__"
|
||||
)
|
||||
if replace_tuple:
|
||||
new_state[key] = copy.deepcopy(value[1])
|
||||
continue
|
||||
|
||||
existing_dict = _as_dict(new_state.get(key))
|
||||
incoming_dict = _as_dict(value)
|
||||
if existing_dict is not None and incoming_dict is not None:
|
||||
new_state[key] = _deep_merge(
|
||||
copy.deepcopy(existing_dict), incoming_dict
|
||||
)
|
||||
continue
|
||||
|
||||
if (
|
||||
key in new_state
|
||||
and isinstance(new_state[key], list)
|
||||
and isinstance(value, list)
|
||||
):
|
||||
# For lists, create a new list (don't extend in place)
|
||||
new_state[key] = new_state[key] + value
|
||||
elif (
|
||||
key in new_state
|
||||
and isinstance(new_state[key], dict)
|
||||
and isinstance(value, dict)
|
||||
):
|
||||
# For dicts, merge recursively
|
||||
new_state[key] = {**new_state[key], **value}
|
||||
else:
|
||||
# Direct assignment for other types
|
||||
new_state[key] = value
|
||||
merged_list = copy.deepcopy(new_state[key])
|
||||
merged_list.extend(copy.deepcopy(value))
|
||||
new_state[key] = merged_list
|
||||
continue
|
||||
|
||||
new_state[key] = copy.deepcopy(value)
|
||||
|
||||
return new_state
|
||||
|
||||
@@ -400,7 +302,7 @@ def ensure_immutable_node(
|
||||
return await result if inspect.iscoroutine(result) else result
|
||||
|
||||
state = args[0]
|
||||
if not isinstance(state, dict):
|
||||
if not isinstance(state, MutableMapping):
|
||||
result = node_func(*args, **kwargs)
|
||||
return await result if inspect.iscoroutine(result) else result
|
||||
|
||||
@@ -436,7 +338,7 @@ def ensure_immutable_node(
|
||||
return node_func(*args, **kwargs)
|
||||
|
||||
state = args[0]
|
||||
if not isinstance(state, dict):
|
||||
if not isinstance(state, MutableMapping):
|
||||
return node_func(*args, **kwargs)
|
||||
|
||||
# PERFORMANCE WARNING: Deep copying state before/after execution
|
||||
@@ -487,24 +389,17 @@ class StateUpdater:
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, base_state: dict[str, Any] | Any):
|
||||
def __init__(self, base_state: Mapping[str, object]):
|
||||
"""Initialize with a base state.
|
||||
|
||||
Args:
|
||||
base_state: The starting state.
|
||||
"""
|
||||
# Handle both regular dicts and ImmutableDict
|
||||
if isinstance(base_state, ImmutableDict):
|
||||
# Convert ImmutableDict to regular dict for internal use
|
||||
self._state: dict[str, Any] = {}
|
||||
for key, value in base_state.items():
|
||||
self._state[key] = copy.deepcopy(value)
|
||||
else:
|
||||
# Regular dict can be deepcopied normally
|
||||
self._state = copy.deepcopy(base_state)
|
||||
self._updates: dict[str, Any] = {}
|
||||
self._state: dict[str, object] = dict(base_state)
|
||||
self._updates: dict[str, object] = {}
|
||||
|
||||
def set(self, key: str, value: Any) -> StateUpdater:
|
||||
def set(self, key: str, value: object) -> StateUpdater:
|
||||
"""Set a value in the state.
|
||||
|
||||
Args:
|
||||
@@ -517,7 +412,7 @@ class StateUpdater:
|
||||
self._updates[key] = value
|
||||
return self
|
||||
|
||||
def replace(self, key: str, value: Any) -> StateUpdater:
|
||||
def replace(self, key: str, value: object) -> StateUpdater:
|
||||
"""Replace a value in the state, ignoring existing value.
|
||||
|
||||
This forces replacement even for dicts and lists, unlike set()
|
||||
@@ -533,7 +428,7 @@ class StateUpdater:
|
||||
self._updates[key] = ("__REPLACE__", value)
|
||||
return self
|
||||
|
||||
def append(self, key: str, value: Any) -> StateUpdater:
|
||||
def append(self, key: str, value: object) -> StateUpdater:
|
||||
"""Append to a list in the state.
|
||||
|
||||
Args:
|
||||
@@ -560,7 +455,7 @@ class StateUpdater:
|
||||
|
||||
return self
|
||||
|
||||
def extend(self, key: str, values: list[Any]) -> StateUpdater:
|
||||
def extend(self, key: str, values: Sequence[object]) -> StateUpdater:
|
||||
"""Extend a list in the state.
|
||||
|
||||
Args:
|
||||
@@ -583,11 +478,11 @@ class StateUpdater:
|
||||
self._updates[key] = list(current_list)
|
||||
|
||||
if isinstance(self._updates[key], list):
|
||||
self._updates[key].extend(values)
|
||||
self._updates[key].extend(list(values))
|
||||
|
||||
return self
|
||||
|
||||
def merge(self, key: str, updates: dict[str, Any]) -> StateUpdater:
|
||||
def merge(self, key: str, updates: Mapping[str, object]) -> StateUpdater:
|
||||
"""Merge updates into a dict in the state.
|
||||
|
||||
Args:
|
||||
@@ -610,7 +505,7 @@ class StateUpdater:
|
||||
self._updates[key] = dict(current_dict)
|
||||
|
||||
if isinstance(self._updates[key], dict):
|
||||
self._updates[key].update(updates)
|
||||
self._updates[key].update(dict(updates))
|
||||
|
||||
return self
|
||||
|
||||
@@ -636,7 +531,7 @@ class StateUpdater:
|
||||
self._updates[key] = current_value + amount
|
||||
return self
|
||||
|
||||
def build(self) -> dict[str, Any]:
|
||||
def build(self) -> dict[str, object]:
|
||||
"""Build the final state with all updates applied.
|
||||
|
||||
Returns:
|
||||
@@ -645,7 +540,7 @@ class StateUpdater:
|
||||
return update_state_immutably(self._state, self._updates)
|
||||
|
||||
|
||||
def validate_state_schema(state: dict[str, Any], schema: type) -> None:
|
||||
def validate_state_schema(state: Mapping[str, object], schema: type) -> None:
|
||||
"""Validate that a state conforms to a schema.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -347,7 +347,7 @@ CleanupFunctionWithArgs = Callable[..., Awaitable[None]]
|
||||
|
||||
|
||||
# Error handling types to break circular imports with core.errors
|
||||
class ErrorDetails(TypedDict):
|
||||
class ErrorDetails(TypedDict):
|
||||
"""Detailed error information."""
|
||||
|
||||
type: str
|
||||
@@ -355,7 +355,7 @@ class ErrorDetails(TypedDict):
|
||||
severity: str
|
||||
category: str
|
||||
timestamp: str
|
||||
context: dict[str, Any]
|
||||
context: dict[str, object]
|
||||
traceback: str | None
|
||||
retry_exhausted: NotRequired[bool] # Added for retry handling
|
||||
notification_sent: NotRequired[bool] # Added for notification tracking
|
||||
@@ -365,7 +365,7 @@ class ErrorDetails(TypedDict):
|
||||
escalated: NotRequired[bool] # Whether error has been escalated
|
||||
|
||||
|
||||
class ErrorInfo(TypedDict):
|
||||
class ErrorInfo(TypedDict):
|
||||
"""Error information structure for state management."""
|
||||
|
||||
message: str
|
||||
@@ -374,7 +374,8 @@ class ErrorInfo(TypedDict):
|
||||
severity: str
|
||||
category: str
|
||||
timestamp: str
|
||||
context: dict[str, Any]
|
||||
context: dict[str, object]
|
||||
details: NotRequired[dict[str, object]]
|
||||
traceback: str | None
|
||||
|
||||
|
||||
|
||||
@@ -1,333 +1,201 @@
|
||||
"""Backward compatibility layer for URL processing.
|
||||
"""Deprecated bridge to the modern URL processing capability layer.
|
||||
|
||||
⚠️ DEPRECATED: This module is deprecated and will be removed in a future version.
|
||||
Please use biz_bud.tools.capabilities.url_processing instead.
|
||||
This module used to provide a large backward-compatibility surface for the
|
||||
legacy URLProcessor implementation. As part of the LangGraph v1 migration we
|
||||
no longer ship stub fallbacks – the new tool-backed service is always required.
|
||||
|
||||
This module provides backward compatibility imports for the URL processing
|
||||
system that has been migrated to the tools layer. Existing imports will
|
||||
continue to work while displaying deprecation warnings.
|
||||
|
||||
Migration Guide:
|
||||
Old import:
|
||||
from biz_bud.core.url_processing import URLProcessor
|
||||
|
||||
New import:
|
||||
from biz_bud.tools.capabilities.url_processing import (
|
||||
validate_url, normalize_url, discover_urls, process_urls_batch
|
||||
)
|
||||
|
||||
Instead of:
|
||||
processor = URLProcessor()
|
||||
result = await processor.validate_url("https://example.com")
|
||||
|
||||
Use:
|
||||
result = await validate_url("https://example.com")
|
||||
Only a thin wrapper remains so existing imports continue to work while emitting
|
||||
clear deprecation guidance. Callers should migrate to
|
||||
``biz_bud.tools.capabilities.url_processing`` as soon as possible.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import warnings
|
||||
from collections.abc import Coroutine, Mapping
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
# Issue deprecation warning
|
||||
warnings.warn(
|
||||
"biz_bud.core.url_processing is deprecated. "
|
||||
"Please use biz_bud.tools.capabilities.url_processing instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
from biz_bud.services.factory import get_global_factory
|
||||
|
||||
# Try to import from new location first, fall back to old location
|
||||
try:
|
||||
# Import new implementations for backward compatibility
|
||||
if TYPE_CHECKING: # pragma: no cover - used for static analysis only
|
||||
from biz_bud.tools.capabilities.url_processing.service import URLProcessingService
|
||||
|
||||
# Service is imported above for direct usage
|
||||
# Legacy URLProcessor wrapper class
|
||||
class URLProcessor:
|
||||
"""Legacy URLProcessor for backward compatibility."""
|
||||
|
||||
def __init__(self, config=None):
|
||||
warnings.warn(
|
||||
"URLProcessor is deprecated. Use URL processing tools instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
self.config = config
|
||||
self._service: URLProcessingService | None = None
|
||||
|
||||
async def _get_service(self) -> URLProcessingService:
|
||||
"""Get or create URL processing service."""
|
||||
if self._service is None:
|
||||
from biz_bud.services.factory import get_global_factory
|
||||
factory = await get_global_factory()
|
||||
self._service = await factory.get_service(URLProcessingService)
|
||||
assert self._service is not None # For type checker
|
||||
return self._service
|
||||
|
||||
async def validate_url(self, url: str):
|
||||
"""Validate URL using new service."""
|
||||
service = await self._get_service()
|
||||
result = await service.validate_url(url)
|
||||
# Convert dataclass to dict for backward compatibility
|
||||
from dataclasses import asdict, is_dataclass
|
||||
return asdict(result) if is_dataclass(result) else result
|
||||
|
||||
async def normalize_url(self, url: str) -> str:
|
||||
"""Normalize URL using new service."""
|
||||
service = await self._get_service()
|
||||
return service.normalize_url(url)
|
||||
|
||||
async def discover_urls(self, base_url: str):
|
||||
"""Discover URLs using new service."""
|
||||
service = await self._get_service()
|
||||
result = await service.discover_urls(base_url)
|
||||
return result.discovered_urls if hasattr(result, 'discovered_urls') else [base_url]
|
||||
|
||||
async def deduplicate_urls(self, urls: list[str]):
|
||||
"""Deduplicate URLs using new service."""
|
||||
service = await self._get_service()
|
||||
return await service.deduplicate_urls(urls)
|
||||
|
||||
async def process_urls(self, urls: list[str]):
|
||||
"""Process URLs using new service."""
|
||||
service = await self._get_service()
|
||||
return await service.process_urls_batch(urls)
|
||||
|
||||
new_tools_available = True
|
||||
|
||||
except ImportError:
|
||||
# Fall back to old implementation if new tools not available
|
||||
new_tools_available = False
|
||||
|
||||
# Re-export key utilities for backward compatibility
|
||||
from ..utils.url_analyzer import (
|
||||
URLAnalysisResult,
|
||||
URLType,
|
||||
analyze_url_type,
|
||||
get_url_type,
|
||||
is_git_repo_url,
|
||||
is_pdf_url,
|
||||
is_valid_url,
|
||||
normalize_url,
|
||||
)
|
||||
from ..utils.url_normalizer import URLNormalizer
|
||||
|
||||
# Legacy classes for backward compatibility
|
||||
class URLProcessorConfig:
|
||||
"""Legacy config class."""
|
||||
|
||||
class ValidationLevel:
|
||||
"""Legacy validation level enum."""
|
||||
BASIC = "basic"
|
||||
STANDARD = "standard"
|
||||
STRICT = "strict"
|
||||
|
||||
class URLDiscoverer:
|
||||
"""Legacy discoverer class."""
|
||||
|
||||
class URLFilter:
|
||||
"""Legacy filter class."""
|
||||
|
||||
class URLValidator:
|
||||
"""Legacy validator class."""
|
||||
|
||||
# Legacy exception classes
|
||||
class URLProcessingError(Exception):
|
||||
"""Base URL processing error."""
|
||||
|
||||
class URLValidationError(URLProcessingError):
|
||||
"""URL validation error."""
|
||||
|
||||
class URLNormalizationError(URLProcessingError):
|
||||
"""URL normalization error."""
|
||||
|
||||
class URLDiscoveryError(URLProcessingError):
|
||||
"""URL discovery error."""
|
||||
|
||||
class URLFilterError(URLProcessingError):
|
||||
"""URL filter error."""
|
||||
|
||||
class URLDeduplicationError(URLProcessingError):
|
||||
"""URL deduplication error."""
|
||||
|
||||
class URLCacheError(URLProcessingError):
|
||||
"""URL cache error."""
|
||||
|
||||
# Legacy data model classes
|
||||
class ProcessedURL:
|
||||
"""Legacy processed URL model."""
|
||||
|
||||
class URLAnalysis:
|
||||
"""Legacy URL analysis model."""
|
||||
|
||||
class ValidationResult:
|
||||
"""Legacy validation result model."""
|
||||
|
||||
class DiscoveryResult:
|
||||
"""Legacy discovery result model."""
|
||||
|
||||
class FilterResult:
|
||||
"""Legacy filter result model."""
|
||||
|
||||
class DeduplicationResult:
|
||||
"""Legacy deduplication result model."""
|
||||
|
||||
# URLProcessor is defined above in the try block
|
||||
|
||||
# Version information
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "Business Buddy Team"
|
||||
|
||||
# Public API - what gets imported with "from url_processing import *"
|
||||
if new_tools_available:
|
||||
__all__ = [
|
||||
# Main processor class (deprecated)
|
||||
"URLProcessor",
|
||||
# New service class
|
||||
"URLProcessingService",
|
||||
# Module metadata
|
||||
"__version__",
|
||||
"__author__",
|
||||
# Convenience functions
|
||||
"validate_url",
|
||||
"normalize_url_simple",
|
||||
]
|
||||
else:
|
||||
__all__ = [
|
||||
# Main processor class
|
||||
"URLProcessor",
|
||||
# Component classes
|
||||
"URLValidator",
|
||||
"URLFilter",
|
||||
"URLDiscoverer",
|
||||
# Configuration
|
||||
"URLProcessorConfig",
|
||||
"ValidationLevel",
|
||||
# Data models
|
||||
"ProcessedURL",
|
||||
"URLAnalysis",
|
||||
"ValidationResult",
|
||||
"DiscoveryResult",
|
||||
"FilterResult",
|
||||
"DeduplicationResult",
|
||||
# Exceptions
|
||||
"URLProcessingError",
|
||||
"URLValidationError",
|
||||
"URLNormalizationError",
|
||||
"URLDiscoveryError",
|
||||
"URLFilterError",
|
||||
"URLDeduplicationError",
|
||||
"URLCacheError",
|
||||
# Backward compatibility exports
|
||||
"analyze_url_type",
|
||||
"normalize_url",
|
||||
"is_valid_url",
|
||||
"is_git_repo_url",
|
||||
"is_pdf_url",
|
||||
"get_url_type",
|
||||
"URLType",
|
||||
"URLAnalysisResult",
|
||||
"URLNormalizer",
|
||||
# Module metadata
|
||||
"__version__",
|
||||
"__author__",
|
||||
]
|
||||
_T = TypeVar("_T")
|
||||
|
||||
warnings.warn(
|
||||
"biz_bud.core.url_processing is deprecated. "
|
||||
"Use biz_bud.tools.capabilities.url_processing instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
# Module-level convenience functions for backward compatibility
|
||||
if new_tools_available:
|
||||
def validate_url(url: str, strict: bool = False) -> bool:
|
||||
"""Quick URL validation using new service (deprecated).
|
||||
def _run_sync(coro: Coroutine[object, object, _T]) -> _T:
|
||||
"""Execute an async operation from a synchronous helper safely."""
|
||||
|
||||
⚠️ DEPRECATED: Use biz_bud.tools.capabilities.url_processing.validate_url instead.
|
||||
"""
|
||||
warnings.warn(
|
||||
"validate_url from core.url_processing is deprecated. "
|
||||
"Use biz_bud.tools.capabilities.url_processing.validate_url instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# No loop running in this thread; safe to create one
|
||||
return asyncio.run(coro)
|
||||
|
||||
if loop.is_running():
|
||||
raise RuntimeError(
|
||||
"Synchronous URL processing helpers cannot run inside an active "
|
||||
"event loop. Call the async variants from "
|
||||
"biz_bud.tools.capabilities.url_processing instead."
|
||||
)
|
||||
import asyncio
|
||||
|
||||
async def _validate():
|
||||
from biz_bud.services.factory import get_global_factory
|
||||
from biz_bud.tools.capabilities.url_processing.service import (
|
||||
URLProcessingService,
|
||||
)
|
||||
factory = await get_global_factory()
|
||||
service = await factory.get_service(URLProcessingService)
|
||||
result = await service.validate_url(url, "strict" if strict else "standard")
|
||||
return result.is_valid if hasattr(result, 'is_valid') else False
|
||||
# A loop exists but isn't running in this thread – submit the coroutine
|
||||
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
||||
return future.result()
|
||||
|
||||
return asyncio.run(_validate())
|
||||
|
||||
def normalize_url_simple(url: str) -> str:
|
||||
"""Quick URL normalization using new service (deprecated).
|
||||
def _convert_dataclass(result: object) -> object:
|
||||
"""Convert dataclass results into plain dictionaries recursively."""
|
||||
|
||||
⚠️ DEPRECATED: Use biz_bud.tools.capabilities.url_processing.normalize_url instead.
|
||||
"""
|
||||
if is_dataclass(result):
|
||||
return asdict(result)
|
||||
return result
|
||||
|
||||
|
||||
class URLProcessor:
|
||||
"""Deprecated wrapper that now forwards to URLProcessingService."""
|
||||
|
||||
def __init__(self, config: Mapping[str, object] | None = None) -> None:
|
||||
warnings.warn(
|
||||
"normalize_url_simple from core.url_processing is deprecated. "
|
||||
"Use biz_bud.tools.capabilities.url_processing.normalize_url instead.",
|
||||
"URLProcessor is deprecated. Use URL processing tools instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
stacklevel=2,
|
||||
)
|
||||
import asyncio
|
||||
self._config = config or {}
|
||||
self._service: "URLProcessingService | None" = None
|
||||
|
||||
async def _normalize():
|
||||
from biz_bud.services.factory import get_global_factory
|
||||
async def _get_service(self) -> "URLProcessingService":
|
||||
if self._service is None:
|
||||
from biz_bud.tools.capabilities.url_processing.service import (
|
||||
URLProcessingService,
|
||||
URLProcessingService as _Service,
|
||||
)
|
||||
|
||||
factory = await get_global_factory()
|
||||
service = await factory.get_service(URLProcessingService)
|
||||
return service.normalize_url(url)
|
||||
self._service = await factory.get_service(_Service)
|
||||
return self._service
|
||||
|
||||
return asyncio.run(_normalize())
|
||||
async def validate_url(
|
||||
self, url: str, *, level: str | None = None, provider: str | None = None
|
||||
) -> dict[str, object]:
|
||||
service = await self._get_service()
|
||||
provider_name = provider or level or self._config.get("validation_level") or "standard"
|
||||
result = await service.validate_url(url, provider_name)
|
||||
converted = _convert_dataclass(result)
|
||||
return converted if isinstance(converted, dict) else {"result": converted}
|
||||
|
||||
else:
|
||||
def validate_url(url: str, strict: bool = False) -> bool:
|
||||
"""Quick URL validation using default settings.
|
||||
async def normalize_url(self, url: str, *, provider: str | None = None) -> str:
|
||||
service = await self._get_service()
|
||||
provider_name = provider or self._config.get("normalization_provider")
|
||||
return service.normalize_url(url, provider_name)
|
||||
|
||||
Args:
|
||||
url: URL to validate
|
||||
strict: Whether to use strict validation
|
||||
async def discover_urls(
|
||||
self, base_url: str, *, provider: str | None = None, max_results: int | None = None
|
||||
) -> list[str]:
|
||||
service = await self._get_service()
|
||||
provider_name = provider or self._config.get("discovery_provider")
|
||||
result = await service.discover_urls(base_url, provider_name)
|
||||
urls = list(result.discovered_urls)
|
||||
if max_results is not None and len(urls) > max_results:
|
||||
return urls[:max_results]
|
||||
return urls
|
||||
|
||||
Returns:
|
||||
True if URL is valid
|
||||
async def deduplicate_urls(
|
||||
self, urls: list[str], *, provider: str | None = None
|
||||
) -> list[str]:
|
||||
service = await self._get_service()
|
||||
provider_name = provider or self._config.get("deduplication_provider")
|
||||
return await service.deduplicate_urls(urls, provider_name)
|
||||
|
||||
Example:
|
||||
>>> validate_url("https://example.com")
|
||||
True
|
||||
>>> validate_url("not-a-url")
|
||||
False
|
||||
"""
|
||||
from ..utils.url_analyzer import is_valid_url
|
||||
return is_valid_url(url, strict=strict)
|
||||
|
||||
def normalize_url_simple(url: str) -> str:
|
||||
"""Quick URL normalization using default settings.
|
||||
|
||||
Args:
|
||||
url: URL to normalize
|
||||
|
||||
Returns:
|
||||
Normalized URL string
|
||||
|
||||
Example:
|
||||
>>> normalize_url_simple("HTTP://Example.COM/Path?param=1#fragment")
|
||||
"https://example.com/Path?param=1"
|
||||
"""
|
||||
from ..utils.url_analyzer import normalize_url
|
||||
return normalize_url(url)
|
||||
async def process_urls(
|
||||
self, urls: list[str], **options: object
|
||||
) -> dict[str, object]:
|
||||
service = await self._get_service()
|
||||
result = await service.process_urls_batch(
|
||||
urls,
|
||||
validation_provider=options.get("validation_provider"),
|
||||
normalization_provider=options.get("normalization_provider"),
|
||||
enable_deduplication=bool(options.get("enable_deduplication", True)),
|
||||
deduplication_provider=options.get("deduplication_provider"),
|
||||
max_concurrent=options.get("max_concurrent"),
|
||||
timeout=options.get("timeout"),
|
||||
)
|
||||
converted = _convert_dataclass(result)
|
||||
return converted if isinstance(converted, dict) else {"result": converted}
|
||||
|
||||
|
||||
# Module initialization
|
||||
def _initialize_module() -> None:
|
||||
"""Initialize the URL processing module."""
|
||||
# Module initialization logic can go here
|
||||
# For now, just validate that required dependencies are available
|
||||
async def _fetch_service() -> "URLProcessingService":
|
||||
from biz_bud.tools.capabilities.url_processing.service import (
|
||||
URLProcessingService as _Service,
|
||||
)
|
||||
|
||||
factory = await get_global_factory()
|
||||
return await factory.get_service(_Service)
|
||||
|
||||
|
||||
# Run initialization when module is imported
|
||||
_initialize_module()
|
||||
def validate_url(url: str, strict: bool = False) -> bool:
|
||||
"""Synchronous helper that proxies to the async validation tool."""
|
||||
|
||||
warnings.warn(
|
||||
"validate_url from biz_bud.core.url_processing is deprecated. "
|
||||
"Use biz_bud.tools.capabilities.url_processing.validate_url instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
async def _validate() -> bool:
|
||||
level = "strict" if strict else "standard"
|
||||
service = await _fetch_service()
|
||||
result = await service.validate_url(url, level)
|
||||
return bool(getattr(result, "is_valid", False))
|
||||
|
||||
return _run_sync(_validate())
|
||||
|
||||
|
||||
def normalize_url_simple(url: str) -> str:
|
||||
"""Synchronous helper for quick URL normalization."""
|
||||
|
||||
warnings.warn(
|
||||
"normalize_url_simple from biz_bud.core.url_processing is deprecated. "
|
||||
"Use biz_bud.tools.capabilities.url_processing.normalize_url instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
async def _normalize() -> str:
|
||||
service = await _fetch_service()
|
||||
return await service.normalize_url(url, None)
|
||||
|
||||
return _run_sync(_normalize())
|
||||
|
||||
|
||||
__all__ = [
|
||||
"URLProcessor",
|
||||
"URLProcessingService",
|
||||
"__version__",
|
||||
"__author__",
|
||||
"validate_url",
|
||||
"normalize_url_simple",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> object:
|
||||
if name == "URLProcessingService":
|
||||
from biz_bud.tools.capabilities.url_processing.service import (
|
||||
URLProcessingService as _Service,
|
||||
)
|
||||
|
||||
globals()[name] = _Service
|
||||
return _Service
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@@ -1,160 +1,201 @@
|
||||
"""Helper functions for graph creation and state initialization.
|
||||
|
||||
This module consolidates common patterns used in graph creation to reduce
|
||||
code duplication and improve maintainability.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from biz_bud.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def process_state_query(
|
||||
query: str | None,
|
||||
messages: list[dict[str, Any]] | None,
|
||||
state_update: dict[str, Any] | None,
|
||||
default_query: str
|
||||
) -> str:
|
||||
"""Process and extract query from various sources.
|
||||
|
||||
Args:
|
||||
query: Direct query string
|
||||
messages: Message history to extract query from
|
||||
state_update: State update that may contain query
|
||||
default_query: Default query to use if none found
|
||||
|
||||
Returns:
|
||||
Processed query string
|
||||
"""
|
||||
# Extract query from messages if not provided
|
||||
if query is None and messages:
|
||||
# Look for the last human/user message
|
||||
for msg in reversed(messages):
|
||||
# Handle different message formats
|
||||
role = msg.get("role")
|
||||
msg_type = msg.get("type")
|
||||
if role in ("user", "human") or msg_type == "human":
|
||||
query = msg.get("content", "")
|
||||
break
|
||||
|
||||
return query or default_query
|
||||
|
||||
|
||||
def format_raw_input(
|
||||
raw_input: str | dict[str, Any] | None,
|
||||
user_query: str
|
||||
) -> tuple[str, str]:
|
||||
"""Format raw input into a consistent string format.
|
||||
|
||||
Args:
|
||||
raw_input: Raw input data (string, dict, or None).
|
||||
If a dict, must be JSON serializable.
|
||||
user_query: User query to use as default
|
||||
|
||||
Returns:
|
||||
Tuple of (raw_input_str, extracted_query)
|
||||
|
||||
Raises:
|
||||
TypeError: If raw_input is a dict but not JSON serializable.
|
||||
"""
|
||||
if raw_input is None:
|
||||
return f'{{"query": "{user_query}"}}', user_query
|
||||
|
||||
if isinstance(raw_input, dict):
|
||||
# If raw_input has a query field, use it
|
||||
extracted_query = raw_input.get("query", user_query)
|
||||
# Avoid json.dumps for simple cases
|
||||
if len(raw_input) == 1 and "query" in raw_input:
|
||||
return f'{{"query": "{raw_input["query"]}"}}', extracted_query
|
||||
else:
|
||||
try:
|
||||
return json.dumps(raw_input), extracted_query
|
||||
except (TypeError, ValueError) as e:
|
||||
# Fallback: show error message in string
|
||||
raw_input_str = f"<non-serializable dict: {e}>"
|
||||
return raw_input_str, extracted_query
|
||||
|
||||
# For unsupported types (already checked that raw_input is str above)
|
||||
raw_input_str = f"<unsupported type: {type(raw_input).__name__}>"
|
||||
return raw_input_str, user_query
|
||||
|
||||
|
||||
def extract_state_update_data(
|
||||
state_update: dict[str, Any] | None,
|
||||
messages: list[dict[str, Any]] | None,
|
||||
raw_input: str | dict[str, Any] | None,
|
||||
thread_id: str | None
|
||||
) -> tuple[list[dict[str, Any]] | None, str | dict[str, Any] | None, str | None]:
|
||||
"""Extract data from state update if provided.
|
||||
|
||||
Args:
|
||||
state_update: State update dictionary from LangGraph API
|
||||
messages: Existing messages
|
||||
raw_input: Existing raw input
|
||||
thread_id: Existing thread ID
|
||||
|
||||
Returns:
|
||||
Tuple of (messages, raw_input, thread_id)
|
||||
"""
|
||||
if not state_update:
|
||||
return messages, raw_input, thread_id
|
||||
|
||||
# Extract messages if present
|
||||
if "messages" in state_update and not messages:
|
||||
messages = state_update["messages"]
|
||||
|
||||
# Extract other fields if present
|
||||
if "raw_input" in state_update and not raw_input:
|
||||
raw_input = state_update["raw_input"]
|
||||
if "thread_id" in state_update and not thread_id:
|
||||
thread_id = state_update["thread_id"]
|
||||
|
||||
return messages, raw_input, thread_id
|
||||
|
||||
|
||||
def create_initial_state_dict(
|
||||
raw_input_str: str,
|
||||
user_query: str,
|
||||
messages: list[dict[str, Any]],
|
||||
thread_id: str,
|
||||
config_dict: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Create the initial state dictionary.
|
||||
|
||||
Args:
|
||||
raw_input_str: Formatted raw input string
|
||||
user_query: User query
|
||||
messages: Message history
|
||||
thread_id: Thread identifier
|
||||
config_dict: Configuration dictionary
|
||||
|
||||
Returns:
|
||||
Initial state dictionary
|
||||
"""
|
||||
return {
|
||||
"raw_input": raw_input_str,
|
||||
"parsed_input": {
|
||||
"raw_payload": {
|
||||
"query": user_query,
|
||||
},
|
||||
"user_query": user_query,
|
||||
},
|
||||
"messages": messages,
|
||||
"initial_input": {
|
||||
"query": user_query,
|
||||
},
|
||||
"thread_id": thread_id,
|
||||
"config": config_dict,
|
||||
"input_metadata": {},
|
||||
"context": {},
|
||||
"status": "pending",
|
||||
"errors": [],
|
||||
"run_metadata": {},
|
||||
"is_last_step": False,
|
||||
"final_result": None,
|
||||
}
|
||||
"""Helper functions for graph creation and state initialization.
|
||||
|
||||
This module consolidates common patterns used in graph creation to reduce
|
||||
code duplication and improve maintainability.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from biz_bud.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def process_state_query(
|
||||
query: str | None,
|
||||
messages: Sequence[Mapping[str, object]] | None,
|
||||
state_update: Mapping[str, object] | None,
|
||||
default_query: str,
|
||||
) -> str:
|
||||
"""Process and extract query from various sources.
|
||||
|
||||
Args:
|
||||
query: Direct query string
|
||||
messages: Message history to extract query from
|
||||
state_update: State update that may contain query
|
||||
default_query: Default query to use if none found
|
||||
|
||||
Returns:
|
||||
Processed query string
|
||||
"""
|
||||
if query is None and messages:
|
||||
for msg in reversed(messages):
|
||||
if not isinstance(msg, Mapping):
|
||||
continue
|
||||
|
||||
role = msg.get("role")
|
||||
msg_type = msg.get("type")
|
||||
if role not in ("user", "human") and msg_type != "human":
|
||||
continue
|
||||
|
||||
content = msg.get("content", "")
|
||||
extracted: str | None = None
|
||||
|
||||
if isinstance(content, str):
|
||||
extracted = content
|
||||
elif isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for part in content:
|
||||
if isinstance(part, str):
|
||||
parts.append(part)
|
||||
elif isinstance(part, Mapping):
|
||||
text_value = part.get("text") or part.get("content")
|
||||
if isinstance(text_value, str):
|
||||
parts.append(text_value)
|
||||
if parts:
|
||||
extracted = " ".join(parts)
|
||||
elif isinstance(content, Mapping):
|
||||
text_value = content.get("text") or content.get("content")
|
||||
if isinstance(text_value, str):
|
||||
extracted = text_value
|
||||
|
||||
if extracted:
|
||||
query = extracted
|
||||
break
|
||||
|
||||
return query or default_query
|
||||
|
||||
|
||||
def format_raw_input(
|
||||
raw_input: str | Mapping[str, object] | None,
|
||||
user_query: str,
|
||||
) -> tuple[str, str]:
|
||||
"""Format raw input into a consistent string format.
|
||||
|
||||
Args:
|
||||
raw_input: Raw input data (string, dict, or None).
|
||||
If a dict, must be JSON serializable.
|
||||
user_query: User query to use as default
|
||||
|
||||
Returns:
|
||||
Tuple of (raw_input_str, extracted_query)
|
||||
|
||||
Raises:
|
||||
TypeError: If raw_input is a dict but not JSON serializable.
|
||||
"""
|
||||
if raw_input is None:
|
||||
return json.dumps({"query": user_query}), user_query
|
||||
|
||||
if isinstance(raw_input, str):
|
||||
return json.dumps({"query": raw_input}), raw_input
|
||||
|
||||
if isinstance(raw_input, Mapping):
|
||||
raw_input_dict = dict(raw_input)
|
||||
# If raw_input has a query field, use it
|
||||
extracted_query = str(raw_input_dict.get("query", user_query))
|
||||
# Avoid json.dumps for simple cases
|
||||
if len(raw_input_dict) == 1 and "query" in raw_input_dict:
|
||||
return json.dumps({"query": raw_input_dict["query"]}), extracted_query
|
||||
else:
|
||||
try:
|
||||
return json.dumps(raw_input_dict), extracted_query
|
||||
except (TypeError, ValueError) as e:
|
||||
fallback = {
|
||||
"error": f"non-serializable dict: {e}",
|
||||
"query": extracted_query,
|
||||
}
|
||||
return json.dumps(fallback), extracted_query
|
||||
|
||||
fallback = {
|
||||
"error": f"unsupported type: {type(raw_input).__name__}",
|
||||
"query": user_query,
|
||||
}
|
||||
return json.dumps(fallback), user_query
|
||||
|
||||
|
||||
def extract_state_update_data(
|
||||
state_update: Mapping[str, object] | None,
|
||||
messages: list[dict[str, object]] | None,
|
||||
raw_input: str | Mapping[str, object] | None,
|
||||
thread_id: str | None,
|
||||
) -> tuple[
|
||||
list[dict[str, object]] | None,
|
||||
str | Mapping[str, object] | None,
|
||||
str | None,
|
||||
]:
|
||||
"""Extract data from state update if provided.
|
||||
|
||||
Args:
|
||||
state_update: State update dictionary from LangGraph API
|
||||
messages: Existing messages
|
||||
raw_input: Existing raw input
|
||||
thread_id: Existing thread ID
|
||||
|
||||
Returns:
|
||||
Tuple of (messages, raw_input, thread_id)
|
||||
"""
|
||||
if not state_update:
|
||||
return messages, raw_input, thread_id
|
||||
|
||||
# Extract messages if present
|
||||
if "messages" in state_update and not messages:
|
||||
messages_value = state_update["messages"]
|
||||
if isinstance(messages_value, list):
|
||||
coerced_messages: list[dict[str, object]] = []
|
||||
for item in messages_value:
|
||||
if isinstance(item, Mapping):
|
||||
coerced_messages.append(dict(item))
|
||||
if coerced_messages:
|
||||
messages = coerced_messages
|
||||
|
||||
# Extract other fields if present
|
||||
if "raw_input" in state_update and not raw_input:
|
||||
raw_input = state_update["raw_input"]
|
||||
if "thread_id" in state_update and not thread_id:
|
||||
thread_id = state_update["thread_id"]
|
||||
|
||||
return messages, raw_input, thread_id
|
||||
|
||||
|
||||
def create_initial_state_dict(
|
||||
raw_input_str: str,
|
||||
user_query: str,
|
||||
messages: list[dict[str, object]],
|
||||
thread_id: str,
|
||||
config_dict: Mapping[str, object],
|
||||
) -> dict[str, object]:
|
||||
"""Create the initial state dictionary.
|
||||
|
||||
Args:
|
||||
raw_input_str: Formatted raw input string
|
||||
user_query: User query
|
||||
messages: Message history
|
||||
thread_id: Thread identifier
|
||||
config_dict: Configuration dictionary
|
||||
|
||||
Returns:
|
||||
Initial state dictionary
|
||||
"""
|
||||
return {
|
||||
"raw_input": raw_input_str,
|
||||
"parsed_input": {
|
||||
"raw_payload": {
|
||||
"query": user_query,
|
||||
},
|
||||
"user_query": user_query,
|
||||
},
|
||||
"messages": messages,
|
||||
"initial_input": {
|
||||
"query": user_query,
|
||||
},
|
||||
"thread_id": thread_id,
|
||||
"config": dict(config_dict),
|
||||
"input_metadata": {},
|
||||
"context": {},
|
||||
"status": "pending",
|
||||
"errors": [],
|
||||
"run_metadata": {},
|
||||
"is_last_step": False,
|
||||
"final_result": None,
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Validation decorators for functions."""
|
||||
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from typing import Any, ParamSpec, TypeVar, cast
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from typing import ParamSpec, TypeVar, cast
|
||||
|
||||
from ..errors import ValidationError
|
||||
|
||||
@@ -52,7 +52,7 @@ def validate_args(
|
||||
|
||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> T:
|
||||
def wrapper(*args: object, **kwargs: object) -> T:
|
||||
# Get function signature
|
||||
import inspect
|
||||
|
||||
@@ -98,7 +98,7 @@ def validate_return(
|
||||
|
||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> T:
|
||||
def wrapper(*args: object, **kwargs: object) -> T:
|
||||
result = func(*args, **kwargs)
|
||||
is_valid, error_msg = validator(result)
|
||||
if not is_valid:
|
||||
@@ -126,7 +126,7 @@ def validate_not_none(
|
||||
|
||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> T:
|
||||
def wrapper(*args: object, **kwargs: object) -> T:
|
||||
import inspect
|
||||
|
||||
sig = inspect.signature(func)
|
||||
@@ -168,7 +168,7 @@ def validate_types(
|
||||
|
||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> T:
|
||||
def wrapper(*args: object, **kwargs: object) -> T:
|
||||
import inspect
|
||||
|
||||
sig = inspect.signature(func)
|
||||
|
||||
@@ -15,8 +15,8 @@ import asyncio
|
||||
import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, ParamSpec, TypeVar, cast
|
||||
from collections.abc import Awaitable, Callable, Mapping
|
||||
from typing import ParamSpec, TypeVar, cast
|
||||
|
||||
from pydantic import BaseModel, ValidationError, create_model
|
||||
|
||||
@@ -28,7 +28,8 @@ F = TypeVar("F", bound=Callable[..., object])
|
||||
|
||||
# Type alias for state dictionaries that can contain any type
|
||||
# (needed for validation framework)
|
||||
StateDict = dict[str, Any]
|
||||
StateDict = dict[str, object]
|
||||
GraphFactory = Callable[..., Awaitable[object]]
|
||||
|
||||
# Registry to track validated functions
|
||||
_VALIDATED_FUNCTIONS: set[str] = set()
|
||||
@@ -54,10 +55,13 @@ RECOMMENDED_GRAPH_METHODS: list[str] = [
|
||||
]
|
||||
|
||||
|
||||
FieldTypes = Mapping[str, type[object]]
|
||||
|
||||
|
||||
def create_validation_model(
|
||||
name: str,
|
||||
required_fields: dict[str, type[Any]],
|
||||
optional_fields: dict[str, type[Any]] | None = None,
|
||||
required_fields: FieldTypes,
|
||||
optional_fields: FieldTypes | None = None,
|
||||
) -> type[BaseModel]:
|
||||
"""Create a Pydantic model for node input/output validation.
|
||||
|
||||
@@ -69,14 +73,14 @@ def create_validation_model(
|
||||
Returns:
|
||||
A Pydantic model class with the specified fields and validation rules
|
||||
"""
|
||||
field_definitions: dict[str, tuple[Any, Any]] = {
|
||||
field_definitions: dict[str, tuple[type[object], object]] = {
|
||||
field_name: (field_type, ...)
|
||||
for field_name, field_type in required_fields.items()
|
||||
}
|
||||
if optional_fields:
|
||||
for field_name, field_type in optional_fields.items():
|
||||
field_definitions[field_name] = (field_type, None)
|
||||
model: type[BaseModel] = create_model(name, **field_definitions) # type: ignore
|
||||
model = cast(type[BaseModel], create_model(name, **field_definitions))
|
||||
return model
|
||||
|
||||
|
||||
@@ -96,13 +100,13 @@ class PydanticValidator:
|
||||
raise NodeValidationError(str(e), validation_type=self.phase) from e
|
||||
|
||||
|
||||
def is_validated(func: Callable[..., Any]) -> bool:
|
||||
def is_validated(func: Callable[..., object]) -> bool:
|
||||
"""Check if a function has been validated."""
|
||||
func_id: str = f"{func.__module__}.{func.__name__}"
|
||||
return func_id in _VALIDATED_FUNCTIONS
|
||||
|
||||
|
||||
def mark_validated(func: Callable[..., Any]) -> None:
|
||||
def mark_validated(func: Callable[..., object]) -> None:
|
||||
"""Mark a function as validated."""
|
||||
func_id: str = f"{func.__module__}.{func.__name__}"
|
||||
_VALIDATED_FUNCTIONS.add(func_id)
|
||||
@@ -209,11 +213,8 @@ def validate_node_output(output_model: type[BaseModel]) -> Callable[[F], F]:
|
||||
result = await async_func(state, *args, **kwargs)
|
||||
result_dict: StateDict = result
|
||||
try:
|
||||
pydantic_validator = cast(
|
||||
"Callable[[dict[str, Any]], BaseModel]", validator
|
||||
)
|
||||
validated_result: BaseModel = pydantic_validator(result_dict)
|
||||
return validated_result.model_dump()
|
||||
validated_result = validator(result_dict)
|
||||
return cast(StateDict, validated_result.model_dump())
|
||||
except NodeValidationError as e:
|
||||
source: str = f"{func.__module__}.{func.__name__}"
|
||||
return add_exception_error(result_dict, e, source)
|
||||
@@ -230,11 +231,8 @@ def validate_node_output(output_model: type[BaseModel]) -> Callable[[F], F]:
|
||||
result = sync_func(state, *args, **kwargs)
|
||||
result_dict: StateDict = result
|
||||
try:
|
||||
pydantic_validator = cast(
|
||||
"Callable[[dict[str, Any]], BaseModel]", validator
|
||||
)
|
||||
validated_result: BaseModel = pydantic_validator(result_dict)
|
||||
return validated_result.model_dump()
|
||||
validated_result = validator(result_dict)
|
||||
return cast(StateDict, validated_result.model_dump())
|
||||
except NodeValidationError as e:
|
||||
source: str = f"{func.__module__}.{func.__name__}"
|
||||
return add_exception_error(result_dict, e, source)
|
||||
@@ -346,7 +344,7 @@ async def ensure_graph_compatibility(
|
||||
|
||||
|
||||
async def validate_all_graphs(
|
||||
graph_functions: dict[str, Callable[[], Any]],
|
||||
graph_functions: Mapping[str, GraphFactory],
|
||||
) -> bool:
|
||||
"""Validate all graph creation functions.
|
||||
|
||||
@@ -362,9 +360,7 @@ async def validate_all_graphs(
|
||||
# Use inspect to determine if function takes arguments
|
||||
sig = inspect.signature(func)
|
||||
if len(sig.parameters) == 0:
|
||||
# Cast to the expected coroutine function type
|
||||
coro_func = cast("Callable[[], Awaitable[object]]", func)
|
||||
graph = await coro_func()
|
||||
graph = await cast("Callable[[], Awaitable[object]]", func)()
|
||||
else:
|
||||
# Skip functions that require arguments for now
|
||||
all_valid = False
|
||||
|
||||
@@ -6,15 +6,23 @@ including interrupt-based validation, state initialization helpers, and proper e
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar, cast
|
||||
import copy
|
||||
from collections import deque
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Literal, Protocol, TypeVar, cast
|
||||
|
||||
try: # pragma: no cover - defensive import guard
|
||||
import json
|
||||
except Exception: # pragma: no cover - fallback when json is unavailable
|
||||
json = None # type: ignore[assignment]
|
||||
|
||||
from biz_bud.core.langgraph.state_immutability import StateUpdater
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Mapping
|
||||
|
||||
# Remove TypedDict bound constraint as it's not allowed
|
||||
T = TypeVar("T")
|
||||
# Remove TypedDict bound constraint as it's not allowed
|
||||
T = TypeVar("T")
|
||||
|
||||
StateMapping = Mapping[str, object]
|
||||
StateDict = dict[str, object]
|
||||
|
||||
|
||||
class ValidationError(Exception):
|
||||
@@ -25,7 +33,7 @@ class ValidationError(Exception):
|
||||
message: str,
|
||||
field: str | None = None,
|
||||
expected_type: type | None = None,
|
||||
actual_value: Any = None,
|
||||
actual_value: object | None = None,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.field = field
|
||||
@@ -33,24 +41,24 @@ class ValidationError(Exception):
|
||||
self.actual_value = actual_value
|
||||
|
||||
|
||||
class StateValidator(Protocol):
|
||||
class StateValidator(Protocol):
|
||||
"""Protocol for state validators."""
|
||||
|
||||
def __call__(self, state: Mapping[str, Any]) -> bool:
|
||||
def __call__(self, state: StateMapping) -> bool:
|
||||
"""Validate state and return True if valid."""
|
||||
...
|
||||
|
||||
|
||||
class StateInitializer(Protocol):
|
||||
class StateInitializer(Protocol):
|
||||
"""Protocol for state initializers."""
|
||||
|
||||
def __call__(self, **kwargs: Any) -> dict[str, Any]:
|
||||
def __call__(self, **kwargs: object) -> StateDict:
|
||||
"""Initialize state with proper defaults."""
|
||||
...
|
||||
|
||||
|
||||
def validate_required_fields(
|
||||
state: Mapping[str, Any],
|
||||
def validate_required_fields(
|
||||
state: StateMapping,
|
||||
required_fields: set[str],
|
||||
) -> None:
|
||||
"""Validate that all required fields are present in state.
|
||||
@@ -68,8 +76,8 @@ def validate_required_fields(
|
||||
)
|
||||
|
||||
|
||||
def validate_field_type(
|
||||
state: Mapping[str, Any],
|
||||
def validate_field_type(
|
||||
state: StateMapping,
|
||||
field_name: str,
|
||||
expected_type: type,
|
||||
) -> None:
|
||||
@@ -97,11 +105,11 @@ def validate_field_type(
|
||||
)
|
||||
|
||||
|
||||
def create_state_validator(
|
||||
required_fields: set[str] | None = None,
|
||||
type_checks: dict[str, type] | None = None,
|
||||
custom_validators: list[StateValidator] | None = None,
|
||||
) -> StateValidator:
|
||||
def create_state_validator(
|
||||
required_fields: set[str] | None = None,
|
||||
type_checks: Mapping[str, type[object]] | None = None,
|
||||
custom_validators: Sequence[StateValidator] | None = None,
|
||||
) -> StateValidator:
|
||||
"""Create a composite state validator.
|
||||
|
||||
Args:
|
||||
@@ -112,7 +120,7 @@ def create_state_validator(
|
||||
Returns:
|
||||
A validator function that can be used in LangGraph nodes
|
||||
"""
|
||||
def validator(state: Mapping[str, Any]) -> bool:
|
||||
def validator(state: StateMapping) -> bool:
|
||||
try:
|
||||
# Check required fields
|
||||
if required_fields:
|
||||
@@ -137,10 +145,10 @@ def create_state_validator(
|
||||
return validator
|
||||
|
||||
|
||||
def create_validation_node(
|
||||
validator: StateValidator,
|
||||
error_message: str = "State validation failed",
|
||||
) -> Callable[[Mapping[str, Any]], dict[str, Any]]:
|
||||
def create_validation_node(
|
||||
validator: StateValidator,
|
||||
error_message: str = "State validation failed",
|
||||
) -> Callable[[StateMapping], StateDict]:
|
||||
"""Create a LangGraph node that validates state and handles errors.
|
||||
|
||||
Args:
|
||||
@@ -150,7 +158,7 @@ def create_validation_node(
|
||||
Returns:
|
||||
A LangGraph node function that validates state
|
||||
"""
|
||||
def validation_node(state: Mapping[str, Any]) -> dict[str, Any]:
|
||||
def validation_node(state: StateMapping) -> StateDict:
|
||||
"""Validate state and update error status if validation fails."""
|
||||
is_valid = validator(state)
|
||||
|
||||
@@ -163,25 +171,25 @@ def create_validation_node(
|
||||
"timestamp": None, # Could add timestamp here
|
||||
})
|
||||
|
||||
return {
|
||||
"errors": errors,
|
||||
"status": "error",
|
||||
"is_output_valid": False,
|
||||
"validation_issues": [error_message],
|
||||
}
|
||||
|
||||
return {
|
||||
"is_output_valid": True,
|
||||
"validation_issues": [],
|
||||
}
|
||||
return {
|
||||
"errors": errors,
|
||||
"status": "error",
|
||||
"is_output_valid": False,
|
||||
"validation_issues": [error_message],
|
||||
}
|
||||
|
||||
return {
|
||||
"is_output_valid": True,
|
||||
"validation_issues": [],
|
||||
}
|
||||
|
||||
return validation_node
|
||||
|
||||
|
||||
def create_interrupt_validation_node(
|
||||
validator: StateValidator,
|
||||
interrupt_message: str = "Human validation required",
|
||||
) -> Callable[[Mapping[str, Any]], dict[str, Any]]:
|
||||
def create_interrupt_validation_node(
|
||||
validator: StateValidator,
|
||||
interrupt_message: str = "Human validation required",
|
||||
) -> Callable[[StateMapping], StateDict]:
|
||||
"""Create a LangGraph node that validates and interrupts for human feedback.
|
||||
|
||||
Args:
|
||||
@@ -191,7 +199,7 @@ def create_interrupt_validation_node(
|
||||
Returns:
|
||||
A LangGraph node function that validates and may interrupt
|
||||
"""
|
||||
def interrupt_validation_node(state: Mapping[str, Any]) -> dict[str, Any]:
|
||||
def interrupt_validation_node(state: StateMapping) -> StateDict:
|
||||
"""Validate state and interrupt if validation fails."""
|
||||
is_valid = validator(state)
|
||||
|
||||
@@ -211,11 +219,11 @@ def create_interrupt_validation_node(
|
||||
return interrupt_validation_node
|
||||
|
||||
|
||||
def initialize_base_state(
|
||||
thread_id: str,
|
||||
config: dict[str, Any] | None = None,
|
||||
status: Literal["pending", "running", "success", "error", "interrupted"] = "pending",
|
||||
) -> dict[str, Any]:
|
||||
def initialize_base_state(
|
||||
thread_id: str,
|
||||
config: Mapping[str, object] | None = None,
|
||||
status: Literal["pending", "running", "success", "error", "interrupted"] = "pending",
|
||||
) -> StateDict:
|
||||
"""Initialize base state with proper defaults for LangGraph workflows.
|
||||
|
||||
Args:
|
||||
@@ -229,7 +237,7 @@ def initialize_base_state(
|
||||
return {
|
||||
"messages": [],
|
||||
"errors": [],
|
||||
"config": config or {},
|
||||
"config": dict(config) if config is not None else {},
|
||||
"thread_id": thread_id,
|
||||
"status": status,
|
||||
"is_output_valid": None,
|
||||
@@ -237,9 +245,9 @@ def initialize_base_state(
|
||||
}
|
||||
|
||||
|
||||
def initialize_search_state(
|
||||
base_state: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
def initialize_search_state(
|
||||
base_state: StateDict | None = None,
|
||||
) -> StateDict:
|
||||
"""Initialize search-related state fields with proper defaults.
|
||||
|
||||
Args:
|
||||
@@ -248,7 +256,7 @@ def initialize_search_state(
|
||||
Returns:
|
||||
State dictionary with search fields initialized
|
||||
"""
|
||||
state = base_state.copy() if base_state else {}
|
||||
state: StateDict = base_state.copy() if base_state else {}
|
||||
|
||||
search_defaults = {
|
||||
"search_results": [],
|
||||
@@ -263,9 +271,9 @@ def initialize_search_state(
|
||||
return updater.build()
|
||||
|
||||
|
||||
def initialize_analysis_state(
|
||||
base_state: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
def initialize_analysis_state(
|
||||
base_state: StateDict | None = None,
|
||||
) -> StateDict:
|
||||
"""Initialize analysis-related state fields with proper defaults.
|
||||
|
||||
Args:
|
||||
@@ -274,7 +282,7 @@ def initialize_analysis_state(
|
||||
Returns:
|
||||
State dictionary with analysis fields initialized
|
||||
"""
|
||||
state = base_state.copy() if base_state else {}
|
||||
state: StateDict = base_state.copy() if base_state else {}
|
||||
|
||||
analysis_defaults = {
|
||||
"visualizations": [],
|
||||
@@ -286,11 +294,11 @@ def initialize_analysis_state(
|
||||
return updater.build()
|
||||
|
||||
|
||||
def create_safe_state_initializer(
|
||||
state_type: type[T],
|
||||
default_values: dict[str, Any] | None = None,
|
||||
required_fields: set[str] | None = None,
|
||||
) -> Callable[..., T]:
|
||||
def create_safe_state_initializer(
|
||||
state_type: type[T],
|
||||
default_values: Mapping[str, object] | None = None,
|
||||
required_fields: set[str] | None = None,
|
||||
) -> Callable[..., T]:
|
||||
"""Create a safe state initializer that prevents KeyError crashes.
|
||||
|
||||
Args:
|
||||
@@ -301,13 +309,15 @@ def create_safe_state_initializer(
|
||||
Returns:
|
||||
A function that safely initializes the state type
|
||||
"""
|
||||
def initializer(**kwargs: Any) -> T:
|
||||
"""Initialize state with safe defaults and validation."""
|
||||
# Start with provided defaults
|
||||
state_dict = (default_values or {}).copy()
|
||||
|
||||
# Update with provided kwargs
|
||||
state_dict.update(kwargs)
|
||||
def initializer(**kwargs: object) -> T:
|
||||
"""Initialize state with safe defaults and validation."""
|
||||
|
||||
base_defaults: dict[str, object] = {}
|
||||
if default_values is not None:
|
||||
base_defaults = copy.deepcopy(dict(default_values))
|
||||
|
||||
state_dict: dict[str, object] = base_defaults
|
||||
state_dict.update(kwargs)
|
||||
|
||||
# Validate required fields if specified
|
||||
if required_fields:
|
||||
@@ -319,14 +329,14 @@ def create_safe_state_initializer(
|
||||
return initializer
|
||||
|
||||
|
||||
def create_business_buddy_state_initializer() -> Callable[..., dict[str, Any]]:
|
||||
def create_business_buddy_state_initializer() -> Callable[..., StateDict]:
|
||||
"""Create a safe initializer for BusinessBuddyState with all defaults.
|
||||
|
||||
Returns:
|
||||
A function that initializes BusinessBuddyState with safe defaults
|
||||
"""
|
||||
# Define comprehensive defaults for all list fields to prevent KeyErrors
|
||||
defaults = {
|
||||
defaults: StateDict = {
|
||||
# Core required fields (should be provided by caller)
|
||||
"messages": [],
|
||||
"errors": [],
|
||||
@@ -412,10 +422,10 @@ def create_business_buddy_state_initializer() -> Callable[..., dict[str, Any]]:
|
||||
|
||||
required_fields = {"thread_id", "status", "config"}
|
||||
|
||||
def initializer(**kwargs: Any) -> dict[str, Any]:
|
||||
def initializer(**kwargs: object) -> StateDict:
|
||||
"""Initialize BusinessBuddyState with comprehensive defaults."""
|
||||
# Start with defaults
|
||||
state_dict = defaults.copy()
|
||||
state_dict: StateDict = defaults.copy()
|
||||
|
||||
# Update with provided values
|
||||
state_dict.update(kwargs)
|
||||
@@ -428,7 +438,9 @@ def create_business_buddy_state_initializer() -> Callable[..., dict[str, Any]]:
|
||||
return initializer
|
||||
|
||||
|
||||
def create_state_access_helper(state: dict[str, Any]) -> Callable[..., Any]:
|
||||
def create_state_access_helper(
|
||||
state: StateMapping,
|
||||
) -> Callable[[str, object | None], object | None]:
|
||||
"""Create a safe state accessor that prevents KeyError crashes.
|
||||
|
||||
Args:
|
||||
@@ -437,17 +449,17 @@ def create_state_access_helper(state: dict[str, Any]) -> Callable[..., Any]:
|
||||
Returns:
|
||||
A function that safely gets values with defaults
|
||||
"""
|
||||
def safe_get(key: str, default: Any = None) -> Any:
|
||||
def safe_get(key: str, default: object | None = None) -> object | None:
|
||||
"""Safely get a value from state with default fallback."""
|
||||
return state.get(key, default)
|
||||
|
||||
return safe_get
|
||||
|
||||
|
||||
def validate_state_structure(
|
||||
state: dict[str, Any],
|
||||
expected_structure: dict[str, type],
|
||||
) -> list[str]:
|
||||
def validate_state_structure(
|
||||
state: StateMapping,
|
||||
expected_structure: Mapping[str, type[object]],
|
||||
) -> list[str]:
|
||||
"""Validate that state has the expected structure and types.
|
||||
|
||||
Args:
|
||||
@@ -473,10 +485,10 @@ def validate_state_structure(
|
||||
return errors
|
||||
|
||||
|
||||
def create_state_transition_validator(
|
||||
pre_transition_checks: list[StateValidator] | None = None,
|
||||
post_transition_checks: list[StateValidator] | None = None,
|
||||
) -> Callable[[dict[str, Any], dict[str, Any]], tuple[bool, list[str]]]:
|
||||
def create_state_transition_validator(
|
||||
pre_transition_checks: Sequence[StateValidator] | None = None,
|
||||
post_transition_checks: Sequence[StateValidator] | None = None,
|
||||
) -> Callable[[StateMapping, StateMapping], tuple[bool, list[str]]]:
|
||||
"""Create a validator for state transitions.
|
||||
|
||||
Args:
|
||||
@@ -486,10 +498,10 @@ def create_state_transition_validator(
|
||||
Returns:
|
||||
A function that validates state transitions
|
||||
"""
|
||||
def validate_transition(
|
||||
old_state: dict[str, Any],
|
||||
new_state: dict[str, Any]
|
||||
) -> tuple[bool, list[str]]:
|
||||
def validate_transition(
|
||||
old_state: StateMapping,
|
||||
new_state: StateMapping,
|
||||
) -> tuple[bool, list[str]]:
|
||||
"""Validate a state transition.
|
||||
|
||||
Returns:
|
||||
@@ -514,7 +526,7 @@ def create_state_transition_validator(
|
||||
return validate_transition
|
||||
|
||||
|
||||
def validate_message_list(messages: Any) -> bool:
|
||||
def validate_message_list(messages: object) -> bool:
|
||||
"""Validate that messages field contains proper LangChain message objects.
|
||||
|
||||
Args:
|
||||
@@ -542,7 +554,7 @@ def create_message_validator() -> StateValidator:
|
||||
Returns:
|
||||
A validator that checks message field integrity
|
||||
"""
|
||||
def message_validator(state: Mapping[str, Any]) -> bool:
|
||||
def message_validator(state: StateMapping) -> bool:
|
||||
"""Validate message fields in state."""
|
||||
if "messages" in state:
|
||||
return validate_message_list(state["messages"])
|
||||
@@ -551,80 +563,235 @@ def create_message_validator() -> StateValidator:
|
||||
return message_validator
|
||||
|
||||
|
||||
# Custom reducer functions for LangGraph state management
|
||||
|
||||
def unique_add(left: list[Any] | None, right: list[Any] | Any | None) -> list[Any]:
|
||||
"""Add items to list while avoiding duplicates.
|
||||
|
||||
Args:
|
||||
left: Existing list
|
||||
right: Items to add (can be single item or list)
|
||||
|
||||
Returns:
|
||||
Combined list without duplicates
|
||||
"""
|
||||
if left is None:
|
||||
left = []
|
||||
|
||||
# Convert single item to list
|
||||
if not isinstance(right, list):
|
||||
right = [right] if right is not None else []
|
||||
|
||||
# Combine and deduplicate while preserving order
|
||||
seen = set(left)
|
||||
result = list(left)
|
||||
|
||||
for item in right:
|
||||
if item not in seen:
|
||||
result.append(item)
|
||||
seen.add(item)
|
||||
|
||||
return result
|
||||
# Custom reducer functions for LangGraph state management
|
||||
|
||||
|
||||
def _as_object_list(value: Sequence[object] | object | None) -> list[object]:
|
||||
"""Normalize a value into a list while treating strings as scalars."""
|
||||
|
||||
if value is None:
|
||||
return []
|
||||
|
||||
if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
|
||||
return list(value)
|
||||
|
||||
return [value]
|
||||
|
||||
|
||||
def unique_add(
|
||||
left: Sequence[object] | None, right: Sequence[object] | object | None
|
||||
) -> list[object]:
|
||||
"""Add items to list while avoiding duplicates."""
|
||||
existing = list(left) if left is not None else []
|
||||
additions = _as_object_list(right)
|
||||
|
||||
def _dedup_key(
|
||||
item: object,
|
||||
*,
|
||||
_seen_ids: set[int] | None = None,
|
||||
_depth: int = 0,
|
||||
) -> object:
|
||||
if _seen_ids is None:
|
||||
_seen_ids = set()
|
||||
if _depth > 10:
|
||||
return ("max_depth", type(item))
|
||||
|
||||
obj_id = id(item)
|
||||
if obj_id in _seen_ids:
|
||||
return ("cycle", type(item))
|
||||
_seen_ids.add(obj_id)
|
||||
|
||||
try:
|
||||
hash(item)
|
||||
except TypeError:
|
||||
if isinstance(item, Mapping):
|
||||
key = (
|
||||
"mapping",
|
||||
type(item),
|
||||
tuple(
|
||||
sorted(
|
||||
(
|
||||
key,
|
||||
_dedup_key(value, _seen_ids=_seen_ids, _depth=_depth + 1),
|
||||
)
|
||||
for key, value in item.items()
|
||||
)
|
||||
),
|
||||
)
|
||||
elif isinstance(item, Sequence) and not isinstance(item, (str, bytes, bytearray)):
|
||||
key = (
|
||||
"sequence",
|
||||
type(item),
|
||||
tuple(
|
||||
_dedup_key(value, _seen_ids=_seen_ids, _depth=_depth + 1)
|
||||
for value in item
|
||||
),
|
||||
)
|
||||
else:
|
||||
attrs: dict[str, str] = {}
|
||||
for attr in ("message", "code", "category"):
|
||||
try:
|
||||
value = getattr(item, attr)
|
||||
except Exception:
|
||||
continue
|
||||
else:
|
||||
if value is not None:
|
||||
attrs[attr] = repr(value)
|
||||
if attrs:
|
||||
key = (
|
||||
"object_attrs",
|
||||
type(item),
|
||||
tuple(sorted(attrs.items())),
|
||||
)
|
||||
else:
|
||||
key = ("identity", type(item), id(item))
|
||||
else:
|
||||
key = ("hashable", type(item), item)
|
||||
finally:
|
||||
_seen_ids.discard(obj_id)
|
||||
|
||||
return key
|
||||
|
||||
seen = {_dedup_key(item) for item in existing}
|
||||
|
||||
for item in additions:
|
||||
key = _dedup_key(item)
|
||||
if key not in seen:
|
||||
existing.append(item)
|
||||
seen.add(key)
|
||||
|
||||
return existing
|
||||
|
||||
|
||||
def error_merge(left: Any, right: Any) -> list[Any]:
|
||||
"""Merge error lists while avoiding duplicates based on message.
|
||||
|
||||
This function handles ErrorInfo TypedDict objects and provides deduplication
|
||||
based on error message content.
|
||||
|
||||
Args:
|
||||
left: Existing error list (can be None)
|
||||
right: New errors to add (can be None, single item, or list)
|
||||
|
||||
Returns:
|
||||
Combined error list without duplicate messages
|
||||
"""
|
||||
if left is None:
|
||||
left = []
|
||||
elif not isinstance(left, list):
|
||||
left = [left]
|
||||
|
||||
if right is None:
|
||||
right = []
|
||||
elif not isinstance(right, list):
|
||||
right = [right]
|
||||
|
||||
# Create set of existing error messages for deduplication
|
||||
existing_messages = set()
|
||||
for error in left:
|
||||
if isinstance(error, dict):
|
||||
existing_messages.add(error.get("message", ""))
|
||||
else:
|
||||
existing_messages.add(str(error))
|
||||
|
||||
result = list(left)
|
||||
for error in right:
|
||||
if isinstance(error, dict):
|
||||
error_message = error.get("message", "")
|
||||
else:
|
||||
error_message = str(error)
|
||||
|
||||
if error_message not in existing_messages:
|
||||
result.append(error)
|
||||
existing_messages.add(error_message)
|
||||
|
||||
return result
|
||||
def error_merge(
|
||||
left: Sequence[object] | object | None,
|
||||
right: Sequence[object] | object | None,
|
||||
*,
|
||||
max_errors: int = 1000,
|
||||
_max_depth: int = 5,
|
||||
_max_items: int = 50,
|
||||
) -> list[object]:
|
||||
"""Merge error lists while avoiding duplicates using robust keys with predictable performance."""
|
||||
|
||||
capacity = max(1, int(max_errors)) if isinstance(max_errors, int) and max_errors > 0 else 1000
|
||||
existing = _as_object_list(left)
|
||||
if len(existing) > capacity:
|
||||
existing = existing[-capacity:]
|
||||
additions = _as_object_list(right)
|
||||
|
||||
def _safe_json(obj: object) -> str:
|
||||
if json is None: # pragma: no cover - json import failed
|
||||
return repr(obj)
|
||||
try:
|
||||
class SafeEncoder(json.JSONEncoder): # type: ignore[name-defined]
|
||||
def default(self, o: object) -> object: # type: ignore[override]
|
||||
try:
|
||||
return str(o)
|
||||
except Exception:
|
||||
return repr(o)
|
||||
|
||||
return json.dumps(obj, sort_keys=True, ensure_ascii=False, cls=SafeEncoder) # type: ignore[name-defined]
|
||||
except Exception:
|
||||
return repr(obj)
|
||||
|
||||
def _stable_repr(value: object, depth: int = 0) -> str:
|
||||
if depth >= _max_depth:
|
||||
return f"<max_depth:{type(value).__name__}>"
|
||||
try:
|
||||
if isinstance(value, Mapping):
|
||||
items = list(value.items())[:_max_items]
|
||||
try:
|
||||
normalized = {str(k): _stable_repr(v, depth + 1) for k, v in items}
|
||||
except Exception:
|
||||
normalized = {str(k): repr(v) for k, v in items}
|
||||
return _safe_json(dict(sorted(normalized.items(), key=lambda kv: kv[0])))
|
||||
if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
|
||||
seq = list(value)[:_max_items]
|
||||
try:
|
||||
normalized_list = [_stable_repr(v, depth + 1) for v in seq]
|
||||
except Exception:
|
||||
normalized_list = [repr(v) for v in seq]
|
||||
return _safe_json(normalized_list)
|
||||
return repr(value)
|
||||
except Exception:
|
||||
return repr(value)
|
||||
|
||||
def _exception_key(exc: BaseException) -> tuple[object, ...]:
|
||||
try:
|
||||
args_repr = tuple(str(arg) for arg in getattr(exc, "args", ()))
|
||||
except Exception:
|
||||
args_repr = ()
|
||||
code = getattr(exc, "code", None)
|
||||
message = getattr(exc, "message", None)
|
||||
return (
|
||||
"exception",
|
||||
exc.__class__.__name__,
|
||||
args_repr,
|
||||
str(code) if code is not None else None,
|
||||
str(message) if message is not None else None,
|
||||
)
|
||||
|
||||
def _message_key(item: object, depth: int = 0) -> tuple[object, ...]:
|
||||
if depth >= _max_depth:
|
||||
return ("truncated", type(item).__name__)
|
||||
if isinstance(item, Mapping):
|
||||
code = item.get("code")
|
||||
message = item.get("message")
|
||||
detail = item.get("detail") or item.get("details")
|
||||
category = item.get("category")
|
||||
if code is None and message is None and detail is None and category is None:
|
||||
truncated_items = tuple(
|
||||
sorted(
|
||||
(str(k), _stable_repr(v, depth + 1))
|
||||
for k, v in list(item.items())[:_max_items]
|
||||
)
|
||||
)
|
||||
return ("mapping_fallback", type(item).__name__, _safe_json(truncated_items))
|
||||
return (
|
||||
"mapping",
|
||||
type(item).__name__,
|
||||
str(code) if code is not None else None,
|
||||
str(message) if message is not None else None,
|
||||
str(detail) if detail is not None else None,
|
||||
str(category) if category is not None else None,
|
||||
)
|
||||
if isinstance(item, BaseException):
|
||||
return _exception_key(item)
|
||||
return ("other", item.__class__.__name__, _stable_repr(item, depth + 1))
|
||||
|
||||
result: deque[object] = deque()
|
||||
key_order: deque[tuple[object, ...]] = deque()
|
||||
key_counts: dict[tuple[object, ...], int] = {}
|
||||
|
||||
for existing_item in existing:
|
||||
key = _message_key(existing_item)
|
||||
if key in key_counts:
|
||||
continue
|
||||
result.append(existing_item)
|
||||
key_order.append(key)
|
||||
key_counts[key] = key_counts.get(key, 0) + 1
|
||||
|
||||
for item in additions:
|
||||
key = _message_key(item)
|
||||
if key in key_counts:
|
||||
continue
|
||||
result.append(item)
|
||||
key_order.append(key)
|
||||
key_counts[key] = 1
|
||||
if len(result) > capacity:
|
||||
evict_count = len(result) - capacity
|
||||
for _ in range(evict_count):
|
||||
evicted_key = key_order.popleft()
|
||||
result.popleft()
|
||||
count = key_counts.get(evicted_key)
|
||||
if count is None:
|
||||
continue
|
||||
if count <= 1:
|
||||
del key_counts[evicted_key]
|
||||
else:
|
||||
key_counts[evicted_key] = count - 1
|
||||
|
||||
return list(result)
|
||||
|
||||
|
||||
def status_override(left: str, right: str | None) -> str:
|
||||
@@ -640,7 +807,7 @@ def status_override(left: str, right: str | None) -> str:
|
||||
return right if right is not None else left
|
||||
|
||||
|
||||
def config_merge(left: dict[str, Any], right: dict[str, Any]) -> dict[str, Any]:
|
||||
def config_merge(left: StateDict, right: Mapping[str, object]) -> StateDict:
|
||||
"""Deep merge configuration dictionaries.
|
||||
|
||||
Args:
|
||||
@@ -650,22 +817,26 @@ def config_merge(left: dict[str, Any], right: dict[str, Any]) -> dict[str, Any]:
|
||||
Returns:
|
||||
Merged configuration dictionary
|
||||
"""
|
||||
# left and right are guaranteed to be dict[str, Any] by type annotations
|
||||
# ``left`` and ``right`` are guaranteed to use ``StateDict``-compatible payloads.
|
||||
|
||||
result = left.copy()
|
||||
result: StateDict = left.copy()
|
||||
|
||||
for key, value in right.items():
|
||||
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
|
||||
# Recursively merge nested dictionaries
|
||||
result[key] = config_merge(result[key], value)
|
||||
else:
|
||||
# Override with new value
|
||||
result[key] = value
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def safe_list_add(left: list[Any] | None, right: list[Any] | Any | None) -> list[Any]:
|
||||
merged_value = config_merge(cast(StateDict, result[key]), value)
|
||||
result[key] = merged_value
|
||||
else:
|
||||
# Override with new value
|
||||
result[key] = value
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def safe_list_add(
|
||||
left: Sequence[object] | None,
|
||||
right: Sequence[object] | object | None,
|
||||
) -> list[object]:
|
||||
"""Safely add items to a list, handling None values gracefully.
|
||||
|
||||
Args:
|
||||
@@ -675,19 +846,16 @@ def safe_list_add(left: list[Any] | None, right: list[Any] | Any | None) -> list
|
||||
Returns:
|
||||
Combined list with proper None handling
|
||||
"""
|
||||
if left is None:
|
||||
left = []
|
||||
|
||||
if right is None:
|
||||
return list(left)
|
||||
|
||||
if not isinstance(right, list):
|
||||
right = [right]
|
||||
|
||||
return list(left) + list(right)
|
||||
|
||||
|
||||
def dict_list_merge(left: list[dict[str, Any]] | None, right: list[dict[str, Any]] | dict[str, Any] | None) -> list[dict[str, Any]]:
|
||||
existing = list(left) if left is not None else []
|
||||
additions = _as_object_list(right)
|
||||
|
||||
return existing + additions
|
||||
|
||||
|
||||
def dict_list_merge(
|
||||
left: Sequence[Mapping[str, object]] | None,
|
||||
right: Sequence[Mapping[str, object]] | Mapping[str, object] | None,
|
||||
) -> list[Mapping[str, object]]:
|
||||
"""Merge lists of dictionaries, useful for complex state objects.
|
||||
|
||||
Args:
|
||||
@@ -697,13 +865,10 @@ def dict_list_merge(left: list[dict[str, Any]] | None, right: list[dict[str, Any
|
||||
Returns:
|
||||
Combined list of dictionaries
|
||||
"""
|
||||
if left is None:
|
||||
left = []
|
||||
|
||||
if right is None:
|
||||
return list(left)
|
||||
|
||||
if isinstance(right, dict):
|
||||
right = [right]
|
||||
|
||||
return list(left) + list(right)
|
||||
existing = list(left) if left is not None else []
|
||||
|
||||
if right is None:
|
||||
return existing
|
||||
|
||||
additions = [right] if isinstance(right, Mapping) else list(right)
|
||||
return existing + additions
|
||||
|
||||
@@ -25,8 +25,8 @@ The Business Buddy framework provides extensive utilities through the core packa
|
||||
- **`normalize_errors_to_list`**: Ensure errors are always list format
|
||||
|
||||
### Type Compatibility
|
||||
- **`create_type_safe_wrapper`**: Wrap functions for LangGraph type safety
|
||||
- **`wrap_for_langgraph`**: Decorator for type-safe conditional edges
|
||||
- **Mapping-based routers**: `route_error_severity` and `route_llm_output` accept LangGraph state directly
|
||||
- **`StateProtocol` utilities**: Build custom routers that interoperate with both dict and TypedDict states
|
||||
|
||||
### Caching Infrastructure
|
||||
- **`LLMCache`**: Specialized cache for LLM responses
|
||||
@@ -126,38 +126,34 @@ def create_state_from_input(user_input: str | dict) -> dict:
|
||||
return state
|
||||
```
|
||||
|
||||
### Type-Safe Wrappers for LangGraph
|
||||
### Typed Routers for LangGraph
|
||||
|
||||
```python
|
||||
from biz_bud.core.langgraph import (
|
||||
create_type_safe_wrapper,
|
||||
wrap_for_langgraph,
|
||||
route_error_severity,
|
||||
route_llm_output
|
||||
)
|
||||
from biz_bud.core.langgraph import route_error_severity, route_llm_output
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
# Create type-safe wrappers for routing functions
|
||||
safe_error_router = create_type_safe_wrapper(route_error_severity)
|
||||
safe_llm_router = create_type_safe_wrapper(route_llm_output)
|
||||
|
||||
# Or use decorator pattern
|
||||
@wrap_for_langgraph(dict)
|
||||
def custom_router(state: MyTypedState) -> str:
|
||||
"""Custom routing logic with automatic type casting."""
|
||||
return route_error_severity(state)
|
||||
|
||||
# Use in graph construction
|
||||
# Router functions operate on ``Mapping[str, object]`` so TypedDict states work
|
||||
# with LangGraph without additional adapters.
|
||||
builder = StateGraph(MyTypedState)
|
||||
builder.add_conditional_edges(
|
||||
"process_node",
|
||||
safe_error_router, # No type errors!
|
||||
route_error_severity,
|
||||
{
|
||||
"retry": "process_node",
|
||||
"error": "error_handler",
|
||||
"continue": "next_node"
|
||||
}
|
||||
)
|
||||
|
||||
builder.add_conditional_edges(
|
||||
"call_model",
|
||||
route_llm_output,
|
||||
{
|
||||
"tool_executor": "tools",
|
||||
"output": "END",
|
||||
"error_handling": "error_handler"
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### Async/Sync Factory Pattern
|
||||
@@ -592,7 +588,6 @@ from biz_bud.core.langgraph import (
|
||||
standard_node,
|
||||
handle_errors,
|
||||
log_node_execution,
|
||||
create_type_safe_wrapper,
|
||||
route_error_severity,
|
||||
route_llm_output
|
||||
)
|
||||
@@ -609,10 +604,6 @@ class ProductionState(InputState):
|
||||
results: Annotated[list[dict], operator.add] = []
|
||||
confidence_score: float = 0.0
|
||||
|
||||
# Create type-safe routers
|
||||
safe_error_router = create_type_safe_wrapper(route_error_severity)
|
||||
safe_llm_router = create_type_safe_wrapper(route_llm_output)
|
||||
|
||||
@standard_node
|
||||
@handle_errors
|
||||
@log_node_execution
|
||||
@@ -670,7 +661,7 @@ async def create_production_graph_with_caching():
|
||||
# Use type-safe routing
|
||||
builder.add_conditional_edges(
|
||||
"process",
|
||||
safe_error_router,
|
||||
route_error_severity,
|
||||
{
|
||||
"retry": "process",
|
||||
"error": "error_handler",
|
||||
@@ -900,7 +891,7 @@ async def process_large_dataset(items: list[dict]) -> list[dict]:
|
||||
7. **Reuse Edge Helpers**: Use project's routing utilities to prevent duplication
|
||||
8. **Leverage Core Utilities**:
|
||||
- `generate_config_hash` for cache keys
|
||||
- `create_type_safe_wrapper` for LangGraph compatibility
|
||||
- Mapping-based routers (`route_error_severity`, `route_llm_output`) for LangGraph compatibility
|
||||
- `normalize_errors_to_list` for consistent error handling
|
||||
- `GraphCache` for thread-safe graph caching
|
||||
- `handle_sync_async_context` for context-aware execution
|
||||
@@ -942,8 +933,6 @@ from biz_bud.core.utils import normalize_errors_to_list
|
||||
|
||||
# Type compatibility
|
||||
from biz_bud.core.langgraph import (
|
||||
create_type_safe_wrapper,
|
||||
wrap_for_langgraph,
|
||||
standard_node,
|
||||
handle_errors,
|
||||
log_node_execution,
|
||||
|
||||
@@ -7,10 +7,11 @@ data visualization, trend analysis, and business insights generation.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, TypedDict, cast
|
||||
from collections.abc import Mapping
|
||||
from typing import TypedDict, cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph.state import CachePolicy, RetryPolicy
|
||||
from langgraph.graph.state import CachePolicy, CompiledStateGraph, RetryPolicy
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from biz_bud.core.edge_helpers.error_handling import handle_error
|
||||
@@ -23,9 +24,6 @@ from biz_bud.core.langgraph.graph_builder import (
|
||||
)
|
||||
from biz_bud.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
# Import consolidated nodes
|
||||
# Import analysis-specific nodes from local module
|
||||
from biz_bud.graphs.analysis.nodes import (
|
||||
@@ -49,8 +47,8 @@ class AnalysisGraphInput(BaseModel):
|
||||
"""Input schema for the analysis graph."""
|
||||
|
||||
task: str = Field(description="The analysis task or request")
|
||||
data: dict[str, Any] | None = Field(
|
||||
default=None, description="Data to analyze (DataFrames, dicts, etc.)"
|
||||
data: object | None = Field(
|
||||
default=None, description="Data to analyze (DataFrames, mappings, etc.)"
|
||||
)
|
||||
include_visualizations: bool = Field(
|
||||
default=True, description="Whether to generate visualizations"
|
||||
@@ -69,11 +67,11 @@ class AnalysisGraphContext(TypedDict, total=False):
|
||||
class AnalysisGraphOutput(TypedDict, total=False):
|
||||
"""Output schema describing the terminal payload from the analysis graph."""
|
||||
|
||||
report: dict[str, Any] | str | None
|
||||
analysis_results: dict[str, Any] | None
|
||||
visualizations: list[dict[str, Any]]
|
||||
report: dict[str, object] | str | None
|
||||
analysis_results: dict[str, object] | None
|
||||
visualizations: list[dict[str, object]]
|
||||
status: str
|
||||
errors: list[dict[str, Any]]
|
||||
errors: list[dict[str, object]]
|
||||
|
||||
|
||||
# Graph metadata for dynamic discovery
|
||||
@@ -127,7 +125,7 @@ _handle_analysis_errors = handle_error(
|
||||
)
|
||||
|
||||
|
||||
def create_analysis_graph() -> "CompiledStateGraph":
|
||||
def create_analysis_graph() -> CompiledStateGraph[AnalysisState]:
|
||||
"""Create the data analysis workflow graph.
|
||||
|
||||
This graph implements a comprehensive analysis workflow:
|
||||
@@ -263,7 +261,9 @@ def create_analysis_graph() -> "CompiledStateGraph":
|
||||
return build_graph_from_config(config)
|
||||
|
||||
|
||||
def analysis_graph_factory(config: RunnableConfig) -> "CompiledStateGraph":
|
||||
def analysis_graph_factory(
|
||||
config: RunnableConfig,
|
||||
) -> CompiledStateGraph[AnalysisState]:
|
||||
"""Create analysis graph for graph-as-tool pattern.
|
||||
|
||||
This factory function follows the standard pattern for graphs
|
||||
@@ -288,7 +288,9 @@ def analysis_graph_factory(config: RunnableConfig) -> "CompiledStateGraph":
|
||||
|
||||
|
||||
# Async factory function for LangGraph API
|
||||
async def analysis_graph_factory_async(config: RunnableConfig) -> Any: # noqa: ANN401
|
||||
async def analysis_graph_factory_async(
|
||||
config: RunnableConfig,
|
||||
) -> CompiledStateGraph[AnalysisState]:
|
||||
"""Async wrapper for analysis_graph_factory to avoid blocking calls."""
|
||||
import asyncio
|
||||
return await asyncio.to_thread(analysis_graph_factory, config)
|
||||
@@ -301,9 +303,9 @@ analysis_graph = create_analysis_graph()
|
||||
# Helper function for easy invocation
|
||||
async def analyze_data(
|
||||
task: str,
|
||||
data: dict[str, Any] | None = None,
|
||||
data: object | None = None,
|
||||
include_visualizations: bool = True,
|
||||
config: dict[str, Any] | None = None,
|
||||
config: Mapping[str, object] | None = None,
|
||||
) -> AnalysisState:
|
||||
"""Analyze data using the analysis workflow.
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from collections.abc import Mapping
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
try: # pragma: no cover - optional checkpoint backend
|
||||
@@ -10,6 +10,7 @@ try: # pragma: no cover - optional checkpoint backend
|
||||
except ModuleNotFoundError: # pragma: no cover - fallback when postgres extra missing
|
||||
PostgresSaver = None # type: ignore[assignment]
|
||||
from langgraph.graph import END, StateGraph
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from biz_bud.core.edge_helpers.core import create_bool_router, create_enum_router
|
||||
from biz_bud.core.edge_helpers.error_handling import handle_error
|
||||
@@ -20,9 +21,6 @@ from biz_bud.core.langgraph.graph_builder import (
|
||||
build_graph_from_config,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from biz_bud.nodes.error_handling import (
|
||||
error_analyzer_node,
|
||||
error_interceptor_node,
|
||||
@@ -64,7 +62,7 @@ GRAPH_METADATA = {
|
||||
|
||||
def create_error_handling_graph(
|
||||
checkpointer: PostgresSaver | None = None,
|
||||
) -> "CompiledStateGraph":
|
||||
) -> CompiledStateGraph[ErrorHandlingState]:
|
||||
"""Create the error handling agent graph.
|
||||
|
||||
This graph can be used as a subgraph in any BizBud workflow
|
||||
@@ -141,7 +139,7 @@ def check_recovery_success(state: ErrorHandlingState) -> bool:
|
||||
return bool(state.get("recovery_success", False))
|
||||
|
||||
|
||||
def check_for_errors(state: dict[str, Any]) -> str:
|
||||
def check_for_errors(state: Mapping[str, object]) -> str:
|
||||
"""Compatibility function that checks for errors in state."""
|
||||
errors = state.get("errors", [])
|
||||
status = state.get("status", "")
|
||||
@@ -160,8 +158,8 @@ def check_error_recovery(state: ErrorHandlingState) -> str:
|
||||
|
||||
|
||||
def add_error_handling_to_graph(
|
||||
main_graph: "StateGraph",
|
||||
error_handler: "CompiledStateGraph",
|
||||
main_graph: StateGraph,
|
||||
error_handler: CompiledStateGraph[ErrorHandlingState],
|
||||
nodes_to_protect: list[str],
|
||||
error_node_name: str = "handle_error",
|
||||
next_node_mapping: dict[str, str] | None = None,
|
||||
@@ -230,7 +228,7 @@ def create_error_handling_config(
|
||||
retry_max_delay: int = 60,
|
||||
enable_llm_analysis: bool = True,
|
||||
recovery_timeout: int = 300,
|
||||
) -> dict[str, Any]:
|
||||
) -> dict[str, object]:
|
||||
"""Create error handling configuration.
|
||||
|
||||
This is a public helper function that creates standardized error handling
|
||||
@@ -300,7 +298,9 @@ def create_error_handling_config(
|
||||
}
|
||||
|
||||
|
||||
def error_handling_graph_factory(config: RunnableConfig) -> "CompiledStateGraph":
|
||||
def error_handling_graph_factory(
|
||||
config: RunnableConfig,
|
||||
) -> CompiledStateGraph[ErrorHandlingState]:
|
||||
"""Create error handling graph for LangGraph API.
|
||||
|
||||
Args:
|
||||
@@ -314,7 +314,9 @@ def error_handling_graph_factory(config: RunnableConfig) -> "CompiledStateGraph"
|
||||
|
||||
|
||||
# Async factory function for LangGraph API
|
||||
async def error_handling_graph_factory_async(config: RunnableConfig) -> Any: # noqa: ANN401
|
||||
async def error_handling_graph_factory_async(
|
||||
config: RunnableConfig,
|
||||
) -> CompiledStateGraph[ErrorHandlingState]:
|
||||
"""Async wrapper for error_handling_graph_factory to avoid blocking calls."""
|
||||
import asyncio
|
||||
return await asyncio.to_thread(error_handling_graph_factory, config)
|
||||
|
||||
@@ -17,7 +17,8 @@ The main graph represents a sophisticated agent workflow that can:
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from biz_bud.services.factory import ServiceFactory
|
||||
@@ -25,7 +26,11 @@ if TYPE_CHECKING:
|
||||
# Import existing lazy loading infrastructure
|
||||
from langchain_core.runnables import RunnableConfig, RunnableLambda
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph.state import CachePolicy, CompiledStateGraph, RetryPolicy
|
||||
from langgraph.graph.state import (
|
||||
CachePolicy,
|
||||
CompiledStateGraph as CompiledGraph,
|
||||
RetryPolicy,
|
||||
)
|
||||
|
||||
from biz_bud.core.caching import InMemoryCache
|
||||
from biz_bud.core.cleanup_registry import get_cleanup_registry
|
||||
@@ -38,10 +43,8 @@ from biz_bud.core.config.schemas import AppConfig
|
||||
from biz_bud.core.edge_helpers import create_enum_router, detect_errors_list
|
||||
from biz_bud.core.langgraph import (
|
||||
NodeConfig,
|
||||
create_type_safe_wrapper,
|
||||
handle_errors,
|
||||
log_node_execution,
|
||||
route_error_severity,
|
||||
route_llm_output,
|
||||
standard_node,
|
||||
)
|
||||
@@ -65,8 +68,8 @@ from biz_bud.nodes import call_model_node, parse_and_validate_initial_payload
|
||||
from biz_bud.services.factory import get_global_factory
|
||||
from biz_bud.states.base import InputState
|
||||
|
||||
# Type variable for state types
|
||||
StateType = TypeVar('StateType')
|
||||
StateMapping = Mapping[str, object] | InputState
|
||||
JsonDict = dict[str, object]
|
||||
|
||||
# Get logger instance
|
||||
logger = get_logger(__name__)
|
||||
@@ -90,7 +93,7 @@ def _create_async_factory_wrapper(sync_resolver_func, async_resolver_func):
|
||||
"""
|
||||
|
||||
def create_sync_factory():
|
||||
def sync_factory(config: RunnableConfig) -> CompiledStateGraph:
|
||||
def sync_factory(config: RunnableConfig) -> CompiledGraph:
|
||||
"""Create graph for LangGraph API with RunnableConfig (optimized)."""
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
@@ -109,7 +112,7 @@ def _create_async_factory_wrapper(sync_resolver_func, async_resolver_func):
|
||||
return sync_factory
|
||||
|
||||
def create_async_factory():
|
||||
async def async_factory(config: RunnableConfig) -> CompiledStateGraph:
|
||||
async def async_factory(config: RunnableConfig) -> CompiledGraph:
|
||||
"""Create graph for LangGraph API with RunnableConfig (async optimized)."""
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
@@ -130,7 +133,7 @@ def _create_async_factory_wrapper(sync_resolver_func, async_resolver_func):
|
||||
return create_sync_factory(), create_async_factory()
|
||||
|
||||
|
||||
def _handle_sync_async_context(app_config: AppConfig, service_factory: "ServiceFactory") -> CompiledStateGraph:
|
||||
def _handle_sync_async_context(app_config: AppConfig, service_factory: "ServiceFactory") -> CompiledGraph:
|
||||
"""Handle sync/async context detection for graph creation.
|
||||
|
||||
This function uses centralized context detection from core.networking
|
||||
@@ -156,7 +159,7 @@ def _handle_sync_async_context(app_config: AppConfig, service_factory: "ServiceF
|
||||
|
||||
|
||||
# Enhanced input validation using edge helpers
|
||||
def _validate_parsed_input(state: dict[str, Any] | InputState) -> bool:
|
||||
def _validate_parsed_input(state: StateMapping) -> bool:
|
||||
"""Enhanced validation for parsed input content.
|
||||
|
||||
This helper checks both the presence and validity of parsed input data
|
||||
@@ -171,7 +174,7 @@ def _validate_parsed_input(state: dict[str, Any] | InputState) -> bool:
|
||||
return bool(user_query and user_query.strip())
|
||||
|
||||
# Create comprehensive input validation router using edge helpers
|
||||
def _enhanced_input_check(state: dict[str, Any] | InputState) -> str:
|
||||
def _enhanced_input_check(state: StateMapping) -> str:
|
||||
"""Enhanced input validation that combines error detection with content validation.
|
||||
|
||||
Uses the Business Buddy pattern of checking both errors and input validity
|
||||
@@ -195,13 +198,8 @@ check_initial_input_simple = detect_errors_list(
|
||||
)
|
||||
|
||||
|
||||
# Use type-safe wrappers from core.langgraph
|
||||
|
||||
route_error_severity_wrapper = create_type_safe_wrapper(route_error_severity)
|
||||
route_llm_output_wrapper = create_type_safe_wrapper(route_llm_output)
|
||||
|
||||
# Replace manual routing with edge helper-based recovery routing
|
||||
def _determine_recovery_action(state: dict[str, Any]) -> str:
|
||||
def _determine_recovery_action(state: Mapping[str, object]) -> str:
|
||||
"""Determine recovery action based on error handler decisions.
|
||||
|
||||
This function consolidates the recovery logic and provides a clear
|
||||
@@ -230,7 +228,7 @@ route_error_recovery = create_enum_router(
|
||||
)
|
||||
|
||||
# Enhanced router that sets the recovery action in state
|
||||
def _enhanced_error_recovery(state: dict[str, Any]) -> str:
|
||||
def _enhanced_error_recovery(state: Mapping[str, object]) -> str:
|
||||
"""Enhanced error recovery that uses edge helpers pattern."""
|
||||
recovery_action = _determine_recovery_action(state)
|
||||
|
||||
@@ -273,83 +271,205 @@ GRAPH_METADATA = {
|
||||
@standard_node()
|
||||
@handle_errors()
|
||||
@log_node_execution("search")
|
||||
async def search(state: Any) -> Any: # noqa: ANN401
|
||||
async def search(state: StateMapping) -> JsonDict:
|
||||
"""Execute web search using the unified search tool with proper error handling."""
|
||||
from collections.abc import Mapping as _Mapping
|
||||
|
||||
from biz_bud.tools.capabilities.search.tool import web_search
|
||||
|
||||
# Extract search query from state
|
||||
search_query = None
|
||||
messages = state.get("messages", [])
|
||||
state_mapping = state if isinstance(state, Mapping) else {}
|
||||
|
||||
search_query: str | None = None
|
||||
messages_value = state_mapping.get("messages", [])
|
||||
messages: list[object] = []
|
||||
if isinstance(messages_value, Sequence) and not isinstance(
|
||||
messages_value, (str, bytes, bytearray)
|
||||
):
|
||||
messages = list(messages_value)
|
||||
|
||||
# Look for the most recent tool call or human message to extract query
|
||||
for message in reversed(messages):
|
||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||
# Extract query from tool call arguments
|
||||
for tool_call in message.tool_calls:
|
||||
if tool_call.get("name") == "web_search" or "search" in tool_call.get("name", ""):
|
||||
args = tool_call.get("args", {})
|
||||
search_query = args.get("query") or args.get("search_query")
|
||||
break
|
||||
if search_query:
|
||||
tool_calls: object | None = None
|
||||
if isinstance(message, Mapping):
|
||||
tool_calls = message.get("tool_calls")
|
||||
if tool_calls is None:
|
||||
additional_kwargs = message.get("additional_kwargs")
|
||||
if isinstance(additional_kwargs, Mapping):
|
||||
tool_calls = additional_kwargs.get("tool_calls")
|
||||
if tool_calls is None:
|
||||
try:
|
||||
tool_calls = getattr(message, "tool_calls")
|
||||
except Exception:
|
||||
tool_calls = None
|
||||
tool_query_found = False
|
||||
if isinstance(tool_calls, Sequence) and not isinstance(
|
||||
tool_calls, (str, bytes, bytearray)
|
||||
) and tool_calls:
|
||||
for tool_call in tool_calls:
|
||||
if not isinstance(tool_call, _Mapping):
|
||||
continue
|
||||
name = str(tool_call.get("name", ""))
|
||||
if name == "web_search" or "search" in name:
|
||||
args = tool_call.get("args", {}) or {}
|
||||
if not isinstance(args, _Mapping):
|
||||
args = {}
|
||||
candidate = (
|
||||
args.get("query")
|
||||
or args.get("search_query")
|
||||
or args.get("q")
|
||||
or args.get("keywords")
|
||||
)
|
||||
if isinstance(candidate, str) and candidate.strip():
|
||||
search_query = candidate.strip()
|
||||
tool_query_found = True
|
||||
break
|
||||
if tool_query_found:
|
||||
break
|
||||
elif hasattr(message, "content") and message.content:
|
||||
# Fall back to using message content as query
|
||||
search_query = str(message.content)[:200] # Limit query length
|
||||
# Do not fall back to message content when tool calls were present
|
||||
continue
|
||||
content: object | None = None
|
||||
if isinstance(message, Mapping):
|
||||
content = message.get("content")
|
||||
if content is None:
|
||||
try:
|
||||
content = getattr(message, "content")
|
||||
except Exception:
|
||||
content = None
|
||||
if isinstance(content, str) and content.strip():
|
||||
search_query = content.strip()
|
||||
break
|
||||
|
||||
# Default fallback query if none found
|
||||
if not search_query:
|
||||
search_query = state.get("user_query", "business intelligence search")
|
||||
if not isinstance(search_query, str) or not search_query.strip():
|
||||
user_query_value = state_mapping.get("user_query")
|
||||
if isinstance(user_query_value, str) and user_query_value.strip():
|
||||
search_query = user_query_value.strip()
|
||||
else:
|
||||
search_query = "business intelligence search"
|
||||
|
||||
logger.info(f"Executing web search for query: {search_query}")
|
||||
normalized_query = "".join(ch for ch in search_query if ch.isprintable())
|
||||
normalized_query = normalized_query.replace("\n", " ").replace("\r", " ").strip()
|
||||
safe_query = normalized_query[:200] if normalized_query else "business intelligence search"
|
||||
|
||||
def _mask_for_log(query: str) -> str:
|
||||
cleaned = "".join(ch for ch in query if ch.isprintable()).replace("\n", " ").replace("\r", " ")
|
||||
trimmed = cleaned.strip()
|
||||
return trimmed if len(trimmed) <= 32 else f"{trimmed[:29]}..."
|
||||
|
||||
logger.info("Executing web search for query: %s", _mask_for_log(safe_query))
|
||||
|
||||
status = state_mapping.get("status", "running")
|
||||
validation_issues_value = state_mapping.get("validation_issues")
|
||||
if isinstance(validation_issues_value, Sequence) and not isinstance(
|
||||
validation_issues_value, (str, bytes, bytearray)
|
||||
):
|
||||
validation_issues = list(validation_issues_value)
|
||||
else:
|
||||
validation_issues = []
|
||||
|
||||
try:
|
||||
# Use the web_search tool with proper tool invocation
|
||||
# Since web_search is decorated with @tool, we need to invoke it properly
|
||||
search_results = await web_search.ainvoke({
|
||||
"query": search_query,
|
||||
"provider": None, # Auto-select best available provider
|
||||
"max_results": 5 # Reasonable number of results for context
|
||||
raw_results = await web_search.ainvoke({
|
||||
"query": safe_query,
|
||||
"provider": None,
|
||||
"max_results": 5,
|
||||
})
|
||||
|
||||
# Format search results for the state
|
||||
if search_results:
|
||||
# Create a summary of search results
|
||||
result_summary = f"Found {len(search_results)} search results for '{search_query}':\n\n"
|
||||
for i, result in enumerate(search_results[:3], 1): # Show top 3 results
|
||||
result_summary += f"{i}. {result.get('title', 'No title')}\n"
|
||||
result_summary += f" URL: {result.get('url', 'No URL')}\n"
|
||||
if result.get('snippet'):
|
||||
result_summary += f" Summary: {result.get('snippet')[:150]}...\n"
|
||||
result_summary += "\n"
|
||||
results_list: list[dict[str, object]] = []
|
||||
MAX_FIELD_LENGTH = 500
|
||||
|
||||
# Update state with search results
|
||||
state_update = {
|
||||
**state,
|
||||
"search_results": search_results,
|
||||
"final_response": result_summary,
|
||||
"is_last_step": True
|
||||
def _normalize_text(value: object, default: str) -> str:
|
||||
text = value if isinstance(value, str) else (str(value) if value is not None else default)
|
||||
cleaned = "".join(ch for ch in text if ch.isprintable()).strip()
|
||||
if not cleaned:
|
||||
cleaned = default
|
||||
return cleaned[:MAX_FIELD_LENGTH]
|
||||
|
||||
def _normalize_url(value: object) -> str:
|
||||
raw = value if isinstance(value, str) else (str(value) if value is not None else "")
|
||||
cleaned = "".join(ch for ch in raw if ch.isprintable()).strip()
|
||||
return cleaned[:MAX_FIELD_LENGTH] if cleaned else ""
|
||||
|
||||
if isinstance(raw_results, Sequence) and not isinstance(
|
||||
raw_results, (str, bytes, bytearray)
|
||||
):
|
||||
for item in raw_results:
|
||||
if isinstance(item, _Mapping):
|
||||
title = _normalize_text(item.get("title"), "No title")
|
||||
url = _normalize_url(item.get("url")) or "No URL"
|
||||
snippet_source = item.get("snippet") or item.get("description")
|
||||
else:
|
||||
title = _normalize_text(getattr(item, "title", None), "No title")
|
||||
url = _normalize_url(getattr(item, "url", None)) or "No URL"
|
||||
snippet_source = getattr(item, "snippet", None) or getattr(
|
||||
item, "description", None
|
||||
)
|
||||
snippet = _normalize_text(snippet_source, "")
|
||||
|
||||
if title or url or snippet:
|
||||
results_list.append({"title": title, "url": url, "snippet": snippet})
|
||||
|
||||
if results_list:
|
||||
result_summary_lines: list[str] = []
|
||||
for index, result in enumerate(results_list[:3], start=1):
|
||||
title = result.get("title", "No title")
|
||||
url = result.get("url", "No URL")
|
||||
snippet = result.get("snippet") or result.get("description")
|
||||
result_summary_lines.append(f"{index}. {title}")
|
||||
result_summary_lines.append(f" URL: {url}")
|
||||
if isinstance(snippet, str) and snippet:
|
||||
safe_snippet = "".join(ch for ch in snippet if ch.isprintable()).strip()
|
||||
result_summary_lines.append(f" Summary: {safe_snippet[:150]}...")
|
||||
result_summary_lines.append("")
|
||||
|
||||
state_update: JsonDict = {
|
||||
"search_query": safe_query,
|
||||
"search_results": list(results_list),
|
||||
"final_response": "\n".join(result_summary_lines).strip(),
|
||||
"is_last_step": True,
|
||||
"status": "success" if status != "error" else status,
|
||||
"validation_issues": validation_issues,
|
||||
}
|
||||
|
||||
logger.info(f"Web search completed successfully with {len(search_results)} results")
|
||||
|
||||
logger.info(
|
||||
"Web search completed successfully with %s results",
|
||||
len(results_list),
|
||||
)
|
||||
else:
|
||||
# No results found
|
||||
masked_query = _mask_for_log(safe_query)
|
||||
state_update = {
|
||||
**state,
|
||||
"final_response": f"No search results found for query: '{search_query}'. Please try a different search term.",
|
||||
"is_last_step": True
|
||||
"search_query": safe_query,
|
||||
"final_response": (
|
||||
"No search results found for query: "
|
||||
f"'{masked_query}'. Please try a different search term."
|
||||
),
|
||||
"is_last_step": True,
|
||||
"status": "success" if status != "error" else status,
|
||||
"validation_issues": validation_issues,
|
||||
}
|
||||
logger.warning(f"No search results found for query: {search_query}")
|
||||
logger.warning("No search results found for query: %s", masked_query)
|
||||
|
||||
except Exception as e:
|
||||
# Handle search errors gracefully
|
||||
logger.error(f"Web search failed for query '{search_query}': {e}")
|
||||
except Exception as exc: # pragma: no cover - defensive logging path
|
||||
masked_query = _mask_for_log(safe_query)
|
||||
logger.error("Web search failed for query '%s': %s", masked_query, exc)
|
||||
errors_value = state_mapping.get("errors")
|
||||
if isinstance(errors_value, Sequence) and not isinstance(
|
||||
errors_value, (str, bytes, bytearray)
|
||||
):
|
||||
errors_list = list(errors_value)
|
||||
elif errors_value is None:
|
||||
errors_list = []
|
||||
else:
|
||||
errors_list = [errors_value]
|
||||
errors_list.append(f"Search error: {exc}")
|
||||
if len(errors_list) > 1000:
|
||||
errors_list = errors_list[-1000:]
|
||||
validation_issues.append("search_failed")
|
||||
state_update = {
|
||||
**state,
|
||||
"final_response": f"Search service temporarily unavailable. Error: {str(e)}",
|
||||
"search_query": safe_query,
|
||||
"final_response": (
|
||||
"Search service temporarily unavailable. Error: %s" % (exc,)
|
||||
),
|
||||
"is_last_step": True,
|
||||
"errors": state.get("errors", []) + [f"Search error: {str(e)}"]
|
||||
"errors": errors_list,
|
||||
"status": "error",
|
||||
"validation_issues": validation_issues,
|
||||
}
|
||||
|
||||
return state_update
|
||||
@@ -357,7 +477,7 @@ async def search(state: Any) -> Any: # noqa: ANN401
|
||||
|
||||
|
||||
|
||||
def create_graph() -> CompiledStateGraph:
|
||||
def create_graph() -> CompiledGraph:
|
||||
"""Build and compile the complete StateGraph for Business Buddy agent execution.
|
||||
|
||||
This function constructs the main workflow graph that serves as the execution
|
||||
@@ -443,7 +563,7 @@ def create_graph() -> CompiledStateGraph:
|
||||
),
|
||||
ConditionalEdgeConfig(
|
||||
"call_model_node",
|
||||
route_llm_output_wrapper,
|
||||
route_llm_output,
|
||||
{
|
||||
"tool_executor": "tools",
|
||||
"output": "END",
|
||||
@@ -470,7 +590,7 @@ def create_graph() -> CompiledStateGraph:
|
||||
|
||||
async def create_graph_with_services(
|
||||
app_config: AppConfig, service_factory: "ServiceFactory"
|
||||
) -> CompiledStateGraph:
|
||||
) -> CompiledGraph:
|
||||
"""Create graph with service factory injection using caching.
|
||||
|
||||
Args:
|
||||
@@ -562,8 +682,8 @@ async def _get_or_create_service_factory_async(config_hash: str, app_config: App
|
||||
|
||||
|
||||
async def create_graph_with_overrides_async(
|
||||
config: dict[str, Any],
|
||||
) -> CompiledStateGraph:
|
||||
config: Mapping[str, object],
|
||||
) -> CompiledGraph:
|
||||
"""Async version of create_graph_with_overrides.
|
||||
|
||||
Same functionality as create_graph_with_overrides but uses async config loading
|
||||
@@ -634,7 +754,7 @@ def _load_config_with_logging() -> AppConfig:
|
||||
|
||||
|
||||
|
||||
_graph_cache_manager = InMemoryCache[CompiledStateGraph](max_size=100)
|
||||
_graph_cache_manager = InMemoryCache[CompiledGraph](max_size=100)
|
||||
_graph_creation_locks: dict[str, asyncio.Lock] = {}
|
||||
_locks_lock = asyncio.Lock() # Lock for managing the locks dict
|
||||
_graph_cache_lock = asyncio.Lock() # Keep for backward compatibility
|
||||
@@ -646,7 +766,7 @@ cleanup_registry.register_cleanup("graph_cache", _graph_cache_manager.clear)
|
||||
# Configuration cache using lazy loader
|
||||
_config_loader = create_lazy_loader(_load_config_with_logging)
|
||||
|
||||
_state_template_cache: dict[str, Any] | None = None
|
||||
_state_template_cache: JsonDict | None = None
|
||||
_state_template_lock = asyncio.Lock()
|
||||
|
||||
|
||||
@@ -670,7 +790,7 @@ async def get_cached_graph(
|
||||
config_hash: str = "default",
|
||||
service_factory: "ServiceFactory | None" = None,
|
||||
use_caching: bool = True
|
||||
) -> CompiledStateGraph:
|
||||
) -> CompiledGraph:
|
||||
"""Get cached compiled graph with optional service injection using GraphCache.
|
||||
|
||||
Args:
|
||||
@@ -682,7 +802,7 @@ async def get_cached_graph(
|
||||
Compiled and cached graph instance
|
||||
"""
|
||||
# Use GraphCache for thread-safe caching
|
||||
async def build_graph_for_cache() -> CompiledStateGraph:
|
||||
async def build_graph_for_cache() -> CompiledGraph:
|
||||
logger.info(f"Creating new graph instance for config: {config_hash}")
|
||||
|
||||
# Get or create service factory
|
||||
@@ -728,7 +848,7 @@ async def get_cached_graph(
|
||||
),
|
||||
ConditionalEdgeConfig(
|
||||
"call_model_node",
|
||||
route_llm_output_wrapper,
|
||||
route_llm_output,
|
||||
{
|
||||
"tool_executor": "tools",
|
||||
"output": "END",
|
||||
@@ -798,7 +918,7 @@ async def get_cached_graph(
|
||||
return graph
|
||||
|
||||
|
||||
def get_graph() -> CompiledStateGraph:
|
||||
def get_graph() -> CompiledGraph:
|
||||
"""Get the singleton graph instance (backward compatibility)."""
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
@@ -816,9 +936,9 @@ def get_graph() -> CompiledStateGraph:
|
||||
|
||||
|
||||
# For backward compatibility - direct access (lazy initialization)
|
||||
_module_graph: CompiledStateGraph | None = None
|
||||
_module_graph: CompiledGraph | None = None
|
||||
|
||||
def get_module_graph() -> CompiledStateGraph:
|
||||
def get_module_graph() -> CompiledGraph:
|
||||
"""Get module-level graph instance (lazy initialization)."""
|
||||
global _module_graph
|
||||
if _module_graph is None:
|
||||
@@ -882,7 +1002,7 @@ def reset_caches() -> None:
|
||||
logger.info("Reset graph and state caches using CleanupRegistry")
|
||||
|
||||
|
||||
def get_cache_stats() -> dict[str, Any]:
|
||||
def get_cache_stats() -> JsonDict:
|
||||
"""Get cache statistics for monitoring using InMemoryCache.
|
||||
|
||||
Returns:
|
||||
@@ -916,12 +1036,12 @@ debug_highlight("Cache infrastructure: graphs, service factories, state template
|
||||
|
||||
|
||||
# Create lazy config access for backward compatibility
|
||||
async def get_config_dict() -> dict[str, Any]:
|
||||
async def get_config_dict() -> JsonDict:
|
||||
"""Get the config dictionary lazily (async version)."""
|
||||
app_config = await get_app_config()
|
||||
debug_highlight(f"[GRAPH_INIT] app_config: {app_config}", category="GRAPH_INIT")
|
||||
# Ensure the config has the required structure
|
||||
config_dict: dict[str, Any] = {
|
||||
config_dict: JsonDict = {
|
||||
**app_config.model_dump(), # Include all existing app_config as dict
|
||||
"llm_config": app_config.llm_config.model_dump()
|
||||
if app_config.llm_config
|
||||
@@ -942,12 +1062,12 @@ async def get_config_dict() -> dict[str, Any]:
|
||||
return config_dict
|
||||
|
||||
|
||||
def get_config_dict_sync() -> dict[str, Any]:
|
||||
def get_config_dict_sync() -> JsonDict:
|
||||
"""Get the config dictionary synchronously (backward compatibility)."""
|
||||
app_config = get_app_config_sync()
|
||||
debug_highlight(f"[GRAPH_INIT] app_config: {app_config}", category="GRAPH_INIT")
|
||||
# Ensure the config has the required structure
|
||||
config_dict: dict[str, Any] = {
|
||||
config_dict: JsonDict = {
|
||||
**app_config.model_dump(), # Include all existing app_config as dict
|
||||
"llm_config": app_config.llm_config.model_dump()
|
||||
if app_config.llm_config
|
||||
@@ -968,7 +1088,7 @@ def get_config_dict_sync() -> dict[str, Any]:
|
||||
return config_dict
|
||||
|
||||
|
||||
async def get_state_template() -> dict[str, Any]:
|
||||
async def get_state_template() -> JsonDict:
|
||||
"""Get cached state template for fast state creation.
|
||||
|
||||
Returns:
|
||||
@@ -997,10 +1117,10 @@ async def get_state_template() -> dict[str, Any]:
|
||||
|
||||
async def create_initial_state(
|
||||
query: str | None = None,
|
||||
raw_input: str | dict[str, Any] | None = None,
|
||||
messages: list[dict[str, Any]] | None = None,
|
||||
raw_input: str | Mapping[str, object] | None = None,
|
||||
messages: list[dict[str, object]] | None = None,
|
||||
thread_id: str | None = None,
|
||||
state_update: dict[str, Any] | None = None,
|
||||
state_update: Mapping[str, object] | None = None,
|
||||
) -> InputState:
|
||||
"""Create initial state with dynamic user input (optimized async version).
|
||||
|
||||
@@ -1052,10 +1172,10 @@ async def create_initial_state(
|
||||
|
||||
def create_initial_state_sync(
|
||||
query: str | None = None,
|
||||
raw_input: str | dict[str, Any] | None = None,
|
||||
messages: list[dict[str, Any]] | None = None,
|
||||
raw_input: str | Mapping[str, object] | None = None,
|
||||
messages: list[dict[str, object]] | None = None,
|
||||
thread_id: str | None = None,
|
||||
state_update: dict[str, Any] | None = None,
|
||||
state_update: Mapping[str, object] | None = None,
|
||||
) -> InputState:
|
||||
"""Create initial state synchronously (backward compatibility).
|
||||
|
||||
@@ -1101,8 +1221,8 @@ async def get_initial_state_async() -> InputState:
|
||||
|
||||
def run_graph(
|
||||
query: str | None = None,
|
||||
raw_input: str | dict[str, Any] | None = None,
|
||||
messages: list[dict[str, Any]] | None = None,
|
||||
raw_input: str | Mapping[str, object] | None = None,
|
||||
messages: list[dict[str, object]] | None = None,
|
||||
thread_id: str | None = None,
|
||||
) -> InputState:
|
||||
"""Execute the Business Buddy agent graph with optional custom input.
|
||||
@@ -1136,8 +1256,8 @@ def run_graph(
|
||||
|
||||
async def run_graph_async(
|
||||
query: str | None = None,
|
||||
raw_input: str | dict[str, Any] | None = None,
|
||||
messages: list[dict[str, Any]] | None = None,
|
||||
raw_input: str | Mapping[str, object] | None = None,
|
||||
messages: list[dict[str, object]] | None = None,
|
||||
thread_id: str | None = None,
|
||||
config_hash: str = "default",
|
||||
) -> InputState:
|
||||
|
||||
@@ -60,14 +60,14 @@ from biz_bud.tools.capabilities.external.paperless.tool import (
|
||||
from biz_bud.tools.capabilities.search.tool import list_search_providers, web_search
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from langgraph.graph.state import CompiledStateGraph as CompiledGraph
|
||||
|
||||
from biz_bud.services.factory import ServiceFactory
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Module-level caches for performance optimization
|
||||
_compiled_graph_cache: dict[str, "CompiledStateGraph"] = {}
|
||||
_compiled_graph_cache: dict[str, "CompiledGraph"] = {}
|
||||
_global_factory: ServiceFactory | None = None
|
||||
_llm_cache: LLMCache[str] | None = None
|
||||
_cache_ttl = 300 # 5 minutes default TTL for LLM responses
|
||||
@@ -676,7 +676,7 @@ def should_continue(state: dict[str, Any]) -> str:
|
||||
|
||||
def create_paperless_agent(
|
||||
config: dict[str, Any] | str | None = None,
|
||||
) -> "CompiledStateGraph":
|
||||
) -> "CompiledGraph":
|
||||
"""Create a Paperless agent using Business Buddy patterns with caching.
|
||||
|
||||
This creates an agent that can bind tools to the LLM and execute them
|
||||
|
||||
@@ -44,7 +44,7 @@ from biz_bud.states.base import BaseState
|
||||
from biz_bud.states.receipt import ReceiptState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from langgraph.graph.state import CompiledStateGraph as CompiledGraph
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -162,7 +162,7 @@ def _create_receipt_processing_graph_internal(
|
||||
config: dict[str, Any] | None = None,
|
||||
app_config: object | None = None,
|
||||
service_factory: object | None = None,
|
||||
) -> CompiledStateGraph:
|
||||
) -> CompiledGraph:
|
||||
"""Internal function to create a focused receipt processing graph using ReceiptState."""
|
||||
# Define nodes
|
||||
nodes = {
|
||||
@@ -202,7 +202,7 @@ def _create_receipt_processing_graph_internal(
|
||||
return compiled_graph
|
||||
|
||||
|
||||
def create_receipt_processing_graph(config: RunnableConfig) -> CompiledStateGraph:
|
||||
def create_receipt_processing_graph(config: RunnableConfig) -> CompiledGraph:
|
||||
"""Create a focused receipt processing graph for LangGraph API.
|
||||
|
||||
Args:
|
||||
@@ -218,7 +218,7 @@ def create_receipt_processing_graph_direct(
|
||||
config: dict[str, Any] | None = None,
|
||||
app_config: object | None = None,
|
||||
service_factory: object | None = None,
|
||||
) -> CompiledStateGraph:
|
||||
) -> CompiledGraph:
|
||||
"""Create a focused receipt processing graph for direct usage.
|
||||
|
||||
Args:
|
||||
@@ -236,7 +236,7 @@ def create_paperless_graph(
|
||||
config: dict[str, Any] | None = None,
|
||||
app_config: object | None = None,
|
||||
service_factory: object | None = None,
|
||||
) -> CompiledStateGraph:
|
||||
) -> CompiledGraph:
|
||||
"""Create the standardized Paperless NGX document management graph.
|
||||
|
||||
This graph provides intelligent document management workflows with:
|
||||
@@ -402,7 +402,7 @@ def create_paperless_graph(
|
||||
return compiled_graph
|
||||
|
||||
|
||||
def paperless_graph_factory(config: RunnableConfig) -> CompiledStateGraph:
|
||||
def paperless_graph_factory(config: RunnableConfig) -> CompiledGraph:
|
||||
"""Create Paperless graph for LangGraph API.
|
||||
|
||||
Args:
|
||||
@@ -421,7 +421,7 @@ async def paperless_graph_factory_async(config: RunnableConfig) -> Any: # noqa:
|
||||
return await asyncio.to_thread(paperless_graph_factory, config)
|
||||
|
||||
|
||||
def receipt_processing_graph_factory(config: RunnableConfig) -> CompiledStateGraph:
|
||||
def receipt_processing_graph_factory(config: RunnableConfig) -> CompiledGraph:
|
||||
"""Create receipt processing graph for LangGraph API.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Literal, cast
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, Literal, cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph.state import CachePolicy, RetryPolicy
|
||||
@@ -43,10 +44,10 @@ from biz_bud.graphs.rag.nodes.scraping import (
|
||||
scrape_status_summary_node,
|
||||
)
|
||||
from biz_bud.nodes import finalize_status_node, preserve_url_fields_node
|
||||
from biz_bud.states.url_to_rag import URLToRAGState
|
||||
from biz_bud.states.url_to_rag import StatePayload, URLToRAGState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from langgraph.graph.state import CompiledStateGraph as CompiledGraph
|
||||
from biz_bud.services.factory import ServiceFactory
|
||||
|
||||
|
||||
@@ -55,7 +56,7 @@ class URLToRAGGraphInput(TypedDict, total=False):
|
||||
|
||||
url: str
|
||||
input_url: str
|
||||
config: dict[str, Any]
|
||||
config: StatePayload
|
||||
collection_name: str | None
|
||||
force_refresh: bool
|
||||
|
||||
@@ -65,10 +66,10 @@ class URLToRAGGraphOutput(TypedDict, total=False):
|
||||
|
||||
status: Literal["pending", "running", "success", "error"]
|
||||
error: str | None
|
||||
r2r_info: dict[str, Any] | None
|
||||
scraped_content: list[dict[str, Any]]
|
||||
r2r_info: StatePayload | None
|
||||
scraped_content: list[StatePayload]
|
||||
repomix_output: str | None
|
||||
upload_tracker: dict[str, Any] | None
|
||||
upload_tracker: StatePayload | None
|
||||
|
||||
|
||||
class URLToRAGGraphContext(TypedDict, total=False):
|
||||
@@ -76,7 +77,7 @@ class URLToRAGGraphContext(TypedDict, total=False):
|
||||
|
||||
service_factory: "ServiceFactory" | None
|
||||
request_id: str | None
|
||||
metadata: dict[str, Any]
|
||||
metadata: StatePayload
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -105,7 +106,7 @@ GRAPH_METADATA = {
|
||||
|
||||
def _create_initial_state(
|
||||
url: str,
|
||||
config: dict[str, Any],
|
||||
config: StatePayload,
|
||||
collection_name: str | None = None,
|
||||
force_refresh: bool = False,
|
||||
) -> URLToRAGState:
|
||||
@@ -190,7 +191,7 @@ _should_scrape_or_skip = create_list_length_router(
|
||||
_should_process_next_url = create_bool_router("finalize", "check_duplicate", "batch_complete")
|
||||
|
||||
|
||||
def create_url_to_r2r_graph(config: dict[str, Any] | None = None) -> "CompiledStateGraph":
|
||||
def create_url_to_r2r_graph(config: StatePayload | None = None) -> "CompiledGraph":
|
||||
"""Create the URL to R2R processing graph with iterative URL processing.
|
||||
|
||||
This graph processes URLs one at a time through the complete pipeline,
|
||||
@@ -464,7 +465,7 @@ def create_url_to_r2r_graph(config: dict[str, Any] | None = None) -> "CompiledSt
|
||||
|
||||
|
||||
# Factory function for LangGraph API
|
||||
def url_to_r2r_graph_factory(config: RunnableConfig) -> Any: # noqa: ANN401
|
||||
def url_to_r2r_graph_factory(config: RunnableConfig) -> "CompiledGraph":
|
||||
"""Create URL to R2R graph for LangGraph API with RunnableConfig."""
|
||||
# Use centralized config resolution to handle all overrides at entry point
|
||||
# Resolve configuration with any RunnableConfig overrides (sync version)
|
||||
@@ -481,7 +482,7 @@ def url_to_r2r_graph_factory(config: RunnableConfig) -> Any: # noqa: ANN401
|
||||
service_factory = ServiceFactory(app_config)
|
||||
|
||||
# Convert AppConfig to dict format expected by the graph
|
||||
config_dict = app_config.model_dump()
|
||||
config_dict: StatePayload = app_config.model_dump()
|
||||
|
||||
# Inject service factory into config for nodes to access
|
||||
config_dict["service_factory"] = service_factory
|
||||
@@ -490,7 +491,7 @@ def url_to_r2r_graph_factory(config: RunnableConfig) -> Any: # noqa: ANN401
|
||||
|
||||
|
||||
# Async factory function for LangGraph API
|
||||
async def url_to_r2r_graph_factory_async(config: RunnableConfig) -> Any: # noqa: ANN401
|
||||
async def url_to_r2r_graph_factory_async(config: RunnableConfig) -> "CompiledGraph":
|
||||
"""Async wrapper for url_to_r2r_graph_factory to avoid blocking calls."""
|
||||
import asyncio
|
||||
return await asyncio.to_thread(url_to_r2r_graph_factory, config)
|
||||
@@ -503,7 +504,7 @@ url_to_r2r_graph = create_url_to_r2r_graph
|
||||
# Usage example
|
||||
async def _process_url_to_r2r(
|
||||
url: str,
|
||||
config: dict[str, Any],
|
||||
config: StatePayload,
|
||||
collection_name: str | None = None,
|
||||
force_refresh: bool = False,
|
||||
) -> URLToRAGState:
|
||||
@@ -555,10 +556,10 @@ async def _process_url_to_r2r(
|
||||
|
||||
async def _stream_url_to_r2r(
|
||||
url: str,
|
||||
config: dict[str, Any],
|
||||
config: StatePayload,
|
||||
collection_name: str | None = None,
|
||||
force_refresh: bool = False,
|
||||
) -> AsyncGenerator[dict[str, Any], None]:
|
||||
) -> AsyncGenerator[StatePayload, None]:
|
||||
"""Process a URL and upload to R2R, yielding streaming updates.
|
||||
|
||||
Args:
|
||||
@@ -603,8 +604,8 @@ async def _stream_url_to_r2r(
|
||||
|
||||
async def _process_url_to_r2r_with_streaming(
|
||||
url: str,
|
||||
config: dict[str, Any],
|
||||
on_update: Callable[[dict[str, Any]], None] | None = None,
|
||||
config: StatePayload,
|
||||
on_update: Callable[[StatePayload], None] | None = None,
|
||||
collection_name: str | None = None,
|
||||
force_refresh: bool = False,
|
||||
) -> URLToRAGState:
|
||||
@@ -665,7 +666,7 @@ stream_url_to_r2r = _stream_url_to_r2r
|
||||
process_url_to_r2r_with_streaming = _process_url_to_r2r_with_streaming
|
||||
|
||||
|
||||
def url_to_rag_graph_factory(config: RunnableConfig) -> "CompiledStateGraph":
|
||||
def url_to_rag_graph_factory(config: RunnableConfig) -> "CompiledGraph":
|
||||
"""Create URL to RAG graph for graph-as-tool pattern.
|
||||
|
||||
This factory function follows the standard pattern for graphs
|
||||
|
||||
@@ -7,7 +7,7 @@ and various vector store integrations.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
@@ -34,10 +34,23 @@ firecrawl_scrape_node = None
|
||||
firecrawl_batch_scrape_node = None
|
||||
|
||||
|
||||
StatePayload = dict[str, object]
|
||||
|
||||
|
||||
def _coerce_sequence(value: object) -> Sequence[Mapping[str, object]]:
|
||||
if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
|
||||
items: list[Mapping[str, object]] = []
|
||||
for element in value:
|
||||
if isinstance(element, Mapping):
|
||||
items.append(element)
|
||||
return items
|
||||
return []
|
||||
|
||||
|
||||
@standard_node(node_name="vector_store_upload", metric_name="vector_upload")
|
||||
async def vector_store_upload_node(
|
||||
state: dict[str, Any], config: RunnableConfig
|
||||
) -> dict[str, Any]:
|
||||
state: Mapping[str, object], config: RunnableConfig
|
||||
) -> StatePayload:
|
||||
"""Upload prepared content to vector store.
|
||||
|
||||
This node handles the upload of prepared content to various vector stores
|
||||
@@ -52,17 +65,17 @@ async def vector_store_upload_node(
|
||||
"""
|
||||
info_highlight("Uploading content to vector store...", category="VectorUpload")
|
||||
|
||||
prepared_content = state.get("rag_prepared_content", [])
|
||||
prepared_content = _coerce_sequence(state.get("rag_prepared_content"))
|
||||
if not prepared_content:
|
||||
warning_highlight("No prepared content to upload", category="VectorUpload")
|
||||
return {"upload_results": {"documents_uploaded": 0}}
|
||||
|
||||
collection_name = state.get("collection_name", "default")
|
||||
collection_name = str(state.get("collection_name", "default"))
|
||||
|
||||
try:
|
||||
# Simulate upload (real implementation would use actual vector store service)
|
||||
uploaded_count = 0
|
||||
failed_uploads = []
|
||||
failed_uploads: list[Mapping[str, object]] = []
|
||||
|
||||
for item in prepared_content:
|
||||
if not item.get("ready_for_upload"):
|
||||
@@ -77,7 +90,7 @@ async def vector_store_upload_node(
|
||||
category="VectorUpload",
|
||||
)
|
||||
|
||||
upload_results = {
|
||||
upload_results: StatePayload = {
|
||||
"documents_uploaded": uploaded_count,
|
||||
"failed_uploads": failed_uploads,
|
||||
"collection_name": collection_name,
|
||||
@@ -109,8 +122,8 @@ async def vector_store_upload_node(
|
||||
|
||||
@standard_node(node_name="git_repo_processor", metric_name="git_processing")
|
||||
async def process_git_repository_node(
|
||||
state: dict[str, Any], config: RunnableConfig
|
||||
) -> dict[str, Any]:
|
||||
state: Mapping[str, object], config: RunnableConfig
|
||||
) -> StatePayload:
|
||||
"""Process Git repository for RAG ingestion.
|
||||
|
||||
This node handles the special case of processing Git repositories,
|
||||
@@ -125,10 +138,10 @@ async def process_git_repository_node(
|
||||
"""
|
||||
info_highlight("Processing Git repository...", category="GitProcessor")
|
||||
|
||||
if not state.get("is_git_repo"):
|
||||
return state
|
||||
if not bool(state.get("is_git_repo")):
|
||||
return {}
|
||||
|
||||
input_url = state.get("input_url", "")
|
||||
input_url = str(state.get("input_url", ""))
|
||||
|
||||
try:
|
||||
# Use repomix if available
|
||||
|
||||
@@ -8,7 +8,8 @@ from biz_bud.core for optimal performance and maintainability.
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict, cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
@@ -32,15 +33,16 @@ from biz_bud.core.langgraph.graph_builder import (
|
||||
NodeConfig,
|
||||
build_graph_from_config,
|
||||
)
|
||||
from langgraph.graph.state import CachePolicy, RetryPolicy
|
||||
from langgraph.graph.state import CachePolicy, CompiledStateGraph, RetryPolicy
|
||||
|
||||
from biz_bud.core.utils import create_lazy_loader
|
||||
from biz_bud.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from langgraph.pregel import Pregel
|
||||
|
||||
from biz_bud.services.factory import ServiceFactory
|
||||
|
||||
from biz_bud.core.config.loader import load_config, load_config_async
|
||||
from biz_bud.core.config.schemas import AppConfig
|
||||
|
||||
@@ -48,15 +50,11 @@ from biz_bud.core.config.schemas import AppConfig
|
||||
from biz_bud.core.edge_helpers import create_conditional_langgraph_command_router
|
||||
|
||||
# Import research-specific nodes from local module
|
||||
from biz_bud.graphs.research.nodes import (
|
||||
derive_research_query_node,
|
||||
)
|
||||
from biz_bud.graphs.research.nodes import derive_research_query_node
|
||||
from biz_bud.graphs.research.nodes import (
|
||||
synthesize_search_results as synthesize_research_results_node,
|
||||
)
|
||||
from biz_bud.graphs.research.nodes import (
|
||||
validate_research_synthesis_node,
|
||||
)
|
||||
from biz_bud.graphs.research.nodes import validate_research_synthesis_node
|
||||
|
||||
# Import consolidated core nodes
|
||||
from biz_bud.nodes import (
|
||||
@@ -65,16 +63,12 @@ from biz_bud.nodes import (
|
||||
research_web_search_node,
|
||||
semantic_extract_node,
|
||||
)
|
||||
from biz_bud.states.base import StateMetadata
|
||||
from biz_bud.states.research import ResearchState
|
||||
|
||||
# Import nodes that haven't been consolidated yet
|
||||
try:
|
||||
from biz_bud.graphs.rag.nodes.rag_enhance import rag_enhance_node
|
||||
from biz_bud.nodes.validation.human_feedback import human_feedback_node
|
||||
except ImportError:
|
||||
# These will be moved to appropriate graph modules later
|
||||
rag_enhance_node = None
|
||||
human_feedback_node = None
|
||||
from biz_bud.graphs.rag.nodes.rag_enhance import rag_enhance_node
|
||||
from biz_bud.nodes.validation.human_feedback import human_feedback_node
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -93,17 +87,17 @@ class ResearchGraphOutput(TypedDict, total=False):
|
||||
|
||||
status: Literal["pending", "running", "complete", "error"]
|
||||
synthesis: str
|
||||
sources: list[dict[str, Any]]
|
||||
validation_summary: dict[str, Any]
|
||||
human_feedback: dict[str, Any]
|
||||
sources: list[StateMetadata]
|
||||
validation_summary: StateMetadata
|
||||
human_feedback: StateMetadata
|
||||
|
||||
|
||||
class ResearchGraphContext(TypedDict, total=False):
|
||||
"""Optional runtime context injected into research graph executions."""
|
||||
|
||||
service_factory: Any | None
|
||||
service_factory: "ServiceFactory" | None
|
||||
request_id: str | None
|
||||
metadata: dict[str, Any]
|
||||
metadata: StateMetadata
|
||||
|
||||
# Graph metadata for dynamic discovery
|
||||
GRAPH_METADATA = {
|
||||
@@ -172,17 +166,6 @@ def _get_cached_config() -> AppConfig:
|
||||
return load_config()
|
||||
|
||||
|
||||
|
||||
|
||||
# Placeholder node for missing imports
|
||||
async def _placeholder_node(
|
||||
state: dict[str, Any], config: RunnableConfig
|
||||
) -> dict[str, Any]:
|
||||
"""Provide placeholder node for features not yet migrated."""
|
||||
logger.warning("Placeholder node called - feature not yet migrated")
|
||||
return state
|
||||
|
||||
|
||||
# Create string-based router for conditional edges compatibility
|
||||
def _search_retry_router(state: ResearchState) -> str:
|
||||
"""Route search results based on retry logic.
|
||||
@@ -321,7 +304,7 @@ _handle_synthesis_errors = handle_error(
|
||||
)
|
||||
|
||||
|
||||
def _prepare_synthesis_routing(state: ResearchState) -> dict[str, Any]:
|
||||
def _prepare_synthesis_routing(state: ResearchState) -> dict[str, object]:
|
||||
"""Prepare state for synthesis quality routing.
|
||||
|
||||
This helper adds synthesis_length to state for threshold routing.
|
||||
@@ -338,7 +321,7 @@ def _prepare_synthesis_routing(state: ResearchState) -> dict[str, Any]:
|
||||
return {"synthesis_length": synthesis_length}
|
||||
|
||||
|
||||
def _prepare_search_results(state: ResearchState) -> dict[str, Any]:
|
||||
def _prepare_search_results(state: ResearchState) -> dict[str, object]:
|
||||
"""Prepare search results state for Command-based routing.
|
||||
|
||||
Args:
|
||||
@@ -358,7 +341,7 @@ def _prepare_search_results(state: ResearchState) -> dict[str, Any]:
|
||||
|
||||
def create_research_graph(
|
||||
checkpointer: PostgresSaver | None = None,
|
||||
) -> "CompiledStateGraph":
|
||||
) -> CompiledStateGraph[ResearchState]:
|
||||
"""Create the consolidated research workflow graph.
|
||||
|
||||
This graph uses consolidated nodes, edge helper factories, and global
|
||||
@@ -389,7 +372,7 @@ def create_research_graph(
|
||||
cache_policy=CachePolicy(ttl=900),
|
||||
),
|
||||
"rag_enhance": NodeConfig(
|
||||
func=rag_enhance_node or _placeholder_node,
|
||||
func=rag_enhance_node,
|
||||
metadata={
|
||||
"category": "augmentation",
|
||||
"description": "Enhance the research scope with retrieval augmented prompts",
|
||||
@@ -448,7 +431,7 @@ def create_research_graph(
|
||||
retry_policy=RetryPolicy(max_attempts=2),
|
||||
),
|
||||
"human_feedback": NodeConfig(
|
||||
func=human_feedback_node or _placeholder_node,
|
||||
func=human_feedback_node,
|
||||
metadata={
|
||||
"category": "feedback",
|
||||
"description": "Escalate to human feedback when automated validation fails",
|
||||
@@ -494,7 +477,7 @@ def create_research_graph(
|
||||
]
|
||||
|
||||
# Wrapper function to extract route from Command object for LangGraph compatibility
|
||||
def _synthesis_quality_route_extractor(state: dict[str, Any]) -> str:
|
||||
def _synthesis_quality_route_extractor(state: Mapping[str, object]) -> str:
|
||||
"""Extract route string from Command object for LangGraph compatibility."""
|
||||
command = _synthesis_quality_router(state)
|
||||
# Handle the case where goto might be a Send object or sequence
|
||||
@@ -567,7 +550,9 @@ def create_research_graph(
|
||||
|
||||
|
||||
# Factory function for LangGraph API
|
||||
def research_graph_factory(config: RunnableConfig) -> "CompiledStateGraph":
|
||||
def research_graph_factory(
|
||||
config: RunnableConfig,
|
||||
) -> CompiledStateGraph[ResearchState]:
|
||||
"""Create research graph for LangGraph API with RunnableConfig.
|
||||
|
||||
This factory extracts configuration from RunnableConfig and sets up
|
||||
@@ -595,14 +580,18 @@ def research_graph_factory(config: RunnableConfig) -> "CompiledStateGraph":
|
||||
|
||||
|
||||
# Async factory function for LangGraph API
|
||||
async def research_graph_factory_async(config: RunnableConfig) -> Any: # noqa: ANN401
|
||||
async def research_graph_factory_async(
|
||||
config: RunnableConfig,
|
||||
) -> CompiledStateGraph[ResearchState]:
|
||||
"""Async wrapper for research_graph_factory to avoid blocking calls."""
|
||||
import asyncio
|
||||
return await asyncio.to_thread(research_graph_factory, config)
|
||||
|
||||
|
||||
# Async factory function that follows Business Buddy patterns
|
||||
async def create_research_graph_async(config: Any = None) -> "CompiledStateGraph":
|
||||
async def create_research_graph_async(
|
||||
config: RunnableConfig | None = None,
|
||||
) -> CompiledStateGraph[ResearchState]:
|
||||
"""Create research graph using async patterns with service factory integration.
|
||||
|
||||
This function follows Business Buddy architectural patterns for async
|
||||
@@ -688,7 +677,7 @@ def get_research_graph(
|
||||
|
||||
# Usage example for testing
|
||||
async def process_research_query(
|
||||
query: str, config: dict[str, Any] | None = None, derive_query: bool = True
|
||||
query: str, config: dict[str, object] | None = None, derive_query: bool = True
|
||||
) -> ResearchState:
|
||||
"""Process a research query using the consolidated graph.
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from biz_bud.nodes import parse_and_validate_initial_payload
|
||||
from biz_bud.states.base import BaseState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from langgraph.graph.state import CompiledStateGraph as CompiledGraph
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -282,7 +282,7 @@ async def finalize_scraping(
|
||||
# --- Graph Construction ---
|
||||
|
||||
|
||||
def create_scraping_graph() -> "CompiledStateGraph":
|
||||
def create_scraping_graph() -> "CompiledGraph":
|
||||
"""Create the web scraping workflow graph.
|
||||
|
||||
This graph demonstrates parallel URL processing using the Send API,
|
||||
@@ -361,7 +361,7 @@ GRAPH_METADATA = {
|
||||
|
||||
|
||||
# Factory function
|
||||
def scraping_graph_factory(config: RunnableConfig) -> "CompiledStateGraph":
|
||||
def scraping_graph_factory(config: RunnableConfig) -> "CompiledGraph":
|
||||
"""Create scraping graph for LangGraph API."""
|
||||
return create_scraping_graph()
|
||||
|
||||
|
||||
@@ -7,73 +7,69 @@ suggesting recovery actions, and updating workflow state to reflect error status
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from biz_bud.core.langgraph import StateUpdater, ensure_immutable_node, standard_node
|
||||
from biz_bud.logging import error_highlight, get_logger, info_highlight
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from biz_bud.core import ErrorInfo, ErrorRecoveryTypedDict
|
||||
from collections.abc import Mapping, MutableMapping, Sequence
|
||||
from typing import TypedDict, cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
logger = get_logger(__name__)
|
||||
from biz_bud.core import ErrorInfo, ErrorRecoveryTypedDict
|
||||
from biz_bud.core.langgraph import StateUpdater, ensure_immutable_node, standard_node
|
||||
from biz_bud.logging import error_highlight, info_highlight
|
||||
|
||||
WorkflowState = MutableMapping[str, object]
|
||||
ErrorMapping = Mapping[str, object]
|
||||
|
||||
|
||||
class ValidationErrorSummary(TypedDict):
|
||||
"""Structured summary returned when validation fails."""
|
||||
|
||||
validation_status: str
|
||||
failure_reason: list[str]
|
||||
error_count: int
|
||||
error_distribution: dict[str, int]
|
||||
recommendation: str
|
||||
|
||||
|
||||
@standard_node(node_name="handle_graph_error", metric_name="error_handling")
|
||||
@ensure_immutable_node
|
||||
async def handle_graph_error(
|
||||
state: dict[str, Any], config: RunnableConfig
|
||||
) -> dict[str, Any]:
|
||||
"""Central error handler for the workflow graph.
|
||||
state: WorkflowState, config: RunnableConfig
|
||||
) -> WorkflowState:
|
||||
"""Central error handler for the workflow graph."""
|
||||
|
||||
Args:
|
||||
state (BusinessBuddyState): The current workflow state containing error information.
|
||||
errors_raw = state.get("errors")
|
||||
errors: list[ErrorInfo] = []
|
||||
if isinstance(errors_raw, Sequence) and not isinstance(
|
||||
errors_raw, (str, bytes, bytearray)
|
||||
):
|
||||
for error in errors_raw:
|
||||
if isinstance(error, Mapping):
|
||||
errors.append(cast(ErrorInfo, error))
|
||||
|
||||
Returns:
|
||||
BusinessBuddyState: The updated state with error status, recovery actions, and workflow status.
|
||||
|
||||
This node function:
|
||||
- Logs all errors found in the state.
|
||||
- Determines if any errors are critical (preventing workflow continuation).
|
||||
- Suggests recovery actions for non-critical errors based on error phase.
|
||||
- Updates the state with error recovery actions and workflow status.
|
||||
- Handles both critical and non-critical errors, marking the workflow as failed or recovered.
|
||||
|
||||
"""
|
||||
# Get errors from state with proper typing
|
||||
errors_raw = state.get("errors", []) or []
|
||||
errors: Sequence[ErrorInfo] = cast(
|
||||
"Sequence[ErrorInfo]", errors_raw if isinstance(errors_raw, list) else []
|
||||
)
|
||||
if not errors:
|
||||
info_highlight("No errors found in state.")
|
||||
return state # No errors to handle
|
||||
return state
|
||||
|
||||
info_highlight(f"Handling {len(errors)} error(s) found in state.")
|
||||
# Initialize variables to prevent potential UnboundLocalError
|
||||
error_detail: ErrorInfo = cast("ErrorInfo", {})
|
||||
phase = ""
|
||||
error_msg = ""
|
||||
|
||||
for error_detail in cast("list[ErrorInfo]", errors):
|
||||
for error_detail in errors:
|
||||
phase = error_detail.get("phase", "unknown")
|
||||
error_msg = error_detail.get("error", "unknown error")
|
||||
error_highlight(f"Error in phase '{phase}': {error_msg}")
|
||||
|
||||
# Define phases considered critical where failure prevents continuation
|
||||
critical_phases = ["input", "initialization", "config_load", "authentication"]
|
||||
critical_phases: tuple[str, ...] = (
|
||||
"input",
|
||||
"initialization",
|
||||
"config_load",
|
||||
"authentication",
|
||||
)
|
||||
|
||||
has_critical_error: bool = any(
|
||||
error_detail.get("phase", "") in critical_phases
|
||||
for error_detail in cast("list[ErrorInfo]", errors)
|
||||
has_critical_error = any(
|
||||
error_detail.get("phase", "") in critical_phases for error_detail in errors
|
||||
)
|
||||
|
||||
if has_critical_error:
|
||||
error_highlight("Critical error detected, cannot continue workflow.")
|
||||
# Use StateUpdater for immutable updates
|
||||
updater = StateUpdater(state)
|
||||
return (
|
||||
updater.set("critical_error", True)
|
||||
@@ -82,9 +78,8 @@ async def handle_graph_error(
|
||||
.build()
|
||||
)
|
||||
|
||||
# Attempt to define recovery actions for non-critical errors
|
||||
recovery_actions: list[ErrorRecoveryTypedDict] = []
|
||||
for error_detail in cast("list[ErrorInfo]", errors):
|
||||
for error_detail in errors:
|
||||
phase = error_detail.get("phase", "unknown")
|
||||
error_msg = error_detail.get("error", "")
|
||||
action: ErrorRecoveryTypedDict = {
|
||||
@@ -92,7 +87,6 @@ async def handle_graph_error(
|
||||
"description": f"Logged error: {error_msg}",
|
||||
}
|
||||
|
||||
# Simple phase-based recovery suggestions (expand as needed)
|
||||
if phase in ["search_query", "web_search"]:
|
||||
action = {
|
||||
"action": "retry_with_different_query",
|
||||
@@ -111,7 +105,10 @@ async def handle_graph_error(
|
||||
elif phase in ["synthesis", "validation", "interpret", "report"]:
|
||||
action = {
|
||||
"action": "simplify_or_request_review",
|
||||
"description": f"Consider simplifying scope or requesting human review due to error in {phase}.",
|
||||
"description": (
|
||||
"Consider simplifying scope or requesting human review due to error in "
|
||||
f"{phase}."
|
||||
),
|
||||
}
|
||||
elif phase == "human_feedback":
|
||||
action = {
|
||||
@@ -126,16 +123,16 @@ async def handle_graph_error(
|
||||
|
||||
recovery_actions.append(action)
|
||||
|
||||
# Use StateUpdater for immutable updates
|
||||
updater = StateUpdater(state)
|
||||
current_status = state.get("workflow_status", "recovered")
|
||||
|
||||
# Track retry attempts
|
||||
retry_count = state.get("retry_count", 0)
|
||||
if isinstance(retry_count, int):
|
||||
retry_count += 1
|
||||
else:
|
||||
retry_count = 1
|
||||
raw_retry = state.get("retry_count", 0)
|
||||
try:
|
||||
retry_count_int = int(raw_retry) # type: ignore[arg-type]
|
||||
if retry_count_int < 0:
|
||||
retry_count_int = 0
|
||||
except (TypeError, ValueError):
|
||||
retry_count_int = 0
|
||||
retry_count = retry_count_int + 1
|
||||
|
||||
new_state = (
|
||||
updater.set("error_recovery", recovery_actions)
|
||||
@@ -151,145 +148,98 @@ async def handle_graph_error(
|
||||
return new_state
|
||||
|
||||
|
||||
@standard_node(
|
||||
node_name="handle_validation_failure", metric_name="validation_error_handling"
|
||||
)
|
||||
@standard_node(node_name="handle_validation_failure", metric_name="validation_error_handling")
|
||||
@ensure_immutable_node
|
||||
async def handle_validation_failure(
|
||||
state: dict[str, Any], config: RunnableConfig | None
|
||||
) -> dict[str, Any]:
|
||||
"""Handle validation failures. Updates 'validation_error_summary' and 'validation_passed' in the workflow state.
|
||||
state: WorkflowState, config: RunnableConfig | None
|
||||
) -> WorkflowState:
|
||||
"""Handle validation failures."""
|
||||
|
||||
Args:
|
||||
state (BusinessBuddyState): The current workflow state containing validation results and errors.
|
||||
errors_value = state.get("errors")
|
||||
errors: Sequence[ErrorMapping]
|
||||
if isinstance(errors_value, Sequence):
|
||||
errors = tuple(
|
||||
cast(ErrorMapping, error)
|
||||
for error in errors_value
|
||||
if isinstance(error, Mapping)
|
||||
)
|
||||
else:
|
||||
errors = ()
|
||||
|
||||
Returns:
|
||||
BusinessBuddyState: The updated state with validation error summary, pass/fail status, and workflow status.
|
||||
|
||||
This node function:
|
||||
- Extracts validation errors and results from the state.
|
||||
- Checks for failures in fact checking, logic validation, consistency, and human feedback.
|
||||
- Aggregates issues and determines the overall validation status.
|
||||
- Updates the state with a summary of validation failures and recommendations.
|
||||
- Handles and logs any errors encountered during validation.
|
||||
|
||||
"""
|
||||
from typing import cast
|
||||
|
||||
# Cast state to dict for dynamic access
|
||||
state_dict = cast("dict[str, object]", state)
|
||||
|
||||
errors_raw = state_dict.get("errors", []) or []
|
||||
errors: Sequence[dict[str, object]] = (
|
||||
errors_raw if isinstance(errors_raw, list) else []
|
||||
)
|
||||
validation_phases: list[str] = [
|
||||
validation_phases: tuple[str, ...] = (
|
||||
"criteria",
|
||||
"fact_check",
|
||||
"logic_check",
|
||||
"consistency_check",
|
||||
"human_feedback",
|
||||
"validate",
|
||||
]
|
||||
)
|
||||
validation_errors: list[dict[str, object]] = []
|
||||
for err in errors or []:
|
||||
if isinstance(err, dict) and err.get("phase", "") in validation_phases:
|
||||
# Convert to dict[str, object] by ensuring all values are objects
|
||||
obj_err: dict[str, object] = {k: str(v) for k, v in err.items()}
|
||||
for err in errors:
|
||||
if err.get("phase", "") in validation_phases:
|
||||
obj_err: dict[str, object] = {str(k): str(v) for k, v in err.items()}
|
||||
validation_errors.append(obj_err)
|
||||
# Get values from state with proper typing
|
||||
fact_check_raw = state_dict.get("fact_check_results")
|
||||
fact_check_results: dict[str, object] | None = (
|
||||
fact_check_raw if isinstance(fact_check_raw, dict) else None
|
||||
|
||||
fact_check_raw = state.get("fact_check_results")
|
||||
fact_check_results: Mapping[str, object] | None = (
|
||||
fact_check_raw if isinstance(fact_check_raw, Mapping) else None
|
||||
)
|
||||
|
||||
logic_raw = state_dict.get("logic_validation")
|
||||
logic_validation: dict[str, object] | None = (
|
||||
logic_raw if isinstance(logic_raw, dict) else None
|
||||
logic_raw = state.get("logic_validation")
|
||||
logic_validation: Mapping[str, object] | None = (
|
||||
logic_raw if isinstance(logic_raw, Mapping) else None
|
||||
)
|
||||
|
||||
consistency_raw = state_dict.get("consistency_check")
|
||||
consistency_check: dict[str, object] | None = (
|
||||
consistency_raw if isinstance(consistency_raw, dict) else None
|
||||
consistency_raw = state.get("consistency_check")
|
||||
consistency_check: Mapping[str, object] | None = (
|
||||
consistency_raw if isinstance(consistency_raw, Mapping) else None
|
||||
)
|
||||
|
||||
feedback_raw = state_dict.get("human_feedback")
|
||||
human_feedback: dict[str, object] | None = (
|
||||
feedback_raw if isinstance(feedback_raw, dict) else None
|
||||
feedback_raw = state.get("human_feedback")
|
||||
human_feedback: Mapping[str, object] | None = (
|
||||
feedback_raw if isinstance(feedback_raw, Mapping) else None
|
||||
)
|
||||
|
||||
valid_raw = state_dict.get("is_output_valid")
|
||||
valid_raw = state.get("is_output_valid")
|
||||
is_output_valid: bool | None = valid_raw if isinstance(valid_raw, bool) else None
|
||||
|
||||
failed: bool = False
|
||||
failed = False
|
||||
issues: list[str] = [
|
||||
str(err.get("error", "Unknown validation issue")) for err in validation_errors
|
||||
]
|
||||
error_counts: dict[str, int] = {}
|
||||
for err in validation_errors:
|
||||
phase = str(err.get("phase", "unknown"))
|
||||
error_counts[phase] = cast("dict[str, int]", error_counts).get(phase, 0) + 1
|
||||
error_counts[phase] = error_counts.get(phase, 0) + 1
|
||||
|
||||
def check_score(
|
||||
result: dict[str, object] | None, label: str, threshold: float
|
||||
) -> None:
|
||||
"""Check if a validation result's score is below a threshold and update failure state.
|
||||
|
||||
Args:
|
||||
result: The validation result, or None if not available.
|
||||
label: The label for the validation type (e.g., 'Fact check').
|
||||
threshold: The minimum acceptable score.
|
||||
|
||||
Side Effects:
|
||||
Modifies the enclosing function's 'failed' and 'issues' variables if the score is too low.
|
||||
|
||||
"""
|
||||
def check_score(result: Mapping[str, object] | None, label: str, threshold: float) -> None:
|
||||
nonlocal failed, issues
|
||||
if result is not None:
|
||||
score_raw = result.get("score", 1.0)
|
||||
score = float(score_raw) if isinstance(score_raw, (int, float)) else 1.0
|
||||
if score < threshold:
|
||||
failed = True
|
||||
issues.append(f"{label} score too low: {score}")
|
||||
issues_raw = result.get("issues")
|
||||
if isinstance(issues_raw, list):
|
||||
issues.extend(
|
||||
str(issue) for issue in issues_raw if issue is not None
|
||||
)
|
||||
if result is None:
|
||||
return
|
||||
|
||||
score_raw = result.get("score", 1.0)
|
||||
score = float(score_raw) if isinstance(score_raw, (int, float)) else 1.0
|
||||
if score < threshold:
|
||||
failed = True
|
||||
issues.append(f"{label} score too low: {score}")
|
||||
issues_raw = result.get("issues")
|
||||
if isinstance(issues_raw, Sequence) and not isinstance(
|
||||
issues_raw, (str, bytes, bytearray)
|
||||
):
|
||||
issues.extend(str(issue) for issue in issues_raw if issue is not None)
|
||||
|
||||
if is_output_valid is False:
|
||||
failed = True
|
||||
# Safely handle validation_issues which might be None or not a list
|
||||
validation_issues = state_dict.get("validation_issues")
|
||||
if isinstance(validation_issues, list):
|
||||
issues.extend(
|
||||
str(issue) for issue in validation_issues if issue is not None
|
||||
)
|
||||
validation_issues = state.get("validation_issues")
|
||||
if isinstance(validation_issues, Sequence):
|
||||
issues.extend(str(issue) for issue in validation_issues if issue is not None)
|
||||
|
||||
def safe_check_score(
|
||||
result: dict[str, object] | None, label: str, threshold: float
|
||||
) -> None:
|
||||
"""Safely check score with proper type handling.
|
||||
check_score(fact_check_results, "Fact check", 0.6)
|
||||
check_score(logic_validation, "Logic validation", 0.6)
|
||||
check_score(consistency_check, "Consistency", 0.6)
|
||||
check_score(human_feedback, "Human feedback", 0.5)
|
||||
|
||||
Args:
|
||||
result: The validation result to check
|
||||
label: Label for the validation type
|
||||
threshold: Minimum acceptable score
|
||||
|
||||
"""
|
||||
if result is not None:
|
||||
check_score(result, label, threshold)
|
||||
else:
|
||||
check_score(None, label, threshold)
|
||||
|
||||
# Check all validation results
|
||||
safe_check_score(fact_check_results, "Fact check", 0.6)
|
||||
safe_check_score(logic_validation, "Logic validation", 0.6)
|
||||
safe_check_score(consistency_check, "Consistency", 0.6)
|
||||
safe_check_score(human_feedback, "Human feedback", 0.5)
|
||||
|
||||
def get_safe_score(result: dict[str, object] | None) -> float:
|
||||
"""Safely extract score from validation result."""
|
||||
def get_safe_score(result: Mapping[str, object] | None) -> float:
|
||||
if result is None:
|
||||
return 1.0
|
||||
score_raw = result.get("score", 1.0)
|
||||
@@ -304,7 +254,7 @@ async def handle_validation_failure(
|
||||
.build()
|
||||
)
|
||||
|
||||
status: str = "failed_validation"
|
||||
status = "failed_validation"
|
||||
if error_counts.get("fact_check", 0) > 0 or (
|
||||
fact_check_results and get_safe_score(fact_check_results) < 0.6
|
||||
):
|
||||
@@ -322,26 +272,15 @@ async def handle_validation_failure(
|
||||
):
|
||||
status = "failed_human_feedback"
|
||||
|
||||
class ValidationErrorSummaryTypedDict(dict[str, object]):
|
||||
"""TypedDict for validation error summary structure."""
|
||||
|
||||
validation_status: str
|
||||
failure_reason: list[str]
|
||||
error_count: int
|
||||
error_distribution: dict[str, int]
|
||||
recommendation: str
|
||||
|
||||
summary_dict = {
|
||||
summary: ValidationErrorSummary = {
|
||||
"validation_status": status,
|
||||
"failure_reason": issues,
|
||||
"error_count": len(validation_errors),
|
||||
"error_distribution": error_counts,
|
||||
"recommendation": "Revise content based on validation feedback or request human review.",
|
||||
}
|
||||
summary = cast("ValidationErrorSummaryTypedDict", summary_dict)
|
||||
error_highlight(f"Validation failed. Summary: {summary}")
|
||||
|
||||
# Use StateUpdater for immutable updates
|
||||
updater = StateUpdater(state)
|
||||
return (
|
||||
updater.set("validation_error_summary", summary)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Error analyzer node for classifying errors and determining recovery strategies."""
|
||||
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
@@ -112,6 +113,9 @@ def _rule_based_analysis(error: ErrorInfo, context: ErrorContext) -> ErrorAnalys
|
||||
error_type = error.get("error_type", "unknown")
|
||||
error_message = error.get("message", "")
|
||||
error_category = error.get("category", ErrorCategory.UNKNOWN)
|
||||
raw_details = error.get("details")
|
||||
if isinstance(raw_details, Mapping):
|
||||
error_category = raw_details.get("category", error_category)
|
||||
|
||||
# Analyze based on error category
|
||||
if error_category == ErrorCategory.LLM:
|
||||
|
||||
@@ -11,6 +11,7 @@ from biz_bud.core.langgraph import ConfigurationProvider, standard_node
|
||||
from biz_bud.logging import get_logger
|
||||
from biz_bud.prompts.error_handling import ERROR_SUMMARY_PROMPT, USER_GUIDANCE_PROMPT
|
||||
from biz_bud.states.error_handling import ErrorHandlingState
|
||||
from biz_bud.services.llm.client import LangchainLLMClient
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -128,7 +129,9 @@ async def _generate_resolution_steps(
|
||||
if service_factory is None:
|
||||
raise ConfigurationError("ServiceFactory not found in RunnableConfig")
|
||||
|
||||
llm_client = await service_factory.get_llm_client()
|
||||
llm_client = cast(
|
||||
LangchainLLMClient, await service_factory.get_llm_client()
|
||||
)
|
||||
|
||||
# Prepare context for guidance generation
|
||||
context = {
|
||||
|
||||
@@ -11,7 +11,7 @@ from __future__ import annotations
|
||||
|
||||
import re
|
||||
import warnings
|
||||
from typing import Any
|
||||
from biz_bud.nodes.url_processing._typing import StateMapping
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
@@ -30,8 +30,8 @@ from .scrape_url import scrape_url_node
|
||||
|
||||
@standard_node(node_name="discover_urls", metric_name="url_discovery")
|
||||
async def discover_urls_node(
|
||||
state: dict[str, Any], config: RunnableConfig | None
|
||||
) -> dict[str, Any]:
|
||||
state: StateMapping, config: RunnableConfig | None
|
||||
) -> dict[str, object]:
|
||||
"""Discover URLs from a website through sitemaps and crawling.
|
||||
|
||||
⚠️ DEPRECATED: This function is deprecated. Use biz_bud.nodes.url_processing.discover_urls_node instead.
|
||||
|
||||
103
src/biz_bud/nodes/url_processing/_typing.py
Normal file
103
src/biz_bud/nodes/url_processing/_typing.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Shared typing helpers for URL processing nodes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable, Mapping
|
||||
|
||||
StateMapping = Mapping[str, object]
|
||||
|
||||
|
||||
def coerce_str(value: object | None) -> str | None:
|
||||
"""Return ``value`` if it is a string, otherwise ``None``."""
|
||||
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def coerce_bool(value: object | None, default: bool = False) -> bool:
|
||||
"""Coerce arbitrary objects into booleans with a default."""
|
||||
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
lowered = value.strip().lower()
|
||||
if lowered in {"true", "1", "yes", "on"}:
|
||||
return True
|
||||
if lowered in {"false", "0", "no", "off"}:
|
||||
return False
|
||||
return default
|
||||
|
||||
|
||||
def coerce_int(value: object | None, default: int) -> int:
|
||||
"""Return an integer when possible, otherwise the provided default."""
|
||||
|
||||
if isinstance(value, bool):
|
||||
return default
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
if isinstance(value, float):
|
||||
return int(value)
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return default
|
||||
return default
|
||||
|
||||
|
||||
def coerce_float(value: object | None, default: float = 0.0) -> float:
|
||||
"""Return a floating-point number when possible."""
|
||||
|
||||
if isinstance(value, bool):
|
||||
return default
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return default
|
||||
return default
|
||||
|
||||
|
||||
def coerce_str_list(value: object | None) -> list[str]:
|
||||
"""Create a list of strings from an arbitrary iterable value."""
|
||||
|
||||
if isinstance(value, str):
|
||||
return [value]
|
||||
if isinstance(value, Iterable):
|
||||
return [item for item in value if isinstance(item, str)]
|
||||
return []
|
||||
|
||||
|
||||
def coerce_object_dict(value: object | None) -> dict[str, object]:
|
||||
"""Convert arbitrary mapping-like objects into ``dict[str, object]``."""
|
||||
|
||||
if isinstance(value, Mapping):
|
||||
return {str(key): item for key, item in value.items()}
|
||||
return {}
|
||||
|
||||
|
||||
def coerce_object_list(value: object | None) -> list[dict[str, object]]:
|
||||
"""Convert an iterable of mappings into concrete dictionaries."""
|
||||
|
||||
if isinstance(value, Iterable) and not isinstance(value, (str, bytes, bytearray)):
|
||||
result: list[dict[str, object]] = []
|
||||
for item in value:
|
||||
if isinstance(item, Mapping):
|
||||
result.append({str(key): item_value for key, item_value in item.items()})
|
||||
return result
|
||||
return []
|
||||
|
||||
|
||||
__all__ = [
|
||||
"StateMapping",
|
||||
"coerce_bool",
|
||||
"coerce_float",
|
||||
"coerce_int",
|
||||
"coerce_object_dict",
|
||||
"coerce_object_list",
|
||||
"coerce_str",
|
||||
"coerce_str_list",
|
||||
]
|
||||
@@ -6,7 +6,15 @@ using the Business Buddy URL processing tools.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from biz_bud.nodes.url_processing._typing import (
|
||||
StateMapping,
|
||||
coerce_bool,
|
||||
coerce_float,
|
||||
coerce_int,
|
||||
coerce_object_dict,
|
||||
coerce_str,
|
||||
coerce_str_list,
|
||||
)
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
@@ -20,9 +28,9 @@ from biz_bud.tools.capabilities.url_processing import (
|
||||
|
||||
|
||||
@standard_node(node_name="discover_urls", metric_name="url_discovery")
|
||||
async def discover_urls_node(
|
||||
state: dict[str, Any], config: RunnableConfig | None
|
||||
) -> dict[str, Any]:
|
||||
async def discover_urls_node(
|
||||
state: StateMapping, config: RunnableConfig | None
|
||||
) -> dict[str, object]:
|
||||
"""Discover URLs from a website using URL processing tools.
|
||||
|
||||
This node discovers URLs from a base URL using various discovery methods
|
||||
@@ -48,13 +56,17 @@ async def discover_urls_node(
|
||||
"""
|
||||
info_highlight("Starting URL discovery...", category="URLDiscovery")
|
||||
|
||||
base_url = state.get("base_url") or state.get("input_url") or state.get("url", "")
|
||||
if not base_url:
|
||||
error_highlight("No base URL provided for discovery", category="URLDiscovery")
|
||||
return {
|
||||
"discovered_urls": [],
|
||||
"discovery_summary": {"total_discovered": 0, "base_url": None, "error": "No base URL provided"},
|
||||
"errors": [
|
||||
base_url_value = (
|
||||
coerce_str(state.get("base_url"))
|
||||
or coerce_str(state.get("input_url"))
|
||||
or coerce_str(state.get("url"))
|
||||
)
|
||||
if not base_url_value:
|
||||
error_highlight("No base URL provided for discovery", category="URLDiscovery")
|
||||
return {
|
||||
"discovered_urls": [],
|
||||
"discovery_summary": {"total_discovered": 0, "base_url": None, "error": "No base URL provided"},
|
||||
"errors": [
|
||||
create_error_info(
|
||||
message="No base URL provided for discovery",
|
||||
node="discover_urls",
|
||||
@@ -64,9 +76,10 @@ async def discover_urls_node(
|
||||
],
|
||||
}
|
||||
|
||||
discovery_provider = state.get("discovery_provider", "comprehensive")
|
||||
max_results = state.get("max_results", 1000)
|
||||
detailed_results = state.get("detailed_results", False)
|
||||
base_url = base_url_value
|
||||
discovery_provider = coerce_str(state.get("discovery_provider")) or "comprehensive"
|
||||
max_results = coerce_int(state.get("max_results"), 1000)
|
||||
detailed_results = coerce_bool(state.get("detailed_results"))
|
||||
|
||||
debug_highlight(
|
||||
f"Discovering URLs from {base_url} using {discovery_provider} method",
|
||||
@@ -74,43 +87,46 @@ async def discover_urls_node(
|
||||
)
|
||||
|
||||
try:
|
||||
if detailed_results:
|
||||
# Get detailed discovery information
|
||||
result = await discover_urls_detailed.ainvoke({
|
||||
"base_url": base_url,
|
||||
"provider": discovery_provider
|
||||
})
|
||||
|
||||
discovered_urls = result.get("discovered_urls", [])
|
||||
|
||||
# Apply max results limit
|
||||
if max_results and len(discovered_urls) > max_results:
|
||||
discovered_urls = discovered_urls[:max_results]
|
||||
|
||||
discovery_details = {
|
||||
"base_url": result.get("base_url"),
|
||||
"total_discovered": result.get("total_discovered", 0),
|
||||
"discovery_method": result.get("discovery_method"),
|
||||
"sitemap_urls": result.get("sitemap_urls", []),
|
||||
"robots_txt_urls": result.get("robots_txt_urls", []),
|
||||
"html_links": result.get("html_links", []),
|
||||
"is_successful": result.get("is_successful", False),
|
||||
"discovery_time": result.get("discovery_time", 0.0),
|
||||
"metadata": result.get("metadata", {}),
|
||||
}
|
||||
|
||||
discovery_summary = {
|
||||
"total_discovered": len(discovered_urls),
|
||||
"base_url": base_url,
|
||||
"discovery_method": result.get("discovery_method"),
|
||||
"discovery_time": result.get("discovery_time", 0.0),
|
||||
"limited_to": len(discovered_urls) if max_results and len(discovered_urls) >= max_results else None,
|
||||
}
|
||||
|
||||
info_highlight(
|
||||
f"Detailed URL discovery completed: {len(discovered_urls)} URLs found using {result.get('discovery_method')} method",
|
||||
category="URLDiscovery"
|
||||
)
|
||||
if detailed_results:
|
||||
# Get detailed discovery information
|
||||
result = await discover_urls_detailed.ainvoke(
|
||||
{"base_url": base_url, "provider": discovery_provider}
|
||||
)
|
||||
|
||||
result_data = coerce_object_dict(result)
|
||||
discovered_urls = coerce_str_list(result_data.get("discovered_urls"))
|
||||
|
||||
# Apply max results limit
|
||||
if max_results and len(discovered_urls) > max_results:
|
||||
discovered_urls = discovered_urls[:max_results]
|
||||
|
||||
discovery_method = coerce_str(result_data.get("discovery_method"))
|
||||
|
||||
discovery_details = {
|
||||
"base_url": coerce_str(result_data.get("base_url")) or base_url,
|
||||
"total_discovered": coerce_int(result_data.get("total_discovered"), len(discovered_urls)),
|
||||
"discovery_method": discovery_method,
|
||||
"sitemap_urls": coerce_str_list(result_data.get("sitemap_urls")),
|
||||
"robots_txt_urls": coerce_str_list(result_data.get("robots_txt_urls")),
|
||||
"html_links": coerce_str_list(result_data.get("html_links")),
|
||||
"is_successful": coerce_bool(result_data.get("is_successful")),
|
||||
"discovery_time": coerce_float(result_data.get("discovery_time")),
|
||||
"metadata": coerce_object_dict(result_data.get("metadata")),
|
||||
}
|
||||
|
||||
discovery_summary = {
|
||||
"total_discovered": len(discovered_urls),
|
||||
"base_url": base_url,
|
||||
"discovery_method": discovery_method,
|
||||
"discovery_time": coerce_float(result_data.get("discovery_time")),
|
||||
"limited_to": len(discovered_urls) if max_results and len(discovered_urls) >= max_results else None,
|
||||
}
|
||||
|
||||
info_highlight(
|
||||
"Detailed URL discovery completed: "
|
||||
f"{len(discovered_urls)} URLs found using {discovery_method or 'unknown'} method",
|
||||
category="URLDiscovery"
|
||||
)
|
||||
|
||||
return {
|
||||
"discovered_urls": discovered_urls,
|
||||
@@ -118,13 +134,17 @@ async def discover_urls_node(
|
||||
"discovery_summary": discovery_summary,
|
||||
}
|
||||
|
||||
else:
|
||||
# Simple discovery
|
||||
discovered_urls = await discover_urls.ainvoke({
|
||||
"base_url": base_url,
|
||||
"provider": discovery_provider,
|
||||
"max_results": max_results
|
||||
})
|
||||
else:
|
||||
# Simple discovery
|
||||
discovered_urls = coerce_str_list(
|
||||
await discover_urls.ainvoke(
|
||||
{
|
||||
"base_url": base_url,
|
||||
"provider": discovery_provider,
|
||||
"max_results": max_results,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
discovery_summary = {
|
||||
"total_discovered": len(discovered_urls),
|
||||
@@ -143,22 +163,33 @@ async def discover_urls_node(
|
||||
"discovery_summary": discovery_summary,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_highlight(f"URL discovery failed: {e}", category="URLDiscovery")
|
||||
return {
|
||||
"discovered_urls": [base_url], # Fallback to base URL
|
||||
"discovery_summary": {
|
||||
"total_discovered": 1,
|
||||
"base_url": base_url,
|
||||
"error": str(e),
|
||||
"fallback_used": True,
|
||||
},
|
||||
"errors": [
|
||||
create_error_info(
|
||||
message=f"URL discovery failed: {str(e)}",
|
||||
node="discover_urls",
|
||||
severity="warning",
|
||||
category="discovery_error",
|
||||
)
|
||||
],
|
||||
except Exception as e:
|
||||
error_highlight(f"URL discovery failed: {e}", category="URLDiscovery")
|
||||
safe_base_url = coerce_str(base_url) or None
|
||||
fallback_urls = [url for url in [safe_base_url] if isinstance(url, str) and url]
|
||||
error_type = type(e).__name__
|
||||
return {
|
||||
"discovered_urls": fallback_urls,
|
||||
"discovery_summary": {
|
||||
"total_discovered": len(fallback_urls),
|
||||
"base_url": safe_base_url,
|
||||
"error": f"{error_type}: {str(e)}",
|
||||
"fallback_used": True,
|
||||
"status": "error",
|
||||
"provider": discovery_provider,
|
||||
"requested_max_results": max_results,
|
||||
},
|
||||
"errors": [
|
||||
create_error_info(
|
||||
message=f"{error_type}: {str(e)}",
|
||||
node="discover_urls",
|
||||
severity="error",
|
||||
category="discovery_error",
|
||||
context={
|
||||
"base_url": safe_base_url,
|
||||
"provider": discovery_provider,
|
||||
"max_results": max_results,
|
||||
},
|
||||
)
|
||||
],
|
||||
}
|
||||
|
||||
@@ -6,7 +6,15 @@ using the Business Buddy URL processing tools.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from biz_bud.nodes.url_processing._typing import (
|
||||
StateMapping,
|
||||
coerce_bool,
|
||||
coerce_int,
|
||||
coerce_object_dict,
|
||||
coerce_object_list,
|
||||
coerce_str,
|
||||
coerce_str_list,
|
||||
)
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
@@ -16,9 +24,9 @@ from biz_bud.logging import debug_highlight, error_highlight, info_highlight
|
||||
|
||||
|
||||
@standard_node(node_name="process_urls", metric_name="url_processing")
|
||||
async def process_urls_node(
|
||||
state: dict[str, Any], config: RunnableConfig | None
|
||||
) -> dict[str, Any]:
|
||||
async def process_urls_node(
|
||||
state: StateMapping, config: RunnableConfig | None
|
||||
) -> dict[str, object]:
|
||||
"""Process multiple URLs using URL processing tools.
|
||||
|
||||
This node processes a batch of URLs using various URL processing
|
||||
@@ -46,55 +54,64 @@ async def process_urls_node(
|
||||
|
||||
info_highlight("Starting batch URL processing...", category="URLProcessing")
|
||||
|
||||
urls = state.get("urls", [])
|
||||
if not urls:
|
||||
info_highlight("No URLs to process", category="URLProcessing")
|
||||
return {
|
||||
"processed_urls": [],
|
||||
"processing_summary": {"total": 0, "successful": 0, "failed": 0},
|
||||
}
|
||||
urls = coerce_str_list(state.get("urls"))
|
||||
if not urls:
|
||||
info_highlight("No URLs to process", category="URLProcessing")
|
||||
return {
|
||||
"processed_urls": [],
|
||||
"processing_summary": {"total": 0, "successful": 0, "failed": 0},
|
||||
}
|
||||
|
||||
processing_operations = coerce_str_list(state.get("processing_operations"))
|
||||
if not processing_operations:
|
||||
processing_operations = ["validate", "normalize"]
|
||||
|
||||
validation_level = coerce_str(state.get("validation_level")) or "standard"
|
||||
max_concurrent = coerce_int(state.get("max_concurrent"), 5)
|
||||
|
||||
processing_operations = state.get("processing_operations", ["validate", "normalize"])
|
||||
validation_level = state.get("validation_level", "standard")
|
||||
max_concurrent = state.get("max_concurrent", 5)
|
||||
|
||||
debug_highlight(
|
||||
f"Processing {len(urls)} URLs with operations: {processing_operations}",
|
||||
category="URLProcessing"
|
||||
)
|
||||
total_urls = len(urls)
|
||||
|
||||
debug_highlight(
|
||||
f"Processing {total_urls} URLs with operations: {processing_operations}",
|
||||
category="URLProcessing"
|
||||
)
|
||||
|
||||
try:
|
||||
# Use the batch processing tool
|
||||
# Call the LangChain tool using invoke
|
||||
input_dict = {
|
||||
"urls": urls,
|
||||
"validation_level": validation_level,
|
||||
"normalization_provider": None,
|
||||
"enable_deduplication": True,
|
||||
"deduplication_provider": None,
|
||||
"max_concurrent": max_concurrent,
|
||||
"timeout": 30.0
|
||||
}
|
||||
result = await process_urls_batch.ainvoke(input_dict)
|
||||
input_dict: dict[str, object] = {
|
||||
"urls": urls,
|
||||
"validation_level": validation_level,
|
||||
"normalization_provider": None,
|
||||
"enable_deduplication": True,
|
||||
"deduplication_provider": None,
|
||||
"max_concurrent": max_concurrent,
|
||||
"timeout": 30.0,
|
||||
}
|
||||
result = await process_urls_batch.ainvoke(input_dict)
|
||||
|
||||
result_data = coerce_object_dict(result)
|
||||
processed_urls = coerce_object_list(result_data.get("results"))
|
||||
|
||||
successful_count = sum(
|
||||
1 for url_result in processed_urls if coerce_bool(url_result.get("success"))
|
||||
)
|
||||
failed_count = len(processed_urls) - successful_count
|
||||
success_rate = (successful_count / total_urls) * 100 if total_urls else 0.0
|
||||
|
||||
processing_summary = {
|
||||
"total": total_urls,
|
||||
"successful": successful_count,
|
||||
"failed": failed_count,
|
||||
"success_rate": success_rate,
|
||||
"operations_performed": processing_operations,
|
||||
}
|
||||
|
||||
processed_urls = result.get("results", [])
|
||||
successful_count = sum(bool(url_result.get("success", False))
|
||||
for url_result in processed_urls)
|
||||
failed_count = len(processed_urls) - successful_count
|
||||
|
||||
processing_summary = {
|
||||
"total": len(urls),
|
||||
"successful": successful_count,
|
||||
"failed": failed_count,
|
||||
"success_rate": (successful_count / len(urls)) * 100 if urls else 0,
|
||||
"operations_performed": processing_operations,
|
||||
}
|
||||
|
||||
info_highlight(
|
||||
f"URL processing completed: {successful_count}/{len(urls)} successful "
|
||||
f"({processing_summary['success_rate']:.1f}% success rate)",
|
||||
category="URLProcessing"
|
||||
)
|
||||
info_highlight(
|
||||
f"URL processing completed: {successful_count}/{len(urls)} successful "
|
||||
f"({success_rate:.1f}% success rate)",
|
||||
category="URLProcessing"
|
||||
)
|
||||
|
||||
return {
|
||||
"processed_urls": processed_urls,
|
||||
@@ -103,14 +120,14 @@ async def process_urls_node(
|
||||
|
||||
except Exception as e:
|
||||
error_highlight(f"URL processing node failed: {e}", category="URLProcessing")
|
||||
return {
|
||||
"processed_urls": [],
|
||||
"processing_summary": {"total": len(urls), "successful": 0, "failed": len(urls)},
|
||||
"errors": [
|
||||
create_error_info(
|
||||
message=f"URL processing failed: {str(e)}",
|
||||
node="process_urls",
|
||||
severity="error",
|
||||
return {
|
||||
"processed_urls": [],
|
||||
"processing_summary": {"total": total_urls, "successful": 0, "failed": total_urls},
|
||||
"errors": [
|
||||
create_error_info(
|
||||
message=f"URL processing failed: {str(e)}",
|
||||
node="process_urls",
|
||||
severity="error",
|
||||
category="processing_error",
|
||||
)
|
||||
],
|
||||
|
||||
@@ -6,7 +6,13 @@ using the Business Buddy URL processing tools.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from biz_bud.nodes.url_processing._typing import (
|
||||
StateMapping,
|
||||
coerce_bool,
|
||||
coerce_object_dict,
|
||||
coerce_str,
|
||||
coerce_str_list,
|
||||
)
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
@@ -17,9 +23,9 @@ from biz_bud.tools.capabilities.url_processing import validate_url
|
||||
|
||||
|
||||
@standard_node(node_name="validate_urls", metric_name="url_validation")
|
||||
async def validate_urls_node(
|
||||
state: dict[str, Any], config: RunnableConfig | None
|
||||
) -> dict[str, Any]:
|
||||
async def validate_urls_node(
|
||||
state: StateMapping, config: RunnableConfig | None
|
||||
) -> dict[str, object]:
|
||||
"""Validate URLs using URL processing tools.
|
||||
|
||||
This node validates a list of URLs using the URL processing validation
|
||||
@@ -45,46 +51,45 @@ async def validate_urls_node(
|
||||
"""
|
||||
info_highlight("Starting URL validation...", category="URLValidation")
|
||||
|
||||
urls = state.get("urls", [])
|
||||
if not urls:
|
||||
info_highlight("No URLs to validate", category="URLValidation")
|
||||
return {
|
||||
"validated_urls": [],
|
||||
"valid_urls": [],
|
||||
"invalid_urls": [],
|
||||
"validation_summary": {"total": 0, "valid": 0, "invalid": 0},
|
||||
}
|
||||
urls = coerce_str_list(state.get("urls"))
|
||||
if not urls:
|
||||
info_highlight("No URLs to validate", category="URLValidation")
|
||||
return {
|
||||
"validated_urls": [],
|
||||
"valid_urls": [],
|
||||
"invalid_urls": [],
|
||||
"validation_summary": {"total": 0, "valid": 0, "invalid": 0},
|
||||
}
|
||||
|
||||
validation_level = coerce_str(state.get("validation_level")) or "standard"
|
||||
validation_provider = coerce_str(state.get("validation_provider"))
|
||||
|
||||
validation_level = state.get("validation_level", "standard")
|
||||
validation_provider = state.get("validation_provider")
|
||||
total_urls = len(urls)
|
||||
|
||||
debug_highlight(
|
||||
f"Validating {total_urls} URLs with level: {validation_level}",
|
||||
category="URLValidation"
|
||||
)
|
||||
|
||||
debug_highlight(
|
||||
f"Validating {len(urls)} URLs with level: {validation_level}",
|
||||
category="URLValidation"
|
||||
)
|
||||
try:
|
||||
validated_urls: list[dict[str, object]] = []
|
||||
valid_urls: list[str] = []
|
||||
invalid_urls: list[str] = []
|
||||
|
||||
try:
|
||||
validated_urls = []
|
||||
valid_urls = []
|
||||
invalid_urls = []
|
||||
|
||||
for url in urls:
|
||||
try:
|
||||
result = await validate_url.ainvoke({
|
||||
"url": url,
|
||||
"level": validation_level,
|
||||
"provider": validation_provider
|
||||
})
|
||||
|
||||
validated_urls.append({
|
||||
"url": url,
|
||||
**result
|
||||
})
|
||||
|
||||
if result.get("is_valid", False):
|
||||
valid_urls.append(url)
|
||||
else:
|
||||
invalid_urls.append(url)
|
||||
for url in urls:
|
||||
try:
|
||||
result = await validate_url.ainvoke(
|
||||
{"url": url, "level": validation_level, "provider": validation_provider}
|
||||
)
|
||||
|
||||
result_data = coerce_object_dict(result)
|
||||
|
||||
validated_urls.append({"url": url, **result_data})
|
||||
|
||||
if coerce_bool(result_data.get("is_valid")):
|
||||
valid_urls.append(url)
|
||||
else:
|
||||
invalid_urls.append(url)
|
||||
|
||||
except Exception as e:
|
||||
error_highlight(
|
||||
@@ -99,18 +104,22 @@ async def validate_urls_node(
|
||||
})
|
||||
invalid_urls.append(url)
|
||||
|
||||
validation_summary = {
|
||||
"total": len(urls),
|
||||
"valid": len(valid_urls),
|
||||
"invalid": len(invalid_urls),
|
||||
"success_rate": (len(valid_urls) / len(urls)) * 100 if urls else 0,
|
||||
}
|
||||
|
||||
info_highlight(
|
||||
f"URL validation completed: {validation_summary['valid']}/{validation_summary['total']} valid "
|
||||
f"({validation_summary['success_rate']:.1f}% success rate)",
|
||||
category="URLValidation"
|
||||
)
|
||||
valid_count = len(valid_urls)
|
||||
invalid_count = len(invalid_urls)
|
||||
success_rate = (valid_count / total_urls) * 100 if total_urls else 0.0
|
||||
|
||||
validation_summary = {
|
||||
"total": total_urls,
|
||||
"valid": valid_count,
|
||||
"invalid": invalid_count,
|
||||
"success_rate": success_rate,
|
||||
}
|
||||
|
||||
info_highlight(
|
||||
f"URL validation completed: {valid_count}/{total_urls} valid "
|
||||
f"({success_rate:.1f}% success rate)",
|
||||
category="URLValidation"
|
||||
)
|
||||
|
||||
return {
|
||||
"validated_urls": validated_urls,
|
||||
@@ -121,11 +130,11 @@ async def validate_urls_node(
|
||||
|
||||
except Exception as e:
|
||||
error_highlight(f"URL validation node failed: {e}", category="URLValidation")
|
||||
return {
|
||||
"validated_urls": [],
|
||||
"valid_urls": [],
|
||||
"invalid_urls": urls, # Mark all as invalid on error
|
||||
"validation_summary": {"total": len(urls), "valid": 0, "invalid": len(urls)},
|
||||
return {
|
||||
"validated_urls": [],
|
||||
"valid_urls": [],
|
||||
"invalid_urls": list(urls), # Mark all as invalid on error
|
||||
"validation_summary": {"total": total_urls, "valid": 0, "invalid": total_urls},
|
||||
"errors": [
|
||||
create_error_info(
|
||||
message=f"URL validation failed: {str(e)}",
|
||||
|
||||
@@ -3,44 +3,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from operator import add
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal, TypedDict, TypeVar
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
|
||||
from langchain_core.messages import AIMessage as LangchainAIMessage
|
||||
from langchain_core.messages import AnyMessage
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.graph import add_messages
|
||||
from langgraph.managed import IsLastStep
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.messages import AIMessage as LangchainAIMessage
|
||||
from langgraph.managed import IsLastStep
|
||||
from biz_bud.core import (
|
||||
ApiResponseTypedDict,
|
||||
ErrorInfo,
|
||||
InputMetadataTypedDict,
|
||||
ParsedInputTypedDict,
|
||||
ToolCallTypedDict,
|
||||
)
|
||||
|
||||
from biz_bud.core import (
|
||||
ApiResponseTypedDict,
|
||||
ErrorInfo,
|
||||
InputMetadataTypedDict,
|
||||
ParsedInputTypedDict,
|
||||
ToolCallTypedDict,
|
||||
)
|
||||
else:
|
||||
# Create dummy types for runtime
|
||||
Configuration = dict
|
||||
ApiResponseTypedDict = dict
|
||||
ErrorInfo = dict
|
||||
InputMetadataTypedDict = dict
|
||||
ParsedInputTypedDict = dict
|
||||
ToolCallTypedDict = dict
|
||||
IsLastStep = Any
|
||||
LangchainAIMessage = Any # Will be set after TYPE_CHECKING
|
||||
|
||||
# Alias for compatibility
|
||||
Message = Any # Will be set after TYPE_CHECKING
|
||||
|
||||
# The types previously here (AIMessage, IsLastStep, Configuration, utils)
|
||||
# were moved into TYPE_CHECKING to satisfy TCH linting rules.
|
||||
# If LangGraph has issues resolving these at runtime, specific noqa directives
|
||||
# for TCH rules (TCH001-TCH004) might be needed on those specific imports
|
||||
# if they are kept outside the TYPE_CHECKING block, or LangGraph's documentation
|
||||
# should be consulted for the preferred pattern.
|
||||
|
||||
T = TypeVar("T")
|
||||
StateMetadata = dict[str, object]
|
||||
|
||||
|
||||
class ContextTypedDict(TypedDict, total=False):
|
||||
@@ -50,13 +28,13 @@ class ContextTypedDict(TypedDict, total=False):
|
||||
task: str
|
||||
"""Task description or identifier."""
|
||||
|
||||
session_data: dict[str, Any]
|
||||
session_data: StateMetadata
|
||||
"""Session-specific data."""
|
||||
|
||||
user_preferences: dict[str, Any]
|
||||
user_preferences: StateMetadata
|
||||
"""User preferences and settings."""
|
||||
|
||||
workflow_metadata: dict[str, Any]
|
||||
workflow_metadata: StateMetadata
|
||||
"""Metadata about the workflow execution."""
|
||||
|
||||
unique_marker: str
|
||||
@@ -76,7 +54,7 @@ class InitialInputTypedDict(TypedDict, total=False):
|
||||
session_id: str
|
||||
"""Session identifier."""
|
||||
|
||||
metadata: dict[str, Any]
|
||||
metadata: StateMetadata
|
||||
"""Additional metadata for the input."""
|
||||
|
||||
|
||||
@@ -103,7 +81,7 @@ class RunMetadataTypedDict(TypedDict, total=False):
|
||||
tags: list[str]
|
||||
"""Tags associated with the run."""
|
||||
|
||||
custom_metadata: dict[str, Any]
|
||||
custom_metadata: StateMetadata
|
||||
"""Custom metadata for the run."""
|
||||
|
||||
|
||||
@@ -127,7 +105,7 @@ class BaseStateRequired(TypedDict):
|
||||
initial_input: InitialInputTypedDict
|
||||
"""The original input dictionary that initiated the graph run. Preserved for context."""
|
||||
|
||||
config: dict[str, Any]
|
||||
config: StateMetadata
|
||||
"""Immutable application configuration (API keys, model settings, agent params, etc.).
|
||||
Passed during graph invocation. Provides context to all nodes."""
|
||||
|
||||
@@ -156,11 +134,11 @@ class BaseStateRequired(TypedDict):
|
||||
LangGraph managed field indicating the final step before recursion limit.
|
||||
|
||||
Note:
|
||||
The type `IsLastStep` is an opaque marker imported from `langgraph.managed` and is defined as `Any`.
|
||||
This is intentional, as LangGraph uses it as a special sentinel value rather than a concrete type.
|
||||
See: src/langgraph/managed.pyi
|
||||
LangGraph annotates this field with ``Annotated[bool, _IsLastStepSentinel()]`` to indicate when a
|
||||
graph execution is about to reach its recursion limit. The concrete value is a boolean flag but the
|
||||
annotation allows LangGraph to manage the lifecycle automatically.
|
||||
|
||||
[Source: LangGraph Docs, User Example]
|
||||
[Source: LangGraph Docs]
|
||||
"""
|
||||
|
||||
|
||||
@@ -190,10 +168,10 @@ class BaseStateOptional(TypedDict, total=False):
|
||||
assistant_message_for_history: LangchainAIMessage
|
||||
"""Assistant message to be appended to message history."""
|
||||
|
||||
claims_to_check: list[dict[str, Any]]
|
||||
claims_to_check: list[StateMetadata]
|
||||
"""Claims extracted from content for fact-checking."""
|
||||
|
||||
fact_check_results: dict[str, Any] | None
|
||||
fact_check_results: StateMetadata | None
|
||||
"""Results from fact-checking claims."""
|
||||
|
||||
is_output_valid: bool | None
|
||||
|
||||
@@ -4,9 +4,9 @@ This module defines the state structure for Buddy, the intelligent graph
|
||||
orchestrator that coordinates complex workflows across the Business Buddy system.
|
||||
"""
|
||||
|
||||
from typing import Any, Literal, TypedDict
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
from biz_bud.states.base import BaseState
|
||||
from biz_bud.states.base import BaseState, StateMetadata
|
||||
from biz_bud.states.planner import ExecutionPlan, QueryStep
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ class ExecutionRecord(TypedDict):
|
||||
start_time: float
|
||||
end_time: float
|
||||
status: Literal["running", "completed", "failed", "skipped"]
|
||||
result: Any | None
|
||||
result: object | None
|
||||
error: str | None
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ class BuddyState(BaseState):
|
||||
|
||||
# Execution tracking
|
||||
execution_history: list[ExecutionRecord]
|
||||
intermediate_results: dict[str, Any] # step_id -> result mapping
|
||||
intermediate_results: dict[str, object] # step_id -> result mapping
|
||||
adaptation_count: int
|
||||
parallel_execution_enabled: bool
|
||||
completed_step_ids: list[str]
|
||||
@@ -75,8 +75,8 @@ class BuddyState(BaseState):
|
||||
complexity_reasoning: str # Why this complexity was chosen
|
||||
|
||||
# Synthesis data (for capability introspection and other direct synthesis)
|
||||
extracted_info: dict[str, Any] | None # Information extracted for synthesis
|
||||
sources: list[dict[str, Any]] | None # Source metadata for synthesis
|
||||
extracted_info: StateMetadata | None # Information extracted for synthesis
|
||||
sources: list[StateMetadata] | None # Source metadata for synthesis
|
||||
is_capability_introspection: (
|
||||
bool # Whether this is a capability introspection query
|
||||
)
|
||||
|
||||
@@ -10,7 +10,7 @@ This module contains TypedDict definitions organized by business domains:
|
||||
- Catalog: Catalog research and component analysis types
|
||||
"""
|
||||
|
||||
from typing import Any, Literal, NotRequired, TypedDict
|
||||
from typing import Literal, NotRequired, TypedDict
|
||||
|
||||
# =============================================================================
|
||||
# Core Domain Types
|
||||
@@ -45,7 +45,7 @@ class DataDict(TypedDict):
|
||||
|
||||
id: str
|
||||
type: str
|
||||
content: str | dict[str, Any] | list[Any]
|
||||
content: str | dict[str, object] | list[object]
|
||||
metadata: NotRequired[MetadataDict]
|
||||
|
||||
|
||||
@@ -247,7 +247,7 @@ class CatalogComponentResearchResult(TypedDict):
|
||||
item_id: str
|
||||
item_name: str
|
||||
search_query: NotRequired[str]
|
||||
component_research: NotRequired[dict[str, Any]]
|
||||
component_research: NotRequired[dict[str, object]]
|
||||
from_cache: NotRequired[bool]
|
||||
cache_age_days: NotRequired[int]
|
||||
error: NotRequired[str]
|
||||
@@ -263,7 +263,7 @@ class CatalogComponentResearchDict(TypedDict):
|
||||
cached_items: NotRequired[int]
|
||||
searched_items: NotRequired[int]
|
||||
research_results: NotRequired[list[CatalogComponentResearchResult]]
|
||||
metadata: NotRequired[dict[str, Any]]
|
||||
metadata: NotRequired[dict[str, object]]
|
||||
message: NotRequired[str]
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
"""Error handling state definitions for the error recovery agent."""
|
||||
|
||||
from typing import Any, Literal, TypedDict
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
from biz_bud.core import ErrorInfo
|
||||
|
||||
from .base import BaseState
|
||||
from .base import BaseState, StateMetadata
|
||||
|
||||
# Note: These TypedDict definitions are specific to the error handling state
|
||||
# and include additional fields required for state management.
|
||||
@@ -16,7 +16,7 @@ class ErrorContext(TypedDict):
|
||||
node_name: str
|
||||
graph_name: str
|
||||
timestamp: str
|
||||
input_state: dict[str, Any]
|
||||
input_state: StateMetadata
|
||||
execution_count: int
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ class RecoveryAction(TypedDict):
|
||||
"""A recovery action to attempt."""
|
||||
|
||||
action_type: Literal["retry", "modify_input", "fallback", "skip", "abort"]
|
||||
parameters: dict[str, Any]
|
||||
parameters: StateMetadata
|
||||
priority: int
|
||||
expected_success_rate: float
|
||||
|
||||
@@ -44,7 +44,7 @@ class RecoveryResult(TypedDict, total=False):
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
new_state: dict[str, Any] # Optional
|
||||
new_state: StateMetadata # Optional
|
||||
duration_seconds: float # Optional
|
||||
|
||||
|
||||
|
||||
@@ -8,8 +8,8 @@ from __future__ import annotations
|
||||
|
||||
from typing import Annotated, Literal, TypedDict, cast
|
||||
|
||||
from langchain_core.messages import AnyMessage
|
||||
from langgraph.graph.message import add_messages
|
||||
from langchain_core.messages import AnyMessage
|
||||
from langgraph.graph import add_messages
|
||||
|
||||
from biz_bud.core import SearchResultTypedDict as SearchResult
|
||||
from biz_bud.core.types import ErrorInfo
|
||||
@@ -174,30 +174,72 @@ def create_message_only_state() -> MessageOnlyState:
|
||||
|
||||
# Example validation functions that work with focused states
|
||||
|
||||
def validate_search_state(state: SearchWorkflowState) -> bool:
|
||||
"""Validate that search state is properly formed.
|
||||
|
||||
Args:
|
||||
state: The search state to validate
|
||||
|
||||
Returns:
|
||||
True if state is valid, False otherwise
|
||||
"""
|
||||
# Required fields check
|
||||
return bool(state.get("status")) if state.get("thread_id") else False
|
||||
def validate_search_state(state: SearchWorkflowState) -> bool:
|
||||
"""Validate that search state is properly formed.
|
||||
|
||||
The modern LangGraph runtime is stricter about reducer inputs, so we make
|
||||
sure list-based fields are actually lists (and contain dictionaries for
|
||||
search results) before allowing the state to proceed. This mirrors
|
||||
LangGraph's runtime validation and keeps our focused states resilient
|
||||
across v0.x and v1.x releases.
|
||||
|
||||
Args:
|
||||
state: The search state to validate
|
||||
|
||||
Returns:
|
||||
True if state is valid, False otherwise
|
||||
"""
|
||||
|
||||
thread_id = state.get("thread_id")
|
||||
status = state.get("status")
|
||||
messages = state.get("messages")
|
||||
search_results = state.get("search_results")
|
||||
|
||||
if not isinstance(thread_id, str) or not thread_id.strip():
|
||||
return False
|
||||
|
||||
if status not in {"pending", "running", "success", "error", "interrupted"}:
|
||||
return False
|
||||
|
||||
if not isinstance(messages, list):
|
||||
return False
|
||||
|
||||
if not isinstance(search_results, list):
|
||||
return False
|
||||
|
||||
if any(not isinstance(result, dict) for result in search_results):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def validate_validation_state(state: ValidationWorkflowState) -> bool:
|
||||
"""Validate that validation state is properly formed.
|
||||
|
||||
Args:
|
||||
state: The validation state to validate
|
||||
|
||||
Returns:
|
||||
True if state is valid, False otherwise
|
||||
"""
|
||||
# Must have content to validate
|
||||
return bool(state.get("thread_id")) if state.get("content") else False
|
||||
def validate_validation_state(state: ValidationWorkflowState) -> bool:
|
||||
"""Validate that validation state is properly formed.
|
||||
|
||||
Args:
|
||||
state: The validation state to validate
|
||||
|
||||
Returns:
|
||||
True if state is valid, False otherwise
|
||||
"""
|
||||
|
||||
thread_id = state.get("thread_id")
|
||||
content = state.get("content")
|
||||
issues = state.get("validation_issues")
|
||||
|
||||
if not isinstance(thread_id, str) or not thread_id.strip():
|
||||
return False
|
||||
|
||||
if not isinstance(content, str) or not content:
|
||||
return False
|
||||
|
||||
if issues is None:
|
||||
issues = []
|
||||
|
||||
if not isinstance(issues, list):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# Export focused states and utilities
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal, TypedDict
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
from biz_bud.states.base import BaseState
|
||||
from biz_bud.states.base import BaseState, StateMetadata
|
||||
|
||||
|
||||
class RAGAgentStateRequired(TypedDict):
|
||||
@@ -25,7 +25,7 @@ class RAGAgentStateRequired(TypedDict):
|
||||
url_hash: str | None
|
||||
"""SHA256 hash of the URL for deduplication (first 16 chars)."""
|
||||
|
||||
existing_content: dict[str, Any] | None
|
||||
existing_content: StateMetadata | None
|
||||
"""Existing content metadata from knowledge stores."""
|
||||
|
||||
content_age_days: int | None
|
||||
@@ -38,14 +38,14 @@ class RAGAgentStateRequired(TypedDict):
|
||||
"""Human-readable reason for processing decision."""
|
||||
|
||||
# Processing parameters
|
||||
scrape_params: dict[str, Any]
|
||||
scrape_params: StateMetadata
|
||||
"""Parameters for web scraping (depth, patterns, selectors)."""
|
||||
|
||||
r2r_params: dict[str, Any]
|
||||
r2r_params: StateMetadata
|
||||
"""Parameters for R2R upload (chunking, metadata)."""
|
||||
|
||||
# Results
|
||||
processing_result: dict[str, Any] | None
|
||||
processing_result: StateMetadata | None
|
||||
"""Result from url_to_rag graph execution."""
|
||||
|
||||
rag_status: Literal["checking", "decided", "processing", "completed", "error"]
|
||||
@@ -62,7 +62,7 @@ class RAGAgentStateOptional(TypedDict, total=False):
|
||||
"""Optional fields for RAG agent workflow."""
|
||||
|
||||
# ReAct agent fields
|
||||
intermediate_steps: list[tuple[Any, str]]
|
||||
intermediate_steps: list[tuple[object, str]]
|
||||
"""Intermediate steps for ReAct agent execution."""
|
||||
|
||||
final_answer: str | None
|
||||
|
||||
@@ -2,20 +2,14 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypedDict
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
from biz_bud.states.base import BaseState
|
||||
from biz_bud.states.base import BaseState, StateMetadata
|
||||
from biz_bud.tools.clients.r2r import R2RSearchResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from biz_bud.tools.clients.r2r import R2RSearchResult
|
||||
else:
|
||||
# Runtime placeholders for type checking
|
||||
R2RSearchResult = Any
|
||||
|
||||
# Legacy types from deleted rag modules - using Any for now
|
||||
FilteredChunk = Any
|
||||
GenerationResult = Any
|
||||
RetrievalResult = Any
|
||||
FilteredChunk = StateMetadata
|
||||
GenerationResult = StateMetadata
|
||||
RetrievalResult = StateMetadata
|
||||
|
||||
|
||||
class RAGOrchestratorStateRequired(TypedDict):
|
||||
@@ -56,7 +50,7 @@ class RAGOrchestratorStateRequired(TypedDict):
|
||||
urls_to_ingest: list[str]
|
||||
"""URLs that need to be ingested."""
|
||||
|
||||
ingestion_results: dict[str, Any]
|
||||
ingestion_results: StateMetadata
|
||||
"""Results from the ingestion component."""
|
||||
|
||||
ingestion_status: Literal["pending", "processing", "completed", "failed", "skipped"]
|
||||
@@ -120,14 +114,14 @@ class RAGOrchestratorStateOptional(TypedDict, total=False):
|
||||
"""Timing information for each component."""
|
||||
|
||||
# Context and metadata
|
||||
user_context: dict[str, Any]
|
||||
user_context: StateMetadata
|
||||
"""Additional context provided by the user."""
|
||||
|
||||
previous_interactions: list[dict[str, Any]]
|
||||
previous_interactions: list[StateMetadata]
|
||||
"""History of previous interactions in this session."""
|
||||
|
||||
# Advanced retrieval options
|
||||
retrieval_filters: dict[str, Any]
|
||||
retrieval_filters: StateMetadata
|
||||
"""Filters to apply during retrieval."""
|
||||
|
||||
max_chunks: int
|
||||
@@ -147,10 +141,10 @@ class RAGOrchestratorStateOptional(TypedDict, total=False):
|
||||
"""Style for citations in the response."""
|
||||
|
||||
# Error handling and monitoring
|
||||
error_history: list[dict[str, Any]]
|
||||
error_history: list[StateMetadata]
|
||||
"""History of errors encountered during workflow."""
|
||||
|
||||
error_analysis: dict[str, Any]
|
||||
error_analysis: StateMetadata
|
||||
"""Analysis results from the error handling graph."""
|
||||
|
||||
should_retry_node: bool
|
||||
@@ -165,10 +159,10 @@ class RAGOrchestratorStateOptional(TypedDict, total=False):
|
||||
recovery_successful: bool
|
||||
"""Whether error recovery was successful."""
|
||||
|
||||
performance_metrics: dict[str, Any]
|
||||
performance_metrics: StateMetadata
|
||||
"""Performance metrics for monitoring."""
|
||||
|
||||
debug_info: dict[str, Any]
|
||||
debug_info: StateMetadata
|
||||
"""Debug information for troubleshooting."""
|
||||
|
||||
# Integration with legacy fields
|
||||
|
||||
@@ -7,7 +7,7 @@ replacing the monolithic BusinessBuddyState with focused, type-safe alternatives
|
||||
from __future__ import annotations
|
||||
|
||||
from operator import add
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal, TypedDict
|
||||
from typing import TYPE_CHECKING, Annotated, Literal, TypedDict
|
||||
|
||||
# Import from biz_bud.core for search results
|
||||
from biz_bud.core import SearchResultTypedDict as _SearchResult
|
||||
@@ -126,7 +126,7 @@ class ResearchStateOptional(TypedDict, total=False):
|
||||
scraped_results: dict[str, dict[str, str | None]]
|
||||
"""Scraped content from URLs, keyed by URL."""
|
||||
|
||||
semantic_extraction_results: dict[str, Any]
|
||||
semantic_extraction_results: dict[str, object]
|
||||
"""Results from semantic extraction including vector IDs."""
|
||||
|
||||
vector_ids: Annotated[list[str], add]
|
||||
|
||||
@@ -9,7 +9,7 @@ Pydantic models in biz_bud.states.validation_models which provide:
|
||||
"""
|
||||
|
||||
|
||||
from typing import Any, Literal, TypedDict, TypeGuard
|
||||
from typing import Literal, TypedDict, TypeGuard
|
||||
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
|
||||
@@ -27,10 +27,10 @@ class ToolInputData(TypedDict, total=False):
|
||||
source_context: str
|
||||
|
||||
|
||||
class ToolOutputData(TypedDict, total=False):
|
||||
class ToolOutputData(TypedDict, total=False):
|
||||
"""Structured output data from tool execution."""
|
||||
|
||||
result: str | dict[str, Any] | list[Any]
|
||||
result: str | dict[str, object] | list[object]
|
||||
status: Literal["success", "error", "partial"]
|
||||
metadata: dict[str, str | int | bool]
|
||||
processing_time_ms: int
|
||||
@@ -260,7 +260,7 @@ class WebToolsStateTypedDict(SourceMetadataBase, ProcessingStatusBase, Configura
|
||||
|
||||
# Type guard functions for runtime type checking
|
||||
|
||||
def is_tool_input_data(obj: dict[str, Any]) -> TypeGuard[ToolInputData]:
|
||||
def is_tool_input_data(obj: dict[str, object]) -> TypeGuard[ToolInputData]:
|
||||
"""Type guard to check if object is valid ToolInputData."""
|
||||
return (
|
||||
isinstance(obj.get("operation", ""), str) and
|
||||
@@ -269,7 +269,7 @@ def is_tool_input_data(obj: dict[str, Any]) -> TypeGuard[ToolInputData]:
|
||||
)
|
||||
|
||||
|
||||
def is_tool_output_data(obj: dict[str, Any]) -> TypeGuard[ToolOutputData]:
|
||||
def is_tool_output_data(obj: dict[str, object]) -> TypeGuard[ToolOutputData]:
|
||||
"""Type guard to check if object is valid ToolOutputData."""
|
||||
return (
|
||||
obj.get("status") in ("success", "error", "partial") and
|
||||
@@ -278,7 +278,7 @@ def is_tool_output_data(obj: dict[str, Any]) -> TypeGuard[ToolOutputData]:
|
||||
)
|
||||
|
||||
|
||||
def is_search_result_dict(obj: dict[str, Any]) -> TypeGuard[SearchResultDict]:
|
||||
def is_search_result_dict(obj: dict[str, object]) -> TypeGuard[SearchResultDict]:
|
||||
"""Type guard to check if object is valid SearchResultDict."""
|
||||
return (
|
||||
isinstance(obj.get("title", ""), str) and
|
||||
@@ -290,7 +290,7 @@ isinstance(obj.get("title", ""), str) and
|
||||
)
|
||||
|
||||
|
||||
def is_scraped_content_dict(obj: dict[str, Any]) -> TypeGuard[ScrapedContentDict]:
|
||||
def is_scraped_content_dict(obj: dict[str, object]) -> TypeGuard[ScrapedContentDict]:
|
||||
"""Type guard to check if object is valid ScrapedContentDict."""
|
||||
return (
|
||||
isinstance(obj.get("url", ""), str) and
|
||||
@@ -301,7 +301,7 @@ isinstance(obj.get("url", ""), str) and
|
||||
)
|
||||
|
||||
|
||||
def is_extracted_statistic(obj: dict[str, Any]) -> TypeGuard[ExtractedStatistic]:
|
||||
def is_extracted_statistic(obj: dict[str, object]) -> TypeGuard[ExtractedStatistic]:
|
||||
"""Type guard to check if object is valid ExtractedStatistic."""
|
||||
return (
|
||||
isinstance(obj.get("name", ""), str) and
|
||||
@@ -313,7 +313,7 @@ isinstance(obj.get("name", ""), str) and
|
||||
)
|
||||
|
||||
|
||||
def is_extracted_fact(obj: dict[str, Any]) -> TypeGuard[ExtractedFact]:
|
||||
def is_extracted_fact(obj: dict[str, object]) -> TypeGuard[ExtractedFact]:
|
||||
"""Type guard to check if object is valid ExtractedFact."""
|
||||
return (
|
||||
isinstance(obj.get("claim", ""), str) and
|
||||
@@ -328,7 +328,7 @@ isinstance(obj.get("claim", ""), str) and
|
||||
|
||||
# Higher-level type guards for state classes
|
||||
|
||||
def is_tool_state_data(obj: dict[str, Any]) -> TypeGuard[ToolStateTypedDict]:
|
||||
def is_tool_state_data(obj: dict[str, object]) -> TypeGuard[ToolStateTypedDict]:
|
||||
"""Type guard to check if object has valid ToolState structure."""
|
||||
return (
|
||||
(obj.get("tool_name") is None or isinstance(obj.get("tool_name"), str)) and
|
||||
@@ -337,7 +337,7 @@ def is_tool_state_data(obj: dict[str, Any]) -> TypeGuard[ToolStateTypedDict]:
|
||||
)
|
||||
|
||||
|
||||
def is_search_state_data(obj: dict[str, Any]) -> TypeGuard[SearchStateTypedDict]:
|
||||
def is_search_state_data(obj: dict[str, object]) -> TypeGuard[SearchStateTypedDict]:
|
||||
"""Type guard to check if object has valid SearchState structure."""
|
||||
return (
|
||||
(obj.get("query") is None or isinstance(obj.get("query"), str)) and
|
||||
@@ -349,7 +349,7 @@ def is_search_state_data(obj: dict[str, Any]) -> TypeGuard[SearchStateTypedDict]
|
||||
)
|
||||
|
||||
|
||||
def is_scraping_state_data(obj: dict[str, Any]) -> TypeGuard[ScrapingStateTypedDict]:
|
||||
def is_scraping_state_data(obj: dict[str, object]) -> TypeGuard[ScrapingStateTypedDict]:
|
||||
"""Type guard to check if object has valid ScrapingState structure."""
|
||||
return (
|
||||
(obj.get("scraper_name") is None or
|
||||
@@ -362,7 +362,7 @@ def is_scraping_state_data(obj: dict[str, Any]) -> TypeGuard[ScrapingStateTypedD
|
||||
)
|
||||
|
||||
|
||||
def is_extraction_state_data(obj: dict[str, Any]) -> TypeGuard[ExtractionStateTypedDict]:
|
||||
def is_extraction_state_data(obj: dict[str, object]) -> TypeGuard[ExtractionStateTypedDict]:
|
||||
"""Type guard to check if object has valid ExtractionState structure."""
|
||||
return (
|
||||
(obj.get("extracted_statistics") is None or
|
||||
@@ -376,32 +376,32 @@ def is_extraction_state_data(obj: dict[str, Any]) -> TypeGuard[ExtractionStateTy
|
||||
|
||||
# Utility functions for runtime validation with type guards
|
||||
|
||||
def validate_and_cast_tool_input(obj: dict[str, Any]) -> ToolInputData | None:
|
||||
def validate_and_cast_tool_input(obj: dict[str, object]) -> ToolInputData | None:
|
||||
"""Validate and cast object to ToolInputData if valid, otherwise return None."""
|
||||
return obj if is_tool_input_data(obj) else None
|
||||
|
||||
|
||||
def validate_and_cast_search_result(obj: dict[str, Any]) -> SearchResultDict | None:
|
||||
def validate_and_cast_search_result(obj: dict[str, object]) -> SearchResultDict | None:
|
||||
"""Validate and cast object to SearchResultDict if valid, otherwise return None."""
|
||||
return obj if is_search_result_dict(obj) else None
|
||||
|
||||
|
||||
def validate_and_cast_scraped_content(obj: dict[str, Any]) -> ScrapedContentDict | None:
|
||||
def validate_and_cast_scraped_content(obj: dict[str, object]) -> ScrapedContentDict | None:
|
||||
"""Validate and cast object to ScrapedContentDict if valid, otherwise return None."""
|
||||
return obj if is_scraped_content_dict(obj) else None
|
||||
|
||||
|
||||
def filter_valid_search_results(results: list[dict[str, Any]]) -> list[SearchResultDict]:
|
||||
def filter_valid_search_results(results: list[dict[str, object]]) -> list[SearchResultDict]:
|
||||
"""Filter list to only include valid SearchResultDict objects."""
|
||||
return [result for result in results if is_search_result_dict(result)]
|
||||
|
||||
|
||||
def filter_valid_extracted_statistics(stats: list[dict[str, Any]]) -> list[ExtractedStatistic]:
|
||||
def filter_valid_extracted_statistics(stats: list[dict[str, object]]) -> list[ExtractedStatistic]:
|
||||
"""Filter list to only include valid ExtractedStatistic objects."""
|
||||
return [stat for stat in stats if is_extracted_statistic(stat)]
|
||||
|
||||
|
||||
def filter_valid_extracted_facts(facts: list[dict[str, Any]]) -> list[ExtractedFact]:
|
||||
def filter_valid_extracted_facts(facts: list[dict[str, object]]) -> list[ExtractedFact]:
|
||||
"""Filter list to only include valid ExtractedFact objects."""
|
||||
return [fact for fact in facts if is_extracted_fact(fact)]
|
||||
|
||||
@@ -449,7 +449,7 @@ def create_tool_input_data(
|
||||
def create_tool_output_data(
|
||||
status: Literal["success", "error", "partial"],
|
||||
*,
|
||||
result: str | dict[str, Any] | list[Any] | None = None,
|
||||
result: str | dict[str, object] | list[object] | None = None,
|
||||
metadata: dict[str, str | int | bool] | None = None,
|
||||
processing_time_ms: int = 0,
|
||||
warnings: list[str] | None = None,
|
||||
@@ -539,15 +539,15 @@ def create_search_result_dict(
|
||||
return search_result
|
||||
|
||||
|
||||
def create_scraped_content_dict(
|
||||
url: str,
|
||||
content: str,
|
||||
*,
|
||||
title: str | None = None,
|
||||
html: str | None = None,
|
||||
metadata: dict[str, str] | None = None,
|
||||
scrape_timestamp: str | None = None,
|
||||
content_length: int | None = None,
|
||||
def create_scraped_content_dict(
|
||||
url: str,
|
||||
content: str | None,
|
||||
*,
|
||||
title: str | None = None,
|
||||
html: str | None = None,
|
||||
metadata: dict[str, str] | None = None,
|
||||
scrape_timestamp: str | None = None,
|
||||
content_length: int | None = None,
|
||||
success: bool = True,
|
||||
error_message: str | None = None,
|
||||
) -> ScrapedContentDict:
|
||||
@@ -567,26 +567,30 @@ def create_scraped_content_dict(
|
||||
Returns:
|
||||
Validated ScrapedContentDict instance
|
||||
|
||||
Raises:
|
||||
ValueError: If url/content are empty, content_length is negative,
|
||||
or error_message is present when success is True
|
||||
"""
|
||||
if not url or url.isspace():
|
||||
raise ValidationError("URL cannot be empty or whitespace")
|
||||
if not content or content.isspace():
|
||||
raise ValidationError("Content cannot be empty or whitespace")
|
||||
if content_length is not None and content_length < 0:
|
||||
raise ValidationError("Content length cannot be negative")
|
||||
if not success and not error_message:
|
||||
error_message = "An unknown error occurred during scraping."
|
||||
if success and error_message:
|
||||
raise ValidationError("Error message should not be present when success is True")
|
||||
|
||||
scraped_content: ScrapedContentDict = {
|
||||
"url": url.strip(),
|
||||
"content": content,
|
||||
"success": success,
|
||||
}
|
||||
Raises:
|
||||
ValidationError: If url/content are empty, content_length is negative,
|
||||
or error_message is present when success is True
|
||||
"""
|
||||
if not url or url.isspace():
|
||||
raise ValidationError("URL cannot be empty or whitespace")
|
||||
if success:
|
||||
if not content or content.isspace():
|
||||
raise ValidationError("Content cannot be empty or whitespace when success is True")
|
||||
else:
|
||||
if not content or content.isspace():
|
||||
content = ""
|
||||
if content_length is not None and content_length < 0:
|
||||
raise ValidationError("Content length cannot be negative")
|
||||
if not success and not error_message:
|
||||
error_message = "An unknown error occurred during scraping."
|
||||
if success and error_message:
|
||||
raise ValidationError("Error message should not be present when success is True")
|
||||
|
||||
scraped_content: ScrapedContentDict = {
|
||||
"url": url.strip(),
|
||||
"content": content or "",
|
||||
"success": success,
|
||||
}
|
||||
|
||||
if title is not None:
|
||||
scraped_content["title"] = title
|
||||
|
||||
@@ -7,12 +7,10 @@ to handle accumulating fields correctly.
|
||||
from __future__ import annotations
|
||||
|
||||
from operator import add
|
||||
from typing import Annotated, Any, Literal, TypedDict
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
|
||||
from langchain_core.messages import AnyMessage
|
||||
|
||||
# Import LangGraph functionality
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.graph import add_messages
|
||||
|
||||
# Import runtime types needed for forward references
|
||||
# These must be available at runtime for LangGraph's get_type_hints() calls
|
||||
@@ -61,6 +59,8 @@ from .catalog import CatalogIntelState
|
||||
# Types are now imported directly at runtime for LangGraph compatibility
|
||||
|
||||
|
||||
StatePayload = dict[str, object]
|
||||
|
||||
class Organization(TypedDict):
|
||||
"""Organization entity."""
|
||||
|
||||
@@ -126,10 +126,10 @@ class BaseStateOptional(TypedDict, total=False):
|
||||
persistence_error: str
|
||||
"""Error message if result persistence fails."""
|
||||
|
||||
claims_to_check: list[dict[str, Any]]
|
||||
claims_to_check: list[StatePayload]
|
||||
"""Claims extracted from content for fact-checking."""
|
||||
|
||||
fact_check_results: dict[str, Any]
|
||||
fact_check_results: StatePayload
|
||||
"""Results from fact-checking claims."""
|
||||
|
||||
is_output_valid: bool | None
|
||||
@@ -305,7 +305,7 @@ class ResearchStateOptional(TypedDict, total=False):
|
||||
scraped_results: dict[str, dict[str, str | None]]
|
||||
"""Scraped content from URLs, keyed by URL."""
|
||||
|
||||
semantic_extraction_results: dict[str, Any]
|
||||
semantic_extraction_results: StatePayload
|
||||
"""Results from semantic extraction including vector IDs."""
|
||||
|
||||
vector_ids: Annotated[list[str], add]
|
||||
@@ -325,7 +325,7 @@ class ResearchState(
|
||||
Combines base state with search and validation capabilities.
|
||||
"""
|
||||
|
||||
extracted_info: dict[str, Any]
|
||||
extracted_info: StatePayload
|
||||
"""Information extracted from search results in source_0, source_1, etc. format."""
|
||||
|
||||
synthesis: str
|
||||
@@ -382,7 +382,7 @@ class BusinessBuddyState(
|
||||
extracted_info: ExtractedInfoDict
|
||||
"""Information extracted from search results."""
|
||||
|
||||
extracted_content: dict[str, Any]
|
||||
extracted_content: StatePayload
|
||||
"""Content extracted from various sources (e.g., catalog data)."""
|
||||
|
||||
synthesis: str
|
||||
@@ -491,7 +491,7 @@ class BusinessBuddyState(
|
||||
catalog_component_research: CatalogComponentResearchDict
|
||||
"""Results from catalog component research."""
|
||||
|
||||
extracted_components: dict[str, Any]
|
||||
extracted_components: StatePayload
|
||||
"""Extracted components from catalog items."""
|
||||
|
||||
# Catalog intelligence fields
|
||||
@@ -501,22 +501,22 @@ class BusinessBuddyState(
|
||||
batch_component_queries: list[str]
|
||||
"""List of components for batch analysis."""
|
||||
|
||||
component_news_impact_reports: list[dict[str, Any]]
|
||||
component_news_impact_reports: list[StatePayload]
|
||||
"""Impact reports for components based on news/market conditions."""
|
||||
|
||||
catalog_optimization_suggestions: list[dict[str, Any]]
|
||||
catalog_optimization_suggestions: list[StatePayload]
|
||||
"""Optimization suggestions for catalog items."""
|
||||
|
||||
optimization_summary: dict[str, Any]
|
||||
optimization_summary: StatePayload
|
||||
"""Summary of optimization recommendations."""
|
||||
|
||||
catalog_items_linked_to_component: list[dict[str, Any]]
|
||||
catalog_items_linked_to_component: list[StatePayload]
|
||||
"""Catalog items affected by specific component."""
|
||||
|
||||
data_source_used: str | None
|
||||
"""Source of catalog data (database, yaml, default)."""
|
||||
|
||||
component_analytics: dict[str, Any]
|
||||
component_analytics: StatePayload
|
||||
"""Analytics and insights from component analysis."""
|
||||
|
||||
|
||||
|
||||
@@ -1,21 +1,13 @@
|
||||
"""Minimal state for URL to R2R processing workflow."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal, TypedDict
|
||||
|
||||
# Import AnyMessage and add_messages for proper LangGraph support
|
||||
try:
|
||||
from langchain_core.messages import AnyMessage
|
||||
from langgraph.graph.message import add_messages
|
||||
except ImportError:
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.messages import AnyMessage
|
||||
from langgraph.graph.message import add_messages
|
||||
else:
|
||||
# Fallback for runtime if langchain is not installed
|
||||
AnyMessage = Any
|
||||
add_messages = list
|
||||
"""Minimal state for URL to R2R processing workflow."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
|
||||
from langchain_core.messages import AnyMessage
|
||||
from langgraph.graph import add_messages
|
||||
|
||||
StatePayload = dict[str, object]
|
||||
|
||||
|
||||
class URLToRAGState(TypedDict, total=False):
|
||||
@@ -33,7 +25,7 @@ class URLToRAGState(TypedDict, total=False):
|
||||
url: str
|
||||
"""Alternative URL field for compatibility."""
|
||||
|
||||
config: dict[str, Any]
|
||||
config: StatePayload
|
||||
"""Configuration containing API keys."""
|
||||
|
||||
# Processing fields
|
||||
@@ -46,17 +38,17 @@ class URLToRAGState(TypedDict, total=False):
|
||||
discovered_urls: list[str]
|
||||
"""List of URLs discovered during processing."""
|
||||
|
||||
scraped_content: list[dict[str, Any]]
|
||||
scraped_content: list[StatePayload]
|
||||
"""Scraped content from Firecrawl."""
|
||||
|
||||
processed_content: dict[str, Any]
|
||||
processed_content: StatePayload
|
||||
"""Processed content ready for upload."""
|
||||
|
||||
repomix_output: str | None
|
||||
"""Repomix output for git repositories."""
|
||||
|
||||
# R2R configuration
|
||||
r2r_info: dict[str, Any] | None
|
||||
r2r_info: StatePayload | None
|
||||
"""R2R upload information and metadata."""
|
||||
|
||||
# R2R fields
|
||||
@@ -78,7 +70,7 @@ class URLToRAGState(TypedDict, total=False):
|
||||
"""Status messages during processing."""
|
||||
|
||||
# For tests
|
||||
errors: list[dict[str, Any]]
|
||||
errors: list[StatePayload]
|
||||
"""Error list for tests."""
|
||||
|
||||
# Track processing state
|
||||
@@ -86,10 +78,10 @@ class URLToRAGState(TypedDict, total=False):
|
||||
"""Number of pages that have been processed by analyzer."""
|
||||
|
||||
# Upload tracking (optional)
|
||||
upload_tracker: dict[str, Any] | None
|
||||
upload_tracker: StatePayload | None
|
||||
"""Upload tracking summary with success/failure counts."""
|
||||
|
||||
upload_details: dict[str, Any] | None
|
||||
upload_details: StatePayload | None
|
||||
"""Detailed upload information per URL."""
|
||||
|
||||
# Additional fields that may be returned
|
||||
@@ -119,7 +111,7 @@ class URLToRAGState(TypedDict, total=False):
|
||||
existing_document_id: str | None
|
||||
"""Document ID if URL was already processed."""
|
||||
|
||||
existing_document_metadata: dict[str, Any] | None
|
||||
existing_document_metadata: StatePayload | None
|
||||
"""Metadata of existing document if found."""
|
||||
|
||||
skipped_urls_count: int
|
||||
@@ -157,7 +149,7 @@ class URLToRAGState(TypedDict, total=False):
|
||||
url_hash: str | None
|
||||
"""SHA256 hash of URL for efficient lookup."""
|
||||
|
||||
existing_content: dict[str, Any] | None
|
||||
existing_content: StatePayload | None
|
||||
"""Found content metadata for deduplication."""
|
||||
|
||||
content_age_days: int | None
|
||||
@@ -169,8 +161,8 @@ class URLToRAGState(TypedDict, total=False):
|
||||
processing_reason: str | None
|
||||
"""Human-readable explanation of processing decision."""
|
||||
|
||||
scrape_params: dict[str, Any]
|
||||
scrape_params: StatePayload
|
||||
"""Parameters for scraping operations."""
|
||||
|
||||
r2r_params: dict[str, Any]
|
||||
r2r_params: StatePayload
|
||||
"""Parameters for R2R processing."""
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""State for URL to R2R workflow."""
|
||||
|
||||
from typing import Any, TypedDict
|
||||
"""State for URL to R2R workflow."""
|
||||
|
||||
from typing import TypedDict
|
||||
|
||||
from langchain_core.messages import AnyMessage
|
||||
|
||||
|
||||
class R2RInfo(TypedDict, total=False):
|
||||
@@ -13,12 +15,12 @@ class R2RInfo(TypedDict, total=False):
|
||||
upload_timestamp: str
|
||||
|
||||
|
||||
class URLToRAGState(TypedDict, total=False):
|
||||
"""State for URL to R2R processing."""
|
||||
|
||||
url: str
|
||||
processed_content: dict[str, str | dict[str, str]]
|
||||
r2r_info: R2RInfo
|
||||
upload_complete: bool
|
||||
messages: list[Any]
|
||||
errors: list[dict[str, str]]
|
||||
class URLToRAGState(TypedDict, total=False):
|
||||
"""State for URL to R2R processing."""
|
||||
|
||||
url: str
|
||||
processed_content: dict[str, str | dict[str, str]]
|
||||
r2r_info: R2RInfo
|
||||
upload_complete: bool
|
||||
messages: list[AnyMessage]
|
||||
errors: list[dict[str, str]]
|
||||
|
||||
@@ -6,7 +6,7 @@ in tools.py, enabling runtime validation, type coercion, and error reporting.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal
|
||||
from typing import Literal
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
@@ -38,9 +38,9 @@ class ToolInputDataModel(BaseModel):
|
||||
class ToolOutputDataModel(BaseModel):
|
||||
"""Pydantic model for ToolOutputData validation."""
|
||||
|
||||
result: str | dict[str, Any] | list[Any] | None = Field(
|
||||
None, description="Tool execution result"
|
||||
)
|
||||
result: str | dict[str, object] | list[object] | None = Field(
|
||||
None, description="Tool execution result"
|
||||
)
|
||||
status: Literal["success", "error", "partial"] = Field(
|
||||
..., description="Execution status"
|
||||
)
|
||||
@@ -428,56 +428,56 @@ class WebToolsStateModel(SourceMetadataModel, ProcessingStatusModel, Configurati
|
||||
|
||||
# Validation helper functions
|
||||
|
||||
def validate_tool_input_data(data: dict[str, Any]) -> ToolInputDataModel:
|
||||
def validate_tool_input_data(data: dict[str, object]) -> ToolInputDataModel:
|
||||
"""Validate tool input data using Pydantic model."""
|
||||
return ToolInputDataModel.model_validate(data)
|
||||
|
||||
|
||||
def validate_tool_output_data(data: dict[str, Any]) -> ToolOutputDataModel:
|
||||
def validate_tool_output_data(data: dict[str, object]) -> ToolOutputDataModel:
|
||||
"""Validate tool output data using Pydantic model."""
|
||||
return ToolOutputDataModel.model_validate(data)
|
||||
|
||||
|
||||
def validate_search_result(data: dict[str, Any]) -> SearchResultDictModel:
|
||||
def validate_search_result(data: dict[str, object]) -> SearchResultDictModel:
|
||||
"""Validate search result data using Pydantic model."""
|
||||
return SearchResultDictModel.model_validate(data)
|
||||
|
||||
|
||||
def validate_scraped_content(data: dict[str, Any]) -> ScrapedContentDictModel:
|
||||
def validate_scraped_content(data: dict[str, object]) -> ScrapedContentDictModel:
|
||||
"""Validate scraped content data using Pydantic model."""
|
||||
return ScrapedContentDictModel.model_validate(data)
|
||||
|
||||
|
||||
def validate_extracted_statistic(data: dict[str, Any]) -> ExtractedStatisticModel:
|
||||
def validate_extracted_statistic(data: dict[str, object]) -> ExtractedStatisticModel:
|
||||
"""Validate extracted statistic data using Pydantic model."""
|
||||
return ExtractedStatisticModel.model_validate(data)
|
||||
|
||||
|
||||
def validate_extracted_fact(data: dict[str, Any]) -> ExtractedFactModel:
|
||||
def validate_extracted_fact(data: dict[str, object]) -> ExtractedFactModel:
|
||||
"""Validate extracted fact data using Pydantic model."""
|
||||
return ExtractedFactModel.model_validate(data)
|
||||
|
||||
|
||||
def validate_tool_state(data: dict[str, Any]) -> ToolStateModel:
|
||||
def validate_tool_state(data: dict[str, object]) -> ToolStateModel:
|
||||
"""Validate tool state data using Pydantic model."""
|
||||
return ToolStateModel.model_validate(data)
|
||||
|
||||
|
||||
def validate_search_state(data: dict[str, Any]) -> SearchStateModel:
|
||||
def validate_search_state(data: dict[str, object]) -> SearchStateModel:
|
||||
"""Validate search state data using Pydantic model."""
|
||||
return SearchStateModel.model_validate(data)
|
||||
|
||||
|
||||
def validate_scraping_state(data: dict[str, Any]) -> ScrapingStateModel:
|
||||
def validate_scraping_state(data: dict[str, object]) -> ScrapingStateModel:
|
||||
"""Validate scraping state data using Pydantic model."""
|
||||
return ScrapingStateModel.model_validate(data)
|
||||
|
||||
|
||||
def validate_extraction_state(data: dict[str, Any]) -> ExtractionStateModel:
|
||||
def validate_extraction_state(data: dict[str, object]) -> ExtractionStateModel:
|
||||
"""Validate extraction state data using Pydantic model."""
|
||||
return ExtractionStateModel.model_validate(data)
|
||||
|
||||
|
||||
def validate_web_tools_state(data: dict[str, Any]) -> WebToolsStateModel:
|
||||
def validate_web_tools_state(data: dict[str, object]) -> WebToolsStateModel:
|
||||
"""Validate web tools state data using Pydantic model."""
|
||||
return WebToolsStateModel.model_validate(data)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Root pytest configuration with hierarchical fixtures."""
|
||||
|
||||
import asyncio
|
||||
import asyncio
|
||||
import inspect
|
||||
import importlib
|
||||
import importlib.util
|
||||
import os
|
||||
@@ -170,9 +170,27 @@ def _ensure_optional_dependency(package: str) -> None:
|
||||
_install_stub(package)
|
||||
|
||||
|
||||
_REQUIRED_DEPENDENCIES: tuple[str, ...] = ("langgraph", "langchain_core")
|
||||
|
||||
|
||||
def _require_dependency(package: str) -> None:
|
||||
"""Ensure a required dependency is importable for the test suite."""
|
||||
|
||||
try:
|
||||
importlib.import_module(package)
|
||||
except ModuleNotFoundError as exc: # pragma: no cover - defensive guard
|
||||
message = (
|
||||
f"The '{package}' package is required to run the test suite. "
|
||||
"Install the project dependencies with `pip install -r requirements.txt`."
|
||||
)
|
||||
raise RuntimeError(message) from exc
|
||||
|
||||
|
||||
for required_package in _REQUIRED_DEPENDENCIES:
|
||||
_require_dependency(required_package)
|
||||
|
||||
|
||||
for optional_package in (
|
||||
"langgraph",
|
||||
"langchain_core",
|
||||
"langchain_anthropic",
|
||||
"langchain_openai",
|
||||
"pydantic",
|
||||
@@ -278,10 +296,19 @@ def pytest_pyfunc_call(pyfuncitem: pytest.Function) -> bool | None:
|
||||
if marker is None:
|
||||
return None
|
||||
|
||||
if pyfuncitem.config.pluginmanager.hasplugin("asyncio"):
|
||||
return None
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
kwargs = {name: pyfuncitem.funcargs[name] for name in pyfuncitem._fixtureinfo.argnames} # type: ignore[attr-defined]
|
||||
loop.run_until_complete(pyfuncitem.obj(**kwargs))
|
||||
kwargs = {
|
||||
name: pyfuncitem.funcargs[name]
|
||||
for name in pyfuncitem._fixtureinfo.argnames # type: ignore[attr-defined]
|
||||
}
|
||||
result = pyfuncitem.obj(**kwargs)
|
||||
if not inspect.isawaitable(result):
|
||||
return None
|
||||
loop.run_until_complete(result)
|
||||
finally:
|
||||
loop.close()
|
||||
return True
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from collections.abc import MutableMapping
|
||||
from typing import TYPE_CHECKING, cast
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -12,7 +13,7 @@ from langchain_core.runnables import RunnableConfig
|
||||
from tests.helpers.mocks.mock_builders import MockLLMBuilder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from langgraph.graph.state import CompiledStateGraph as CompiledGraph
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
@@ -48,7 +49,7 @@ class TestAnalysisWorkflowE2E:
|
||||
return mock_service_factory
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data(self) -> dict[str, Any]:
|
||||
def sample_data(self) -> dict[str, object]:
|
||||
"""Sample data for analysis."""
|
||||
return {
|
||||
"sales_data": [
|
||||
@@ -146,7 +147,7 @@ class TestAnalysisWorkflowE2E:
|
||||
mock_call_model_node: AsyncMock,
|
||||
mock_get_global_factory: AsyncMock,
|
||||
mock_analyze_data: AsyncMock,
|
||||
sample_data: dict[str, Any],
|
||||
sample_data: dict[str, object],
|
||||
mock_data_analyzer: AsyncMock,
|
||||
mock_llm_analyzer: AsyncMock,
|
||||
) -> None:
|
||||
@@ -204,7 +205,7 @@ class TestAnalysisWorkflowE2E:
|
||||
# Create analysis graph (simplified for test) - using optimized version
|
||||
from biz_bud.graphs.graph import get_graph
|
||||
|
||||
analysis_graph: CompiledStateGraph[Any] = get_graph()
|
||||
analysis_graph: CompiledGraph[MutableMapping[str, object]] = get_graph()
|
||||
|
||||
# Initial state with data
|
||||
initial_state = {
|
||||
@@ -246,8 +247,8 @@ class TestAnalysisWorkflowE2E:
|
||||
|
||||
# Mock the LLM to handle the error
|
||||
async def mock_llm_response(
|
||||
state: dict[str, Any], config=None
|
||||
) -> dict[str, Any]:
|
||||
state: MutableMapping[str, object], config=None
|
||||
) -> MutableMapping[str, object]:
|
||||
# Return a simple error acknowledgment
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
@@ -268,7 +269,7 @@ class TestAnalysisWorkflowE2E:
|
||||
|
||||
from biz_bud.graphs.graph import get_graph
|
||||
|
||||
analysis_graph: CompiledStateGraph[Any] = get_graph()
|
||||
analysis_graph: CompiledGraph[MutableMapping[str, object]] = get_graph()
|
||||
|
||||
# Invalid data
|
||||
invalid_data = {
|
||||
@@ -313,7 +314,7 @@ class TestAnalysisWorkflowE2E:
|
||||
mock_get_global_factory: AsyncMock,
|
||||
mock_create_viz: AsyncMock,
|
||||
mock_analyze_data: AsyncMock,
|
||||
sample_data: dict[str, Any],
|
||||
sample_data: dict[str, object],
|
||||
mock_data_analyzer: AsyncMock,
|
||||
mock_llm_analyzer: AsyncMock,
|
||||
) -> None:
|
||||
@@ -373,7 +374,7 @@ class TestAnalysisWorkflowE2E:
|
||||
|
||||
from biz_bud.graphs.graph import get_graph
|
||||
|
||||
analysis_graph: CompiledStateGraph[Any] = get_graph()
|
||||
analysis_graph: CompiledGraph[MutableMapping[str, object]] = get_graph()
|
||||
|
||||
initial_state = {
|
||||
"messages": [
|
||||
@@ -409,7 +410,7 @@ class TestAnalysisWorkflowE2E:
|
||||
mock_call_model_node: AsyncMock,
|
||||
mock_get_global_factory: AsyncMock,
|
||||
mock_analyze_data: AsyncMock,
|
||||
sample_data: dict[str, Any],
|
||||
sample_data: dict[str, object],
|
||||
mock_data_analyzer: AsyncMock,
|
||||
) -> None:
|
||||
"""Test analysis workflow that includes strategic planning."""
|
||||
@@ -484,7 +485,7 @@ class TestAnalysisWorkflowE2E:
|
||||
|
||||
from biz_bud.graphs.graph import get_graph
|
||||
|
||||
analysis_graph: CompiledStateGraph[Any] = get_graph()
|
||||
analysis_graph: CompiledGraph[MutableMapping[str, object]] = get_graph()
|
||||
|
||||
initial_state = {
|
||||
"messages": [
|
||||
@@ -621,7 +622,7 @@ class TestAnalysisWorkflowE2E:
|
||||
|
||||
from biz_bud.graphs.graph import get_graph
|
||||
|
||||
analysis_graph: CompiledStateGraph[Any] = get_graph()
|
||||
analysis_graph: CompiledGraph[MutableMapping[str, object]] = get_graph()
|
||||
|
||||
initial_state = {
|
||||
"messages": [HumanMessage(content="Analyze large transaction dataset")],
|
||||
@@ -723,7 +724,7 @@ class TestAnalysisWorkflowE2E:
|
||||
|
||||
from biz_bud.graphs.graph import get_graph
|
||||
|
||||
analysis_graph: CompiledStateGraph[Any] = get_graph()
|
||||
analysis_graph: CompiledGraph[MutableMapping[str, object]] = get_graph()
|
||||
|
||||
initial_state = {
|
||||
"messages": [HumanMessage(content="Compare 2023 vs 2024 performance")],
|
||||
|
||||
@@ -4,7 +4,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from langgraph.graph.state import CompiledStateGraph as CompiledGraph
|
||||
|
||||
from biz_bud.states.error_handling import ErrorHandlingState
|
||||
from biz_bud.states.research import ResearchState
|
||||
@@ -18,7 +18,7 @@ class TestResearchGraph:
|
||||
from biz_bud.graphs.research.graph import create_research_graph
|
||||
|
||||
graph = create_research_graph()
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
assert hasattr(graph, 'ainvoke')
|
||||
|
||||
def test_research_graph_factory(self):
|
||||
@@ -27,7 +27,7 @@ class TestResearchGraph:
|
||||
|
||||
config: RunnableConfig = {"configurable": {"test": "config"}}
|
||||
graph = research_graph_factory(config)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_research_graph_factory_async(self):
|
||||
@@ -36,7 +36,7 @@ class TestResearchGraph:
|
||||
|
||||
config: RunnableConfig = {"configurable": {"test": "config"}}
|
||||
graph = await research_graph_factory_async(config)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
def test_research_graph_has_expected_nodes(self):
|
||||
"""Test that research graph contains expected nodes."""
|
||||
@@ -58,7 +58,7 @@ class TestCatalogGraph:
|
||||
from biz_bud.graphs.catalog.graph import create_catalog_graph
|
||||
|
||||
graph = create_catalog_graph()
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
def test_catalog_graph_factory(self):
|
||||
"""Test catalog graph factory function."""
|
||||
@@ -66,7 +66,7 @@ class TestCatalogGraph:
|
||||
|
||||
config: RunnableConfig = {"configurable": {"test": "config"}}
|
||||
graph = catalog_graph_factory(config)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_catalog_graph_factory_async(self):
|
||||
@@ -76,7 +76,7 @@ class TestCatalogGraph:
|
||||
config: RunnableConfig = {"configurable": {"test": "config"}}
|
||||
import asyncio
|
||||
graph = await asyncio.to_thread(catalog_graph_factory, config)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
|
||||
class TestAnalysisGraph:
|
||||
@@ -87,7 +87,7 @@ class TestAnalysisGraph:
|
||||
from biz_bud.graphs.analysis.graph import create_analysis_graph
|
||||
|
||||
graph = create_analysis_graph()
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
def test_analysis_graph_factory(self):
|
||||
"""Test analysis graph factory function."""
|
||||
@@ -95,7 +95,7 @@ class TestAnalysisGraph:
|
||||
|
||||
config: RunnableConfig = {"configurable": {"test": "config"}}
|
||||
graph = analysis_graph_factory(config)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analysis_graph_factory_async(self):
|
||||
@@ -104,7 +104,7 @@ class TestAnalysisGraph:
|
||||
|
||||
config: RunnableConfig = {"configurable": {"test": "config"}}
|
||||
graph = await analysis_graph_factory_async(config)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
|
||||
class TestScrapingGraph:
|
||||
@@ -115,7 +115,7 @@ class TestScrapingGraph:
|
||||
from biz_bud.graphs.scraping.graph import create_scraping_graph
|
||||
|
||||
graph = create_scraping_graph()
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
def test_scraping_graph_factory(self):
|
||||
"""Test scraping graph factory function."""
|
||||
@@ -123,7 +123,7 @@ class TestScrapingGraph:
|
||||
|
||||
config: RunnableConfig = {"configurable": {"test": "config"}}
|
||||
graph = scraping_graph_factory(config)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scraping_graph_factory_async(self):
|
||||
@@ -132,7 +132,7 @@ class TestScrapingGraph:
|
||||
|
||||
config: RunnableConfig = {"configurable": {"test": "config"}}
|
||||
graph = await scraping_graph_factory_async(config)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
|
||||
class TestRAGGraph:
|
||||
@@ -143,7 +143,7 @@ class TestRAGGraph:
|
||||
from biz_bud.graphs.rag.graph import create_url_to_r2r_graph
|
||||
|
||||
graph = create_url_to_r2r_graph()
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
def test_rag_graph_with_config(self):
|
||||
"""Test RAG graph creation with config."""
|
||||
@@ -151,7 +151,7 @@ class TestRAGGraph:
|
||||
|
||||
config = {"test": "config"}
|
||||
graph = create_url_to_r2r_graph(config)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
|
||||
class TestPlannerGraph:
|
||||
@@ -162,7 +162,7 @@ class TestPlannerGraph:
|
||||
from biz_bud.graphs.planner import create_planner_graph
|
||||
|
||||
graph = create_planner_graph()
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
def test_planner_graph_factory(self):
|
||||
"""Test planner graph factory function."""
|
||||
@@ -170,7 +170,7 @@ class TestPlannerGraph:
|
||||
|
||||
config: RunnableConfig = {"configurable": {"test": "config"}}
|
||||
graph = planner_graph_factory(config)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_planner_graph_factory_async(self):
|
||||
@@ -179,7 +179,7 @@ class TestPlannerGraph:
|
||||
|
||||
config: RunnableConfig = {"configurable": {"test": "config"}}
|
||||
graph = await planner_graph_factory_async(config)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
|
||||
class TestErrorHandlingGraph:
|
||||
@@ -190,7 +190,7 @@ class TestErrorHandlingGraph:
|
||||
from biz_bud.graphs.error_handling import create_error_handling_graph
|
||||
|
||||
graph = create_error_handling_graph()
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
def test_error_handling_graph_factory(self):
|
||||
"""Test error handling graph factory function."""
|
||||
@@ -198,7 +198,7 @@ class TestErrorHandlingGraph:
|
||||
|
||||
config: RunnableConfig = {"configurable": {"test": "config"}}
|
||||
graph = error_handling_graph_factory(config)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_graph_factory_async(self):
|
||||
@@ -207,7 +207,7 @@ class TestErrorHandlingGraph:
|
||||
|
||||
config: RunnableConfig = {"configurable": {"test": "config"}}
|
||||
graph = await error_handling_graph_factory_async(config)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
|
||||
class TestGraphStateCompatibility:
|
||||
@@ -316,7 +316,7 @@ class TestGraphBuilderPatternConsistency:
|
||||
]
|
||||
|
||||
for graph in graphs:
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
assert hasattr(graph, 'ainvoke')
|
||||
assert hasattr(graph, 'astream')
|
||||
|
||||
@@ -344,7 +344,7 @@ class TestGraphBuilderPatternConsistency:
|
||||
|
||||
for factory in factories:
|
||||
graph = factory(config)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
|
||||
class TestGraphExecutionCompatibility:
|
||||
@@ -493,7 +493,7 @@ class TestEndToEndIntegration:
|
||||
for name, creator in graph_creators:
|
||||
try:
|
||||
graph = creator()
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
successful_graphs.append(name)
|
||||
except Exception as e:
|
||||
failed_graphs.append((name, str(e)))
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import pytest
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from langgraph.graph.state import CompiledStateGraph as CompiledGraph
|
||||
|
||||
|
||||
class TestGraphBuilderIntegration:
|
||||
@@ -13,7 +13,7 @@ class TestGraphBuilderIntegration:
|
||||
from biz_bud.graphs.research.graph import create_research_graph
|
||||
|
||||
graph = create_research_graph()
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
assert hasattr(graph, 'ainvoke')
|
||||
|
||||
def test_catalog_graph_creation(self):
|
||||
@@ -21,7 +21,7 @@ class TestGraphBuilderIntegration:
|
||||
from biz_bud.graphs.catalog.graph import create_catalog_graph
|
||||
|
||||
graph = create_catalog_graph()
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
assert hasattr(graph, 'ainvoke')
|
||||
|
||||
def test_analysis_graph_creation(self):
|
||||
@@ -29,7 +29,7 @@ class TestGraphBuilderIntegration:
|
||||
from biz_bud.graphs.analysis.graph import create_analysis_graph
|
||||
|
||||
graph = create_analysis_graph()
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
assert hasattr(graph, 'ainvoke')
|
||||
|
||||
def test_scraping_graph_creation(self):
|
||||
@@ -37,7 +37,7 @@ class TestGraphBuilderIntegration:
|
||||
from biz_bud.graphs.scraping.graph import create_scraping_graph
|
||||
|
||||
graph = create_scraping_graph()
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
assert hasattr(graph, 'ainvoke')
|
||||
|
||||
def test_rag_graph_creation(self):
|
||||
@@ -45,7 +45,7 @@ class TestGraphBuilderIntegration:
|
||||
from biz_bud.graphs.rag.graph import create_url_to_r2r_graph
|
||||
|
||||
graph = create_url_to_r2r_graph()
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
assert hasattr(graph, 'ainvoke')
|
||||
|
||||
def test_planner_graph_creation(self):
|
||||
@@ -53,7 +53,7 @@ class TestGraphBuilderIntegration:
|
||||
from biz_bud.graphs.planner import create_planner_graph
|
||||
|
||||
graph = create_planner_graph()
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
assert hasattr(graph, 'ainvoke')
|
||||
|
||||
def test_error_handling_graph_creation(self):
|
||||
@@ -61,7 +61,7 @@ class TestGraphBuilderIntegration:
|
||||
from biz_bud.graphs.error_handling import create_error_handling_graph
|
||||
|
||||
graph = create_error_handling_graph()
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
assert hasattr(graph, 'ainvoke')
|
||||
|
||||
@pytest.mark.integration
|
||||
@@ -91,7 +91,7 @@ class TestGraphBuilderIntegration:
|
||||
for name, creator in graph_creators:
|
||||
try:
|
||||
graph = creator()
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
assert hasattr(graph, 'ainvoke')
|
||||
successful_graphs.append(name)
|
||||
except Exception as e:
|
||||
@@ -132,7 +132,7 @@ class TestGraphBuilderIntegration:
|
||||
|
||||
for factory in factories:
|
||||
graph = factory(config)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
assert hasattr(graph, 'ainvoke')
|
||||
|
||||
def test_builder_pattern_consistency(self):
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Lightweight test-oriented stub of the LangGraph package."""
|
||||
|
||||
from .graph import END, START, StateGraph
|
||||
|
||||
__all__ = ["StateGraph", "START", "END"]
|
||||
6
tests/stubs/langgraph/cache/__init__.py
vendored
6
tests/stubs/langgraph/cache/__init__.py
vendored
@@ -1,6 +0,0 @@
|
||||
"""Cache subpackage for LangGraph test stubs."""
|
||||
|
||||
from .base import BaseCache
|
||||
from .memory import InMemoryCache
|
||||
|
||||
__all__ = ["BaseCache", "InMemoryCache"]
|
||||
22
tests/stubs/langgraph/cache/base.py
vendored
22
tests/stubs/langgraph/cache/base.py
vendored
@@ -1,22 +0,0 @@
|
||||
"""Cache protocol used by the LangGraph test stub."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class BaseCache(Protocol):
|
||||
"""Protocol capturing the minimal cache interface exercised in tests."""
|
||||
|
||||
async def aget(self, key: str) -> Any: # pragma: no cover - interface only
|
||||
...
|
||||
|
||||
async def aset(self, key: str, value: Any) -> None: # pragma: no cover - interface only
|
||||
...
|
||||
|
||||
async def adelete(self, key: str) -> None: # pragma: no cover - interface only
|
||||
...
|
||||
|
||||
|
||||
__all__ = ["BaseCache"]
|
||||
29
tests/stubs/langgraph/cache/memory.py
vendored
29
tests/stubs/langgraph/cache/memory.py
vendored
@@ -1,29 +0,0 @@
|
||||
"""Minimal in-memory cache implementation for tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from .base import BaseCache
|
||||
|
||||
|
||||
class InMemoryCache(BaseCache):
|
||||
"""Dictionary-backed async cache used in unit tests."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._store: Dict[str, Any] = {}
|
||||
|
||||
async def aget(self, key: str) -> Optional[Any]:
|
||||
return self._store.get(key)
|
||||
|
||||
async def aset(self, key: str, value: Any) -> None:
|
||||
self._store[key] = value
|
||||
|
||||
async def adelete(self, key: str) -> None:
|
||||
self._store.pop(key, None)
|
||||
|
||||
def clear(self) -> None:
|
||||
self._store.clear()
|
||||
|
||||
|
||||
__all__ = ["InMemoryCache"]
|
||||
@@ -1,19 +0,0 @@
|
||||
"""Checkpoint base classes used by the LangGraph stub."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class BaseCheckpointSaver(Protocol):
|
||||
"""Protocol capturing the limited API exercised by the tests."""
|
||||
|
||||
def save(self, state: Any) -> None: # pragma: no cover - interface only
|
||||
...
|
||||
|
||||
def load(self) -> Any: # pragma: no cover - interface only
|
||||
...
|
||||
|
||||
|
||||
__all__ = ["BaseCheckpointSaver"]
|
||||
@@ -1,23 +0,0 @@
|
||||
"""In-memory checkpoint saver used by LangGraph tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseCheckpointSaver
|
||||
|
||||
|
||||
class InMemorySaver(BaseCheckpointSaver):
|
||||
"""Trivial checkpoint saver that stores the latest state in memory."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._state: Any | None = None
|
||||
|
||||
def save(self, state: Any) -> None:
|
||||
self._state = state
|
||||
|
||||
def load(self) -> Any:
|
||||
return self._state
|
||||
|
||||
|
||||
__all__ = ["InMemorySaver"]
|
||||
@@ -1,264 +0,0 @@
|
||||
"""Minimal subset of LangGraph graph primitives for unit testing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, Generic, List, Mapping, Sequence, Tuple, TypeVar, TYPE_CHECKING
|
||||
|
||||
from langgraph.cache.base import BaseCache
|
||||
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||
from langgraph.store.base import BaseStore
|
||||
from langgraph.types import All
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover - hinting only
|
||||
from langgraph.graph.state import CachePolicy, RetryPolicy
|
||||
else: # pragma: no cover - fallback for runtime without circular imports
|
||||
CachePolicy = Any # type: ignore[assignment]
|
||||
RetryPolicy = Any # type: ignore[assignment]
|
||||
|
||||
StateT = TypeVar("StateT")
|
||||
NodeCallable = Callable[[Any], Any]
|
||||
RouterCallable = Callable[[Any], str]
|
||||
|
||||
START: str = "__start__"
|
||||
END: str = "__end__"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class NodeSpec:
|
||||
"""Lightweight representation of a configured LangGraph node."""
|
||||
|
||||
func: NodeCallable
|
||||
metadata: Dict[str, Any]
|
||||
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None
|
||||
cache_policy: CachePolicy | None
|
||||
defer: bool
|
||||
input_schema: type[Any] | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ConditionalEdges:
|
||||
"""Internal representation of conditional routing information."""
|
||||
|
||||
source: str | object
|
||||
router: RouterCallable
|
||||
mapping: Mapping[str, str] | Sequence[str]
|
||||
|
||||
|
||||
class StateGraph(Generic[StateT]):
|
||||
"""Simplified builder that mimics the LangGraph public API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state_schema: type[StateT],
|
||||
*,
|
||||
context_schema: type[Any] | None = None,
|
||||
input_schema: type[Any] | None = None,
|
||||
output_schema: type[Any] | None = None,
|
||||
) -> None:
|
||||
self.state_schema = state_schema
|
||||
self.context_schema = context_schema
|
||||
self.input_schema = input_schema
|
||||
self.output_schema = output_schema
|
||||
self._nodes: Dict[str, NodeSpec] = {}
|
||||
self._edges: List[Tuple[str | object, str | object]] = []
|
||||
self._conditional_edges: List[_ConditionalEdges] = []
|
||||
self._entry_point: str | None = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Graph mutation helpers
|
||||
# ------------------------------------------------------------------
|
||||
def add_node(
|
||||
self,
|
||||
name: str,
|
||||
func: NodeCallable,
|
||||
*,
|
||||
metadata: Dict[str, Any] | None = None,
|
||||
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None,
|
||||
cache_policy: CachePolicy | None = None,
|
||||
defer: bool = False,
|
||||
input_schema: type[Any] | None = None,
|
||||
**_: Any,
|
||||
) -> None:
|
||||
self._nodes[name] = NodeSpec(
|
||||
func=func,
|
||||
metadata=dict(metadata or {}),
|
||||
retry_policy=retry_policy,
|
||||
cache_policy=cache_policy,
|
||||
defer=defer,
|
||||
input_schema=input_schema,
|
||||
)
|
||||
|
||||
def add_edge(self, source: str | object, target: str | object) -> None:
|
||||
self._edges.append((source, target))
|
||||
|
||||
def add_conditional_edges(
|
||||
self,
|
||||
source: str | object,
|
||||
router: RouterCallable,
|
||||
mapping: Mapping[str, str] | Sequence[str],
|
||||
) -> None:
|
||||
self._conditional_edges.append(
|
||||
_ConditionalEdges(source=source, router=router, mapping=mapping)
|
||||
)
|
||||
|
||||
def set_entry_point(self, node: str) -> None:
|
||||
self._entry_point = node
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Compilation
|
||||
# ------------------------------------------------------------------
|
||||
def compile(
|
||||
self,
|
||||
*,
|
||||
checkpointer: BaseCheckpointSaver[Any] | None = None,
|
||||
cache: BaseCache[Any] | None = None,
|
||||
store: BaseStore[Any] | None = None,
|
||||
interrupt_before: All | Sequence[str] | None = None,
|
||||
interrupt_after: All | Sequence[str] | None = None,
|
||||
debug: bool = False,
|
||||
name: str | None = None,
|
||||
) -> "CompiledStateGraph[StateT]":
|
||||
self._validate()
|
||||
|
||||
alias = name
|
||||
context_schema = self.context_schema
|
||||
input_schema = self.input_schema
|
||||
output_schema = self.output_schema
|
||||
if alias:
|
||||
if context_schema is not None:
|
||||
context_schema = _alias_schema(context_schema, f"{alias}_context")
|
||||
if input_schema is not None:
|
||||
input_schema = _alias_schema(input_schema, f"{alias}_input")
|
||||
if output_schema is not None:
|
||||
output_schema = _alias_schema(output_schema, f"{alias}_output")
|
||||
|
||||
compiled = CompiledStateGraph(
|
||||
builder=self,
|
||||
checkpointer=checkpointer,
|
||||
cache=cache,
|
||||
store=store,
|
||||
interrupt_before_nodes=_normalise_interrupts(interrupt_before),
|
||||
interrupt_after_nodes=_normalise_interrupts(interrupt_after),
|
||||
debug=debug,
|
||||
name=name,
|
||||
context_schema=context_schema,
|
||||
input_schema=input_schema,
|
||||
output_schema=output_schema,
|
||||
)
|
||||
return compiled
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Introspection helpers used in tests
|
||||
# ------------------------------------------------------------------
|
||||
@property
|
||||
def nodes(self) -> Mapping[str, NodeSpec]:
|
||||
return dict(self._nodes)
|
||||
|
||||
@property
|
||||
def edges(self) -> Sequence[Tuple[str | object, str | object]]:
|
||||
return list(self._edges)
|
||||
|
||||
@property
|
||||
def conditional_edges(self) -> Sequence[_ConditionalEdges]:
|
||||
return list(self._conditional_edges)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _validate(self) -> None:
|
||||
if not self._nodes:
|
||||
raise ValueError("Graph must have an entrypoint")
|
||||
|
||||
# Determine the effective entry point.
|
||||
entry_point = self._entry_point
|
||||
if entry_point is None:
|
||||
for source, target in self._edges:
|
||||
if source == START and isinstance(target, str):
|
||||
entry_point = target
|
||||
break
|
||||
if entry_point is None:
|
||||
raise ValueError("Graph must have an entrypoint")
|
||||
|
||||
# Validate that all referenced nodes exist (ignoring START/END sentinels).
|
||||
for source, target in self._edges:
|
||||
if source not in {START, END} and source not in self._nodes:
|
||||
raise ValueError("Found edge starting at unknown node: {0}".format(source))
|
||||
if target not in {START, END} and target not in self._nodes:
|
||||
raise ValueError("Found edge ending at unknown node: {0}".format(target))
|
||||
|
||||
for cond in self._conditional_edges:
|
||||
if cond.source not in {START, END} and cond.source not in self._nodes:
|
||||
raise ValueError(
|
||||
"Found conditional edge starting at unknown node: {0}".format(
|
||||
cond.source
|
||||
)
|
||||
)
|
||||
if isinstance(cond.mapping, Mapping):
|
||||
missing = [
|
||||
dest
|
||||
for dest in cond.mapping.values()
|
||||
if dest not in {END} and dest not in self._nodes
|
||||
]
|
||||
if missing:
|
||||
raise ValueError(
|
||||
"Found conditional edge ending at unknown node: {0}".format(
|
||||
", ".join(missing)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _normalise_interrupts(value: All | Sequence[str] | None) -> List[str] | All | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return [value]
|
||||
if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
|
||||
return list(value)
|
||||
return value
|
||||
|
||||
|
||||
def _alias_schema(schema: type[Any], alias: str) -> type[Any]:
|
||||
try:
|
||||
setattr(schema, "__name__", alias)
|
||||
except Exception: # pragma: no cover - defensive fallback
|
||||
pass
|
||||
return schema
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompiledStateGraph(Generic[StateT]):
|
||||
"""Runtime representation returned by :meth:`StateGraph.compile`."""
|
||||
|
||||
builder: StateGraph[StateT]
|
||||
checkpointer: BaseCheckpointSaver[Any] | None = None
|
||||
cache: BaseCache[Any] | None = None
|
||||
store: BaseStore[Any] | None = None
|
||||
interrupt_before_nodes: List[str] | All | None = field(default_factory=list)
|
||||
interrupt_after_nodes: List[str] | All | None = field(default_factory=list)
|
||||
debug: bool = False
|
||||
name: str | None = None
|
||||
context_schema: type[Any] | None = None
|
||||
input_schema: type[Any] | None = None
|
||||
output_schema: type[Any] | None = None
|
||||
|
||||
@property
|
||||
def InputType(self) -> type[Any] | None:
|
||||
"""Compatibility alias for LangGraph 1.x compiled graphs."""
|
||||
|
||||
return self.input_schema
|
||||
|
||||
@property
|
||||
def OutputType(self) -> type[Any] | None:
|
||||
"""Compatibility alias for LangGraph 1.x compiled graphs."""
|
||||
|
||||
return self.output_schema
|
||||
|
||||
@property
|
||||
def ContextType(self) -> type[Any] | None:
|
||||
"""Compatibility alias for LangGraph 1.x compiled graphs."""
|
||||
|
||||
return self.context_schema
|
||||
|
||||
|
||||
__all__ = ["StateGraph", "START", "END", "CompiledStateGraph", "NodeSpec"]
|
||||
@@ -1,17 +0,0 @@
|
||||
"""Stubbed message utilities for LangGraph integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Iterable
|
||||
|
||||
|
||||
def add_messages(state: dict[str, Any], messages: Iterable[Any]) -> dict[str, Any]:
|
||||
"""Append messages to the state's ``messages`` list."""
|
||||
|
||||
existing = state.setdefault("messages", [])
|
||||
if isinstance(existing, list):
|
||||
existing.extend(list(messages))
|
||||
return state
|
||||
|
||||
|
||||
__all__ = ["add_messages"]
|
||||
@@ -1,47 +0,0 @@
|
||||
"""LangGraph graph state helpers used across the Biz Bud tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Sequence
|
||||
|
||||
from . import CompiledStateGraph
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class RetryPolicy:
|
||||
"""Stubbed retry policy with configurable attempt limits."""
|
||||
|
||||
max_attempts: int = 1
|
||||
min_backoff: float | None = None
|
||||
max_backoff: float | None = None
|
||||
backoff_multiplier: float | None = None
|
||||
jitter: float | None = None
|
||||
|
||||
def as_sequence(self) -> Sequence["RetryPolicy"]:
|
||||
"""Return the policy as a singleton sequence for convenience."""
|
||||
|
||||
return (self,)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class CachePolicy:
|
||||
"""Minimal cache policy matching the attributes used in tests."""
|
||||
|
||||
namespace: str | None = None
|
||||
key: str | None = None
|
||||
ttl: float | None = None
|
||||
populate: bool = True
|
||||
|
||||
def describe(self) -> dict[str, Any]:
|
||||
"""Expose a serialisable representation for debugging."""
|
||||
|
||||
return {
|
||||
"namespace": self.namespace,
|
||||
"key": self.key,
|
||||
"ttl": self.ttl,
|
||||
"populate": self.populate,
|
||||
}
|
||||
|
||||
|
||||
__all__ = ["CachePolicy", "RetryPolicy", "CompiledStateGraph"]
|
||||
@@ -1,19 +0,0 @@
|
||||
"""Store base classes used by the LangGraph stub."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class BaseStore(Protocol):
|
||||
"""Protocol capturing the minimal store interface used in tests."""
|
||||
|
||||
def put(self, key: str, value: Any) -> None: # pragma: no cover - interface only
|
||||
...
|
||||
|
||||
def get(self, key: str) -> Any: # pragma: no cover - interface only
|
||||
...
|
||||
|
||||
|
||||
__all__ = ["BaseStore"]
|
||||
@@ -1,26 +0,0 @@
|
||||
"""Simple in-memory store implementation for tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from .base import BaseStore
|
||||
|
||||
|
||||
class InMemoryStore(BaseStore):
|
||||
"""Minimal dictionary-backed store."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._store: Dict[str, Any] = {}
|
||||
|
||||
def put(self, key: str, value: Any) -> None:
|
||||
self._store[key] = value
|
||||
|
||||
def get(self, key: str) -> Optional[Any]:
|
||||
return self._store.get(key)
|
||||
|
||||
def clear(self) -> None:
|
||||
self._store.clear()
|
||||
|
||||
|
||||
__all__ = ["InMemoryStore"]
|
||||
@@ -1,43 +0,0 @@
|
||||
"""Common type hints and helper payloads exposed by the LangGraph stub."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Generic, Literal, Mapping, TypeVar
|
||||
|
||||
All = Literal["*"]
|
||||
|
||||
TTarget = TypeVar("TTarget")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class Command(Generic[TTarget]):
|
||||
"""Simplified representation of LangGraph's command-based routing payload."""
|
||||
|
||||
goto: TTarget | None = None
|
||||
update: Dict[str, Any] = field(default_factory=dict)
|
||||
graph: str | None = None
|
||||
|
||||
# Class-level sentinel used to signal returning to the parent graph.
|
||||
PARENT: str = "__parent__"
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
"""Return a serialisable view of the command for debugging."""
|
||||
|
||||
return {"goto": self.goto, "update": dict(self.update), "graph": self.graph}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class Send(Generic[TTarget]):
|
||||
"""Parallel dispatch payload used by LangGraph for fan-out patterns."""
|
||||
|
||||
target: TTarget
|
||||
state: Mapping[str, Any]
|
||||
|
||||
def with_updates(self, updates: Mapping[str, Any]) -> "Send[TTarget]":
|
||||
merged: Dict[str, Any] = dict(self.state)
|
||||
merged.update(dict(updates))
|
||||
return Send(target=self.target, state=merged)
|
||||
|
||||
|
||||
__all__ = ["All", "Command", "Send"]
|
||||
@@ -4,7 +4,7 @@ from typing import Any, TypedDict
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from langgraph.graph.state import CachePolicy, CompiledStateGraph, RetryPolicy
|
||||
from langgraph.graph.state import CachePolicy, CompiledStateGraph as CompiledGraph, RetryPolicy
|
||||
from langgraph.cache.memory import InMemoryCache
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
@@ -160,7 +160,7 @@ class TestBuildGraphFromConfig:
|
||||
)
|
||||
|
||||
graph = build_graph_from_config(config)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
def test_graph_with_conditional_edges(self, sample_node, sample_router):
|
||||
"""Test building graph with conditional edges."""
|
||||
@@ -189,7 +189,7 @@ class TestBuildGraphFromConfig:
|
||||
)
|
||||
|
||||
graph = build_graph_from_config(config)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
def test_graph_with_entry_point(self, sample_node):
|
||||
"""Test building graph with custom entry point."""
|
||||
@@ -202,7 +202,7 @@ class TestBuildGraphFromConfig:
|
||||
)
|
||||
|
||||
graph = build_graph_from_config(config)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
def test_graph_with_checkpointer(self, sample_node):
|
||||
"""Test building graph with checkpointer."""
|
||||
@@ -217,7 +217,7 @@ class TestBuildGraphFromConfig:
|
||||
)
|
||||
|
||||
graph = build_graph_from_config(config)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
def test_graph_with_node_config_options(self, sample_node):
|
||||
"""GraphBuilderConfig should accept modern node configuration options."""
|
||||
@@ -238,7 +238,7 @@ class TestBuildGraphFromConfig:
|
||||
)
|
||||
|
||||
graph = build_graph_from_config(config)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
node = graph.builder.nodes["process"]
|
||||
assert node.metadata == {"role": "primary"}
|
||||
@@ -288,7 +288,7 @@ class TestGraphBuilder:
|
||||
.build()
|
||||
)
|
||||
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
def test_fluent_api_with_conditional_edge(self, sample_node, sample_router):
|
||||
"""Test fluent API with conditional edges."""
|
||||
@@ -311,7 +311,7 @@ class TestGraphBuilder:
|
||||
.build()
|
||||
)
|
||||
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
def test_fluent_api_with_metadata(self, sample_node):
|
||||
"""Test fluent API with metadata."""
|
||||
@@ -324,7 +324,7 @@ class TestGraphBuilder:
|
||||
.build()
|
||||
)
|
||||
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
# Note: LangGraph may not expose config on compiled graphs
|
||||
# The metadata is used internally during graph building
|
||||
|
||||
@@ -368,7 +368,7 @@ class TestHelperFunctions:
|
||||
nodes = [("step1", node1), ("step2", node2), ("step3", node3)]
|
||||
|
||||
graph = create_simple_linear_graph(TestState, nodes)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
def test_create_simple_linear_graph_empty_nodes(self):
|
||||
"""Test creating linear graph with empty nodes raises error."""
|
||||
@@ -382,7 +382,7 @@ class TestHelperFunctions:
|
||||
nodes = [("step1", node1)]
|
||||
|
||||
graph = create_simple_linear_graph(TestState, nodes, mock_checkpointer)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
def test_create_simple_linear_graph_with_node_options(self):
|
||||
"""Linear helper should accept per-node LangGraph options."""
|
||||
@@ -428,7 +428,7 @@ class TestHelperFunctions:
|
||||
sample_router,
|
||||
branches
|
||||
)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
def test_create_branching_graph_with_checkpointer(self, sample_router):
|
||||
"""Test creating branching graph with checkpointer."""
|
||||
@@ -443,7 +443,7 @@ class TestHelperFunctions:
|
||||
branches,
|
||||
mock_checkpointer
|
||||
)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
def test_create_branching_graph_empty_branches(self, sample_router):
|
||||
"""Test creating branching graph handles empty branches."""
|
||||
@@ -456,7 +456,7 @@ class TestHelperFunctions:
|
||||
sample_router,
|
||||
branches
|
||||
)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
@@ -500,7 +500,7 @@ class TestErrorHandling:
|
||||
)
|
||||
|
||||
graph = build_graph_from_config(config)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
|
||||
class TestIntegrationWithExistingGraphs:
|
||||
@@ -535,7 +535,7 @@ class TestIntegrationWithExistingGraphs:
|
||||
.build()
|
||||
)
|
||||
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
def test_research_pattern_simulation(self):
|
||||
"""Test pattern similar to research graph refactoring."""
|
||||
@@ -577,7 +577,7 @@ class TestIntegrationWithExistingGraphs:
|
||||
)
|
||||
|
||||
graph = build_graph_from_config(config)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert isinstance(graph, CompiledGraph)
|
||||
|
||||
def test_fluent_api_with_runtime_options(self, sample_node):
|
||||
"""Fluent builder should expose new LangGraph runtime options."""
|
||||
|
||||
@@ -55,16 +55,17 @@ def base_error_state() -> ErrorHandlingState:
|
||||
input_state={},
|
||||
execution_count=0,
|
||||
),
|
||||
current_error={
|
||||
"message": "Test error",
|
||||
"error_type": "TestError",
|
||||
"node": "test_node",
|
||||
"severity": "error",
|
||||
"category": ErrorCategory.UNKNOWN.value,
|
||||
"timestamp": "2024-01-01T00:00:00",
|
||||
"context": {},
|
||||
"traceback": None,
|
||||
},
|
||||
current_error={
|
||||
"message": "Test error",
|
||||
"error_type": "TestError",
|
||||
"node": "test_node",
|
||||
"severity": "error",
|
||||
"category": ErrorCategory.UNKNOWN.value,
|
||||
"timestamp": "2024-01-01T00:00:00",
|
||||
"context": {},
|
||||
"details": {},
|
||||
"traceback": None,
|
||||
},
|
||||
attempted_actions=[],
|
||||
)
|
||||
|
||||
@@ -830,7 +831,7 @@ class TestConfigurationFunctions:
|
||||
config: RunnableConfig = {"configurable": {"test": "config"}}
|
||||
graph = error_handling_graph_factory(config)
|
||||
assert graph is not None
|
||||
# CompiledStateGraph has ainvoke method
|
||||
# CompiledGraph has ainvoke method
|
||||
assert hasattr(graph, "ainvoke")
|
||||
|
||||
# NOTE: get_next_node_function was removed - functionality moved to edge helpers
|
||||
|
||||
@@ -8,7 +8,8 @@ from langchain_core.messages import AIMessage, ToolMessage
|
||||
if TYPE_CHECKING:
|
||||
from biz_bud.states.tools import ToolState
|
||||
|
||||
from biz_bud.states.tools import create_scraped_content_dict
|
||||
from biz_bud.core.errors import ValidationError
|
||||
from biz_bud.states.tools import create_scraped_content_dict
|
||||
|
||||
|
||||
class TestToolState:
|
||||
@@ -499,41 +500,53 @@ class TestScrapedContentDictCreator:
|
||||
|
||||
assert result.get("error_message") == special_error
|
||||
|
||||
def test_create_scraped_content_dict_validates_empty_url(self):
|
||||
"""Test that empty URL raises ValueError."""
|
||||
with pytest.raises(ValueError, match="URL cannot be empty or whitespace"):
|
||||
create_scraped_content_dict(url="", content="Content")
|
||||
|
||||
def test_create_scraped_content_dict_validates_whitespace_url(self):
|
||||
"""Test that whitespace-only URL raises ValueError."""
|
||||
with pytest.raises(ValueError, match="URL cannot be empty or whitespace"):
|
||||
create_scraped_content_dict(url=" ", content="Content")
|
||||
|
||||
def test_create_scraped_content_dict_validates_empty_content(self):
|
||||
"""Test that empty content raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Content cannot be empty or whitespace"):
|
||||
create_scraped_content_dict(url="https://example.com", content="")
|
||||
|
||||
def test_create_scraped_content_dict_validates_whitespace_content(self):
|
||||
"""Test that whitespace-only content raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Content cannot be empty or whitespace"):
|
||||
create_scraped_content_dict(url="https://example.com", content=" ")
|
||||
|
||||
def test_create_scraped_content_dict_validates_negative_content_length(self):
|
||||
"""Test that negative content length raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Content length cannot be negative"):
|
||||
create_scraped_content_dict(
|
||||
url="https://example.com",
|
||||
content="Content",
|
||||
content_length=-1
|
||||
)
|
||||
|
||||
def test_create_scraped_content_dict_validates_success_with_error_message(self):
|
||||
"""Test that success=True with error_message raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Error message should not be present when success is True"):
|
||||
create_scraped_content_dict(
|
||||
url="https://example.com",
|
||||
content="Content",
|
||||
success=True,
|
||||
error_message="This should not be here"
|
||||
)
|
||||
def test_create_scraped_content_dict_validates_empty_url(self):
|
||||
"""Test that empty URL raises ValidationError."""
|
||||
with pytest.raises(ValidationError, match="URL cannot be empty or whitespace"):
|
||||
create_scraped_content_dict(url="", content="Content")
|
||||
|
||||
def test_create_scraped_content_dict_validates_whitespace_url(self):
|
||||
"""Test that whitespace-only URL raises ValidationError."""
|
||||
with pytest.raises(ValidationError, match="URL cannot be empty or whitespace"):
|
||||
create_scraped_content_dict(url=" ", content="Content")
|
||||
|
||||
def test_create_scraped_content_dict_validates_empty_content(self):
|
||||
"""Test that empty content raises ValidationError when success is True."""
|
||||
with pytest.raises(ValidationError, match="Content cannot be empty or whitespace"):
|
||||
create_scraped_content_dict(url="https://example.com", content="")
|
||||
|
||||
def test_create_scraped_content_dict_validates_whitespace_content(self):
|
||||
"""Test that whitespace-only content raises ValidationError when success is True."""
|
||||
with pytest.raises(ValidationError, match="Content cannot be empty or whitespace"):
|
||||
create_scraped_content_dict(url="https://example.com", content=" ")
|
||||
|
||||
def test_create_scraped_content_dict_validates_negative_content_length(self):
|
||||
"""Test that negative content length raises ValidationError."""
|
||||
with pytest.raises(ValidationError, match="Content length cannot be negative"):
|
||||
create_scraped_content_dict(
|
||||
url="https://example.com",
|
||||
content="Content",
|
||||
content_length=-1
|
||||
)
|
||||
|
||||
def test_create_scraped_content_dict_validates_success_with_error_message(self):
|
||||
"""Test that success=True with error_message raises ValidationError."""
|
||||
with pytest.raises(ValidationError, match="Error message should not be present when success is True"):
|
||||
create_scraped_content_dict(
|
||||
url="https://example.com",
|
||||
content="Content",
|
||||
success=True,
|
||||
error_message="This should not be here"
|
||||
)
|
||||
|
||||
def test_create_scraped_content_dict_failure_normalizes_empty_content(self):
|
||||
"""Ensure failed scraping normalizes missing content to an empty string."""
|
||||
result = create_scraped_content_dict(
|
||||
url="https://example.com",
|
||||
content=" ",
|
||||
success=False,
|
||||
error_message="failure",
|
||||
)
|
||||
|
||||
assert result["content"] == ""
|
||||
assert result["success"] is False
|
||||
|
||||
Reference in New Issue
Block a user