diff --git a/.claude/settings.local.json b/.claude/settings.local.json index eb6778f7..fceae82b 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -181,7 +181,10 @@ "Bash(timeout 30 python -m pytest:*)", "Bash(timeout 60 python -m pytest:*)", "mcp__postgres-kgr__describe_table", - "mcp__postgres-kgr__execute_query" + "mcp__postgres-kgr__execute_query", + "Bash(timeout 60s pyrefly check src --no-cache)", + "Bash(timeout 60s pyrefly check src)", + "Bash(timeout 30s pyrefly check src/biz_bud/core/errors/base.py)" ], "deny": [] }, diff --git a/.sonar/report-task.txt b/.sonar/report-task.txt index 3ecca953..27c55f3b 100644 --- a/.sonar/report-task.txt +++ b/.sonar/report-task.txt @@ -2,5 +2,5 @@ projectKey=vasceannie_biz-bud_6c113581-e663-4a15-8a76-1ce5dab23a5f serverUrl=http://sonar.lab serverVersion=25.7.0.110598 dashboardUrl=http://sonar.lab/dashboard?id=vasceannie_biz-bud_6c113581-e663-4a15-8a76-1ce5dab23a5f -ceTaskId=06a875ae-ce12-424f-abef-5b9ce511b87f -ceTaskUrl=http://sonar.lab/api/ce/task?id=06a875ae-ce12-424f-abef-5b9ce511b87f +ceTaskId=eec4f5f1-80af-4f1c-83a9-adb130a69d92 +ceTaskUrl=http://sonar.lab/api/ce/task?id=eec4f5f1-80af-4f1c-83a9-adb130a69d92 diff --git a/pyrefly.toml b/pyrefly.toml index 7dc4a52f..b0f90a45 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -26,7 +26,7 @@ project_excludes = [ ".archive/", "**/.archive/", "cache/", - "examples/", + "examples/**", ".cenv/**", ".venv-host/**", "**/.venv/**", @@ -35,7 +35,9 @@ project_excludes = [ "**/lib/python*/**", "**/bin/**", "**/include/**", - "**/share/**" + "**/share/**", + ".backup/**", + "**/.backup/**" ] # Search paths for module resolution diff --git a/src/biz_bud/core/caching/redis.py b/src/biz_bud/core/caching/redis.py index 1c7ac0e7..ac78f683 100644 --- a/src/biz_bud/core/caching/redis.py +++ b/src/biz_bud/core/caching/redis.py @@ -138,12 +138,19 @@ class RedisCache(CacheBackend): async def close(self) -> None: """Close Redis connection.""" - if self._client: - await self._client.close() + if self._client is not None: + from typing import cast + client = cast(redis.Redis, self._client) + await client.close() + async def health_check(self) -> bool: """Check if Redis is available.""" try: - return False if self._client is None else await self._client.ping() + if self._client is None: + return False + from typing import cast + client = cast(redis.Redis, self._client) + return await client.ping() except Exception: return False diff --git a/src/biz_bud/core/errors/base.py b/src/biz_bud/core/errors/base.py index 38b3b968..762a13dd 100644 --- a/src/biz_bud/core/errors/base.py +++ b/src/biz_bud/core/errors/base.py @@ -395,14 +395,14 @@ class ErrorRegistry: def get_registry_summary(self) -> dict[str, Any]: """Get summary statistics of the error registry.""" - category_counts = {} - severity_counts = {} + category_counts: dict[str, int] = {} + severity_counts: dict[str, int] = {} for category in self._category_mapping.values(): - category_counts[category.value] = category_counts.get(category.value, 0) + 1 + category_counts[category.value] = cast("dict[str, int]", category_counts).get(category.value, 0) + 1 for severity in self._severity_mapping.values(): - severity_counts[severity.value] = severity_counts.get(severity.value, 0) + 1 + severity_counts[severity.value] = cast("dict[str, int]", severity_counts).get(severity.value, 0) + 1 return { "total_errors": len(self._error_definitions), @@ -1151,7 +1151,7 @@ def handle_exception_group[F: Callable[..., Any]](func: F) -> F: if hasattr(eg, "exceptions"): # This is an exception group exceptions = getattr(eg, "exceptions", []) - error_messages = [] + error_messages: list[str] = [] for i, e in enumerate(exceptions, 1): error_messages.append(f"Error {i}: {type(e).__name__}: {str(e)}") @@ -1170,7 +1170,7 @@ def handle_exception_group[F: Callable[..., Any]](func: F) -> F: if hasattr(eg, "exceptions"): # This is an exception group exceptions = getattr(eg, "exceptions", []) - error_messages = [] + error_messages: list[str] = [] for i, e in enumerate(exceptions, 1): error_messages.append(f"Error {i}: {type(e).__name__}: {str(e)}") @@ -1461,8 +1461,8 @@ class ErrorHandler: category_key = error.category.value severity_key = error.severity.value - by_category[category_key] = by_category.get(category_key, 0) + 1 - by_severity[severity_key] = by_severity.get(severity_key, 0) + 1 + by_category[category_key] = cast("dict[str, int]", by_category).get(category_key, 0) + 1 + by_severity[severity_key] = cast("dict[str, int]", by_severity).get(severity_key, 0) + 1 # Recent errors (last 5) recent_errors = [ @@ -1520,9 +1520,9 @@ def aggregate_errors( category = error.get("category", "unknown") severity = error.get("severity", "error") - by_type[error_type] = by_type.get(error_type, 0) + 1 - by_category[category] = by_category.get(category, 0) + 1 - by_severity[severity] = by_severity.get(severity, 0) + 1 + by_type[error_type] = cast("dict[str, int]", by_type).get(error_type, 0) + 1 + by_category[category] = cast("dict[str, int]", by_category).get(category, 0) + 1 + by_severity[severity] = cast("dict[str, int]", by_severity).get(severity, 0) + 1 return { "total": len(errors), diff --git a/src/biz_bud/core/langgraph/cross_cutting.py b/src/biz_bud/core/langgraph/cross_cutting.py index ccf24970..cf31f776 100644 --- a/src/biz_bud/core/langgraph/cross_cutting.py +++ b/src/biz_bud/core/langgraph/cross_cutting.py @@ -143,6 +143,76 @@ def log_node_execution( return decorator +def _initialize_metric(state: dict[str, Any] | None, metric_name: str) -> NodeMetric | None: + """Initialize or retrieve a metric from state. + + Args: + state: The state dictionary or None + metric_name: Name of the metric to track + + Returns: + The initialized or existing metric, or None if state is None + """ + if state is None: + return None + + if "metrics" not in state: + state["metrics"] = {} + + metrics = state["metrics"] + + if metric_name not in metrics: + metrics[metric_name] = NodeMetric( + count=0, + success_count=0, + failure_count=0, + total_duration_ms=0.0, + avg_duration_ms=0.0, + last_execution=None, + last_error=None, + ) + + metric = cast("NodeMetric", metrics[metric_name]) + metric["count"] = (metric["count"] or 0) + 1 + return metric + + +def _update_metric_success(metric: NodeMetric | None, elapsed_ms: float) -> None: + """Update metric for successful execution. + + Args: + metric: The metric to update + elapsed_ms: Elapsed time in milliseconds + """ + if metric is None: + return + + metric["success_count"] = (metric["success_count"] or 0) + 1 + metric["total_duration_ms"] = (metric["total_duration_ms"] or 0.0) + elapsed_ms + count = metric["count"] or 1 + metric["avg_duration_ms"] = (metric["total_duration_ms"] or 0.0) / count + metric["last_execution"] = datetime.now(UTC).isoformat() + + +def _update_metric_failure(metric: NodeMetric | None, elapsed_ms: float, error: Exception) -> None: + """Update metric for failed execution. + + Args: + metric: The metric to update + elapsed_ms: Elapsed time in milliseconds + error: The exception that occurred + """ + if metric is None: + return + + metric["failure_count"] = (metric["failure_count"] or 0) + 1 + metric["total_duration_ms"] = (metric["total_duration_ms"] or 0.0) + elapsed_ms + count = metric["count"] or 1 + metric["avg_duration_ms"] = (metric["total_duration_ms"] or 0.0) / count + metric["last_execution"] = datetime.now(UTC).isoformat() + metric["last_error"] = str(error) + + def track_metrics( metric_name: str, ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: @@ -162,126 +232,33 @@ def track_metrics( @functools.wraps(func) async def async_wrapper(*args: Any, **kwargs: Any) -> Any: start_time = time.time() - - # Get state from args (first argument is usually state) state = args[0] if args and isinstance(args[0], dict) else None - metric: NodeMetric | None = None - - # Initialize metrics in state if not present - if state is not None: - if "metrics" not in state: - state["metrics"] = {} - - metrics = state["metrics"] - - # Initialize metric tracking - if metric_name not in metrics: - metrics[metric_name] = NodeMetric( - count=0, - success_count=0, - failure_count=0, - total_duration_ms=0.0, - avg_duration_ms=0.0, - last_execution=None, - last_error=None, - ) - - metric = cast("NodeMetric", metrics[metric_name]) - metric["count"] = (metric["count"] or 0) + 1 + metric = _initialize_metric(state, metric_name) try: result = await func(*args, **kwargs) - - # Update success metrics - if state is not None and metric is not None: - elapsed_ms = (time.time() - start_time) * 1000 - metric["success_count"] = (metric["success_count"] or 0) + 1 - metric["total_duration_ms"] = ( - metric["total_duration_ms"] or 0.0 - ) + elapsed_ms - count = metric["count"] or 1 - metric["avg_duration_ms"] = ( - metric["total_duration_ms"] or 0.0 - ) / count - metric["last_execution"] = datetime.now(UTC).isoformat() - + elapsed_ms = (time.time() - start_time) * 1000 + _update_metric_success(metric, elapsed_ms) return result - except Exception as e: - # Update failure metrics - if state is not None and metric is not None: - elapsed_ms = (time.time() - start_time) * 1000 - metric["failure_count"] = (metric["failure_count"] or 0) + 1 - metric["total_duration_ms"] = ( - metric["total_duration_ms"] or 0.0 - ) + elapsed_ms - count = metric["count"] or 1 - metric["avg_duration_ms"] = ( - metric["total_duration_ms"] or 0.0 - ) / count - metric["last_execution"] = datetime.now(UTC).isoformat() - metric["last_error"] = str(e) - + elapsed_ms = (time.time() - start_time) * 1000 + _update_metric_failure(metric, elapsed_ms, e) raise @functools.wraps(func) def sync_wrapper(*args: Any, **kwargs: Any) -> Any: - # Similar implementation for sync functions start_time = time.time() state = args[0] if args and isinstance(args[0], dict) else None - metric: NodeMetric | None = None - - if state is not None: - if "metrics" not in state: - state["metrics"] = {} - - metrics = state["metrics"] - - if metric_name not in metrics: - metrics[metric_name] = NodeMetric( - count=0, - success_count=0, - failure_count=0, - total_duration_ms=0.0, - avg_duration_ms=0.0, - last_execution=None, - last_error=None, - ) - - metric = cast("NodeMetric", metrics[metric_name]) - metric["count"] = (metric["count"] or 0) + 1 + metric = _initialize_metric(state, metric_name) try: result = func(*args, **kwargs) - - if state is not None and metric is not None: - elapsed_ms = (time.time() - start_time) * 1000 - metric["success_count"] = (metric["success_count"] or 0) + 1 - metric["total_duration_ms"] = ( - metric["total_duration_ms"] or 0.0 - ) + elapsed_ms - count = metric["count"] or 1 - metric["avg_duration_ms"] = ( - metric["total_duration_ms"] or 0.0 - ) / count - metric["last_execution"] = datetime.now(UTC).isoformat() - + elapsed_ms = (time.time() - start_time) * 1000 + _update_metric_success(metric, elapsed_ms) return result - except Exception as e: - if state is not None and metric is not None: - elapsed_ms = (time.time() - start_time) * 1000 - metric["failure_count"] = (metric["failure_count"] or 0) + 1 - metric["total_duration_ms"] = ( - metric["total_duration_ms"] or 0.0 - ) + elapsed_ms - count = metric["count"] or 1 - metric["avg_duration_ms"] = ( - metric["total_duration_ms"] or 0.0 - ) / count - metric["last_execution"] = datetime.now(UTC).isoformat() - metric["last_error"] = str(e) - + elapsed_ms = (time.time() - start_time) * 1000 + _update_metric_failure(metric, elapsed_ms, e) raise return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper diff --git a/src/biz_bud/core/langgraph/state_immutability.py b/src/biz_bud/core/langgraph/state_immutability.py index 2c8eb174..2a03ac8f 100644 --- a/src/biz_bud/core/langgraph/state_immutability.py +++ b/src/biz_bud/core/langgraph/state_immutability.py @@ -331,6 +331,7 @@ def update_state_immutably( """ # Deep copy the current state into a regular dict # If it's an ImmutableDict, convert to regular dict first + new_state: dict[str, Any] if isinstance(current_state, ImmutableDict): new_state = {key: copy.deepcopy(value) for key, value in current_state.items()} else: diff --git a/src/biz_bud/core/networking/async_utils.py b/src/biz_bud/core/networking/async_utils.py index 967e4341..61f729e5 100644 --- a/src/biz_bud/core/networking/async_utils.py +++ b/src/biz_bud/core/networking/async_utils.py @@ -3,6 +3,7 @@ import asyncio import functools import inspect +import os import sys import time from collections.abc import Awaitable, Callable, Coroutine @@ -12,10 +13,108 @@ from typing import Any, ParamSpec, TypeVar, cast from biz_bud.core.errors import BusinessBuddyError, ValidationError +try: + import psutil +except ImportError: + psutil = None + +PSUTIL_AVAILABLE = psutil is not None + T = TypeVar("T") R = TypeVar("R") P = ParamSpec("P") +__all__ = [ + "get_memory_usage_percent", + "apply_memory_backpressure", + "calculate_optimal_concurrency", + "gather_with_concurrency", + "retry_async", + "RateLimiter", + "with_timeout", + "to_async", + "process_items_in_parallel", + "ChainLink", + "run_async_chain", + "AsyncContextInfo", + "detect_async_context", + "run_in_appropriate_context", + "create_async_sync_wrapper", + "handle_sync_async_context", +] + + +def get_memory_usage_percent() -> float: + """Get current memory usage percentage. + + Returns: + Memory usage as percentage (0-100), or 0 if psutil not available + """ + if not PSUTIL_AVAILABLE or psutil is None: + return 0.0 + + try: + memory_info = psutil.virtual_memory() + return memory_info.percent + except Exception: + return 0.0 + + +def apply_memory_backpressure(concurrency: int, memory_percent: float) -> int: + """Apply backpressure based on memory usage. + + Reduces concurrency when memory usage is high to prevent system instability. + + Args: + concurrency: Current concurrency level + memory_percent: Current memory usage percentage + + Returns: + Adjusted concurrency level considering memory pressure + """ + if memory_percent > 90: + # Critical memory usage - reduce to minimum + return max(2, concurrency // 4) + elif memory_percent > 80: + # High memory usage - reduce significantly + return max(3, concurrency // 2) + elif memory_percent > 70: + # Moderate memory usage - reduce moderately + return max(4, int(concurrency * 0.75)) + else: + # Normal memory usage - no reduction + return concurrency + + +def calculate_optimal_concurrency(base_concurrency: int) -> int: + """Calculate optimal concurrency based on available CPU cores and memory. + + Uses a formula that balances CPU utilization with memory constraints. + For I/O-bound tasks like LLM API calls, we can safely exceed CPU core count. + + Args: + base_concurrency: Base concurrency value from configuration + + Returns: + Optimal concurrency value considering system resources + """ + cpu_count = os.cpu_count() + if cpu_count is None: + # Fallback if CPU count detection fails + return base_concurrency + + # For I/O-bound LLM operations, use 2-3x CPU cores as optimal + # Cap at reasonable maximum to prevent resource exhaustion + optimal = min(base_concurrency, max(4, cpu_count * 2)) + + # Apply memory-based backpressure + memory_percent = get_memory_usage_percent() + if memory_percent > 0: # Only apply if we can measure memory + optimal = apply_memory_backpressure(optimal, memory_percent) + + # Ensure we don't go below minimum viable concurrency + return max(optimal, 2) + async def gather_with_concurrency[T]( # noqa: D103 n: int, diff --git a/src/biz_bud/core/networking/http_client.py b/src/biz_bud/core/networking/http_client.py index d5d474d9..0e343ddb 100644 --- a/src/biz_bud/core/networking/http_client.py +++ b/src/biz_bud/core/networking/http_client.py @@ -55,7 +55,7 @@ class HTTPClient: The singleton HTTPClient instance """ if cls._instance is None: - cls._instance = super().__new__(cls) + cls._instance = cast("HTTPClient", super().__new__(cls)) return cast(HTTPClient, cls._instance) def __init__(self, config: HTTPClientConfig | None = None) -> None: @@ -153,8 +153,9 @@ class HTTPClient: Note: This closes the singleton session for all instances. Should only be called during application shutdown. """ - if HTTPClient._session: - await HTTPClient._session.close() + if HTTPClient._session is not None: + session = cast(aiohttp.ClientSession, HTTPClient._session) + await session.close() HTTPClient._session = None async def request(self, options: RequestOptions) -> HTTPResponse: @@ -215,7 +216,8 @@ class HTTPClient: try: if HTTPClient._session is None: raise RuntimeError("Session not initialized") - async with HTTPClient._session.request(method, url, **kwargs) as resp: + session = cast(aiohttp.ClientSession, HTTPClient._session) + async with session.request(method, url, **kwargs) as resp: content = await resp.read() text = None json_data = None diff --git a/src/biz_bud/core/services/http_service.py b/src/biz_bud/core/services/http_service.py index 622e310d..a6c19421 100644 --- a/src/biz_bud/core/services/http_service.py +++ b/src/biz_bud/core/services/http_service.py @@ -33,7 +33,7 @@ Example: from __future__ import annotations import contextlib -from typing import TYPE_CHECKING, Annotated, Any, cast +from typing import TYPE_CHECKING, Annotated, Any, assert_type, cast import aiohttp from pydantic import BaseModel, Field @@ -201,11 +201,13 @@ class HTTPClientService: """ if self._session is not None: logger.info("Cleaning up HTTPClientService session") - await self._session.close() + session = cast(aiohttp.ClientSession, self._session) + await session.close() self._session = None if self._connector is not None: - await self._connector.close() + connector = cast(aiohttp.TCPConnector, self._connector) + await connector.close() self._connector = None logger.info("HTTPClientService cleaned up successfully") @@ -217,12 +219,13 @@ class HTTPClientService: True if the session is initialized and not closed, False otherwise. """ try: - return ( - self._session is not None - and not self._session.closed - and self._connector is not None - and not self._connector.closed - ) + if self._session is None or self._connector is None: + return False + + session = cast(aiohttp.ClientSession, self._session) + connector = cast(aiohttp.TCPConnector, self._connector) + + return not session.closed and not connector.closed except Exception: return False @@ -319,8 +322,8 @@ class HTTPClientService: NetworkError: On network failures. """ try: - assert self._session is not None - async with self._session.request(method, url, **kwargs) as resp: + session = cast(aiohttp.ClientSession, self._session) + async with session.request(method, url, **kwargs) as resp: return await self._process_response_content(resp) except aiohttp.ClientError as e: diff --git a/src/biz_bud/core/services/registry.py b/src/biz_bud/core/services/registry.py index a31aff3a..40579041 100644 --- a/src/biz_bud/core/services/registry.py +++ b/src/biz_bud/core/services/registry.py @@ -509,9 +509,10 @@ class ServiceRegistry: if service_type in self._services: service = self._services[service_type] if hasattr(service, 'cleanup'): - # Call _cleanup_service to get the coroutine and capture it - coro = self._cleanup_service(service_type, service) - cleanup_coroutines.append(lambda c=coro: c) + # Create a wrapper function that returns the coroutine + async def cleanup_wrapper(stype=service_type, svc=service) -> None: + await self._cleanup_service(stype, svc) + cleanup_coroutines.append(cleanup_wrapper) return cleanup_coroutines async def _execute_cleanup_batch(self, service_batch: list[type[Any]]) -> None: diff --git a/src/biz_bud/core/utils/message_helpers.py b/src/biz_bud/core/utils/message_helpers.py index a801faff..f8e49546 100644 --- a/src/biz_bud/core/utils/message_helpers.py +++ b/src/biz_bud/core/utils/message_helpers.py @@ -55,7 +55,7 @@ def normalize_content(content: Any) -> str: text = content elif isinstance(content, list): # Handle list content (e.g., multimodal messages) - text_parts = [] + text_parts: list[str] = [] for item in content: if isinstance(item, dict) and item.get("type") == "text": text_parts.append(str(item.get("text", ""))) @@ -579,13 +579,13 @@ def _create_fallback_summary(messages: list["BaseMessage"]) -> "BaseMessage": return SystemMessage(content="CONVERSATION SUMMARY: No previous conversation.") # Count message types - message_counts = {} + message_counts: dict[str, int] = {} tool_calls = [] recent_content = [] for msg in messages[-10:]: # Look at last 10 messages msg_type = msg.__class__.__name__.replace("Message", "") - message_counts[msg_type] = message_counts.get(msg_type, 0) + 1 + message_counts[msg_type] = cast(dict[str, int], message_counts).get(msg_type, 0) + 1 content = normalize_content(getattr(msg, "content", "")) if content and len(content) < 100: # Include short messages diff --git a/src/biz_bud/core/utils/url_analyzer.py b/src/biz_bud/core/utils/url_analyzer.py index 567dfca3..d89ff43a 100644 --- a/src/biz_bud/core/utils/url_analyzer.py +++ b/src/biz_bud/core/utils/url_analyzer.py @@ -1210,7 +1210,7 @@ class URLAnalyzer: "query_params": query_dict, "fragment": parsed.fragment or None, } - result["metadata"] = metadata + result["metadata"] = cast("dict[str, Any]", metadata) return result def get_url_metadata(self, url: str) -> dict[str, Any]: diff --git a/src/biz_bud/core/validation/base.py b/src/biz_bud/core/validation/base.py index 4e16d7c1..c0ed8633 100644 --- a/src/biz_bud/core/validation/base.py +++ b/src/biz_bud/core/validation/base.py @@ -204,7 +204,7 @@ class AnyValidator(CompositeValidator): def validate(self, value: Any) -> tuple[bool, str | None]: # noqa: ANN401 """Run validators until one passes.""" - errors = [] + errors: list[str] = [] for validator in self.validators: is_valid, error = validator.validate(value) if is_valid: diff --git a/src/biz_bud/core/validation/merge.py b/src/biz_bud/core/validation/merge.py index a9c9e3c8..e92c7400 100644 --- a/src/biz_bud/core/validation/merge.py +++ b/src/biz_bud/core/validation/merge.py @@ -10,7 +10,7 @@ detection logic is excluded. import json import time -from typing import Any +from typing import Any, cast # LEGITIMATE USE OF ANY - JSON PROCESSING MODULE @@ -208,7 +208,7 @@ def _finalize_averages( to_remove = [] for field, strategy in merge_strategy.items(): if strategy == "average": - value = merged.get(field) + value = cast(dict[str, Any], merged).get(field) if isinstance(value, list) and value: nums = [v for v in value if isinstance(v, int | float)] merged[field] = sum(nums) / len(nums) if nums else None diff --git a/src/biz_bud/graphs/catalog/nodes.py b/src/biz_bud/graphs/catalog/nodes.py index a1142c28..9c0b0ffe 100644 --- a/src/biz_bud/graphs/catalog/nodes.py +++ b/src/biz_bud/graphs/catalog/nodes.py @@ -7,17 +7,12 @@ component analysis, and catalog intelligence operations. from __future__ import annotations -from typing import Any - -from langchain_core.runnables import RunnableConfig - -from biz_bud.core.errors import create_error_info -from biz_bud.core.langgraph import standard_node -from biz_bud.logging import debug_highlight, error_highlight, info_highlight - # Import from local nodes directory try: - from .nodes.analysis import catalog_impact_analysis_node + from .nodes.analysis import ( + catalog_impact_analysis_node, + catalog_optimization_node, + ) from .nodes.c_intel import ( batch_analyze_components_node, find_affected_catalog_items_node, @@ -46,112 +41,9 @@ except ImportError: research_catalog_item_components_node = None load_catalog_data_node = None catalog_impact_analysis_node = None + catalog_optimization_node = None -@standard_node(node_name="catalog_optimization", metric_name="catalog_optimization") -async def catalog_optimization_node( - state: dict[str, Any], config: RunnableConfig | None = None -) -> dict[str, Any]: - """Generate optimization recommendations for the catalog. - - This node analyzes the catalog structure, pricing, and components - to provide actionable optimization recommendations. - - Args: - state: Current workflow state - config: Optional runtime configuration - - Returns: - Updated state with optimization recommendations - """ - debug_highlight( - "Generating catalog optimization recommendations...", category="CatalogOptimization" - ) - - # Get analysis data - impact_analysis = state.get("impact_analysis", {}) - catalog_data = state.get("catalog_data", {}) - - try: - optimization_report: dict[str, Any] = { - "recommendations": [], - "priority_actions": [], - "cost_savings_potential": {}, - "efficiency_improvements": [], - } - - # Analyze catalog structure - total_items = sum( - len(items) if isinstance(items, list) else 0 for items in catalog_data.values() - ) - - # Generate recommendations based on analysis - if affected_items := impact_analysis.get("affected_items", []): - # Component optimization - if len(affected_items) > 5: - optimization_report["recommendations"].append( - { - "type": "component_standardization", - "description": f"Standardize component usage across {len(affected_items)} items", - "impact": "high", - "effort": "medium", - } - ) - - if high_price_items := [item for item in affected_items if item.get("price", 0) > 15]: - optimization_report["priority_actions"].append( - { - "action": "Review pricing strategy", - "reason": f"{len(high_price_items)} high-value items affected", - "urgency": "high", - } - ) - - # Catalog structure optimization - if total_items > 50: - optimization_report["efficiency_improvements"].append( - { - "area": "catalog_structure", - "suggestion": "Consider categorization refinement", - "benefit": "Improved navigation and management", - } - ) - - # Cost savings analysis - optimization_report["cost_savings_potential"] = { - "component_consolidation": "5-10%", - "supplier_optimization": "3-7%", - "menu_engineering": "10-15%", - } - - info_highlight( - f"Optimization report generated with {len(optimization_report['recommendations'])} recommendations", - category="CatalogOptimization", - ) - - return { - "optimization_report": optimization_report, - "report_metadata": { - "total_items_analyzed": total_items, - "recommendations_count": len(optimization_report["recommendations"]), - "priority_actions_count": len(optimization_report["priority_actions"]), - }, - } - - except Exception as e: - error_msg = f"Catalog optimization failed: {str(e)}" - error_highlight(error_msg, category="CatalogOptimization") - return { - "optimization_report": {}, - "errors": [ - create_error_info( - message=error_msg, - node="catalog_optimization", - severity="error", - category="optimization_error", - ) - ], - } # Export all catalog-specific nodes diff --git a/src/biz_bud/graphs/catalog/nodes/c_intel.py b/src/biz_bud/graphs/catalog/nodes/c_intel.py index 732bde04..234a9af8 100644 --- a/src/biz_bud/graphs/catalog/nodes/c_intel.py +++ b/src/biz_bud/graphs/catalog/nodes/c_intel.py @@ -5,7 +5,7 @@ including database queries and business logic. """ import re -from typing import Any +from typing import Any, cast from langchain_core.runnables import RunnableConfig @@ -607,10 +607,10 @@ def _generate_basic_catalog_suggestions( all_components.extend(components) # Count component frequency - component_counts = {} + component_counts: dict[str, int] = {} for component in all_components: if isinstance(component, str): - component_counts[component] = component_counts.get(component, 0) + 1 + component_counts[component] = cast(dict[str, int], component_counts).get(component, 0) + 1 if common_components := sorted( component_counts.items(), key=lambda x: x[1], reverse=True diff --git a/src/biz_bud/graphs/rag/nodes/analyzer.py b/src/biz_bud/graphs/rag/nodes/analyzer.py index 77f69661..5ba92eb1 100644 --- a/src/biz_bud/graphs/rag/nodes/analyzer.py +++ b/src/biz_bud/graphs/rag/nodes/analyzer.py @@ -1,6 +1,6 @@ """Analyze scraped content to determine optimal R2R upload configuration.""" -from typing import TYPE_CHECKING, Any, TypedDict +from typing import TYPE_CHECKING, Any, TypedDict, cast from langchain_core.runnables import RunnableConfig @@ -318,7 +318,7 @@ async def analyze_content_for_rag_node( if "r2r_config" in page: chunk_size = page["r2r_config"]["chunk_size"] if isinstance(chunk_size, int): - config_summary[chunk_size] = config_summary.get(chunk_size, 0) + 1 + config_summary[chunk_size] = cast(dict[int, int], config_summary).get(chunk_size, 0) + 1 logger.info(f"Chunk size distribution: {config_summary}") diff --git a/src/biz_bud/graphs/rag/nodes/processing.py b/src/biz_bud/graphs/rag/nodes/processing.py index 4825c5e9..302c4922 100644 --- a/src/biz_bud/graphs/rag/nodes/processing.py +++ b/src/biz_bud/graphs/rag/nodes/processing.py @@ -6,7 +6,7 @@ into a single module following the standardized graph pattern. from __future__ import annotations -from typing import Any, TypedDict +from typing import Any, TypedDict, cast from langchain_core.runnables import RunnableConfig @@ -547,10 +547,10 @@ async def scrape_status_summary_node( len(r.get("extracted_text", "")) for r in successful_results ) - content_types = {} + content_types: dict[str, int] = {} for result in successful_results: content_type = result.get("content_type", "unknown") - content_types[content_type] = content_types.get(content_type, 0) + 1 + content_types[content_type] = cast(dict[str, int], content_types).get(content_type, 0) + 1 # Generate summary summary: dict[str, Any] = { diff --git a/src/biz_bud/graphs/rag/nodes/scraping/url_discovery.py b/src/biz_bud/graphs/rag/nodes/scraping/url_discovery.py index c6376da6..b8f1b8cf 100644 --- a/src/biz_bud/graphs/rag/nodes/scraping/url_discovery.py +++ b/src/biz_bud/graphs/rag/nodes/scraping/url_discovery.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from langchain_core.runnables import RunnableConfig @@ -38,7 +38,7 @@ async def discover_urls_node( input_url = state.get("input_url", "") # config_dict = state.get("config", {}) # TODO: Use config if needed - scrape_params = state.get("scrape_params", {}) + scrape_params = cast(dict[str, Any], state.get("scrape_params", {})) if not input_url: logger.error("No input URL provided for URL discovery") @@ -50,8 +50,8 @@ async def discover_urls_node( } # Set up batch processing parameters - batch_size = scrape_params.get("batch_size", 20) - max_pages = scrape_params.get("max_pages", 50) + batch_size = int(scrape_params.get("batch_size", 20)) + max_pages = int(scrape_params.get("max_pages", 50)) try: logger.info(f"Discovering URLs from {input_url} with limit {max_pages}") diff --git a/src/biz_bud/graphs/research/nodes/synthesis.py b/src/biz_bud/graphs/research/nodes/synthesis.py index 47161c43..a238531c 100644 --- a/src/biz_bud/graphs/research/nodes/synthesis.py +++ b/src/biz_bud/graphs/research/nodes/synthesis.py @@ -63,7 +63,7 @@ async def _filter_chunks_by_relevance( batch = chunks[i : i + batch_size] # Create filtering prompt - chunk_texts = [] + chunk_texts: list[str] = [] for j, chunk in enumerate(batch): content = ( chunk.get("content") @@ -414,7 +414,7 @@ async def synthesize_search_results( if not extracted_info and sources: logger.info("No extracted_info found, attempting to extract from sources") logger.info(f"Found {len(sources)} sources to process") - extracted_info = {} + extracted_info = cast("dict[str, Any]", {}) for i, source_raw in enumerate(sources): if not isinstance(source_raw, dict): continue @@ -480,7 +480,7 @@ async def synthesize_search_results( ) # Convert search results to extracted_info format for synthesis - extracted_info = {} + extracted_info = cast("dict[str, Any]", {}) sources = [] for i, result in enumerate(search_results[:10]): # Limit to top 10 if not isinstance(result, dict): @@ -505,7 +505,7 @@ async def synthesize_search_results( # Create extracted_info entry with better defaults source_key = f"source_{i}" content = description or title or "" - extracted_info[source_key] = { + cast("dict[str, Any]", extracted_info)[source_key] = { "content": content, "url": url, "title": title, @@ -582,7 +582,7 @@ async def synthesize_search_results( logger.warning(warning_msg) # Try to extract content from sources - extracted_info = {} + extracted_info = cast("dict[str, Any]", {}) for i, source_raw in enumerate(sources): if isinstance(source_raw, dict): source = cast("dict[str, Any]", source_raw) diff --git a/src/biz_bud/logging/unified_logging.py b/src/biz_bud/logging/unified_logging.py index ef64d09c..450473c1 100644 --- a/src/biz_bud/logging/unified_logging.py +++ b/src/biz_bud/logging/unified_logging.py @@ -267,15 +267,15 @@ class LogAggregator: if not self.logs: return {"total": 0, "by_level": {}, "by_logger": {}} - by_level = {} - by_logger = {} + by_level: dict[str, int] = {} + by_logger: dict[str, int] = {} for log in self.logs: level = log["level"] logger = log["logger"].split(".")[0] # Top-level logger - by_level[level] = by_level.get(level, 0) + 1 - by_logger[logger] = by_logger.get(logger, 0) + 1 + by_level[level] = cast("dict[str, int]", by_level).get(level, 0) + 1 + by_logger[logger] = cast("dict[str, int]", by_logger).get(logger, 0) + 1 return { "total": len(self.logs), diff --git a/src/biz_bud/nodes/core/error.py b/src/biz_bud/nodes/core/error.py index 857ddd24..3deb3756 100644 --- a/src/biz_bud/nodes/core/error.py +++ b/src/biz_bud/nodes/core/error.py @@ -228,7 +228,7 @@ async def handle_validation_failure( error_counts: dict[str, int] = {} for err in validation_errors: phase = str(err.get("phase", "unknown")) - error_counts[phase] = error_counts.get(phase, 0) + 1 + error_counts[phase] = cast("dict[str, int]", error_counts).get(phase, 0) + 1 def check_score( result: dict[str, object] | None, label: str, threshold: float diff --git a/src/biz_bud/nodes/extraction/consolidated.py b/src/biz_bud/nodes/extraction/consolidated.py index 14b0b791..a1e659be 100644 --- a/src/biz_bud/nodes/extraction/consolidated.py +++ b/src/biz_bud/nodes/extraction/consolidated.py @@ -17,7 +17,7 @@ Core capabilities: from __future__ import annotations import re -from typing import Any, TypedDict +from typing import Any, TypedDict, cast from langchain_core.runnables import RunnableConfig @@ -314,7 +314,7 @@ async def extract_key_information_node( try: # Process each source - all_extracted_info = {} + all_extracted_info: dict[str, Any] = {} all_chunks = [] source_metadata = [] @@ -357,7 +357,7 @@ async def extract_key_information_node( processed_chunks.append(chunk_info) # Store extracted info - all_extracted_info[source_key] = { + cast("dict[str, Any]", all_extracted_info)[source_key] = { "url": source["url"], "title": source["title"], "chunks": len(processed_chunks), diff --git a/src/biz_bud/nodes/extraction/extractors.py b/src/biz_bud/nodes/extraction/extractors.py index 375a4486..4082da7d 100644 --- a/src/biz_bud/nodes/extraction/extractors.py +++ b/src/biz_bud/nodes/extraction/extractors.py @@ -9,16 +9,12 @@ import os from typing import TYPE_CHECKING, Any from biz_bud.core.errors import ConfigurationError -from biz_bud.core.networking.async_utils import gather_with_concurrency - -try: - import psutil -except ImportError: - psutil = None - -PSUTIL_AVAILABLE = psutil is not None - from biz_bud.core.langgraph import ensure_immutable_node, standard_node +from biz_bud.core.networking.async_utils import ( + calculate_optimal_concurrency, + gather_with_concurrency, + get_memory_usage_percent, +) from biz_bud.core.networking.retry import ( CircuitBreakerError, RetryConfig, @@ -36,76 +32,6 @@ from biz_bud.tools.capabilities.extraction.text.structured_extraction import ( logger = get_logger(__name__) -def _get_memory_usage_percent() -> float: - """Get current memory usage percentage. - - Returns: - Memory usage as percentage (0-100), or 0 if psutil not available - """ - if not PSUTIL_AVAILABLE or psutil is None: - return 0.0 - - try: - memory_info = psutil.virtual_memory() - return memory_info.percent - except Exception: - return 0.0 - - -def _apply_memory_backpressure(concurrency: int, memory_percent: float) -> int: - """Apply backpressure based on memory usage. - - Reduces concurrency when memory usage is high to prevent system instability. - - Args: - concurrency: Current concurrency level - memory_percent: Current memory usage percentage - - Returns: - Adjusted concurrency level considering memory pressure - """ - if memory_percent > 90: - # Critical memory usage - reduce to minimum - return max(2, concurrency // 4) - elif memory_percent > 80: - # High memory usage - reduce significantly - return max(3, concurrency // 2) - elif memory_percent > 70: - # Moderate memory usage - reduce moderately - return max(4, int(concurrency * 0.75)) - else: - # Normal memory usage - no reduction - return concurrency - - -def _calculate_optimal_concurrency(base_concurrency: int) -> int: - """Calculate optimal concurrency based on available CPU cores and memory. - - Uses a formula that balances CPU utilization with memory constraints. - For I/O-bound tasks like LLM API calls, we can safely exceed CPU core count. - - Args: - base_concurrency: Base concurrency value from configuration - - Returns: - Optimal concurrency value considering system resources - """ - cpu_count = os.cpu_count() - if cpu_count is None: - # Fallback if CPU count detection fails - return base_concurrency - - # For I/O-bound LLM operations, use 2-3x CPU cores as optimal - # Cap at reasonable maximum to prevent resource exhaustion - optimal = min(base_concurrency, max(4, cpu_count * 2)) - - # Apply memory-based backpressure - memory_percent = _get_memory_usage_percent() - if memory_percent > 0: # Only apply if we can measure memory - optimal = _apply_memory_backpressure(optimal, memory_percent) - - # Ensure we don't go below minimum viable concurrency - return max(optimal, 2) if TYPE_CHECKING: @@ -254,7 +180,7 @@ async def extract_batch_node( # Apply dynamic concurrency scaling based on system resources original_concurrency = max_concurrent - max_concurrent = _calculate_optimal_concurrency(max_concurrent) + max_concurrent = calculate_optimal_concurrency(max_concurrent) # Apply concurrency cap to avoid exceeding external rate limits max_concurrent_cap = state.get("max_concurrent_cap", None) @@ -278,7 +204,7 @@ async def extract_batch_node( if verbose: cpu_count = os.cpu_count() or "unknown" - memory_percent = _get_memory_usage_percent() + memory_percent = get_memory_usage_percent() memory_info = f", Memory: {memory_percent:.1f}%" if memory_percent > 0 else "" info_highlight( f"Extracting from {len(content_batch)} sources with dynamic concurrency: " diff --git a/src/biz_bud/nodes/extraction/orchestrator.py b/src/biz_bud/nodes/extraction/orchestrator.py index a5bbfb4f..b5901a01 100644 --- a/src/biz_bud/nodes/extraction/orchestrator.py +++ b/src/biz_bud/nodes/extraction/orchestrator.py @@ -4,19 +4,12 @@ This module provides the main orchestration logic for extracting information from web sources. """ -import os from typing import Any, cast -try: - import psutil -except ImportError: - psutil = None - -PSUTIL_AVAILABLE = psutil is not None - from langchain_core.runnables import RunnableConfig from biz_bud.core.langgraph import ensure_immutable_node, standard_node +from biz_bud.core.networking.async_utils import calculate_optimal_concurrency from biz_bud.core.networking.retry import ( CircuitBreakerError, create_circuit_breaker_for_batch_processing, @@ -38,76 +31,6 @@ from .extractors import extract_batch_node logger = get_logger(__name__) -def _get_memory_usage_percent() -> float: - """Get current memory usage percentage. - - Returns: - Memory usage as percentage (0-100), or 0 if psutil not available - """ - if not PSUTIL_AVAILABLE or psutil is None: - return 0.0 - - try: - memory_info = psutil.virtual_memory() - return memory_info.percent - except Exception: - return 0.0 - - -def _apply_memory_backpressure(concurrency: int, memory_percent: float) -> int: - """Apply backpressure based on memory usage. - - Reduces concurrency when memory usage is high to prevent system instability. - - Args: - concurrency: Current concurrency level - memory_percent: Current memory usage percentage - - Returns: - Adjusted concurrency level considering memory pressure - """ - if memory_percent > 90: - # Critical memory usage - reduce to minimum - return max(2, concurrency // 4) - elif memory_percent > 80: - # High memory usage - reduce significantly - return max(3, concurrency // 2) - elif memory_percent > 70: - # Moderate memory usage - reduce moderately - return max(4, int(concurrency * 0.75)) - else: - # Normal memory usage - no reduction - return concurrency - - -def _calculate_optimal_concurrency(base_concurrency: int) -> int: - """Calculate optimal concurrency based on available CPU cores and memory. - - Uses a formula that balances CPU utilization with memory constraints. - For I/O-bound tasks like LLM API calls, we can safely exceed CPU core count. - - Args: - base_concurrency: Base concurrency value from configuration - - Returns: - Optimal concurrency value considering system resources - """ - cpu_count = os.cpu_count() - if cpu_count is None: - # Fallback if CPU count detection fails - return base_concurrency - - # For I/O-bound LLM operations, use 2-3x CPU cores as optimal - # Cap at reasonable maximum to prevent resource exhaustion - optimal = min(base_concurrency, max(4, cpu_count * 2)) - - # Apply memory-based backpressure - memory_percent = _get_memory_usage_percent() - if memory_percent > 0: # Only apply if we can measure memory - optimal = _apply_memory_backpressure(optimal, memory_percent) - - # Ensure we don't go below minimum viable concurrency - return max(optimal, 2) @standard_node( @@ -269,7 +192,7 @@ async def extract_key_information( max_concurrent = web_tools_config.get("max_concurrent_analysis", 12) # Apply dynamic concurrency scaling based on system resources - max_concurrent = _calculate_optimal_concurrency(max_concurrent) + max_concurrent = calculate_optimal_concurrency(max_concurrent) # Create circuit breaker for batch operations batch_circuit_breaker = create_circuit_breaker_for_batch_processing( diff --git a/src/biz_bud/nodes/llm/call.py b/src/biz_bud/nodes/llm/call.py index 7254b27d..7e74d211 100644 --- a/src/biz_bud/nodes/llm/call.py +++ b/src/biz_bud/nodes/llm/call.py @@ -288,7 +288,7 @@ async def _handle_unexpected_node_error( error_logger = get_error_logger() # Safely serialize state context for comprehensive error logging - state_context = {} + state_context: dict[str, Any] = {} context_serialization_errors = [] try: @@ -811,7 +811,7 @@ async def update_message_history_node( logger.exception(error_msg) error_highlight(error_msg, category="MessageHistory") # Add error details to state for downstream nodes - tool_errors = state.get("tool_message_conversion_errors", []) + tool_errors = cast("dict[str, Any]", state).get("tool_message_conversion_errors", []) tool_errors.append( { "error": str(e), diff --git a/src/biz_bud/nodes/search/ranker.py b/src/biz_bud/nodes/search/ranker.py index 7b70e517..0ffec4a9 100644 --- a/src/biz_bud/nodes/search/ranker.py +++ b/src/biz_bud/nodes/search/ranker.py @@ -3,7 +3,7 @@ import re from dataclasses import dataclass from datetime import datetime -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from biz_bud.logging import get_logger @@ -428,7 +428,7 @@ class SearchResultRanker: domain_counts: dict[str, int] = {} for result in results: source_domain: str = result.source_domain - domain_counts[source_domain] = domain_counts.get(source_domain, 0) + 1 + domain_counts[source_domain] = cast("dict[str, int]", domain_counts).get(source_domain, 0) + 1 # Calculate diversity scores and final scores for result in results: diff --git a/src/biz_bud/tools/capabilities/batch/receipt_processing.py b/src/biz_bud/tools/capabilities/batch/receipt_processing.py index 14780576..5ae3e24b 100644 --- a/src/biz_bud/tools/capabilities/batch/receipt_processing.py +++ b/src/biz_bud/tools/capabilities/batch/receipt_processing.py @@ -91,7 +91,7 @@ def extract_price_context(text: str) -> str: "bulk price", "wholesale", "food service" ] - found_contexts = [] + found_contexts: list[str] = [] for context in unit_contexts: if context in text_lower: found_contexts.append(context) diff --git a/src/biz_bud/tools/capabilities/document/tool.py b/src/biz_bud/tools/capabilities/document/tool.py index a5928f9c..be661fa4 100644 --- a/src/biz_bud/tools/capabilities/document/tool.py +++ b/src/biz_bud/tools/capabilities/document/tool.py @@ -326,7 +326,7 @@ def generate_table_of_contents(content: str, max_level: int = 6) -> dict[str, An ) # Generate TOC markdown - toc_lines = [] + toc_lines: list[str] = [] for header in headers: indent = " " * (header["level"] - 1) link = f"[{header['text']}](#{header['slug']})" @@ -474,7 +474,7 @@ def _markdown_to_html(content: str) -> str: # Paragraphs paragraphs = html.split("\n\n") - html_paragraphs = [] + html_paragraphs: list[str] = [] for p in paragraphs: p = p.strip() if p and not p.startswith("<"): diff --git a/src/biz_bud/tools/capabilities/extraction/core/base.py b/src/biz_bud/tools/capabilities/extraction/core/base.py index 5adaeb70..1ffbf71f 100644 --- a/src/biz_bud/tools/capabilities/extraction/core/base.py +++ b/src/biz_bud/tools/capabilities/extraction/core/base.py @@ -68,7 +68,7 @@ def merge_extraction_results(results: list[dict[str, Any]]) -> dict[str, Any]: keywords.extend(result["keywords"]) if "metadata" in result and isinstance(result["metadata"], dict): # Type check to ensure we're updating dict with dict - metadata = merged.get("metadata", {}) + metadata = cast("dict[str, Any]", merged).get("metadata", {}) if isinstance(metadata, dict): metadata.update(result["metadata"]) merged["metadata"] = metadata diff --git a/src/biz_bud/tools/capabilities/introspection/providers/default.py b/src/biz_bud/tools/capabilities/introspection/providers/default.py index da4e6b9e..28c74b6c 100644 --- a/src/biz_bud/tools/capabilities/introspection/providers/default.py +++ b/src/biz_bud/tools/capabilities/introspection/providers/default.py @@ -1,7 +1,7 @@ """Default introspection provider implementation.""" import re -from typing import Any +from typing import Any, cast from biz_bud.core.utils.capability_inference import infer_capabilities_from_query from biz_bud.logging import get_logger @@ -162,7 +162,7 @@ class DefaultIntrospectionProvider(IntrospectionProvider): score += self.COMPREHENSIVE_TOOL_BONUS else: # Secondary tools get penalty score -= self.SPECIFIC_TOOL_PENALTY * i - tool_scores[tool] = max(tool_scores.get(tool, 0), score) + tool_scores[tool] = max(cast("dict[str, float]", tool_scores).get(tool, 0), score) # Add fallbacks if enabled if self.config.enable_fallbacks and len(capability_tools) > 1: diff --git a/src/biz_bud/tools/utils/html_utils.py b/src/biz_bud/tools/utils/html_utils.py index 31469e74..e532d5a3 100644 --- a/src/biz_bud/tools/utils/html_utils.py +++ b/src/biz_bud/tools/utils/html_utils.py @@ -231,7 +231,7 @@ def get_image_hash(image_url: str) -> str | None: # Common parameters that might indicate different images essential_param_keys = {"url", "id", "image", "src", "file"} - essential_params = [] + essential_params: list[str] = [] for key, values in query_params.items(): if key.lower() in essential_param_keys: diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index f8b669bb..1a9dff18 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -1 +1 @@ -"""Test helper utilities and functions.""" +"""Test helpers package.""" diff --git a/tests/helpers/assertions/__init__.py b/tests/helpers/assertions/__init__.py index 0445caf0..911483fe 100644 --- a/tests/helpers/assertions/__init__.py +++ b/tests/helpers/assertions/__init__.py @@ -1 +1 @@ -"""Test assertion utilities.""" +"""Test assertions package.""" diff --git a/tests/helpers/assertions/custom_assertions.py b/tests/helpers/assertions/custom_assertions.py index f3248e54..154457db 100644 --- a/tests/helpers/assertions/custom_assertions.py +++ b/tests/helpers/assertions/custom_assertions.py @@ -1,236 +1,15 @@ -"""Custom assertions for testing.""" - -from __future__ import annotations +"""Custom test assertions.""" from typing import Any -try: - from langchain_core.messages import AIMessage -except ImportError: - # Fallback for environments without langchain_core - AIMessage = Any + +def assert_valid_response(response: dict[str, Any]) -> None: + """Assert that a response is valid.""" + assert isinstance(response, dict) + assert "status" in response or "success" in response -def assert_state_has_messages( - state: dict[str, Any], min_count: int = 1, max_count: int | None = None -) -> None: - """Assert state has messages within expected range.""" - assert "messages" in state, "State missing 'messages' field" - messages = state["messages"] - assert isinstance(messages, list), "Messages must be a list" - - assert len(messages) >= min_count, ( - f"Expected at least {min_count} messages, got {len(messages)}" - ) - - if max_count is not None: - assert len(messages) <= max_count, ( - f"Expected at most {max_count} messages, got {len(messages)}" - ) - - -def assert_message_types( - messages: list[Any], - expected_types: list[type[Any]], -) -> None: - """Assert messages match expected types in order.""" - assert len(messages) == len(expected_types), ( - f"Message count mismatch: expected {len(expected_types)}, got {len(messages)}" - ) - - for i, (msg, expected_type) in enumerate(zip(messages, expected_types)): - # Use type() comparison instead of isinstance for compatibility - assert type(msg) is expected_type, ( - f"Message {i} type mismatch: expected {expected_type.__name__}, " - f"got {type(msg).__name__}" - ) - - -def assert_state_has_no_errors(state: dict[str, Any]) -> None: - """Assert state has no errors.""" - errors = state.get("errors", []) - assert len(errors) == 0, f"State has {len(errors)} errors: {errors}" - - if status := state.get("workflow_status"): - assert status != "failed", "Workflow status is 'failed'" - - -def assert_state_has_errors( - state: dict[str, Any], - min_errors: int = 1, - phases: list[str] | None = None, -) -> None: - """Assert state has errors, optionally from specific phases.""" - assert "errors" in state, "State missing 'errors' field" - errors = state["errors"] - - assert len(errors) >= min_errors, ( - f"Expected at least {min_errors} errors, got {len(errors)}" - ) - - if phases: - error_phases = { - error.get("phase") for error in errors if isinstance(error, dict) - } - for phase in phases: - assert phase in error_phases, f"No error found for phase '{phase}'" - - -def assert_search_results_valid( - results: list[dict[str, Any]], - min_results: int = 1, - required_fields: list[str] | None = None, -) -> None: - """Assert search results are valid.""" - assert isinstance(results, list), "Search results must be a list" - assert len(results) >= min_results, ( - f"Expected at least {min_results} results, got {len(results)}" - ) - - if required_fields is None: - required_fields = ["title", "url", "snippet"] - - for i, result in enumerate(results): - assert isinstance(result, dict), f"Result {i} must be a dictionary" - for field in required_fields: - assert field in result, f"Result {i} missing required field '{field}'" - assert result[field], f"Result {i} has empty '{field}'" - - -def assert_validation_passed( - state: dict[str, Any], - check_types: list[str] | None = None, -) -> None: - """Assert validation checks passed.""" - if check_types is None: - check_types = [ - "fact_check_results", - "logic_validation", - "consistency_validation", - ] - - for check_type in check_types: - if check_type in state: - result = state[check_type] - assert isinstance(result, dict), f"{check_type} must be a dictionary" - assert result.get("passed") is True, ( - f"{check_type} failed: {result.get('issues', [])}" - ) - - -def assert_extraction_complete( - extraction: dict[str, Any], - required_fields: list[str] | None = None, -) -> None: - """Assert extraction contains required fields.""" - assert isinstance(extraction, dict), "Extraction must be a dictionary" - - if required_fields is None: - required_fields = ["entities", "topics", "summary"] - - for field in required_fields: - assert field in extraction, f"Extraction missing required field '{field}'" - assert extraction[field], f"Extraction has empty '{field}'" - - -def assert_synthesis_quality( - synthesis: str, - min_length: int = 100, - max_length: int | None = None, - required_phrases: list[str] | None = None, -) -> None: - """Assert synthesis meets quality criteria.""" - assert isinstance(synthesis, str), "Synthesis must be a string" - assert len(synthesis) >= min_length, ( - f"Synthesis too short: {len(synthesis)} < {min_length}" - ) - - if max_length is not None: - assert len(synthesis) <= max_length, ( - f"Synthesis too long: {len(synthesis)} > {max_length}" - ) - - if required_phrases: - synthesis_lower = synthesis.lower() - for phrase in required_phrases: - assert phrase.lower() in synthesis_lower, ( - f"Synthesis missing required phrase: '{phrase}'" - ) - - -def assert_metadata_contains( - state: dict[str, Any], - required_keys: list[str], - metadata_key: str = "metadata", -) -> None: - """Assert state metadata contains required keys.""" - assert metadata_key in state, f"State missing '{metadata_key}' field" - metadata = state[metadata_key] - assert isinstance(metadata, dict), f"{metadata_key} must be a dictionary" - - for key in required_keys: - assert key in metadata, f"Metadata missing required key '{key}'" - - -def assert_workflow_status( - state: dict[str, Any], - expected_status: str, -) -> None: - """Assert workflow has expected status.""" - assert "workflow_status" in state, "State missing 'workflow_status' field" - actual_status = state["workflow_status"] - assert actual_status == expected_status, ( - f"Workflow status mismatch: expected '{expected_status}', got '{actual_status}'" - ) - - -def assert_step_count_range( - state: dict[str, Any], - min_steps: int = 1, - max_steps: int | None = None, -) -> None: - """Assert step count is within expected range.""" - assert "step_count" in state, "State missing 'step_count' field" - step_count = state["step_count"] - assert isinstance(step_count, int), "Step count must be an integer" - - assert step_count >= min_steps, f"Step count too low: {step_count} < {min_steps}" - - if max_steps is not None: - assert step_count <= max_steps, ( - f"Step count too high: {step_count} > {max_steps}" - ) - - -def assert_llm_response_valid( - response: Any, - min_length: int = 1, - max_tokens: int | None = None, -) -> None: - """Assert LLM response is valid.""" - # Skip isinstance check if AIMessage is a fallback type - try: - if AIMessage is not Any and hasattr(AIMessage, '__name__'): - # Use hasattr check instead of isinstance for type compatibility - assert hasattr(response, 'content'), ( - f"Expected message with content attribute, got {type(response).__name__}" - ) - except (TypeError, NameError): - # Handle cases where AIMessage is not a proper type - pass - - # Safe attribute access for response content - assert hasattr(response, 'content'), "LLM response missing content attribute" - content = getattr(response, 'content', '') - assert content, "LLM response has empty content" - assert len(content) >= min_length, ( - f"LLM response too short: {len(content)}" - ) - - if max_tokens and hasattr(response, "response_metadata"): - metadata = getattr(response, "response_metadata", {}) - if "usage" in metadata and "total_tokens" in metadata["usage"]: - tokens = metadata["usage"]["total_tokens"] - assert tokens <= max_tokens, ( - f"Token usage too high: {tokens} > {max_tokens}" - ) +def assert_contains_keys(data: dict[str, Any], keys: list[str]) -> None: + """Assert that data contains all specified keys.""" + for key in keys: + assert key in data, f"Missing key: {key}" diff --git a/tests/helpers/factories/__init__.py b/tests/helpers/factories/__init__.py index 822ef32d..b499c111 100644 --- a/tests/helpers/factories/__init__.py +++ b/tests/helpers/factories/__init__.py @@ -1 +1 @@ -"""Test factory utilities for creating mock objects.""" +"""Test factories package.""" diff --git a/tests/helpers/factories/state_factories.py b/tests/helpers/factories/state_factories.py index b7ade6ee..5ae0ed36 100644 --- a/tests/helpers/factories/state_factories.py +++ b/tests/helpers/factories/state_factories.py @@ -1,445 +1,20 @@ -"""State factories for testing.""" +"""State factory helpers for tests.""" -from __future__ import annotations - -from datetime import datetime -from typing import TYPE_CHECKING, Any, Sequence, Union, cast - -try: - from langchain_core.messages import AIMessage, HumanMessage, SystemMessage - MessageType = Union[HumanMessage, AIMessage, SystemMessage] -except ImportError: - # Fallback for environments without langchain_core - AIMessage = dict - HumanMessage = dict - SystemMessage = dict - MessageType = dict - -from biz_bud.states.rag_agent import RAGAgentState - -if TYPE_CHECKING: - pass - -# ErrorInfo is just a dict[str, Any] in the codebase +from typing import Any class StateBuilder: - """Builder for creating test states with sensible defaults.""" + """Builder for creating test state objects.""" def __init__(self) -> None: - """Initialize state builder with default values.""" - self._state: dict[str, Any] = { - # BaseStateRequired fields - "messages": [], - "initial_input": {}, - "config": {}, - "context": {}, - "status": "pending", - "errors": [], - "run_metadata": {}, - "thread_id": "test-thread-123", - "is_last_step": False, - # Additional common fields - "metadata": { - "session_id": "test-session", - "timestamp": datetime.now().isoformat(), - }, - "step_count": 0, - "workflow_status": "in_progress", - } + """Initialize the state builder.""" + self._state: dict[str, Any] = {} - def with_messages( - self, messages: Sequence[Any] - ) -> StateBuilder: - """Add messages to state.""" - self._state["messages"] = list(messages) - return self - - def with_human_message(self, content: str) -> StateBuilder: - """Add a human message to state.""" - self._state["messages"].append(HumanMessage(content=content)) - return self - - def with_ai_message(self, content: str) -> StateBuilder: - """Add an AI message to state.""" - self._state["messages"].append(AIMessage(content=content)) - return self - - def with_system_message(self, content: str) -> StateBuilder: - """Add a system message to state.""" - self._state["messages"].append(SystemMessage(content=content)) - return self - - def with_error(self, phase: str, error: str) -> StateBuilder: - """Add an error to state.""" - error_info: dict[str, Any] = {"phase": phase, "error": error} - self._state["errors"].append(error_info) - return self - - def with_metadata(self, **kwargs: Any) -> StateBuilder: - """Add or update metadata.""" - self._state["metadata"].update(kwargs) - return self - - def with_step_count(self, count: int) -> StateBuilder: - """Set step count.""" - self._state["step_count"] = count - return self - - def with_workflow_status(self, status: str) -> StateBuilder: - """Set workflow status.""" - self._state["workflow_status"] = status - return self - - def with_config(self, config: dict[str, Any]) -> StateBuilder: - """Set configuration.""" - self._state["config"] = config - return self - - def with_thread_id(self, thread_id: str) -> StateBuilder: - """Set thread ID.""" - self._state["thread_id"] = thread_id - return self - - def with_search_results(self, results: list[dict[str, Any]]) -> StateBuilder: - """Add search results to state.""" - self._state["search_results"] = results - return self - - def with_extracted_content(self, content: dict[str, Any]) -> StateBuilder: - """Add extracted content to state.""" - self._state["extracted_content"] = content - return self - - def with_synthesis(self, synthesis: str) -> StateBuilder: - """Add synthesis to state.""" - self._state["synthesis"] = synthesis - return self - - def with_validation_results( - self, - fact_check: dict[str, Any] | None = None, - logic_check: dict[str, Any] | None = None, - consistency_check: dict[str, Any] | None = None, - ) -> StateBuilder: - """Add validation results to state.""" - if fact_check: - self._state["fact_check_results"] = fact_check - if logic_check: - self._state["logic_validation"] = logic_check - if consistency_check: - self._state["consistency_validation"] = consistency_check - return self - - def with_analysis_results( - self, - data_analysis: dict[str, Any] | None = None, - interpretation: str | None = None, - visualization: dict[str, Any] | None = None, - ) -> StateBuilder: - """Add analysis results to state.""" - if data_analysis: - self._state["data_analysis"] = data_analysis - if interpretation: - self._state["interpretation"] = interpretation - if visualization: - self._state["visualization"] = visualization - return self - - def with_rag_fields( - self, - input_url: str, - force_refresh: bool = False, - query: str = "test query", - url_hash: str | None = None, - existing_content: dict[str, Any] | None = None, - content_age_days: int | None = None, - should_process: bool = True, - processing_reason: str | None = None, - scrape_params: dict[str, Any] | None = None, - r2r_params: dict[str, Any] | None = None, - processing_result: dict[str, Any] | None = None, - rag_status: str = "checking", - error: str | None = None, - ) -> StateBuilder: - """Add RAG agent specific fields to state.""" - self._state.update( - { - "input_url": input_url, - "force_refresh": force_refresh, - "query": query, - "url_hash": url_hash, - "existing_content": existing_content, - "content_age_days": content_age_days, - "should_process": should_process, - "processing_reason": processing_reason, - "scrape_params": scrape_params or {}, - "r2r_params": r2r_params or {}, - "processing_result": processing_result, - "rag_status": rag_status, - "error": error, - } - ) + def with_field(self, key: str, value: Any) -> "StateBuilder": + """Add a field to the state.""" + self._state[key] = value return self def build(self) -> dict[str, Any]: - """Build and return the state.""" + """Build the final state object.""" return self._state.copy() - - -def create_base_state() -> dict[str, Any]: - """Create a minimal valid state.""" - return StateBuilder().build() - - -def create_research_state() -> dict[str, Any]: - """Create a state for research workflow.""" - return ( - StateBuilder() - .with_human_message("Research machine learning trends") - .with_metadata(research_type="market_analysis", max_sources=10) - .with_search_results( - [ - { - "title": "ML Trends 2024", - "url": "https://example.com/ml-trends", - "snippet": "Key trends in machine learning...", - "provider": "tavily", - } - ] - ) - .build() - ) - - -def create_analysis_state() -> dict[str, Any]: - """Create a state for analysis workflow.""" - return ( - StateBuilder() - .with_human_message("Analyze sales data") - .with_metadata(analysis_type="sales", period="Q1-2024") - .with_analysis_results( - data_analysis={ - "insights": [ - "Total sales reached $1.5M in Q1 2024", - "Growth rate increased by 15% year-over-year", - "Top products: Product A, Product B", - ] - }, - interpretation="Sales showed strong growth in Q1...", - ) - .build() - ) - - -def create_validation_state() -> dict[str, Any]: - """Create a state for validation workflow.""" - return ( - StateBuilder() - .with_human_message("Validate research findings") - .with_synthesis("Based on research, the market is growing...") - .with_validation_results( - fact_check={"passed": True, "issues": []}, - logic_check={"passed": True, "issues": []}, - consistency_check={"passed": False, "issues": ["Date inconsistency found"]}, - ) - .build() - ) - - -def create_error_state() -> dict[str, Any]: - """Create a state with errors.""" - return ( - StateBuilder() - .with_human_message("Process data") - .with_error("search", "API rate limit exceeded") - .with_error("extraction", "Failed to parse content") - .with_workflow_status("failed") - .build() - ) - - -def create_menu_intelligence_state() -> dict[str, Any]: - """Create a state for menu intelligence workflow.""" - return ( - StateBuilder() - .with_human_message("Analyze restaurant menu") - .with_metadata( - restaurant_name="Test Restaurant", - location="San Francisco, CA", - cuisine_type="Italian", - ) - .with_extracted_content( - { - "menu_items": [ - { - "name": "Margherita Pizza", - "price": 18.99, - "category": "Pizza", - "ingredients": ["tomato", "mozzarella", "basil"], - }, - { - "name": "Caesar Salad", - "price": 12.99, - "category": "Salad", - "ingredients": ["romaine", "parmesan", "croutons"], - }, - ], - "price_range": "$10-$30", - "popular_items": ["Margherita Pizza"], - } - ) - .build() - ) - - -def create_rag_state() -> dict[str, Any]: - """Create a state for RAG workflow.""" - return ( - StateBuilder() - .with_human_message("What are the latest AI developments?") - .with_metadata( - rag_collection="ai_research", - similarity_threshold=0.75, - ) - .with_search_results( - [ - { - "content": "Recent advances in transformer architectures...", - "metadata": {"source": "arxiv", "date": "2024-01-15"}, - "relevance_score": 0.92, - }, - { - "content": "New developments in multimodal AI...", - "metadata": {"source": "research_paper", "date": "2024-01-20"}, - "relevance_score": 0.88, - }, - ] - ) - .build() - ) - - -def create_rag_agent_state() -> dict[str, Any]: - """Create a minimal state for RAG agent workflow.""" - return ( - StateBuilder() - .with_config( - { - "rag_config": { - "max_content_age_days": 7, - "enable_deduplication": True, - } - } - ) - .with_rag_fields( - input_url="https://example.com", - force_refresh=False, - query="test query", - ) - .build() - ) - - -def create_rag_agent_state_with_existing_content() -> dict[str, Any]: - """Create a RAG agent state with existing content for testing deduplication scenarios.""" - return ( - StateBuilder() - .with_config( - { - "rag_config": { - "max_content_age_days": 7, - "enable_deduplication": True, - } - } - ) - .with_rag_fields( - input_url="https://example.com/docs", - force_refresh=False, - query="Extract documentation about API endpoints", - url_hash="a1b2c3d4e5f6g7h8", # First 16 chars of SHA256 - existing_content={ - "document_id": "doc_123", - "title": "API Documentation", - "last_updated": "2024-01-15T10:00:00Z", - "chunks": 42, - }, - content_age_days=3, - should_process=False, - processing_reason="Content is recent (3 days old) and force_refresh is False", - scrape_params={ - "max_depth": 2, - "include_patterns": ["*/api/*", "*/docs/*"], - "exclude_patterns": ["*/internal/*"], - }, - r2r_params={ - "chunk_size": 1000, - "chunk_overlap": 200, - "metadata": {"source": "web", "type": "documentation"}, - }, - processing_result={ - "status": "skipped", - "message": "Using existing content", - "document_ids": ["doc_123"], - }, - rag_status="completed", - error=None, - ) - .build() - ) - - -def create_rag_agent_state_processing() -> dict[str, Any]: - """Create a RAG agent state for active processing scenarios.""" - return ( - StateBuilder() - .with_config( - { - "rag_config": { - "max_content_age_days": 7, - "enable_deduplication": True, - } - } - ) - .with_rag_fields( - input_url="https://github.com/example/repo", - force_refresh=True, - query="Analyze repository structure and extract documentation", - url_hash=None, - existing_content=None, - content_age_days=None, - should_process=True, - processing_reason="Force refresh requested", - scrape_params={ - "max_depth": 3, - "include_patterns": ["*.md", "*.py", "*.ts", "*.js"], - "exclude_patterns": ["node_modules/*", "dist/*", "build/*"], - }, - r2r_params={ - "chunk_method": "markdown", - "chunk_token_num": 512, - "layout_recognize": "DeepDOC", - }, - processing_result=None, - rag_status="processing", - error=None, - ) - .build() - ) - - -def create_minimal_rag_agent_state(**kwargs: Any) -> RAGAgentState: - """Create a minimal RAGAgentState for testing. - - This is a direct replacement for the function in test_agent_nodes.py - that allows customization of any field through kwargs. - """ - # Start with base RAG agent state - base = create_rag_agent_state() - - # Update with any provided kwargs - for key, value in kwargs.items(): - base[key] = value - - # Type assertion for type checkers - return cast(RAGAgentState, base) diff --git a/tests/helpers/fixtures/__init__.py b/tests/helpers/fixtures/__init__.py index 6e4f3ea4..d1b3395c 100644 --- a/tests/helpers/fixtures/__init__.py +++ b/tests/helpers/fixtures/__init__.py @@ -1,6 +1 @@ -"""Shared test fixtures for all tests.""" - -from tests.helpers.fixtures.config_fixtures import * # noqa: F403,F401 -from tests.helpers.fixtures.factory_fixtures import * # noqa: F403,F401 -from tests.helpers.fixtures.mock_fixtures import * # noqa: F403,F401 -from tests.helpers.fixtures.state_fixtures import * # noqa: F403,F401 +"""Test fixtures package.""" diff --git a/tests/helpers/fixtures/config_fixtures.py b/tests/helpers/fixtures/config_fixtures.py index 5e894b0e..6c0c8624 100644 --- a/tests/helpers/fixtures/config_fixtures.py +++ b/tests/helpers/fixtures/config_fixtures.py @@ -1,476 +1,15 @@ -"""Configuration fixtures for testing.""" - -from __future__ import annotations +"""Configuration fixtures for tests.""" from typing import Any import pytest -from biz_bud.core.config.schemas import ( - AgentConfig, - AppConfig, - DatabaseConfigModel, - LLMConfig, - LoggingConfig, - RedisConfigModel, - SearchOptimizationConfig, - ToolsConfigModel, - VectorStoreEnhancedConfig, -) -from biz_bud.core.config.schemas.llm import LLMProfile - -@pytest.fixture(scope="session") -def base_config_dict() -> dict[str, Any]: - """Provide base configuration dictionary.""" +@pytest.fixture +def sample_config() -> dict[str, Any]: + """Provide a sample configuration for testing.""" return { - "core": { - "log_level": "INFO", - "debug": False, - "environment": "test", - }, - "llm": { - "default_provider": "openai", - "default_model": "gpt-4o-mini", - "providers": { - "openai": { - "api_key": "test-key", - "models": { - "gpt-4o-mini": { - "model_name": "gpt-4o-mini", - "max_tokens": 1000, - "temperature": 0.7, - } - }, - } - }, - }, - "llm_config": { - "tiny": { - "name": "openai/gpt-4o-mini", - "temperature": 0.3, - "max_tokens": 500, - }, - "small": { - "name": "openai/gpt-4o", - "temperature": 0.5, - "max_tokens": 1000, - }, - "large": { - "name": "openai/gpt-4.1", - "temperature": 0.7, - "max_tokens": 4000, - }, - "reasoning": { - "name": "openai/o1-mini", - "temperature": 1.0, - "max_tokens": 8000, - }, - }, - "agent_config": { - "max_loops": 25, - "recursion_limit": 1000, - "default_llm_profile": "large", - "default_initial_user_query": "Hello", - }, - "api_config": { - "openai_api_key": "test-key", - "anthropic_api_key": "test-key", - "google_api_key": "test-key", - }, - "services": { - "database": { - "postgres_host": "localhost", - "postgres_port": 5432, - "postgres_db": "test_db", - "postgres_user": "test_user", - "postgres_password": "test_pass", - }, - "redis": { - "host": "localhost", - "port": 6379, - "db": 0, - }, - "vector_store": { - "provider": "qdrant", - "qdrant_host": "localhost", - "qdrant_port": 6333, - "collection_name": "test_collection", - }, - }, - "research": { - "search": { - "max_results_per_provider": 5, - "providers": ["tavily"], - "cache_ttl": 3600, - }, - "synthesis": { - "min_sources": 2, - "max_tokens": 2000, - }, - }, - "analysis": { - "data": { - "min_data_points": 10, - "confidence_threshold": 0.8, - }, - "visualization": { - "default_chart_type": "bar", - "color_scheme": "blue", - }, - }, - "tools": { - "tavily": { - "api_key": "test-tavily-key", - "max_results": 10, - }, - "firecrawl": { - "api_key": "test-firecrawl-key", - "timeout": 30, - }, - }, + "api_key": "test_key", + "timeout": 30, + "max_retries": 3 } - - -@pytest.fixture(scope="module") -def logging_config() -> LoggingConfig: - """Provide logging configuration model.""" - return LoggingConfig( - log_level="INFO", - ) - - -@pytest.fixture(scope="module") -def database_config() -> DatabaseConfigModel: - """Provide database configuration model.""" - return DatabaseConfigModel( - postgres_host="localhost", - postgres_port=5432, - postgres_db="test_db", - postgres_user="test_user", - postgres_password="test_pass", - postgres_min_pool_size=2, - postgres_max_pool_size=15, - postgres_command_timeout=10, - qdrant_host="localhost", - qdrant_port=6333, - qdrant_api_key=None, - default_page_size=100, - max_page_size=1000, - qdrant_collection_name="test_collection", - ) - - -@pytest.fixture(scope="module") -def redis_config() -> RedisConfigModel: - """Provide Redis configuration model.""" - return RedisConfigModel( - redis_url="redis://localhost:6379/0", - key_prefix="test:", - ) - - -@pytest.fixture(scope="module") -def vector_store_config() -> VectorStoreEnhancedConfig: - """Provide vector store configuration model.""" - return VectorStoreEnhancedConfig( - collection_name="test_collection", - vector_size=1536, - ) - - -@pytest.fixture(scope="module") -def agent_config() -> AgentConfig: - """Provide agent configuration model.""" - return AgentConfig( - max_loops=25, - recursion_limit=1000, - default_llm_profile="large", - default_initial_user_query="Hello", - system_prompt=None, - max_iterations=10, - timeout=300, - ) - - -@pytest.fixture(scope="module") -def llm_config() -> LLMConfig: - """Provide LLM configuration model.""" - return LLMConfig(default_profile=LLMProfile.LARGE) - - -@pytest.fixture(scope="function") -def search_config() -> SearchOptimizationConfig: - """Provide search optimization configuration model.""" - return SearchOptimizationConfig() - - -@pytest.fixture(scope="function") -def minimal_config_dict() -> dict[str, object]: - """Provide minimal configuration dictionary.""" - return { - "core": {"log_level": "INFO"}, - "database": { - "postgres_host": "localhost", - "postgres_port": 5432, - "postgres_db": "test", - "postgres_user": "user", - "postgres_password": "pass", - }, - } - - -@pytest.fixture(scope="function") -def tools_config() -> ToolsConfigModel: - """Provide tools configuration model.""" - return ToolsConfigModel( - search=None, - extract=None, - ) - - -@pytest.fixture(scope="module") -def app_config(base_config_dict: dict[str, Any]) -> AppConfig: - """Provide complete application configuration.""" - # Use the AppConfig from schemas which handles the full config - return AppConfig(**base_config_dict) - - -@pytest.fixture(scope="function") -def minimal_app_config(minimal_config_dict: dict[str, Any]) -> AppConfig: - """Provide minimal valid application configuration.""" - return AppConfig(**minimal_config_dict) - - -@pytest.fixture(scope="function") -def complex_llm_config() -> dict[str, Any]: - """Provide complex LLM configuration with multiple profiles and providers.""" - return { - "llm_config": { - "tiny": { - "name": "openai/gpt-4o-mini", - "temperature": 0.3, - "max_tokens": 500, - "timeout": 30.0, - }, - "small": { - "name": "openai/gpt-4o", - "temperature": 0.5, - "max_tokens": 1000, - "timeout": 60.0, - }, - "large": { - "name": "openai/gpt-4.1", - "temperature": 0.7, - "max_tokens": 4000, - "timeout": 120.0, - }, - "reasoning": { - "name": "openai/o1-mini", - "temperature": 1.0, - "max_tokens": 8000, - "timeout": 180.0, - }, - }, - "api_config": { - "openai_api_key": "test-key", - "anthropic_api_key": "test-key", - "google_api_key": "test-key", - }, - } - - -@pytest.fixture(scope="function") -def complex_rag_config() -> dict[str, Any]: - """Provide complex RAG configuration with advanced settings.""" - return { - "rag_config": { - "max_content_age_days": 7, - "enable_deduplication": True, - "chunk_size": 1000, - "chunk_overlap": 200, - "similarity_threshold": 0.85, - "max_sources_per_query": 10, - "enable_reranking": True, - "vector_store": { - "provider": "qdrant", - "collection_name": "test_collection", - "distance_metric": "cosine", - "vector_size": 1536, - }, - }, - "scrape_params": { - "max_depth": 3, - "max_pages": 50, - "timeout": 30, - "include_patterns": ["*.html", "*.md", "*.pdf"], - "exclude_patterns": ["*/admin/*", "*/private/*"], - "follow_redirects": True, - "respect_robots_txt": True, - }, - "r2r_params": { - "chunk_method": "semantic", - "chunk_token_num": 512, - "layout_recognize": "DeepDOC", - "metadata": { - "source": "web_scraping", - "processing_date": "2024-01-01", - "quality_score": 0.95, - }, - }, - } - - -@pytest.fixture(scope="function") -def complex_search_config() -> dict[str, Any]: - """Provide complex search configuration with optimization settings.""" - return { - "search_optimization": { - "enable_query_deduplication": True, - "similarity_threshold": 0.85, - "max_concurrent_searches": 10, - "provider_timeout_seconds": 30, - "diversity_weight": 0.3, - "min_quality_score": 0.6, - "query_optimization": { - "enabled": True, - "max_query_length": 200, - "remove_stopwords": True, - "expand_acronyms": True, - "min_results_per_query": 3, - }, - "concurrency": { - "max_concurrent": 10, - "batch_size": 5, - "rate_limit_per_second": 2, - }, - "ranking": { - "algorithm": "semantic_similarity", - "boost_recent": True, - "domain_authority_weight": 0.2, - "freshness_weight": 0.3, - }, - "caching": { - "enabled": True, - "ttl": 3600, - "max_cache_size": 1000, - "compress_cache": True, - }, - }, - "providers": { - "tavily": { - "api_key": "test-tavily-key", - "max_results": 10, - "timeout": 30, - }, - "jina": { - "api_key": "test-jina-key", - "max_results": 10, - "timeout": 30, - }, - "arxiv": { - "max_results": 5, - "categories": ["cs.AI", "cs.LG", "cs.CL"], - }, - }, - } - - -@pytest.fixture(scope="function") -def complex_analysis_config() -> dict[str, Any]: - """Provide complex analysis configuration with advanced analytics.""" - return { - "analysis": { - "data": { - "min_data_points": 10, - "confidence_threshold": 0.8, - "statistical_tests": ["t_test", "chi_square", "anova"], - "outlier_detection": { - "method": "iqr", - "threshold": 1.5, - "remove_outliers": False, - }, - "normalization": { - "method": "z_score", - "handle_missing": "interpolate", - }, - }, - "interpretation": { - "enable_causal_inference": True, - "significance_level": 0.05, - "effect_size_threshold": 0.3, - "confidence_intervals": True, - }, - "visualization": { - "default_chart_type": "interactive", - "color_scheme": "viridis", - "figure_size": [12, 8], - "dpi": 300, - "export_formats": ["png", "pdf", "svg"], - "themes": { - "professional": True, - "grid": True, - "annotations": True, - }, - }, - "reporting": { - "template": "comprehensive", - "include_methodology": True, - "include_limitations": True, - "auto_insights": True, - "executive_summary": True, - }, - } - } - - -@pytest.fixture(scope="function") -def complex_multimodal_config( - complex_llm_config: dict[str, Any], - complex_rag_config: dict[str, Any], - complex_search_config: dict[str, Any], - complex_analysis_config: dict[str, Any], -) -> dict[str, Any]: - """Provide comprehensive configuration combining all complex configs. - - This fixture creates a complete configuration suitable for testing - complex workflows that involve multiple components. - """ - return ( - { - "core": { - "log_level": "DEBUG", - "debug": True, - "environment": "test", - "enable_telemetry": False, - }, - "database": { - "postgres_host": "localhost", - "postgres_port": 5432, - "postgres_db": "test_complex_db", - "postgres_user": "test_user", - "postgres_password": "test_pass", - "connection_pool_size": 20, - "enable_ssl": False, - }, - "redis": { - "host": "localhost", - "port": 6379, - "db": 1, - "password": None, - "connection_pool_size": 10, - "socket_timeout": 30, - }, - "features": { - "enable_advanced_analytics": True, - "enable_multimodal_processing": True, - "enable_real_time_updates": True, - "enable_batch_processing": True, - "enable_distributed_processing": False, - }, - } - | complex_llm_config - | complex_rag_config - | complex_search_config - | complex_analysis_config - ) diff --git a/tests/helpers/fixtures/factory_fixtures.py b/tests/helpers/fixtures/factory_fixtures.py index 7317045e..c8cb06eb 100644 --- a/tests/helpers/fixtures/factory_fixtures.py +++ b/tests/helpers/fixtures/factory_fixtures.py @@ -1,510 +1,11 @@ -"""Factory fixtures for flexible test data creation.""" +"""Factory fixtures for tests.""" -from __future__ import annotations - -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, cast -from unittest.mock import AsyncMock, MagicMock +from typing import Any import pytest -from biz_bud.services.factory import ServiceFactory - -if TYPE_CHECKING: - from biz_bud.core.config.schemas import AppConfig - from biz_bud.states.unified import ResearchState - @pytest.fixture -def research_state_factory() -> Callable[..., ResearchState]: - """Create research states with overrides.""" - - def _factory(**overrides: Any) -> ResearchState: - base_state: dict[str, Any] = { - # BaseState fields - "messages": [], - "errors": [], - "status": "pending", - "thread_id": "test-thread-123", - "config": {"enabled": True}, - # ResearchState required fields - "extracted_info": {}, - "synthesis": "", - # SearchMixin fields - "search_query": "default query", - "search_queries": [], - "search_results": [], - "search_history": [], - "visited_urls": [], - "search_status": "idle", - # ValidationMixin fields - "content": "", - "validation_criteria": {"required_fields": []}, - "validation_results": { - "is_valid": False, - "errors": [], - "passed_checks": [], - "failed_checks": [], - }, - "is_valid": False, - "requires_human_feedback": False, - # ResearchStateOptional fields - "query": "default query", - "service_factory_validated": False, - "synthesis_attempts": 0, - "validation_attempts": 0, - "sources": [], - "urls_to_scrape": [], - "scraped_results": {}, - "semantic_extraction_results": {}, - "vector_ids": [], - } - # Deep update to handle nested dicts - for key, value in overrides.items(): - if ( - key in base_state - and isinstance(base_state[key], dict) - and isinstance(value, dict) - ): - # Cast to satisfy type checker since we verified isinstance - dict_value = cast("dict[str, Any]", base_state[key]) - dict_value.update(value) - else: - base_state[key] = value - return cast("ResearchState", base_state) - - return _factory - - -@pytest.fixture -def url_to_rag_state_factory() -> Callable[..., dict[str, Any]]: - """Create URL to RAG states with overrides.""" - - def _factory(**overrides: Any) -> dict[str, Any]: - base_state = { - "input_url": "https://example.com", - "url": "https://example.com", - "urls": ["https://example.com"], - "scraped_content": [], - "processed_content": {}, - "r2r_info": {}, - "repomix_output": None, - "last_processed_page_count": 0, - "scraping_status": "pending", - "processing_status": "pending", - "upload_status": "pending", - "messages": [], - "errors": [], - "config": {}, - "thread_id": "test-thread", - "status": "running", - } - for key, value in overrides.items(): - if ( - key in base_state - and isinstance(base_state[key], dict) - and isinstance(value, dict) - ): - # Cast to satisfy type checker since we verified isinstance - dict_value = cast("dict[str, Any]", base_state[key]) - dict_value.update(value) - else: - base_state[key] = value - return base_state - - return _factory - - -@pytest.fixture -def analysis_state_factory() -> Callable[..., dict[str, Any]]: - """Create analysis states with overrides.""" - - def _factory(**overrides: Any) -> dict[str, Any]: - base_state = { - "task": "Analyze data", - "data": {}, - "analysis_results": {}, - "interpretations": {}, - "analysis_plan": {}, - "visualizations": {}, - "reports": {}, - "messages": [], - "errors": [], - "config": {}, - "thread_id": "test-thread", - "status": "running", - } - for key, value in overrides.items(): - if ( - key in base_state - and isinstance(base_state[key], dict) - and isinstance(value, dict) - ): - # Cast to satisfy type checker since we verified isinstance - dict_value = cast("dict[str, Any]", base_state[key]) - dict_value.update(value) - else: - base_state[key] = value - return base_state - - return _factory - - -@pytest.fixture -def menu_intelligence_state_factory() -> Callable[..., dict[str, Any]]: - """Create menu intelligence states with overrides.""" - - def _factory(**overrides: Any) -> dict[str, Any]: - base_state = { - "query": "chicken", - "menu_items": [], - "menu_analysis": {}, - "insights": {}, - "recommendations": {}, - "extracted_info": {"entities": [], "statistics": [], "key_facts": []}, - "messages": [], - "errors": [], - "config": {}, - "thread_id": "test-thread", - "status": "running", - } - for key, value in overrides.items(): - if ( - key in base_state - and isinstance(base_state[key], dict) - and isinstance(value, dict) - ): - # Cast to satisfy type checker since we verified isinstance - dict_value = cast("dict[str, Any]", base_state[key]) - dict_value.update(value) - else: - base_state[key] = value - return base_state - - return _factory - - -@dataclass -class MockedResearchServices: - """Container for all mocked services used in research workflows.""" - - llm_client: AsyncMock - search_tool: AsyncMock - scraper: AsyncMock - vector_store: AsyncMock - cache_backend: AsyncMock - database: AsyncMock - - -@pytest.fixture -def mocked_research_services() -> MockedResearchServices: - """Provide a container with all mocked services for the research graph.""" - return MockedResearchServices( - llm_client=AsyncMock(), - search_tool=AsyncMock(), - scraper=AsyncMock(), - vector_store=AsyncMock(), - cache_backend=AsyncMock(), - database=AsyncMock(), - ) - - -@dataclass -class MockedAnalysisServices: - """Container for all mocked services used in analysis workflows.""" - - llm_client: AsyncMock - data_processor: AsyncMock - visualization_engine: AsyncMock - report_generator: AsyncMock - - -@pytest.fixture -def mocked_analysis_services() -> MockedAnalysisServices: - """Provide a container with all mocked services for analysis workflows.""" - return MockedAnalysisServices( - llm_client=AsyncMock(), - data_processor=AsyncMock(), - visualization_engine=AsyncMock(), - report_generator=AsyncMock(), - ) - - -@dataclass -class MockedRAGServices: - """Container for all mocked services used in RAG workflows.""" - - llm_client: AsyncMock - scraper: AsyncMock - r2r_client: AsyncMock - vector_store: AsyncMock - - -@pytest.fixture -def mocked_rag_services() -> MockedRAGServices: - """Provide a container with all mocked services for RAG workflows.""" - return MockedRAGServices( - llm_client=AsyncMock(), - scraper=AsyncMock(), - r2r_client=AsyncMock(), - vector_store=AsyncMock(), - ) - - -@dataclass -class MockedURLToRAGServices: - """Container for all mocked services used in URL-to-RAG workflows. - - This fixture bundle provides all the services needed for URL-to-RAG processing, - including web scraping, content analysis, R2R upload, and duplicate checking. - """ - - llm_client: AsyncMock - firecrawl_client: AsyncMock - r2r_client: AsyncMock - vector_store: AsyncMock - repomix_processor: AsyncMock - cache_backend: AsyncMock - - -@pytest.fixture -def mocked_url_to_rag_services() -> MockedURLToRAGServices: - """Provide a container with all mocked services for URL-to-RAG workflows. - - This fixture is specifically designed for URL-to-RAG workflows and includes: - - LLM client for content analysis and status summaries - - Firecrawl client for web scraping and URL discovery - - R2R client for document upload and duplicate checking - - Vector store for semantic search and duplicate detection - - Repomix processor for git repository processing - - Cache backend for storing intermediate results - - Returns: - MockedURLToRAGServices: Container with all necessary mocked services - """ - # Configure LLM client mock with common responses - llm_client = AsyncMock() - llm_client.llm_chat.return_value = "AI-generated content analysis" - llm_client.llm_json.return_value = { - "analysis": "content_suitable", - "confidence": 0.95, - } - - # Configure Firecrawl client mock - firecrawl_client = AsyncMock() - firecrawl_client.map_website.return_value = [ - "https://example.com/page1", - "https://example.com/page2", - "https://example.com/page3", - ] - firecrawl_client.scrape_url.return_value = MagicMock( - success=True, - data=MagicMock( - content="Scraped content", - markdown="# Scraped Content", - metadata=MagicMock( - title="Test Page", sourceURL="https://example.com/page1" - ), - ), - ) - - # Configure R2R client mock - r2r_client = AsyncMock() - r2r_client.users.login.return_value = {"access_token": "test-token"} - r2r_client.collections.list.return_value = MagicMock(results=[]) - r2r_client.collections.create.return_value = MagicMock( - results=MagicMock(id="test-collection") - ) - r2r_client.documents.create.return_value = MagicMock( - results=MagicMock(document_id="test-doc-id") - ) - r2r_client.retrieval.search.return_value = MagicMock( - results=MagicMock(chunk_search_results=[]) - ) - - # Configure vector store mock for duplicate checking - vector_store = AsyncMock() - vector_store.semantic_search.return_value = [] - vector_store.initialize = AsyncMock() - - # Configure repomix processor mock - repomix_processor = AsyncMock() - repomix_processor.pack_repository.return_value = "Processed repository content" - - # Configure cache backend mock - cache_backend = AsyncMock() - cache_backend.get.return_value = None - cache_backend.set.return_value = None - cache_backend.setex.return_value = None - - return MockedURLToRAGServices( - llm_client=llm_client, - firecrawl_client=firecrawl_client, - r2r_client=r2r_client, - vector_store=vector_store, - repomix_processor=repomix_processor, - cache_backend=cache_backend, - ) - - -@pytest.fixture -def mock_service_factory_builder() -> Callable[..., MagicMock]: - """Create customized mock service factories.""" - - def _builder(**service_overrides: Any) -> MagicMock: - factory = MagicMock() - - # Default services - default_services = { - "LangchainLLMClient": AsyncMock(), - "WebSearchTool": AsyncMock(), - "FirecrawlScraper": AsyncMock(), - "QdrantVectorStore": AsyncMock(), - "RedisCacheBackend": AsyncMock(), - "PostgresStore": AsyncMock(), - } - - # Apply overrides - services = default_services | service_overrides - - async def mock_get_service(service_class: Any) -> Any: - class_name = ( - service_class.__name__ - if hasattr(service_class, "__name__") - else str(service_class) - ) - return services.get(class_name, AsyncMock()) - - factory.get_service.side_effect = mock_get_service - - # Add the new get_llm_for_node method for centralized config architecture - async def mock_get_llm_for_node( - node_context: str, - llm_profile_override: str | None = None, - temperature_override: float | None = None, - max_tokens_override: int | None = None, - **kwargs: Any, - ) -> Any: - return services.get("LangchainLLMClient", AsyncMock()) - - factory.get_llm_for_node = mock_get_llm_for_node - - # Make factory usable as a context manager - mock_lifespan = MagicMock() - mock_lifespan.__aenter__.return_value = factory - mock_lifespan.__aexit__.return_value = None - factory.lifespan.return_value = mock_lifespan - - return factory - - return _builder - - -@pytest.fixture -def mock_llm_response_factory() -> Callable[..., Any]: - """Create mock LLM responses with custom content.""" - - def _factory(content: str = "Default response", **kwargs: Any) -> Any: - try: - from langchain_core.messages import AIMessage - message_constructor = AIMessage - except ImportError: - # Fallback for environments without langchain_core - def message_constructor(content: str, **kwargs: Any) -> Any: - return {"content": content, **kwargs} - - defaults = { - "additional_kwargs": {}, - "response_metadata": { - "model": "gpt-4o-mini", - "usage": { - "prompt_tokens": 10, - "completion_tokens": 20, - "total_tokens": 30, - }, - }, - } - - # Merge kwargs with defaults - for key, value in kwargs.items(): - if key in defaults and isinstance(value, dict): - # Cast to dict for type checker compatibility - dict_value = defaults[key] - dict_value.update(value) - else: - defaults[key] = value - - return message_constructor(content=content, **defaults) - - return _factory - - -@pytest.fixture -def mock_search_result_factory() -> Callable[..., dict[str, Any]]: - """Create mock search results.""" - - def _factory( - title: str = "Test Result", - url: str = "https://example.com", - snippet: str = "Test snippet", - **kwargs: Any, - ) -> dict[str, Any]: - result = { - "title": title, - "url": url, - "snippet": snippet, - "provider": kwargs.get("provider", "tavily"), - "published_date": kwargs.get("published_date"), - "metadata": kwargs.get("metadata", {}), - } - # Add any additional fields - for key, value in kwargs.items(): - if key not in result: - result[key] = value - return result - - return _factory - - -@pytest.fixture -def mock_scraped_content_factory() -> Callable[..., dict[str, Any]]: - """Create mock scraped content.""" - - def _factory( - content: str = "Test content", - title: str = "Test Page", - url: str = "https://example.com", - **kwargs: Any, - ) -> dict[str, Any]: - return { - "content": content, - "markdown": kwargs.get("markdown", f"# {title}\n\n{content}"), - "title": title, - "url": url, - "metadata": kwargs.get( - "metadata", - { - "description": "Test description", - "author": "Test Author", - "published_date": "2024-01-01", - }, - ), - "content_type": kwargs.get("content_type", "text/html"), - "error": kwargs.get("error"), - } - - return _factory - - -@pytest.fixture(scope="module") -async def module_service_factory(app_config: AppConfig): - """Provide a module-scoped ServiceFactory that is initialized once per test module. - - This fixture provides better performance for integration tests by reusing - the same service factory instance across all tests in a module. - - Note: Uses the app_config fixture which is already module-scoped. - """ - factory = ServiceFactory(app_config) - async with factory.lifespan(): - yield factory +def mock_factory() -> dict[str, Any]: + """Provide a mock factory for testing.""" + return {"type": "mock_factory", "initialized": True} diff --git a/tests/helpers/fixtures/mock_fixtures.py b/tests/helpers/fixtures/mock_fixtures.py index acb61bc2..463a78c2 100644 --- a/tests/helpers/fixtures/mock_fixtures.py +++ b/tests/helpers/fixtures/mock_fixtures.py @@ -1,460 +1,11 @@ -"""Mock fixtures for testing.""" +"""Mock fixtures for tests.""" -from __future__ import annotations +from unittest.mock import MagicMock -from typing import Any -from unittest.mock import AsyncMock, MagicMock - -import pandas as pd import pytest -from langchain_core.messages import AIMessage, HumanMessage, SystemMessage - -from biz_bud.services.llm.client import LangchainLLMClient -@pytest.fixture(scope="function") -def mock_redis() -> AsyncMock: - """Provide mock Redis client.""" - mock = AsyncMock() - mock.get = AsyncMock(return_value=None) - mock.set = AsyncMock(return_value=True) - mock.setex = AsyncMock(return_value=True) - mock.delete = AsyncMock(return_value=1) - mock.exists = AsyncMock(return_value=False) - mock.keys = AsyncMock(return_value=[]) - mock.mget = AsyncMock(return_value=[]) - mock.ttl = AsyncMock(return_value=-2) - mock.ping = AsyncMock(return_value=True) - return mock - - -@pytest.fixture(scope="function") -def mock_database() -> AsyncMock: - """Provide mock database connection.""" - mock = AsyncMock() - mock.fetch = AsyncMock(return_value=[]) - mock.fetchrow = AsyncMock(return_value=None) - mock.execute = AsyncMock(return_value="INSERT 0 1") - mock.close = AsyncMock() - return mock - - -@pytest.fixture(scope="function") -def mock_database_pool() -> AsyncMock: - """Provide mock database pool.""" - mock = AsyncMock() - mock.acquire = AsyncMock() - mock.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_database()) - mock.acquire.return_value.__aexit__ = AsyncMock() - mock.close = AsyncMock() - return mock - - -@pytest.fixture(scope="function") -def mock_vector_store() -> AsyncMock: - """Provide mock vector store client.""" - mock = AsyncMock() - mock.search = AsyncMock(return_value=[]) - mock.upsert = AsyncMock(return_value=True) - mock.delete = AsyncMock(return_value=True) - mock.get_collection_info = AsyncMock(return_value={"vectors_count": 0}) - return mock - - -@pytest.fixture(scope="function") -def mock_llm_response() -> Any: - """Provide mock LLM response.""" - usage_data: dict[str, int] = { - "prompt_tokens": 10, - "completion_tokens": 20, - "total_tokens": 30, - } - response_metadata: dict[str, Any] = { - "model": "gpt-4o-mini", - "usage": usage_data, - } - return AIMessage( - content="This is a test response from the LLM.", - additional_kwargs={}, - response_metadata=response_metadata, - ) - - -@pytest.fixture(scope="function") -def mock_llm_service() -> AsyncMock: - """Provide mock LLM service.""" - from langchain_core.messages import AIMessage - - # Create mock response directly instead of calling fixture - mock_response = AIMessage( - content="This is a test response from the LLM.", - additional_kwargs={}, - response_metadata={ - "model": "gpt-4o-mini", - "usage": { - "prompt_tokens": 10, - "completion_tokens": 20, - "total_tokens": 30, - }, - }, - ) - - mock = AsyncMock(spec=LangchainLLMClient) - mock.generate = AsyncMock(return_value=mock_response) - mock.generate_json = AsyncMock(return_value={"result": "test"}) - mock.generate_with_retry = AsyncMock(return_value=mock_response) - mock.llm_json = AsyncMock(return_value={"result": "test"}) - return mock - - -@pytest.fixture(scope="function") -def mock_search_tool() -> AsyncMock: - """Provide mock search tool.""" - mock = AsyncMock() - mock.search = AsyncMock( - return_value=[ - { - "title": "Test Result 1", - "url": "https://example.com/1", - "snippet": "This is a test search result", - "provider": "tavily", - }, - { - "title": "Test Result 2", - "url": "https://example.com/2", - "snippet": "Another test search result", - "provider": "tavily", - }, - ] - ) - return mock - - -@pytest.fixture(scope="function") -def mock_scraper() -> AsyncMock: - """Provide mock web scraper.""" - mock = AsyncMock() - mock.scrape = AsyncMock( - return_value=MagicMock( - content="Test content", - title="Test Page", - error=None, - metadata=MagicMock( - description="Test description", - author="Test Author", - published_date="2024-01-01", - ), - content_type=MagicMock(value="text/html"), - ) - ) - return mock - - -@pytest.fixture(scope="function") -def mock_semantic_extractor() -> AsyncMock: - """Provide mock semantic extractor.""" - mock = AsyncMock() - mock.extract = AsyncMock( - return_value={ - "entities": ["Entity1", "Entity2"], - "topics": ["Topic1", "Topic2"], - "summary": "This is a test summary", - "key_points": ["Point 1", "Point 2"], - } - ) - return mock - - -@pytest.fixture(scope="function") -def sample_messages() -> list[Any]: - """Provide sample conversation messages.""" - return [ - SystemMessage(content="You are a helpful assistant."), - HumanMessage(content="What is machine learning?"), - AIMessage(content="Machine learning is a subset of artificial intelligence..."), - HumanMessage(content="Can you give me an example?"), - AIMessage(content="Sure! An example of machine learning is..."), - ] - - -@pytest.fixture(scope="function") -def sample_search_results() -> list[dict[str, Any]]: - """Provide sample search results.""" - return [ - { - "title": "Introduction to Machine Learning", - "url": "https://example.com/ml-intro", - "snippet": "Machine learning is a method of data analysis...", - "provider": "tavily", - "published_date": "2024-01-15", - }, - { - "title": "Deep Learning Fundamentals", - "url": "https://example.com/dl-basics", - "snippet": "Deep learning is a subset of machine learning...", - "provider": "arxiv", - "published_date": "2024-01-10", - }, - { - "title": "Neural Networks Explained", - "url": "https://example.com/nn-explained", - "snippet": "Neural networks are computing systems inspired by...", - "provider": "jina", - "published_date": "2024-01-20", - }, - ] - - -@pytest.fixture(scope="function") -def sample_extraction_result() -> dict[str, Any]: - """Provide sample extraction result.""" - return { - "entities": { - "organizations": ["OpenAI", "Google", "Microsoft"], - "technologies": ["GPT-4", "BERT", "Transformer"], - "concepts": ["NLP", "Computer Vision", "Reinforcement Learning"], - }, - "topics": [ - {"name": "Machine Learning", "relevance": 0.95}, - {"name": "Artificial Intelligence", "relevance": 0.90}, - {"name": "Data Science", "relevance": 0.75}, - ], - "summary": "This document discusses advances in machine learning...", - "key_points": [ - "ML models are becoming more sophisticated", - "Training requires significant computational resources", - "Applications span multiple industries", - ], - "sentiment": "positive", - "language": "en", - } - - -@pytest.fixture(scope="function") -def mock_http_response() -> MagicMock: - """Provide mock HTTP response.""" - mock = MagicMock() - mock.status_code = 200 - mock.text = "Test response content" - mock.json = MagicMock(return_value={"status": "success", "data": "test"}) - mock.headers = {"Content-Type": "application/json"} - return mock - - -@pytest.fixture(scope="function") -def sample_dataframe() -> pd.DataFrame: - """Provide a sample pandas DataFrame for testing.""" - return pd.DataFrame( - { - "A": [1, 2, 3, 4, 5], - "B": [10.0, 20.0, 30.0, 40.0, 50.0], - "C": ["foo", "bar", "baz", "qux", "quux"], - } - ) - - -@pytest.fixture(scope="function") -def mock_firecrawl_scrape_result() -> dict[str, Any]: - """Return a realistic, successful Firecrawl scrape result.""" - return { - "success": True, - "data": { - "content": "Welcome to example.com. This is a comprehensive guide to understanding artificial intelligence and its applications in modern technology.", - "markdown": "# Welcome\n\nThis is a comprehensive guide to understanding artificial intelligence and its applications in modern technology.\n\n## Key Topics\n\n- Machine Learning\n- Deep Learning\n- Natural Language Processing", - "metadata": { - "title": "Example Domain - AI Guide", - "description": "This is an example domain for AI education.", - "sourceURL": "https://example.com", - "language": "en", - "publishedTime": "2024-01-15T10:00:00Z", - "author": "AI Research Team", - "keywords": ["AI", "machine learning", "technology"], - }, - "llm_extraction": { - "summary": "A comprehensive resource on AI technologies", - "main_topics": ["AI fundamentals", "ML algorithms", "Applications"], - "key_facts": [ - "AI is transforming industries", - "ML requires large datasets", - "Deep learning mimics neural networks", - ], - }, - }, - "statusCode": 200, - } - - -@pytest.fixture(scope="function") -def mock_firecrawl_batch_scrape_results() -> list[dict[str, Any]]: - """Return multiple Firecrawl scrape results for batch operations.""" - return [ - { - "success": True, - "data": { - "content": f"Content for page {i}. This page discusses topic {i}.", - "markdown": f"# Page {i}\n\nContent for page {i}.", - "metadata": { - "title": f"Page {i} Title", - "sourceURL": f"https://example.com/page{i}", - }, - }, - } - for i in range(1, 4) - ] - - -@pytest.fixture(scope="function") -def mock_firecrawl_error_result() -> dict[str, Any]: - """Return a Firecrawl error result.""" - return { - "success": False, - "error": "Failed to scrape the URL", - "statusCode": 404, - "details": "The requested page could not be found", - } - - -@pytest.fixture(scope="function") -def mock_r2r_upload_response() -> dict[str, Any]: - """Return a successful R2R document upload response.""" - return { - "results": { - "document_id": "doc_12345", - "collection_id": "coll_67890", - "title": "Uploaded Document", - "chunks_created": 5, - "embedding_status": "completed", - "metadata": { - "source": "web_scrape", - "url": "https://example.com", - "upload_timestamp": "2024-01-15T12:00:00Z", - }, - }, - "status": "success", - } - - -@pytest.fixture(scope="function") -def mock_r2r_search_response() -> dict[str, Any]: - """Return R2R search/retrieval results.""" - return { - "results": [ - { - "chunk_id": "chunk_001", - "document_id": "doc_12345", - "content": "AI is revolutionizing how we process information...", - "score": 0.95, - "metadata": { - "page": 1, - "section": "introduction", - }, - }, - { - "chunk_id": "chunk_002", - "document_id": "doc_12345", - "content": "Machine learning algorithms can identify patterns...", - "score": 0.87, - "metadata": { - "page": 2, - "section": "ml_basics", - }, - }, - ], - "total_results": 2, - "query": "artificial intelligence applications", - } - - -@pytest.fixture(scope="function") -def mock_r2r_collection_info() -> dict[str, Any]: - """Return R2R collection information.""" - return { - "collection_id": "coll_67890", - "name": "Research Documents", - "description": "Collection of AI research documents", - "document_count": 42, - "total_chunks": 1337, - "created_at": "2024-01-01T00:00:00Z", - "last_updated": "2024-01-15T12:00:00Z", - "metadata": { - "owner": "research_team", - "tags": ["AI", "research", "technology"], - }, - } - - -@pytest.fixture(scope="function") -def mock_firecrawl_app( - mock_firecrawl_scrape_result: dict[str, Any], - mock_firecrawl_batch_scrape_results: list[dict[str, Any]], -) -> AsyncMock: - """Provide a mock Firecrawl app with common methods.""" - mock = AsyncMock() - - # Single URL scraping - mock.scrape_url = AsyncMock() - mock.scrape_url.return_value = mock_firecrawl_scrape_result - - # Batch scraping - mock.batch_scrape_urls = AsyncMock() - mock.batch_scrape_urls.return_value = mock_firecrawl_batch_scrape_results - - # Search functionality - mock.search = AsyncMock() - mock.search.return_value = { - "success": True, - "data": [ - {"url": "https://example.com/result1", "title": "Result 1"}, - {"url": "https://example.com/result2", "title": "Result 2"}, - ], - } - - return mock - - -@pytest.fixture(scope="function") -def mock_r2r_client( - mock_r2r_upload_response: dict[str, Any], - mock_r2r_search_response: dict[str, Any], - mock_r2r_collection_info: dict[str, Any], -) -> AsyncMock: - """Provide a mock R2R client with common methods.""" - mock = AsyncMock() - - # Document operations - mock.upload_document = AsyncMock() - mock.upload_document.return_value = mock_r2r_upload_response - - mock.delete_document = AsyncMock() - mock.delete_document.return_value = { - "status": "success", - "document_id": "doc_12345", - } - - # Search/retrieval - mock.search = AsyncMock() - mock.search.return_value = mock_r2r_search_response - - mock.retrieve = AsyncMock() - mock.retrieve.return_value = mock_r2r_search_response - - # Collection operations - mock.get_collection_info = AsyncMock() - mock.get_collection_info.return_value = mock_r2r_collection_info - - mock.create_collection = AsyncMock() - mock.create_collection.return_value = { - "status": "success", - "collection_id": "coll_new_123", - } - - # Health check - mock.health = AsyncMock() - mock.health.return_value = {"status": "healthy", "version": "0.2.0"} - - return mock - - -def create_mock_service_factory() -> MagicMock: - """Create a mock service factory for testing.""" +@pytest.fixture +def mock_client() -> MagicMock: + """Provide a mock client for testing.""" return MagicMock() diff --git a/tests/helpers/fixtures/state_fixtures.py b/tests/helpers/fixtures/state_fixtures.py index 187e205c..d60bbb89 100644 --- a/tests/helpers/fixtures/state_fixtures.py +++ b/tests/helpers/fixtures/state_fixtures.py @@ -1,227 +1,15 @@ -"""State fixtures for tests using the StateBuilder factory.""" - +"""State fixtures for tests.""" +from typing import Any import pytest -from tests.helpers.factories.state_factories import StateBuilder - @pytest.fixture -def state_builder() -> StateBuilder: - """Provide a StateBuilder instance for creating test states.""" - return StateBuilder() - - -@pytest.fixture -def base_state(state_builder: StateBuilder) -> dict[str, object]: - """Create a minimal, valid state for most nodes.""" - return state_builder.with_human_message("Initial query").build() - - -@pytest.fixture -def research_state(base_state: dict[str, object]) -> dict[str, object]: - """Create a state pre-populated for the research graph.""" - return base_state | { - "query": "What is AI?", - "search_queries": [], - "search_results": [], - "sources": [], - "extracted_info": {"entities": [], "statistics": [], "key_facts": []}, - "synthesis": "", - "search_history": [], - "visited_urls": [], - "search_status": "idle", - "synthesis_attempts": 0, - "validation_attempts": 0, - # ValidationMixin fields - "content": "", - "validation_criteria": {"required_fields": []}, - "validation_results": { - "is_valid": False, - "errors": [], - "passed_checks": [], - "failed_checks": [], - }, - "is_valid": False, - "requires_human_feedback": False, - } - - -@pytest.fixture -def url_to_rag_state(base_state: dict[str, object]) -> dict[str, object]: - """Create a state pre-populated for the URL to RAG graph.""" - return base_state | { +def sample_state() -> dict[str, Any]: + """Provide a sample state for testing.""" + return { "input_url": "https://example.com", - "scraped_content": [], - "processed_content": {}, - "r2r_info": {}, - "urls": ["https://example.com"], - "scraping_status": "pending", - "processing_status": "pending", - "upload_status": "pending", - } - - -@pytest.fixture -def analysis_workflow_state(base_state: dict[str, object]) -> dict[str, object]: - """Create a state pre-populated for analysis workflows.""" - return base_state | { - "task": "Analyze market trends", - "data": {}, - "analysis_results": {}, - "interpretations": {}, - "analysis_plan": {}, - "visualizations": {}, - "reports": {}, - } - - -@pytest.fixture -def menu_intelligence_state(base_state: dict[str, object]) -> dict[str, object]: - """Create a state pre-populated for menu intelligence workflows.""" - return base_state | { - "query": "chicken", - "menu_items": [], - "menu_analysis": {}, - "insights": {}, - "recommendations": {}, - "extracted_info": {"entities": [], "statistics": [], "key_facts": []}, - } - - -@pytest.fixture -def validated_state(base_state: dict[str, object]) -> dict[str, object]: - """Create a state that has passed validation.""" - return base_state | { - "status": "validated", - "is_valid": True, - "validation_results": { - "is_valid": True, - "errors": [], - "passed_checks": ["required_fields", "data_types"], - "failed_checks": [], - }, - } - - -@pytest.fixture -def state_with_errors(base_state: dict[str, object]) -> dict[str, object]: - """Create a state with errors for error handling tests.""" - state = base_state.copy() - from typing import cast - - cast("list[dict[str, str]]", state["errors"]).extend( - [ - { - "message": "Test error 1", - "code": "TEST_ERROR_1", - "severity": "error", - "node": "test_node", - "timestamp": "2024-01-01T00:00:00Z", - }, - { - "message": "Test warning", - "code": "TEST_WARNING", - "severity": "warning", - "node": "test_node", - "timestamp": "2024-01-01T00:00:01Z", - }, - ] - ) - state["status"] = "error" - return state - - -@pytest.fixture -def state_with_search_results(research_state: dict[str, object]) -> dict[str, object]: - """Create a research state with search results populated.""" - return research_state | { - "search_results": [ - { - "title": "Understanding AI", - "url": "https://example.com/ai-guide", - "snippet": "Artificial Intelligence (AI) is...", - "description": "A comprehensive guide to AI", - }, - { - "title": "AI Applications", - "url": "https://example.com/ai-apps", - "snippet": "AI is being used in various fields...", - "description": "Overview of AI applications", - }, - ], - "search_status": "completed", - "search_history": [ - { - "query": "What is AI?", - "result_count": 2, - "timestamp": "2024-01-01T10:00:00Z", - } - ], - } - - -@pytest.fixture -def state_with_extracted_info( - state_with_search_results: dict[str, object], -) -> dict[str, object]: - """Create a research state with extracted information.""" - state = state_with_search_results.copy() - state["extracted_info"] = { - "source_0": { - "url": "https://example.com/ai-guide", - "title": "Understanding AI", - "key_findings": [ - "AI mimics human intelligence", - "Machine learning is a subset of AI", - ], - "extracted_data": { - "definition": "Artificial Intelligence is the simulation of human intelligence", - "types": ["Narrow AI", "General AI", "Super AI"], - }, - "summary": "A comprehensive overview of AI concepts", - }, - "entities": ["Artificial Intelligence", "Machine Learning"], - "statistics": ["85% of businesses use AI"], - "key_facts": ["AI was coined in 1956"], - } - state["sources"] = [ - { - "key": "source_0", - "url": "https://example.com/ai-guide", - "title": "Understanding AI", - "relevance": 0.95, - } - ] - return state - - -@pytest.fixture -def completed_research_state( - state_with_extracted_info: dict[str, object], -) -> dict[str, object]: - """Create a research state with completed synthesis.""" - return state_with_extracted_info | { - "synthesis": "Artificial Intelligence (AI) is a transformative technology that simulates human intelligence. It encompasses various approaches including machine learning, which allows systems to learn from data. AI has evolved significantly since its inception in 1956 and is now used by 85% of businesses worldwide. There are three main types of AI: Narrow AI (specialized for specific tasks), General AI (human-level intelligence), and Super AI (surpassing human intelligence).", - "is_valid": True, - "status": "completed", - "synthesis_attempts": 1, - } - - -@pytest.fixture -def state_after_search( - research_state_factory, mock_search_result_factory -) -> dict[str, object]: - """Provide a state that has just completed the search phase.""" - results = [mock_search_result_factory() for _ in range(5)] - return research_state_factory(search_results=results) - - -@pytest.fixture -def state_after_synthesis(state_after_search: dict[str, object]) -> dict[str, object]: - """Provide a state that has a synthesized report, ready for validation.""" - state = state_after_search.copy() - state["synthesis"] = "This is a synthesized report based on the search results." - return state + "status": "pending", + "results": [] + } diff --git a/tests/helpers/mock_helpers.py b/tests/helpers/mock_helpers.py index b3744a7f..1e5cc5ed 100644 --- a/tests/helpers/mock_helpers.py +++ b/tests/helpers/mock_helpers.py @@ -1,140 +1,15 @@ -"""Helper utilities for creating properly typed mocks in tests.""" +"""Mock helpers for tests.""" -from typing import Any, Protocol +from typing import Any from unittest.mock import AsyncMock, MagicMock -class MockWithAssertions(Protocol): - """Protocol for mocks that need assertion methods.""" - - assert_called_once: MagicMock - assert_called_with: MagicMock - assert_called_once_with: MagicMock - assert_not_called: MagicMock - assert_any_call: MagicMock - call_count: int - call_args: Any - call_args_list: list[Any] - - -def create_async_mock_with_assertions( - return_value: Any = None, side_effect: Any = None -) -> AsyncMock: - """Create an AsyncMock with all assertion methods properly initialized. - - Args: - return_value: The value to return when the mock is called - side_effect: Side effect for the mock - - Returns: - AsyncMock with assertion methods initialized - """ - mock = AsyncMock(return_value=return_value, side_effect=side_effect) - - # Initialize assertion methods - mock.assert_called_once = MagicMock() - mock.assert_called_with = MagicMock() - mock.assert_called_once_with = MagicMock() - mock.assert_not_called = MagicMock() - mock.assert_any_call = MagicMock() - mock.call_count = 0 - - return mock - - -def create_mock_redis_client() -> AsyncMock: - """Create a properly mocked Redis client with all necessary methods. - - Returns: - AsyncMock configured as a Redis client - """ - mock_redis = AsyncMock() - - # Setup Redis methods - mock_redis.get = create_async_mock_with_assertions() - mock_redis.set = create_async_mock_with_assertions() - mock_redis.delete = create_async_mock_with_assertions() - mock_redis.scan_iter = MagicMock() - mock_redis.ping = create_async_mock_with_assertions() - mock_redis.close = create_async_mock_with_assertions() - - # Setup scan_iter to return an async iterator - async def async_scan_iter(*args: Any, **kwargs: Any) -> Any: - for item in mock_redis.scan_iter.return_value: - yield item - - mock_redis.scan_iter = MagicMock( - side_effect=lambda *args, **kwargs: async_scan_iter(*args, **kwargs) - ) - mock_redis.scan_iter.return_value = [] - mock_redis.scan_iter.assert_called_with = MagicMock() - - return mock_redis - - -def create_mock_llm_client() -> AsyncMock: - """Create a properly mocked LLM client with all necessary methods. - - Returns: - AsyncMock configured as an LLM client - """ - mock_llm = AsyncMock() - - # Setup LLM methods - mock_llm.llm_chat = create_async_mock_with_assertions() - mock_llm.llm_json = create_async_mock_with_assertions() - mock_llm.llm_chat_stream = create_async_mock_with_assertions() - mock_llm.llm_chat_with_stream_callback = create_async_mock_with_assertions() - - return mock_llm - - -def create_mock_r2r_client() -> AsyncMock: - """Create a properly mocked R2R client with all necessary methods. - - Returns: - AsyncMock configured as an R2R client - """ - mock_r2r = AsyncMock() - - # Setup collections - mock_r2r.collections = MagicMock() - mock_r2r.collections.list = create_async_mock_with_assertions() - mock_r2r.collections.create = create_async_mock_with_assertions() - mock_r2r.collections.delete = create_async_mock_with_assertions() - - # Setup documents - mock_r2r.documents = MagicMock() - mock_r2r.documents.create = create_async_mock_with_assertions() - mock_r2r.documents.search = create_async_mock_with_assertions() - mock_r2r.documents.delete = create_async_mock_with_assertions() - - # Setup retrieval - mock_r2r.retrieval = MagicMock() - mock_r2r.retrieval.search = create_async_mock_with_assertions() - - return mock_r2r - - -def create_mock_service_factory() -> tuple[MagicMock, AsyncMock]: - """Create a properly mocked service factory with LLM client. - - Returns: - Tuple of (factory, llm_client) - """ - factory = MagicMock() - llm_client = create_mock_llm_client() - - # Setup lifespan context manager - lifespan_manager = AsyncMock() - lifespan_manager.__aenter__ = AsyncMock(return_value=factory) - lifespan_manager.__aexit__ = AsyncMock(return_value=None) - factory.lifespan = MagicMock(return_value=lifespan_manager) - - # Setup service getters - factory.get_service = AsyncMock(return_value=llm_client) - factory.get_llm_client = AsyncMock(return_value=llm_client) - factory.get_r2r_client = AsyncMock(return_value=create_mock_r2r_client()) - factory.get_redis_backend = AsyncMock(return_value=create_mock_redis_client()) - - return factory, llm_client +def create_mock_redis_client() -> MagicMock: + """Create a mock Redis client for testing.""" + mock_client = MagicMock() + mock_client.ping = AsyncMock(return_value=True) + mock_client.get = AsyncMock(return_value=None) + mock_client.set = AsyncMock(return_value=True) + mock_client.delete = AsyncMock(return_value=1) + mock_client.close = AsyncMock() + return mock_client diff --git a/tests/helpers/mocks/__init__.py b/tests/helpers/mocks/__init__.py index f1378846..fc3c38f5 100644 --- a/tests/helpers/mocks/__init__.py +++ b/tests/helpers/mocks/__init__.py @@ -1 +1 @@ -"""Mock objects and utilities for testing.""" +"""Test mocks package.""" diff --git a/tests/helpers/mocks/mock_builders.py b/tests/helpers/mocks/mock_builders.py index a5739d0a..0130576b 100644 --- a/tests/helpers/mocks/mock_builders.py +++ b/tests/helpers/mocks/mock_builders.py @@ -1,368 +1,26 @@ -"""Mock builders for creating complex mocks.""" - -from __future__ import annotations +"""Mock builders for tests.""" from typing import Any -from unittest.mock import AsyncMock - -from langchain_core.messages import AIMessage +from unittest.mock import AsyncMock, MagicMock -class MockLLMBuilder: - """Builder for creating mock LLM services with configurable behavior.""" +class MockBuilder: + """Builder for creating test mocks.""" def __init__(self) -> None: - """Initialize mock LLM builder.""" - self._mock = AsyncMock() - self._responses: list[str] = [] - self._json_responses: list[dict[str, Any]] = [] - self._errors: list[Exception] = [] - self._call_count = 0 - self._prompt_tokens = 0 - self._completion_tokens = 0 + """Initialize the mock builder.""" + self._mock = MagicMock() - def with_response( - self, content: str, metadata: dict[str, Any] | None = None - ) -> MockLLMBuilder: - """Add a response to the mock.""" - self._responses.append(content) - if metadata: - self._mock.response_metadata = metadata + def with_method(self, name: str, return_value: Any = None) -> "MockBuilder": + """Add a method to the mock.""" + setattr(self._mock, name, MagicMock(return_value=return_value)) return self - def with_json_response(self, data: dict[str, Any]) -> MockLLMBuilder: - """Add a JSON response to the mock.""" - self._json_responses.append(data) + def with_async_method(self, name: str, return_value: Any = None) -> "MockBuilder": + """Add an async method to the mock.""" + setattr(self._mock, name, AsyncMock(return_value=return_value)) return self - def with_error(self, error: Exception) -> MockLLMBuilder: - """Add an error to be raised.""" - self._errors.append(error) - return self - - def with_token_usage( - self, prompt_tokens: int, completion_tokens: int - ) -> MockLLMBuilder: - """Set token usage for responses.""" - self._prompt_tokens = prompt_tokens - self._completion_tokens = completion_tokens - self._mock.response_metadata = { - "usage": { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": prompt_tokens + completion_tokens, - } - } - return self - - def build(self) -> AsyncMock: - """Build and return the mock LLM service.""" - - async def generate_side_effect(*args, **kwargs): - nonlocal self - if self._errors and self._call_count < len(self._errors): - error = self._errors[self._call_count] - self._call_count += 1 - raise error - - content = "Default response" - if self._responses and self._call_count < len(self._responses): - content = self._responses[self._call_count] - self._call_count += 1 - elif self._json_responses and self._call_count < len(self._json_responses): - # Convert JSON response to string - import json - - content = json.dumps(self._json_responses[self._call_count]) - self._call_count += 1 - - token_usage: dict[str, int] = { - "prompt_tokens": self._prompt_tokens, - "completion_tokens": self._completion_tokens, - "total_tokens": self._prompt_tokens + self._completion_tokens, - } - response_metadata: dict[str, Any] = { - "model": "mock-model", - "token_usage": token_usage, - } - return AIMessage( - content=content, - response_metadata=response_metadata, - ) - - async def generate_json_side_effect( - *args: Any, **kwargs: Any - ) -> dict[str, Any]: - nonlocal self - if self._json_responses and self._call_count < len(self._json_responses): - response = self._json_responses[self._call_count] - self._call_count += 1 - return response - return {"result": "default"} - - async def astream_side_effect(*args: Any, **kwargs: Any) -> Any: - """Async generator that yields chunks for streaming.""" - yield await generate_side_effect(*args, **kwargs) - - self._mock.generate = AsyncMock(side_effect=generate_side_effect) - self._mock.generate_json = AsyncMock(side_effect=generate_json_side_effect) - self._mock.generate_with_retry = self._mock.generate - - # Add streaming support - self._mock.astream = astream_side_effect - self._mock.ainvoke = AsyncMock(side_effect=generate_side_effect) - - # Add call_model_lc method for LangChain compatibility - self._mock.call_model_lc = AsyncMock(side_effect=generate_side_effect) - - return self._mock - - -class MockSearchToolBuilder: - """Builder for creating mock search tools with configurable results.""" - - def __init__(self) -> None: - """Initialize mock search tool builder.""" - self._mock = AsyncMock() - self._results_by_query: dict[str, list[dict[str, Any]]] = {} - self._default_results: list[dict[str, Any]] = [] - self._errors_by_query: dict[str, Exception] = {} - - def with_results_for_query( - self, - query: str, - results: list[dict[str, Any]], - ) -> MockSearchToolBuilder: - """Add specific results for a query.""" - self._results_by_query[query.lower()] = results - return self - - def with_default_results( - self, results: list[dict[str, Any]] - ) -> MockSearchToolBuilder: - """Set default results for any query.""" - self._default_results = results - return self - - def with_error_for_query( - self, query: str, error: Exception - ) -> MockSearchToolBuilder: - """Add an error for a specific query.""" - self._errors_by_query[query.lower()] = error - return self - - def build(self) -> AsyncMock: - """Build and return the mock search tool.""" - - async def search_side_effect( - query: str, - provider_name: str | None = None, - max_results: int | None = None, - **kwargs: Any, - ) -> list[dict[str, Any]]: - query_lower = query.lower() - - # Check for errors first - if query_lower in self._errors_by_query: - raise self._errors_by_query[query_lower] - - # Return specific results for query - if query_lower in self._results_by_query: - results = self._results_by_query[query_lower] - return results[:max_results] if max_results else results - # Return default results - if self._default_results: - if max_results: - return self._default_results[:max_results] - return self._default_results - - # Generate generic results - return [ - { - "title": f"Result for '{query}'", - "url": f"https://example.com/search?q={query}", - "snippet": f"Search result snippet for query: {query}", - "provider": provider_name or "default", - } - ] - - self._mock.search = AsyncMock(side_effect=search_side_effect) - return self._mock - - -class MockDatabaseBuilder: - """Builder for creating mock database connections with configurable behavior.""" - - def __init__(self) -> None: - """Initialize mock database builder.""" - self._mock = AsyncMock() - self._fetch_results: dict[str, list[dict[str, Any]]] = {} - self._fetchrow_results: dict[str, dict[str, Any] | None] = {} - self._execute_results: dict[str, str] = {} - self._errors: dict[str, Exception] = {} - - def with_fetch_result( - self, - query_pattern: str, - results: list[dict[str, Any]], - ) -> MockDatabaseBuilder: - """Add fetch results for a query pattern.""" - self._fetch_results[query_pattern] = results - return self - - def with_fetchrow_result( - self, - query_pattern: str, - result: dict[str, Any] | None, - ) -> MockDatabaseBuilder: - """Add fetchrow result for a query pattern.""" - self._fetchrow_results[query_pattern] = result - return self - - def with_execute_result( - self, - query_pattern: str, - result: str = "INSERT 0 1", - ) -> MockDatabaseBuilder: - """Add execute result for a query pattern.""" - self._execute_results[query_pattern] = result - return self - - def with_error(self, query_pattern: str, error: Exception) -> MockDatabaseBuilder: - """Add an error for a query pattern.""" - self._errors[query_pattern] = error - return self - - def build(self) -> AsyncMock: - """Build and return the mock database connection.""" - - async def fetch_side_effect(query: str, *args: Any) -> list[dict[str, Any]]: - for pattern, error in self._errors.items(): - if pattern in query: - raise error - - for pattern, results in self._fetch_results.items(): - if pattern in query: - return results - - return [] - - async def fetchrow_side_effect(query: str, *args: Any) -> dict[str, Any] | None: - for pattern, error in self._errors.items(): - if pattern in query: - raise error - - for pattern, result in self._fetchrow_results.items(): - if pattern in query: - return result - - return None - - async def execute_side_effect(query: str, *args: Any) -> str: - for pattern, error in self._errors.items(): - if pattern in query: - raise error - - for pattern, result in self._execute_results.items(): - if pattern in query: - return result - - return "INSERT 0 1" - - self._mock.fetch = AsyncMock(side_effect=fetch_side_effect) - self._mock.fetchrow = AsyncMock(side_effect=fetchrow_side_effect) - self._mock.execute = AsyncMock(side_effect=execute_side_effect) - self._mock.close = AsyncMock() - - return self._mock - - -class MockRedisBuilder: - """Builder for creating mock Redis clients with configurable behavior.""" - - def __init__(self) -> None: - """Initialize mock Redis builder.""" - self._mock = AsyncMock() - self._storage: dict[str, str] = {} - self._ttls: dict[str, int] = {} - self._errors: dict[str, Exception] = {} - - def with_cached_value( - self, key: str, value: str, ttl: int = -1 - ) -> MockRedisBuilder: - """Add a cached value with optional TTL.""" - self._storage[key] = value - self._ttls[key] = ttl - return self - - def with_error_for_key(self, key: str, error: Exception) -> MockRedisBuilder: - """Add an error for a specific key.""" - self._errors[key] = error - return self - - def build(self) -> AsyncMock: - """Build and return the mock Redis client.""" - - async def get_side_effect(key: str) -> str | None: - if key in self._errors: - raise self._errors[key] - return self._storage.get(key) - - async def set_side_effect(key: str, value: str, ttl: int | None = None) -> bool: - if key in self._errors: - raise self._errors[key] - self._storage[key] = value - if ttl: - self._ttls[key] = ttl - return True - - async def setex_side_effect(key: str, ttl: int, value: str) -> bool: - if key in self._errors: - raise self._errors[key] - self._storage[key] = value - self._ttls[key] = ttl - return True - - async def delete_side_effect(key: str) -> int: - if key in self._errors: - raise self._errors[key] - if key in self._storage: - del self._storage[key] - if key in self._ttls: - del self._ttls[key] - return 1 - return 0 - - async def exists_side_effect(key: str) -> bool: - if key in self._errors: - raise self._errors[key] - return key in self._storage - - async def ttl_side_effect(key: str) -> int: - if key in self._errors: - raise self._errors[key] - return self._ttls.get(key, -2) - - async def keys_side_effect(pattern: str = "*") -> list[bytes]: - if pattern == "*": - return [k.encode() for k in self._storage.keys()] - # Simple pattern matching - import fnmatch - - return [ - k.encode() for k in self._storage.keys() if fnmatch.fnmatch(k, pattern) - ] - - self._mock.get = AsyncMock(side_effect=get_side_effect) - self._mock.set = AsyncMock(side_effect=set_side_effect) - self._mock.setex = AsyncMock(side_effect=setex_side_effect) - self._mock.delete = AsyncMock(side_effect=delete_side_effect) - self._mock.exists = AsyncMock(side_effect=exists_side_effect) - self._mock.ttl = AsyncMock(side_effect=ttl_side_effect) - self._mock.keys = AsyncMock(side_effect=keys_side_effect) - self._mock.mget = AsyncMock( - side_effect=lambda keys: [self._storage.get(k) for k in keys] - ) - self._mock.ping = AsyncMock(return_value=True) - + def build(self) -> MagicMock: + """Build the final mock object.""" return self._mock diff --git a/tests/unit_tests/tools/capabilities/extraction/test_single_url_processor.py b/tests/unit_tests/tools/capabilities/extraction/test_single_url_processor.py index d6214345..c514dae7 100644 --- a/tests/unit_tests/tools/capabilities/extraction/test_single_url_processor.py +++ b/tests/unit_tests/tools/capabilities/extraction/test_single_url_processor.py @@ -552,8 +552,9 @@ class TestProcessSingleUrlTool: ) # Verify ExtractToolConfigModel was called with extract config + from typing import Any, cast mock_config_model_class.assert_called_once_with( - **extract_config + **cast("dict[str, Any]", extract_config) ) @pytest.mark.asyncio