Refactor type safety checks and enhance error handling across various modules

- Update typing in error handling and validation nodes to improve type safety.
- Refactor cache decorators for async compatibility and cleanup functionality.
- Enhance URL processing and validation logic with improved type checks.
- Centralize error handling and recovery mechanisms in nodes.
- Simplify and standardize function signatures across multiple modules for consistency.
- Resolve linting issues and ensure compliance with type safety standards.
This commit is contained in:
2025-09-28 13:45:52 -04:00
parent 80d5bc6c23
commit 7a84d75d8e
44 changed files with 3289 additions and 2054 deletions

1009
errors.md

File diff suppressed because it is too large Load Diff

View File

@@ -6,38 +6,37 @@ project_includes = [
"tests"
]
# Exclude directories
# Exclude directories - exclude everything except src/ and tests/
project_excludes = [
"build/",
"dist/",
".venv/",
"venv/",
".cenv/",
".venv-host/",
"**/__pycache__/",
"**/node_modules/",
"**/htmlcov/",
"**/prof/",
"**/.pytest_cache/",
"**/.mypy_cache/",
"**/.ruff_cache/",
"**/cassettes/",
"**/*.egg-info/",
".archive/",
"**/.archive/",
"cache/",
"examples/**",
".cenv/**",
".venv-host/**",
"**/.venv/**",
"**/venv/**",
"**/site-packages/**",
"**/lib/python*/**",
"**/bin/**",
"**/include/**",
"**/share/**",
".backup/**",
"**/.backup/**"
"__marimo__",
"cache",
"coverage_reports",
"docker",
"docs",
"examples",
"htmlcov",
"htmlcov_tools",
"logs",
"node_modules",
"prof",
"scripts",
"static",
"*.md",
"*.toml",
"*.txt",
"*.yml",
"*.yaml",
"*.json",
"*.lock",
"*.ini",
"*.conf",
"*.sh",
"*.xml",
"*.gpgsign",
"Makefile",
"Dockerfile*",
"nginx.conf",
"sonar-project.properties"
]
# Search paths for module resolution

View File

@@ -3,9 +3,10 @@
"src",
"tests"
],
"extraPaths": [
"src"
],
"extraPaths": [
"src",
"tests"
],
"exclude": [
"**/node_modules",
"**/__pycache__",

File diff suppressed because it is too large Load Diff

View File

@@ -18,7 +18,7 @@ from dataclasses import dataclass, field
from datetime import UTC, datetime
from enum import Enum
from functools import wraps
from typing import ParamSpec, TypedDict, TypeVar, Unpack, cast
from typing import Any, ParamSpec, TypedDict, TypeVar, Unpack, cast
# Import error types from shared module to break circular imports
from biz_bud.core.types import ErrorInfo, JSONObject, JSONValue
@@ -923,7 +923,10 @@ def handle_errors(
@wraps(func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> object:
try:
return await async_func(*args, **kwargs)
return await cast(Callable[..., Awaitable[object]], async_func)(
*cast(tuple[Any, ...], args),
**cast(dict[str, object], kwargs),
)
except Exception as error: # pragma: no cover - defensive
if not any(isinstance(error, exc_type) for exc_type in catch_types):
raise
@@ -959,7 +962,10 @@ def handle_errors(
@wraps(func)
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> object:
try:
return sync_func(*args, **kwargs)
return cast(Callable[..., object], sync_func)(
*cast(tuple[Any, ...], args),
**cast(dict[str, object], kwargs),
)
except Exception as error: # pragma: no cover - defensive
if not any(isinstance(error, exc_type) for exc_type in catch_types):
raise
@@ -1091,16 +1097,9 @@ def create_error_info(
if category == "unknown" and error_type:
try:
# Create a mock exception to categorize
exception_class = None
# Try builtins first
if hasattr(__builtins__, error_type):
exception_class = getattr(__builtins__, error_type)
# Try standard library exceptions
import builtins
if not exception_class and hasattr(builtins, error_type):
exception_class = getattr(builtins, error_type)
exception_class = getattr(builtins, error_type, None)
# If still not found, try common exception types
if not exception_class:
@@ -1111,6 +1110,7 @@ def create_error_info(
'FileNotFoundError': FileNotFoundError,
'ValueError': ValueError,
'KeyError': KeyError,
'IndexError': IndexError,
'TypeError': TypeError,
'AttributeError': AttributeError,
}
@@ -1707,7 +1707,10 @@ def with_retry(
for attempt in range(max_attempts):
try:
return await async_func(*args, **kwargs)
return await cast(Callable[..., Awaitable[object]], async_func)(
*cast(tuple[Any, ...], args),
**cast(dict[str, object], kwargs),
)
except exceptions as exc:
last_error = exc
@@ -1733,7 +1736,10 @@ def with_retry(
for attempt in range(max_attempts):
try:
return sync_func(*args, **kwargs)
return cast(Callable[..., object], sync_func)(
*cast(tuple[Any, ...], args),
**cast(dict[str, object], kwargs),
)
except exceptions as e:
last_error = e

View File

@@ -10,7 +10,7 @@ import functools
import time
from collections.abc import Awaitable, Callable, Mapping, MutableMapping, Sequence
from datetime import UTC, datetime
from typing import Any, ParamSpec, TypedDict, TypeVar, cast
from typing import Any, TypedDict, TypeVar, cast
from biz_bud.core.errors import RetryExecutionError
from biz_bud.logging import get_logger
@@ -18,8 +18,7 @@ from biz_bud.logging import get_logger
logger = get_logger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
class NodeMetric(TypedDict):
@@ -105,7 +104,7 @@ def _log_execution_error(
def log_node_execution(
node_name: str | None = None,
) -> Callable[[Callable[P, Awaitable[R] | R]], Callable[P, Awaitable[R] | R]]:
) -> Callable[[CallableT], CallableT]:
"""Log node execution with timing and context.
This decorator automatically logs entry, exit, and timing information
@@ -118,18 +117,19 @@ def log_node_execution(
Decorated function with logging
"""
def decorator(func: Callable[P, Awaitable[R] | R]) -> Callable[P, Awaitable[R] | R]:
def decorator(func: CallableT) -> CallableT:
actual_node_name = node_name or func.__name__
@functools.wraps(func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
context = _extract_context_from_args(
cast(tuple[object, ...], args), cast(Mapping[str, object], kwargs)
)
start_time = _log_execution_start(actual_node_name, context)
try:
result = await cast(Callable[P, Awaitable[R]], func)(*args, **kwargs)
async_func = cast(Callable[..., Awaitable[object]], func)
result = await async_func(*args, **kwargs)
_log_execution_success(actual_node_name, start_time, context)
return result
except Exception as e:
@@ -137,14 +137,15 @@ def log_node_execution(
raise
@functools.wraps(func)
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
context = _extract_context_from_args(
cast(tuple[object, ...], args), cast(Mapping[str, object], kwargs)
)
start_time = _log_execution_start(actual_node_name, context)
try:
result = cast(Callable[P, R], func)(*args, **kwargs)
sync_func = cast(Callable[..., object], func)
result = sync_func(*args, **kwargs)
_log_execution_success(actual_node_name, start_time, context)
return result
except Exception as e:
@@ -153,8 +154,8 @@ def log_node_execution(
# Return appropriate wrapper based on function type
if asyncio.iscoroutinefunction(func):
return cast(Callable[P, Awaitable[R] | R], async_wrapper)
return cast(Callable[P, Awaitable[R] | R], sync_wrapper)
return cast(CallableT, async_wrapper)
return cast(CallableT, sync_wrapper)
return decorator
@@ -267,7 +268,7 @@ def _update_metric_failure(metric: NodeMetric | None, elapsed_ms: float, error:
def track_metrics(
metric_name: str,
) -> Callable[[Callable[P, Awaitable[R] | R]], Callable[P, Awaitable[R] | R]]:
) -> Callable[[CallableT], CallableT]:
"""Track metrics for node execution.
This decorator updates state with performance metrics including
@@ -280,11 +281,11 @@ def track_metrics(
Decorated function with metric tracking
"""
def decorator(func: Callable[P, Awaitable[R] | R]) -> Callable[P, Awaitable[R] | R]:
def decorator(func: CallableT) -> CallableT:
@functools.wraps(func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
start_time = time.time()
positional = cast(tuple[Any, ...], args)
positional = args
state = (
cast(MutableMapping[str, object], positional[0])
if positional and isinstance(positional[0], MutableMapping)
@@ -293,7 +294,8 @@ def track_metrics(
metric = _initialize_metric(state, metric_name)
try:
result = await cast(Callable[P, Awaitable[R]], func)(*args, **kwargs)
async_func = cast(Callable[..., Awaitable[object]], func)
result = await async_func(*args, **kwargs)
elapsed_ms = (time.time() - start_time) * 1000
_update_metric_success(metric, elapsed_ms)
return result
@@ -303,9 +305,9 @@ def track_metrics(
raise
@functools.wraps(func)
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
start_time = time.time()
positional = cast(tuple[Any, ...], args)
positional = args
state = (
cast(MutableMapping[str, object], positional[0])
if positional and isinstance(positional[0], MutableMapping)
@@ -314,7 +316,8 @@ def track_metrics(
metric = _initialize_metric(state, metric_name)
try:
result = cast(Callable[P, R], func)(*args, **kwargs)
sync_func = cast(Callable[..., object], func)
result = sync_func(*args, **kwargs)
elapsed_ms = (time.time() - start_time) * 1000
_update_metric_success(metric, elapsed_ms)
return result
@@ -324,8 +327,8 @@ def track_metrics(
raise
if asyncio.iscoroutinefunction(func):
return cast(Callable[P, Awaitable[R] | R], async_wrapper)
return cast(Callable[P, Awaitable[R] | R], sync_wrapper)
return cast(CallableT, async_wrapper)
return cast(CallableT, sync_wrapper)
return decorator
@@ -392,8 +395,8 @@ def _handle_error(
def handle_errors(
error_handler: Callable[[Exception], None] | None = None,
fallback_value: R | None = None,
) -> Callable[[Callable[P, Awaitable[R] | R]], Callable[P, Awaitable[R] | R]]:
fallback_value: object | None = None,
) -> Callable[[CallableT], CallableT]:
"""Handle errors with standardized error handling in nodes.
This decorator provides consistent error handling with optional
@@ -407,11 +410,12 @@ def handle_errors(
Decorated function with error handling
"""
def decorator(func: Callable[P, Awaitable[R] | R]) -> Callable[P, Awaitable[R] | R]:
def decorator(func: CallableT) -> CallableT:
@functools.wraps(func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
try:
return await cast(Callable[P, Awaitable[R]], func)(*args, **kwargs)
async_func = cast(Callable[..., Awaitable[object]], func)
return await async_func(*args, **kwargs)
except Exception as e:
result = _handle_error(
e,
@@ -420,12 +424,13 @@ def handle_errors(
error_handler,
fallback_value,
)
return cast(R, result)
return result
@functools.wraps(func)
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
try:
return cast(Callable[P, R], func)(*args, **kwargs)
sync_func = cast(Callable[..., object], func)
return sync_func(*args, **kwargs)
except Exception as e:
result = _handle_error(
e,
@@ -434,11 +439,11 @@ def handle_errors(
error_handler,
fallback_value,
)
return cast(R, result)
return result
if asyncio.iscoroutinefunction(func):
return cast(Callable[P, Awaitable[R] | R], async_wrapper)
return cast(Callable[P, Awaitable[R] | R], sync_wrapper)
return cast(CallableT, async_wrapper)
return cast(CallableT, sync_wrapper)
return decorator
@@ -447,7 +452,7 @@ def retry_on_failure(
max_attempts: int = 3,
backoff_seconds: float = 1.0,
exponential_backoff: bool = True,
) -> Callable[[Callable[P, Awaitable[R] | R]], Callable[P, Awaitable[R] | R]]:
) -> Callable[[CallableT], CallableT]:
"""Retry node execution on failure.
Args:
@@ -459,14 +464,15 @@ def retry_on_failure(
Decorated function with retry logic
"""
def decorator(func: Callable[P, Awaitable[R] | R]) -> Callable[P, Awaitable[R] | R]:
def decorator(func: CallableT) -> CallableT:
@functools.wraps(func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
last_exception: Exception | None = None
for attempt in range(max_attempts):
try:
return await cast(Callable[P, Awaitable[R]], func)(*args, **kwargs)
async_func = cast(Callable[..., Awaitable[object]], func)
return await async_func(*args, **kwargs)
except Exception as e:
last_exception = e
@@ -497,12 +503,13 @@ def retry_on_failure(
)
@functools.wraps(func)
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
last_exception: Exception | None = None
for attempt in range(max_attempts):
try:
return cast(Callable[P, R], func)(*args, **kwargs)
sync_func = cast(Callable[..., object], func)
return sync_func(*args, **kwargs)
except Exception as e:
last_exception = e
@@ -532,8 +539,8 @@ def retry_on_failure(
)
if asyncio.iscoroutinefunction(func):
return cast(Callable[P, Awaitable[R] | R], async_wrapper)
return cast(Callable[P, Awaitable[R] | R], sync_wrapper)
return cast(CallableT, async_wrapper)
return cast(CallableT, sync_wrapper)
return decorator
@@ -590,7 +597,7 @@ def standard_node(
node_name: str | None = None,
metric_name: str | None = None,
retry_attempts: int = 0,
) -> Callable[[Callable[P, Awaitable[R] | R]], Callable[P, Awaitable[R] | R]]:
) -> Callable[[CallableT], CallableT]:
"""Composite decorator applying standard cross-cutting concerns.
This decorator combines logging, metrics, error handling, and retries
@@ -605,9 +612,9 @@ def standard_node(
Decorated function with all standard concerns applied
"""
def decorator(func: Callable[P, Awaitable[R] | R]) -> Callable[P, Awaitable[R] | R]:
def decorator(func: CallableT) -> CallableT:
# Apply decorators in order (innermost to outermost)
decorated: Callable[P, Awaitable[R] | R] = func
decorated: CallableT = func
# Add retry if requested
if retry_attempts > 0:

View File

@@ -26,12 +26,12 @@ from collections.abc import (
ValuesView,
)
from types import MappingProxyType
from typing import Any, NoReturn, cast
from typing import Any, NoReturn, TypeVar, cast
_MISSING = object()
DataFrameType: type[Any] | None
SeriesType: type[Any] | None
DataFrameType: type[object] | None
SeriesType: type[object] | None
pd: Any | None
try: # pragma: no cover - pandas is optional in lightweight test environments
import pandas as _pandas_module
@@ -42,10 +42,12 @@ except ModuleNotFoundError: # pragma: no cover - executed when pandas isn't ins
else:
pandas_module = cast(Any, _pandas_module)
pd = pandas_module
DataFrameType = cast(type[Any], pandas_module.DataFrame)
SeriesType = cast(type[Any], pandas_module.Series)
DataFrameType = cast(type[object], pandas_module.DataFrame)
SeriesType = cast(type[object], pandas_module.Series)
from biz_bud.core.errors import ImmutableStateError, StateValidationError
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
def _states_equal(state1: object, state2: object) -> bool:
"""Compare two states safely, handling DataFrames and other complex objects.
@@ -168,6 +170,7 @@ class ImmutableDict(Mapping[str, object]):
def popitem(self) -> tuple[str, object]: # pragma: no cover
self._raise_mutation_error()
raise AssertionError('unreachable')
def clear(self) -> None: # pragma: no cover
self._raise_mutation_error()
@@ -303,8 +306,8 @@ def update_state_immutably(
def ensure_immutable_node(
node_func: Callable[..., object],
) -> Callable[..., object]:
node_func: CallableT,
) -> CallableT:
"""Ensure a node function treats state as immutable.
This decorator:
@@ -327,7 +330,7 @@ def ensure_immutable_node(
import inspect
@functools.wraps(node_func)
async def async_wrapper(*args: object, **kwargs: object) -> object:
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
"""Handle async node function execution with state immutability."""
# Extract state from args (assuming it's the first argument)
if not args:
@@ -364,7 +367,7 @@ def ensure_immutable_node(
return result
@functools.wraps(node_func)
def sync_wrapper(*args: object, **kwargs: object) -> object:
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
"""Handle sync node function execution with state immutability."""
# Extract state from args (assuming it's the first argument)
if not args:
@@ -398,9 +401,8 @@ def ensure_immutable_node(
# Return appropriate wrapper based on function type
if inspect.iscoroutinefunction(node_func):
return cast(Callable[..., object], async_wrapper)
else:
return cast(Callable[..., object], sync_wrapper)
return cast(CallableT, async_wrapper)
return cast(CallableT, sync_wrapper)
class StateUpdater:

View File

@@ -90,6 +90,7 @@ def _process_merge_strategy_fields(
) -> int:
if not merge_strategy:
return 0
merged_dict: dict[str, JSONValue] = merged
processed = 0
for field, strategy in merge_strategy.items():
values = [
@@ -132,14 +133,13 @@ def _process_merge_strategy_fields(
merged[field] = None
for v in values:
if isinstance(v, (int, float)) and not isinstance(v, bool):
existing_value = merged.get(field)
numeric_existing: int | float | None
if isinstance(existing_value, (int, float)) and not isinstance(
existing_value, bool
):
numeric_existing = cast(int | float, existing_value)
else:
numeric_existing = None
existing_value = merged_dict.get(field)
numeric_existing = (
existing_value
if isinstance(existing_value, (int, float))
and not isinstance(existing_value, bool)
else None
)
merged[field] = _handle_numeric_operation(
numeric_existing,
float(v),
@@ -149,14 +149,13 @@ def _process_merge_strategy_fields(
merged[field] = None
for v in values:
if isinstance(v, (int, float)) and not isinstance(v, bool):
existing_value = merged.get(field)
numeric_existing: int | float | None
if isinstance(existing_value, (int, float)) and not isinstance(
existing_value, bool
):
numeric_existing = cast(int | float, existing_value)
else:
numeric_existing = None
existing_value = merged_dict.get(field)
numeric_existing = (
existing_value
if isinstance(existing_value, (int, float))
and not isinstance(existing_value, bool)
else None
)
merged[field] = _handle_numeric_operation(
numeric_existing,
float(v),
@@ -173,6 +172,7 @@ def _process_remaining_fields(
merge_strategy: dict[str, str] | None,
seen_items: set[str],
) -> None:
merged_dict: dict[str, JSONValue] = merged
for r in results:
for field, value in r.items():
if field in seen_items and not isinstance(value, dict | list):
@@ -206,11 +206,11 @@ def _process_remaining_fields(
existing_dict = cast(dict[str, JSONValue], merged[field])
existing_dict.update(value)
elif isinstance(value, (int, float)) and not isinstance(value, bool):
existing_numeric = merged.get(field)
existing_numeric = merged_dict.get(field)
if isinstance(existing_numeric, (int, float)) and not isinstance(
existing_numeric, bool
):
merged[field] = cast(int | float, existing_numeric) + float(value)
merged[field] = existing_numeric + float(value)
seen_items.add(field)
elif isinstance(value, str):
if field in merged and isinstance(merged[field], str):

View File

@@ -10,7 +10,7 @@ This module demonstrates the implementation of LangGraph best practices includin
- Reusable subgraph pattern
"""
from typing import Annotated, Any, NotRequired, TypedDict
from typing import Annotated, Any, NotRequired, TypedDict, cast
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import tool
@@ -102,8 +102,8 @@ async def research_web_search(
@standard_node(node_name="search_web", metric_name="web_search")
@ensure_immutable_node
async def search_web_node(
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
state: ResearchSubgraphState, config: RunnableConfig | None = None
) -> ResearchSubgraphState:
"""Search the web for information related to the query.
Args:
@@ -116,7 +116,7 @@ async def search_web_node(
"""
query = state.get("query", "")
if not query:
return (
updated = (
StateUpdater(state)
.append(
"errors",
@@ -124,6 +124,7 @@ async def search_web_node(
)
.build()
)
return cast(ResearchSubgraphState, updated)
try:
# Simulate web search directly
@@ -143,24 +144,26 @@ async def search_web_node(
# Update state immutably
updater = StateUpdater(state)
return updater.set(
"search_results", result.get("results", [])
).build()
return cast(
ResearchSubgraphState,
updater.set("search_results", result.get("results", [])).build(),
)
except Exception as e:
logger.error(f"Search failed: {e}")
return (
updated = (
StateUpdater(state)
.append("errors", {"node": "search_web", "error": str(e), "phase": "search"})
.build()
)
return cast(ResearchSubgraphState, updated)
@standard_node(node_name="extract_facts", metric_name="fact_extraction")
@ensure_immutable_node
async def extract_facts_node(
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
state: ResearchSubgraphState, config: RunnableConfig | None = None
) -> ResearchSubgraphState:
"""Extract facts from search results.
Args:
@@ -173,7 +176,7 @@ async def extract_facts_node(
"""
search_results = state.get("search_results", [])
if not search_results:
return StateUpdater(state).set("extracted_facts", []).build()
return cast(ResearchSubgraphState, StateUpdater(state).set("extracted_facts", []).build())
try:
facts = [
@@ -187,11 +190,11 @@ async def extract_facts_node(
]
# Update state immutably
updater = StateUpdater(state)
return updater.set("extracted_facts", facts).build()
return cast(ResearchSubgraphState, updater.set("extracted_facts", facts).build())
except Exception as e:
logger.error(f"Fact extraction failed: {e}")
return (
updated = (
StateUpdater(state)
.append(
"errors",
@@ -199,13 +202,14 @@ async def extract_facts_node(
)
.build()
)
return cast(ResearchSubgraphState, updated)
@standard_node(node_name="summarize_research", metric_name="research_summary")
@ensure_immutable_node
async def summarize_research_node(
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
state: ResearchSubgraphState, config: RunnableConfig | None = None
) -> ResearchSubgraphState:
"""Summarize the research findings.
Args:
@@ -220,11 +224,12 @@ async def summarize_research_node(
query = state.get("query", "")
if not facts:
return (
updated = (
StateUpdater(state)
.set("research_summary", "No facts were extracted from the search results.")
.build()
)
return cast(ResearchSubgraphState, updated)
try:
# Mock summarization - replace with actual LLM summarization
@@ -237,11 +242,14 @@ async def summarize_research_node(
# Update state immutably
updater = StateUpdater(state)
return updater.set("research_summary", summary).set("confidence_score", confidence).build()
return cast(
ResearchSubgraphState,
updater.set("research_summary", summary).set("confidence_score", confidence).build(),
)
except Exception as e:
logger.error(f"Summarization failed: {e}")
return (
updated = (
StateUpdater(state)
.append(
"errors",
@@ -253,6 +261,7 @@ async def summarize_research_node(
)
.build()
)
return cast(ResearchSubgraphState, updated)
def create_research_subgraph(

View File

@@ -93,6 +93,44 @@ if TYPE_CHECKING: # pragma: no cover - static typing support
)
from biz_bud.nodes.validation.logic import validate_content_logic
_TYPE_CHECKING_EXPORTS: tuple[object, ...] = (
finalize_status_node,
format_output_node,
format_response_for_caller,
handle_graph_error,
handle_validation_failure,
parse_and_validate_initial_payload,
persist_results,
prepare_final_result,
preserve_url_fields_node,
error_analyzer_node,
error_interceptor_node,
recovery_executor_node,
user_guidance_node,
extract_key_information_node,
orchestrate_extraction_node,
semantic_extract_node,
NodeLLMConfigOverride,
call_model_node,
prepare_llm_messages_node,
update_message_history_node,
batch_process_urls_node,
route_url_node,
scrape_url_node,
cached_web_search_node,
research_web_search_node,
web_search_node,
discover_urls_node,
identify_claims_for_fact_checking,
perform_fact_check,
validate_content_output,
validate_content_logic,
prepare_human_feedback_request,
should_request_feedback,
human_feedback_node,
)
del _TYPE_CHECKING_EXPORTS
EXPORTS: dict[str, tuple[str, str]] = {
# Core nodes

View File

@@ -380,7 +380,7 @@ async def _llm_error_analysis(
return ErrorAnalysisDelta()
def _get_configurable_section(config: RunnableConfig | None) -> Mapping[str, object]:
def _get_configurable_section(config: RunnableConfig | None = None) -> Mapping[str, object]:
if config is None:
return {}
raw_configurable = config.get("configurable")

View File

@@ -1,7 +1,7 @@
"""User guidance node for generating error resolution instructions."""
from collections.abc import Mapping
from typing import Literal, TypedDict
from typing import Literal, TypedDict, cast
from langchain_core.runnables import RunnableConfig
@@ -255,10 +255,15 @@ def _coerce_error_analysis(value: object) -> ErrorAnalysis | None:
return None
error_type = str(value.get("error_type", "unknown"))
criticality_raw = value.get("criticality", "medium")
criticality_value = value.get("criticality")
criticality: Literal["low", "medium", "high", "critical"]
if criticality_raw in {"low", "medium", "high", "critical"}: # type: ignore[comparison-overlap]
criticality = criticality_raw # type: ignore[assignment]
if isinstance(criticality_value, str) and criticality_value in {
"low",
"medium",
"high",
"critical",
}:
criticality = cast(Literal["low", "medium", "high", "critical"], criticality_value)
else:
criticality = "medium"
@@ -492,7 +497,7 @@ def _calculate_duration(state: ErrorHandlingState) -> float | None:
return None
def _get_configurable_section(config: RunnableConfig | None) -> Mapping[str, object]:
def _get_configurable_section(config: RunnableConfig | None = None) -> Mapping[str, object]:
if config is None:
return {}
raw_configurable = config.get("configurable")

View File

@@ -151,7 +151,7 @@ def should_intercept_error(state: Mapping[str, object]) -> bool:
return bool(last_error and not last_error.get("handled", False))
def _get_configurable_section(config: RunnableConfig | None) -> Mapping[str, object]:
def _get_configurable_section(config: RunnableConfig | None = None) -> Mapping[str, object]:
if config is None:
return {}
raw_configurable = config.get("configurable")

View File

@@ -526,10 +526,15 @@ def _coerce_error_analysis(value: object) -> ErrorAnalysis | None:
return None
error_type = str(value.get("error_type", "unknown"))
criticality_raw = value.get("criticality", "medium")
criticality_value = value.get("criticality")
criticality: Literal["low", "medium", "high", "critical"]
if criticality_raw in {"low", "medium", "high", "critical"}: # type: ignore[comparison-overlap]
criticality = criticality_raw # type: ignore[assignment]
if isinstance(criticality_value, str) and criticality_value in {
"low",
"medium",
"high",
"critical",
}:
criticality = cast(Literal["low", "medium", "high", "critical"], criticality_value)
else:
criticality = "medium"
@@ -636,7 +641,7 @@ def register_custom_recovery_action(
logger.info("Registered custom recovery action: %s", action_name)
def _get_configurable_section(config: RunnableConfig | None) -> Mapping[str, object]:
def _get_configurable_section(config: RunnableConfig | None = None) -> Mapping[str, object]:
if config is None:
return {}
raw_configurable = config.get("configurable")

View File

@@ -10,11 +10,13 @@ from typing import TYPE_CHECKING, Any, Awaitable, cast
from biz_bud.core.langgraph import (
ConfigurationProvider,
StateUpdater,
ensure_immutable_node,
standard_node,
)
from biz_bud.core.types import create_error_info
from biz_bud.logging import get_logger, info_highlight, warning_highlight
from biz_bud.states.research import ResearchState
from .extractors import extract_batch_node
@@ -23,7 +25,6 @@ if TYPE_CHECKING:
from biz_bud.nodes.models import ExtractionResultModel
from biz_bud.services.factory import ServiceFactory
from biz_bud.states.research import ResearchState
logger = get_logger(__name__)
@@ -153,15 +154,18 @@ async def semantic_extract_node(
# Extract information using the refactored extractors
# Create temporary state for batch extraction node
batch_state = {
"content_batch": valid_content,
"query": query,
"chunk_size": 4000,
"chunk_overlap": 200,
"max_chunks": 5,
"max_concurrent": 3,
"verbose": True,
}
batch_state_mapping = (
StateUpdater(state)
.set("content_batch", valid_content)
.set("query", query)
.set("chunk_size", 4000)
.set("chunk_overlap", 200)
.set("max_chunks", 5)
.set("max_concurrent", 3)
.set("verbose", True)
.build()
)
batch_state = cast(ResearchState, batch_state_mapping)
batch_result = await _resolve_node_result(
extract_batch_node(batch_state, config)
)

View File

@@ -48,9 +48,7 @@ async def discover_urls_node(
info_highlight("Starting URL discovery...", category="URLDiscovery")
modern_result = cast(
JSONObject, await modern_discover_urls_node(state, config)
)
modern_result = await modern_discover_urls_node(state, config)
result_mapping = modern_result
discovered_urls = coerce_str_list(result_mapping.get("discovered_urls"))

View File

@@ -239,7 +239,7 @@ async def _get_validation_client(
node_name="identify_claims_for_fact_checking", metric_name="claim_identification"
)
async def identify_claims_for_fact_checking(
state: StateDict, config: RunnableConfig | None
state: StateDict, config: RunnableConfig | None = None
) -> StateDict:
"""Identify factual claims within content that require validation."""
@@ -331,7 +331,7 @@ async def identify_claims_for_fact_checking(
@standard_node(node_name="perform_fact_check", metric_name="fact_checking")
async def perform_fact_check(state: StateDict, config: RunnableConfig | None) -> StateDict:
async def perform_fact_check(state: StateDict, config: RunnableConfig | None = None) -> StateDict:
"""Validate the previously identified claims using an LLM."""
logger.info("Performing fact-checking on identified claims...")
@@ -464,7 +464,7 @@ async def perform_fact_check(state: StateDict, config: RunnableConfig | None) ->
@standard_node(node_name="validate_content_output", metric_name="content_validation")
async def validate_content_output(
state: StateDict, config: RunnableConfig | None
state: StateDict, config: RunnableConfig | None = None
) -> StateDict:
"""Perform final validation on generated output."""

View File

@@ -177,7 +177,7 @@ def _summarise_search_results(results: Sequence[SearchResultTypedDict]) -> list[
@standard_node(node_name="human_feedback_node", metric_name="human_feedback")
async def human_feedback_node(
state: BusinessBuddyState, config: RunnableConfig | None
state: BusinessBuddyState, config: RunnableConfig | None = None
) -> FeedbackUpdate: # pragma: no cover - execution halts via interrupt
"""Request and process human feedback via LangGraph interrupts."""
@@ -272,7 +272,7 @@ async def human_feedback_node(
node_name="prepare_human_feedback_request", metric_name="feedback_preparation"
)
async def prepare_human_feedback_request(
state: BusinessBuddyState, config: RunnableConfig | None
state: BusinessBuddyState, config: RunnableConfig | None = None
) -> FeedbackUpdate:
"""Prepare context and summary for human feedback."""
@@ -370,7 +370,7 @@ async def prepare_human_feedback_request(
@standard_node(node_name="apply_human_feedback", metric_name="feedback_application")
async def apply_human_feedback(
state: BusinessBuddyState, config: RunnableConfig | None
state: BusinessBuddyState, config: RunnableConfig | None = None
) -> FeedbackUpdate:
"""Apply human feedback to refine generated output."""

View File

@@ -135,7 +135,7 @@ def _coerce_string_list(value: object) -> list[str]:
@standard_node(node_name="validate_content_logic", metric_name="logic_validation")
@ensure_immutable_node
async def validate_content_logic(
state: StateDict, config: RunnableConfig | None
state: StateDict, config: RunnableConfig | None = None
) -> StateDict:
"""Validate the logical structure, reasoning, and consistency of content."""
@@ -172,8 +172,9 @@ async def validate_content_logic(
)
if "error" in response:
error_message = response.get("error")
raise ValidationError(str(error_message) if error_message is not None else "Unknown validation error")
error_value = cast(JSONValue | None, response.get("error"))
error_message = "Unknown validation error" if error_value is None else str(error_value)
raise ValidationError(error_message)
overall_score_raw = response.get("overall_score", 0)
score_value = 0.0

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
import json
from collections.abc import Awaitable, Callable, Iterable
from logging import Logger
from typing import Annotated, ClassVar, cast, override
from typing import Annotated, ClassVar, cast
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import ArgsSchema, BaseTool, tool

View File

@@ -2,7 +2,7 @@
import logging
import re
from typing import TypedDict, cast
from typing import TypedDict
from biz_bud.core.types import (
ReceiptCanonicalizationResultTypedDict,

View File

@@ -28,7 +28,7 @@ Example usage:
from __future__ import annotations
from collections.abc import Mapping
from typing import Literal, cast
from typing import Literal
from langchain_core.tools import tool
@@ -59,12 +59,42 @@ from .models import (
URLAnalysis,
URLProcessingRequest,
ValidationResult,
ValidationStatus,
)
from .service import URLProcessingService
logger = get_logger(__name__)
ValidationStatusLiteral = Literal["valid", "invalid", "timeout", "error", "blocked"]
ProcessingStatusLiteral = Literal["success", "failed", "skipped", "timeout"]
def _validation_status_literal(status: ValidationStatus) -> ValidationStatusLiteral:
if status is ValidationStatus.VALID:
return "valid"
if status is ValidationStatus.INVALID:
return "invalid"
if status is ValidationStatus.TIMEOUT:
return "timeout"
if status is ValidationStatus.ERROR:
return "error"
if status is ValidationStatus.BLOCKED:
return "blocked"
raise ValueError(f"Unexpected ValidationStatus: {status!r}")
def _processing_status_literal(status: ProcessingStatus) -> ProcessingStatusLiteral:
if status is ProcessingStatus.SUCCESS:
return "success"
if status is ProcessingStatus.FAILED:
return "failed"
if status is ProcessingStatus.SKIPPED:
return "skipped"
if status is ProcessingStatus.TIMEOUT:
return "timeout"
raise ValueError(f"Unexpected ProcessingStatus: {status!r}")
def _coerce_to_json_value(value: object) -> JSONValue:
if value is None or isinstance(value, (str, int, float, bool)):
return value
@@ -83,7 +113,7 @@ def _coerce_to_json_value(value: object) -> JSONValue:
def _coerce_to_json_object(value: Mapping[str, object] | None) -> JSONObject:
if value is None or not isinstance(value, Mapping):
if value is None:
return {}
json_obj: JSONObject = {}
for key, item in value.items():
@@ -98,7 +128,7 @@ def _validation_result_to_typed(result: ValidationResult | None) -> URLValidatio
typed: URLValidationResultTypedDict = {
"url": result.url,
"is_valid": result.is_valid,
"status": cast(Literal["valid", "invalid", "timeout", "error", "blocked"], result.status.value),
"status": _validation_status_literal(result.status),
"error_message": result.error_message or "",
"validation_level": result.validation_level,
"checks_performed": [str(item) for item in result.checks_performed],
@@ -144,7 +174,7 @@ def _processed_url_to_typed(result: ProcessedURL) -> URLProcessingResultItemType
"original_url": result.original_url,
"normalized_url": result.normalized_url,
"final_url": result.final_url,
"status": cast(Literal["success", "failed", "skipped", "timeout"], result.status.value),
"status": _processing_status_literal(result.status),
"analysis": analysis_typed,
"validation": validation_typed,
"is_valid": is_valid,

View File

@@ -482,10 +482,7 @@ def _coerce_kwargs(overrides: Mapping[str, JSONValue] | None) -> dict[str, JSONV
coerced: dict[str, JSONValue] = {}
for key, value in overrides.items():
if not isinstance(key, str):
coerced[str(key)] = value
else:
coerced[key] = value
coerced[str(key)] = value
return coerced

View File

@@ -681,13 +681,11 @@ class URLProcessingService(BaseService[URLProcessingServiceConfig]):
) from e
# Finalize metrics
metrics = result.metrics
if metrics is not None:
metrics.total_processing_time = time.time() - start_time
metrics.finish()
processing_time = metrics.total_processing_time
else:
processing_time = 0.0
metrics = result.metrics or ProcessingMetrics(total_urls=len(processed_urls))
result.metrics = metrics
metrics.total_processing_time = time.time() - start_time
metrics.finish()
processing_time = metrics.total_processing_time
logger.info(
f"Batch processing completed: {result.success_rate:.1f}% success rate, "
f"{len(result.results)} URLs processed in "

View File

@@ -155,20 +155,16 @@ def validate_list_field(
if item_type is None:
return cast(list[T], list(value))
allowed_types: tuple[type[object], ...]
if isinstance(item_type, tuple):
allowed_types = item_type
else:
allowed_types = (item_type,)
# Validate and filter items
validated_items: list[T] = []
for item in value:
if isinstance(item_type, tuple):
if any(isinstance(item, cast(type[object], t)) for t in item_type):
validated_items.append(cast(T, item))
else:
logger.warning(
f"Invalid {field_name} item type: {type(item)}, skipping"
)
continue
single_type = cast(type[object], item_type)
if isinstance(item, single_type):
if isinstance(item, allowed_types):
validated_items.append(cast(T, item))
else:
logger.warning(

View File

@@ -56,6 +56,7 @@ from biz_bud.nodes.search.orchestrator import optimized_search_node # noqa: E40
from biz_bud.services.llm.client import LangchainLLMClient # noqa: E402
from biz_bud.services.redis_backend import RedisCacheBackend # noqa: E402
from biz_bud.services.vector_store import VectorStore # noqa: E402
from tests.helpers.mock_helpers import invoke_async_maybe
class TestNetworkFailures:
@@ -445,11 +446,14 @@ class TestNetworkFailures:
}
# Search node should handle API unavailability gracefully
result = await optimized_search_node(
{"query": "test query", "search_queries": ["test query"]}, config=config
result = await invoke_async_maybe(
optimized_search_node,
{"query": "test query", "search_queries": ["test query"]},
config,
)
# Should return empty or error results, not raise exception
assert isinstance(result, dict)
assert "search_results" in result or "error" in result
@pytest.mark.asyncio
@@ -464,7 +468,8 @@ class TestNetworkFailures:
# The extract_from_content function should work even if scraping fails
# since it's given content directly
result = await extract_from_content(
result = await invoke_async_maybe(
extract_from_content,
content="<html><body>Test content</body></html>",
query="Test query",
url="https://example.com",

View File

@@ -1,14 +1,30 @@
"""Mock helpers for tests."""
from unittest.mock import AsyncMock, MagicMock
def create_mock_redis_client() -> MagicMock:
"""Create a mock Redis client for testing."""
mock_client = MagicMock()
mock_client.ping = AsyncMock(return_value=True)
mock_client.get = AsyncMock(return_value=None)
mock_client.set = AsyncMock(return_value=True)
mock_client.delete = AsyncMock(return_value=1)
mock_client.close = AsyncMock()
"""Mock helpers for tests."""
import inspect
from collections.abc import Awaitable, Callable
from typing import TypeVar, cast
from unittest.mock import AsyncMock, MagicMock
T = TypeVar("T")
def create_mock_redis_client() -> MagicMock:
"""Create a mock Redis client for testing."""
mock_client = MagicMock()
mock_client.ping = AsyncMock(return_value=True)
mock_client.get = AsyncMock(return_value=None)
mock_client.set = AsyncMock(return_value=True)
mock_client.delete = AsyncMock(return_value=1)
mock_client.close = AsyncMock()
return mock_client
async def invoke_async_maybe(
func: Callable[..., Awaitable[T] | T], *args: object, **kwargs: object
) -> T:
"""Invoke a callable and await the result when necessary."""
result = func(*args, **kwargs)
if inspect.isawaitable(result):
return await cast(Awaitable[T], result)
return cast(T, result)

View File

@@ -19,7 +19,7 @@ from biz_bud.core.utils.lazy_loader import AsyncSafeLazyLoader
class TestRegexPatternCaching:
"""Test regex pattern caching for performance."""
def test_regex_pattern_compilation_caching(self):
def test_regex_pattern_compilation_caching(self) -> None:
"""Test that regex patterns are cached for better performance."""
# Test caching of compiled regex patterns
pattern_cache: Dict[str, Pattern[str]] = {}
@@ -63,7 +63,7 @@ class TestRegexPatternCaching:
)
def test_cached_regex_patterns_functionality(
self, pattern: str, test_string: str, should_match: bool
):
) -> None:
"""Test that cached regex patterns work correctly."""
# Simple pattern cache implementation
pattern_cache: Dict[str, Pattern[str]] = {}
@@ -81,7 +81,7 @@ class TestRegexPatternCaching:
# Pattern should be cached
assert pattern in pattern_cache
def test_regex_performance_with_caching(self):
def test_regex_performance_with_caching(self) -> None:
"""Test performance improvement with regex caching."""
pattern_str = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
test_strings = [
@@ -119,7 +119,7 @@ class TestLazyLoadingPatterns:
# Test that AsyncSafeLazyLoader defers actual loading
loader_called = False
def expensive_loader():
def expensive_loader() -> dict[str, str]:
nonlocal loader_called
loader_called = True
return {"expensive": "data"}
@@ -193,7 +193,7 @@ class TestLazyLoadingPatterns:
# Simple async lazy loader implementation
class AsyncLazyLoader:
def __init__(self, loader_func):
def __init__(self, loader_func) -> None:
self._loader = loader_func
self._data: dict[str, Any] | None = None
self._loaded = False
@@ -224,7 +224,7 @@ class TestLazyLoadingPatterns:
"""Test that lazy loading improves memory efficiency."""
# Create large data that's expensive to generate
def create_large_data():
def create_large_data() -> list[int]:
return list(range(data_size))
# Without lazy loading - data is created immediately
@@ -258,7 +258,7 @@ class TestLazyLoadingPatterns:
class TestConfigurationSchemaPatterns:
"""Test configuration schema patterns with Pydantic."""
def test_pydantic_model_validation_performance(self):
def test_pydantic_model_validation_performance(self) -> None:
"""Test Pydantic model validation performance."""
from biz_bud.core.validation.pydantic_models import UserQueryModel
@@ -281,7 +281,7 @@ class TestConfigurationSchemaPatterns:
# Should be fast (less than 1 second for 100 validations)
assert validation_time < 1.0
def test_pydantic_schema_caching(self):
def test_pydantic_schema_caching(self) -> None:
"""Test that Pydantic schemas are cached for performance."""
from biz_bud.core.validation.pydantic_models import UserQueryModel
@@ -306,7 +306,7 @@ class TestConfigurationSchemaPatterns:
)
def test_configuration_validation_patterns(
self, config_type: str, test_data: Dict[str, Any]
):
) -> None:
"""Test various configuration validation patterns."""
from biz_bud.core.validation.pydantic_models import (
APIConfigModel,
@@ -335,7 +335,7 @@ class TestConfigurationSchemaPatterns:
has_attr and value_matches for has_attr, value_matches in attribute_checks
)
def test_configuration_error_handling_performance(self):
def test_configuration_error_handling_performance(self) -> None:
"""Test that configuration error handling is performant."""
from pydantic import ValidationError
@@ -349,7 +349,7 @@ class TestConfigurationSchemaPatterns:
{}, # Missing required fields
]
def validate_data(invalid_data):
def validate_data(invalid_data) -> bool:
try:
UserQueryModel(**invalid_data)
return False # Should not reach here
@@ -440,7 +440,7 @@ class TestConcurrencyPatterns:
# Should be reasonably fast
assert total_time < 1.0 # Less than 1 second for 200 operations
def test_synchronous_performance_patterns(self):
def test_synchronous_performance_patterns(self) -> None:
"""Test synchronous performance patterns."""
from biz_bud.core.caching.decorators import cache_sync
@@ -480,7 +480,7 @@ class TestConcurrencyPatterns:
class TestMemoryOptimizationPatterns:
"""Test memory optimization patterns."""
def test_weak_reference_patterns(self):
def test_weak_reference_patterns(self) -> None:
"""Test weak reference patterns for memory optimization."""
import weakref
@@ -489,7 +489,7 @@ class TestMemoryOptimizationPatterns:
weak_cache: Dict[str, Any] = {}
class ExpensiveObject:
def __init__(self, data: str):
def __init__(self, data: str) -> None:
self.data = data
def get_or_create_object(key: str) -> ExpensiveObject:
@@ -520,7 +520,7 @@ class TestMemoryOptimizationPatterns:
obj3 = get_or_create_object("test")
assert obj3.data == "data_for_test"
def test_memory_efficient_data_structures(self):
def test_memory_efficient_data_structures(self) -> None:
"""Test memory efficient data structures."""
from collections import deque

View File

@@ -6,6 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from biz_bud.graphs.analysis.nodes.plan import formulate_analysis_plan
from tests.helpers.mock_helpers import invoke_async_maybe
@pytest.mark.asyncio
@@ -30,7 +31,7 @@ async def test_formulate_analysis_plan_success(
analysis_state["context"] = {"analysis_goal": "Test goal"}
analysis_state["data"] = {"customers": {"type": "dataframe", "shape": [100, 5]}}
result = await cast(Any, formulate_analysis_plan)(analysis_state)
result = await invoke_async_maybe(formulate_analysis_plan, analysis_state)
assert "analysis_plan" in result
# Cast to dict to access dynamically added fields
result_dict = dict(result)
@@ -81,5 +82,5 @@ async def test_formulate_analysis_plan_llm_failure() -> None:
with patch(
"biz_bud.services.factory.get_global_factory", return_value=mock_factory
):
result = await cast(Any, formulate_analysis_plan)(cast("dict[str, Any]", state))
result = await invoke_async_maybe(formulate_analysis_plan, cast("dict[str, Any]", state))
assert "errors" in result

View File

@@ -7,6 +7,16 @@ import pytest
from langchain_core.messages import AIMessage, HumanMessage
from biz_bud.nodes.core.input import parse_and_validate_initial_payload
from tests.helpers.mock_helpers import invoke_async_maybe
async def invoke_input_node(state: dict[str, Any]) -> dict[str, Any]:
"""Run the input node and cast the result to a typed state dict."""
return cast(
dict[str, Any],
await invoke_async_maybe(parse_and_validate_initial_payload, state.copy(), None),
)
@pytest.fixture
@@ -125,9 +135,7 @@ async def test_standard_payload_with_query_and_metadata(
# Mock load_config to return the expected config
mock_load_config_async.return_value = mock_app_config
result = await cast(Any, parse_and_validate_initial_payload)(
initial_state.copy(), None
)
result = await invoke_input_node(initial_state)
assert result["parsed_input"]["user_query"] == "What is the weather?"
assert result["input_metadata"]["session_id"] == "abc"
@@ -183,9 +191,7 @@ async def test_missing_or_empty_query_uses_default(
# Mock load_config to return a mock AppConfig object
mock_load_config_async.return_value = mock_app_config_minimal
result = await cast(Any, parse_and_validate_initial_payload)(
initial_state.copy(), None
)
result = await invoke_input_node(initial_state)
assert result["parsed_input"]["user_query"] == expected_query
assert result["messages"][-1]["content"] == expected_query
@@ -215,7 +221,7 @@ async def test_message_objects_are_normalized(
"parsed_input": {},
}
result = await cast(Any, parse_and_validate_initial_payload)(initial_state.copy(), None)
result = await invoke_input_node(initial_state)
messages = result["messages"]
assert isinstance(messages, list)
@@ -246,7 +252,7 @@ async def test_errors_are_normalized_to_json(
"parsed_input": {"raw_payload": {}},
}
result = await cast(Any, parse_and_validate_initial_payload)(initial_state.copy(), None)
result = await invoke_input_node(initial_state)
errors = result["errors"]
assert isinstance(errors, list)
@@ -290,9 +296,7 @@ async def test_existing_messages_in_state(
# Mock load_config to return a mock AppConfig object
mock_load_config_async.return_value = mock_app_config_minimal
result = await cast(Any, parse_and_validate_initial_payload)(
initial_state.copy(), None
)
result = await invoke_input_node(initial_state)
# Should append new message if not duplicate
assert result["messages"][-1]["content"] == "Continue"
@@ -302,9 +306,7 @@ async def test_existing_messages_in_state(
assert isinstance(initial_state["messages"], list)
assert all(isinstance(m, dict) for m in initial_state["messages"])
initial_state["messages"].append({"role": "user", "content": "Continue"})
result2 = await cast(Any, parse_and_validate_initial_payload)(
initial_state.copy(), None
)
result2 = await invoke_input_node(initial_state)
assert result2["messages"][-1]["content"] == "Continue"
assert result2["messages"].count({"role": "user", "content": "Continue"}) == 1
@@ -337,9 +339,7 @@ async def test_missing_payload_fallbacks(
# Mock load_config to return a mock AppConfig object
mock_load_config_async.return_value = mock_app_config_minimal
result = await cast(Any, parse_and_validate_initial_payload)(
initial_state.copy(), None
)
result = await invoke_input_node(initial_state)
assert result["parsed_input"]["user_query"] == "Fallback Q"
assert result["input_metadata"]["session_id"] == "sid"
assert result["input_metadata"]["user_id"] == "uid"
@@ -376,9 +376,7 @@ async def test_metadata_extraction(
# Mock load_config to return a mock AppConfig object
mock_load_config_async.return_value = mock_app_config_minimal
result = await cast(Any, parse_and_validate_initial_payload)(
initial_state.copy(), None
)
result = await invoke_input_node(initial_state)
assert result["input_metadata"]["session_id"] == "sess"
assert result["input_metadata"].get("user_id") is None
@@ -414,9 +412,7 @@ async def test_config_merging(
# Mock load_config to return a mock AppConfig object
mock_load_config_async.return_value = mock_app_config_custom
result = await cast(Any, parse_and_validate_initial_payload)(
initial_state.copy(), None
)
result = await invoke_input_node(initial_state)
assert result["config"]["DEFAULT_QUERY"] == "New"
assert result["config"]["extra"] == 42
assert (
@@ -446,7 +442,7 @@ async def test_no_parsed_input_or_initial_input_uses_fallback(
state: dict[str, Any] = {}
# Mock load_config to return a mock AppConfig object
mock_load_config_async.return_value = mock_app_config_empty
result = await cast(Any, parse_and_validate_initial_payload)(state.copy(), None)
result = await invoke_input_node(state)
# Should use hardcoded fallback query
assert (
result["parsed_input"]["user_query"]
@@ -487,7 +483,7 @@ async def test_non_list_messages_are_ignored(
}
# Mock load_config to return a mock AppConfig object
mock_load_config_async.return_value = mock_app_config_short
result = await cast(Any, parse_and_validate_initial_payload)(state.copy(), None)
result = await invoke_input_node(state)
# Should initialize messages with the user query only
assert result["messages"] == [{"role": "user", "content": "Q"}]
@@ -514,7 +510,7 @@ async def test_raw_payload_and_metadata_not_dicts(
}
# Mock load_config to return a mock AppConfig object
mock_load_config_async.return_value = mock_app_config_short
result = await cast(Any, parse_and_validate_initial_payload)(state.copy(), None)
result = await invoke_input_node(state)
# Should fallback to default query, metadata extraction should not error
assert result["parsed_input"]["user_query"] == "D"
assert result["input_metadata"].get("session_id") is None
@@ -528,7 +524,7 @@ async def test_raw_payload_and_metadata_not_dicts(
# Reset mock to return updated config for second test
# Mock load_config to return a mock AppConfig object
mock_load_config_async.return_value = mock_app_config_short
result2 = await cast(Any, parse_and_validate_initial_payload)(state2.copy(), None)
result2 = await invoke_input_node(state2)
# When payload validation fails due to invalid metadata, should fallback to default query
assert result2["parsed_input"]["user_query"] == "D"
assert result2["input_metadata"].get("session_id") is None
@@ -559,7 +555,7 @@ async def test_config_missing_and_loaded_config_empty(
}
# Mock load_config to return a mock AppConfig object
mock_load_config_async.return_value = mock_app_config_empty
result = await cast(Any, parse_and_validate_initial_payload)(state.copy(), None)
result = await invoke_input_node(state)
# Should use hardcoded fallback query
assert (
result["parsed_input"]["user_query"]
@@ -597,7 +593,7 @@ async def test_non_string_query_is_coerced_to_string_or_default(
}
# Mock load_config to return a mock AppConfig object
mock_load_config_async.return_value = mock_app_config_short
result = await cast(Any, parse_and_validate_initial_payload)(state.copy(), None)
result = await invoke_input_node(state)
# If query is not a string, should fallback to default
assert result["parsed_input"]["user_query"] == "D"
assert result["messages"][-1]["content"] == "D"

File diff suppressed because it is too large Load Diff

View File

@@ -1,370 +1,405 @@
"""Unit tests for URL analyzer module."""
from unittest.mock import AsyncMock, patch
import pytest
from biz_bud.graphs.rag.nodes.scraping.url_analyzer import analyze_url_for_params_node
class TestAnalyzeURLForParamsNode:
"""Test the analyze_url_for_params_node function."""
@pytest.mark.asyncio
@pytest.mark.parametrize(
"user_input, url, expected_max_pages, expected_max_depth, expected_rationale",
[
# Basic URLs with default values
(
"Extract information from this site",
"https://example.com",
20,
2,
"defaults",
),
# User specifies explicit values
(
"Crawl 50 pages with max depth of 3",
"https://example.com",
50,
3,
"explicit",
),
(
"Get 200 pages from this site",
"https://docs.example.com",
200,
2,
"explicit pages",
),
(
"Max depth of 5 for comprehensive crawl",
"https://site.com",
20,
5,
"explicit depth",
),
# Comprehensive crawl requests
("Crawl the entire site", "https://example.com", 200, 5, "comprehensive"),
(
"Get all pages from the whole site",
"https://docs.com",
200,
5,
"comprehensive",
),
# Documentation URLs
(
"Get API documentation",
"https://example.com/docs/api",
20,
2,
"documentation",
),
(
"Extract from documentation site",
"https://docs.example.com",
20,
2,
"documentation",
),
# Blog URLs
("Get blog posts", "https://example.com/blog", 20, 2, "blog"),
("Extract articles", "https://site.com/posts/2024", 20, 2, "blog"),
# Single page URLs
(
"Extract this page",
"https://example.com/page.html",
20,
2,
"single_page",
),
("Get this PDF content", "https://site.com/doc.pdf", 20, 2, "single_page"),
# GitHub repositories
(
"Analyze this repository",
"https://github.com/user/repo",
20,
2,
"repository",
),
# Empty or minimal input
("", "https://example.com", 20, 2, "no input"),
(None, "https://example.com", 20, 2, "no input"),
],
)
async def test_parameter_extraction_patterns(
self,
user_input: str | None,
url: str,
expected_max_pages: int,
expected_max_depth: int,
expected_rationale: str,
) -> None:
"""Test parameter extraction for various input patterns."""
from unittest.mock import MagicMock
mock_service_factory = MagicMock()
state = {
"query": user_input,
"input_url": url,
"messages": [],
"config": {},
"errors": [],
"service_factory": mock_service_factory,
}
# Mock the LLM call to return None (forcing fallback logic)
with patch("biz_bud.nodes.scraping.url_analyzer.call_model_node") as mock_call:
mock_call.return_value = {"final_response": None}
result = await analyze_url_for_params_node(state)
assert "url_processing_params" in result
params = result["url_processing_params"]
# Check max_pages - should match expected values since we're testing fallback logic
assert params["max_pages"] == expected_max_pages
# Check max_depth - should match expected values since we're testing fallback logic
assert params["max_depth"] == expected_max_depth
@pytest.mark.asyncio
@pytest.mark.parametrize(
"llm_response, expected_params",
[
# Valid JSON response
(
'{"max_pages": 100, "max_depth": 3, "include_subdomains": true, "follow_external_links": false, "extract_metadata": true, "priority_paths": ["/docs", "/api"], "rationale": "Documentation site"}',
{
"max_pages": 100,
"max_depth": 3,
"include_subdomains": True,
"priority_paths": ["/docs", "/api"],
},
),
# JSON wrapped in markdown
(
'```json\n{"max_pages": 50, "max_depth": 2, "include_subdomains": false, "follow_external_links": false, "extract_metadata": true, "priority_paths": [], "rationale": "Blog site"}\n```',
{
"max_pages": 50,
"max_depth": 2,
"include_subdomains": False,
"priority_paths": [],
},
),
# Invalid values get clamped
(
'{"max_pages": 2000, "max_depth": 10, "include_subdomains": false, "follow_external_links": false, "extract_metadata": true, "priority_paths": [], "rationale": "Too high"}',
{"max_pages": 1000, "max_depth": 5, "include_subdomains": False},
),
# Missing fields use defaults
(
'{"max_pages": 30, "rationale": "Partial response"}',
{
"max_pages": 30,
"max_depth": 2,
"extract_metadata": True,
"priority_paths": [],
},
),
],
)
async def test_llm_response_parsing(
self, llm_response: str, expected_params: dict[str, str]
) -> None:
"""Test parsing of various LLM response formats."""
# Mock service factory to avoid global factory error
from unittest.mock import MagicMock
mock_service_factory = MagicMock()
state = {
"query": "Analyze this site",
"input_url": "https://example.com",
"messages": [],
"config": {},
"errors": [],
"service_factory": mock_service_factory, # Provide mock service factory
}
with (
patch("biz_bud.nodes.scraping.url_analyzer.call_model_node") as mock_call,
patch("biz_bud.services.factory.get_global_factory") as mock_factory,
):
mock_call.return_value = {"final_response": llm_response}
mock_factory.return_value = mock_service_factory
result = await analyze_url_for_params_node(state)
assert "url_processing_params" in result
params = result["url_processing_params"]
# Assert all expected parameters match
assert all(params[key] == expected_value for key, expected_value in expected_params.items())
@pytest.mark.asyncio
async def test_error_handling(self) -> None:
"""Test error handling in URL analysis."""
from unittest.mock import MagicMock
mock_service_factory = MagicMock()
state = {
"query": "Analyze site",
"input_url": "https://example.com",
"messages": [],
"config": {},
"errors": [],
"service_factory": mock_service_factory,
}
# Test LLM call failure
with patch("biz_bud.nodes.scraping.url_analyzer.call_model_node") as mock_call:
mock_call.side_effect = Exception("LLM API error")
result = await analyze_url_for_params_node(state)
# Should return default params on error
assert "url_processing_params" in result
params = result["url_processing_params"]
assert params["max_pages"] == 20
assert params["max_depth"] == 2
assert (
params["rationale"]
== "Using extracted parameters from user input or defaults"
)
@pytest.mark.asyncio
async def test_no_url_provided(self) -> None:
"""Test behavior when no URL is provided."""
state = {
"query": "Analyze something",
"input_url": "",
"messages": [],
"config": {},
"errors": [],
}
result = await analyze_url_for_params_node(state)
assert result == {"url_processing_params": None}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"url, path, expected_url_type",
[
("https://example.com/docs/api", "/docs/api", "documentation"),
("https://docs.example.com", "/", "documentation"),
("https://example.com/blog/post", "/blog/post", "blog"),
("https://site.com/articles/2024", "/articles/2024", "blog"),
("https://example.com/file.pdf", "/file.pdf", "single_page"),
(
"https://example.com/deep/nested/path/file.html",
"/deep/nested/path/file.html",
"single_page",
),
("https://github.com/user/repo", "/user/repo", "repository"),
("https://example.com/random", "/random", "general"),
],
)
async def test_url_type_detection(
self, url: str, path: str, expected_url_type: str
) -> None:
"""Test URL type detection logic."""
state = {
"query": "Analyze this",
"input_url": url,
"messages": [],
"config": {},
"errors": [],
}
# We'll intercept the LLM call to check what URL type was detected
detected_url_type = None
# We need to mock the service factory to control the LLM call
# Mock the service factory to return a mock LLM client
mock_llm_client = AsyncMock()
mock_llm_client.llm_json.return_value = (
None # Return None to force fallback logic
)
mock_service_factory = AsyncMock()
mock_service_factory.get_service.return_value = mock_llm_client
# Capture what URL type was detected from the internal classification logic
detected_url_type = None
async def mock_call_model(state, config=None):
# Extract the prompt from the state to check the URL type classification
prompt = state["messages"][1].content
import re
match = re.search(r"URL Type: (\w+)", prompt)
nonlocal detected_url_type
detected_url_type = match.group(1) if match else None
return {"final_response": None} # Return None to trigger fallback logic
with patch(
"biz_bud.nodes.scraping.url_analyzer.call_model_node",
side_effect=mock_call_model,
):
await analyze_url_for_params_node(state)
# The test should verify that the URL classification worked correctly
# The detected_url_type comes from the internal classification logic
assert detected_url_type == expected_url_type
@pytest.mark.asyncio
async def test_context_building(self) -> None:
"""Test context building from state."""
from unittest.mock import MagicMock
from langchain_core.messages import AIMessage, HumanMessage
mock_service_factory = MagicMock()
state = {
"query": "Analyze docs",
"input_url": "https://example.com",
"messages": [
HumanMessage(content="First message"),
AIMessage(content="Response message"),
HumanMessage(
content="Second message with a very long content that should be truncated after 200 characters to avoid sending too much context to the LLM when analyzing URL parameters for optimal crawling settings"
),
],
"synthesis": "Previous synthesis result that is also very long and should be truncated",
"extracted_info": [{"item": 1}, {"item": 2}, {"item": 3}],
"config": {},
"errors": [],
"service_factory": mock_service_factory,
}
context_captured = None
async def mock_call_model(state, config=None):
# Extract the context from the prompt
prompt = state["messages"][1].content
import re
match = re.search(r"Context: (.+?)\nURL Type:", prompt, re.DOTALL)
nonlocal context_captured
context_captured = match.group(1) if match else None
return {"final_response": None}
with patch(
"biz_bud.nodes.scraping.url_analyzer.call_model_node",
side_effect=mock_call_model,
):
await analyze_url_for_params_node(state)
assert context_captured is not None
assert "First message" in context_captured
assert "Response message" in context_captured
assert "..." in context_captured # Truncation indicator
assert "Previous synthesis" in context_captured
assert "3 items" in context_captured
"""Unit tests for URL analyzer module."""
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any, cast
from unittest.mock import AsyncMock, patch
import pytest
from langchain_core.runnables import RunnableConfig
from biz_bud.graphs.rag.nodes.scraping.url_analyzer import analyze_url_for_params_node
ANALYZE_URL_FOR_PARAMS_NODE = cast(
Callable[["URLToRAGState", RunnableConfig | None], Awaitable[dict[str, Any]]],
analyze_url_for_params_node,
)
def _node_config(overrides: dict[str, Any] | None = None) -> RunnableConfig:
base: dict[str, Any] = {"metadata": {"unit_test": "url-analyzer"}}
if overrides:
base.update(overrides)
return cast(RunnableConfig, base)
async def _run_url_analyzer(
state: dict[str, Any], config: RunnableConfig | None = None
) -> dict[str, Any]:
return await ANALYZE_URL_FOR_PARAMS_NODE(
cast("URLToRAGState", cast("Any", state)), config or _node_config()
)
if TYPE_CHECKING:
from biz_bud.states.url_to_rag import URLToRAGState
class TestAnalyzeURLForParamsNode:
"""Test the analyze_url_for_params_node function."""
@pytest.mark.asyncio
@pytest.mark.parametrize(
"user_input, url, expected_max_pages, expected_max_depth, expected_rationale",
[
# Basic URLs with default values
(
"Extract information from this site",
"https://example.com",
20,
2,
"defaults",
),
# User specifies explicit values
(
"Crawl 50 pages with max depth of 3",
"https://example.com",
50,
3,
"explicit",
),
(
"Get 200 pages from this site",
"https://docs.example.com",
200,
2,
"explicit pages",
),
(
"Max depth of 5 for comprehensive crawl",
"https://site.com",
20,
5,
"explicit depth",
),
# Comprehensive crawl requests
("Crawl the entire site", "https://example.com", 200, 5, "comprehensive"),
(
"Get all pages from the whole site",
"https://docs.com",
200,
5,
"comprehensive",
),
# Documentation URLs
(
"Get API documentation",
"https://example.com/docs/api",
20,
2,
"documentation",
),
(
"Extract from documentation site",
"https://docs.example.com",
20,
2,
"documentation",
),
# Blog URLs
("Get blog posts", "https://example.com/blog", 20, 2, "blog"),
("Extract articles", "https://site.com/posts/2024", 20, 2, "blog"),
# Single page URLs
(
"Extract this page",
"https://example.com/page.html",
20,
2,
"single_page",
),
("Get this PDF content", "https://site.com/doc.pdf", 20, 2, "single_page"),
# GitHub repositories
(
"Analyze this repository",
"https://github.com/user/repo",
20,
2,
"repository",
),
# Empty or minimal input
("", "https://example.com", 20, 2, "no input"),
(None, "https://example.com", 20, 2, "no input"),
],
)
async def test_parameter_extraction_patterns(
self,
user_input: str | None,
url: str,
expected_max_pages: int,
expected_max_depth: int,
expected_rationale: str,
) -> None:
"""Test parameter extraction for various input patterns."""
from unittest.mock import MagicMock
mock_service_factory = MagicMock()
state = {
"query": user_input,
"input_url": url,
"messages": [],
"config": {},
"errors": [],
"service_factory": mock_service_factory,
}
# Mock the LLM call to return None (forcing fallback logic)
with patch("biz_bud.nodes.scraping.url_analyzer.call_model_node") as mock_call:
mock_call.return_value = {"final_response": None}
result = await _run_url_analyzer(state)
assert "url_processing_params" in result
params = result["url_processing_params"]
assert isinstance(params, dict)
# Check max_pages - should match expected values since we're testing fallback logic
assert params["max_pages"] == expected_max_pages
# Check max_depth - should match expected values since we're testing fallback logic
assert params["max_depth"] == expected_max_depth
assert params["rationale"] == expected_rationale
@pytest.mark.asyncio
@pytest.mark.parametrize(
"llm_response, expected_params",
[
# Valid JSON response
(
'{"max_pages": 100, "max_depth": 3, "include_subdomains": true, "follow_external_links": false, "extract_metadata": true, "priority_paths": ["/docs", "/api"], "rationale": "Documentation site"}',
{
"max_pages": 100,
"max_depth": 3,
"include_subdomains": True,
"priority_paths": ["/docs", "/api"],
},
),
# JSON wrapped in markdown
(
'```json\n{"max_pages": 50, "max_depth": 2, "include_subdomains": false, "follow_external_links": false, "extract_metadata": true, "priority_paths": [], "rationale": "Blog site"}\n```',
{
"max_pages": 50,
"max_depth": 2,
"include_subdomains": False,
"priority_paths": [],
},
),
# Invalid values get clamped
(
'{"max_pages": 2000, "max_depth": 10, "include_subdomains": false, "follow_external_links": false, "extract_metadata": true, "priority_paths": [], "rationale": "Too high"}',
{"max_pages": 1000, "max_depth": 5, "include_subdomains": False},
),
# Missing fields use defaults
(
'{"max_pages": 30, "rationale": "Partial response"}',
{
"max_pages": 30,
"max_depth": 2,
"extract_metadata": True,
"priority_paths": [],
},
),
],
)
async def test_llm_response_parsing(
self, llm_response: str, expected_params: dict[str, object]
) -> None:
"""Test parsing of various LLM response formats."""
# Mock service factory to avoid global factory error
from unittest.mock import MagicMock
mock_service_factory = MagicMock()
state = {
"query": "Analyze this site",
"input_url": "https://example.com",
"messages": [],
"config": {},
"errors": [],
"service_factory": mock_service_factory, # Provide mock service factory
}
with (
patch("biz_bud.nodes.scraping.url_analyzer.call_model_node") as mock_call,
patch("biz_bud.services.factory.get_global_factory") as mock_factory,
):
mock_call.return_value = {"final_response": llm_response}
mock_factory.return_value = mock_service_factory
result = await _run_url_analyzer(state)
assert "url_processing_params" in result
params = result["url_processing_params"]
assert isinstance(params, dict)
# Assert all expected parameters match
assert all(params[key] == expected_value for key, expected_value in expected_params.items())
@pytest.mark.asyncio
async def test_error_handling(self) -> None:
"""Test error handling in URL analysis."""
from unittest.mock import MagicMock
mock_service_factory = MagicMock()
state = {
"query": "Analyze site",
"input_url": "https://example.com",
"messages": [],
"config": {},
"errors": [],
"service_factory": mock_service_factory,
}
# Test LLM call failure
with patch("biz_bud.nodes.scraping.url_analyzer.call_model_node") as mock_call:
mock_call.side_effect = Exception("LLM API error")
result = await _run_url_analyzer(state)
# Should return default params on error
assert "url_processing_params" in result
params = result["url_processing_params"]
assert isinstance(params, dict)
assert params["max_pages"] == 20
assert params["max_depth"] == 2
assert (
params["rationale"]
== "Using extracted parameters from user input or defaults"
)
@pytest.mark.asyncio
async def test_no_url_provided(self) -> None:
"""Test behavior when no URL is provided."""
state = {
"query": "Analyze something",
"input_url": "",
"messages": [],
"config": {},
"errors": [],
}
result = await _run_url_analyzer(state)
assert result == {"url_processing_params": None}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"url, path, expected_url_type",
[
("https://example.com/docs/api", "/docs/api", "documentation"),
("https://docs.example.com", "/", "documentation"),
("https://example.com/blog/post", "/blog/post", "blog"),
("https://site.com/articles/2024", "/articles/2024", "blog"),
("https://example.com/file.pdf", "/file.pdf", "single_page"),
(
"https://example.com/deep/nested/path/file.html",
"/deep/nested/path/file.html",
"single_page",
),
("https://github.com/user/repo", "/user/repo", "repository"),
("https://example.com/random", "/random", "general"),
],
)
async def test_url_type_detection(
self, url: str, path: str, expected_url_type: str
) -> None:
"""Test URL type detection logic."""
state = {
"query": "Analyze this",
"input_url": url,
"messages": [],
"config": {},
"errors": [],
}
# We'll intercept the LLM call to check what URL type was detected
detected_url_type = None
# We need to mock the service factory to control the LLM call
# Mock the service factory to return a mock LLM client
mock_llm_client = AsyncMock()
mock_llm_client.llm_json.return_value = (
None # Return None to force fallback logic
)
mock_service_factory = AsyncMock()
mock_service_factory.get_service.return_value = mock_llm_client
# Capture what URL type was detected from the internal classification logic
detected_url_type = None
async def mock_call_model(
state: dict[str, Any], config: RunnableConfig | None = None
) -> dict[str, object]:
# Extract the prompt from the state to check the URL type classification
prompt = state["messages"][1].content
import re
match = re.search(r"URL Type: (\w+)", prompt)
nonlocal detected_url_type
detected_url_type = match.group(1) if match else None
return {"final_response": None} # Return None to trigger fallback logic
with patch(
"biz_bud.nodes.scraping.url_analyzer.call_model_node",
side_effect=mock_call_model,
):
await _run_url_analyzer(state)
# The test should verify that the URL classification worked correctly
# The detected_url_type comes from the internal classification logic
assert detected_url_type == expected_url_type
@pytest.mark.asyncio
async def test_context_building(self) -> None:
"""Test context building from state."""
from unittest.mock import MagicMock
from langchain_core.messages import AIMessage, HumanMessage
mock_service_factory = MagicMock()
state = {
"query": "Analyze docs",
"input_url": "https://example.com",
"messages": [
HumanMessage(content="First message"),
AIMessage(content="Response message"),
HumanMessage(
content="Second message with a very long content that should be truncated after 200 characters to avoid sending too much context to the LLM when analyzing URL parameters for optimal crawling settings"
),
],
"synthesis": "Previous synthesis result that is also very long and should be truncated",
"extracted_info": [{"item": 1}, {"item": 2}, {"item": 3}],
"config": {},
"errors": [],
"service_factory": mock_service_factory,
}
context_captured = None
async def mock_call_model(
state: dict[str, Any], config: RunnableConfig | None = None
) -> dict[str, object]:
# Extract the context from the prompt
prompt = state["messages"][1].content
import re
match = re.search(r"Context: (.+?)\nURL Type:", prompt, re.DOTALL)
nonlocal context_captured
context_captured = match.group(1) if match else None
return {"final_response": None}
with patch(
"biz_bud.nodes.scraping.url_analyzer.call_model_node",
side_effect=mock_call_model,
):
await _run_url_analyzer(state)
assert context_captured is not None
assert "First message" in context_captured
assert "Response message" in context_captured
assert "..." in context_captured # Truncation indicator
assert "Previous synthesis" in context_captured
assert "3 items" in context_captured

View File

@@ -1,19 +1,81 @@
"""Unit tests for content validation node."""
from collections.abc import Awaitable, Callable
from typing import Any, cast
from unittest.mock import AsyncMock, MagicMock
import pytest
from langchain_core.runnables import RunnableConfig
from biz_bud.core.types import StateDict
from biz_bud.nodes.validation.content import (
ClaimCheckTypedDict,
FactCheckResultsTypedDict,
identify_claims_for_fact_checking,
perform_fact_check,
validate_content_output,
)
IDENTIFY_CLAIMS_NODE = cast(
Callable[[StateDict, RunnableConfig | None], Awaitable[StateDict]],
identify_claims_for_fact_checking,
)
PERFORM_FACT_CHECK_NODE = cast(
Callable[[StateDict, RunnableConfig | None], Awaitable[StateDict]],
perform_fact_check,
)
VALIDATE_CONTENT_NODE = cast(
Callable[[StateDict, RunnableConfig | None], Awaitable[StateDict]],
validate_content_output,
)
def _node_config(overrides: dict[str, Any] | None = None) -> RunnableConfig:
base: dict[str, Any] = {"metadata": {"unit_test": "content-validation"}}
if overrides:
base.update(overrides)
return cast(RunnableConfig, base)
def _as_state(payload: dict[str, object]) -> StateDict:
return cast(StateDict, payload)
async def _identify(
state: dict[str, object], config: RunnableConfig | None = None
) -> StateDict:
return await IDENTIFY_CLAIMS_NODE(_as_state(state), config or _node_config())
async def _perform_fact_check(
state: dict[str, object], config: RunnableConfig | None = None
) -> StateDict:
return await PERFORM_FACT_CHECK_NODE(_as_state(state), config or _node_config())
async def _validate_content(
state: dict[str, object], config: RunnableConfig | None = None
) -> StateDict:
return await VALIDATE_CONTENT_NODE(_as_state(state), config or _node_config())
def _expect_fact_check_results(value: object) -> FactCheckResultsTypedDict:
assert isinstance(value, dict)
return cast(FactCheckResultsTypedDict, value)
def _expect_claims(value: object) -> list[dict[str, object]]:
assert isinstance(value, list)
return cast(list[dict[str, object]], value)
def _expect_issue_list(value: object) -> list[str]:
if isinstance(value, list):
return [item for item in value if isinstance(item, str)]
return []
@pytest.fixture
def minimal_state():
def minimal_state() -> dict[str, object]:
"""Create a minimal state for testing."""
return {
"messages": [],
@@ -26,9 +88,8 @@ def minimal_state():
"status": "running",
}
@pytest.fixture
def mock_service_factory():
def mock_service_factory() -> tuple[MagicMock, AsyncMock]:
"""Create a mock service factory with LLM client."""
factory = MagicMock()
llm_client = AsyncMock()
@@ -52,7 +113,6 @@ def mock_service_factory():
return factory, llm_client
@pytest.mark.asyncio
class TestIdentifyClaimsForFactChecking:
"""Test the identify_claims_for_fact_checking function."""
@@ -70,18 +130,12 @@ class TestIdentifyClaimsForFactChecking:
'["The Earth is round", "Water boils at 100°C at sea level"]'
)
result = await identify_claims_for_fact_checking(minimal_state)
result = await _identify(minimal_state)
assert "claims_to_check" in result
assert len(result.get("claims_to_check", [])) == 2
assert (
result.get("claims_to_check", [])[0]["claim_statement"]
== "The Earth is round"
)
assert (
result.get("claims_to_check", [])[1]["claim_statement"]
== "Water boils at 100°C at sea level"
)
claims = _expect_claims(result.get("claims_to_check", []))
assert len(claims) == 2
assert claims[0]["claim_statement"] == "The Earth is round"
assert claims[1]["claim_statement"] == "Water boils at 100°C at sea level"
async def test_identify_claims_from_research_summary(
self, minimal_state, mock_service_factory
@@ -93,14 +147,11 @@ class TestIdentifyClaimsForFactChecking:
llm_client.llm_chat.return_value = "AI market grew by 40% in 2023"
result = await identify_claims_for_fact_checking(minimal_state)
result = await _identify(minimal_state)
assert "claims_to_check" in result
assert len(result.get("claims_to_check", [])) == 1
assert (
result.get("claims_to_check", [])[0]["claim_statement"]
== "AI market grew by 40% in 2023"
)
claims = _expect_claims(result.get("claims_to_check", []))
assert len(claims) == 1
assert claims[0]["claim_statement"] == "AI market grew by 40% in 2023"
async def test_identify_claims_no_content(
self, minimal_state, mock_service_factory
@@ -109,12 +160,11 @@ class TestIdentifyClaimsForFactChecking:
factory, _llm_client = mock_service_factory
minimal_state["service_factory"] = factory
result = await identify_claims_for_fact_checking(minimal_state)
result = await _identify(minimal_state)
assert result.get("claims_to_check", []) == []
fact_check_results = cast(
"dict[str, Any]", result.get("fact_check_results", {})
)
claims = _expect_claims(result.get("claims_to_check", []))
assert claims == []
fact_check_results = _expect_fact_check_results(result.get("fact_check_results", {}))
assert fact_check_results["issues"] == ["No content provided"]
assert fact_check_results["score"] == 0.0
@@ -127,13 +177,13 @@ class TestIdentifyClaimsForFactChecking:
minimal_state["config"] = {} # No llm_config
minimal_state["service_factory"] = factory
result = await identify_claims_for_fact_checking(minimal_state)
result = await _identify(minimal_state)
assert result.get("claims_to_check", []) == []
fact_check_results = cast(
"dict[str, Any]", result.get("fact_check_results", {})
)
assert "Error identifying claims:" in fact_check_results["issues"][0]
claims = _expect_claims(result.get("claims_to_check", []))
assert claims == []
fact_check_results = _expect_fact_check_results(result.get("fact_check_results", {}))
issues = fact_check_results["issues"]
assert any(entry.startswith("Error identifying claims:") for entry in issues)
async def test_identify_claims_llm_error(self, minimal_state, mock_service_factory):
"""Test error handling when LLM call fails."""
@@ -143,21 +193,17 @@ class TestIdentifyClaimsForFactChecking:
llm_client.llm_chat.side_effect = Exception("LLM API error")
result = await identify_claims_for_fact_checking(minimal_state)
result = await _identify(minimal_state)
assert result.get("claims_to_check", []) == []
fact_check_results = cast(
"dict[str, Any]", result.get("fact_check_results", {})
)
issues = cast("list[str]", fact_check_results["issues"])
assert "Error identifying claims" in issues[0]
claims = _expect_claims(result.get("claims_to_check", []))
assert claims == []
fact_check_results = _expect_fact_check_results(result.get("fact_check_results", {}))
issues = fact_check_results["issues"]
assert any(entry.startswith("Error identifying claims") for entry in issues)
assert result.get("is_output_valid") is False
assert len(result["errors"]) > 0
@pytest.mark.asyncio
class TestPerformFactCheck:
"""Test the perform_fact_check function."""
errors = result.get("errors", [])
assert isinstance(errors, list)
assert len(errors) > 0
async def test_fact_check_success(self, minimal_state, mock_service_factory):
"""Test successful fact checking of claims."""
@@ -184,13 +230,11 @@ class TestPerformFactCheck:
},
]
result = await perform_fact_check(minimal_state)
result = await _perform_fact_check(minimal_state)
assert "fact_check_results" in result
fact_check_results = cast(
"dict[str, Any]", result.get("fact_check_results", {})
)
assert len(fact_check_results["claims_checked"]) == 2
fact_check_results = _expect_fact_check_results(result.get("fact_check_results", {}))
claims_checked: list[ClaimCheckTypedDict] = fact_check_results["claims_checked"]
assert len(claims_checked) == 2
assert fact_check_results["score"] == 9.5 # (9+10)/2
assert fact_check_results["issues"] == []
@@ -209,14 +253,13 @@ class TestPerformFactCheck:
"verification_notes": "Common misconception",
}
result = await perform_fact_check(minimal_state)
result = await _perform_fact_check(minimal_state)
fact_check_results = cast(
"dict[str, Any]", result.get("fact_check_results", {})
)
fact_check_results = _expect_fact_check_results(result.get("fact_check_results", {}))
assert fact_check_results["score"] == 1.0
assert len(fact_check_results["issues"]) == 1
assert "myth" in fact_check_results["issues"][0]
issues = fact_check_results["issues"]
assert len(issues) == 1
assert "myth" in issues[0]
async def test_fact_check_no_claims(self, minimal_state, mock_service_factory):
"""Test behavior when no claims to check."""
@@ -224,11 +267,9 @@ class TestPerformFactCheck:
minimal_state["claims_to_check"] = []
minimal_state["service_factory"] = factory
result = await perform_fact_check(minimal_state)
result = await _perform_fact_check(minimal_state)
fact_check_results = cast(
"dict[str, Any]", result.get("fact_check_results", {})
)
fact_check_results = _expect_fact_check_results(result.get("fact_check_results", {}))
assert fact_check_results["claims_checked"] == []
assert fact_check_results["issues"] == ["No claims to check"]
assert fact_check_results["score"] == 0.0
@@ -241,23 +282,15 @@ class TestPerformFactCheck:
llm_client.llm_json.side_effect = Exception("API timeout")
result = await perform_fact_check(minimal_state)
result = await _perform_fact_check(minimal_state)
fact_check_results = cast(
"dict[str, Any]", result.get("fact_check_results", {})
)
assert len(fact_check_results["claims_checked"]) == 1
claims_checked = cast("list[str]", fact_check_results["claims_checked"])
assert (
cast("dict[str, Any]", cast("dict[str, Any]", claims_checked[0])["result"])[
"accuracy"
]
== 1
)
assert (
"API timeout"
in cast("dict[str, Any]", result.get("fact_check_results", {}))["issues"][0]
)
fact_check_results = _expect_fact_check_results(result.get("fact_check_results", {}))
claims_checked: list[ClaimCheckTypedDict] = fact_check_results["claims_checked"]
assert len(claims_checked) == 1
claim_result = claims_checked[0]["result"]
assert claim_result["accuracy"] == 1
issues = fact_check_results["issues"]
assert any("API timeout" in issue for issue in issues)
@pytest.mark.asyncio
@@ -270,25 +303,21 @@ class TestValidateContentOutput:
"This is a sufficiently long and valid final output without any issues."
)
result = await validate_content_output(minimal_state)
result = await _validate_content(minimal_state)
assert result.get("is_output_valid") is True
assert (
"validation_issues" not in result
or result.get("validation_issues", []) == []
)
issues = _expect_issue_list(result.get("validation_issues", []))
assert not issues
async def test_validate_output_too_short(self, minimal_state):
"""Test validation fails for short output."""
minimal_state["final_output"] = "Too short"
result = await validate_content_output(minimal_state)
result = await _validate_content(minimal_state)
assert result.get("is_output_valid") is False
assert any(
"Output seems too short" in issue
for issue in result.get("validation_issues", [])
)
issues = _expect_issue_list(result.get("validation_issues", []))
assert any("Output seems too short" in issue for issue in issues)
async def test_validate_output_contains_error(self, minimal_state):
"""Test validation fails when output contains 'error'."""
@@ -296,13 +325,11 @@ class TestValidateContentOutput:
"This is a long output but contains an error message somewhere."
)
result = await validate_content_output(minimal_state)
result = await _validate_content(minimal_state)
assert result.get("is_output_valid") is False
assert any(
"Output contains the word 'error'" in issue
for issue in result.get("validation_issues", [])
)
issues = _expect_issue_list(result.get("validation_issues", []))
assert any("Output contains the word 'error'" in issue for issue in issues)
async def test_validate_output_placeholder(self, minimal_state):
"""Test validation fails for placeholder text."""
@@ -310,24 +337,24 @@ class TestValidateContentOutput:
"This is a placeholder text that should be replaced with actual content."
)
result = await validate_content_output(minimal_state)
result = await _validate_content(minimal_state)
assert result.get("is_output_valid") is False
issues = _expect_issue_list(result.get("validation_issues", []))
assert any(
"Output may contain unresolved placeholder text" in issue
for issue in result.get("validation_issues", [])
"Output may contain unresolved placeholder text" in issue for issue in issues
)
async def test_validate_no_output(self, minimal_state):
"""Test behavior when no final output exists."""
# Don't set final_output
result = await validate_content_output(minimal_state)
result = await _validate_content(minimal_state)
assert result.get("is_output_valid") is None
issues = _expect_issue_list(result.get("validation_issues", []))
assert any(
"No final output generated for validation" in issue
for issue in result.get("validation_issues", [])
"No final output generated for validation" in issue for issue in issues
)
async def test_validate_already_invalid(self, minimal_state):
@@ -335,7 +362,7 @@ class TestValidateContentOutput:
minimal_state["final_output"] = "Some output"
minimal_state["is_output_valid"] = False # Explicitly set for this test
result = await validate_content_output(minimal_state)
result = await _validate_content(minimal_state)
# Should not change the is_output_valid status
assert result.get("is_output_valid") is False

View File

@@ -1,6 +1,9 @@
"""Unit tests for human feedback validation functionality."""
from typing import Any, cast
import pytest
from langchain_core.runnables import RunnableConfig
from biz_bud.nodes.validation import human_feedback
from biz_bud.states.unified import BusinessBuddyState
@@ -9,8 +12,6 @@ from biz_bud.states.unified import BusinessBuddyState
@pytest.fixture
def minimal_state() -> BusinessBuddyState:
"""Create a minimal state for testing human feedback validation."""
from typing import Any, cast
return cast(
"BusinessBuddyState",
{
@@ -47,7 +48,8 @@ async def test_should_request_feedback_success(minimal_state) -> None:
@pytest.mark.asyncio
async def test_prepare_human_feedback_request_success(minimal_state) -> None:
"""Test successful preparation of human feedback request."""
result = await human_feedback.prepare_human_feedback_request(minimal_state)
config = cast("RunnableConfig", {})
result = await human_feedback.prepare_human_feedback_request(minimal_state, config)
assert "human_feedback_context" in result
assert "human_feedback_summary" in result
assert "requires_interrupt" in result
@@ -58,7 +60,8 @@ async def test_prepare_human_feedback_request_error(minimal_state) -> None:
"""Test human feedback request preparation with error conditions."""
# The current implementation doesn't generate errors directly,
# so test the basic functionality
result = await human_feedback.prepare_human_feedback_request(minimal_state)
config = cast("RunnableConfig", {})
result = await human_feedback.prepare_human_feedback_request(minimal_state, config)
assert "human_feedback_context" in result
assert "human_feedback_summary" in result
assert "requires_interrupt" in result
@@ -68,14 +71,7 @@ async def test_prepare_human_feedback_request_error(minimal_state) -> None:
async def test_apply_human_feedback_success(minimal_state) -> None:
"""Test successful application of human feedback."""
# Test the apply_human_feedback function instead
result = await human_feedback.apply_human_feedback(minimal_state)
config = cast("RunnableConfig", {})
result = await human_feedback.apply_human_feedback(minimal_state, config)
# Check the expected return type from FeedbackUpdate
assert isinstance(result, dict)
@pytest.mark.asyncio
async def test_should_apply_refinement_success(minimal_state) -> None:
"""Test successful refinement application validation."""
# Test the should_apply_refinement function
result = human_feedback.should_apply_refinement(minimal_state)
assert isinstance(result, bool)

View File

@@ -3,7 +3,7 @@
import asyncio
import gc
import weakref
from typing import Any, AsyncGenerator
from typing import AsyncGenerator
from unittest.mock import AsyncMock, patch
import pytest
@@ -52,17 +52,16 @@ class MockServiceConfig(BaseServiceConfig):
class MockService(BaseService[MockServiceConfig]):
"""Mock service for testing."""
def __init__(self, config: Any) -> None:
"""Initialize mock service with config."""
# For tests, accept anything as config but ensure proper parent initialization
validated_config = self._validate_config(config)
super().__init__(validated_config)
self._raw_config = config
def __init__(self, app_config: AppConfig) -> None:
"""Initialize mock service with a real AppConfig."""
super().__init__(app_config)
self._raw_config = app_config
self.initialized = False
self.cleaned_up = False
@classmethod
def _validate_config(cls, app_config: Any) -> MockServiceConfig:
def _validate_config(cls, app_config: AppConfig) -> MockServiceConfig:
"""Validate config - for tests, just return a mock config."""
return MockServiceConfig()

View File

@@ -24,25 +24,20 @@ class TestServiceConfig(BaseServiceConfig):
class TestService(BaseService[TestServiceConfig]):
"""Test service for dependency injection patterns."""
def __init__(self, config: Any) -> None:
def __init__(self, app_config: AppConfig) -> None:
"""Initialize test service with config."""
validated_config = self._validate_config(config)
super().__init__(validated_config)
super().__init__(app_config)
self.initialized = False
self.cleanup_called = False
self.start_time = time.time()
self.dependencies: dict[str, Any] = {}
@classmethod
def _validate_config(cls, app_config: Any) -> TestServiceConfig:
def _validate_config(cls, app_config: AppConfig) -> TestServiceConfig:
"""Validate and convert config to proper type."""
if isinstance(app_config, TestServiceConfig):
return app_config
elif isinstance(app_config, dict):
return TestServiceConfig(**app_config)
else:
# For testing, create a default config
return TestServiceConfig()
return TestServiceConfig()
async def initialize(self) -> None:
"""Initialize the service."""
@@ -66,18 +61,17 @@ class TestService(BaseService[TestServiceConfig]):
class DependentService(BaseService[TestServiceConfig]):
"""Service that depends on other services."""
def __init__(self, config: Any, test_service: TestService) -> None:
def __init__(self, app_config: AppConfig, test_service: TestService) -> None:
"""Initialize with dependency."""
validated_config = self._validate_config(config)
super().__init__(validated_config)
super().__init__(app_config)
self.test_service = test_service
self.initialized = False
@classmethod
def _validate_config(cls, app_config: Any) -> TestServiceConfig:
def _validate_config(cls, app_config: AppConfig) -> TestServiceConfig:
"""Validate config."""
if isinstance(app_config, TestServiceConfig):
return app_config
return TestServiceConfig()
async def initialize(self) -> None:
@@ -91,16 +85,17 @@ class DependentService(BaseService[TestServiceConfig]):
class SlowInitializingService(BaseService[TestServiceConfig]):
"""Service that takes time to initialize."""
def __init__(self, config: Any, init_delay: float = 0.1) -> None:
def __init__(self, app_config: AppConfig, init_delay: float = 0.1) -> None:
"""Initialize with configurable delay."""
validated_config = self._validate_config(config)
super().__init__(validated_config)
super().__init__(app_config)
self.init_delay = init_delay
self.initialized = False
@classmethod
def _validate_config(cls, app_config: Any) -> TestServiceConfig:
def _validate_config(cls, app_config: AppConfig) -> TestServiceConfig:
"""Validate config."""
return TestServiceConfig()
async def initialize(self) -> None:
@@ -113,15 +108,18 @@ class SlowInitializingService(BaseService[TestServiceConfig]):
class FailingService(BaseService[TestServiceConfig]):
"""Service that fails during initialization."""
def __init__(self, config: Any, failure_message: str = "Initialization failed") -> None:
def __init__(
self, app_config: AppConfig, failure_message: str = "Initialization failed"
) -> None:
"""Initialize with failure configuration."""
validated_config = self._validate_config(config)
super().__init__(validated_config)
super().__init__(app_config)
self.failure_message = failure_message
@classmethod
def _validate_config(cls, app_config: Any) -> TestServiceConfig:
def _validate_config(cls, app_config: AppConfig) -> TestServiceConfig:
"""Validate config."""
return TestServiceConfig()
async def initialize(self) -> None:
@@ -166,7 +164,7 @@ class TestServiceFactoryBasics:
"""Test basic service creation."""
# Mock the cleanup registry to return our test service
with patch.object(service_factory._cleanup_registry, 'create_service') as mock_create:
test_service = TestService({})
test_service = TestService(service_factory.config)
await test_service.initialize()
mock_create.return_value = test_service
@@ -181,7 +179,7 @@ class TestServiceFactoryBasics:
async def test_service_singleton_behavior(self, service_factory):
"""Test that services are created as singletons."""
with patch.object(service_factory._cleanup_registry, 'create_service') as mock_create:
test_service = TestService({})
test_service = TestService(service_factory.config)
await test_service.initialize()
mock_create.return_value = test_service
@@ -196,7 +194,7 @@ class TestServiceFactoryBasics:
async def test_service_registration(self, service_factory):
"""Test that services are properly registered."""
with patch.object(service_factory._cleanup_registry, 'create_service') as mock_create:
test_service = TestService({})
test_service = TestService(service_factory.config)
await test_service.initialize()
mock_create.return_value = test_service
@@ -220,7 +218,7 @@ class TestConcurrencyAndRaceConditions:
call_count["value"] += 1
# Simulate some initialization time
await asyncio.sleep(0.01)
service = TestService({})
service = TestService(service_factory.config)
await service.initialize()
return service
@@ -250,7 +248,7 @@ class TestConcurrencyAndRaceConditions:
creation_order.append(f"start_{service_class.__name__}")
await asyncio.sleep(0.02) # Simulate work
creation_order.append(f"end_{service_class.__name__}")
service = TestService({})
service = TestService(service_factory.config)
await service.initialize()
return service
@@ -272,7 +270,7 @@ class TestConcurrencyAndRaceConditions:
async def test_initialization_tracking_cleanup(self, service_factory):
"""Test that initialization tracking is properly cleaned up."""
with patch.object(service_factory._cleanup_registry, 'create_service') as mock_create:
test_service = TestService({})
test_service = TestService(service_factory.config)
await test_service.initialize()
mock_create.return_value = test_service
@@ -303,7 +301,7 @@ class TestErrorHandling:
"""Test handling of initialization timeouts."""
async def slow_create_service(service_class):
await asyncio.sleep(10) # Very slow initialization
return TestService({})
return TestService(service_factory.config)
with patch.object(service_factory._cleanup_registry, 'create_service', side_effect=slow_create_service):
# This should timeout in real scenarios
@@ -318,7 +316,7 @@ class TestErrorHandling:
async def cancellable_create_service(service_class):
try:
await asyncio.sleep(1) # Long operation
return TestService({})
return TestService(service_factory.config)
except asyncio.CancelledError:
# Simulate proper cleanup on cancellation
raise
@@ -336,7 +334,7 @@ class TestErrorHandling:
async def test_cleanup_tracking_error_recovery(self, service_factory):
"""Test recovery from cleanup tracking errors."""
with patch.object(service_factory._cleanup_registry, 'create_service') as mock_create:
test_service = TestService({})
test_service = TestService(service_factory.config)
await test_service.initialize()
mock_create.return_value = test_service
@@ -354,7 +352,7 @@ class TestServiceLifecycle:
async def test_service_cleanup(self, service_factory):
"""Test service cleanup process."""
with patch.object(service_factory._cleanup_registry, 'create_service') as mock_create:
test_service = TestService({})
test_service = TestService(service_factory.config)
await test_service.initialize()
mock_create.return_value = test_service
@@ -383,7 +381,7 @@ class TestServiceLifecycle:
async def test_service_memory_cleanup(self, service_factory):
"""Test that services are properly cleaned up from memory."""
with patch.object(service_factory._cleanup_registry, 'create_service') as mock_create:
test_service = TestService({})
test_service = TestService(service_factory.config)
await test_service.initialize()
mock_create.return_value = test_service
@@ -415,12 +413,12 @@ class TestDependencyInjection:
# Setup mock to return different services based on class
def create_service_side_effect(service_class):
if service_class == TestService:
service = TestService({})
service = TestService(service_factory.config)
elif service_class == DependentService:
# This would normally be handled by the cleanup registry
# For testing, we simulate the dependency injection
test_service = TestService({})
service = DependentService({}, test_service)
test_service = TestService(service_factory.config)
service = DependentService(service_factory.config, test_service)
else:
raise ValueError(f"Unknown service class: {service_class}")
@@ -460,7 +458,7 @@ class TestPerformanceAndResourceManagement:
async def timed_create_service(service_class):
start_time = time.time()
service = TestService({})
service = TestService(service_factory.config)
await service.initialize()
creation_times.append(time.time() - start_time)
return service
@@ -538,7 +536,8 @@ class TestConfigurationIntegration:
async def test_config_propagation_to_services(self, service_factory):
"""Test that configuration is properly propagated to services."""
with patch.object(service_factory._cleanup_registry, 'create_service') as mock_create:
test_service = TestService({"test_value": "configured"})
test_service = TestService(service_factory.config)
test_service.config.test_value = "configured"
await test_service.initialize()
mock_create.return_value = test_service
@@ -564,7 +563,7 @@ class TestThreadSafetyAndAsyncPatterns:
try:
# Get services
with patch.object(service_factory._cleanup_registry, 'create_service') as mock_create:
test_service = TestService({})
test_service = TestService(service_factory.config)
await test_service.initialize()
mock_create.return_value = test_service
@@ -578,7 +577,7 @@ class TestThreadSafetyAndAsyncPatterns:
async def test_service_access_after_cleanup(self, service_factory):
"""Test service access after factory cleanup."""
with patch.object(service_factory._cleanup_registry, 'create_service') as mock_create:
test_service = TestService({})
test_service = TestService(service_factory.config)
await test_service.initialize()
mock_create.return_value = test_service
@@ -592,7 +591,7 @@ class TestThreadSafetyAndAsyncPatterns:
# Getting service again should work (creates new instance)
# Reset the mock to return a new service
new_service = TestService({})
new_service = TestService(service_factory.config)
await new_service.initialize()
mock_create.return_value = new_service
@@ -624,7 +623,7 @@ class TestEdgeCasesAndErrorScenarios:
# First call fails
mock_create.side_effect = [
RuntimeError("First attempt failed"),
TestService({}) # Second attempt succeeds
TestService(service_factory.config), # Second attempt succeeds
]
# First attempt should fail
@@ -640,7 +639,7 @@ class TestEdgeCasesAndErrorScenarios:
async def test_factory_state_consistency(self, service_factory):
"""Test factory state consistency across operations."""
with patch.object(service_factory._cleanup_registry, 'create_service') as mock_create:
test_service = TestService({})
test_service = TestService(service_factory.config)
await test_service.initialize()
mock_create.return_value = test_service

View File

@@ -6,9 +6,27 @@ from biz_bud.tools.capabilities.extraction.core.types import (
FactTypedDict,
JsonDict,
JsonValue,
YearMentionTypedDict,
)
def _year_mention(
year: int,
context: str,
*,
value: int | None = None,
text: str | None = None,
) -> YearMentionTypedDict:
"""Build a YearMentionTypedDict with optional metadata."""
mention: YearMentionTypedDict = {"year": year, "context": context}
if value is not None:
mention["value"] = value
if text is not None:
mention["text"] = text
return mention
class TestJsonValueType:
"""Test JsonValue type definition and usage."""
@@ -290,18 +308,19 @@ class TestFactTypedDict:
def test_fact_typed_dict_with_year_mentioned(self):
"""Test FactTypedDict with year_mentioned field."""
year_data = [
{"year": 2023, "context": "fiscal year"},
{"year": 2024, "context": "projected"}
year_data: list[YearMentionTypedDict] = [
_year_mention(2023, "fiscal year"),
_year_mention(2024, "projected"),
]
fact: FactTypedDict = {"year_mentioned": year_data}
assert len(fact["year_mentioned"]) == 2
assert isinstance(fact["year_mentioned"], list)
mentions = fact.get("year_mentioned")
assert mentions is not None
assert len(mentions) == 2
first_year = fact["year_mentioned"][0]
assert isinstance(first_year, dict)
assert first_year["year"] == 2023
first_year = mentions[0]
assert first_year.get("year") == 2023
assert first_year.get("context") == "fiscal year"
def test_fact_typed_dict_with_source_quality(self):
"""Test FactTypedDict with source_quality field."""
@@ -546,38 +565,27 @@ class TestTypeEdgeCases:
def test_fact_typed_dict_with_complex_year_mentioned(self):
"""Test FactTypedDict with complex year_mentioned structure."""
complex_years = [
{
"year": 2023,
"context": "reporting period",
"confidence": 0.95,
"source": "document title"
},
{
"year": 2024,
"context": "projected",
"confidence": 0.7,
"source": "forecast section",
"notes": ["estimate", "subject to change"]
}
complex_years: list[YearMentionTypedDict] = [
_year_mention(2023, "reporting period", value=2023, text="document title"),
_year_mention(2024, "projected", value=2024, text="forecast section"),
]
fact: FactTypedDict = {
"fact": "Multi-year projection",
"year_mentioned": complex_years
"year_mentioned": complex_years,
}
assert len(fact["year_mentioned"]) == 2
mentions = fact.get("year_mentioned")
assert mentions is not None
assert len(mentions) == 2
year_2023 = fact["year_mentioned"][0]
assert isinstance(year_2023, dict)
assert year_2023["year"] == 2023
assert year_2023["confidence"] == 0.95
year_2023 = mentions[0]
assert year_2023.get("year") == 2023
assert year_2023.get("text") == "document title"
year_2024 = fact["year_mentioned"][1]
assert isinstance(year_2024, dict)
assert year_2024["year"] == 2024
assert isinstance(year_2024["notes"], list)
year_2024 = mentions[1]
assert year_2024.get("year") == 2024
assert year_2024.get("context") == "projected"
def test_empty_and_minimal_structures(self):
"""Test empty and minimal type structures."""

View File

@@ -2,13 +2,16 @@
import json
from datetime import datetime
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import ValidationError
from biz_bud.core.types import JSONValue, StateDict
from biz_bud.tools.capabilities.extraction.legacy_tools import (
EXTRACTION_STATE_METHODS,
CategoryExtractionInput,
CategoryExtractionLangChainTool,
CategoryExtractionTool,
StatisticsExtractionInput,
@@ -24,6 +27,22 @@ from biz_bud.tools.capabilities.extraction.legacy_tools import (
)
def _get_list(value: Any) -> list[Any]:
return value if isinstance(value, list) else []
def _get_dict(value: Any) -> dict[str, Any]:
return value if isinstance(value, dict) else {}
def _get_str(value: Any) -> str | None:
return value if isinstance(value, str) else None
def _to_state(data: dict[str, JSONValue]) -> StateDict:
return data
class TestStatisticsExtractionInput:
"""Test StatisticsExtractionInput schema validation."""
@@ -87,24 +106,34 @@ class TestStatisticsExtractionOutput:
{"text": "25%", "type": "percentage", "value": 25},
{"text": "$100M", "type": "monetary", "value": 100000000},
],
quality_scores={"overall": 0.8, "credibility": 0.7},
quality_scores={
"source_quality": 0.8,
"average_statistic_quality": 0.7,
"total_credibility_terms": 5,
},
total_facts=2,
extraction_metadata={},
)
assert len(output_data.statistics) == 2
assert output_data.quality_scores["overall"] == 0.8
assert output_data.quality_scores.get("source_quality") == 0.8
assert output_data.total_facts == 2
def test_output_schema_empty_values(self):
"""Test output schema with empty values."""
output_data = StatisticsExtractionOutput(
statistics=[],
quality_scores={},
quality_scores={
"source_quality": 0.0,
"average_statistic_quality": 0.0,
"total_credibility_terms": 0,
},
total_facts=0,
extraction_metadata={},
)
assert output_data.statistics == []
assert output_data.quality_scores == {}
assert output_data.quality_scores["total_credibility_terms"] == 0
assert output_data.total_facts == 0
@@ -164,7 +193,7 @@ class TestExtractStatisticsTool:
assert percentage_fact is not None
assert percentage_fact["value"] == 25
assert percentage_fact["source_url"] == url
assert percentage_fact["source_title"] == source_title
assert _get_str(percentage_fact.get("source_title")) == source_title
# Check monetary fact
monetary_fact = next(
@@ -172,7 +201,7 @@ class TestExtractStatisticsTool:
)
assert monetary_fact is not None
assert monetary_fact["value"] == 100000000
assert monetary_fact["currency"] == "USD"
assert _get_str(monetary_fact.get("currency")) == "USD"
def test_extract_statistics_no_data_found(self):
"""Test statistics extraction when no statistics are found."""
@@ -397,7 +426,7 @@ class TestExtractCategoryInformation:
mock_logger.info.assert_called_with("Empty content provided")
assert result["facts"] == []
assert result["relevance_score"] == 0.0
assert result["category"] == category
assert result.get("category") == category
@pytest.mark.asyncio
async def test_extract_category_information_invalid_url(self):
@@ -420,7 +449,7 @@ class TestExtractCategoryInformation:
mock_logger.info.assert_called_with(f"Invalid URL: {url}")
assert result["facts"] == []
assert result["category"] == category
assert result.get("category") == category
@pytest.mark.asyncio
async def test_extract_category_information_exception(self):
@@ -447,7 +476,7 @@ class TestExtractCategoryInformation:
mock_logger.error.assert_called_once()
assert result["facts"] == []
assert result["category"] == category
assert result.get("category") == category
@pytest.mark.asyncio
async def test_extract_category_information_whitespace_content(self):
@@ -584,8 +613,8 @@ class TestProcessContent:
assert isinstance(facts, list) and len(facts) == 1
fact = facts[0]
assert isinstance(fact, dict)
assert fact["source_title"] == source_title
assert fact["currency"] == "USD"
assert _get_str(fact.get("source_title")) == source_title
assert _get_str(fact.get("currency")) == "USD"
@pytest.mark.asyncio
async def test_process_content_no_facts_found(self):
@@ -716,7 +745,7 @@ class TestHelperFunctions:
assert result["facts"] == []
assert result["relevance_score"] == 0.0
assert result["processed_at"] == "2024-01-01T00:00:00Z"
assert result["category"] == category
assert result.get("category") == category
def test_get_timestamp(self):
"""Test _get_timestamp function."""
@@ -800,11 +829,11 @@ class TestCategoryExtractionLangChainTool:
tool = CategoryExtractionLangChainTool()
assert tool.name == "category_extraction_langchain"
assert "extract structured information" in tool.description.lower()
assert tool.args_schema == tool.CategoryExtractionInput
assert tool.args_schema is CategoryExtractionInput
def test_input_schema(self):
"""Test nested input schema."""
input_data = CategoryExtractionLangChainTool.CategoryExtractionInput(
input_data = CategoryExtractionInput(
content="Test content",
url="https://example.com",
category="technology",
@@ -870,12 +899,12 @@ class TestExtractionStateMethods:
@pytest.mark.asyncio
async def test_extract_statistics_from_state_successful(self):
"""Test successful statistics extraction from state."""
state = {
state = _to_state({
"text": "Revenue increased by 25%",
"url": "https://example.com",
"source_title": "Q4 Report",
"chunk_size": 4000,
}
})
mock_result = {
"statistics": [{"text": "25%", "type": "percentage", "value": 25}],
@@ -890,14 +919,14 @@ class TestExtractionStateMethods:
methods = create_extraction_state_methods()
result_state = await methods["extract_statistics_from_state"](state)
assert result_state["extracted_statistics"] == mock_result["statistics"]
assert result_state["statistics_quality_scores"] == mock_result["quality_scores"]
assert result_state["total_facts"] == mock_result["total_facts"]
assert _get_list(result_state.get("extracted_statistics")) == mock_result["statistics"]
assert _get_dict(result_state.get("statistics_quality_scores")) == mock_result["quality_scores"]
assert result_state.get("total_facts") == mock_result["total_facts"]
@pytest.mark.asyncio
async def test_extract_statistics_from_state_no_text(self):
"""Test statistics extraction from state with no text."""
state = {"url": "https://example.com"}
state = _to_state({"url": "https://example.com"})
with patch(
"biz_bud.tools.capabilities.extraction.legacy_tools.logger"
@@ -912,10 +941,10 @@ class TestExtractionStateMethods:
async def test_extract_statistics_from_state_fallback_text_sources(self):
"""Test statistics extraction with fallback text sources."""
# Test content fallback
state = {
state = _to_state({
"content": "Profit margin of 12%",
"url": "https://example.com",
}
})
mock_result = {
"statistics": [{"text": "12%", "type": "percentage", "value": 12}],
@@ -935,10 +964,10 @@ class TestExtractionStateMethods:
assert call_args["text"] == "Profit margin of 12%"
# Test search results fallback
state = {
state = _to_state({
"search_results": [{"snippet": "Growth rate of 8%"}],
"url": "https://example.com",
}
})
with patch('biz_bud.tools.capabilities.extraction.legacy_tools.extract_statistics') as mock_tool:
mock_tool.ainvoke = AsyncMock(return_value=mock_result)
@@ -951,10 +980,10 @@ class TestExtractionStateMethods:
assert call_args["text"] == "Growth rate of 8%"
# Test scraped content fallback
state = {
state = _to_state({
"scraped_content": {"content": "Market share of 30%"},
"url": "https://example.com",
}
})
with patch('biz_bud.tools.capabilities.extraction.legacy_tools.extract_statistics') as mock_tool:
mock_tool.ainvoke = AsyncMock(return_value=mock_result)
@@ -969,7 +998,7 @@ class TestExtractionStateMethods:
@pytest.mark.asyncio
async def test_extract_statistics_from_state_exception(self):
"""Test statistics extraction from state with exception."""
state = {"text": "Test text"}
state = _to_state({"text": "Test text"})
with patch('biz_bud.tools.capabilities.extraction.legacy_tools.extract_statistics') as mock_tool:
mock_tool.ainvoke = AsyncMock(side_effect=ValueError("Extract error"))
@@ -979,19 +1008,19 @@ class TestExtractionStateMethods:
methods = create_extraction_state_methods()
result_state = await methods["extract_statistics_from_state"](state)
assert "errors" in result_state
assert "Statistics extraction failed: Extract error" in result_state["errors"]
errors = _get_list(result_state.get("errors"))
assert "Statistics extraction failed: Extract error" in errors
mock_logger.error.assert_called_once()
@pytest.mark.asyncio
async def test_extract_category_info_from_state_successful(self):
"""Test successful category info extraction from state."""
state = {
state = _to_state({
"content": "AI technology is advancing rapidly",
"url": "https://example.com",
"category": "technology",
"source_title": "Tech News",
}
})
mock_result = {
"facts": [{"text": "AI advancing", "type": "trend"}],
@@ -1007,14 +1036,14 @@ class TestExtractionStateMethods:
methods = create_extraction_state_methods()
result_state = await methods["extract_category_info_from_state"](state)
assert result_state["extracted_facts"] == mock_result["facts"]
assert result_state["relevance_score"] == mock_result["relevance_score"]
assert result_state["extraction_processed_at"] == mock_result["processed_at"]
assert _get_list(result_state.get("extracted_facts")) == mock_result["facts"]
assert result_state.get("relevance_score") == mock_result["relevance_score"]
assert result_state.get("extraction_processed_at") == mock_result["processed_at"]
@pytest.mark.asyncio
async def test_extract_category_info_from_state_missing_fields(self):
"""Test category info extraction with missing required fields."""
state = {"content": "Test content"} # Missing url and category
state = _to_state({"content": "Test content"}) # Missing url and category
with patch(
"biz_bud.tools.capabilities.extraction.legacy_tools.logger"
@@ -1028,11 +1057,11 @@ class TestExtractionStateMethods:
@pytest.mark.asyncio
async def test_extract_category_info_from_state_fallback_fields(self):
"""Test category info extraction with fallback field names."""
state = {
state = _to_state({
"text": "Tech content", # Fallback for content
"source_url": "https://example.com", # Fallback for url
"research_category": "technology", # Fallback for category
}
})
mock_result = {
"facts": [],
@@ -1058,11 +1087,11 @@ class TestExtractionStateMethods:
@pytest.mark.asyncio
async def test_extract_category_info_from_state_exception(self):
"""Test category info extraction from state with exception."""
state = {
state = _to_state({
"content": "Test content",
"url": "https://example.com",
"category": "tech",
}
})
with patch(
"biz_bud.tools.capabilities.extraction.legacy_tools.extract_category_information",
@@ -1075,8 +1104,8 @@ class TestExtractionStateMethods:
methods = create_extraction_state_methods()
result_state = await methods["extract_category_info_from_state"](state)
assert "errors" in result_state
assert "Category extraction failed: Category error" in result_state["errors"]
errors = _get_list(result_state.get("errors"))
assert "Category extraction failed: Category error" in errors
mock_logger.error.assert_called_once()
def test_extraction_state_methods_constant(self):
@@ -1167,7 +1196,7 @@ class TestAdditionalCoverage:
def test_category_extraction_input_validation(self):
"""Test CategoryExtractionInput validation."""
# Valid input
input_data = CategoryExtractionLangChainTool.CategoryExtractionInput(
input_data = CategoryExtractionInput(
content="Test content",
url="https://example.com",
category="technology",
@@ -1175,7 +1204,7 @@ class TestAdditionalCoverage:
assert input_data.source_title is None
# Test with all fields
input_complete = CategoryExtractionLangChainTool.CategoryExtractionInput(
input_complete = CategoryExtractionInput(
content="Complete content",
url="https://example.com",
category="finance",
@@ -1186,10 +1215,10 @@ class TestAdditionalCoverage:
@pytest.mark.asyncio
async def test_extract_statistics_from_state_empty_search_results(self):
"""Test statistics extraction with empty search results."""
state = {
state = _to_state({
"search_results": [], # Empty list
"url": "https://example.com",
}
})
with patch(
"biz_bud.tools.capabilities.extraction.legacy_tools.logger"
@@ -1203,10 +1232,10 @@ class TestAdditionalCoverage:
@pytest.mark.asyncio
async def test_extract_statistics_from_state_empty_scraped_content(self):
"""Test statistics extraction with empty scraped content."""
state = {
state = _to_state({
"scraped_content": {}, # Empty dict
"url": "https://example.com",
}
})
with patch(
"biz_bud.tools.capabilities.extraction.legacy_tools.logger"

View File

@@ -1,10 +1,12 @@
"""Comprehensive tests for single URL processor tool."""
from typing import cast
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import ValidationError
from biz_bud.core.types import JSONValue
from biz_bud.tools.capabilities.extraction.single_url_processor import (
ProcessSingleUrlInput,
process_single_url_tool,
@@ -41,7 +43,7 @@ class TestProcessSingleUrlInput:
input_data = ProcessSingleUrlInput(
url="https://test.com/article",
query="Extract technical specifications",
config=complex_config,
config=cast(dict[str, JSONValue], complex_config),
)
assert input_data.config == complex_config

View File

@@ -54,6 +54,14 @@ def parse_action_arguments_impl(text: str) -> dict[str, Any]:
}
def _get_list(value: Any) -> list[Any]:
return value if isinstance(value, list) else []
def _get_dict(value: Any) -> dict[str, Any]:
return value if isinstance(value, dict) else {}
def clean_and_normalize_text_impl(text: str, normalize_quotes: bool = True, normalize_spaces: bool = True, remove_html: bool = True) -> dict[str, Any]:
"""Clean and normalize text (test wrapper)."""
try:
@@ -241,7 +249,7 @@ class TestExtractStructuredContent:
assert result["structured_data"] == mock_result["data"]
assert result["source_type"] == "mixed"
assert result["confidence"] == 0.8
assert set(result["extraction_types"]) == {
assert set(_get_list(result.get("extraction_types"))) == {
"json",
"lists",
"key_value_pairs",
@@ -380,7 +388,7 @@ class TestExtractKeyValueData:
assert result["found"] is True
assert result["key_value_pairs"] == mock_kv_pairs
assert result["total_pairs"] == 3
assert set(result["keys"]) == {"Name", "Age", "City"}
assert set(_get_list(result.get("keys"))) == {"Name", "Age", "City"}
def test_extract_key_value_no_pairs_found(self):
"""Test key-value extraction when no pairs are found."""
@@ -578,7 +586,7 @@ class TestParseActionArguments:
assert result["found"] is True
assert result["action_args"] == mock_args
assert result["total_args"] == 2
assert set(result["arg_keys"]) == {"query", "limit"}
assert set(_get_list(result.get("arg_keys"))) == {"query", "limit"}
def test_parse_action_arguments_no_args_found(self):
"""Test action argument parsing when no args are found."""
@@ -654,19 +662,19 @@ class TestExtractThoughtActionSequences:
result = extract_thought_action_sequences_impl(text)
assert result.get("success") is True
assert result["found"] is True
assert result["total_pairs"] == 2
assert len(result["thought_action_pairs"]) == 2
assert (
result["thought_action_pairs"][0]["thought"]
== "I need to search for information"
)
assert result["thought_action_pairs"][0]["action"] == "search"
assert (
result["thought_action_pairs"][1]["thought"]
== "Now I should analyze the results"
)
assert result["thought_action_pairs"][1]["action"] == "analyze"
assert result.get("found") is True
assert result.get("total_pairs") == 2
pairs = result.get("thought_action_pairs", [])
assert isinstance(pairs, list)
assert len(pairs) == 2
first_pair = pairs[0]
second_pair = pairs[1]
assert isinstance(first_pair, dict)
assert isinstance(second_pair, dict)
assert first_pair.get("thought") == "I need to search for information"
assert first_pair.get("action") == "search"
assert second_pair.get("thought") == "Now I should analyze the results"
assert second_pair.get("action") == "analyze"
def test_extract_thought_action_no_pairs_found(self):
"""Test thought-action extraction when no pairs are found."""
@@ -679,9 +687,10 @@ class TestExtractThoughtActionSequences:
result = extract_thought_action_sequences_impl(text)
assert result.get("success") is True
assert result["found"] is False
assert result["thought_action_pairs"] == []
assert result["total_pairs"] == 0
assert result.get("found") is False
pairs = result.get("thought_action_pairs")
assert isinstance(pairs, list) and pairs == []
assert result.get("total_pairs") == 0
def test_extract_thought_action_exception(self):
"""Test thought-action extraction with exception."""
@@ -697,9 +706,9 @@ class TestExtractThoughtActionSequences:
result = extract_thought_action_sequences_impl(text)
assert result.get("success") is False
assert result["found"] is False
assert result["thought_action_pairs"] == []
assert result["total_pairs"] == 0
assert result.get("found") is False
assert _get_list(result.get("thought_action_pairs")) == []
assert result.get("total_pairs") == 0
assert result["error"] == "TA parsing error"
mock_logger.error.assert_called_once()
@@ -714,10 +723,13 @@ class TestExtractThoughtActionSequences:
result = extract_thought_action_sequences_impl(text)
assert result.get("success") is True
assert result["found"] is True
assert result["total_pairs"] == 1
assert result["thought_action_pairs"][0]["thought"] == "Think"
assert result["thought_action_pairs"][0]["action"] == "Act"
assert result.get("found") is True
assert result.get("total_pairs") == 1
pairs = _get_list(result.get("thought_action_pairs"))
assert len(pairs) == 1
pair_data = _get_dict(pairs[0])
assert pair_data.get("thought") == "Think"
assert pair_data.get("action") == "Act"
class TestCleanAndNormalizeText:
@@ -746,10 +758,14 @@ class TestCleanAndNormalizeText:
assert result["cleaned_text"] == 'Hello "world" Extra spaces'
assert result["original_length"] == len(text)
assert result["cleaned_length"] == len('Hello "world" Extra spaces')
assert "html_removed" in result["transformations_applied"]
assert "quotes_normalized" in result["transformations_applied"]
assert "whitespace_normalized" in result["transformations_applied"]
assert result["reduction_ratio"] > 0
transformations = result.get("transformations_applied") or []
assert isinstance(transformations, list)
assert "html_removed" in transformations
assert "quotes_normalized" in transformations
assert "whitespace_normalized" in transformations
reduction_ratio = result.get("reduction_ratio", 0)
assert isinstance(reduction_ratio, (int, float))
assert reduction_ratio > 0
def test_clean_and_normalize_selective_options(self):
"""Test text cleaning with selective options."""
@@ -863,15 +879,19 @@ Third paragraph."""
result = analyze_text_structure_impl(text)
assert result.get("success") is True
assert result["total_characters"] == len(text)
assert result["total_words"] == len(text.split())
assert result["total_lines"] == len(text.split("\n"))
assert result["total_paragraphs"] == 3
assert result["total_sentences"] == 5
assert result["estimated_tokens"] == mock_token_count
assert result["avg_words_per_sentence"] > 0
assert result["avg_sentences_per_paragraph"] > 0
assert len(result["sentences"]) <= 10 # Preview limit
assert result.get("total_characters") == len(text)
assert result.get("total_words") == len(text.split())
assert result.get("total_lines") == len(text.split("\n"))
assert result.get("total_paragraphs") == 3
assert result.get("total_sentences") == 5
assert result.get("estimated_tokens") == mock_token_count
avg_words = result.get("avg_words_per_sentence", 0)
avg_sentences = result.get("avg_sentences_per_paragraph", 0)
assert isinstance(avg_words, (int, float)) and avg_words > 0
assert isinstance(avg_sentences, (int, float)) and avg_sentences > 0
sentences_preview = result.get("sentences", [])
assert isinstance(sentences_preview, list)
assert len(sentences_preview) <= 10 # Preview limit
def test_analyze_text_structure_single_paragraph(self):
"""Test text structure analysis with single paragraph."""
@@ -970,8 +990,10 @@ Third paragraph."""
result = analyze_text_structure_impl(text)
assert result.get("success") is True
assert result["total_sentences"] == 15
assert len(result["sentences"]) == 10 # Preview limited to first 10
assert result.get("total_sentences") == 15
sentences_preview = result.get("sentences", [])
assert isinstance(sentences_preview, list)
assert len(sentences_preview) == 10 # Preview limited to first 10
def test_analyze_text_structure_whitespace_handling(self):
"""Test text structure analysis with various whitespace scenarios."""

View File

@@ -408,13 +408,11 @@ class TestDiscoveryProvidersIntegration:
# Test that get_discovery_methods returns a list
methods = provider.get_discovery_methods()
methods_is_list = isinstance(methods, list)
methods_not_empty = len(methods) > 0
all_methods_strings = all(isinstance(method, str) for method in methods)
methods_not_empty = bool(methods)
return (has_discover_urls and has_get_discovery_methods and
discover_callable and methods_callable and
methods_is_list and methods_not_empty and all_methods_strings)
methods_not_empty)
provider_test_results = [test_provider_interface(provider) for provider in providers]
failed_provider_indices = [i for i, passed in enumerate(provider_test_results) if not passed]

View File

@@ -35,19 +35,17 @@ class TestBasicValidationProvider:
expected_config = create_validation_config(level=ValidationLevel.BASIC)
assert provider.config == expected_config
assert provider.timeout == expected_config["timeout"]
assert provider.timeout == expected_config.get("timeout", provider.timeout)
def test_initialization_custom_config(self):
"""Test initialization with custom configuration."""
config = create_validation_config(
level=ValidationLevel.BASIC,
timeout=25.0,
retry_attempts=1,
)
config = create_validation_config(level=ValidationLevel.BASIC)
config["timeout"] = 25.0
config["retry_attempts"] = 1
provider = BasicValidationProvider(config)
assert provider.config == config
assert provider.timeout == 25.0
assert provider.timeout == config.get("timeout", provider.timeout)
def test_initialization_none_config(self):
"""Test initialization with None configuration."""
@@ -55,7 +53,7 @@ class TestBasicValidationProvider:
expected_config = create_validation_config(level=ValidationLevel.BASIC)
assert provider.config == expected_config
assert provider.timeout == expected_config["timeout"]
assert provider.timeout == expected_config.get("timeout", provider.timeout)
def test_get_validation_level(self):
"""Test get_validation_level method."""
@@ -145,19 +143,17 @@ class TestStandardValidationProvider:
expected_config = create_validation_config(level=ValidationLevel.STANDARD)
assert provider.config == expected_config
assert provider.timeout == expected_config["timeout"]
assert provider.timeout == expected_config.get("timeout", provider.timeout)
def test_initialization_custom_config(self):
"""Test initialization with custom configuration."""
config = create_validation_config(
level=ValidationLevel.STANDARD,
timeout=45.0,
retry_attempts=3,
)
config = create_validation_config(level=ValidationLevel.STANDARD)
config["timeout"] = 45.0
config["retry_attempts"] = 3
provider = StandardValidationProvider(config)
assert provider.config == config
assert provider.timeout == 45.0
assert provider.timeout == config.get("timeout", provider.timeout)
def test_get_validation_level(self):
"""Test get_validation_level method."""
@@ -336,7 +332,7 @@ class TestStrictValidationProvider:
expected_config = create_validation_config(level=ValidationLevel.STRICT)
assert provider.config == expected_config
assert provider.timeout == expected_config["timeout"]
assert provider.timeout == expected_config.get("timeout", provider.timeout)
assert "application/octet-stream" in provider.blocked_content_types
assert "application/pdf" in provider.blocked_content_types
assert len(provider.blocked_content_types) >= 5
@@ -344,14 +340,12 @@ class TestStrictValidationProvider:
def test_initialization_custom_config(self):
"""Test initialization with custom configuration."""
custom_blocked_types = ["application/json", "text/plain"]
config = create_validation_config(
level=ValidationLevel.STRICT,
timeout=120.0,
blocked_content_types=custom_blocked_types,
)
config = create_validation_config(level=ValidationLevel.STRICT)
config["timeout"] = 120.0
config["blocked_content_types"] = list(custom_blocked_types)
provider = StrictValidationProvider(config)
assert provider.timeout == 120.0
assert provider.timeout == config.get("timeout", provider.timeout)
assert provider.blocked_content_types == custom_blocked_types
def test_get_validation_level(self):

View File

@@ -398,8 +398,8 @@ class TestFactoryFunctions:
config = create_validation_config()
assert isinstance(config, dict)
assert config["timeout"] == 30.0
assert config["retry_attempts"] == 3
assert config.get("timeout") == 30.0
assert config.get("retry_attempts") == 3
assert "blocked_content_types" in config
def test_create_validation_config_custom(self):
@@ -410,8 +410,8 @@ class TestFactoryFunctions:
retry_attempts=5
)
assert config["timeout"] == 60.0
assert config["retry_attempts"] == 5
assert config.get("timeout") == 60.0
assert config.get("retry_attempts") == 5
def test_create_normalization_config_standard(self):
"""Test create_normalization_config with standard strategy."""
@@ -419,25 +419,25 @@ class TestFactoryFunctions:
assert isinstance(config, dict)
# Standard strategy should use defaults
assert config["default_protocol"] == "https"
assert config.get("default_protocol") == "https"
def test_create_normalization_config_conservative(self):
"""Test create_normalization_config with conservative strategy."""
config = create_normalization_config(NormalizationStrategy.CONSERVATIVE)
assert config["normalize_protocol"] is False
assert config["remove_www"] is False
assert config["sort_query_params"] is False
assert config["remove_trailing_slash"] is False
assert config.get("normalize_protocol") is False
assert config.get("remove_www") is False
assert config.get("sort_query_params") is False
assert config.get("remove_trailing_slash") is False
def test_create_normalization_config_aggressive(self):
"""Test create_normalization_config with aggressive strategy."""
config = create_normalization_config(NormalizationStrategy.AGGRESSIVE)
assert config["normalize_protocol"] is True
assert config["remove_www"] is True
assert config["sort_query_params"] is True
assert config["remove_trailing_slash"] is True
assert config.get("normalize_protocol") is True
assert config.get("remove_www") is True
assert config.get("sort_query_params") is True
assert config.get("remove_trailing_slash") is True
def test_create_normalization_config_custom_override(self):
"""Test create_normalization_config with custom overrides."""
@@ -446,8 +446,8 @@ class TestFactoryFunctions:
normalize_protocol=True # Override conservative default
)
assert config["normalize_protocol"] is True # Override applied
assert config["remove_www"] is False # Conservative default maintained
assert config.get("normalize_protocol") is True # Override applied
assert config.get("remove_www") is False # Conservative default maintained
def test_create_normalization_config_type_filtering(self):
"""Test create_normalization_config filters problematic types."""
@@ -457,34 +457,34 @@ class TestFactoryFunctions:
default_protocol=True # Should be filtered out and default applied
)
assert config["default_protocol"] == "https" # Default applied
assert config.get("default_protocol") == "https" # Default applied
def test_create_discovery_config_comprehensive(self):
"""Test create_discovery_config with comprehensive method."""
config = create_discovery_config(DiscoveryMethod.COMPREHENSIVE)
assert config["parse_sitemaps"] is True
assert config["parse_robots_txt"] is True
assert config["extract_links_from_html"] is True
assert config["max_pages"] == 1000
assert config.get("parse_sitemaps") is True
assert config.get("parse_robots_txt") is True
assert config.get("extract_links_from_html") is True
assert config.get("max_pages") == 1000
def test_create_discovery_config_sitemap_only(self):
"""Test create_discovery_config with sitemap only method."""
config = create_discovery_config(DiscoveryMethod.SITEMAP_ONLY)
assert config["parse_sitemaps"] is True
assert config["parse_robots_txt"] is False
assert config["extract_links_from_html"] is False
assert config.get("parse_sitemaps") is True
assert config.get("parse_robots_txt") is False
assert config.get("extract_links_from_html") is False
def test_create_discovery_config_html_parsing(self):
"""Test create_discovery_config with HTML parsing method."""
config = create_discovery_config(DiscoveryMethod.HTML_PARSING)
assert config["parse_sitemaps"] is False
assert config["parse_robots_txt"] is False
assert config["extract_links_from_html"] is True
assert config["max_pages"] == 100 # Lower default for HTML parsing
assert config["max_depth"] == 1
assert config.get("parse_sitemaps") is False
assert config.get("parse_robots_txt") is False
assert config.get("extract_links_from_html") is True
assert config.get("max_pages") == 100 # Lower default for HTML parsing
assert config.get("max_depth") == 1
def test_create_discovery_config_custom_max_pages(self):
"""Test create_discovery_config with custom max_pages."""
@@ -493,7 +493,7 @@ class TestFactoryFunctions:
max_pages=2000
)
assert config["max_pages"] == 2000
assert config.get("max_pages") == 2000
def test_create_discovery_config_type_filtering(self):
"""Test create_discovery_config filters problematic types."""
@@ -504,30 +504,30 @@ class TestFactoryFunctions:
user_agent=True # Should be filtered and default applied
)
assert config["parse_sitemaps"] is False
assert config["user_agent"] == "BusinessBuddy-URLProcessor/1.0"
assert config.get("parse_sitemaps") is False
assert config.get("user_agent") == "BusinessBuddy-URLProcessor/1.0"
def test_create_deduplication_config_hash_based(self):
"""Test create_deduplication_config with hash-based strategy."""
config = create_deduplication_config(DeduplicationStrategy.HASH_BASED)
assert config["cache_enabled"] is True
assert config["cache_size"] == 10000
assert config.get("cache_enabled") is True
assert config.get("cache_size") == 10000
def test_create_deduplication_config_advanced(self):
"""Test create_deduplication_config with advanced strategy."""
config = create_deduplication_config(DeduplicationStrategy.ADVANCED)
assert config["similarity_threshold"] == 0.8
assert config["cache_enabled"] is True
assert config["cache_size"] == 50000 # Larger cache for advanced methods
assert config.get("similarity_threshold") == 0.8
assert config.get("cache_enabled") is True
assert config.get("cache_size") == 50000 # Larger cache for advanced methods
def test_create_deduplication_config_domain_based(self):
"""Test create_deduplication_config with domain-based strategy."""
config = create_deduplication_config(DeduplicationStrategy.DOMAIN_BASED)
assert config["keep_shortest"] is True
assert config["cache_enabled"] is False # Not needed for domain-based
assert config.get("keep_shortest") is True
assert config.get("cache_enabled") is False # Not needed for domain-based
def test_create_deduplication_config_custom_override(self):
"""Test create_deduplication_config with custom overrides."""
@@ -536,7 +536,7 @@ class TestFactoryFunctions:
cache_size=20000
)
assert config["cache_size"] == 20000
assert config.get("cache_size") == 20000
def test_create_deduplication_config_type_filtering(self):
"""Test create_deduplication_config filters problematic types."""
@@ -547,8 +547,8 @@ class TestFactoryFunctions:
cache_size="5000.5" # String float should convert to int
)
assert config["cache_enabled"] is False
assert config["cache_size"] == 5000
assert config.get("cache_enabled") is False
assert config.get("cache_size") == 5000
def test_create_url_processing_config_default(self):
"""Test create_url_processing_config with defaults."""
@@ -596,9 +596,9 @@ class TestFactoryFunctions:
assert isinstance(config.deduplication_config, dict)
# Verify specific strategy settings are reflected
assert config.normalization_config["remove_www"] is False # Conservative
assert config.discovery_config["extract_links_from_html"] is True # HTML parsing
assert config.deduplication_config["keep_shortest"] is True # Domain-based
assert config.normalization_config.get("remove_www") is False # Conservative
assert config.discovery_config.get("extract_links_from_html") is True # HTML parsing
assert config.deduplication_config.get("keep_shortest") is True # Domain-based
def test_create_url_processing_config_with_valid_kwargs(self):
"""Test create_url_processing_config with valid additional kwargs."""
@@ -624,7 +624,7 @@ class TestEdgeCases:
NormalizationStrategy.STANDARD,
default_protocol=None
)
assert config["default_protocol"] == "https" # Default should be applied
assert config.get("default_protocol") == "https" # Default should be applied
def test_discovery_config_edge_cases(self):
"""Test edge cases in discovery config creation."""
@@ -634,8 +634,8 @@ class TestEdgeCases:
parse_sitemaps=1, # Should convert to True
parse_robots_txt=0 # Should convert to False
)
assert config["parse_sitemaps"] is True
assert config["parse_robots_txt"] is False
assert config.get("parse_sitemaps") is True
assert config.get("parse_robots_txt") is False
def test_deduplication_config_edge_cases(self):
"""Test edge cases in deduplication config creation."""
@@ -644,7 +644,7 @@ class TestEdgeCases:
DeduplicationStrategy.HASH_BASED,
cache_size="invalid" # Should use default
)
assert config["cache_size"] == 1000 # Default fallback
assert config.get("cache_size") == 1000 # Default fallback
def test_factory_functions_with_empty_dicts(self):
"""Test factory functions with empty configuration dictionaries."""

View File

@@ -145,7 +145,7 @@ class TestURLNormalizationProvider:
# Should not raise exception
provider = CompleteProvider()
assert provider.normalize_url("HTTP://EXAMPLE.COM") == "http://example.com"
assert provider.get_normalization_config()["lowercase_domain"] is True
assert provider.get_normalization_config().get("lowercase_domain") is True
class TestURLDiscoveryProvider:
@@ -471,7 +471,7 @@ class TestMultipleInheritanceScenarios:
provider = MultiProvider()
assert provider.get_validation_level() == "test"
assert provider.normalize_url("TEST") == "test"
assert provider.get_normalization_config()["lowercase_domain"] is True
assert provider.get_normalization_config().get("lowercase_domain") is True
def test_partial_implementation_fails(self):
"""Test that partial implementation of multiple interfaces fails."""
@@ -515,7 +515,7 @@ class TestMultipleInheritanceScenarios:
provider = OrderedProvider()
ordered_config = provider.get_normalization_config()
assert ordered_config["lowercase_domain"] is True
assert ordered_config.get("lowercase_domain") is True
# Check MRO includes all expected classes
mro = OrderedProvider.__mro__