feat: add Paperless NGX agent with robust error handling (#42)
* feat: add Paperless NGX agent with robust error handling - Implement ReAct agent for document management with Paperless NGX - Add comprehensive error handling to prevent crashes on missing credentials - Use global ServiceFactory singleton pattern for dependency injection - Integrate edge helpers for error routing and retry logic - Add user-friendly error messages when Paperless is not configured - Support runtime configuration through RunnableConfig - Include all Paperless tools (search, update, tags, correspondents, etc.) - Add factory function for LangGraph API compatibility - Ensure graceful degradation when credentials are missing * Apply suggestions from code review Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * feat: enhance custom tool node for Paperless NGX with improved error handling - Introduced ReActAgentState for better state management in custom tool node. - Added explicit configuration validation and user-friendly error messages for missing Paperless NGX credentials. - Enhanced logging for configuration usage and error reporting. - Updated agent factory to ensure proper configuration handling for tool execution. * fix: update and add files, resolve pre-commit and safe.directory issues * chore: update project configurations and enhance documentation - Added new commands to CLAUDE.local.md for formatting and type checking. - Updated Makefile to use 'basedpyright' for linting instead of 'pyright'. - Included 'pytest' as a new dependency in pyproject.toml. - Cleaned up pyrightconfig.json by removing unnecessary stubPath. - Deleted obsolete test_fixes_summary.md file. - Introduced new test-failures.txt and type_errors.txt files for better error tracking. - Added new devcontainer configuration files for improved development environment setup. These changes collectively enhance project organization, improve documentation clarity, and streamline the development workflow. * chore: add pre-commit as a dependency in pyproject.toml - Included "pre-commit>=4.2.0" in the dependencies section of pyproject.toml to enhance code quality checks and maintainability. * chore: update project configurations and enhance documentation - Added new commands to CLAUDE.local.md for formatting and type checking. - Updated Makefile to use `basedpyright` for linting instead of `pyright`. - Included `pytest` as a dependency in pyproject.toml for testing. - Cleaned up pyrightconfig.json by removing unnecessary stubPath. - Deleted obsolete test_fixes_summary.md and added new test-failures.txt for tracking test failures. - Introduced new devcontainer configuration files for improved development environment setup. These changes collectively enhance project organization, improve testing capabilities, and provide clearer development guidelines. * chore: update devcontainer configuration to include host .venv bind mount - Added a bind mount for the host's .venv directory to the devcontainer configuration for shared access, avoiding conflicts with the existing .venv volume. * chore: add pre-commit as a development dependency in pyproject.toml - Included "pre-commit>=4.2.0" in the dev dependencies section of pyproject.toml to enhance code quality checks and maintainability. * chore: remove obsolete basedpyright_output.json and update LLMConfig initialization - Deleted the obsolete basedpyright_output.json file. - Updated LLMConfig initialization in app.py to set a default profile using LLMProfile.LARGE for improved configuration management. * fix: improve formatting and consistency in Paperless NGX client and tests - Added missing newlines for better readability in paperless.py and test_paperless.py. - Ensured consistent formatting in function calls and docstrings across the codebase. - Enhanced clarity in custom field query examples within the documentation. * fix: improve output formatting and assertions in crash tests - Adjusted print statements in run_crash_tests.py for consistent newline usage. - Enhanced assertions in test_concurrency_races.py to clarify expected behavior during concurrent operations. - Improved readability of markdown output in summary reports. * fix: refactor paperless_ngx_agent_factory to be asynchronous - Changed paperless_ngx_agent_factory from a synchronous to an asynchronous function. - Removed unnecessary nested asyncio handling and streamlined agent creation process. - Updated comments for clarity regarding checkpointer usage. * fix: resolve ruff linting errors - Fix B027 errors by adding return statements to non-abstract ainit methods - Fix E501 line length errors in embeddings.py - Fix ANN204, D105, ANN202, ANN401 type annotation errors - Add proper type annotations for __await__ and service_factory 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> * fix: update project configurations and enhance documentation - Added examples directory to .pyreflyignore for improved file management. - Updated Makefile to include new test_watch target for better testing workflow. - Refined pyproject.toml and pyrightconfig.json for enhanced project configuration. - Introduced new check_singletons.md documentation for clarity on singleton checks. - Removed obsolete type_errors.txt file to streamline error tracking. These changes collectively improve project organization, enhance documentation clarity, and refine testing processes. * fix: enhance graph validation and improve type handling in extraction service - Updated `validate_all_graphs` to handle coroutine functions more effectively and skip those requiring arguments. - Improved type casting in `ComponentExtractor` to ensure proper handling of matches. - Enhanced message handling in `call_model_node` to ensure proper indexing and debugging output. - Refactored `scrape_status_summary_node` for clearer progress calculation and improved handling of R2R information. - Streamlined error handling and type casting in `SemanticExtractionService` for better response processing. - Improved test readability and structure in `test_check_duplicate_error_handling.py` with consistent patching style. These changes collectively enhance the robustness and clarity of the validation and extraction processes. * fix: improve scraper name validation in extraction orchestrator - Updated validation for `scraper_name` in `extract_key_information` and `process_single_url` functions to ensure it is a string. - Simplified logic to default to "beautifulsoup" if `scraper_name` is not a valid string. These changes enhance the robustness of the extraction process by ensuring proper type handling for scraper names. * fix: restore type safety in response formatting - Updated `format_response_for_caller` to ensure `validation_issues_list` is a list of strings and refined URL extraction from sources. - These changes enhance type safety and improve the robustness of response data handling. * fix: enhance type safety in content validation and service factory - Added type checks to `validate_content` and `preprocess_content` functions to ensure inputs are strings. - Improved date validation in `count_recent_sources` to check for both presence and type. - Updated condition for updating `llm_kwargs` in `ServiceFactory` to ensure `kwargs` is a dictionary. These changes collectively enhance type safety and robustness in validation and service handling. * fix: enhance type safety in content extraction and analysis - Added comments to clarify that `extracted_content` and `catalog_metadata` are always dictionaries, improving type safety in `identify_component_focus_node`, `find_affected_catalog_items_node`, and `batch_analyze_components_node`. - Simplified checks for `catalog_items` to ensure they are lists, enhancing robustness in handling extracted content. These changes collectively improve type handling and clarity in the content extraction and analysis processes. * Fix error handling for mixed error types in validation checks Co-authored-by: travis.vas <travis.vas@gmail.com> * fix: enhance type checking in URL filtering and human feedback processing - Added type checks to ensure results are dictionaries in `filter_search_results` and human feedback functions. - Improved robustness by skipping non-dictionary results, enhancing type safety in data handling. These changes collectively improve type handling and clarity in the scraping and validation processes. * fix: enhance type checking in configuration handling - Updated `check_existing_content_node` and `store_processing_metadata` functions to check if `config` is an instance of `AppConfig` before validation. - Improved type safety by ensuring that the correct type is used for `app_config`. These changes collectively enhance type handling and robustness in configuration processing. * fix: improve code readability and type safety in various nodes - Refactored conditional expressions in multiple nodes to enhance readability and maintainability. - Ensured consistent type handling for extracted content and configuration parameters across several functions. - These changes collectively improve clarity and robustness in data processing and validation. * fix: improve type safety in RouterConfig category and severity handling - Refactored category and severity processing in RouterConfig to use direct enumeration instead of casting. - Added type checks to raise errors for invalid types, enhancing robustness in configuration handling. These changes collectively improve type safety and error handling in the RouterConfig class. --------- Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Cursor Agent <cursoragent@cursor.com>
This commit is contained in:
173
.claude/commands/check_singletons.md
Normal file
173
.claude/commands/check_singletons.md
Normal file
@@ -0,0 +1,173 @@
|
||||
Ensure that the modules, functions, and classes in $ARGUMENTS have adopted my global singleton patterns
|
||||
|
||||
### **Tier 1: Core Architectural Pillars**
|
||||
|
||||
These are the most critical, high-level abstractions that every developer should use to interact with the application's core services and configuration.
|
||||
|
||||
---
|
||||
|
||||
#### **1. The Global Service Factory**
|
||||
|
||||
This is the **single most important singleton** in your application. It provides centralized, asynchronous, and cached access to all major services.
|
||||
|
||||
* **Primary Accessor**: `get_global_factory(config: AppConfig | None = None)`
|
||||
* **Location**: `src/biz_bud/services/factory.py`
|
||||
* **Purpose**: To provide a singleton instance of the `ServiceFactory`, which in turn creates and manages the lifecycle of essential services like LLM clients, database connections, and caches.
|
||||
* **Usage Pattern**:
|
||||
```python
|
||||
# In any async part of your application
|
||||
from src.biz_bud.services.factory import get_global_factory
|
||||
|
||||
service_factory = await get_global_factory()
|
||||
llm_client = await service_factory.get_llm_client()
|
||||
vector_store = await service_factory.get_vector_store()
|
||||
db_service = await service_factory.get_db_service()
|
||||
```
|
||||
* **Instead of**:
|
||||
* Instantiating services like `LangchainLLMClient` or `PostgresStore` directly.
|
||||
* Managing database connection pools or API clients manually in different parts of the code.
|
||||
|
||||
---
|
||||
|
||||
#### **2. Application Configuration Loading**
|
||||
|
||||
Configuration is managed centrally and should be accessed through these standardized functions to ensure all overrides (from environment variables or runtime) are correctly applied.
|
||||
|
||||
* **Primary Accessors**: `load_config()` and `load_config_async()`
|
||||
* **Location**: `src/biz_bud/config/loader.py`
|
||||
* **Purpose**: To load the `AppConfig` from `config.yaml`, merge it with environment variables, and return a validated Pydantic model. The async version is for use within an existing event loop.
|
||||
* **Runtime Override Helper**: `resolve_app_config_with_overrides(runnable_config: RunnableConfig)`
|
||||
* **Purpose**: **This is the standard pattern for graphs.** It takes the base `AppConfig` and intelligently merges it with runtime parameters passed in a `RunnableConfig` (e.g., `llm_profile_override`, `temperature`).
|
||||
* **Usage Pattern (Inside a Graph Factory)**:
|
||||
```python
|
||||
from src.biz_bud.config.loader import resolve_app_config_with_overrides
|
||||
from src.biz_bud.services.factory import ServiceFactory
|
||||
|
||||
def my_graph_factory(config: dict[str, Any]) -> Pregel:
|
||||
runnable_config = RunnableConfig(configurable=config.get("configurable", {}))
|
||||
app_config = resolve_app_config_with_overrides(runnable_config=runnable_config)
|
||||
service_factory = ServiceFactory(app_config)
|
||||
# ... inject factory into a graph or use it to build nodes
|
||||
```
|
||||
* **Instead of**:
|
||||
* Manually reading `config.yaml` with `pyyaml`.
|
||||
* Using `os.getenv()` scattered throughout the codebase.
|
||||
* Manually parsing `RunnableConfig` inside every node.
|
||||
|
||||
---
|
||||
|
||||
### **Tier 2: Standardized Interaction Patterns**
|
||||
|
||||
These are the common patterns and helpers for core tasks like AI model interaction, caching, and error handling.
|
||||
|
||||
---
|
||||
|
||||
#### **3. LLM Interaction**
|
||||
|
||||
All interactions with Large Language Models should go through standardized nodes or clients to ensure consistency in configuration, message handling, and error management.
|
||||
|
||||
* **Primary Graph Node**: `call_model_node(state: dict, config: RunnableConfig)`
|
||||
* **Location**: `src/biz_bud/nodes/llm/call.py`
|
||||
* **Purpose**: This is the **standard node for all LLM calls** within a LangGraph workflow. It correctly resolves the LLM profile (`tiny`, `small`, `large`), handles message history, parses tool calls, and manages exceptions.
|
||||
* **Service Factory Method**: `ServiceFactory.get_llm_for_node(node_context: str, llm_profile_override: str | None)`
|
||||
* **Location**: `src/biz_bud/services/factory.py`
|
||||
* **Purpose**: For custom nodes that require more complex logic, this method provides a pre-configured, wrapped LLM client from the factory. The `node_context` helps select an appropriate default model size.
|
||||
* **Instead of**:
|
||||
* Directly importing and using `ChatOpenAI`, `ChatAnthropic`, etc.
|
||||
* Manually constructing message lists or handling API errors for each LLM call.
|
||||
* Implementing your own retry logic for LLM calls.
|
||||
|
||||
---
|
||||
|
||||
#### **4. Caching System**
|
||||
|
||||
The project provides a default, asynchronous, in-memory cache and a Redis-backed cache. Direct interaction should be minimal; prefer the decorator.
|
||||
|
||||
* **Primary Decorator**: `@cache_async(ttl: int)`
|
||||
* **Location**: `packages/business-buddy-core/src/bb_core/caching/decorators.py`
|
||||
* **Purpose**: The standard way to cache the results of any `async` function. It automatically generates a key based on the function and its arguments.
|
||||
* **Singleton Accessors**:
|
||||
* `get_default_cache_async()`: Gets the default in-memory cache instance.
|
||||
* `ServiceFactory.get_redis_cache()`: Gets the Redis cache backend if configured.
|
||||
* **Usage Pattern**:
|
||||
```python
|
||||
from bb_core.caching import cache_async
|
||||
|
||||
@cache_async(ttl=3600) # Cache for 1 hour
|
||||
async def my_expensive_api_call(arg1: str, arg2: int) -> dict:
|
||||
# ... implementation
|
||||
```
|
||||
* **Instead of**:
|
||||
* Implementing your own caching logic with dictionaries or files.
|
||||
* Instantiating `InMemoryCache` or `RedisCache` manually.
|
||||
|
||||
---
|
||||
|
||||
#### **5. Error Handling & Lifecycle Subsystem**
|
||||
|
||||
This is a comprehensive, singleton-based system for robust error management and application lifecycle.
|
||||
|
||||
* **Global Singletons**:
|
||||
* `get_error_aggregator()`: (`errors/aggregator.py`) Use to report errors for deduplication and rate-limiting.
|
||||
* `get_error_router()`: (`errors/router.py`) Use to define and apply routing logic for different error types.
|
||||
* `get_error_logger()`: (`errors/logger.py`) Use for consistent, structured error logging.
|
||||
* **Primary Decorator**: `@handle_errors(error_type)`
|
||||
* **Location**: `packages/business-buddy-core/src/bb_core/errors/base.py`
|
||||
* **Purpose**: Wraps functions to automatically catch common exceptions (`httpx` errors, `pydantic` validation errors) and convert them into standardized `BusinessBuddyError` types.
|
||||
* **Lifecycle Management**:
|
||||
* `get_singleton_manager()`: (`services/singleton_manager.py`) Main accessor for the lifecycle manager.
|
||||
* `cleanup_all_singletons()`: Use this at application shutdown to gracefully close all registered services (DB pools, HTTP sessions, etc.).
|
||||
* **Instead of**:
|
||||
* Using generic `try...except Exception` blocks.
|
||||
* Manually logging error details with `logger.error()`.
|
||||
* Forgetting to close resources like database connections.
|
||||
|
||||
---
|
||||
|
||||
### **Tier 3: Reusable Helpers & Utilities**
|
||||
|
||||
These are specific tools and helpers for common, recurring tasks across the codebase.
|
||||
|
||||
---
|
||||
|
||||
#### **6. Asynchronous and Networking Utilities**
|
||||
|
||||
* **Location**: `packages/business-buddy-core/src/bb_core/networking/async_utils.py`
|
||||
* **Key Helpers**:
|
||||
* `gather_with_concurrency(n, *tasks)`: Runs multiple awaitables concurrently with a semaphore to limit parallelism.
|
||||
* `retry_async(...)`: A decorator to add exponential backoff retry logic to any `async` function.
|
||||
* `RateLimiter(calls_per_second)`: An `async` context manager to enforce rate limits.
|
||||
* `HTTPClient`: The base client for making robust HTTP requests, managed by the `ServiceFactory`.
|
||||
|
||||
---
|
||||
|
||||
#### **7. Graph State & Node Helpers**
|
||||
|
||||
* **State Management**: `StateUpdater(base_state)`
|
||||
* **Location**: `packages/business-buddy-core/src/bb_core/langgraph/state_immutability.py`
|
||||
* **Purpose**: Provides a **safe, fluent API** for updating graph state immutably (e.g., `updater.set("key", val).append("list_key", item).build()`).
|
||||
* **Instead of**: Directly mutating the state dictionary (`state["key"] = value`), which can cause difficult-to-debug issues in concurrent or resumable graphs.
|
||||
* **Node Validation Decorators**: `@validate_node_input(Model)` and `@validate_node_output(Model)`
|
||||
* **Location**: `packages/business-buddy-core/src/bb_core/validation/graph_validation.py`
|
||||
* **Purpose**: Ensures that the input to a node and the output from a node conform to a specified Pydantic model, automatically adding errors to the state if validation fails.
|
||||
* **Edge Helpers**:
|
||||
* **Location**: `packages/business-buddy-core/src/bb_core/edge_helpers/`
|
||||
* **Purpose**: A suite of pre-built, configurable routing functions for common graph conditions (e.g., `handle_error`, `retry_on_failure`, `check_critical_error`, `should_continue`). Use these to define conditional edges in your graphs.
|
||||
|
||||
---
|
||||
|
||||
#### **8. High-Level Workflow Tools**
|
||||
|
||||
These tools abstract away complex, multi-step processes like searching and scraping.
|
||||
|
||||
* **Unified Scraper**: `UnifiedScraper(config: ScrapeConfig)`
|
||||
* **Location**: `packages/business-buddy-tools/src/bb_tools/scrapers/unified.py`
|
||||
* **Purpose**: The single entry point for all web scraping. It automatically selects the best strategy (BeautifulSoup, Firecrawl, Jina) based on the URL and configuration.
|
||||
* **Unified Search**: `UnifiedSearchTool(config: SearchConfig)`
|
||||
* **Location**: `packages/business-buddy-tools/src/bb_tools/search/unified.py`
|
||||
* **Purpose**: The single entry point for web searches. It can query multiple providers (Tavily, Jina, Arxiv) in parallel and deduplicates results.
|
||||
* **Search & Scrape Fetcher**: `WebContentFetcher(search_tool, scraper)`
|
||||
* **Location**: `packages/business-buddy-tools/src/bb_tools/actions/fetch.py`
|
||||
* **Purpose**: A high-level tool that combines the `UnifiedSearchTool` and `UnifiedScraper` to perform a complete "search and scrape" workflow.
|
||||
|
||||
By consistently using these established patterns and utilities, you will improve code quality, reduce bugs, and make the entire project easier to understand and maintain.
|
||||
88
.claude/commands/check_types.md
Normal file
88
.claude/commands/check_types.md
Normal file
@@ -0,0 +1,88 @@
|
||||
Check the following code for quality and strict type safety: $ARGUMENTS
|
||||
|
||||
# Type Safety and Type Checking Rules
|
||||
|
||||
## 1. Type Safety
|
||||
|
||||
- All function parameters and variables must have explicit types declared or be inferable with strict typing.
|
||||
- **No usage of `Any` is allowed**, either explicitly or implicitly.
|
||||
- Make sure all types are explicit and specific. Avoid overly broad types like `object`.
|
||||
- Use typing features from the standard library (`typing` or `typing_extensions`), preferably Python 3.9+ syntax (`list[str]` instead of `List[str]`).
|
||||
- Avoid legacy/deprecated typing constructs, e.g., no `NotRequired` from older `typing-extensions` versions or Pydantic v1.0 conventions.
|
||||
- Use Pydantic v2+ patterns if necessary or alternatives that are modern and well-maintained.
|
||||
|
||||
---
|
||||
|
||||
## 2. Return Types
|
||||
|
||||
- Every function and method must have an explicit return type annotation, even if returning `None`.
|
||||
- For async functions, specify return types with `async def func(...) -> Awaitable[ReturnType]` or more commonly just `async def func(...) -> ReturnType`.
|
||||
- Avoid implicit returns or assumed returns.
|
||||
- If the function returns multiple types (e.g., `Union`), make it clear and as narrow as possible.
|
||||
|
||||
---
|
||||
|
||||
## 3. Type Hints
|
||||
|
||||
- Use inline type hints for variables where applicable and useful to improve readability.
|
||||
- Use `TypedDict`, `Literal`, `Protocol`, `Final`, `TypeAlias` and other modern typing features appropriately.
|
||||
- **No shortcuts:** do not disable or skip type checking rules or warnings anywhere.
|
||||
- Errors and warnings raised by pyrefly, ruff, and basedpyright must be resolved, not ignored.
|
||||
- Don't use third-party typing decorators or hacks that circumvent standard Python type hints.
|
||||
|
||||
---
|
||||
|
||||
## 4. No Ignoring or Shortcuts
|
||||
|
||||
- Do not add `# type: ignore` or lint ignores.
|
||||
- All issues flagged by linters/type-checkers must be handled by fixing the root problem.
|
||||
- Don't commit code with ignored errors or warnings.
|
||||
- Avoid complexity that makes typing impossible or too complex; refactor instead.
|
||||
|
||||
---
|
||||
|
||||
## 5. Tools and Integration
|
||||
|
||||
- Use **pyrefly** for autocompletion and precise inline type hints while coding.
|
||||
- Use **ruff** to enforce style and typing rules:
|
||||
- Enable rules that enforce function annotations (flake8-annotations or pyright analogs).
|
||||
- Enable rules to disallow `Any` and enforce consistent type hints.
|
||||
- Use **basedpyright** for type checking:
|
||||
- Treat all warnings/errors as errors.
|
||||
- Set strict mode enabled.
|
||||
- Fix all type errors reported.
|
||||
|
||||
---
|
||||
|
||||
## Example Enforcement Sample
|
||||
|
||||
```python
|
||||
from typing import List
|
||||
|
||||
def get_usernames(users: List[str]) -> List[str]:
|
||||
# Explicit input and output types
|
||||
return [user.strip().lower() for user in users]
|
||||
|
||||
async def fetch_data(url: str) -> dict[str, str]:
|
||||
# Explicit return type for async func
|
||||
response = await some_http_client.get(url)
|
||||
return response.json()
|
||||
```
|
||||
|
||||
- No use of `Any` anywhere.
|
||||
- All functions/types explicitly annotated.
|
||||
- No places where errors are ignored.
|
||||
|
||||
---
|
||||
|
||||
## Summary Checklist
|
||||
|
||||
- All functions/methods have explicit return types.
|
||||
- No use of `Any` types.
|
||||
- No `# type: ignore` or lint ignores.
|
||||
- No legacy/disallowed typing patterns (`NotRequired`, Pydantic v1, etc.).
|
||||
- Correct use of modern python typing features.
|
||||
- Passes all pyrefly suggestions without disabling checks.
|
||||
- Passes all ruff lint rules, especially regarding annotations.
|
||||
- Passes all basedpyright type checks with strict mode and zero errors.
|
||||
|
||||
@@ -4,18 +4,45 @@
|
||||
{
|
||||
"hooks": [
|
||||
{
|
||||
"command": "cd $(git rev-parse --show-toplevel) && /home/vasceannie/repos/biz-budz/scripts/black-file.sh",
|
||||
"command": "cd $(git rev-parse --show-toplevel) && ./scripts/black-file.sh",
|
||||
"type": "command"
|
||||
}
|
||||
],
|
||||
"matcher": "Write|Edit|MultiEdit"
|
||||
},
|
||||
{
|
||||
"hooks": [
|
||||
{
|
||||
"command": "echo 'BLOCKED: Using SKIP= with git commit is forbidden. Run proper pre-commit hooks instead.' && exit 1",
|
||||
"type": "command"
|
||||
}
|
||||
],
|
||||
"matcher": "Bash.*SKIP=.*git commit"
|
||||
},
|
||||
{
|
||||
"hooks": [
|
||||
{
|
||||
"command": "echo 'BLOCKED: Using --no-verify with git commit is forbidden. Pre-commit hooks must run.' && exit 1",
|
||||
"type": "command"
|
||||
}
|
||||
],
|
||||
"matcher": "Bash.*git commit.*--no-verify"
|
||||
},
|
||||
{
|
||||
"hooks": [
|
||||
{
|
||||
"command": "echo 'BLOCKED: Using git commit with -n flag (no-verify) is forbidden. Pre-commit hooks must run.' && exit 1",
|
||||
"type": "command"
|
||||
}
|
||||
],
|
||||
"matcher": "Bash.*git commit.*-n"
|
||||
}
|
||||
],
|
||||
"PostToolUse": [
|
||||
{
|
||||
"hooks": [
|
||||
{
|
||||
"command": "cd $(git rev-parse --show-toplevel) && /home/vasceannie/repos/biz-budz/scripts/lint-file.sh",
|
||||
"command": "cd $(git rev-parse --show-toplevel) && ./scripts/lint-file.sh",
|
||||
"type": "command"
|
||||
}
|
||||
],
|
||||
|
||||
57
.devcontainer/Dockerfile
Normal file
57
.devcontainer/Dockerfile
Normal file
@@ -0,0 +1,57 @@
|
||||
FROM mcr.microsoft.com/devcontainers/python:1-3.12-bookworm
|
||||
|
||||
# Install additional system dependencies
|
||||
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
||||
&& apt-get -y install --no-install-recommends \
|
||||
postgresql-client \
|
||||
redis-tools \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Node.js and npm
|
||||
RUN curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - \
|
||||
&& apt-get install -y nodejs \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install UV package manager globally
|
||||
RUN pip install --no-cache-dir uv
|
||||
|
||||
# Install global npm packages
|
||||
RUN npm install -g task-master-ai repomix @anthropic-ai/claude-code
|
||||
|
||||
# Create cache directories with proper permissions
|
||||
RUN mkdir -p /home/vscode/.cache/uv \
|
||||
&& chown -R vscode:vscode /home/vscode/.cache
|
||||
|
||||
# Set up Python environment paths
|
||||
ENV PYTHONPATH="/workspace/src:/workspace:$PYTHONPATH"
|
||||
|
||||
# Allow vscode user to use sudo without password for permission fixes
|
||||
RUN echo "vscode ALL=(ALL) NOPASSWD: /bin/chown, /bin/chmod, /usr/bin/chown, /usr/bin/chmod" >> /etc/sudoers
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /workspace
|
||||
|
||||
# Create a startup script to fix permissions on volumes and ensure group membership
|
||||
RUN echo '#!/bin/bash\n\
|
||||
# Fix ownership of workspace directories\n\
|
||||
if [ -d "/workspace/.venv" ]; then\n\
|
||||
sudo chown -R vscode:vscode /workspace/.venv\n\
|
||||
fi\n\
|
||||
# Ensure vscode user is in root group for shared permissions\n\
|
||||
if ! groups vscode | grep -q "\broot\b"; then\n\
|
||||
sudo usermod -a -G root vscode\n\
|
||||
fi\n\
|
||||
# Fix permissions on config and cache directories if they exist\n\
|
||||
if [ -d "/home/vscode/.config" ]; then\n\
|
||||
sudo chown -R vscode:vscode /home/vscode/.config 2>/dev/null || true\n\
|
||||
fi\n\
|
||||
if [ -d "/home/vscode/.cache" ]; then\n\
|
||||
sudo chown -R vscode:vscode /home/vscode/.cache 2>/dev/null || true\n\
|
||||
fi\n\
|
||||
exec "$@"' > /usr/local/bin/docker-entrypoint.sh \
|
||||
&& chmod +x /usr/local/bin/docker-entrypoint.sh
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/docker-entrypoint.sh"]
|
||||
CMD ["sleep", "infinity"]
|
||||
95
.devcontainer/devcontainer.json
Normal file
95
.devcontainer/devcontainer.json
Normal file
@@ -0,0 +1,95 @@
|
||||
{
|
||||
"name": "Business Buddy Dev Container",
|
||||
"dockerComposeFile": "docker-compose.yml",
|
||||
"service": "app",
|
||||
"workspaceFolder": "/workspace",
|
||||
|
||||
"features": {
|
||||
"ghcr.io/devcontainers/features/python:1": {
|
||||
"version": "3.12",
|
||||
"installJupyterlab": false
|
||||
},
|
||||
"ghcr.io/devcontainers/features/node:1": {
|
||||
"version": "lts"
|
||||
},
|
||||
"ghcr.io/devcontainers/features/git:1": {},
|
||||
"ghcr.io/devcontainers/features/github-cli:1": {},
|
||||
"ghcr.io/devcontainers/features/common-utils:2": {
|
||||
"installZsh": true,
|
||||
"configureZshAsDefaultShell": true,
|
||||
"installOhMyZsh": true,
|
||||
"installOhMyZshConfig": true,
|
||||
"upgradePackages": true
|
||||
}
|
||||
},
|
||||
|
||||
"customizations": {
|
||||
"vscode": {
|
||||
"settings": {
|
||||
"python.defaultInterpreterPath": "/workspace/.venv/bin/python",
|
||||
"python.terminal.activateEnvironment": true,
|
||||
"python.linting.enabled": true,
|
||||
"python.linting.pylintEnabled": false,
|
||||
"python.formatting.provider": "none",
|
||||
"[python]": {
|
||||
"editor.defaultFormatter": "charliermarsh.ruff",
|
||||
"editor.formatOnSave": true,
|
||||
"editor.codeActionsOnSave": {
|
||||
"source.organizeImports": "explicit",
|
||||
"source.fixAll": "explicit"
|
||||
}
|
||||
},
|
||||
"terminal.integrated.defaultProfile.linux": "zsh",
|
||||
"git.enabled": true,
|
||||
"git.path": "/usr/local/bin/git",
|
||||
"git.allowNoVerifyCommit": false,
|
||||
"git.alwaysSignOff": false,
|
||||
"git.useEditorAsCommitInput": true,
|
||||
"git.enableCommitSigning": false
|
||||
},
|
||||
"extensions": [
|
||||
"ms-python.python",
|
||||
"ms-python.vscode-pylance",
|
||||
"charliermarsh.ruff",
|
||||
"ms-azuretools.vscode-docker",
|
||||
"GitHub.copilot",
|
||||
"ms-vscode.makefile-tools",
|
||||
"donjayamanne.githistory",
|
||||
"streetsidesoftware.code-spell-checker"
|
||||
]
|
||||
}
|
||||
},
|
||||
|
||||
"forwardPorts": [2024, 5432, 6379, 6333],
|
||||
|
||||
"postCreateCommand": "bash .devcontainer/setup.sh",
|
||||
|
||||
"containerEnv": {
|
||||
"PYTHONPATH": "/workspace/src:/workspace",
|
||||
"DATABASE_URL": "postgres://user:password@postgres:5432/langgraph_db",
|
||||
"REDIS_URL": "redis://redis:6379/0",
|
||||
"QDRANT_URL": "http://qdrant:6333",
|
||||
"CHOKIDAR_USEPOLLING": "true",
|
||||
"WATCHPACK_POLLING": "true",
|
||||
"LANGGRAPH_HOST": "0.0.0.0",
|
||||
"LANGGRAPH_PORT": "2024",
|
||||
"LANGCHAIN_TRACING_V2": "true",
|
||||
"LANGCHAIN_PROJECT": "biz-budz-dev",
|
||||
"LANGSMITH_TRACING": "true"
|
||||
},
|
||||
|
||||
"mounts": [
|
||||
"source=${localWorkspaceFolder},target=/workspace,type=bind,consistency=delegated",
|
||||
"source=${localEnv:HOME}/.config,target=/home/vscode/.config,type=bind,consistency=cached",
|
||||
"source=${localEnv:HOME}/.cache,target=/home/vscode/.cache,type=bind,consistency=cached"
|
||||
],
|
||||
|
||||
"remoteUser": "vscode",
|
||||
|
||||
"runArgs": [
|
||||
"--user=1000:1000",
|
||||
"--group-add=0"
|
||||
],
|
||||
|
||||
"updateContentCommand": "bash .devcontainer/fix-permissions.sh"
|
||||
}
|
||||
125
.devcontainer/docker-compose.yml
Normal file
125
.devcontainer/docker-compose.yml
Normal file
@@ -0,0 +1,125 @@
|
||||
|
||||
services:
|
||||
app:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: .devcontainer/Dockerfile
|
||||
ports:
|
||||
- "2024:2024" # LangGraph Studio
|
||||
user: "1000:1000"
|
||||
volumes:
|
||||
# Main workspace with delegated consistency for better performance
|
||||
- type: bind
|
||||
source: ..
|
||||
target: /workspace
|
||||
consistency: delegated
|
||||
# Mount .venv to a volume to hide host's .venv
|
||||
- venv-volume:/workspace/.venv
|
||||
# Bind mount host .venv for shared access (different name to avoid conflicts)
|
||||
- ./.venv-host:/workspace/.venv-host
|
||||
# UV cache volume
|
||||
- uv-cache:/home/vscode/.cache/uv
|
||||
# Additional bind mounts for hot reload
|
||||
- type: bind
|
||||
source: ../src
|
||||
target: /workspace/src
|
||||
consistency: consistent
|
||||
- type: bind
|
||||
source: ../tests
|
||||
target: /workspace/tests
|
||||
consistency: consistent
|
||||
- type: bind
|
||||
source: ../packages
|
||||
target: /workspace/packages
|
||||
consistency: consistent
|
||||
# Mount host config and cache directories for sharing permissions and API access
|
||||
- type: bind
|
||||
source: ${HOME}/.config
|
||||
target: /home/vscode/.config
|
||||
consistency: cached
|
||||
- type: bind
|
||||
source: ${HOME}/.cache
|
||||
target: /home/vscode/.cache
|
||||
consistency: cached
|
||||
environment:
|
||||
# Python configuration
|
||||
PYTHONPATH: /workspace/src:/workspace
|
||||
# Database connections
|
||||
DATABASE_URL: postgres://user:password@postgres:5432/langgraph_db
|
||||
REDIS_URL: redis://redis:6379/0
|
||||
QDRANT_URL: http://qdrant:6333
|
||||
# Development environment
|
||||
NODE_ENV: development
|
||||
PYTHON_ENV: development
|
||||
# LangGraph configuration
|
||||
LANGGRAPH_HOST: 0.0.0.0
|
||||
LANGGRAPH_PORT: 2024
|
||||
# LangSmith/LangChain tracing configuration
|
||||
LANGCHAIN_TRACING_V2: "true"
|
||||
LANGCHAIN_PROJECT: "biz-budz-dev"
|
||||
LANGSMITH_TRACING: "true"
|
||||
# Pass through API keys from host environment
|
||||
LANGCHAIN_API_KEY: ${LANGCHAIN_API_KEY:-}
|
||||
LANGSMITH_API_KEY: ${LANGSMITH_API_KEY:-}
|
||||
ANTHROPIC_API_KEY: ${ANTHROPIC_API_KEY:-}
|
||||
OPENAI_API_KEY: ${OPENAI_API_KEY:-}
|
||||
TAVILY_API_KEY: ${TAVILY_API_KEY:-}
|
||||
PERPLEXITY_API_KEY: ${PERPLEXITY_API_KEY:-}
|
||||
depends_on:
|
||||
- postgres
|
||||
- redis
|
||||
- qdrant
|
||||
networks:
|
||||
- devcontainer-network
|
||||
|
||||
postgres:
|
||||
image: postgres:14
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- postgres-data:/var/lib/postgresql/data
|
||||
- ./init-items.sql:/docker-entrypoint-initdb.d/init-items.sql:ro
|
||||
environment:
|
||||
POSTGRES_USER: user
|
||||
POSTGRES_PASSWORD: password
|
||||
POSTGRES_DB: langgraph_db
|
||||
networks:
|
||||
- devcontainer-network
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U user -d langgraph_db"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
redis:
|
||||
image: redis:7
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- redis-data:/data
|
||||
networks:
|
||||
- devcontainer-network
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "ping"]
|
||||
interval: 5s
|
||||
timeout: 3s
|
||||
retries: 5
|
||||
|
||||
qdrant:
|
||||
image: qdrant/qdrant:latest
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- qdrant-storage:/qdrant/storage
|
||||
environment:
|
||||
QDRANT__SERVICE__HTTP_PORT: 6333
|
||||
networks:
|
||||
- devcontainer-network
|
||||
|
||||
volumes:
|
||||
postgres-data:
|
||||
redis-data:
|
||||
qdrant-storage:
|
||||
venv-volume:
|
||||
uv-cache:
|
||||
|
||||
networks:
|
||||
devcontainer-network:
|
||||
driver: bridge
|
||||
40
.devcontainer/fix-permissions.sh
Normal file
40
.devcontainer/fix-permissions.sh
Normal file
@@ -0,0 +1,40 @@
|
||||
#!/bin/bash
|
||||
# Quick fix script for permissions and API access in running container
|
||||
|
||||
echo "🔧 Fixing permissions and configuring API access..."
|
||||
|
||||
# Fix ownership of cache directories
|
||||
sudo chown -R vscode:vscode /home/vscode/.cache 2>/dev/null || true
|
||||
sudo chown -R vscode:vscode /home/vscode/.config 2>/dev/null || true
|
||||
|
||||
# Create cache directory with proper permissions
|
||||
mkdir -p /home/vscode/.cache/uv
|
||||
touch /home/vscode/.cache/uv/CACHEDIR.TAG
|
||||
chmod -R 755 /home/vscode/.cache/uv
|
||||
|
||||
# Create config directories for LangSmith/LangChain
|
||||
mkdir -p /home/vscode/.config/langchain
|
||||
mkdir -p /home/vscode/.config/langsmith
|
||||
|
||||
# Ensure vscode user is in root group for shared permissions
|
||||
if ! groups vscode | grep -q "\broot\b"; then
|
||||
sudo usermod -a -G root vscode
|
||||
echo "✓ Added vscode user to root group"
|
||||
fi
|
||||
|
||||
# Copy shell profiles
|
||||
if [ -f /home/vscode/.zshrc.host ]; then
|
||||
cp /home/vscode/.zshrc.host /home/vscode/.zshrc
|
||||
sed -i 's|/home/vasceannie|/home/vscode|g' /home/vscode/.zshrc
|
||||
echo "✓ Copied .zshrc"
|
||||
fi
|
||||
|
||||
if [ -f /home/vscode/.bashrc.host ]; then
|
||||
cp /home/vscode/.bashrc.host /home/vscode/.bashrc
|
||||
echo "✓ Copied .bashrc"
|
||||
fi
|
||||
|
||||
# Set proper permissions for workspace
|
||||
sudo chown -R vscode:vscode /workspace/.venv 2>/dev/null || true
|
||||
|
||||
echo "✅ Permissions and API access configured!"
|
||||
146
.devcontainer/init-items.sql
Normal file
146
.devcontainer/init-items.sql
Normal file
@@ -0,0 +1,146 @@
|
||||
CREATE TABLE IF NOT EXISTS items (
|
||||
id VARCHAR(32) PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
online_name TEXT,
|
||||
hidden BOOLEAN DEFAULT FALSE,
|
||||
available BOOLEAN DEFAULT TRUE,
|
||||
auto_manage BOOLEAN DEFAULT FALSE,
|
||||
price INTEGER NOT NULL,
|
||||
price_type VARCHAR(16) NOT NULL,
|
||||
default_tax_rates BOOLEAN,
|
||||
is_revenue BOOLEAN,
|
||||
modified_time BIGINT NOT NULL,
|
||||
deleted BOOLEAN DEFAULT FALSE,
|
||||
enabled_online BOOLEAN DEFAULT TRUE,
|
||||
alternate_name TEXT,
|
||||
code TEXT,
|
||||
sku TEXT,
|
||||
unit_name TEXT,
|
||||
cost INTEGER,
|
||||
price_without_vat INTEGER
|
||||
);
|
||||
|
||||
INSERT INTO items (
|
||||
id, name, online_name, hidden, available, auto_manage, price, price_type, default_tax_rates, is_revenue,
|
||||
modified_time, deleted, enabled_online, alternate_name, code, sku, unit_name, cost, price_without_vat
|
||||
) VALUES
|
||||
('8FWKTEEY952NJ','LGE Can Tastee Cheese 4.8-Lbs','LGE Can Tastee Cheese 4.8-Lbs',FALSE,TRUE,FALSE,7000,'FIXED',TRUE,TRUE,1742946876000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
|
||||
('AES43GBJYVMHE','Chili Wings 3 for $4.00','Chili Wings 3 for $4.00',FALSE,TRUE,FALSE,400,'FIXED',TRUE,TRUE,1741900252000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
|
||||
('0PEJY11DE4K1Y','LG Brown Stew Chicken Only','LG Brown Stew Chicken Only',FALSE,TRUE,FALSE,2500,'FIXED',TRUE,TRUE,1741546970000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
|
||||
('94TSPNQ53MHT2','LG Beef Ribs','LG Beef Ribs',FALSE,TRUE,FALSE,1800,'FIXED',TRUE,TRUE,1740336153000,FALSE,TRUE,'','','','',0,0),
|
||||
('TRMEF1K37ZVBY','Reg Beef Ribs','Reg Beef Ribs',FALSE,TRUE,FALSE,1500,'FIXED',TRUE,TRUE,1740336132000,FALSE,TRUE,'','','','',0,0),
|
||||
('XDVX9NQGGNHWM','LG Cookup Salmon','LG Cookup Salmon',FALSE,TRUE,FALSE,1450,'FIXED',TRUE,TRUE,1736307052000,FALSE,TRUE,'','','','',0,0),
|
||||
('3ZSY0QBJXADYP','Reg Cookup Salmon','Reg Cookup Salmon',FALSE,TRUE,FALSE,1250,'FIXED',TRUE,TRUE,1736307028000,FALSE,TRUE,'','','','',0,0),
|
||||
('FVE90V5KANJMC','Mini Cookup Salmon','Mini Cookup Salmon',FALSE,TRUE,FALSE,1000,'FIXED',TRUE,TRUE,1741888827000,FALSE,TRUE,'','','','',0,0),
|
||||
('Y5S0TVYTEPMPY','Christmas 2024 Package For 2-4 People','Christmas 2024 Package For 2-4 People',TRUE,FALSE,FALSE,22500,'FIXED',TRUE,TRUE,1740882286000,FALSE,TRUE,'','','','',0,0),
|
||||
('JCRTY9PR0PWGM','Christmas 2024 Package For 6-8 People','Christmas 2024 Package For 6-8 People',TRUE,FALSE,FALSE,55000,'FIXED',TRUE,TRUE,1740882290000,FALSE,TRUE,'','','','',0,0),
|
||||
('550XACY435GEG','Christmas 2024 Package For 10-15 People','Christmas 2024 Package For 10-15 People',TRUE,FALSE,FALSE,75000,'FIXED',TRUE,TRUE,1740882282000,FALSE,TRUE,'','','','',0,0),
|
||||
('EDKQXCK0M7SQW','Christmas 2024 Package For 20-25 People','Christmas 2024 Package For 20-25 People',TRUE,FALSE,FALSE,95000,'FIXED',TRUE,TRUE,1740882288000,FALSE,TRUE,'','','','',0,0),
|
||||
('0TRKQ08BW6RNY','SM Tray Spinach Rice','SM Tray Spinach Rice',FALSE,TRUE,FALSE,6000,'FIXED',TRUE,TRUE,1732042296000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
|
||||
('3J7VACG51026P','SM Tray Spinach','SM Tray Spinach',FALSE,TRUE,FALSE,7000,'FIXED',TRUE,TRUE,1732042206000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
|
||||
('8XZS0RP2YAXGP','LG Tray Veggie Fried Rice','LG Tray Veggie Fried Rice',FALSE,TRUE,FALSE,12500,'FIXED',TRUE,TRUE,1732041639000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
|
||||
('VRS51MX1PY98E','LG Veggie Fried Rice','LG Veggie Fried Rice',FALSE,TRUE,FALSE,1400,'FIXED',TRUE,TRUE,1732040104000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
|
||||
('KWMNPW3RTVJ6Y','Reg Veggie Fried Rice','Reg Veggie Fried Rice',FALSE,TRUE,FALSE,1100,'FIXED',TRUE,TRUE,1732040028000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
|
||||
('QX3077QP7ECAP','Mini Veggie Fried Rice','Mini Veggie Fried Rice',FALSE,TRUE,FALSE,800,'FIXED',TRUE,TRUE,1732039964000,FALSE,TRUE,'','','','',0,0),
|
||||
('FPBG4P3PMDZYC','LG TRAY ITAL STEW','LG TRAY ITAL STEW',FALSE,TRUE,FALSE,13000,'FIXED',TRUE,TRUE,1737753773000,FALSE,TRUE,'','','','',0,0),
|
||||
('KSM43DRQMAWZT','SM TRAY ITAL STEW','SM TRAY ITAL STEW',FALSE,TRUE,FALSE,6500,'FIXED',TRUE,TRUE,1737753765000,FALSE,TRUE,'','','','',0,0),
|
||||
('RG1KS6KPK9TKR','SM Tray Jerk Chicken Lo Mein','SM Tray Jerk Chicken Lo Mein',FALSE,TRUE,FALSE,7500,'FIXED',TRUE,TRUE,1731973151000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
|
||||
('CPQ44F6PXZ232','SM TRAY CANDID YAMS','SM TRAY CANDID YAMS',FALSE,TRUE,FALSE,4000,'FIXED',TRUE,TRUE,1732045788000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
|
||||
('26MBWTPXPB2B6','PACKAGE FOR 6-8 PEOPLE','PACKAGE FOR 6-8 PEOPLE',TRUE,FALSE,FALSE,32500,'FIXED',TRUE,TRUE,1740894161000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
|
||||
('7YTB27AP3VENC','PACKAGE FOR 2-4 PEOPLE','PACKAGE FOR 2-4 PEOPLE',TRUE,FALSE,FALSE,15000,'FIXED',TRUE,TRUE,1740894155000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
|
||||
('FJM2JJKF1WAE0','PACKAGE FOR 10-15 PEOPLE','PACKAGE FOR 10-15 PEOPLE',TRUE,FALSE,FALSE,47500,'FIXED',TRUE,TRUE,1741890382000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
|
||||
('WKKVKJ6SK0G46','PACKAGE FOR 20-25 PEOPLE','PACKAGE FOR 20-25 PEOPLE',TRUE,FALSE,FALSE,90000,'FIXED',TRUE,TRUE,1740894159000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
|
||||
('EQ3AN5BC9Z9GE','SEAFOOD DINNER PACKAGE 8-12 PEOPLE','SEAFOOD DINNER PACKAGE 8-12 PEOPLE',TRUE,FALSE,FALSE,45000,'FIXED',TRUE,TRUE,1741892428000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
|
||||
('6CZDMYB936MP6','LG TRAY JERK CHICKEN AND SHRIMP LOMEIN','LG TRAY JERK CHICKEN AND SHRIMP LOMEIN',FALSE,TRUE,FALSE,19000,'FIXED',TRUE,TRUE,1730731144000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
|
||||
('Q40XKXMYZSYVR','SM TRAY JERK CHICKEN AND SHRIMP LOMEIN','SM TRAY JERK CHICKEN AND SHRIMP LOMEIN',FALSE,TRUE,FALSE,9500,'FIXED',TRUE,TRUE,1730731082000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
|
||||
('1ENMJYCQCVBJ4','LG TRAY JERK CHICKEN AND SHRIMP RASTA PASTA','LG TRAY JERK CHICKEN AND SHRIMP RASTA PASTA',FALSE,TRUE,FALSE,24000,'FIXED',TRUE,TRUE,1730730985000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
|
||||
('RQGMC7Q52E8VT','SM TRAY JERK CHICKEN AND SHRIMP RASTA PASTA','SM TRAY JERK CHICKEN AND SHRIMP RASTA PASTA',FALSE,TRUE,FALSE,12000,'FIXED',TRUE,TRUE,1730730724000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
|
||||
('DEDEAPYWRV60P','SMALL TRAY COOKUP CODFISH','SMALL TRAY COOKUP CODFISH',FALSE,TRUE,FALSE,12500,'FIXED',TRUE,TRUE,1729783003000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
|
||||
('7JWZG5VQ2FNW0','1 DOZEN CODFISH FRITTER','1 DOZEN CODFISH FRITTER',FALSE,TRUE,FALSE,2400,'FIXED',TRUE,TRUE,1740879802000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
|
||||
('Z2592B3GP15PA','1 DOZEN FESTIVAL','1 DOZEN FESTIVAL',FALSE,TRUE,FALSE,2400,'FIXED',TRUE,TRUE,1740879835000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
|
||||
('J70AD21A88E5C','I DOZEN FRIED DUMPLING','I DOZEN FRIED DUMPLING',TRUE,FALSE,FALSE,1800,'FIXED',TRUE,TRUE,1740890803000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
|
||||
('TPK211AZHRN8E','Extra Pumpkin Or Spinach Rice-LG','Extra Pumpkin Or Spinach Rice-LG',FALSE,TRUE,FALSE,300,'FIXED',TRUE,TRUE,1726181679000,FALSE,TRUE,'','','','',0,0),
|
||||
('Z2R5SFT2FK2QP','Extra Pumpkin Or Spinach Rice-Reg','Extra Pumpkin Or Spinach Rice-Reg',FALSE,TRUE,FALSE,200,'FIXED',TRUE,TRUE,1726181558000,FALSE,TRUE,'','','','',0,0),
|
||||
('7KDDT1QWR57GC','Quinoi','Quinoi',TRUE,FALSE,FALSE,300,'FIXED',TRUE,TRUE,1741891643000,FALSE,TRUE,'','','','',0,0),
|
||||
('MGR82XDRZKFAE','Whole Cheese Cake','Whole Cheese Cake',FALSE,TRUE,FALSE,5500,'FIXED',TRUE,TRUE,1726006044000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
|
||||
('F1TB52KPDAFMG','Whole Chocolate Cake','Whole Chocolate Cake',FALSE,TRUE,FALSE,5500,'FIXED',TRUE,TRUE,1726005980000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
|
||||
('AFXX5PGWX3GK8','Whole Red Velvet Cake','Whole Red Velvet Cake',FALSE,TRUE,FALSE,5500,'FIXED',TRUE,TRUE,1726005743000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
|
||||
('WTJ801A75ZAK0','LG Tray Esco Chicken','LG Tray Esco Chicken',FALSE,TRUE,FALSE,14500,'FIXED',TRUE,TRUE,1725666148000,FALSE,TRUE,'','','','',0,0),
|
||||
('DMD91V58QH1FY','Esco Snapper Meal','Esco Snapper Meal',FALSE,TRUE,FALSE,2000,'FIXED',TRUE,TRUE,1741545187000,FALSE,TRUE,'','','','',0,0),
|
||||
('3867387970RC0','Steam Butter Fish','Steam Butter Fish',TRUE,FALSE,FALSE,1500,'FIXED',TRUE,TRUE,1741893336000,FALSE,TRUE,'','','','',0,0),
|
||||
('JF3GTKY0392YW','Bammy','Bammy',FALSE,TRUE,FALSE,200,'FIXED',TRUE,TRUE,1729341142000,FALSE,TRUE,'','','','',0,0),
|
||||
('DXBBZ2FCVGX02','Steam Snapper Fish','Steam Snapper Fish',FALSE,TRUE,FALSE,2200,'FIXED',TRUE,TRUE,1741893681000,FALSE,TRUE,'','','','',0,0),
|
||||
('1P1F01N96ZWJG','Roots Man','Roots Man',FALSE,TRUE,FALSE,600,'FIXED',TRUE,TRUE,1722967143000,FALSE,TRUE,'','','','',0,0),
|
||||
('SGC3DAXDH9ZBG','Fiwi Lemongrass Tea','Fiwi Lemongrass Tea',FALSE,TRUE,FALSE,499,'FIXED',TRUE,TRUE,1721353508000,FALSE,TRUE,'','','','',0,0),
|
||||
('8PJ2VCBEV4T1Y','Extra Mac & Cheese','Extra Mac & Cheese',FALSE,TRUE,FALSE,150,'FIXED',TRUE,TRUE,1717690516000,FALSE,TRUE,'','','','',0,0),
|
||||
('0Y0FJHHCWZKTT','LG Butter Shrimp (9 Shrimp)','LG Butter Shrimp (9 Shrimp)',FALSE,TRUE,FALSE,2100,'FIXED',TRUE,TRUE,1741547134000,FALSE,TRUE,'','','','',0,0),
|
||||
('7Y2H5RX0XSX3Y','Extra Rasta Pasta-LG','Extra Rasta Pasta-LG',FALSE,TRUE,FALSE,600,'FIXED',TRUE,TRUE,1716147863000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
|
||||
('1V7H5DXSCC2Y6','LG Shrimp LoMein','LG Shrimp LoMein',FALSE,TRUE,FALSE,22500,'FIXED',TRUE,TRUE,1715966030000,FALSE,TRUE,'','','','',0,0),
|
||||
('AWBNJDJQQBSD4','Stakehouse Cheesecake','Stakehouse Cheesecake',FALSE,TRUE,FALSE,550,'FIXED',TRUE,TRUE,1715048741000,FALSE,TRUE,'','','','',0,0),
|
||||
('V7EYC23J5C63J','Fruit Of Life Coconut Water','Fruit Of Life Coconut Water',TRUE,FALSE,FALSE,499,'FIXED',TRUE,TRUE,1740889976000,FALSE,TRUE,'','','','',0,0),
|
||||
('BFWXPV6B6Z2VC','Esco Saltfish Meal-Large','Esco Saltfish Meal-Large',FALSE,TRUE,FALSE,1700,'FIXED',TRUE,TRUE,1740887749000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
|
||||
('8A1HWX6V88F2Y','Esco Saltfish Meal-Regular ','Esco Saltfish Meal-Regular ',FALSE,TRUE,FALSE,1400,'FIXED',TRUE,TRUE,1712318630000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
|
||||
('3JP9245D50S9Y','Esco Saltfish Meal-Mini','Esco Saltfish Meal-Mini',FALSE,TRUE,FALSE,1100,'FIXED',TRUE,TRUE,1740887717000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
|
||||
('VECQK0YJZZEZ4','Soursop With Lime','Soursop With Lime',FALSE,TRUE,FALSE,599,'FIXED',TRUE,TRUE,1714774724000,FALSE,TRUE,'','','','',0,0),
|
||||
('2H4XS57ZD9D9T','SOURSOP LEAF W/LEMON & GINGER','SOURSOP LEAF W/LEMON & GINGER',FALSE,TRUE,FALSE,499,'FIXED',TRUE,TRUE,1711841482000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
|
||||
('YHDKA6KPQFXRC','Mackerel In Tomato Sauce-Large','Mackerel In Tomato Sauce-Large',FALSE,TRUE,FALSE,1500,'FIXED',TRUE,TRUE,1741888772000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
|
||||
('66JY11P9BC4BE','Mackerel In Tomato Sauce-Regular ','Mackerel In Tomato Sauce-Regular ',FALSE,TRUE,FALSE,1300,'FIXED',TRUE,TRUE,1741888756000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
|
||||
('DFYHFNPWVV9F2','Mackerel In Tomato Sauce-Mini','Mackerel In Tomato Sauce-Mini',FALSE,TRUE,FALSE,1000,'FIXED',TRUE,TRUE,1741888738000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
|
||||
('8NEBXD6RF0QSM','HTB Easter Bun','HTB Easter Bun',TRUE,FALSE,FALSE,1850,'FIXED',TRUE,TRUE,1740890798000,FALSE,TRUE,'','','','',0,0),
|
||||
('DM6RTJH8KVKS6','SM CAN TASTEE CHEESE 1.1 LB','SM CAN TASTEE CHEESE 1.1 LB',TRUE,FALSE,FALSE,2199,'FIXED',TRUE,TRUE,1742946771000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
|
||||
('28AKKR5RY1YJ8','SM CAN TASTEE CHEESE 8.8OZ','SM CAN TASTEE CHEESE 8.8OZ',TRUE,FALSE,FALSE,1399,'FIXED',TRUE,TRUE,1742946780000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
|
||||
('132FXN33YJB1P','9" Fruit Cake','9" Fruit Cake',FALSE,TRUE,FALSE,8500,'FIXED',TRUE,TRUE,1734399306000,FALSE,TRUE,'','','','',0,0),
|
||||
('4X9KC9YZV9HG8','8" Fruit Cake','8" Fruit Cake',FALSE,TRUE,FALSE,6500,'FIXED',TRUE,TRUE,1734399293000,FALSE,TRUE,'','','','',0,0),
|
||||
('4133GEYQJMDZ4','Extra shrimp Fried Rice-Mini','Extra shrimp Fried Rice-Mini',FALSE,TRUE,FALSE,250,'FIXED',TRUE,TRUE,1740889316000,FALSE,TRUE,'','','','',0,0),
|
||||
('WSWZD03ZPJXBA','Extra shrimp Fried Rice-Lge','Extra shrimp Fried Rice-Lge',FALSE,TRUE,FALSE,500,'FIXED',TRUE,TRUE,1741545377000,FALSE,TRUE,'','','','',0,0),
|
||||
('YW7Q7B1R1RBBT','Extra shrimp Fried Rice-Reg','Extra shrimp Fried Rice-Reg',FALSE,TRUE,FALSE,400,'FIXED',TRUE,TRUE,1709922064000,FALSE,TRUE,'','','','',0,0),
|
||||
('N7BQK2H3HJQ8Y','Mini Oxtail Rasta Pasta','Mini Oxtail Rasta Pasta',TRUE,FALSE,FALSE,1300,'FIXED',TRUE,TRUE,1741888888000,FALSE,TRUE,'','','','',0,0),
|
||||
('SGNFMEHEDME44','Oxtail Rasta Pasta-Large','Oxtail Rasta Pasta-Large',TRUE,FALSE,FALSE,2100,'FIXED',TRUE,TRUE,1742955126000,FALSE,TRUE,'','','','',0,0),
|
||||
('FXJRNBKNXPWN4','Oxtail Rasta Pasta-Regular','Oxtail Rasta Pasta-Regular',TRUE,FALSE,FALSE,1800,'FIXED',TRUE,TRUE,1741890372000,FALSE,TRUE,'','','','',0,0),
|
||||
('SCKJSJBP6655W','Gift card','Gift card',FALSE,TRUE,FALSE,0,'VARIABLE',FALSE,FALSE,1707342109000,FALSE,TRUE,NULL,'CLOVER_GIFT_CARD',NULL,NULL,NULL,NULL),
|
||||
('18RWNDTG7Q8H4','Bag of Cookies','Bag of Cookies',TRUE,FALSE,FALSE,895,'FIXED',TRUE,TRUE,1740880223000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
|
||||
('YGV43SZXQVJZ0','Extra Veggie Rice-Large','Extra Veggie Rice-Large',FALSE,TRUE,FALSE,400,'FIXED',TRUE,TRUE,1741545489000,FALSE,TRUE,'','','','',0,0),
|
||||
('E0YYQ8NBQ6K32','Extra Veggie Rice-Mini','Extra Veggie Rice-Mini',FALSE,TRUE,FALSE,150,'FIXED',TRUE,TRUE,1705960154000,FALSE,TRUE,'','','','',0,0),
|
||||
('TEDBEC74W0788','Extra Veggie Rice-Reg','Extra Veggie Rice-Reg',FALSE,TRUE,FALSE,250,'FIXED',TRUE,TRUE,1705960103000,FALSE,TRUE,'','','','',0,0),
|
||||
('7K83E2CGP5A72','Veggie Soup for 200 People','Veggie Soup for 200 People',TRUE,FALSE,FALSE,30000,'FIXED',TRUE,TRUE,1741894706000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
|
||||
('KZRXCM1SJY9RE','FRUIT CAKE W/ICING SLICE','FRUIT CAKE W/ICING SLICE',FALSE,TRUE,FALSE,495,'FIXED',TRUE,TRUE,1704486470000,FALSE,TRUE,'','','','',0,0),
|
||||
('66QNN6HH1Z28C','Jerk Chicken Fried Rice-LG','Jerk Chicken Fried Rice-LG',FALSE,TRUE,FALSE,1550,'FIXED',TRUE,TRUE,1732040452000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
|
||||
('GT6DEXTX8B57P','Jerk Chicken Fried Rice-Reg','Jerk Chicken Fried Rice-Reg',FALSE,TRUE,FALSE,1300,'FIXED',TRUE,TRUE,1740890941000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
|
||||
('ZJ6MH0DXTZ5VG','Jerk Chicken Fried Rice-Mini','Jerk Chicken Fried Rice-Mini',FALSE,TRUE,FALSE,1100,'FIXED',TRUE,TRUE,1740890971000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
|
||||
('79AZD841FFGYE','2024 Thanksgiving Package -20-25 People','2024 Thanksgiving Package -20-25 People',TRUE,FALSE,FALSE,90000,'FIXED',TRUE,TRUE,1740879977000,FALSE,TRUE,'','','','',0,0),
|
||||
('GM610GMEHNJ74','One Gallon Sorrel','One Gallon Sorrel',FALSE,TRUE,FALSE,3000,'FIXED',TRUE,TRUE,1699824929000,FALSE,TRUE,'','','','',0,0),
|
||||
('NDZV105PFGPD6','One Gallon Lemonade','One Gallon Lemonade',FALSE,TRUE,FALSE,3000,'FIXED',TRUE,TRUE,1699824899000,FALSE,TRUE,'','','','',0,0),
|
||||
('ZYP2T4VBEGYBR','One Gallon Carrot Juice','One Gallon Carrot Juice',FALSE,TRUE,FALSE,3000,'FIXED',TRUE,TRUE,1741890338000,FALSE,TRUE,'','','','',0,0),
|
||||
('TR67T34YF7HKA','2024 Thanksgiving Package 10-15 People','2024 Thanksgiving Package 10-15 People',TRUE,FALSE,FALSE,47500,'FIXED',TRUE,TRUE,1740879984000,FALSE,TRUE,'','','','',0,0),
|
||||
('TX3NFGDBEH7RT','Veggie LoMein Mini','Veggie LoMein Mini',FALSE,TRUE,FALSE,900,'FIXED',TRUE,TRUE,1699729415000,FALSE,TRUE,'','','','',0,0),
|
||||
('X1T0JPPCWR5MG','Veggie LoMein Reg','Veggie LoMein Reg',FALSE,TRUE,FALSE,1200,'FIXED',TRUE,TRUE,1699729309000,FALSE,TRUE,'','','','',0,0),
|
||||
('FFW13B7P1QX6A','Veggie LoMein Lge','Veggie LoMein Lge',FALSE,TRUE,FALSE,1500,'FIXED',TRUE,TRUE,1699729351000,FALSE,TRUE,'','','','',0,0),
|
||||
('D7K8GCZVRJH4Y','Reg Cow Foot Only','Reg Cow Foot Only',FALSE,TRUE,FALSE,1500,'FIXED',TRUE,TRUE,1697315969000,FALSE,TRUE,'','','','',0,0),
|
||||
('K09YHN1WEQ0RY','Reg Stew Peas Only','Reg Stew Peas Only',FALSE,TRUE,FALSE,1500,'FIXED',TRUE,TRUE,1697315886000,FALSE,TRUE,'','','','',0,0),
|
||||
('CPCQEYF5F3M8Y','CURRY SHRIMP & LOBSTER COMBO','CURRY SHRIMP & LOBSTER COMBO',TRUE,FALSE,FALSE,2200,'FIXED',TRUE,TRUE,1740887468000,FALSE,TRUE,'','','','',0,0),
|
||||
('KE17PY026J59W','1 Piece Crab Leg','1 Piece Crab Leg',FALSE,TRUE,FALSE,1000,'FIXED',TRUE,TRUE,1710544918000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
|
||||
('D8V3PSJTF79VP','Regular Steam King Fish','Regular Steam King Fish',FALSE,TRUE,FALSE,1600,'FIXED',TRUE,TRUE,1741901309000,FALSE,TRUE,'','','','',0,0),
|
||||
('NJ2AA5WVYEE9A','SHRIMP LO MEIN LG','SHRIMP LO MEIN LG',FALSE,TRUE,FALSE,1975,'FIXED',TRUE,TRUE,1741892531000,FALSE,TRUE,'','','','',0,0),
|
||||
('8KJ4A7K88XDNA','SHRIMP LO MEIN REG','SHRIMP LO MEIN REG',FALSE,TRUE,FALSE,1675,'FIXED',TRUE,TRUE,1741892562000,FALSE,TRUE,'','','','',0,0),
|
||||
('V3FGX8ZBXY052','SHRIMP LO MEIN-MINI','SHRIMP LO MEIN-MINI',FALSE,TRUE,FALSE,1300,'FIXED',TRUE,TRUE,1717700888000,FALSE,TRUE,'','','','',0,0),
|
||||
('9VP3RZ7EQFY84','This Is It Watermelon','This Is It Watermelon',FALSE,TRUE,FALSE,450,'FIXED',TRUE,TRUE,1701397722000,FALSE,TRUE,'','','','',0,0)
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
name = EXCLUDED.name,
|
||||
online_name = EXCLUDED.online_name,
|
||||
hidden = EXCLUDED.hidden,
|
||||
available = EXCLUDED.available,
|
||||
auto_manage = EXCLUDED.auto_manage,
|
||||
price = EXCLUDED.price,
|
||||
price_type = EXCLUDED.price_type,
|
||||
default_tax_rates = EXCLUDED.default_tax_rates,
|
||||
is_revenue = EXCLUDED.is_revenue,
|
||||
modified_time = EXCLUDED.modified_time,
|
||||
deleted = EXCLUDED.deleted,
|
||||
enabled_online = EXCLUDED.enabled_online,
|
||||
alternate_name = EXCLUDED.alternate_name,
|
||||
code = EXCLUDED.code,
|
||||
sku = EXCLUDED.sku,
|
||||
unit_name = EXCLUDED.unit_name,
|
||||
cost = EXCLUDED.cost,
|
||||
price_without_vat = EXCLUDED.price_without_vat
|
||||
;
|
||||
118
.devcontainer/post-create.sh
Normal file
118
.devcontainer/post-create.sh
Normal file
@@ -0,0 +1,118 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
echo "🚀 Setting up Business Buddy development environment..."
|
||||
|
||||
# Set up shell configuration
|
||||
echo "🐚 Configuring shell environment..."
|
||||
|
||||
# Ensure zsh is the default shell
|
||||
if command -v zsh &> /dev/null; then
|
||||
sudo chsh -s $(which zsh) vscode || true
|
||||
fi
|
||||
|
||||
# Create necessary directories for zsh plugins if they don't exist
|
||||
mkdir -p ~/.cache ~/.local/bin ~/.config
|
||||
|
||||
# Fix UV cache permissions
|
||||
if [ -d ~/.cache/uv ]; then
|
||||
sudo chown -R vscode:vscode ~/.cache/uv
|
||||
fi
|
||||
mkdir -p ~/.cache/uv
|
||||
touch ~/.cache/uv/CACHEDIR.TAG
|
||||
chmod -R 755 ~/.cache/uv
|
||||
|
||||
# Set up zsh history
|
||||
touch ~/.zsh_history
|
||||
chmod 600 ~/.zsh_history
|
||||
|
||||
# Copy host shell profiles if they exist
|
||||
if [ -f ~/.zshrc.host ]; then
|
||||
cp ~/.zshrc.host ~/.zshrc
|
||||
# Update paths in .zshrc for container environment
|
||||
sed -i 's|/home/vasceannie|/home/vscode|g' ~/.zshrc
|
||||
fi
|
||||
|
||||
if [ -f ~/.bashrc.host ]; then
|
||||
cp ~/.bashrc.host ~/.bashrc
|
||||
fi
|
||||
|
||||
# Source P10k config if it exists
|
||||
if [ -f ~/.p10k.zsh ]; then
|
||||
echo "✓ Powerlevel10k configuration found"
|
||||
fi
|
||||
|
||||
# Ensure we're in the workspace directory
|
||||
cd /workspace
|
||||
|
||||
# Create virtual environment if it doesn't exist
|
||||
if [ ! -d ".venv" ]; then
|
||||
echo "📦 Creating Python virtual environment..."
|
||||
python3.12 -m venv .venv
|
||||
fi
|
||||
|
||||
# Activate virtual environment
|
||||
source .venv/bin/activate
|
||||
|
||||
# Install UV in the virtual environment if not already installed
|
||||
if ! command -v uv &> /dev/null; then
|
||||
echo "📦 Installing UV package manager..."
|
||||
pip install --upgrade pip
|
||||
pip install uv
|
||||
fi
|
||||
|
||||
# Install project dependencies
|
||||
echo "📦 Installing Python dependencies..."
|
||||
uv pip install -e ".[dev]"
|
||||
|
||||
# Install pre-commit hooks
|
||||
echo "🔗 Installing pre-commit hooks..."
|
||||
pre-commit install || true
|
||||
|
||||
# Create .env file from example if it doesn't exist
|
||||
if [ ! -f ".env" ]; then
|
||||
if [ -f ".env.example" ]; then
|
||||
cp .env.example .env
|
||||
echo "📝 Created .env file from .env.example - please add your API keys"
|
||||
else
|
||||
echo "⚠️ No .env.example file found - please create a .env file with your API keys"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Initialize Task Master if not already initialized
|
||||
if [ ! -d ".taskmaster" ]; then
|
||||
echo "📋 Initializing Task Master AI..."
|
||||
task-master init --yes --rules claude cursor || true
|
||||
fi
|
||||
|
||||
# Wait for services to be ready
|
||||
echo "⏳ Waiting for services to start..."
|
||||
sleep 5
|
||||
|
||||
# Check database connection
|
||||
echo "🔍 Checking database connection..."
|
||||
until pg_isready -h postgres -p 5432 -U user; do
|
||||
echo "Waiting for PostgreSQL..."
|
||||
sleep 2
|
||||
done
|
||||
|
||||
# Check Redis connection
|
||||
echo "🔍 Checking Redis connection..."
|
||||
until redis-cli -h redis ping; do
|
||||
echo "Waiting for Redis..."
|
||||
sleep 2
|
||||
done
|
||||
|
||||
echo "✅ Development environment setup complete!"
|
||||
echo ""
|
||||
echo "📌 Next steps:"
|
||||
echo " 1. Add your API keys to the .env file"
|
||||
echo " 2. Run 'make test' to verify the setup"
|
||||
echo " 3. Start coding! 🎉"
|
||||
echo ""
|
||||
echo "📚 Useful commands:"
|
||||
echo " - make test : Run tests with coverage"
|
||||
echo " - make lint-all : Run all linters"
|
||||
echo " - make format : Format code"
|
||||
echo " - task-master list : View Task Master tasks"
|
||||
echo ""
|
||||
44
.devcontainer/setup.sh
Normal file
44
.devcontainer/setup.sh
Normal file
@@ -0,0 +1,44 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
echo "🚀 Setting up Business Buddy development environment..."
|
||||
|
||||
# Ensure we're in the workspace directory
|
||||
cd /workspace
|
||||
|
||||
# Create virtual environment if it doesn't exist
|
||||
if [ ! -d ".venv" ]; then
|
||||
echo "📦 Creating Python virtual environment..."
|
||||
uv venv .venv --python 3.12
|
||||
fi
|
||||
|
||||
# Activate virtual environment
|
||||
source .venv/bin/activate
|
||||
|
||||
# Install project dependencies using UV
|
||||
echo "📦 Installing Python dependencies..."
|
||||
uv pip install -e ".[dev]" || echo "⚠️ Some dependencies failed to install"
|
||||
|
||||
# Install pre-commit hooks
|
||||
echo "🔗 Installing pre-commit hooks..."
|
||||
pre-commit install || true
|
||||
|
||||
# Create .env file from example if it doesn't exist
|
||||
if [ ! -f ".env" ]; then
|
||||
if [ -f ".env.example" ]; then
|
||||
cp .env.example .env
|
||||
echo "📝 Created .env file from .env.example"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Fix permissions - ensure all directories exist first
|
||||
echo "🔧 Fixing permissions..."
|
||||
mkdir -p /home/vscode/.cache/uv
|
||||
sudo chown -R vscode:vscode /home/vscode/.cache || true
|
||||
|
||||
echo "✅ Development environment setup complete!"
|
||||
echo ""
|
||||
echo "📌 Next steps:"
|
||||
echo " 1. Add your API keys to the .env file"
|
||||
echo " 2. Run 'source .venv/bin/activate' to activate the virtual environment"
|
||||
echo " 3. Run 'make test' to verify the setup"
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -68,7 +68,8 @@ cover/
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
.cenv/
|
||||
.venv-host/
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
|
||||
@@ -2,17 +2,13 @@
|
||||
# Install with: pip install pre-commit && pre-commit install
|
||||
|
||||
repos:
|
||||
# Ruff - Fast Python linter and formatter
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.12.3
|
||||
# Black - Python code formatter
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 24.4.2
|
||||
hooks:
|
||||
- id: ruff
|
||||
name: ruff (linter)
|
||||
args: [--fix, --unsafe-fixes]
|
||||
types_or: [python, pyi]
|
||||
- id: ruff-format
|
||||
name: ruff (formatter)
|
||||
types_or: [python, pyi]
|
||||
- id: black
|
||||
language_version: python3
|
||||
types: [python, pyi]
|
||||
|
||||
# Basic file checks
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
@@ -44,29 +40,6 @@ repos:
|
||||
stages: [pre-commit]
|
||||
verbose: true
|
||||
|
||||
# Optional: Spell checking for documentation
|
||||
- repo: https://github.com/codespell-project/codespell
|
||||
rev: v2.4.1
|
||||
hooks:
|
||||
- id: codespell
|
||||
name: codespell
|
||||
description: Checks for common misspellings in text files
|
||||
entry: codespell
|
||||
language: python
|
||||
types: [text]
|
||||
exclude: |
|
||||
(?x)^(
|
||||
\.git/.*|
|
||||
\.pytest_cache/.*|
|
||||
\.venv/.*|
|
||||
__pycache__/.*|
|
||||
.*\.pyc|
|
||||
\.codespellignore|
|
||||
tests/cassettes/.*
|
||||
)$
|
||||
args: [--ignore-words=.codespellignore]
|
||||
|
||||
# Configuration for pre-commit
|
||||
default_stages: [pre-commit]
|
||||
fail_fast: false
|
||||
minimum_pre_commit_version: '3.0.0'
|
||||
|
||||
@@ -4,6 +4,15 @@ dist/
|
||||
__pycache__/
|
||||
.venv/
|
||||
venv/
|
||||
.cenv/
|
||||
.venv-host/
|
||||
**/.venv/
|
||||
**/venv/
|
||||
**/site-packages/
|
||||
**/lib/python*/
|
||||
**/bin/
|
||||
**/include/
|
||||
**/share/
|
||||
.git/
|
||||
.pytest_cache/
|
||||
.mypy_cache/
|
||||
@@ -14,3 +23,6 @@ htmlcov/
|
||||
*.pyo
|
||||
.archive/
|
||||
**/.archive/
|
||||
node_modules/
|
||||
cache/
|
||||
examples/
|
||||
|
||||
2
.vscode/settings.json
vendored
2
.vscode/settings.json
vendored
@@ -16,7 +16,7 @@
|
||||
"[python]": {
|
||||
"editor.defaultFormatter": "ms-python.black-formatter"
|
||||
},
|
||||
"python.pythonPath": "${workspaceFolder}/.venv/bin/python3.12",
|
||||
"python.pythonPath": "",
|
||||
"mypy-type-checker.interpreter": [
|
||||
"${workspaceFolder}/.venv/bin/python3.12"
|
||||
],
|
||||
|
||||
400
CLAUDE.local.md
400
CLAUDE.local.md
@@ -1,4 +1,4 @@
|
||||
# CLAUDE.md
|
||||
# CLAUDE.local.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
@@ -17,11 +17,20 @@ make test TEST_FILE=tests/unit_tests/nodes/llm/test_unit_call.py
|
||||
|
||||
# Run single test function
|
||||
pytest tests/path/to/test.py::test_function_name -v
|
||||
|
||||
# Run crash tests for resilience testing
|
||||
python tests/crash_tests/run_crash_tests.py
|
||||
|
||||
# Run specific test categories
|
||||
pytest -m "not slow" # Skip slow tests
|
||||
pytest -m integration # Only integration tests
|
||||
pytest -m e2e # Only end-to-end tests
|
||||
pytest -m unit # Only unit tests
|
||||
```
|
||||
|
||||
### Code Quality
|
||||
```bash
|
||||
# Run all linters (ruff, mypy, pyrefly, codespell) - ALWAYS run before committing
|
||||
# Run all linters (ruff, basedpyright, pyrefly, codespell) - ALWAYS run before committing
|
||||
make lint-all
|
||||
|
||||
# Format code with ruff
|
||||
@@ -32,6 +41,15 @@ make pre-commit
|
||||
|
||||
# Advanced type checking with Pyrefly
|
||||
pyrefly check .
|
||||
|
||||
# Lint single file
|
||||
make lint-file FILE_PATH=path/to/file.py
|
||||
|
||||
# Format single file
|
||||
make black FILE_PATH=path/to/file.py
|
||||
|
||||
# Generate coverage report
|
||||
make coverage-report
|
||||
```
|
||||
|
||||
### Development Services
|
||||
@@ -43,13 +61,142 @@ make start
|
||||
make stop
|
||||
|
||||
# Quick setup for new developers
|
||||
make setup
|
||||
|
||||
# Complete environment setup
|
||||
./scripts/setup-dev.sh
|
||||
```
|
||||
|
||||
### Package Management
|
||||
|
||||
The project uses UV for all package management:
|
||||
|
||||
```bash
|
||||
# Install main project with all packages
|
||||
uv pip install -e ".[dev]"
|
||||
|
||||
# Install individual packages for development
|
||||
uv pip install -e packages/business-buddy-core
|
||||
uv pip install -e packages/business-buddy-extraction
|
||||
uv pip install -e packages/business-buddy-tools
|
||||
|
||||
# Sync dependencies
|
||||
uv sync
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
This is a LangGraph-based ReAct (Reasoning and Action) agent system designed for business research and analysis.
|
||||
|
||||
### Project Structure
|
||||
|
||||
```
|
||||
biz-budz/ # Main project root (monorepo)
|
||||
├── src/biz_bud/ # Main application
|
||||
│ ├── agents/ # Specialized agent implementations
|
||||
│ ├── config/ # Configuration system
|
||||
│ ├── graphs/ # LangGraph workflow orchestration
|
||||
│ ├── nodes/ # Modular processing units
|
||||
│ ├── services/ # External dependency abstractions
|
||||
│ ├── states/ # TypedDict state management
|
||||
│ └── utils/ # Utility functions and helpers
|
||||
├── packages/ # Modular utility packages
|
||||
│ ├── business-buddy-core/ # Core utilities & helpers
|
||||
│ │ └── src/bb_core/
|
||||
│ │ ├── caching/ # Cache management system
|
||||
│ │ ├── edge_helpers/ # LangGraph edge routing utilities
|
||||
│ │ ├── errors/ # Error handling system
|
||||
│ │ ├── langgraph/ # LangGraph-specific utilities
|
||||
│ │ ├── logging/ # Logging system
|
||||
│ │ ├── networking/ # Network utilities
|
||||
│ │ ├── validation/ # Validation system
|
||||
│ │ ├── utils/ # General utilities
|
||||
│ ├── business-buddy-extraction/ # Entity & data extraction
|
||||
│ │ └── src/bb_extraction/
|
||||
│ │ ├── core/ # Core extraction framework
|
||||
│ │ ├── domain/ # Domain-specific extractors
|
||||
│ │ ├── numeric/ # Numeric data extraction
|
||||
│ │ ├── statistics/ # Statistical extraction
|
||||
│ │ ├── text/ # Text processing
|
||||
│ │ └── tools.py # Extraction tools
|
||||
│ └── business-buddy-tools/ # Web tools, scrapers, API clients
|
||||
│ └── src/bb_tools/
|
||||
│ ├── actions/ # High-level action workflows
|
||||
│ ├── api_clients/ # API client implementations
|
||||
│ │ ├── arxiv.py # ArXiv API client
|
||||
│ │ ├── base.py # Base API client
|
||||
│ │ ├── firecrawl.py # Firecrawl API client
|
||||
│ │ ├── jina.py # Jina API client
|
||||
│ │ ├── paperless.py # Paperless NGX client
|
||||
│ │ ├── r2r.py # R2R API client
|
||||
│ │ └── tavily.py # Tavily search client
|
||||
│ ├── apis/ # API abstractions
|
||||
│ │ ├── arxiv.py # ArXiv API abstraction
|
||||
│ │ ├── firecrawl.py # Firecrawl API abstraction
|
||||
│ │ └── jina/ # Jina API modules
|
||||
│ ├── browser/ # Browser automation
|
||||
│ │ ├── base.py # Base browser interface
|
||||
│ │ ├── browser.py # Selenium browser implementation
|
||||
│ │ ├── browser_helper.py # Browser utilities
|
||||
│ │ ├── driverless_browser.py # Driverless browser
|
||||
│ │ └── js/ # JavaScript utilities
|
||||
│ │ └── overlay.js # Browser overlay script
|
||||
│ ├── flows/ # Workflow implementations
|
||||
│ │ ├── agent_creator.py # Agent creation workflows
|
||||
│ │ ├── catalog_inspect.py # Catalog inspection
|
||||
│ │ ├── fetch.py # Content fetching flows
|
||||
│ │ ├── human_assistance.py # Human interaction flows
|
||||
│ │ ├── md_processing.py # Markdown processing
|
||||
│ │ ├── query_processing.py # Query processing
|
||||
│ │ ├── report_gen.py # Report generation
|
||||
│ │ └── scrape.py # Scraping workflows
|
||||
│ ├── interfaces/ # Protocol definitions
|
||||
│ │ └── web_tools.py # Web tools protocols
|
||||
│ ├── loaders/ # Data loaders
|
||||
│ │ └── web_base_loader.py # Web content loader
|
||||
│ ├── r2r/ # R2R integration
|
||||
│ │ └── tools.py # R2R tools
|
||||
│ ├── scrapers/ # Web scraping implementations
|
||||
│ │ ├── base.py # Base scraper interface
|
||||
│ │ ├── beautiful_soup.py # BeautifulSoup scraper
|
||||
│ │ ├── pymupdf.py # PyMuPDF scraper
|
||||
│ │ ├── strategies/ # Scraping strategies
|
||||
│ │ │ ├── beautifulsoup.py # BeautifulSoup strategy
|
||||
│ │ │ ├── firecrawl.py # Firecrawl strategy
|
||||
│ │ │ └── jina.py # Jina strategy
|
||||
│ │ ├── unified.py # Unified scraper
|
||||
│ │ ├── unified_scraper.py # Alternative unified scraper
|
||||
│ │ └── utils/ # Scraping utilities
|
||||
│ │ └── __init__.py # Scraping utilities
|
||||
│ ├── search/ # Search implementations
|
||||
│ │ ├── base.py # Base search interface
|
||||
│ │ ├── providers/ # Search providers
|
||||
│ │ │ ├── arxiv.py # ArXiv search provider
|
||||
│ │ │ ├── jina.py # Jina search provider
|
||||
│ │ │ └── tavily.py # Tavily search provider
|
||||
│ │ ├── unified.py # Unified search tool
|
||||
│ │ └── web_search.py # Web search implementation
|
||||
│ ├── stores/ # Data storage
|
||||
│ │ └── database.py # Database storage utilities
|
||||
│ ├── stubs/ # Type stubs
|
||||
│ │ ├── langgraph.pyi # LangGraph stubs
|
||||
│ │ └── r2r.pyi # R2R stubs
|
||||
│ ├── utils/ # Tool utilities
|
||||
│ │ └── html_utils.py # HTML processing utilities
|
||||
│ ├── constants.py # Tool constants
|
||||
│ ├── interfaces.py # Tool interfaces
|
||||
│ └── models.py # Data models
|
||||
├── tests/ # Comprehensive test suite
|
||||
│ ├── unit_tests/ # Unit tests with mocks
|
||||
│ ├── integration_tests/ # Integration tests
|
||||
│ ├── e2e/ # End-to-end tests
|
||||
│ ├── crash_tests/ # Resilience & failure tests
|
||||
│ └── manual/ # Manual test scripts
|
||||
├── docker/ # Docker configurations
|
||||
├── scripts/ # Development and deployment scripts
|
||||
└── .taskmaster/ # Task Master AI project files
|
||||
```
|
||||
|
||||
### Core Components
|
||||
|
||||
1. **Graphs** (`src/biz_bud/graphs/`): Define workflow orchestration using LangGraph state machines
|
||||
@@ -64,6 +211,7 @@ This is a LangGraph-based ReAct (Reasoning and Action) agent system designed for
|
||||
- `llm/`: LLM interaction layer
|
||||
- `research/`: Web search, extraction, synthesis with optimization
|
||||
- `validation/`: Content and logic validation, human feedback
|
||||
- `integrations/`: External service integrations (Firecrawl, Repomix, etc.)
|
||||
|
||||
3. **States** (`src/biz_bud/states/`): TypedDict-based state management for type safety across workflows
|
||||
|
||||
@@ -72,11 +220,23 @@ This is a LangGraph-based ReAct (Reasoning and Action) agent system designed for
|
||||
- Database (PostgreSQL via asyncpg)
|
||||
- Vector store (Qdrant)
|
||||
- Cache (Redis)
|
||||
- Singleton management for expensive resources
|
||||
|
||||
5. **Configuration** (`src/biz_bud/config/`): Multi-source configuration system
|
||||
- Pydantic models for validation
|
||||
5. **Agents** (`src/biz_bud/agents/`): Specialized agent implementations
|
||||
- `ngx_agent.py`: Paperless NGX integration agent
|
||||
- `rag_agent.py`: RAG workflow agent
|
||||
- `research_agent.py`: Research automation agent
|
||||
|
||||
6. **Configuration** (`src/biz_bud/config/`): Multi-source configuration system
|
||||
- Schema-based configuration (`schemas/`): Typed configuration models
|
||||
- Environment variables override `config.yaml` defaults
|
||||
- LLM profiles (tiny, small, large, reasoning)
|
||||
- Service-specific configurations
|
||||
|
||||
7. **Packages** (`packages/`): Modular utility libraries
|
||||
- **business-buddy-core**: Core utilities, error handling, edge helpers
|
||||
- **business-buddy-extraction**: Entity and data extraction tools
|
||||
- **business-buddy-tools**: Web tools, scrapers, API clients
|
||||
|
||||
### Key Design Patterns
|
||||
|
||||
@@ -227,166 +387,106 @@ uv run pre-commit install
|
||||
- Use RunnableConfig to make services accessible and configurable at runtime.
|
||||
- Employ decorators and wrappers to add cross-cutting concerns like logging, caching, or metrics without cluttering core logic.
|
||||
|
||||
## Commands
|
||||
- Never use `pip` directly - always use `uv`
|
||||
- Don't modify `tasks.json` manually when using Task Master
|
||||
- Always run `make lint-all` before committing
|
||||
- Use absolute imports from package roots
|
||||
- Avoid circular imports between packages
|
||||
- Always use the centralized ServiceFactory for external services
|
||||
- Never hardcode API keys - use environment variables
|
||||
|
||||
### Testing
|
||||
```bash
|
||||
# Run all tests with coverage (uses pytest-xdist for parallel execution)
|
||||
make test
|
||||
## Configuration System
|
||||
|
||||
# Run tests in watch mode
|
||||
make test_watch
|
||||
Business Buddy uses a sophisticated configuration system:
|
||||
|
||||
# Run specific test file
|
||||
make test TEST_FILE=tests/unit_tests/nodes/llm/test_unit_call.py
|
||||
### Schema-based Configuration (`src/biz_bud/config/schemas/`)
|
||||
- `analysis.py`: Analysis configuration schemas
|
||||
- `app.py`: Application-wide settings
|
||||
- `core.py`: Core configuration types
|
||||
- `llm.py`: LLM provider configurations
|
||||
- `research.py`: Research workflow settings
|
||||
- `services.py`: External service configurations
|
||||
- `tools.py`: Tool-specific settings
|
||||
|
||||
# Run single test function
|
||||
pytest tests/path/to/test.py::test_function_name -v
|
||||
### Usage
|
||||
```python
|
||||
from biz_bud.config.loader import load_config
|
||||
config = load_config()
|
||||
|
||||
# Access typed configurations
|
||||
llm_config = config.llm_config
|
||||
research_config = config.research_config
|
||||
service_config = config.service_config
|
||||
```
|
||||
|
||||
### Code Quality
|
||||
## Docker Services
|
||||
|
||||
The project requires these services (via Docker):
|
||||
- **PostgreSQL**: Main database with asyncpg
|
||||
- **Redis**: Caching and session management
|
||||
- **Qdrant**: Vector database for embeddings
|
||||
|
||||
Start all services:
|
||||
```bash
|
||||
# Run all linters (ruff, mypy, pyrefly, codespell) - ALWAYS run before committing
|
||||
make lint-all
|
||||
|
||||
# Format code with ruff
|
||||
make format
|
||||
|
||||
# Run pre-commit hooks (recommended)
|
||||
make pre-commit
|
||||
|
||||
# Advanced type checking with Pyrefly
|
||||
pyrefly check .
|
||||
make start # Uses docker/compose-dev.yaml
|
||||
make stop # Stop and clean up
|
||||
```
|
||||
|
||||
## Architecture
|
||||
## Import Guidelines
|
||||
|
||||
This is a LangGraph-based ReAct (Reasoning and Action) agent system designed for business research and analysis.
|
||||
```python
|
||||
# From main application
|
||||
from biz_bud.nodes.analysis import data_node
|
||||
from biz_bud.services.factory import ServiceFactory
|
||||
from biz_bud.states.research import ResearchState
|
||||
|
||||
### Core Components
|
||||
# From packages (always use full path)
|
||||
from bb_core.edge_helpers import create_conditional_edge
|
||||
from bb_extraction.domain import CompanyExtractor
|
||||
from bb_tools.api_clients import TavilyClient
|
||||
|
||||
1. **Graphs** (`src/biz_bud/graphs/`): Define workflow orchestration using LangGraph state machines
|
||||
- `research.py`: Market research workflow implementation
|
||||
- `graph.py`: Main agent graph with reasoning and action cycles
|
||||
- `research_agent.py`: Research-specific agent workflow
|
||||
- `menu_intelligence.py`: Menu analysis subgraph
|
||||
|
||||
2. **Nodes** (`src/biz_bud/nodes/`): Modular processing units
|
||||
- `analysis/`: Data analysis, interpretation, planning, visualization
|
||||
- `core/`: Input/output handling, error management
|
||||
- `llm/`: LLM interaction layer
|
||||
- `research/`: Web search, extraction, synthesis with optimization
|
||||
- `validation/`: Content and logic validation, human feedback
|
||||
|
||||
3. **States** (`src/biz_bud/states/`): TypedDict-based state management for type safety across workflows
|
||||
|
||||
4. **Services** (`src/biz_bud/services/`): Abstract external dependencies
|
||||
- LLM providers (Anthropic, OpenAI, Google, Cohere, etc.)
|
||||
- Database (PostgreSQL via asyncpg)
|
||||
- Vector store (Qdrant)
|
||||
- Cache (Redis)
|
||||
|
||||
5. **Configuration** (`src/biz_bud/config/`): Multi-source configuration system
|
||||
- Pydantic models for validation
|
||||
- Environment variables override `config.yaml` defaults
|
||||
- LLM profiles (tiny, small, large, reasoning)
|
||||
|
||||
### Key Design Patterns
|
||||
|
||||
- **State-Driven Workflows**: All graphs use TypedDict states for type-safe data flow
|
||||
- **Decorator Pattern**: `@log_config` and `@error_handling` for cross-cutting concerns
|
||||
- **Service Abstraction**: Clean interfaces for external dependencies
|
||||
- **Modular Nodes**: Each node has a single responsibility and can be tested independently
|
||||
- **Parallel Processing**: Search and extraction operations utilize asyncio for performance
|
||||
|
||||
### Testing Strategy
|
||||
|
||||
- Unit tests in `tests/unit_tests/` with mocked dependencies
|
||||
- Integration tests in `tests/integration_tests/` for full workflows
|
||||
- E2E tests in `tests/e2e/` for complete system validation
|
||||
- VCR cassettes for API mocking in `tests/cassettes/`
|
||||
- Test markers: `slow`, `integration`, `unit`, `e2e`, `web`, `browser`
|
||||
- Coverage requirement: 70% minimum
|
||||
|
||||
### Test Architecture
|
||||
|
||||
#### Test Organization
|
||||
- **Naming Convention**: All test files follow `test_*.py` pattern
|
||||
- Unit tests: `test_<module_name>.py`
|
||||
- Integration tests: `test_<feature>_integration.py`
|
||||
- E2E tests: `test_<workflow>_e2e.py`
|
||||
- Manual tests: `test_<feature>_manual.py`
|
||||
|
||||
#### Test Helpers (`tests/helpers/`)
|
||||
- **Assertions** (`assertions/custom_assertions.py`): Reusable assertion functions
|
||||
- **Factories** (`factories/state_factories.py`): State builders for creating test data
|
||||
- **Fixtures** (`fixtures/`): Shared pytest fixtures
|
||||
- `config_fixtures.py`: Configuration mocks and test configs
|
||||
- `mock_fixtures.py`: Common mock objects
|
||||
- **Mocks** (`mocks/mock_builders.py`): Builder classes for complex mocks
|
||||
- `MockLLMBuilder`: Creates mock LLM clients with configurable responses
|
||||
- `StateBuilder`: Creates typed state objects for workflows
|
||||
|
||||
#### Key Testing Patterns
|
||||
1. **Async Testing**: Use `@pytest.mark.asyncio` for async functions
|
||||
2. **Mock Builders**: Use builder pattern for complex mocks
|
||||
```python
|
||||
mock_llm = MockLLMBuilder()
|
||||
.with_model("gpt-4")
|
||||
.with_response("Test response")
|
||||
.build()
|
||||
```
|
||||
3. **State Factories**: Create valid state objects easily
|
||||
```python
|
||||
state = StateBuilder.research_state()
|
||||
.with_query("test query")
|
||||
.with_search_results([...])
|
||||
.build()
|
||||
```
|
||||
4. **Service Factory Mocking**: Mock the service factory for dependency injection
|
||||
```python
|
||||
with patch("biz_bud.utils.service_helpers.get_service_factory",
|
||||
return_value=mock_service_factory):
|
||||
# Test code here
|
||||
```
|
||||
|
||||
#### Common Test Patterns
|
||||
- **E2E Workflow Tests**: Test complete workflows with mocked external services
|
||||
- **Resilient Node Tests**: Nodes should handle failures gracefully
|
||||
- Extraction continues even if vector storage fails
|
||||
- Partial results are returned when some operations fail
|
||||
- **Configuration Tests**: Validate Pydantic models and config schemas
|
||||
- **Import Testing**: Ensure all public APIs are importable
|
||||
|
||||
### Environment Setup
|
||||
|
||||
```bash
|
||||
# Prerequisites: Python 3.12+, UV package manager, Docker
|
||||
|
||||
# Create and activate virtual environment
|
||||
uv venv
|
||||
source .venv/bin/activate # Always use this activation path
|
||||
|
||||
# Install dependencies with UV
|
||||
uv pip install -e ".[dev]"
|
||||
|
||||
# Install pre-commit hooks
|
||||
uv run pre-commit install
|
||||
|
||||
# Create .env file with required API keys:
|
||||
# TAVILY_API_KEY=your_key
|
||||
# OPENAI_API_KEY=your_key (or other LLM provider keys)
|
||||
# Never use relative imports across packages
|
||||
```
|
||||
|
||||
## Development Principles
|
||||
## Architectural Patterns
|
||||
|
||||
- **Type Safety**: No `Any` types or `# type: ignore` annotations allowed
|
||||
- **Documentation**: Imperative docstrings with punctuation
|
||||
- **Package Management**: Always use UV, not pip
|
||||
- **Pre-commit**: Never skip pre-commit checks
|
||||
- **Testing**: Write tests for new functionality, maintain 70%+ coverage
|
||||
- **Error Handling**: Use centralized decorators for consistency
|
||||
### State Management
|
||||
- All states are TypedDict-based for type safety
|
||||
- States are immutable within nodes
|
||||
- Use reducer functions for state updates
|
||||
|
||||
## Development Warnings
|
||||
### Service Factory Pattern
|
||||
- Centralized service creation via `ServiceFactory`
|
||||
- Singleton management for expensive resources
|
||||
- Dependency injection throughout
|
||||
|
||||
- Do not try and launch 'langgraph dev' or any variation
|
||||
### Error Handling
|
||||
- Centralized error aggregation
|
||||
- Namespace-based error routing
|
||||
- Comprehensive telemetry integration
|
||||
|
||||
### Async-First Design
|
||||
- All I/O operations are async
|
||||
- Proper connection pooling
|
||||
- Graceful degradation on failures
|
||||
|
||||
## Development Tools
|
||||
|
||||
### Pyrefly Configuration (`pyrefly.toml`)
|
||||
- Advanced type checking beyond mypy
|
||||
- Monorepo-aware with package path resolution
|
||||
- Custom import handling for external libraries
|
||||
|
||||
### Pre-commit Hooks (`.pre-commit-config.yaml`)
|
||||
- Automated code quality checks
|
||||
- Includes ruff, pyrefly, codespell
|
||||
- File size and merge conflict checks
|
||||
|
||||
### Additional Make Commands
|
||||
```bash
|
||||
make setup # Complete setup for new machines
|
||||
make lint-file FILE_PATH=path/to/file.py # Single file linting
|
||||
make black FILE_PATH=path/to/file.py # Format single file
|
||||
make pyrefly # Run pyrefly type checking
|
||||
make coverage-report # Generate HTML coverage report
|
||||
```
|
||||
44
Makefile
44
Makefile
@@ -13,13 +13,19 @@ TEST_FILE ?= tests/
|
||||
ifeq ($(OS),Windows_NT)
|
||||
ACTIVATE = .venv\Scripts\activate
|
||||
PYTHON = python
|
||||
else ifneq (,$(wildcard .venv-host))
|
||||
ACTIVATE = source .venv-host/bin/activate
|
||||
PYTHON = python3
|
||||
else ifneq (,$(wildcard .cenv))
|
||||
ACTIVATE = source .cenv/bin/activate
|
||||
PYTHON = python3
|
||||
else
|
||||
ACTIVATE = source .venv/bin/activate
|
||||
PYTHON = python3
|
||||
endif
|
||||
|
||||
test:
|
||||
@bash -c "$(ACTIVATE) && PYTHONPATH=src coverage run --source=biz_bud -m pytest -n 4 tests/"
|
||||
@bash -c "$(ACTIVATE) && PYTHONPATH=src coverage run --source=biz_bud -m pytest -n 8 tests/"
|
||||
@bash -c "$(ACTIVATE) && coverage report --show-missing"
|
||||
|
||||
test_watch:
|
||||
@@ -112,7 +118,7 @@ lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||
# Legacy lint targets - now use pre-commit
|
||||
lint lint_diff lint_package lint_tests:
|
||||
@echo "Running linting via pre-commit hooks..."
|
||||
@bash -c "$(ACTIVATE) && pre-commit run ruff --all-files"
|
||||
@bash -c "$(ACTIVATE) && pre-commit run black --all-files"
|
||||
|
||||
pyrefly:
|
||||
@bash -c "$(ACTIVATE) && pyrefly check ."
|
||||
@@ -121,33 +127,23 @@ pyrefly:
|
||||
lint-all: pre-commit
|
||||
@echo "\n🔍 Running additional type checks..."
|
||||
@bash -c "$(ACTIVATE) && pyrefly check . || true"
|
||||
@bash -c "$(ACTIVATE) && ruff check . || true"
|
||||
|
||||
# Run pre-commit hooks (single source of truth for linting)
|
||||
pre-commit:
|
||||
@echo "🔧 Running pre-commit hooks..."
|
||||
@echo "This includes: ruff (lint + format), pyrefly, codespell, and file checks"
|
||||
@echo "This includes: black (format), pyrefly, and file checks"
|
||||
@bash -c "$(ACTIVATE) && pre-commit run --all-files"
|
||||
|
||||
# Format code using pre-commit
|
||||
format format_diff:
|
||||
@echo "Formatting code via pre-commit..."
|
||||
@bash -c "$(ACTIVATE) && pre-commit run ruff-format --all-files || true"
|
||||
@bash -c "$(ACTIVATE) && pre-commit run ruff --all-files || true"
|
||||
|
||||
spell_check:
|
||||
@bash -c "$(ACTIVATE) && codespell --toml pyproject.toml"
|
||||
|
||||
spell_fix:
|
||||
@bash -c "$(ACTIVATE) && codespell --toml pyproject.toml -w"
|
||||
@bash -c "$(ACTIVATE) && pre-commit run black --all-files || true"
|
||||
|
||||
# Single file linting for hooks (expects FILE_PATH variable)
|
||||
lint-file:
|
||||
ifdef FILE_PATH
|
||||
@echo "🔍 Linting $(FILE_PATH)..."
|
||||
@bash -c "$(ACTIVATE) && pyrefly check '$(FILE_PATH)'"
|
||||
@bash -c "$(ACTIVATE) && ruff check '$(FILE_PATH)' --fix"
|
||||
@bash -c "$(ACTIVATE) && pyright '$(FILE_PATH)'"
|
||||
@echo "✅ Linting complete"
|
||||
else
|
||||
@echo "❌ FILE_PATH not provided"
|
||||
@@ -164,6 +160,20 @@ else
|
||||
@exit 1
|
||||
endif
|
||||
|
||||
######################
|
||||
# LANGGRAPH DEVELOPMENT
|
||||
######################
|
||||
|
||||
# Start LangGraph development server with proper container configuration
|
||||
langgraph-dev:
|
||||
@echo "🚀 Starting LangGraph development server..."
|
||||
@bash -c "$(ACTIVATE) && langgraph dev --host 0.0.0.0 --port 2024"
|
||||
|
||||
# Start LangGraph development server with local studio
|
||||
langgraph-dev-local:
|
||||
@echo "🚀 Starting LangGraph development server with local studio..."
|
||||
@bash -c "$(ACTIVATE) && langgraph dev --host 0.0.0.0 --port 2024 --studio-local"
|
||||
|
||||
######################
|
||||
# HELP
|
||||
######################
|
||||
@@ -174,8 +184,8 @@ help:
|
||||
@echo 'start - start Docker services (postgres, redis, qdrant)'
|
||||
@echo 'stop - stop Docker services'
|
||||
@echo 'format - run code formatters'
|
||||
@echo 'lint - run ruff linter'
|
||||
@echo 'lint-all - run all linters and type checkers (ruff, mypy, pyrefly, codespell)'
|
||||
@echo 'lint - run black formatter via pre-commit'
|
||||
@echo 'lint-all - run all type checkers (pyrefly)'
|
||||
@echo 'pre-commit - run all pre-commit hooks'
|
||||
@echo 'pyrefly - run pyrefly type checking'
|
||||
@echo 'test - run unit tests in parallel (pytest-xdist) with coverage'
|
||||
@@ -183,6 +193,8 @@ help:
|
||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
||||
@echo 'test_watch - run unit tests in watch mode'
|
||||
@echo 'coverage-report - generate HTML coverage report at htmlcov/index.html'
|
||||
@echo 'langgraph-dev - start LangGraph dev server (for containers/devcontainer)'
|
||||
@echo 'langgraph-dev-local - start LangGraph dev server with local studio'
|
||||
@echo 'tree - show tree of .py files in src/'
|
||||
|
||||
coverage-report:
|
||||
|
||||
@@ -49,6 +49,7 @@ logging:
|
||||
# LLM profiles configuration
|
||||
# Env Override: e.g., TINY_LLM_NAME, LARGE_LLM_TEMPERATURE
|
||||
llm_config:
|
||||
default_profile: "large" # Options: tiny, small, large, reasoning
|
||||
tiny:
|
||||
name: "openai/gpt-4.1-mini"
|
||||
temperature: 0.7
|
||||
|
||||
0
docker/entrypoint.sh
Executable file → Normal file
0
docker/entrypoint.sh
Executable file → Normal file
0
examples/crawl_r2r_docs.py
Executable file → Normal file
0
examples/crawl_r2r_docs.py
Executable file → Normal file
0
examples/crawl_r2r_docs_fixed.py
Executable file → Normal file
0
examples/crawl_r2r_docs_fixed.py
Executable file → Normal file
@@ -84,7 +84,7 @@ async def example_search_and_scrape():
|
||||
print(f"Found and scraped {len(results)} search results")
|
||||
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, dict):
|
||||
if result:
|
||||
print(f"\n{i + 1}. {result.get('title', 'No title')}")
|
||||
print(f" URL: {result.get('url', 'No URL')}")
|
||||
markdown = result.get("markdown")
|
||||
|
||||
@@ -50,7 +50,7 @@ async def test_rag_agent_with_firecrawl():
|
||||
# Show processing result
|
||||
if result.get("processing_result"):
|
||||
processing_result = result["processing_result"]
|
||||
if processing_result and isinstance(processing_result, dict):
|
||||
if processing_result:
|
||||
if processing_result.get("skipped"):
|
||||
print(f"\nSkipped: {processing_result.get('reason')}")
|
||||
else:
|
||||
|
||||
@@ -8,7 +8,8 @@
|
||||
"catalog_research": "./src/biz_bud/graphs/catalog_research.py:catalog_research_factory",
|
||||
"url_to_r2r": "./src/biz_bud/graphs/url_to_r2r.py:url_to_r2r_graph_factory",
|
||||
"rag_agent": "./src/biz_bud/agents/rag_agent.py:create_rag_agent_for_api",
|
||||
"error_handling": "./src/biz_bud/graphs/error_handling.py:error_handling_graph_factory"
|
||||
"error_handling": "./src/biz_bud/graphs/error_handling.py:error_handling_graph_factory",
|
||||
"paperless_ngx_agent": "./src/biz_bud/agents/ngx_agent.py:paperless_ngx_agent_factory"
|
||||
},
|
||||
"env": ".env"
|
||||
}
|
||||
|
||||
@@ -105,10 +105,10 @@ from bb_core.errors import (
|
||||
|
||||
# Helpers
|
||||
from bb_core.helpers import (
|
||||
_is_sensitive_field,
|
||||
_redact_sensitive_data,
|
||||
create_error_details,
|
||||
is_sensitive_field,
|
||||
preserve_url_fields,
|
||||
redact_sensitive_data,
|
||||
safe_serialize_response,
|
||||
)
|
||||
|
||||
@@ -266,8 +266,8 @@ __all__ = [
|
||||
"Tone",
|
||||
# Helpers
|
||||
"preserve_url_fields",
|
||||
"_is_sensitive_field",
|
||||
"_redact_sensitive_data",
|
||||
"is_sensitive_field",
|
||||
"redact_sensitive_data",
|
||||
"safe_serialize_response",
|
||||
# Networking
|
||||
"gather_with_concurrency",
|
||||
|
||||
@@ -106,3 +106,14 @@ class CacheBackend(ABC):
|
||||
"""
|
||||
for key in keys:
|
||||
await self.delete(key)
|
||||
|
||||
async def ainit(self) -> None:
|
||||
"""Initialize the cache backend.
|
||||
|
||||
This method can be overridden by implementations that need
|
||||
async initialization. The default implementation does nothing.
|
||||
|
||||
Note: This is intentionally non-abstract to provide a default
|
||||
implementation for backends that don't need initialization.
|
||||
"""
|
||||
return None
|
||||
|
||||
@@ -52,7 +52,7 @@ class AsyncFileCacheBackend[T](CacheBackend[T]):
|
||||
async def ainit(self) -> None:
|
||||
"""Async initialization method for compatibility."""
|
||||
if not self._initialized:
|
||||
await self._file_cache._ensure_initialized()
|
||||
await self._file_cache.ensure_initialized()
|
||||
self._initialized = True
|
||||
|
||||
async def get(self, key: str) -> T | None:
|
||||
|
||||
@@ -48,7 +48,7 @@ class CacheBackend[T](ABC):
|
||||
"""Clear all cache entries."""
|
||||
...
|
||||
|
||||
async def ainit(self) -> None: # noqa: B027
|
||||
async def ainit(self) -> None:
|
||||
"""Initialize the cache backend.
|
||||
|
||||
This method can be overridden by implementations that need
|
||||
@@ -57,4 +57,4 @@ class CacheBackend[T](ABC):
|
||||
Note: This is intentionally non-abstract to provide a default
|
||||
implementation for backends that don't need initialization.
|
||||
"""
|
||||
pass
|
||||
return None
|
||||
|
||||
@@ -191,7 +191,7 @@ def cache_async(
|
||||
# Convert ParamSpecKwargs to dict for processing
|
||||
kwargs_dict = cast("dict[str, object]", kwargs)
|
||||
force_refresh = kwargs_dict.pop("force_refresh", False)
|
||||
kwargs = cast("P.kwargs", kwargs_dict)
|
||||
# Note: kwargs remains unchanged since we work with kwargs_dict
|
||||
|
||||
# Generate cache key (excluding force_refresh from key generation)
|
||||
try:
|
||||
@@ -207,7 +207,7 @@ def cache_async(
|
||||
cache_key = _generate_cache_key(
|
||||
func.__name__,
|
||||
cast("tuple[object, ...]", args),
|
||||
cast("dict[str, object]", kwargs),
|
||||
kwargs_dict,
|
||||
key_prefix,
|
||||
)
|
||||
except Exception:
|
||||
|
||||
@@ -55,11 +55,18 @@ class FileCache(CacheBackend):
|
||||
self.serializer = serializer
|
||||
self.key_prefix = key_prefix
|
||||
self._initialized = False
|
||||
self._init_lock = asyncio.Lock()
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
async def _ensure_initialized(self) -> None:
|
||||
async def ensure_initialized(self) -> None:
|
||||
"""Ensure cache directory exists."""
|
||||
if not self._initialized:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
async with self._init_lock:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
self.cache_dir.mkdir, parents=True, exist_ok=True
|
||||
@@ -90,7 +97,7 @@ class FileCache(CacheBackend):
|
||||
Returns:
|
||||
Cached bytes or None if not found or expired
|
||||
"""
|
||||
await self._ensure_initialized()
|
||||
await self.ensure_initialized()
|
||||
file_path = self._get_file_path(key)
|
||||
|
||||
try:
|
||||
@@ -165,7 +172,7 @@ class FileCache(CacheBackend):
|
||||
value: Value to store as bytes
|
||||
ttl: Time-to-live in seconds (None uses default TTL)
|
||||
"""
|
||||
await self._ensure_initialized()
|
||||
await self.ensure_initialized()
|
||||
file_path = self._get_file_path(key)
|
||||
|
||||
# Use default TTL if not specified and cache has TTL
|
||||
@@ -178,18 +185,20 @@ class FileCache(CacheBackend):
|
||||
"ttl": effective_ttl,
|
||||
}
|
||||
|
||||
temp_path = None
|
||||
try:
|
||||
# Serialize the entire structure
|
||||
if self.serializer == "pickle":
|
||||
content = pickle.dumps(cache_data)
|
||||
else:
|
||||
# For JSON, ensure bytes are encoded properly
|
||||
if isinstance(value, bytes):
|
||||
cache_data["value"] = value.decode("utf-8", errors="replace")
|
||||
cache_data["value"] = value.decode("utf-8", errors="replace")
|
||||
content = json.dumps(cache_data).encode("utf-8")
|
||||
|
||||
# Write atomically using a temporary file
|
||||
temp_path = file_path.with_suffix(".tmp")
|
||||
# Write atomically using a temporary file with unique name
|
||||
import uuid
|
||||
|
||||
temp_path = file_path.with_suffix(f".tmp.{uuid.uuid4().hex[:8]}")
|
||||
async with aiofiles.open(temp_path, "wb") as f:
|
||||
await f.write(content)
|
||||
|
||||
@@ -199,11 +208,12 @@ class FileCache(CacheBackend):
|
||||
except (OSError, pickle.PickleError, json.JSONDecodeError) as e:
|
||||
self.logger.error(f"Failed to write cache file {file_path}: {e}")
|
||||
# Clean up temp file if it exists
|
||||
try:
|
||||
if temp_path.exists():
|
||||
await asyncio.to_thread(os.remove, temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
if temp_path is not None:
|
||||
try:
|
||||
if temp_path.exists():
|
||||
await asyncio.to_thread(os.remove, temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
"""Remove value from cache.
|
||||
@@ -211,25 +221,33 @@ class FileCache(CacheBackend):
|
||||
Args:
|
||||
key: Cache key
|
||||
"""
|
||||
await self._ensure_initialized()
|
||||
await self.ensure_initialized()
|
||||
file_path = self._get_file_path(key)
|
||||
await self._delete_file(file_path)
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all cache entries."""
|
||||
await self._ensure_initialized()
|
||||
await self.ensure_initialized()
|
||||
|
||||
try:
|
||||
# List all cache files
|
||||
cache_files = [
|
||||
f
|
||||
for f in self.cache_dir.iterdir()
|
||||
if f.is_file() and f.suffix == ".cache"
|
||||
]
|
||||
# List all cache files in a thread-safe way
|
||||
def list_cache_files() -> list[Path]:
|
||||
if not self.cache_dir.exists():
|
||||
return []
|
||||
return [
|
||||
f
|
||||
for f in self.cache_dir.iterdir()
|
||||
if f.is_file() and f.suffix == ".cache"
|
||||
]
|
||||
|
||||
# Delete all cache files
|
||||
for file_path in cache_files:
|
||||
await self._delete_file(file_path)
|
||||
cache_files = await asyncio.to_thread(list_cache_files)
|
||||
|
||||
# Delete all cache files concurrently
|
||||
if cache_files:
|
||||
await asyncio.gather(
|
||||
*[self._delete_file(file_path) for file_path in cache_files],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
except OSError as e:
|
||||
self.logger.error(f"Failed to clear cache directory: {e}")
|
||||
@@ -289,8 +307,14 @@ class FileCache(CacheBackend):
|
||||
items: Dictionary mapping keys to values
|
||||
ttl: Time-to-live in seconds (None uses default TTL)
|
||||
"""
|
||||
# Use asyncio.gather for parallel storage
|
||||
# Use asyncio.gather for parallel storage with limited concurrency
|
||||
semaphore = asyncio.Semaphore(10) # Limit concurrent file operations
|
||||
|
||||
async def set_with_semaphore(key: str, value: bytes) -> None:
|
||||
async with semaphore:
|
||||
await self.set(key, value, ttl)
|
||||
|
||||
await asyncio.gather(
|
||||
*[self.set(key, value, ttl) for key, value in items.items()],
|
||||
*[set_with_semaphore(key, value) for key, value in items.items()],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
"""Edge helpers for LangGraph routing and conditional logic.
|
||||
|
||||
This module provides factory functions and edge helpers for creating flexible
|
||||
routing logic in LangGraph workflows. Edge helpers are functions that accept
|
||||
state and return enum values or strings indicating the next node(s) to route to.
|
||||
|
||||
Example:
|
||||
from bb_core.edge_helpers import create_enum_router, should_continue
|
||||
|
||||
# Create a custom router
|
||||
router = create_enum_router({"yes": "continue_node", "no": "end_node"})
|
||||
graph.add_conditional_edges("source_node", router)
|
||||
|
||||
# Use built-in edge helpers
|
||||
graph.add_conditional_edges("agent", should_continue)
|
||||
"""
|
||||
|
||||
from bb_core.edge_helpers.core import (
|
||||
create_bool_router,
|
||||
create_enum_router,
|
||||
create_status_router,
|
||||
create_threshold_router,
|
||||
)
|
||||
from bb_core.edge_helpers.error_handling import (
|
||||
fallback_to_default,
|
||||
handle_error,
|
||||
retry_on_failure,
|
||||
)
|
||||
from bb_core.edge_helpers.flow_control import (
|
||||
multi_step_progress,
|
||||
should_continue,
|
||||
timeout_check,
|
||||
)
|
||||
from bb_core.edge_helpers.monitoring import (
|
||||
check_resource_availability,
|
||||
log_and_monitor,
|
||||
trigger_notifications,
|
||||
)
|
||||
from bb_core.edge_helpers.user_interaction import (
|
||||
escalate_to_human,
|
||||
human_interrupt,
|
||||
pass_status_to_user,
|
||||
user_feedback_loop,
|
||||
)
|
||||
from bb_core.edge_helpers.validation import (
|
||||
check_accuracy,
|
||||
check_confidence_level,
|
||||
check_data_freshness,
|
||||
check_privacy_compliance,
|
||||
validate_output_format,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Core factories
|
||||
"create_enum_router",
|
||||
"create_bool_router",
|
||||
"create_status_router",
|
||||
"create_threshold_router",
|
||||
# Flow control
|
||||
"should_continue",
|
||||
"timeout_check",
|
||||
"multi_step_progress",
|
||||
# Error handling
|
||||
"handle_error",
|
||||
"retry_on_failure",
|
||||
"fallback_to_default",
|
||||
# Validation
|
||||
"check_accuracy",
|
||||
"check_confidence_level",
|
||||
"validate_output_format",
|
||||
"check_privacy_compliance",
|
||||
"check_data_freshness",
|
||||
# User interaction
|
||||
"human_interrupt",
|
||||
"pass_status_to_user",
|
||||
"user_feedback_loop",
|
||||
"escalate_to_human",
|
||||
# Monitoring
|
||||
"log_and_monitor",
|
||||
"check_resource_availability",
|
||||
"trigger_notifications",
|
||||
]
|
||||
248
packages/business-buddy-core/src/bb_core/edge_helpers/core.py
Normal file
248
packages/business-buddy-core/src/bb_core/edge_helpers/core.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""Core routing factory functions for edge helpers.
|
||||
|
||||
This module provides factory functions for creating flexible routing functions
|
||||
that can be used as conditional edges in LangGraph workflows.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal, Protocol, TypeVar
|
||||
|
||||
# Generic state type for edge functions
|
||||
StateT = TypeVar("StateT")
|
||||
|
||||
|
||||
class StateProtocol(Protocol):
|
||||
"""Protocol for state objects that can be used with edge helpers."""
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""Get a value from the state."""
|
||||
...
|
||||
|
||||
|
||||
def create_enum_router(
|
||||
enum_to_target: dict[str, str],
|
||||
state_key: str = "routing_decision",
|
||||
default_target: str = "end",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that maps enum values to target nodes.
|
||||
|
||||
Args:
|
||||
enum_to_target: Mapping of enum values to target node names
|
||||
state_key: Key in state to check for routing decision
|
||||
default_target: Default target if no match found
|
||||
|
||||
Returns:
|
||||
Router function that accepts state and returns target node name
|
||||
|
||||
Example:
|
||||
router = create_enum_router({
|
||||
"continue": "next_step",
|
||||
"retry": "retry_node",
|
||||
"error": "error_handler"
|
||||
})
|
||||
graph.add_conditional_edges("source_node", router)
|
||||
"""
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
enum_value = state.get(state_key)
|
||||
else:
|
||||
enum_value = getattr(state, state_key, None)
|
||||
|
||||
return enum_to_target.get(str(enum_value), default_target)
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def create_bool_router(
|
||||
true_target: str,
|
||||
false_target: str,
|
||||
state_key: str = "condition",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router based on a boolean condition in state.
|
||||
|
||||
Args:
|
||||
true_target: Target node when condition is True
|
||||
false_target: Target node when condition is False
|
||||
state_key: Key in state to check for boolean value
|
||||
|
||||
Returns:
|
||||
Router function that accepts state and returns target node name
|
||||
|
||||
Example:
|
||||
router = create_bool_router("success_node", "failure_node", "is_valid")
|
||||
graph.add_conditional_edges("validator", router)
|
||||
"""
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
condition = state.get(state_key, False)
|
||||
else:
|
||||
condition = getattr(state, state_key, False)
|
||||
|
||||
return true_target if condition else false_target
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def create_status_router(
|
||||
status_mapping: dict[str, str],
|
||||
state_key: str = "status",
|
||||
default_target: str = "end",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router based on operation status.
|
||||
|
||||
Args:
|
||||
status_mapping: Mapping of status values to target nodes
|
||||
state_key: Key in state containing status value
|
||||
default_target: Default target if status not found in mapping
|
||||
|
||||
Returns:
|
||||
Router function that accepts state and returns target node name
|
||||
|
||||
Example:
|
||||
router = create_status_router({
|
||||
"pending": "wait_node",
|
||||
"running": "monitor_node",
|
||||
"completed": "success_node",
|
||||
"failed": "error_node"
|
||||
})
|
||||
graph.add_conditional_edges("task_checker", router)
|
||||
"""
|
||||
return create_enum_router(status_mapping, state_key, default_target)
|
||||
|
||||
|
||||
def create_threshold_router(
|
||||
threshold: float,
|
||||
above_target: str,
|
||||
below_target: str,
|
||||
state_key: str = "score",
|
||||
equal_target: str | None = None,
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router based on numeric threshold comparison.
|
||||
|
||||
Args:
|
||||
threshold: Threshold value to compare against
|
||||
above_target: Target when value is above threshold
|
||||
below_target: Target when value is below threshold
|
||||
state_key: Key in state containing numeric value
|
||||
equal_target: Optional target when value equals threshold
|
||||
(defaults to above_target)
|
||||
|
||||
Returns:
|
||||
Router function that accepts state and returns target node name
|
||||
|
||||
Example:
|
||||
router = create_threshold_router(
|
||||
0.8, "high_confidence", "low_confidence", "confidence"
|
||||
)
|
||||
graph.add_conditional_edges("scorer", router)
|
||||
"""
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
value = state.get(state_key, 0.0)
|
||||
else:
|
||||
value = getattr(state, state_key, 0.0)
|
||||
|
||||
try:
|
||||
numeric_value = float(value)
|
||||
if numeric_value > threshold:
|
||||
return above_target
|
||||
elif numeric_value < threshold:
|
||||
return below_target
|
||||
else:
|
||||
return equal_target or above_target
|
||||
except (ValueError, TypeError):
|
||||
return below_target
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def create_field_presence_router(
|
||||
required_fields: list[str],
|
||||
complete_target: str,
|
||||
incomplete_target: str,
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router based on presence of required fields in state.
|
||||
|
||||
Args:
|
||||
required_fields: List of field names that must be present
|
||||
complete_target: Target when all fields are present
|
||||
incomplete_target: Target when any field is missing
|
||||
|
||||
Returns:
|
||||
Router function that accepts state and returns target node name
|
||||
|
||||
Example:
|
||||
router = create_field_presence_router(
|
||||
["user_input", "processed_data", "validation_result"],
|
||||
"proceed",
|
||||
"gather_missing_data"
|
||||
)
|
||||
graph.add_conditional_edges("data_checker", router)
|
||||
"""
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
for field in required_fields:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
value = state.get(field)
|
||||
else:
|
||||
value = getattr(state, field, None)
|
||||
|
||||
if value is None or value == "":
|
||||
return incomplete_target
|
||||
|
||||
return complete_target
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def create_list_length_router(
|
||||
min_length: int,
|
||||
sufficient_target: str,
|
||||
insufficient_target: str,
|
||||
state_key: str = "items",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router based on list length in state.
|
||||
|
||||
Args:
|
||||
min_length: Minimum required length
|
||||
sufficient_target: Target when list meets minimum length
|
||||
insufficient_target: Target when list is too short
|
||||
state_key: Key in state containing the list
|
||||
|
||||
Returns:
|
||||
Router function that accepts state and returns target node name
|
||||
|
||||
Example:
|
||||
router = create_list_length_router(
|
||||
3, "process_results", "gather_more", "search_results"
|
||||
)
|
||||
graph.add_conditional_edges("result_checker", router)
|
||||
"""
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
items = state.get(state_key, [])
|
||||
else:
|
||||
items = getattr(state, state_key, [])
|
||||
|
||||
# Only process actual lists/sequences, not strings or other types
|
||||
if not isinstance(items, (list, tuple)):
|
||||
return insufficient_target
|
||||
|
||||
try:
|
||||
return (
|
||||
sufficient_target if len(items) >= min_length else insufficient_target
|
||||
)
|
||||
except TypeError:
|
||||
return insufficient_target
|
||||
|
||||
return router
|
||||
|
||||
|
||||
# Type aliases for common router patterns
|
||||
BoolRouter = Callable[[StateT], Literal["true", "false"]]
|
||||
StatusRouter = Callable[[StateT], Literal["pending", "running", "completed", "failed"]]
|
||||
ContinueRouter = Callable[[StateT], Literal["continue", "end"]]
|
||||
120
packages/business-buddy-core/src/bb_core/edge_helpers/edges.md
Normal file
120
packages/business-buddy-core/src/bb_core/edge_helpers/edges.md
Normal file
@@ -0,0 +1,120 @@
|
||||
What it would look like
|
||||
|
||||
/packages/bb_core/
|
||||
edge_helpers/
|
||||
__init__.py
|
||||
routing.py
|
||||
error_handling.py
|
||||
accuracy_checks.py
|
||||
...
|
||||
|
||||
|
||||
Edge helper functions would be generic routing functions that accept the current state and return an enum or string indicating the next node(s) to route to.
|
||||
These functions can be parameterized or configured at runtime with mappings of enum values to target node names, allowing flexible routing without hardcoding targets.
|
||||
The package can provide factory functions or classes that generate routing functions given a source node and a dynamic mapping of enum-to-target nodes.
|
||||
The routing functions would follow a consistent interface, e.g., def route(state: State) -> Literal[...] or def route(state: State) -> str.
|
||||
|
||||
Example conceptual snippet:
|
||||
python
|
||||
|
||||
from typing import Callable, Dict, Literal, TypeVar
|
||||
|
||||
State = TypeVar("State")
|
||||
|
||||
def create_enum_router(
|
||||
enum_to_target: Dict[str, str],
|
||||
) -> Callable[[State], str]:
|
||||
def router(state: State) -> str:
|
||||
# Determine enum value from state (custom logic)
|
||||
enum_value = determine_enum_value(state)
|
||||
return enum_to_target.get(enum_value, "default_node")
|
||||
return router
|
||||
|
||||
How to organize it at runtime
|
||||
|
||||
At graph construction or initialization time, you instantiate these routing functions with the current source node context and the dynamic mapping of enum values to target nodes.
|
||||
You then register these routing functions as conditional edges on the graph, e.g.:
|
||||
|
||||
python
|
||||
|
||||
graph.add_conditional_edges("source_node", create_enum_router(mapping))
|
||||
|
||||
This allows the graph to remain flexible and extensible, with routing logic decoupled from static graph definitions.
|
||||
|
||||
|
||||
|
||||
should_continue
|
||||
Decide whether to continue or end based on tool calls in the last AI message.
|
||||
|
||||
check_accuracy
|
||||
Route based on whether the agent's last response meets a defined accuracy threshold.
|
||||
|
||||
handle_error
|
||||
Detect error messages or exceptions and route to error handling or retry nodes.
|
||||
|
||||
append_visual_aids
|
||||
Determine if visual aids (images, charts) should be appended to the response.
|
||||
|
||||
pass_status_to_user
|
||||
Route to nodes that update the human user with status or progress messages.
|
||||
|
||||
human_interrupt
|
||||
Detect if a human user has requested to interrupt or stop the agent.
|
||||
|
||||
timeout_check
|
||||
Route based on whether a timeout has occurred during processing.
|
||||
|
||||
retry_on_failure
|
||||
Decide to retry a failed step or escalate after a certain number of attempts.
|
||||
|
||||
fallback_to_default
|
||||
Route to a default or safe response if the agent output is empty or nonsensical.
|
||||
|
||||
check_confidence_level
|
||||
Route based on the confidence score of the LLM's output.
|
||||
|
||||
validate_output_format
|
||||
Check if the output matches expected schema or format, route to correction if not.
|
||||
|
||||
multi_step_progress
|
||||
Route based on progress through a multi-step workflow (e.g., step 1, step 2, etc.).
|
||||
|
||||
user_feedback_loop
|
||||
Route to nodes that solicit or process user feedback on the agent's response.
|
||||
|
||||
escalate_to_human
|
||||
Route to a human agent or supervisor if automated handling fails or is insufficient.
|
||||
|
||||
check_resource_availability
|
||||
Route based on availability of external resources or APIs needed for next steps.
|
||||
|
||||
handle_ambiguous_input
|
||||
Detect ambiguous or unclear user input and route to clarification nodes.
|
||||
|
||||
language_detection
|
||||
Route based on detected language of user input for multilingual support.
|
||||
|
||||
check_privacy_compliance
|
||||
Route to privacy checks or redaction if sensitive data is detected.
|
||||
|
||||
log_and_monitor
|
||||
Route to logging or monitoring nodes for audit or debugging purposes.
|
||||
|
||||
load_balance
|
||||
Route requests to different nodes or agents for load balancing.
|
||||
|
||||
check_user_authentication
|
||||
Route based on user authentication or authorization status.
|
||||
|
||||
handle_rate_limiting
|
||||
Detect API rate limits and route to wait or fallback nodes.
|
||||
|
||||
check_data_freshness
|
||||
Route based on whether data used is up-to-date or stale.
|
||||
|
||||
trigger_notifications
|
||||
Route to nodes that send notifications or alerts to users or admins.
|
||||
|
||||
conditional_branching_by_intent
|
||||
Route based on detected user intent or topic classification.
|
||||
|
||||
@@ -0,0 +1,365 @@
|
||||
"""Error handling edge helpers for robust workflow management.
|
||||
|
||||
This module provides edge helpers for detecting errors, implementing retry logic,
|
||||
and routing to appropriate error handling or recovery nodes.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal, TypeVar
|
||||
|
||||
from bb_core.edge_helpers.core import StateProtocol
|
||||
|
||||
StateT = TypeVar("StateT", bound=StateProtocol)
|
||||
|
||||
|
||||
def handle_error(
|
||||
error_types: dict[str, str] | None = None,
|
||||
error_key: str = "error",
|
||||
default_target: str = "generic_error_handler",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that handles different types of errors.
|
||||
|
||||
Args:
|
||||
error_types: Mapping of error type/class names to handler nodes
|
||||
error_key: Key in state containing error information
|
||||
default_target: Default target if error type not in mapping
|
||||
|
||||
Returns:
|
||||
Router function that routes based on error type
|
||||
|
||||
Example:
|
||||
error_router = handle_error({
|
||||
"ValidationError": "validation_recovery",
|
||||
"NetworkError": "network_retry",
|
||||
"AuthenticationError": "auth_failure_handler"
|
||||
})
|
||||
graph.add_conditional_edges("error_detector", error_router)
|
||||
"""
|
||||
if error_types is None:
|
||||
error_types = {
|
||||
"ValidationError": "validation_error_handler",
|
||||
"NetworkError": "network_error_handler",
|
||||
"TimeoutError": "timeout_error_handler",
|
||||
"AuthenticationError": "auth_error_handler",
|
||||
}
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
error = state.get(error_key)
|
||||
else:
|
||||
error = getattr(state, error_key, None)
|
||||
|
||||
if error is None:
|
||||
return "no_error"
|
||||
|
||||
# Handle different error formats
|
||||
error_type = None
|
||||
if isinstance(error, dict):
|
||||
error_type = error.get("type") or error.get("error_type")
|
||||
elif isinstance(error, str):
|
||||
error_type = error
|
||||
elif hasattr(error, "__class__"):
|
||||
error_type = error.__class__.__name__
|
||||
|
||||
if error_type:
|
||||
return error_types.get(error_type, default_target)
|
||||
|
||||
return default_target
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def retry_on_failure(
|
||||
max_retries: int = 3,
|
||||
retry_count_key: str = "retry_count",
|
||||
error_key: str = "error",
|
||||
) -> Callable[
|
||||
[dict[str, Any] | StateProtocol],
|
||||
Literal["retry", "max_retries_exceeded", "success"],
|
||||
]:
|
||||
"""Create a router that implements retry logic for failures.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of retry attempts
|
||||
retry_count_key: Key in state tracking retry attempts
|
||||
error_key: Key in state containing error information
|
||||
|
||||
Returns:
|
||||
Router function that returns "retry", "max_retries_exceeded", or "success"
|
||||
|
||||
Example:
|
||||
retry_router = retry_on_failure(max_retries=5)
|
||||
graph.add_conditional_edges("failure_handler", retry_router)
|
||||
"""
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["retry", "max_retries_exceeded", "success"]:
|
||||
# Check if there's an error
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
error = state.get(error_key)
|
||||
retry_count = state.get(retry_count_key, 0)
|
||||
else:
|
||||
error = getattr(state, error_key, None)
|
||||
retry_count = getattr(state, retry_count_key, 0)
|
||||
|
||||
# No error means success
|
||||
if error is None:
|
||||
return "success"
|
||||
|
||||
try:
|
||||
current_retries = int(retry_count)
|
||||
if current_retries >= max_retries:
|
||||
return "max_retries_exceeded"
|
||||
else:
|
||||
return "retry"
|
||||
except (ValueError, TypeError):
|
||||
return "retry"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def fallback_to_default(
|
||||
fallback_conditions: list[str] | None = None,
|
||||
output_key: str = "output",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], Literal["use_output", "fallback"]]:
|
||||
"""Create a router that falls back to default when output is inadequate.
|
||||
|
||||
Args:
|
||||
fallback_conditions: List of conditions that trigger fallback
|
||||
output_key: Key in state containing output to check
|
||||
|
||||
Returns:
|
||||
Router function that returns "use_output" or "fallback"
|
||||
|
||||
Example:
|
||||
fallback_router = fallback_to_default([
|
||||
"empty_output", "low_confidence", "invalid_format"
|
||||
])
|
||||
graph.add_conditional_edges("output_validator", fallback_router)
|
||||
"""
|
||||
if fallback_conditions is None:
|
||||
fallback_conditions = ["empty_output", "low_confidence", "error"]
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["use_output", "fallback"]:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
output = state.get(output_key)
|
||||
else:
|
||||
output = getattr(state, output_key, None)
|
||||
|
||||
# Check for empty or null output
|
||||
if output is None or output == "":
|
||||
return "fallback"
|
||||
|
||||
# Check for other fallback conditions in state
|
||||
for condition in fallback_conditions:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
condition_value = state.get(condition, False)
|
||||
else:
|
||||
condition_value = getattr(state, condition, False)
|
||||
|
||||
if condition_value:
|
||||
return "fallback"
|
||||
|
||||
return "use_output"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_critical_error(
|
||||
critical_error_types: list[str] | None = None,
|
||||
error_key: str = "error",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], Literal["critical", "non_critical"]]:
|
||||
"""Create a router that identifies critical errors requiring immediate attention.
|
||||
|
||||
Args:
|
||||
critical_error_types: List of error types considered critical
|
||||
error_key: Key in state containing error information
|
||||
|
||||
Returns:
|
||||
Router function that returns "critical" or "non_critical"
|
||||
|
||||
Example:
|
||||
critical_router = check_critical_error([
|
||||
"SecurityError", "DataCorruptionError", "SystemFailure"
|
||||
])
|
||||
graph.add_conditional_edges("error_classifier", critical_router)
|
||||
"""
|
||||
if critical_error_types is None:
|
||||
critical_error_types = [
|
||||
"SecurityError",
|
||||
"AuthenticationError",
|
||||
"AuthorizationError",
|
||||
"DataCorruptionError",
|
||||
"SystemFailure",
|
||||
"OutOfMemoryError",
|
||||
]
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["critical", "non_critical"]:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
error = state.get(error_key)
|
||||
else:
|
||||
error = getattr(state, error_key, None)
|
||||
|
||||
if error is None:
|
||||
return "non_critical"
|
||||
|
||||
# Determine error type
|
||||
error_type = None
|
||||
if isinstance(error, dict):
|
||||
error_type = error.get("type") or error.get("error_type")
|
||||
elif isinstance(error, str):
|
||||
error_type = error
|
||||
elif hasattr(error, "__class__"):
|
||||
error_type = error.__class__.__name__
|
||||
|
||||
if error_type and error_type in critical_error_types:
|
||||
return "critical"
|
||||
|
||||
return "non_critical"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def escalation_policy(
|
||||
escalation_threshold: int = 3,
|
||||
failure_count_key: str = "consecutive_failures",
|
||||
escalation_types: dict[str, str] | None = None,
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that implements error escalation policies.
|
||||
|
||||
Args:
|
||||
escalation_threshold: Number of failures before escalation
|
||||
failure_count_key: Key in state tracking consecutive failures
|
||||
escalation_types: Mapping of failure types to escalation targets
|
||||
|
||||
Returns:
|
||||
Router function that handles escalation logic
|
||||
|
||||
Example:
|
||||
escalation_router = escalation_policy(
|
||||
escalation_threshold=3,
|
||||
escalation_types={
|
||||
"timeout": "timeout_escalation",
|
||||
"validation": "validation_escalation"
|
||||
}
|
||||
)
|
||||
graph.add_conditional_edges("failure_monitor", escalation_router)
|
||||
"""
|
||||
if escalation_types is None:
|
||||
escalation_types = {
|
||||
"timeout": "timeout_escalation",
|
||||
"network": "network_escalation",
|
||||
"validation": "validation_escalation",
|
||||
"auth": "security_escalation",
|
||||
"security": "security_escalation",
|
||||
}
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
failure_count = state.get(failure_count_key, 0)
|
||||
error = state.get("error")
|
||||
else:
|
||||
failure_count = getattr(state, failure_count_key, 0)
|
||||
error = getattr(state, "error", None)
|
||||
|
||||
try:
|
||||
failures = int(failure_count)
|
||||
if failures < escalation_threshold:
|
||||
return "continue_monitoring"
|
||||
|
||||
# Determine escalation type based on error
|
||||
if error:
|
||||
error_type = None
|
||||
if isinstance(error, dict):
|
||||
error_type = error.get("type") or error.get("error_type")
|
||||
elif isinstance(error, str):
|
||||
error_type = error.lower()
|
||||
elif hasattr(error, "__class__"):
|
||||
error_type = error.__class__.__name__
|
||||
|
||||
if error_type:
|
||||
for escalation_key, target in escalation_types.items():
|
||||
if escalation_key.lower() in error_type.lower():
|
||||
return target
|
||||
|
||||
return "generic_escalation"
|
||||
|
||||
except (ValueError, TypeError):
|
||||
return "continue_monitoring"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def circuit_breaker(
|
||||
failure_threshold: int = 5,
|
||||
recovery_timeout: int = 60,
|
||||
state_key: str = "circuit_breaker_state",
|
||||
failure_count_key: str = "failure_count",
|
||||
last_failure_key: str = "last_failure_time",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], Literal["allow", "reject", "probe"]]:
|
||||
"""Create a router that implements circuit breaker pattern.
|
||||
|
||||
Args:
|
||||
failure_threshold: Number of failures before opening circuit
|
||||
recovery_timeout: Seconds to wait before trying to close circuit
|
||||
state_key: Key storing circuit state (closed/open/half_open)
|
||||
failure_count_key: Key tracking failure count
|
||||
last_failure_key: Key storing last failure timestamp
|
||||
|
||||
Returns:
|
||||
Router function implementing circuit breaker logic
|
||||
|
||||
Example:
|
||||
breaker_router = circuit_breaker(failure_threshold=3, recovery_timeout=30)
|
||||
graph.add_conditional_edges("service_call", breaker_router)
|
||||
"""
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["allow", "reject", "probe"]:
|
||||
import time
|
||||
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
circuit_state = state.get(state_key, "closed")
|
||||
failure_count = state.get(failure_count_key, 0)
|
||||
last_failure = state.get(last_failure_key, 0)
|
||||
else:
|
||||
circuit_state = getattr(state, state_key, "closed")
|
||||
failure_count = getattr(state, failure_count_key, 0)
|
||||
last_failure = getattr(state, last_failure_key, 0)
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# Convert numeric values with error handling
|
||||
try:
|
||||
failure_count = int(failure_count)
|
||||
except (ValueError, TypeError):
|
||||
failure_count = 0
|
||||
|
||||
try:
|
||||
last_failure = float(last_failure)
|
||||
except (ValueError, TypeError):
|
||||
last_failure = 0.0
|
||||
|
||||
if circuit_state == "open":
|
||||
# Check if recovery timeout has passed
|
||||
if current_time - last_failure >= recovery_timeout:
|
||||
return "probe" # Try half-open state
|
||||
else:
|
||||
return "reject" # Still in timeout
|
||||
|
||||
elif circuit_state == "half_open":
|
||||
return "probe" # Allow one request to test
|
||||
|
||||
else: # closed state
|
||||
if failure_count > failure_threshold:
|
||||
return "reject" # Open circuit
|
||||
else:
|
||||
return "allow" # Normal operation
|
||||
|
||||
return router
|
||||
@@ -0,0 +1,262 @@
|
||||
"""Flow control edge helpers for managing workflow progression.
|
||||
|
||||
This module provides edge helpers for controlling the flow of execution
|
||||
in LangGraph workflows, including continuation logic, timeout checks,
|
||||
and multi-step progress tracking.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal, TypeVar
|
||||
|
||||
from bb_core.edge_helpers.core import StateProtocol
|
||||
|
||||
StateT = TypeVar("StateT", bound=StateProtocol)
|
||||
|
||||
|
||||
def should_continue(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["continue", "end"]:
|
||||
"""Decide whether to continue or end based on tool calls in the last AI message.
|
||||
|
||||
Checks for the presence of tool calls in the last message to determine
|
||||
if the workflow should continue processing or end.
|
||||
|
||||
Args:
|
||||
state: State containing messages with potential tool calls
|
||||
|
||||
Returns:
|
||||
"continue" if tool calls are present, "end" otherwise
|
||||
|
||||
Example:
|
||||
graph.add_conditional_edges("agent", should_continue)
|
||||
"""
|
||||
# Check for messages in state
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
messages = state.get("messages", [])
|
||||
else:
|
||||
messages = getattr(state, "messages", [])
|
||||
|
||||
if not messages:
|
||||
return "end"
|
||||
|
||||
last_message = messages[-1]
|
||||
|
||||
# Check various formats for tool calls
|
||||
if isinstance(last_message, dict):
|
||||
# Check for tool_calls in additional_kwargs (must be a list)
|
||||
additional_kwargs = last_message.get("additional_kwargs", {})
|
||||
tool_calls = additional_kwargs.get("tool_calls")
|
||||
if tool_calls and isinstance(tool_calls, list) and len(tool_calls) > 0:
|
||||
return "continue"
|
||||
|
||||
# Check for tool_calls directly (must be a list)
|
||||
tool_calls = last_message.get("tool_calls")
|
||||
if tool_calls and isinstance(tool_calls, list) and len(tool_calls) > 0:
|
||||
return "continue"
|
||||
|
||||
# Check for function_call
|
||||
if last_message.get("function_call"):
|
||||
return "continue"
|
||||
|
||||
# Check if message object has tool_calls attribute (must be a list)
|
||||
if hasattr(last_message, "tool_calls"):
|
||||
tool_calls = getattr(last_message, "tool_calls", None)
|
||||
if tool_calls and isinstance(tool_calls, list) and len(tool_calls) > 0:
|
||||
return "continue"
|
||||
|
||||
# Check if message object has additional_kwargs with tool_calls (must be a list)
|
||||
if hasattr(last_message, "additional_kwargs"):
|
||||
additional_kwargs = getattr(last_message, "additional_kwargs", None) or {}
|
||||
if isinstance(additional_kwargs, dict):
|
||||
tool_calls = additional_kwargs.get("tool_calls")
|
||||
if tool_calls and isinstance(tool_calls, list) and len(tool_calls) > 0:
|
||||
return "continue"
|
||||
|
||||
return "end"
|
||||
|
||||
|
||||
def timeout_check(
|
||||
timeout_seconds: float = 300.0,
|
||||
start_time_key: str = "start_time",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that checks for timeout conditions.
|
||||
|
||||
Args:
|
||||
timeout_seconds: Maximum allowed execution time in seconds
|
||||
start_time_key: Key in state containing start timestamp
|
||||
|
||||
Returns:
|
||||
Router function that returns "timeout" or "continue"
|
||||
|
||||
Example:
|
||||
timeout_router = timeout_check(timeout_seconds=60.0)
|
||||
graph.add_conditional_edges("long_task", timeout_router)
|
||||
"""
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> Literal["timeout", "continue"]:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
start_time = state.get(start_time_key)
|
||||
else:
|
||||
start_time = getattr(state, start_time_key, None)
|
||||
|
||||
if start_time is None:
|
||||
# No start time recorded, assume we should continue
|
||||
return "continue"
|
||||
|
||||
current_time = time.time()
|
||||
elapsed = current_time - start_time
|
||||
|
||||
return "timeout" if elapsed > timeout_seconds else "continue"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def multi_step_progress(
|
||||
total_steps: int,
|
||||
step_key: str = "current_step",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router for multi-step workflow progress tracking.
|
||||
|
||||
Args:
|
||||
total_steps: Total number of steps in the workflow
|
||||
step_key: Key in state containing current step number
|
||||
|
||||
Returns:
|
||||
Router function that returns "next_step", "complete", or "error"
|
||||
|
||||
Example:
|
||||
progress_router = multi_step_progress(total_steps=5)
|
||||
graph.add_conditional_edges("step_processor", progress_router)
|
||||
"""
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["next_step", "complete", "error"]:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
current_step = state.get(step_key, 0)
|
||||
else:
|
||||
current_step = getattr(state, step_key, 0)
|
||||
|
||||
try:
|
||||
step_num = int(current_step)
|
||||
if step_num < 0:
|
||||
return "error"
|
||||
elif step_num >= total_steps:
|
||||
return "complete"
|
||||
else:
|
||||
return "next_step"
|
||||
except (ValueError, TypeError):
|
||||
return "error"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_iteration_limit(
|
||||
max_iterations: int = 10,
|
||||
iteration_key: str = "iteration_count",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that enforces iteration limits to prevent infinite loops.
|
||||
|
||||
Args:
|
||||
max_iterations: Maximum allowed iterations
|
||||
iteration_key: Key in state containing iteration counter
|
||||
|
||||
Returns:
|
||||
Router function that returns "continue" or "limit_reached"
|
||||
|
||||
Example:
|
||||
limit_router = check_iteration_limit(max_iterations=5)
|
||||
graph.add_conditional_edges("retry_node", limit_router)
|
||||
"""
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["continue", "limit_reached"]:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
iterations = state.get(iteration_key, 0)
|
||||
else:
|
||||
iterations = getattr(state, iteration_key, 0)
|
||||
|
||||
try:
|
||||
iter_count = int(iterations)
|
||||
return "limit_reached" if iter_count >= max_iterations else "continue"
|
||||
except (ValueError, TypeError):
|
||||
return "continue"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_completion_criteria(
|
||||
required_conditions: list[str],
|
||||
condition_prefix: str = "completed_",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that checks if all completion criteria are met.
|
||||
|
||||
Args:
|
||||
required_conditions: List of condition names that must be true
|
||||
condition_prefix: Prefix for condition keys in state
|
||||
|
||||
Returns:
|
||||
Router function that returns "complete" or "continue"
|
||||
|
||||
Example:
|
||||
completion_router = check_completion_criteria([
|
||||
"data_validated", "report_generated", "notifications_sent"
|
||||
])
|
||||
graph.add_conditional_edges("final_check", completion_router)
|
||||
"""
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["complete", "continue"]:
|
||||
for condition in required_conditions:
|
||||
condition_key = f"{condition_prefix}{condition}"
|
||||
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
is_complete = state.get(condition_key, False)
|
||||
else:
|
||||
is_complete = getattr(state, condition_key, False)
|
||||
|
||||
if not is_complete:
|
||||
return "continue"
|
||||
|
||||
return "complete"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_workflow_state(
|
||||
state_transitions: dict[str, str],
|
||||
state_key: str = "workflow_state",
|
||||
default_target: str = "error",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router based on workflow state transitions.
|
||||
|
||||
Args:
|
||||
state_transitions: Mapping of current states to next targets
|
||||
state_key: Key in state containing current workflow state
|
||||
default_target: Default target if state not found in transitions
|
||||
|
||||
Returns:
|
||||
Router function that returns target based on workflow state
|
||||
|
||||
Example:
|
||||
workflow_router = check_workflow_state({
|
||||
"initialized": "processing",
|
||||
"processing": "validation",
|
||||
"validation": "completion",
|
||||
"completion": "end"
|
||||
})
|
||||
graph.add_conditional_edges("state_manager", workflow_router)
|
||||
"""
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
current_state = state.get(state_key)
|
||||
else:
|
||||
current_state = getattr(state, state_key, None)
|
||||
|
||||
return state_transitions.get(str(current_state), default_target)
|
||||
|
||||
return router
|
||||
@@ -0,0 +1,394 @@
|
||||
"""Monitoring and operational edge helpers for system management.
|
||||
|
||||
This module provides edge helpers for monitoring system health, checking
|
||||
resource availability, triggering notifications, and managing operational
|
||||
concerns like load balancing and rate limiting.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal, TypeVar
|
||||
|
||||
from bb_core.edge_helpers.core import StateProtocol
|
||||
|
||||
StateT = TypeVar("StateT", bound=StateProtocol)
|
||||
|
||||
|
||||
def log_and_monitor(
|
||||
log_levels: dict[str, str] | None = None,
|
||||
log_level_key: str = "log_level",
|
||||
should_log_key: str = "should_log",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that determines logging and monitoring actions.
|
||||
|
||||
Args:
|
||||
log_levels: Mapping of log levels to monitoring actions
|
||||
log_level_key: Key in state containing log level
|
||||
should_log_key: Key in state indicating if logging is enabled
|
||||
|
||||
Returns:
|
||||
Router function that returns monitoring action or "no_logging"
|
||||
|
||||
Example:
|
||||
monitor_router = log_and_monitor({
|
||||
"debug": "debug_monitor",
|
||||
"info": "info_monitor",
|
||||
"warning": "alert_monitor",
|
||||
"error": "urgent_monitor"
|
||||
})
|
||||
graph.add_conditional_edges("logger", monitor_router)
|
||||
"""
|
||||
if log_levels is None:
|
||||
log_levels = {
|
||||
"debug": "debug_log",
|
||||
"info": "info_log",
|
||||
"warning": "warning_alert",
|
||||
"error": "error_alert",
|
||||
"critical": "critical_alert",
|
||||
}
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
should_log = state.get(should_log_key, True)
|
||||
log_level = state.get(log_level_key, "info")
|
||||
else:
|
||||
should_log = getattr(state, should_log_key, True)
|
||||
log_level = getattr(state, log_level_key, "info")
|
||||
|
||||
if not should_log:
|
||||
return "no_logging"
|
||||
|
||||
level_str = str(log_level).lower()
|
||||
return log_levels.get(level_str, "info_log")
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_resource_availability(
|
||||
resource_thresholds: dict[str, float] | None = None,
|
||||
resources_key: str = "resource_usage",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that checks system resource availability.
|
||||
|
||||
Args:
|
||||
resource_thresholds: Mapping of resource names to threshold percentages
|
||||
resources_key: Key in state containing resource usage data
|
||||
|
||||
Returns:
|
||||
Router function that returns "resources_available" or resource constraint
|
||||
|
||||
Example:
|
||||
resource_router = check_resource_availability({
|
||||
"cpu": 0.8, # 80% threshold
|
||||
"memory": 0.9, # 90% threshold
|
||||
"disk": 0.95 # 95% threshold
|
||||
})
|
||||
graph.add_conditional_edges("resource_check", resource_router)
|
||||
"""
|
||||
if resource_thresholds is None:
|
||||
resource_thresholds = {
|
||||
"cpu": 0.8,
|
||||
"memory": 0.9,
|
||||
"disk": 0.95,
|
||||
"network": 0.8,
|
||||
}
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
resources = state.get(resources_key)
|
||||
else:
|
||||
resources = getattr(state, resources_key, None)
|
||||
|
||||
if resources is None or not isinstance(resources, dict):
|
||||
return "resources_unknown"
|
||||
|
||||
# Check each resource against its threshold
|
||||
for resource_name, threshold in resource_thresholds.items():
|
||||
usage = resources.get(resource_name, 0.0)
|
||||
try:
|
||||
usage_percent = float(usage)
|
||||
if usage_percent >= threshold:
|
||||
return f"{resource_name}_constrained"
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
|
||||
return "resources_available"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def trigger_notifications(
|
||||
notification_rules: dict[str, str] | None = None,
|
||||
alert_level_key: str = "alert_level",
|
||||
notification_enabled_key: str = "notifications_enabled",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that triggers appropriate notifications.
|
||||
|
||||
Args:
|
||||
notification_rules: Mapping of alert levels to notification types
|
||||
alert_level_key: Key in state containing alert level
|
||||
notification_enabled_key: Key in state for notification toggle
|
||||
|
||||
Returns:
|
||||
Router function that returns notification type or "no_notification"
|
||||
|
||||
Example:
|
||||
notify_router = trigger_notifications({
|
||||
"low": "email_notification",
|
||||
"medium": "slack_notification",
|
||||
"high": "sms_notification",
|
||||
"critical": "phone_notification"
|
||||
})
|
||||
graph.add_conditional_edges("alert_handler", notify_router)
|
||||
"""
|
||||
if notification_rules is None:
|
||||
notification_rules = {
|
||||
"low": "email_notification",
|
||||
"medium": "slack_notification",
|
||||
"high": "sms_notification",
|
||||
"critical": "phone_notification",
|
||||
"emergency": "all_channels",
|
||||
}
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
notifications_enabled = state.get(notification_enabled_key, True)
|
||||
alert_level = state.get(alert_level_key)
|
||||
else:
|
||||
notifications_enabled = getattr(state, notification_enabled_key, True)
|
||||
alert_level = getattr(state, alert_level_key, None)
|
||||
|
||||
if not notifications_enabled or alert_level is None:
|
||||
return "no_notification"
|
||||
|
||||
level_str = str(alert_level).lower()
|
||||
return notification_rules.get(level_str, "no_notification")
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_rate_limiting(
|
||||
rate_limits: dict[str, dict[str, float]] | None = None,
|
||||
request_counts_key: str = "request_counts",
|
||||
time_window_key: str = "time_window",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that enforces rate limiting policies.
|
||||
|
||||
Args:
|
||||
rate_limits: Mapping of endpoints/operations to rate limit configs
|
||||
request_counts_key: Key in state containing request count data
|
||||
time_window_key: Key in state containing time window info
|
||||
|
||||
Returns:
|
||||
Router function that returns "allowed" or "rate_limited"
|
||||
|
||||
Example:
|
||||
rate_router = check_rate_limiting({
|
||||
"api_calls": {"max_requests": 100, "window_seconds": 60},
|
||||
"file_uploads": {"max_requests": 10, "window_seconds": 60}
|
||||
})
|
||||
graph.add_conditional_edges("rate_limiter", rate_router)
|
||||
"""
|
||||
if rate_limits is None:
|
||||
rate_limits = {
|
||||
"default": {"max_requests": 100, "window_seconds": 60},
|
||||
"api_calls": {"max_requests": 100, "window_seconds": 60},
|
||||
"heavy_operations": {"max_requests": 10, "window_seconds": 300},
|
||||
}
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["allowed", "rate_limited"]:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
request_counts = state.get(request_counts_key, {})
|
||||
time_window_data = state.get(time_window_key, {})
|
||||
else:
|
||||
request_counts = getattr(state, request_counts_key, {})
|
||||
time_window_data = getattr(state, time_window_key, {})
|
||||
|
||||
# Use time from time_window_data if available, otherwise use system time
|
||||
current_time = time_window_data.get("current_time", time.time())
|
||||
|
||||
# Check rate limits for each configured endpoint
|
||||
for endpoint, limits in rate_limits.items():
|
||||
if endpoint not in request_counts:
|
||||
continue
|
||||
|
||||
max_requests = limits.get("max_requests", 100)
|
||||
window_seconds = limits.get("window_seconds", 60)
|
||||
|
||||
# Get request history for this endpoint
|
||||
endpoint_data = request_counts.get(endpoint, {})
|
||||
request_count = endpoint_data.get("count", 0)
|
||||
window_start = endpoint_data.get("window_start", current_time)
|
||||
|
||||
# Check if we're still in the same time window
|
||||
if (
|
||||
current_time - window_start < window_seconds
|
||||
and request_count >= max_requests
|
||||
):
|
||||
return "rate_limited"
|
||||
|
||||
return "allowed"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def load_balance(
|
||||
load_balancing_strategy: str = "round_robin",
|
||||
available_nodes: list[str] | None = None,
|
||||
node_status_key: str = "node_status",
|
||||
current_node_key: str = "current_node_index",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that implements load balancing across nodes.
|
||||
|
||||
Args:
|
||||
load_balancing_strategy: Strategy to use ("round_robin", "least_loaded")
|
||||
available_nodes: List of available node names
|
||||
node_status_key: Key in state containing node health status
|
||||
current_node_key: Key in state tracking current node index
|
||||
|
||||
Returns:
|
||||
Router function that returns selected node name
|
||||
|
||||
Example:
|
||||
balance_router = load_balance(
|
||||
strategy="round_robin",
|
||||
available_nodes=["node1", "node2", "node3"]
|
||||
)
|
||||
graph.add_conditional_edges("load_balancer", balance_router)
|
||||
"""
|
||||
if available_nodes is None:
|
||||
available_nodes = ["primary_node", "secondary_node"]
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
node_status = state.get(node_status_key, {})
|
||||
current_index = state.get(current_node_key, 0)
|
||||
else:
|
||||
node_status = getattr(state, node_status_key, {})
|
||||
current_index = getattr(state, current_node_key, 0)
|
||||
|
||||
# Filter healthy nodes
|
||||
healthy_nodes = []
|
||||
for node in available_nodes:
|
||||
node_health = node_status.get(node, {})
|
||||
if node_health.get("healthy", True): # Default to healthy
|
||||
healthy_nodes.append(node)
|
||||
|
||||
if not healthy_nodes:
|
||||
return "no_available_nodes"
|
||||
|
||||
if load_balancing_strategy == "round_robin":
|
||||
try:
|
||||
index = int(current_index) % len(healthy_nodes)
|
||||
return healthy_nodes[index]
|
||||
except (ValueError, TypeError):
|
||||
return healthy_nodes[0]
|
||||
|
||||
elif load_balancing_strategy == "least_loaded":
|
||||
# Find node with lowest load
|
||||
min_load = float("inf")
|
||||
selected_node = healthy_nodes[0]
|
||||
|
||||
for node in healthy_nodes:
|
||||
node_data = node_status.get(node, {})
|
||||
load = node_data.get("load", 0.0)
|
||||
try:
|
||||
if float(load) < min_load:
|
||||
min_load = float(load)
|
||||
selected_node = node
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
|
||||
return selected_node
|
||||
|
||||
else: # Default to first healthy node
|
||||
return healthy_nodes[0]
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def health_check(
|
||||
health_criteria: dict[str, Any] | None = None,
|
||||
health_status_key: str = "health_status",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that performs system health checks.
|
||||
|
||||
Args:
|
||||
health_criteria: Criteria for determining system health
|
||||
health_status_key: Key in state containing health check results
|
||||
|
||||
Returns:
|
||||
Router function that returns "healthy", "degraded", or "unhealthy"
|
||||
|
||||
Example:
|
||||
health_router = health_check({
|
||||
"response_time_ms": 1000,
|
||||
"error_rate": 0.05,
|
||||
"uptime_percent": 0.99
|
||||
})
|
||||
graph.add_conditional_edges("health_monitor", health_router)
|
||||
"""
|
||||
if health_criteria is None:
|
||||
health_criteria = {
|
||||
"response_time_ms": 1000,
|
||||
"error_rate": 0.05,
|
||||
"cpu_usage": 0.8,
|
||||
"memory_usage": 0.9,
|
||||
}
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["healthy", "degraded", "unhealthy"]:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
health_status = state.get(health_status_key, {})
|
||||
else:
|
||||
health_status = getattr(state, health_status_key, {})
|
||||
|
||||
if not isinstance(health_status, dict):
|
||||
return "unhealthy"
|
||||
|
||||
unhealthy_count = 0
|
||||
degraded_count = 0
|
||||
|
||||
for metric, threshold in health_criteria.items():
|
||||
current_value = health_status.get(metric)
|
||||
if current_value is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
value = float(current_value)
|
||||
threshold_val = float(threshold)
|
||||
|
||||
# Different metrics have different "good" directions
|
||||
if metric in [
|
||||
"error_rate",
|
||||
"response_time_ms",
|
||||
"cpu_usage",
|
||||
"memory_usage",
|
||||
"latency",
|
||||
]:
|
||||
# Lower is better
|
||||
if value > threshold_val * 1.5: # 150% of threshold
|
||||
unhealthy_count += 1
|
||||
elif value > threshold_val:
|
||||
degraded_count += 1
|
||||
else:
|
||||
# Higher is better (like uptime_percent)
|
||||
if value < threshold_val * 0.8: # 80% of threshold
|
||||
unhealthy_count += 1
|
||||
elif value < threshold_val:
|
||||
degraded_count += 1
|
||||
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
|
||||
if unhealthy_count > 0:
|
||||
return "unhealthy"
|
||||
elif degraded_count > 0:
|
||||
return "degraded"
|
||||
else:
|
||||
return "healthy"
|
||||
|
||||
return router
|
||||
@@ -0,0 +1,326 @@
|
||||
"""User interaction edge helpers for human-in-the-loop workflows.
|
||||
|
||||
This module provides edge helpers for managing user interactions, interrupts,
|
||||
feedback loops, and escalation to human operators.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal, TypeVar
|
||||
|
||||
from bb_core.edge_helpers.core import StateProtocol
|
||||
|
||||
StateT = TypeVar("StateT", bound=StateProtocol)
|
||||
|
||||
|
||||
def human_interrupt(
|
||||
interrupt_signals: list[str] | None = None,
|
||||
interrupt_key: str = "human_interrupt",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that detects human interruption requests.
|
||||
|
||||
Args:
|
||||
interrupt_signals: List of signals that indicate interrupt request
|
||||
interrupt_key: Key in state containing interrupt signal
|
||||
|
||||
Returns:
|
||||
Router function that returns "interrupt" or "continue"
|
||||
|
||||
Example:
|
||||
interrupt_router = human_interrupt([
|
||||
"stop", "pause", "cancel", "abort"
|
||||
])
|
||||
graph.add_conditional_edges("user_input_check", interrupt_router)
|
||||
"""
|
||||
if interrupt_signals is None:
|
||||
interrupt_signals = ["stop", "pause", "cancel", "abort", "interrupt", "halt"]
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["interrupt", "continue"]:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
signal = state.get(interrupt_key)
|
||||
else:
|
||||
signal = getattr(state, interrupt_key, None)
|
||||
|
||||
if signal is None:
|
||||
return "continue"
|
||||
|
||||
signal_str = str(signal).lower().strip()
|
||||
normalized_interrupt_signals = {s.lower() for s in interrupt_signals}
|
||||
return "interrupt" if signal_str in normalized_interrupt_signals else "continue"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def pass_status_to_user(
|
||||
status_levels: dict[str, str] | None = None,
|
||||
status_key: str = "status",
|
||||
notify_key: str = "notify_user",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that determines when to notify users of status.
|
||||
|
||||
Args:
|
||||
status_levels: Mapping of status values to notification urgency
|
||||
status_key: Key in state containing current status
|
||||
notify_key: Key in state indicating if user should be notified
|
||||
|
||||
Returns:
|
||||
Router function that returns notification priority or "no_notification"
|
||||
|
||||
Example:
|
||||
status_router = pass_status_to_user({
|
||||
"error": "urgent",
|
||||
"warning": "medium",
|
||||
"completed": "low"
|
||||
})
|
||||
graph.add_conditional_edges("status_monitor", status_router)
|
||||
"""
|
||||
if status_levels is None:
|
||||
status_levels = {
|
||||
"error": "urgent",
|
||||
"failed": "urgent",
|
||||
"warning": "medium",
|
||||
"completed": "low",
|
||||
"success": "low",
|
||||
"in_progress": "info",
|
||||
}
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
status = state.get(status_key)
|
||||
should_notify = state.get(notify_key, True)
|
||||
else:
|
||||
status = getattr(state, status_key, None)
|
||||
should_notify = getattr(state, notify_key, True)
|
||||
|
||||
if not should_notify or status is None:
|
||||
return "no_notification"
|
||||
|
||||
status_str = str(status).lower()
|
||||
return status_levels.get(status_str, "no_notification")
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def user_feedback_loop(
|
||||
feedback_required_conditions: list[str] | None = None,
|
||||
feedback_key: str = "requires_feedback",
|
||||
feedback_type_key: str = "feedback_type",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that manages user feedback collection.
|
||||
|
||||
Args:
|
||||
feedback_required_conditions: Conditions that trigger feedback request
|
||||
feedback_key: Key in state indicating if feedback is required
|
||||
feedback_type_key: Key in state specifying type of feedback needed
|
||||
|
||||
Returns:
|
||||
Router function that returns feedback type or "no_feedback"
|
||||
|
||||
Example:
|
||||
feedback_router = user_feedback_loop([
|
||||
"low_confidence", "ambiguous_input", "multiple_options"
|
||||
])
|
||||
graph.add_conditional_edges("feedback_check", feedback_router)
|
||||
"""
|
||||
if feedback_required_conditions is None:
|
||||
feedback_required_conditions = [
|
||||
"low_confidence",
|
||||
"ambiguous_input",
|
||||
"multiple_options",
|
||||
"validation_failed",
|
||||
]
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
# Check if feedback is explicitly required
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
requires_feedback = state.get(feedback_key, False)
|
||||
feedback_type = state.get(feedback_type_key, "general")
|
||||
else:
|
||||
requires_feedback = getattr(state, feedback_key, False)
|
||||
feedback_type = getattr(state, feedback_type_key, "general")
|
||||
|
||||
if requires_feedback:
|
||||
return f"feedback_{feedback_type}"
|
||||
|
||||
# Check for conditions that trigger feedback
|
||||
for condition in feedback_required_conditions:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
condition_met = state.get(condition, False)
|
||||
else:
|
||||
condition_met = getattr(state, condition, False)
|
||||
|
||||
if condition_met:
|
||||
return f"feedback_{condition}"
|
||||
|
||||
return "no_feedback"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def escalate_to_human(
|
||||
escalation_triggers: dict[str, str] | None = None,
|
||||
auto_escalate_key: str = "auto_escalate",
|
||||
escalation_reason_key: str = "escalation_reason",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that escalates to human operators when needed.
|
||||
|
||||
Args:
|
||||
escalation_triggers: Mapping of trigger conditions to escalation types
|
||||
auto_escalate_key: Key in state for automatic escalation flag
|
||||
escalation_reason_key: Key in state containing escalation reason
|
||||
|
||||
Returns:
|
||||
Router function that returns escalation type or "continue_automated"
|
||||
|
||||
Example:
|
||||
escalation_router = escalate_to_human({
|
||||
"critical_error": "immediate",
|
||||
"multiple_failures": "urgent",
|
||||
"user_request": "standard"
|
||||
})
|
||||
graph.add_conditional_edges("escalation_check", escalation_router)
|
||||
"""
|
||||
if escalation_triggers is None:
|
||||
escalation_triggers = {
|
||||
"critical_error": "immediate",
|
||||
"security_issue": "immediate",
|
||||
"multiple_failures": "urgent",
|
||||
"user_request": "standard",
|
||||
"manual_review": "standard",
|
||||
"complex_case": "expert",
|
||||
}
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
# Check for automatic escalation flag
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
auto_escalate = state.get(auto_escalate_key, False)
|
||||
escalation_reason = state.get(escalation_reason_key)
|
||||
else:
|
||||
auto_escalate = getattr(state, auto_escalate_key, False)
|
||||
escalation_reason = getattr(state, escalation_reason_key, None)
|
||||
|
||||
if auto_escalate and escalation_reason:
|
||||
reason_str = str(escalation_reason).lower()
|
||||
for trigger, escalation_type in escalation_triggers.items():
|
||||
if trigger in reason_str:
|
||||
return f"escalate_{escalation_type}"
|
||||
return "escalate_standard"
|
||||
|
||||
# Check for specific escalation triggers in state
|
||||
for trigger, escalation_type in escalation_triggers.items():
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
trigger_present = state.get(trigger, False)
|
||||
else:
|
||||
trigger_present = getattr(state, trigger, False)
|
||||
|
||||
if trigger_present:
|
||||
return f"escalate_{escalation_type}"
|
||||
|
||||
return "continue_automated"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_user_authorization(
|
||||
required_permissions: list[str] | None = None,
|
||||
user_key: str = "user",
|
||||
permissions_key: str = "permissions",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that checks user authorization levels.
|
||||
|
||||
Args:
|
||||
required_permissions: List of permissions required for access
|
||||
user_key: Key in state containing user information
|
||||
permissions_key: Key in user data containing permissions list
|
||||
|
||||
Returns:
|
||||
Router function that returns "authorized" or "unauthorized"
|
||||
|
||||
Example:
|
||||
auth_router = check_user_authorization([
|
||||
"read_data", "write_reports", "admin_access"
|
||||
])
|
||||
graph.add_conditional_edges("permission_check", auth_router)
|
||||
"""
|
||||
if required_permissions is None:
|
||||
required_permissions = ["basic_access"]
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["authorized", "unauthorized"]:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
user = state.get(user_key)
|
||||
else:
|
||||
user = getattr(state, user_key, None)
|
||||
|
||||
if user is None:
|
||||
return "unauthorized"
|
||||
|
||||
# Extract user permissions
|
||||
user_permissions = []
|
||||
if isinstance(user, dict):
|
||||
user_permissions = user.get(permissions_key, [])
|
||||
elif hasattr(user, permissions_key):
|
||||
user_permissions = getattr(user, permissions_key, [])
|
||||
|
||||
if not isinstance(user_permissions, list):
|
||||
return "unauthorized"
|
||||
|
||||
# Check if user has all required permissions
|
||||
for permission in required_permissions:
|
||||
if permission not in user_permissions:
|
||||
return "unauthorized"
|
||||
|
||||
return "authorized"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def collect_user_input(
|
||||
input_types: dict[str, str] | None = None,
|
||||
pending_input_key: str = "pending_user_input",
|
||||
input_type_key: str = "input_type",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that manages user input collection.
|
||||
|
||||
Args:
|
||||
input_types: Mapping of input types to collection methods
|
||||
pending_input_key: Key in state indicating pending input
|
||||
input_type_key: Key in state specifying type of input needed
|
||||
|
||||
Returns:
|
||||
Router function that returns input collection method or "no_input_needed"
|
||||
|
||||
Example:
|
||||
input_router = collect_user_input({
|
||||
"text": "text_input_form",
|
||||
"choice": "multiple_choice",
|
||||
"file": "file_upload"
|
||||
})
|
||||
graph.add_conditional_edges("input_manager", input_router)
|
||||
"""
|
||||
if input_types is None:
|
||||
input_types = {
|
||||
"text": "text_input",
|
||||
"choice": "choice_input",
|
||||
"confirmation": "confirm_input",
|
||||
"file": "file_input",
|
||||
"numeric": "number_input",
|
||||
}
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> str:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
pending = state.get(pending_input_key, False)
|
||||
input_type = state.get(input_type_key, "text")
|
||||
else:
|
||||
pending = getattr(state, pending_input_key, False)
|
||||
input_type = getattr(state, input_type_key, "text")
|
||||
|
||||
if not pending:
|
||||
return "no_input_needed"
|
||||
|
||||
input_type_str = str(input_type).lower()
|
||||
return input_types.get(input_type_str, "text_input")
|
||||
|
||||
return router
|
||||
@@ -0,0 +1,320 @@
|
||||
"""Validation edge helpers for quality control and data integrity.
|
||||
|
||||
This module provides edge helpers for validating outputs, checking accuracy,
|
||||
confidence levels, format compliance, and data privacy requirements.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal, TypeVar
|
||||
|
||||
from bb_core.edge_helpers.core import StateProtocol
|
||||
|
||||
StateT = TypeVar("StateT", bound=StateProtocol)
|
||||
|
||||
|
||||
def check_accuracy(
|
||||
threshold: float = 0.8,
|
||||
accuracy_key: str = "accuracy_score",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that checks if accuracy meets threshold.
|
||||
|
||||
Args:
|
||||
threshold: Minimum accuracy threshold (0.0 to 1.0)
|
||||
accuracy_key: Key in state containing accuracy score
|
||||
|
||||
Returns:
|
||||
Router function that returns "high_accuracy" or "low_accuracy"
|
||||
|
||||
Example:
|
||||
accuracy_router = check_accuracy(threshold=0.85)
|
||||
graph.add_conditional_edges("quality_check", accuracy_router)
|
||||
"""
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["high_accuracy", "low_accuracy"]:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
accuracy = state.get(accuracy_key, 0.0)
|
||||
else:
|
||||
accuracy = getattr(state, accuracy_key, 0.0)
|
||||
|
||||
try:
|
||||
score = float(accuracy)
|
||||
return "high_accuracy" if score >= threshold else "low_accuracy"
|
||||
except (ValueError, TypeError):
|
||||
return "low_accuracy"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_confidence_level(
|
||||
threshold: float = 0.7,
|
||||
confidence_key: str = "confidence",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router based on confidence score threshold.
|
||||
|
||||
Args:
|
||||
threshold: Minimum confidence threshold (0.0 to 1.0)
|
||||
confidence_key: Key in state containing confidence score
|
||||
|
||||
Returns:
|
||||
Router function that returns "high_confidence" or "low_confidence"
|
||||
|
||||
Example:
|
||||
confidence_router = check_confidence_level(threshold=0.75)
|
||||
graph.add_conditional_edges("llm_output", confidence_router)
|
||||
"""
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["high_confidence", "low_confidence"]:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
confidence = state.get(confidence_key, 0.0)
|
||||
else:
|
||||
confidence = getattr(state, confidence_key, 0.0)
|
||||
|
||||
try:
|
||||
score = float(confidence)
|
||||
return "high_confidence" if score >= threshold else "low_confidence"
|
||||
except (ValueError, TypeError):
|
||||
return "low_confidence"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def validate_output_format(
|
||||
expected_format: str = "json",
|
||||
output_key: str = "output",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that validates output format.
|
||||
|
||||
Args:
|
||||
expected_format: Expected format ("json", "xml", "csv", "text")
|
||||
output_key: Key in state containing output to validate
|
||||
|
||||
Returns:
|
||||
Router function that returns "valid_format" or "invalid_format"
|
||||
|
||||
Example:
|
||||
format_router = validate_output_format(expected_format="json")
|
||||
graph.add_conditional_edges("formatter", format_router)
|
||||
"""
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["valid_format", "invalid_format"]:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
output = state.get(output_key, "")
|
||||
else:
|
||||
output = getattr(state, output_key, "")
|
||||
|
||||
if not output:
|
||||
return "invalid_format"
|
||||
|
||||
try:
|
||||
if expected_format.lower() == "json":
|
||||
json.loads(str(output))
|
||||
return "valid_format"
|
||||
elif expected_format.lower() == "xml":
|
||||
# Basic XML validation
|
||||
if str(output).strip().startswith("<") and str(output).strip().endswith(
|
||||
">"
|
||||
):
|
||||
return "valid_format"
|
||||
return "invalid_format"
|
||||
elif expected_format.lower() == "csv":
|
||||
# Basic CSV validation - check for comma separation
|
||||
lines = str(output).strip().split("\n")
|
||||
if len(lines) > 0 and "," in lines[0]:
|
||||
return "valid_format"
|
||||
return "invalid_format"
|
||||
else: # Default to text validation
|
||||
if isinstance(output, str) and len(output.strip()) > 0:
|
||||
return "valid_format"
|
||||
return "invalid_format"
|
||||
|
||||
except Exception:
|
||||
return "invalid_format"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_privacy_compliance(
|
||||
sensitive_patterns: list[str] | None = None,
|
||||
content_key: str = "content",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that checks for privacy compliance.
|
||||
|
||||
Args:
|
||||
sensitive_patterns: List of regex patterns for sensitive data
|
||||
content_key: Key in state containing content to check
|
||||
|
||||
Returns:
|
||||
Router function that returns "compliant" or "privacy_violation"
|
||||
|
||||
Example:
|
||||
privacy_router = check_privacy_compliance([
|
||||
r'\\b\\d{3}-\\d{2}-\\d{4}\\b', # SSN
|
||||
r'\\b\\d{4}[- ]?\\d{4}[- ]?\\d{4}[- ]?\\d{4}\\b' # Credit card
|
||||
])
|
||||
graph.add_conditional_edges("privacy_check", privacy_router)
|
||||
"""
|
||||
if sensitive_patterns is None:
|
||||
sensitive_patterns = [
|
||||
r"\b\d{3}-\d{2}-\d{4}\b", # SSN pattern
|
||||
r"\b\d{4}[- ]?\d{4}[- ]?\d{4}[- ]?\d{4}\b", # Credit card pattern
|
||||
r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", # Email
|
||||
r"\b\d{3}[- ]?\d{3}[- ]?\d{4}\b", # Phone number
|
||||
]
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["compliant", "privacy_violation"]:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
content = state.get(content_key, "")
|
||||
else:
|
||||
content = getattr(state, content_key, "")
|
||||
|
||||
content_str = str(content)
|
||||
|
||||
for pattern in sensitive_patterns:
|
||||
if re.search(pattern, content_str, re.IGNORECASE):
|
||||
return "privacy_violation"
|
||||
|
||||
return "compliant"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_data_freshness(
|
||||
max_age_seconds: int = 3600, # 1 hour default
|
||||
timestamp_key: str = "timestamp",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that checks if data is fresh enough.
|
||||
|
||||
Args:
|
||||
max_age_seconds: Maximum age in seconds before data is stale
|
||||
timestamp_key: Key in state containing data timestamp
|
||||
|
||||
Returns:
|
||||
Router function that returns "fresh" or "stale"
|
||||
|
||||
Example:
|
||||
freshness_router = check_data_freshness(max_age_seconds=1800) # 30 minutes
|
||||
graph.add_conditional_edges("data_check", freshness_router)
|
||||
"""
|
||||
|
||||
def router(state: dict[str, Any] | StateProtocol) -> Literal["fresh", "stale"]:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
timestamp = state.get(timestamp_key)
|
||||
else:
|
||||
timestamp = getattr(state, timestamp_key, None)
|
||||
|
||||
if timestamp is None:
|
||||
return "stale"
|
||||
|
||||
try:
|
||||
# Handle different timestamp formats
|
||||
if isinstance(timestamp, str):
|
||||
# Try parsing ISO format
|
||||
from datetime import datetime
|
||||
|
||||
dt = datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
|
||||
timestamp_value = dt.timestamp()
|
||||
else:
|
||||
timestamp_value = float(timestamp)
|
||||
|
||||
current_time = time.time()
|
||||
age = current_time - timestamp_value
|
||||
|
||||
return "fresh" if age <= max_age_seconds else "stale"
|
||||
|
||||
except (ValueError, TypeError):
|
||||
return "stale"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def validate_required_fields(
|
||||
required_fields: list[str],
|
||||
strict_mode: bool = True,
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that validates presence of required fields.
|
||||
|
||||
Args:
|
||||
required_fields: List of field names that must be present
|
||||
strict_mode: If True, fields must be non-empty; if False, just present
|
||||
|
||||
Returns:
|
||||
Router function that returns "valid" or "missing_fields"
|
||||
|
||||
Example:
|
||||
field_router = validate_required_fields([
|
||||
"user_id", "request_data", "timestamp"
|
||||
])
|
||||
graph.add_conditional_edges("input_validator", field_router)
|
||||
"""
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["valid", "missing_fields"]:
|
||||
for field in required_fields:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
value = state.get(field)
|
||||
else:
|
||||
value = getattr(state, field, None)
|
||||
|
||||
if value is None:
|
||||
return "missing_fields"
|
||||
|
||||
if strict_mode and (
|
||||
value == "" or (isinstance(value, list) and len(value) == 0)
|
||||
):
|
||||
return "missing_fields"
|
||||
|
||||
return "valid"
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def check_output_length(
|
||||
min_length: int = 1,
|
||||
max_length: int | None = None,
|
||||
content_key: str = "output",
|
||||
) -> Callable[[dict[str, Any] | StateProtocol], str]:
|
||||
"""Create a router that validates output length constraints.
|
||||
|
||||
Args:
|
||||
min_length: Minimum required length
|
||||
max_length: Maximum allowed length (None for no limit)
|
||||
content_key: Key in state containing content to check
|
||||
|
||||
Returns:
|
||||
Router function that returns "valid_length", "too_short", or "too_long"
|
||||
|
||||
Example:
|
||||
length_router = check_output_length(min_length=10, max_length=1000)
|
||||
graph.add_conditional_edges("length_validator", length_router)
|
||||
"""
|
||||
|
||||
def router(
|
||||
state: dict[str, Any] | StateProtocol,
|
||||
) -> Literal["valid_length", "too_short", "too_long"]:
|
||||
if hasattr(state, "get") or isinstance(state, dict):
|
||||
content = state.get(content_key, "")
|
||||
else:
|
||||
content = getattr(state, content_key, "")
|
||||
|
||||
content_length = len(str(content))
|
||||
|
||||
if content_length < min_length:
|
||||
return "too_short"
|
||||
elif max_length is not None and content_length > max_length:
|
||||
return "too_long"
|
||||
else:
|
||||
return "valid_length"
|
||||
|
||||
return router
|
||||
@@ -13,8 +13,6 @@ def get_embedding_client() -> Any:
|
||||
ImportError: If main application dependencies not available
|
||||
"""
|
||||
try:
|
||||
from biz_bud.services.factory import ServiceFactory # noqa: F401
|
||||
|
||||
# This is a placeholder - actual implementation would require ServiceFactory
|
||||
raise ImportError(
|
||||
"Embedding client dependencies not available. "
|
||||
@@ -61,20 +59,10 @@ def get_embeddings_instance(
|
||||
Raises:
|
||||
ImportError: If main application dependencies not available
|
||||
"""
|
||||
# Try to create the appropriate embeddings instance
|
||||
try:
|
||||
if embedding_provider == "openai":
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
# Create OpenAIEmbeddings with explicit model parameter
|
||||
embeddings_kwargs = {"model": model or "text-embedding-3-small", **kwargs}
|
||||
return OpenAIEmbeddings(**embeddings_kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported embedding provider: {embedding_provider}")
|
||||
except ImportError as e:
|
||||
# If dependencies not available, raise our custom error
|
||||
raise ImportError(
|
||||
"Embeddings instance not available in bb_core. "
|
||||
"This function requires the main biz_bud application "
|
||||
"with langchain dependencies."
|
||||
) from e
|
||||
# This is a placeholder implementation - actual embedding functionality
|
||||
# requires the main biz_bud application with proper dependency injection
|
||||
raise ImportError(
|
||||
"Embedding service not available in bb_core. "
|
||||
"This function requires the main biz_bud application with "
|
||||
"proper dependency injection."
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ This module provides mechanisms for:
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
@@ -129,6 +130,7 @@ class RateLimitWindow:
|
||||
window_size: int = 60 # seconds
|
||||
max_errors: int = 10
|
||||
timestamps: deque[float] = field(default_factory=deque)
|
||||
_lock: threading.Lock = field(default_factory=threading.Lock, init=False)
|
||||
|
||||
def is_allowed(self) -> bool:
|
||||
"""Check if error reporting is allowed within rate limit.
|
||||
@@ -136,18 +138,21 @@ class RateLimitWindow:
|
||||
Returns:
|
||||
True if within rate limit, False otherwise
|
||||
"""
|
||||
current_time = time.time()
|
||||
with self._lock:
|
||||
current_time = time.time()
|
||||
|
||||
# Remove old timestamps outside the window
|
||||
while self.timestamps and self.timestamps[0] < current_time - self.window_size:
|
||||
self.timestamps.popleft()
|
||||
# Remove old timestamps outside the window
|
||||
while (
|
||||
self.timestamps and self.timestamps[0] < current_time - self.window_size
|
||||
):
|
||||
self.timestamps.popleft()
|
||||
|
||||
# Check if we're within the limit
|
||||
if len(self.timestamps) < self.max_errors:
|
||||
self.timestamps.append(current_time)
|
||||
return True
|
||||
# Check if we're within the limit
|
||||
if len(self.timestamps) < self.max_errors:
|
||||
self.timestamps.append(current_time)
|
||||
return True
|
||||
|
||||
return False
|
||||
return False
|
||||
|
||||
def time_until_allowed(self) -> float:
|
||||
"""Calculate time until next error is allowed.
|
||||
@@ -155,14 +160,15 @@ class RateLimitWindow:
|
||||
Returns:
|
||||
Seconds until next error can be reported
|
||||
"""
|
||||
if not self.timestamps:
|
||||
return 0.0
|
||||
with self._lock:
|
||||
if not self.timestamps:
|
||||
return 0.0
|
||||
|
||||
current_time = time.time()
|
||||
oldest_timestamp = self.timestamps[0]
|
||||
time_until_expiry = (oldest_timestamp + self.window_size) - current_time
|
||||
current_time = time.time()
|
||||
oldest_timestamp = self.timestamps[0]
|
||||
time_until_expiry = (oldest_timestamp + self.window_size) - current_time
|
||||
|
||||
return max(0.0, time_until_expiry)
|
||||
return max(0.0, time_until_expiry)
|
||||
|
||||
|
||||
class ErrorAggregator:
|
||||
@@ -185,6 +191,7 @@ class ErrorAggregator:
|
||||
"""
|
||||
self.dedup_window = dedup_window
|
||||
self.aggregate_similar = aggregate_similar
|
||||
self._lock = threading.RLock() # Use RLock for potential recursive calls
|
||||
|
||||
# Storage for aggregated errors by fingerprint
|
||||
self.aggregated_errors: dict[str, AggregatedError] = {}
|
||||
@@ -222,47 +229,48 @@ class ErrorAggregator:
|
||||
Returns:
|
||||
Tuple of (should_report, reason_if_not)
|
||||
"""
|
||||
fingerprint = ErrorFingerprint.from_error_info(error)
|
||||
fingerprint_hash = fingerprint.hash
|
||||
current_time = time.time()
|
||||
with self._lock:
|
||||
fingerprint = ErrorFingerprint.from_error_info(error)
|
||||
fingerprint_hash = fingerprint.hash
|
||||
current_time = time.time()
|
||||
|
||||
# Check deduplication
|
||||
if fingerprint_hash in self.recent_errors:
|
||||
time_since_last = current_time - self.recent_errors[fingerprint_hash]
|
||||
if time_since_last < self.dedup_window:
|
||||
# Check deduplication
|
||||
if fingerprint_hash in self.recent_errors:
|
||||
time_since_last = current_time - self.recent_errors[fingerprint_hash]
|
||||
if time_since_last < self.dedup_window:
|
||||
return (
|
||||
False,
|
||||
f"Duplicate error suppressed ({time_since_last:.1f}s ago)",
|
||||
)
|
||||
|
||||
# Check rate limiting
|
||||
details = error.get("details", {})
|
||||
severity = details.get("severity", ErrorSeverity.ERROR.value)
|
||||
|
||||
# Check severity-specific rate limit
|
||||
if (
|
||||
severity in self.rate_limiters
|
||||
and not self.rate_limiters[severity].is_allowed()
|
||||
):
|
||||
wait_time = self.rate_limiters[severity].time_until_allowed()
|
||||
return (
|
||||
False,
|
||||
f"Duplicate error suppressed ({time_since_last:.1f}s ago)",
|
||||
f"Rate limit exceeded for {severity} errors ({wait_time:.1f}s)",
|
||||
)
|
||||
|
||||
# Check rate limiting
|
||||
details = error.get("details", {})
|
||||
severity = details.get("severity", ErrorSeverity.ERROR.value)
|
||||
# Check global rate limit
|
||||
if not self.global_rate_limiter.is_allowed():
|
||||
wait_time = self.global_rate_limiter.time_until_allowed()
|
||||
return False, f"Global rate limit exceeded (wait {wait_time:.1f}s)"
|
||||
|
||||
# Check severity-specific rate limit
|
||||
if (
|
||||
severity in self.rate_limiters
|
||||
and not self.rate_limiters[severity].is_allowed()
|
||||
):
|
||||
wait_time = self.rate_limiters[severity].time_until_allowed()
|
||||
return (
|
||||
False,
|
||||
f"Rate limit exceeded for {severity} errors ({wait_time:.1f}s)",
|
||||
)
|
||||
# Update tracking
|
||||
self.recent_errors[fingerprint_hash] = current_time
|
||||
|
||||
# Check global rate limit
|
||||
if not self.global_rate_limiter.is_allowed():
|
||||
wait_time = self.global_rate_limiter.time_until_allowed()
|
||||
return False, f"Global rate limit exceeded (wait {wait_time:.1f}s)"
|
||||
# Clean old entries periodically
|
||||
if len(self.recent_errors) > 1000:
|
||||
self._cleanup_old_entries()
|
||||
|
||||
# Update tracking
|
||||
self.recent_errors[fingerprint_hash] = current_time
|
||||
|
||||
# Clean old entries periodically
|
||||
if len(self.recent_errors) > 1000:
|
||||
self._cleanup_old_entries()
|
||||
|
||||
return True, None
|
||||
return True, None
|
||||
|
||||
def add_error(self, error: ErrorInfo) -> AggregatedError:
|
||||
"""Add an error to aggregation.
|
||||
@@ -273,19 +281,20 @@ class ErrorAggregator:
|
||||
Returns:
|
||||
Aggregated error information
|
||||
"""
|
||||
fingerprint = ErrorFingerprint.from_error_info(error)
|
||||
with self._lock:
|
||||
fingerprint = ErrorFingerprint.from_error_info(error)
|
||||
|
||||
if self.aggregate_similar and fingerprint.hash in self.aggregated_errors:
|
||||
# Update existing aggregation
|
||||
aggregated = self.aggregated_errors[fingerprint.hash]
|
||||
aggregated.add_occurrence(error)
|
||||
else:
|
||||
# Create new aggregation
|
||||
aggregated = AggregatedError(fingerprint=fingerprint)
|
||||
aggregated.add_occurrence(error)
|
||||
self.aggregated_errors[fingerprint.hash] = aggregated
|
||||
if self.aggregate_similar and fingerprint.hash in self.aggregated_errors:
|
||||
# Update existing aggregation
|
||||
aggregated = self.aggregated_errors[fingerprint.hash]
|
||||
aggregated.add_occurrence(error)
|
||||
else:
|
||||
# Create new aggregation
|
||||
aggregated = AggregatedError(fingerprint=fingerprint)
|
||||
aggregated.add_occurrence(error)
|
||||
self.aggregated_errors[fingerprint.hash] = aggregated
|
||||
|
||||
return aggregated
|
||||
return aggregated
|
||||
|
||||
def get_aggregated_errors(
|
||||
self,
|
||||
@@ -301,24 +310,27 @@ class ErrorAggregator:
|
||||
Returns:
|
||||
List of aggregated errors
|
||||
"""
|
||||
current_time = datetime.now(UTC)
|
||||
results = []
|
||||
with self._lock:
|
||||
current_time = datetime.now(UTC)
|
||||
results = []
|
||||
|
||||
for aggregated in self.aggregated_errors.values():
|
||||
if aggregated.count < min_count:
|
||||
continue
|
||||
|
||||
if time_window:
|
||||
time_since_last = (current_time - aggregated.last_seen).total_seconds()
|
||||
if time_since_last > time_window:
|
||||
for aggregated in self.aggregated_errors.values():
|
||||
if aggregated.count < min_count:
|
||||
continue
|
||||
|
||||
results.append(aggregated)
|
||||
if time_window:
|
||||
time_since_last = (
|
||||
current_time - aggregated.last_seen
|
||||
).total_seconds()
|
||||
if time_since_last > time_window:
|
||||
continue
|
||||
|
||||
# Sort by count (descending) and recency
|
||||
results.sort(key=lambda x: (-x.count, x.last_seen), reverse=True)
|
||||
results.append(aggregated)
|
||||
|
||||
return results
|
||||
# Sort by count (descending) and recency
|
||||
results.sort(key=lambda x: (-x.count, x.last_seen), reverse=True)
|
||||
|
||||
return results
|
||||
|
||||
def get_error_summary(self) -> dict[str, Any]:
|
||||
"""Get summary of aggregated errors.
|
||||
@@ -326,58 +338,61 @@ class ErrorAggregator:
|
||||
Returns:
|
||||
Summary statistics
|
||||
"""
|
||||
total_errors = sum(agg.count for agg in self.aggregated_errors.values())
|
||||
unique_errors = len(self.aggregated_errors)
|
||||
with self._lock:
|
||||
total_errors = sum(agg.count for agg in self.aggregated_errors.values())
|
||||
unique_errors = len(self.aggregated_errors)
|
||||
|
||||
# Group by category
|
||||
by_category = defaultdict(int)
|
||||
by_severity = defaultdict(int)
|
||||
by_node = defaultdict(int)
|
||||
# Group by category
|
||||
by_category = defaultdict(int)
|
||||
by_severity = defaultdict(int)
|
||||
by_node = defaultdict(int)
|
||||
|
||||
for aggregated in self.aggregated_errors.values():
|
||||
by_category[aggregated.fingerprint.category] += aggregated.count
|
||||
for aggregated in self.aggregated_errors.values():
|
||||
by_category[aggregated.fingerprint.category] += aggregated.count
|
||||
|
||||
# Get severity from sample
|
||||
if aggregated.sample_errors:
|
||||
severity = (
|
||||
aggregated.sample_errors[0]
|
||||
.get("details", {})
|
||||
.get("severity", "unknown")
|
||||
)
|
||||
by_severity[severity] += aggregated.count
|
||||
# Get severity from sample
|
||||
if aggregated.sample_errors:
|
||||
severity = (
|
||||
aggregated.sample_errors[0]
|
||||
.get("details", {})
|
||||
.get("severity", "unknown")
|
||||
)
|
||||
by_severity[severity] += aggregated.count
|
||||
|
||||
if aggregated.fingerprint.node:
|
||||
by_node[aggregated.fingerprint.node] += aggregated.count
|
||||
if aggregated.fingerprint.node:
|
||||
by_node[aggregated.fingerprint.node] += aggregated.count
|
||||
|
||||
return {
|
||||
"total_errors": total_errors,
|
||||
"unique_errors": unique_errors,
|
||||
"by_category": dict(by_category),
|
||||
"by_severity": dict(by_severity),
|
||||
"by_node": dict(by_node),
|
||||
"top_errors": [
|
||||
{
|
||||
"fingerprint": agg.fingerprint.hash,
|
||||
"type": agg.fingerprint.error_type,
|
||||
"category": agg.fingerprint.category,
|
||||
"count": agg.count,
|
||||
"first_seen": agg.first_seen.isoformat(),
|
||||
"last_seen": agg.last_seen.isoformat(),
|
||||
}
|
||||
for agg in self.get_aggregated_errors(min_count=2)[:10]
|
||||
],
|
||||
}
|
||||
return {
|
||||
"total_errors": total_errors,
|
||||
"unique_errors": unique_errors,
|
||||
"by_category": dict(by_category),
|
||||
"by_severity": dict(by_severity),
|
||||
"by_node": dict(by_node),
|
||||
"top_errors": [
|
||||
{
|
||||
"fingerprint": agg.fingerprint.hash,
|
||||
"type": agg.fingerprint.error_type,
|
||||
"category": agg.fingerprint.category,
|
||||
"count": agg.count,
|
||||
"first_seen": agg.first_seen.isoformat(),
|
||||
"last_seen": agg.last_seen.isoformat(),
|
||||
}
|
||||
for agg in self.get_aggregated_errors(min_count=2)[:10]
|
||||
],
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all aggregation state."""
|
||||
self.aggregated_errors.clear()
|
||||
self.recent_errors.clear()
|
||||
for limiter in self.rate_limiters.values():
|
||||
limiter.timestamps.clear()
|
||||
self.global_rate_limiter.timestamps.clear()
|
||||
with self._lock:
|
||||
self.aggregated_errors.clear()
|
||||
self.recent_errors.clear()
|
||||
for limiter in self.rate_limiters.values():
|
||||
limiter.timestamps.clear()
|
||||
self.global_rate_limiter.timestamps.clear()
|
||||
|
||||
def _cleanup_old_entries(self) -> None:
|
||||
"""Remove old entries from recent errors tracking."""
|
||||
# Lock is already held by caller
|
||||
current_time = time.time()
|
||||
cutoff_time = current_time - (self.dedup_window * 2)
|
||||
|
||||
@@ -388,6 +403,7 @@ class ErrorAggregator:
|
||||
|
||||
# Global instance for easy access
|
||||
_error_aggregator: ErrorAggregator | None = None
|
||||
_error_aggregator_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_error_aggregator() -> ErrorAggregator:
|
||||
@@ -397,13 +413,19 @@ def get_error_aggregator() -> ErrorAggregator:
|
||||
Global ErrorAggregator instance
|
||||
"""
|
||||
global _error_aggregator
|
||||
if _error_aggregator is None:
|
||||
if _error_aggregator is not None:
|
||||
return _error_aggregator
|
||||
|
||||
with _error_aggregator_lock:
|
||||
if _error_aggregator is not None:
|
||||
return _error_aggregator
|
||||
_error_aggregator = ErrorAggregator()
|
||||
return _error_aggregator
|
||||
return _error_aggregator
|
||||
|
||||
|
||||
def reset_error_aggregator() -> None:
|
||||
"""Reset the global error aggregator."""
|
||||
global _error_aggregator
|
||||
if _error_aggregator is not None:
|
||||
_error_aggregator.reset()
|
||||
with _error_aggregator_lock:
|
||||
if _error_aggregator is not None:
|
||||
_error_aggregator.reset()
|
||||
|
||||
@@ -177,7 +177,7 @@ class ErrorRegistry:
|
||||
instance = cast("ErrorRegistry", super().__new__(cls))
|
||||
cls._instance = instance
|
||||
instance._initialize_registry()
|
||||
return cast("ErrorRegistry", cls._instance)
|
||||
return cls._instance
|
||||
|
||||
def _initialize_registry(self) -> None:
|
||||
"""Initialize the error registry with default mappings."""
|
||||
|
||||
@@ -290,13 +290,17 @@ def create_formatted_error(
|
||||
# Auto-categorize if no category or error_code provided
|
||||
if not category and not error_code:
|
||||
auto_category, auto_code = categorize_error(exception)
|
||||
category = auto_category.value
|
||||
category = auto_category
|
||||
if auto_code and not error_code:
|
||||
error_code = auto_code
|
||||
|
||||
# Get ErrorCategory enum
|
||||
try:
|
||||
error_category = ErrorCategory(category or "unknown")
|
||||
if isinstance(category, ErrorCategory):
|
||||
error_category: ErrorCategory = category
|
||||
else:
|
||||
category_str = category or "unknown"
|
||||
error_category = cast(ErrorCategory, ErrorCategory(category_str))
|
||||
except ValueError:
|
||||
error_category = ErrorCategory.UNKNOWN
|
||||
|
||||
@@ -304,7 +308,7 @@ def create_formatted_error(
|
||||
formatted_message = ErrorMessageFormatter.format_error_message(
|
||||
message=message,
|
||||
error_code=error_code,
|
||||
category=cast("ErrorCategory", error_category),
|
||||
category=error_category,
|
||||
template_type=template_type,
|
||||
**context,
|
||||
)
|
||||
|
||||
@@ -160,17 +160,27 @@ class RouterConfig:
|
||||
|
||||
# Categories
|
||||
if "categories" in config:
|
||||
condition.categories = cast(
|
||||
"list[ErrorCategory]",
|
||||
[ErrorCategory(cat) for cat in config["categories"]],
|
||||
)
|
||||
categories_list: list[ErrorCategory] = []
|
||||
for cat in config["categories"]:
|
||||
if isinstance(cat, str):
|
||||
categories_list.append(cast(ErrorCategory, ErrorCategory(cat)))
|
||||
elif isinstance(cat, ErrorCategory):
|
||||
categories_list.append(cat)
|
||||
else:
|
||||
raise TypeError(f"Invalid type for category: {cat!r} (type: {type(cat)})")
|
||||
condition.categories = categories_list
|
||||
|
||||
# Severities
|
||||
if "severities" in config:
|
||||
condition.severities = cast(
|
||||
"list[ErrorSeverity]",
|
||||
[ErrorSeverity(sev) for sev in config["severities"]],
|
||||
)
|
||||
severities_list: list[ErrorSeverity] = []
|
||||
for sev in config["severities"]:
|
||||
if isinstance(sev, str):
|
||||
severities_list.append(cast(ErrorSeverity, ErrorSeverity(sev)))
|
||||
elif isinstance(sev, ErrorSeverity):
|
||||
severities_list.append(sev)
|
||||
else:
|
||||
raise TypeError(f"Invalid type for severity: {sev!r} (type: {type(sev)})")
|
||||
condition.severities = severities_list
|
||||
|
||||
# Nodes
|
||||
if "nodes" in config:
|
||||
@@ -204,10 +214,13 @@ class RouterConfig:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Build config dict
|
||||
config = {"default_action": self.router.default_action.value, "routes": []}
|
||||
config: dict[str, Any] = {
|
||||
"default_action": self.router.default_action.value,
|
||||
"routes": [],
|
||||
}
|
||||
|
||||
for route in self.router.routes:
|
||||
route_config = {
|
||||
route_config: dict[str, Any] = {
|
||||
"name": route.name,
|
||||
"action": route.action.value,
|
||||
"priority": route.priority,
|
||||
|
||||
@@ -375,12 +375,10 @@ class ErrorTelemetry:
|
||||
else:
|
||||
# Default: log the alert
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
logger = cast("logging.Logger", logging.getLogger(__name__))
|
||||
level = cast(
|
||||
"int",
|
||||
logging.ERROR if severity in ["critical", "error"] else logging.WARNING,
|
||||
logger = logging.getLogger(__name__)
|
||||
level = (
|
||||
logging.ERROR if severity in ["critical", "error"] else logging.WARNING
|
||||
)
|
||||
logger.log(
|
||||
level,
|
||||
|
||||
@@ -79,7 +79,7 @@ def create_error_details(
|
||||
}
|
||||
|
||||
|
||||
def _is_sensitive_field(field_name: str) -> bool:
|
||||
def is_sensitive_field(field_name: str) -> bool:
|
||||
"""Check if a field name indicates sensitive data.
|
||||
|
||||
Args:
|
||||
@@ -88,16 +88,13 @@ def _is_sensitive_field(field_name: str) -> bool:
|
||||
Returns:
|
||||
True if the field name indicates sensitive data
|
||||
"""
|
||||
if not isinstance(field_name, str):
|
||||
return False
|
||||
|
||||
field_lower = field_name.lower()
|
||||
|
||||
# Check against all sensitive patterns
|
||||
return any(re.match(pattern, field_lower) for pattern in SENSITIVE_PATTERNS)
|
||||
|
||||
|
||||
def _redact_sensitive_data(data: Any, max_depth: int = 10) -> Any:
|
||||
def redact_sensitive_data(data: Any, max_depth: int = 10) -> Any:
|
||||
"""Recursively redact sensitive data from nested structures.
|
||||
|
||||
Args:
|
||||
@@ -113,13 +110,13 @@ def _redact_sensitive_data(data: Any, max_depth: int = 10) -> Any:
|
||||
if isinstance(data, dict):
|
||||
result = {}
|
||||
for key, value in data.items():
|
||||
if _is_sensitive_field(key):
|
||||
if is_sensitive_field(key):
|
||||
result[key] = REDACTED_VALUE
|
||||
else:
|
||||
result[key] = _redact_sensitive_data(value, max_depth - 1)
|
||||
result[key] = redact_sensitive_data(value, max_depth - 1)
|
||||
return result
|
||||
elif isinstance(data, list):
|
||||
return [_redact_sensitive_data(item, max_depth - 1) for item in data]
|
||||
return [redact_sensitive_data(item, max_depth - 1) for item in data]
|
||||
else:
|
||||
return data
|
||||
|
||||
@@ -165,7 +162,7 @@ def safe_serialize_response(response: Any) -> dict[str, Any]:
|
||||
# Handle lists directly (return as-is after redaction)
|
||||
elif isinstance(response, list):
|
||||
# Apply redaction directly to the list and return it
|
||||
return _redact_sensitive_data(response)
|
||||
return redact_sensitive_data(response)
|
||||
# Handle built-in types without __dict__
|
||||
elif response is None or isinstance(response, (str, int, float, bool)):
|
||||
data = {"type": type(response).__name__, "value": str(response)}
|
||||
@@ -180,7 +177,7 @@ def safe_serialize_response(response: Any) -> dict[str, Any]:
|
||||
data = {"type": type(response).__name__, "value": str(response)}
|
||||
|
||||
# Redact sensitive data from the result
|
||||
return _redact_sensitive_data(data)
|
||||
return redact_sensitive_data(data)
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
|
||||
@@ -9,7 +9,7 @@ import asyncio
|
||||
import functools
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
from ..logging import get_logger
|
||||
@@ -188,8 +188,7 @@ def track_metrics(
|
||||
)
|
||||
|
||||
metric = cast("NodeMetric", metrics[metric_name])
|
||||
if metric is not None:
|
||||
metric["count"] = (metric["count"] or 0) + 1
|
||||
metric["count"] = (metric["count"] or 0) + 1
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
@@ -205,7 +204,7 @@ def track_metrics(
|
||||
metric["avg_duration_ms"] = (
|
||||
metric["total_duration_ms"] or 0.0
|
||||
) / count
|
||||
metric["last_execution"] = datetime.utcnow().isoformat()
|
||||
metric["last_execution"] = datetime.now(UTC).isoformat()
|
||||
|
||||
return result
|
||||
|
||||
@@ -221,7 +220,7 @@ def track_metrics(
|
||||
metric["avg_duration_ms"] = (
|
||||
metric["total_duration_ms"] or 0.0
|
||||
) / count
|
||||
metric["last_execution"] = datetime.utcnow().isoformat()
|
||||
metric["last_execution"] = datetime.now(UTC).isoformat()
|
||||
metric["last_error"] = str(e)
|
||||
|
||||
raise
|
||||
@@ -251,8 +250,7 @@ def track_metrics(
|
||||
)
|
||||
|
||||
metric = cast("NodeMetric", metrics[metric_name])
|
||||
if metric is not None:
|
||||
metric["count"] = (metric["count"] or 0) + 1
|
||||
metric["count"] = (metric["count"] or 0) + 1
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
@@ -267,7 +265,7 @@ def track_metrics(
|
||||
metric["avg_duration_ms"] = (
|
||||
metric["total_duration_ms"] or 0.0
|
||||
) / count
|
||||
metric["last_execution"] = datetime.utcnow().isoformat()
|
||||
metric["last_execution"] = datetime.now(UTC).isoformat()
|
||||
|
||||
return result
|
||||
|
||||
@@ -282,7 +280,7 @@ def track_metrics(
|
||||
metric["avg_duration_ms"] = (
|
||||
metric["total_duration_ms"] or 0.0
|
||||
) / count
|
||||
metric["last_execution"] = datetime.utcnow().isoformat()
|
||||
metric["last_execution"] = datetime.now(UTC).isoformat()
|
||||
metric["last_error"] = str(e)
|
||||
|
||||
raise
|
||||
@@ -339,7 +337,7 @@ def handle_errors(
|
||||
"node": func.__name__,
|
||||
"error": str(e),
|
||||
"type": type(e).__name__,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -373,7 +371,7 @@ def handle_errors(
|
||||
"node": func.__name__,
|
||||
"error": str(e),
|
||||
"type": type(e).__name__,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ enabling consistent configuration injection across all nodes and tools.
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph import StateGraph
|
||||
@@ -53,17 +53,18 @@ def configure_graph_with_injection(
|
||||
continue
|
||||
|
||||
# Create wrapper that injects config
|
||||
wrapped_node = create_config_injected_node(node_func, base_config)
|
||||
|
||||
# Replace the node
|
||||
graph_builder.nodes[node_name] = wrapped_node
|
||||
if callable(node_func):
|
||||
callable_node = cast(Callable[..., object], node_func)
|
||||
wrapped_node = create_config_injected_node(callable_node, base_config)
|
||||
# Replace the node
|
||||
graph_builder.nodes[node_name] = wrapped_node
|
||||
|
||||
return graph_builder
|
||||
|
||||
|
||||
def create_config_injected_node(
|
||||
node_func: Callable[..., object], base_config: RunnableConfig
|
||||
) -> Callable[..., object]:
|
||||
) -> Any:
|
||||
"""Create a node wrapper that injects RunnableConfig.
|
||||
|
||||
Args:
|
||||
@@ -86,7 +87,7 @@ def create_config_injected_node(
|
||||
# Node already expects config, wrap to provide it
|
||||
@wraps(node_func)
|
||||
async def config_aware_wrapper(
|
||||
state: dict, config: RunnableConfig | None = None
|
||||
state: dict[str, Any], config: RunnableConfig | None = None
|
||||
) -> Any:
|
||||
# Merge base config with runtime config
|
||||
if config:
|
||||
@@ -102,7 +103,7 @@ def create_config_injected_node(
|
||||
else:
|
||||
return node_func(state, config=merged_config)
|
||||
|
||||
wrapped = RunnableLambda(config_aware_wrapper).with_config(base_config)
|
||||
wrapped: Any = RunnableLambda(config_aware_wrapper).with_config(base_config)
|
||||
else:
|
||||
# Node doesn't expect config, just wrap with config
|
||||
wrapped = RunnableLambda(node_func).with_config(base_config)
|
||||
|
||||
@@ -166,12 +166,12 @@ class ConfigurationProvider:
|
||||
# Copy metadata
|
||||
self_metadata = getattr(self._config, "metadata", {})
|
||||
other_metadata = getattr(other, "metadata", {})
|
||||
merged.metadata = {**self_metadata, **other_metadata}
|
||||
merged.metadata = {**self_metadata, **other_metadata} # type: ignore[attr-defined]
|
||||
|
||||
# Copy configurable
|
||||
self_configurable = getattr(self._config, "configurable", {})
|
||||
other_configurable = getattr(other, "configurable", {})
|
||||
merged.configurable = {**self_configurable, **other_configurable}
|
||||
merged.configurable = {**self_configurable, **other_configurable} # type: ignore[attr-defined]
|
||||
|
||||
# Copy other attributes
|
||||
for attr in ["tags", "callbacks", "recursion_limit"]:
|
||||
@@ -207,13 +207,14 @@ class ConfigurationProvider:
|
||||
config = RunnableConfig()
|
||||
|
||||
# Set configurable values
|
||||
config.configurable = {
|
||||
configurable_dict = {
|
||||
"app_config": app_config,
|
||||
"service_factory": service_factory,
|
||||
}
|
||||
config.configurable = configurable_dict # type: ignore[attr-defined]
|
||||
|
||||
# Set metadata
|
||||
config.metadata = metadata
|
||||
config.metadata = metadata # type: ignore[attr-defined]
|
||||
|
||||
return cls(config)
|
||||
|
||||
@@ -273,7 +274,7 @@ def create_runnable_config(
|
||||
if max_tokens_override is not None:
|
||||
configurable["max_tokens_override"] = max_tokens_override
|
||||
|
||||
config.configurable = configurable
|
||||
config.configurable = configurable # type: ignore[attr-defined]
|
||||
|
||||
# Set metadata
|
||||
metadata = {}
|
||||
@@ -287,6 +288,6 @@ def create_runnable_config(
|
||||
if session_id is not None:
|
||||
metadata["session_id"] = session_id
|
||||
|
||||
config.metadata = metadata
|
||||
config.metadata = metadata # type: ignore[attr-defined]
|
||||
|
||||
return config
|
||||
|
||||
@@ -91,6 +91,7 @@ def setup_logging(
|
||||
root.setLevel(numeric_level)
|
||||
|
||||
# Add console handler
|
||||
console_handler = None
|
||||
if use_rich:
|
||||
console_handler = SafeRichHandler(
|
||||
console=_console,
|
||||
@@ -114,11 +115,12 @@ def setup_logging(
|
||||
if hasattr(console_handler, "setFormatter"):
|
||||
console_handler.setFormatter(formatter)
|
||||
|
||||
# Set handler level and add to root
|
||||
if hasattr(console_handler, "setLevel"):
|
||||
console_handler.setLevel(numeric_level)
|
||||
if hasattr(root, "addHandler"):
|
||||
root.addHandler(console_handler)
|
||||
# Set handler level and add to root (only if handler was created)
|
||||
if console_handler is not None:
|
||||
if hasattr(console_handler, "setLevel"):
|
||||
console_handler.setLevel(numeric_level)
|
||||
if hasattr(root, "addHandler"):
|
||||
root.addHandler(console_handler)
|
||||
|
||||
# Add file handler if specified
|
||||
if log_file:
|
||||
|
||||
@@ -47,6 +47,9 @@ def log_function_call(
|
||||
if logger is None:
|
||||
logger = get_logger(func.__module__)
|
||||
|
||||
# Assert logger is not None after assignment
|
||||
assert logger is not None
|
||||
|
||||
# Log function call
|
||||
message_parts = [f"Calling {func.__name__}"]
|
||||
if include_args and (args or kwargs):
|
||||
|
||||
@@ -337,8 +337,9 @@ class APIClient:
|
||||
) -> APIResponse:
|
||||
"""Make request with retry logic."""
|
||||
# Normalize method to RequestMethod
|
||||
method_val = RequestMethod(method) if isinstance(method, str) else method
|
||||
method_val = cast(RequestMethod, method_val)
|
||||
method_val = (
|
||||
method if isinstance(method, RequestMethod) else RequestMethod(method)
|
||||
)
|
||||
# Merge headers
|
||||
request_headers: dict[str, str] = {**self.headers}
|
||||
if headers:
|
||||
@@ -351,10 +352,10 @@ class APIClient:
|
||||
try:
|
||||
# Log request
|
||||
logger.info(
|
||||
f"Making {method_val.value} request to: {url}",
|
||||
f"Making {method_val} request to: {url}",
|
||||
extra={
|
||||
"operation": "api_request",
|
||||
"method": method_val.value,
|
||||
"method": method_val,
|
||||
"url": url,
|
||||
"attempt": attempt + 1,
|
||||
},
|
||||
@@ -600,8 +601,10 @@ def create_api_client(
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
config_kwargs = {"timeout": timeout, "max_retries": max_retries}
|
||||
config = RequestConfig(**config_kwargs)
|
||||
config: RequestConfig = RequestConfig.model_validate({
|
||||
"timeout": timeout,
|
||||
"max_retries": max_retries
|
||||
})
|
||||
|
||||
if client_type == "basic":
|
||||
return APIClient(base_url=base_url, headers=headers, config=config)
|
||||
@@ -649,8 +652,10 @@ def proxied_rate_limited_request(
|
||||
import asyncio
|
||||
|
||||
async def _make_request() -> object:
|
||||
config_kwargs = {"timeout": timeout_val}
|
||||
async with APIClient(config=RequestConfig(**config_kwargs)) as client:
|
||||
config: RequestConfig = RequestConfig.model_validate({
|
||||
"timeout": timeout_val
|
||||
})
|
||||
async with APIClient(config=config) as client:
|
||||
# Map method string to RequestMethod enum
|
||||
method_upper: str = method.upper()
|
||||
try:
|
||||
|
||||
@@ -4,10 +4,13 @@ import asyncio
|
||||
import functools
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, TypeVar, cast
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, Union, cast, overload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -149,7 +152,7 @@ def exponential_backoff(
|
||||
|
||||
|
||||
async def retry_with_backoff(
|
||||
func: Callable, # Simplified type to avoid overload issues
|
||||
func: Callable[..., Any],
|
||||
config: RetryConfig,
|
||||
*args,
|
||||
**kwargs,
|
||||
@@ -174,12 +177,14 @@ async def retry_with_backoff(
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
# For async functions
|
||||
result = await func(*args, **kwargs)
|
||||
return cast("T", result)
|
||||
async_func = cast(Callable[..., Awaitable[T]], func)
|
||||
result = await async_func(*args, **kwargs)
|
||||
return result
|
||||
else:
|
||||
# For sync functions
|
||||
result = func(*args, **kwargs)
|
||||
return cast("T", result)
|
||||
sync_func = cast(Callable[..., T], func)
|
||||
result = sync_func(*args, **kwargs)
|
||||
return result
|
||||
except Exception as e:
|
||||
# Check if exception is in the allowed list
|
||||
if not any(isinstance(e, exc_type) for exc_type in config.exceptions):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Type definitions for networking module."""
|
||||
|
||||
from typing import Literal, TypedDict
|
||||
from typing import Any, Literal, TypedDict
|
||||
|
||||
try:
|
||||
from typing import NotRequired
|
||||
@@ -17,7 +17,7 @@ class HTTPResponse(TypedDict):
|
||||
headers: dict[str, str]
|
||||
content: bytes
|
||||
text: NotRequired[str]
|
||||
json: NotRequired[dict | list]
|
||||
json: NotRequired[dict[str, Any] | list[Any]]
|
||||
|
||||
|
||||
class RequestOptions(TypedDict):
|
||||
@@ -27,7 +27,7 @@ class RequestOptions(TypedDict):
|
||||
url: str
|
||||
headers: NotRequired[dict[str, str]]
|
||||
params: NotRequired[dict[str, str]]
|
||||
json: NotRequired[dict | list]
|
||||
json: NotRequired[dict[str, Any] | list[Any]]
|
||||
data: NotRequired[bytes | str]
|
||||
timeout: NotRequired[float | tuple[float, float]]
|
||||
follow_redirects: NotRequired[bool]
|
||||
|
||||
@@ -74,12 +74,12 @@ class ServiceHelperRemovedError(ImportError):
|
||||
)
|
||||
|
||||
|
||||
def get_service_factory(*args, **kwargs): # type: ignore
|
||||
def get_service_factory(*args, **kwargs):
|
||||
"""REMOVED: Use biz_bud.services.factory.get_global_factory() instead."""
|
||||
raise ServiceHelperRemovedError("get_service_factory")
|
||||
|
||||
|
||||
def get_service_factory_sync(*args, **kwargs): # type: ignore
|
||||
def get_service_factory_sync(*args, **kwargs):
|
||||
"""REMOVED: Use biz_bud.services.factory.get_global_factory() instead."""
|
||||
raise ServiceHelperRemovedError("get_service_factory_sync")
|
||||
|
||||
|
||||
@@ -183,7 +183,7 @@ class WebSearchHistoryEntry(TypedDict, total=False):
|
||||
|
||||
query: str
|
||||
timestamp: str
|
||||
results: list[dict]
|
||||
results: list[dict[str, Any]]
|
||||
summary: str
|
||||
|
||||
|
||||
@@ -203,7 +203,7 @@ class ApiResponseDataTypedDict(TypedDict, total=False):
|
||||
validation_passed: bool | None
|
||||
validation_issues: list[str]
|
||||
entities: NotRequired[list[str]]
|
||||
report_metadata: NotRequired[dict]
|
||||
report_metadata: NotRequired[dict[str, Any]]
|
||||
|
||||
|
||||
class ApiResponseTypedDict(TypedDict, total=False):
|
||||
@@ -246,7 +246,7 @@ class ToolCallTypedDict(TypedDict):
|
||||
class ParsedInputTypedDict(TypedDict, total=False):
|
||||
"""Structured input payload with 'raw_payload' and 'user_query' keys."""
|
||||
|
||||
raw_payload: dict
|
||||
raw_payload: dict[str, Any]
|
||||
user_query: str
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
"""Core utilities for Business Buddy framework."""
|
||||
|
||||
from bb_core.utils.lazy_loader import (
|
||||
LazyProxy,
|
||||
ThreadSafeLazyLoader,
|
||||
create_lazy_loader,
|
||||
)
|
||||
from bb_core.utils.url_normalizer import URLNormalizer
|
||||
|
||||
__all__ = ["URLNormalizer"]
|
||||
__all__ = ["URLNormalizer", "ThreadSafeLazyLoader", "LazyProxy", "create_lazy_loader"]
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
"""Thread-safe lazy loading utilities for singleton pattern implementations.
|
||||
|
||||
This module provides thread-safe utilities for lazy loading expensive resources
|
||||
like graphs and configurations in a multi-threaded environment.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
class ThreadSafeLazyLoader[T]:
|
||||
"""Thread-safe lazy loader for singleton instances."""
|
||||
|
||||
def __init__(self, factory: Callable[[], T]) -> None:
|
||||
"""Initialize the lazy loader with a factory function.
|
||||
|
||||
Args:
|
||||
factory: Function that creates the instance when called
|
||||
|
||||
"""
|
||||
self._factory = factory
|
||||
self._instance: T | None = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def get_instance(self) -> T:
|
||||
"""Get the singleton instance, creating it if necessary.
|
||||
|
||||
Returns:
|
||||
The singleton instance
|
||||
|
||||
"""
|
||||
# Double-checked locking pattern for thread safety
|
||||
if self._instance is None:
|
||||
with self._lock:
|
||||
# Check again after acquiring lock
|
||||
if self._instance is None:
|
||||
self._instance = self._factory()
|
||||
return self._instance
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the instance (mainly for testing purposes)."""
|
||||
with self._lock:
|
||||
self._instance = None
|
||||
|
||||
|
||||
def create_lazy_loader[T](factory: Callable[[], T]) -> ThreadSafeLazyLoader[T]:
|
||||
"""Create a thread-safe lazy loader for the given factory function.
|
||||
|
||||
Args:
|
||||
factory: Function that creates the instance when called
|
||||
|
||||
Returns:
|
||||
ThreadSafeLazyLoader instance
|
||||
|
||||
"""
|
||||
return ThreadSafeLazyLoader(factory)
|
||||
|
||||
|
||||
class LazyProxy[T]:
|
||||
"""Proxy object that forwards attribute access to a lazy-loaded instance."""
|
||||
|
||||
def __init__(self, loader: ThreadSafeLazyLoader[T]) -> None:
|
||||
"""Initialize the proxy with a lazy loader.
|
||||
|
||||
Args:
|
||||
loader: ThreadSafeLazyLoader instance
|
||||
|
||||
"""
|
||||
self._loader = loader
|
||||
|
||||
def __getattr__(self, name: str) -> object:
|
||||
"""Forward attribute access to the lazy-loaded instance."""
|
||||
return getattr(self._loader.get_instance(), name)
|
||||
|
||||
def __call__(self, *args: object, **kwargs: object) -> object:
|
||||
"""Forward calls to the lazy-loaded instance."""
|
||||
instance = self._loader.get_instance()
|
||||
if not callable(instance):
|
||||
raise TypeError(f"Instance of type {type(instance)} is not callable")
|
||||
return instance(*args, **kwargs)
|
||||
@@ -11,14 +11,15 @@ Example:
|
||||
'split into chunks.']
|
||||
"""
|
||||
|
||||
from typing import Any, Protocol, cast, runtime_checkable
|
||||
from typing import TYPE_CHECKING, Any, Protocol, cast, runtime_checkable
|
||||
|
||||
try:
|
||||
if TYPE_CHECKING:
|
||||
import nltk
|
||||
from nltk.tokenize.api import TokenizerI
|
||||
except ImportError:
|
||||
TokenizerI = None
|
||||
nltk = None
|
||||
else:
|
||||
try:
|
||||
import nltk
|
||||
except ImportError:
|
||||
nltk = None
|
||||
|
||||
# Optional: tiktoken import for future token estimation (not used in chunking)
|
||||
try:
|
||||
@@ -79,41 +80,8 @@ def chunk_text(text: str, chunk_size: int = 1000, overlap: int = 100) -> list[st
|
||||
if tokenizer is None:
|
||||
tokenizer = _get_tokenizer()
|
||||
|
||||
if tokenizer is None:
|
||||
# Fallback: split by whitespace tokens using SimpleTokenizer,
|
||||
# or by characters if only one token
|
||||
simple_tokenizer = SimpleTokenizer()
|
||||
tokens = list(simple_tokenizer.span_tokenize(text))
|
||||
if len(tokens) <= 1:
|
||||
# No spaces, fallback to character-based chunking
|
||||
chunks = []
|
||||
i = 0
|
||||
while i < len(text):
|
||||
chunk = text[i : i + chunk_size]
|
||||
chunks.append(chunk)
|
||||
i += chunk_size - overlap if (chunk_size - overlap) > 0 else chunk_size
|
||||
return chunks
|
||||
else:
|
||||
chunks = []
|
||||
step = chunk_size - overlap
|
||||
if step <= 0:
|
||||
step = chunk_size
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
chunk_tokens = tokens[i : i + chunk_size]
|
||||
if chunk_tokens:
|
||||
start_pos = chunk_tokens[0][0]
|
||||
end_pos = chunk_tokens[-1][1]
|
||||
chunks.append(text[start_pos:end_pos])
|
||||
i += step
|
||||
return chunks
|
||||
|
||||
if not isinstance(tokenizer, TokenizerProtocol):
|
||||
raise RuntimeError(
|
||||
"Tokenizer does not implement the required span_tokenize method."
|
||||
)
|
||||
|
||||
# Tokenize and create chunks
|
||||
# _get_tokenizer() always returns a tokenizer (with fallback), so no need to check
|
||||
# None
|
||||
tokens = list(tokenizer.span_tokenize(text))
|
||||
|
||||
# Fallback to character-based splitting if only a single token
|
||||
@@ -128,6 +96,9 @@ def chunk_text(text: str, chunk_size: int = 1000, overlap: int = 100) -> list[st
|
||||
return char_chunks
|
||||
|
||||
chunks = []
|
||||
step = chunk_size - overlap
|
||||
if step <= 0:
|
||||
step = chunk_size
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
chunk_tokens = tokens[i : i + chunk_size]
|
||||
@@ -135,7 +106,7 @@ def chunk_text(text: str, chunk_size: int = 1000, overlap: int = 100) -> list[st
|
||||
start_pos = chunk_tokens[0][0]
|
||||
end_pos = chunk_tokens[-1][1]
|
||||
chunks.append(text[start_pos:end_pos])
|
||||
i += chunk_size - overlap
|
||||
i += step
|
||||
|
||||
return chunks
|
||||
|
||||
@@ -156,12 +127,9 @@ def _get_tokenizer() -> TokenizerProtocol:
|
||||
return None
|
||||
|
||||
# Download the required NLTK data
|
||||
if nltk is not None:
|
||||
nltk.download(tokenizer_path.split("/")[0], quiet=True)
|
||||
# Load the tokenizer
|
||||
tokenizer_obj: Any = nltk.data.load(tokenizer_path)
|
||||
else:
|
||||
return None
|
||||
nltk.download(tokenizer_path.split("/")[0], quiet=True)
|
||||
# Load the tokenizer
|
||||
tokenizer_obj: Any = nltk.data.load(tokenizer_path)
|
||||
|
||||
# Check if it's a tokenizer that implements span_tokenize
|
||||
if hasattr(tokenizer_obj, "span_tokenize") and callable(
|
||||
@@ -206,7 +174,7 @@ def _get_tokenizer() -> TokenizerProtocol:
|
||||
|
||||
# Final fallback to word tokenizer
|
||||
try:
|
||||
if nltk is not None:
|
||||
if nltk:
|
||||
nltk.download("punkt", quiet=True)
|
||||
from nltk.tokenize import word_tokenize
|
||||
|
||||
@@ -267,6 +235,7 @@ def split_into_sentences(text: str) -> list[str]:
|
||||
Returns:
|
||||
List of sentences
|
||||
"""
|
||||
# Check if nltk exists and is not None before checking hasattr
|
||||
if nltk and hasattr(nltk, "data"):
|
||||
try:
|
||||
nltk.download("punkt", quiet=True)
|
||||
|
||||
@@ -269,7 +269,7 @@ def detect_content_type(url: str, content: str) -> str:
|
||||
if isinstance(result, dict) and "type" in result:
|
||||
content_heuristic_type = str(result["type"])
|
||||
break
|
||||
if isinstance(result, str) and result is not None:
|
||||
if isinstance(result, str) and result:
|
||||
content_heuristic_type = result
|
||||
break
|
||||
except Exception:
|
||||
|
||||
@@ -176,7 +176,12 @@ def validate_content(content: str, min_length: int = MIN_CONTENT_LENGTH) -> bool
|
||||
Returns:
|
||||
True if content is valid, False otherwise.
|
||||
"""
|
||||
if not content or not isinstance(content, str):
|
||||
if not isinstance(content, str):
|
||||
warning_highlight(
|
||||
"Content is not a string.", category="content_validation"
|
||||
)
|
||||
return False
|
||||
if not content:
|
||||
warning_highlight(
|
||||
"Content is empty or not a string.", category="content_validation"
|
||||
)
|
||||
@@ -207,7 +212,7 @@ def preprocess_content(arg1: str | None, arg2: str) -> str:
|
||||
Returns:
|
||||
The cleaned and normalized content string.
|
||||
"""
|
||||
if not arg2 or not isinstance(arg2, str) or len(arg2.strip()) < MIN_CONTENT_LENGTH:
|
||||
if not isinstance(arg2, str) or not arg2 or len(arg2.strip()) < MIN_CONTENT_LENGTH:
|
||||
return ""
|
||||
content = arg2
|
||||
cleaned = re.sub(r"(?is)<.*?>", "", content)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from typing import ParamSpec, TypeVar, cast
|
||||
from typing import Any, ParamSpec, TypeVar, cast
|
||||
|
||||
from ..errors import ValidationError
|
||||
|
||||
@@ -28,7 +28,7 @@ def _check_type(value: object, expected_type: object) -> bool:
|
||||
|
||||
def validate_args(
|
||||
**validators: Callable[[object], tuple[bool, str | None]],
|
||||
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
||||
) -> Callable[[Callable[..., T]], Callable[..., T]]:
|
||||
"""Decorator to validate function arguments.
|
||||
|
||||
Args:
|
||||
@@ -50,17 +50,14 @@ def validate_args(
|
||||
```
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[P, T]) -> Callable[P, T]:
|
||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
def wrapper(*args: Any, **kwargs: Any) -> T:
|
||||
# Get function signature
|
||||
import inspect
|
||||
|
||||
sig = inspect.signature(cast("Callable[..., object]", func))
|
||||
# Convert ParamSpec args/kwargs to regular tuple/dict for binding
|
||||
bound_args = sig.bind(
|
||||
*cast("tuple[object, ...]", args), **cast("dict[str, object]", kwargs)
|
||||
)
|
||||
sig = inspect.signature(func)
|
||||
bound_args = sig.bind(*args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
|
||||
# Validate each argument
|
||||
@@ -73,7 +70,7 @@ def validate_args(
|
||||
f"Invalid argument '{arg_name}': {error_msg}"
|
||||
)
|
||||
|
||||
return func(*cast("P.args", args), **cast("P.kwargs", kwargs))
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -82,7 +79,7 @@ def validate_args(
|
||||
|
||||
def validate_return(
|
||||
validator: Callable[[object], tuple[bool, str | None]],
|
||||
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
||||
) -> Callable[[Callable[..., T]], Callable[..., T]]:
|
||||
"""Decorator to validate function return value.
|
||||
|
||||
Args:
|
||||
@@ -99,9 +96,9 @@ def validate_return(
|
||||
```
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[P, T]) -> Callable[P, T]:
|
||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
def wrapper(*args: Any, **kwargs: Any) -> T:
|
||||
result = func(*args, **kwargs)
|
||||
is_valid, error_msg = validator(result)
|
||||
if not is_valid:
|
||||
@@ -117,7 +114,7 @@ def validate_return(
|
||||
|
||||
def validate_not_none(
|
||||
*arg_names: str,
|
||||
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
||||
) -> Callable[[Callable[..., T]], Callable[..., T]]:
|
||||
"""Decorator to ensure specified arguments are not None.
|
||||
|
||||
Args:
|
||||
@@ -127,16 +124,13 @@ def validate_not_none(
|
||||
Decorated function
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[P, T]) -> Callable[P, T]:
|
||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
def wrapper(*args: Any, **kwargs: Any) -> T:
|
||||
import inspect
|
||||
|
||||
sig = inspect.signature(cast("Callable[..., object]", func))
|
||||
# Convert ParamSpec args/kwargs to regular tuple/dict for binding
|
||||
bound_args = sig.bind(
|
||||
*cast("tuple[object, ...]", args), **cast("dict[str, object]", kwargs)
|
||||
)
|
||||
sig = inspect.signature(func)
|
||||
bound_args = sig.bind(*args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
|
||||
for arg_name in arg_names:
|
||||
@@ -146,7 +140,7 @@ def validate_not_none(
|
||||
):
|
||||
raise ValidationError(f"Argument '{arg_name}' cannot be None")
|
||||
|
||||
return func(*cast("P.args", args), **cast("P.kwargs", kwargs))
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -155,7 +149,7 @@ def validate_not_none(
|
||||
|
||||
def validate_types(
|
||||
**expected_types: type,
|
||||
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
||||
) -> Callable[[Callable[..., T]], Callable[..., T]]:
|
||||
"""Decorator to validate argument types.
|
||||
|
||||
Args:
|
||||
@@ -172,16 +166,13 @@ def validate_types(
|
||||
```
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[P, T]) -> Callable[P, T]:
|
||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
def wrapper(*args: Any, **kwargs: Any) -> T:
|
||||
import inspect
|
||||
|
||||
sig = inspect.signature(cast("Callable[..., object]", func))
|
||||
# Convert ParamSpec args/kwargs to regular tuple/dict for binding
|
||||
bound_args = sig.bind(
|
||||
*cast("tuple[object, ...]", args), **cast("dict[str, object]", kwargs)
|
||||
)
|
||||
sig = inspect.signature(func)
|
||||
bound_args = sig.bind(*args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
|
||||
for arg_name, expected_type in expected_types.items():
|
||||
@@ -196,7 +187,7 @@ def validate_types(
|
||||
f"got {actual_type}"
|
||||
)
|
||||
|
||||
return func(*cast("P.args", args), **cast("P.kwargs", kwargs))
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@@ -169,7 +169,7 @@ async def process_document(url: str, content: bytes | None = None) -> tuple[str,
|
||||
cache_key = f"document_processing_{url_hash}"
|
||||
cached_result = _get_from_cache(cache_key)
|
||||
|
||||
if cached_result and isinstance(cached_result, dict):
|
||||
if cached_result:
|
||||
logger.info(f"Using cached document extraction for {url}")
|
||||
text = cached_result.get("text", "")
|
||||
content_type = cached_result.get("content_type", document_type)
|
||||
@@ -219,6 +219,8 @@ async def process_document(url: str, content: bytes | None = None) -> tuple[str,
|
||||
logger.info(
|
||||
f"Document processed in {time.time() - start_time:.2f}s [{document_type}]"
|
||||
)
|
||||
if extracted_text is None:
|
||||
raise ValueError("Extracted text is None after processing")
|
||||
return extracted_text, document_type
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -136,7 +136,7 @@ def mark_validated(func: Callable[..., Any]) -> None:
|
||||
|
||||
def has_errors(state: StateDict) -> bool:
|
||||
"""Check if the state contains errors."""
|
||||
return isinstance(state, dict) and "errors" in state and bool(state["errors"])
|
||||
return "errors" in state and bool(state["errors"])
|
||||
|
||||
|
||||
def add_exception_error(state: StateDict, exc: Exception, source: str) -> StateDict:
|
||||
@@ -177,21 +177,15 @@ def validate_node_input(input_model: type[BaseModel]) -> Callable[[F], F]:
|
||||
return state
|
||||
try:
|
||||
validator(state)
|
||||
async_func = cast("Callable[..., Awaitable[StateDict]]", func)
|
||||
async_func = cast(Callable[..., Awaitable[StateDict]], func)
|
||||
result = await async_func(state, *args, **kwargs)
|
||||
return cast("StateDict", result)
|
||||
return cast(StateDict, result)
|
||||
except NodeValidationError as exc:
|
||||
validation_source: str = f"{func.__module__}.{func.__name__}"
|
||||
return cast(
|
||||
"StateDict",
|
||||
add_exception_error(state, exc, validation_source),
|
||||
)
|
||||
return add_exception_error(state, exc, validation_source)
|
||||
except Exception as exc:
|
||||
exception_source: str = f"{func.__module__}.{func.__name__}"
|
||||
return cast(
|
||||
"StateDict",
|
||||
add_exception_error(state, exc, exception_source),
|
||||
)
|
||||
return add_exception_error(state, exc, exception_source)
|
||||
|
||||
mark_validated(async_wrapper)
|
||||
return cast("F", async_wrapper)
|
||||
@@ -206,19 +200,13 @@ def validate_node_input(input_model: type[BaseModel]) -> Callable[[F], F]:
|
||||
try:
|
||||
validator(state)
|
||||
result = func(state, *args, **kwargs)
|
||||
return cast("StateDict", result)
|
||||
return cast(StateDict, result)
|
||||
except NodeValidationError as exc:
|
||||
sync_source: str = f"{func.__module__}.{func.__name__}"
|
||||
return cast(
|
||||
"StateDict",
|
||||
add_exception_error(state, exc, sync_source),
|
||||
)
|
||||
return add_exception_error(state, exc, sync_source)
|
||||
except Exception as exc:
|
||||
sync_exc_source: str = f"{func.__module__}.{func.__name__}"
|
||||
return cast(
|
||||
"StateDict",
|
||||
add_exception_error(state, exc, sync_exc_source),
|
||||
)
|
||||
return add_exception_error(state, exc, sync_exc_source)
|
||||
|
||||
mark_validated(sync_wrapper)
|
||||
return cast("F", sync_wrapper)
|
||||
@@ -245,20 +233,16 @@ def validate_node_output(output_model: type[BaseModel]) -> Callable[[F], F]:
|
||||
) -> StateDict:
|
||||
async_func = cast("Callable[..., Awaitable[StateDict]]", func)
|
||||
result = await async_func(state, *args, **kwargs)
|
||||
result_dict: StateDict = result
|
||||
try:
|
||||
pydantic_validator = cast(
|
||||
"Callable[[dict[str, Any]], BaseModel]", validator
|
||||
)
|
||||
validated_result: BaseModel = pydantic_validator(
|
||||
cast("dict[str, Any]", result)
|
||||
)
|
||||
return cast("StateDict", validated_result.model_dump())
|
||||
validated_result: BaseModel = pydantic_validator(result_dict)
|
||||
return validated_result.model_dump()
|
||||
except NodeValidationError as e:
|
||||
source: str = f"{func.__module__}.{func.__name__}"
|
||||
return cast(
|
||||
"StateDict",
|
||||
add_exception_error(cast("StateDict", result), e, source),
|
||||
)
|
||||
return add_exception_error(result_dict, e, source)
|
||||
|
||||
mark_validated(async_wrapper)
|
||||
return cast("F", async_wrapper)
|
||||
@@ -268,21 +252,18 @@ def validate_node_output(output_model: type[BaseModel]) -> Callable[[F], F]:
|
||||
def sync_wrapper(
|
||||
state: StateDict, *args: object, **kwargs: object
|
||||
) -> StateDict:
|
||||
result = func(state, *args, **kwargs)
|
||||
sync_func = cast("Callable[..., StateDict]", func)
|
||||
result = sync_func(state, *args, **kwargs)
|
||||
result_dict: StateDict = result
|
||||
try:
|
||||
pydantic_validator = cast(
|
||||
"Callable[[dict[str, Any]], BaseModel]", validator
|
||||
)
|
||||
validated_result: BaseModel = pydantic_validator(
|
||||
cast("dict[str, Any]", result)
|
||||
)
|
||||
return cast("StateDict", validated_result.model_dump())
|
||||
validated_result: BaseModel = pydantic_validator(result_dict)
|
||||
return validated_result.model_dump()
|
||||
except NodeValidationError as e:
|
||||
source: str = f"{func.__module__}.{func.__name__}"
|
||||
return cast(
|
||||
"StateDict",
|
||||
add_exception_error(cast("StateDict", result), e, source),
|
||||
)
|
||||
return add_exception_error(result_dict, e, source)
|
||||
|
||||
mark_validated(sync_wrapper)
|
||||
return cast("F", sync_wrapper)
|
||||
@@ -290,14 +271,14 @@ def validate_node_output(output_model: type[BaseModel]) -> Callable[[F], F]:
|
||||
return decorator
|
||||
|
||||
|
||||
def validated_node(
|
||||
_func: Any = None,
|
||||
def validated_node[F: Callable[..., object]](
|
||||
_func: F | None = None,
|
||||
*,
|
||||
name: str | None = None,
|
||||
input_model: type[BaseModel] | None = None,
|
||||
output_model: type[BaseModel] | None = None,
|
||||
**metadata: object,
|
||||
) -> Any:
|
||||
) -> F | Callable[[F], F]:
|
||||
"""Create a validated node with input and output models.
|
||||
|
||||
Decorates a function to create a node that validates both input and
|
||||
@@ -333,7 +314,7 @@ def validated_node(
|
||||
return decorator
|
||||
else:
|
||||
# Called without parameters: @validated_node
|
||||
return cast("F | Callable[[F], F]", decorator(_func))
|
||||
return decorator(_func)
|
||||
|
||||
|
||||
async def validate_graph(graph: object, graph_id: str = "unknown") -> bool:
|
||||
@@ -390,24 +371,31 @@ async def ensure_graph_compatibility(
|
||||
|
||||
|
||||
async def validate_all_graphs(
|
||||
graph_functions: dict[str, Callable[[], Awaitable[Any]]],
|
||||
graph_functions: dict[str, object],
|
||||
) -> bool:
|
||||
"""Validate all graph creation functions.
|
||||
|
||||
Performs batch validation of multiple graph creation functions.
|
||||
"""
|
||||
all_valid: bool = True
|
||||
for name, func in graph_functions.items():
|
||||
for name in graph_functions:
|
||||
try:
|
||||
func = graph_functions[name]
|
||||
if not asyncio.iscoroutinefunction(func):
|
||||
all_valid = False
|
||||
continue
|
||||
# Type narrowing for pyrefly - we know it's a coroutine function now
|
||||
# Use a more specific cast to help with type inference
|
||||
if callable(func):
|
||||
# Use type ignore to suppress the complex type inference issue
|
||||
graph = await func() # type: ignore[misc]
|
||||
await validate_graph(graph, name)
|
||||
# Call the function directly since we know it's a coroutine function
|
||||
# Use inspect to determine if function takes arguments
|
||||
sig = inspect.signature(func)
|
||||
if len(sig.parameters) == 0:
|
||||
# Cast to the expected coroutine function type
|
||||
coro_func = cast("Callable[[], Awaitable[object]]", func)
|
||||
graph = await coro_func()
|
||||
else:
|
||||
# Skip functions that require arguments for now
|
||||
all_valid = False
|
||||
continue
|
||||
await validate_graph(graph, name)
|
||||
except Exception:
|
||||
all_valid = False
|
||||
return all_valid
|
||||
|
||||
@@ -324,16 +324,5 @@ def _handle_list_extend(
|
||||
if sublist_str not in seen_items:
|
||||
seen_items.add(sublist_str)
|
||||
if not isinstance(deduped_sublist, list):
|
||||
deduped_sublist = [deduped_sublist] # type: ignore[unreachable]
|
||||
deduped_sublist = [deduped_sublist]
|
||||
merged_list.append(deduped_sublist)
|
||||
|
||||
|
||||
def _handle_average_collection(
|
||||
merged: dict[str, dict[str, list[float]]], key: str, value: float
|
||||
) -> None:
|
||||
"""Collect values for averaging under merged['values'][key]."""
|
||||
if "values" not in merged:
|
||||
merged["values"] = {}
|
||||
if key not in merged["values"]:
|
||||
merged["values"][key] = []
|
||||
merged["values"][key].append(value)
|
||||
|
||||
@@ -335,9 +335,7 @@ def count_recent_sources(sources: list["Source"], recency_threshold: int) -> int
|
||||
try:
|
||||
if published_date.endswith("Z"):
|
||||
published_date = published_date.replace("Z", "+00:00")
|
||||
date_candidate = datetime.fromisoformat(published_date)
|
||||
if isinstance(date_candidate, datetime):
|
||||
date = date_candidate
|
||||
date = datetime.fromisoformat(published_date)
|
||||
except ValueError:
|
||||
parsed_result = dateutil_parser.parse(published_date)
|
||||
# dateutil_parser.parse may rarely return a tuple
|
||||
@@ -345,20 +343,20 @@ def count_recent_sources(sources: list["Source"], recency_threshold: int) -> int
|
||||
if isinstance(parsed_result, tuple):
|
||||
# Take the first element if it's a datetime
|
||||
first_elem = parsed_result[0] if parsed_result else None
|
||||
if isinstance(first_elem, datetime):
|
||||
if first_elem:
|
||||
date = first_elem
|
||||
else:
|
||||
continue
|
||||
elif isinstance(parsed_result, datetime):
|
||||
elif parsed_result:
|
||||
date = parsed_result
|
||||
else:
|
||||
continue
|
||||
elif isinstance(published_date, datetime):
|
||||
elif published_date:
|
||||
date = published_date
|
||||
else:
|
||||
continue
|
||||
# Ensure date is a datetime object
|
||||
if not isinstance(date, datetime):
|
||||
# Ensure date is valid
|
||||
if not (date and isinstance(date, datetime)):
|
||||
continue
|
||||
if date.tzinfo is None:
|
||||
date = date.replace(tzinfo=UTC)
|
||||
@@ -400,12 +398,12 @@ def extract_topics_from_facts(facts: list["ExtractedFact"]) -> list[str]:
|
||||
def get_topics_in_fact(fact: "ExtractedFact") -> set[str]:
|
||||
"""Extract topics from a single fact."""
|
||||
topics = set()
|
||||
if "data" in fact and isinstance(fact["data"], dict):
|
||||
if "data" in fact and fact["data"]:
|
||||
data = fact["data"]
|
||||
if fact.get("type") == "vendor":
|
||||
if "vendor_name" in data:
|
||||
vendor_name = data["vendor_name"]
|
||||
if isinstance(vendor_name, str):
|
||||
if vendor_name and isinstance(vendor_name, str):
|
||||
topics.add(vendor_name.lower())
|
||||
elif fact.get("type") == "relationship":
|
||||
entities = data.get("entities", [])
|
||||
@@ -426,7 +424,7 @@ def get_topics_in_fact(fact: "ExtractedFact") -> set[str]:
|
||||
description = data["description"]
|
||||
if isinstance(description, str):
|
||||
extract_noun_phrases(description, topics)
|
||||
if "source_text" in fact and isinstance(fact["source_text"], str):
|
||||
if "source_text" in fact and fact["source_text"]:
|
||||
extract_noun_phrases(fact["source_text"], topics)
|
||||
return topics
|
||||
|
||||
@@ -449,10 +447,10 @@ def perform_statistical_validation(facts: list["ExtractedFact"]) -> float:
|
||||
pattern = re.compile(r"\b\d+(?:\.\d+)?\b")
|
||||
for fact in facts:
|
||||
source_text = fact.get("source_text", "")
|
||||
if isinstance(source_text, str):
|
||||
if source_text:
|
||||
found_numbers = pattern.findall(source_text)
|
||||
numeric_values.extend(float(n) for n in found_numbers)
|
||||
if "data" in fact and isinstance(fact["data"], dict):
|
||||
if "data" in fact and fact["data"]:
|
||||
for _key, val in fact["data"].items():
|
||||
if isinstance(val, int | float):
|
||||
numeric_values.append(float(val))
|
||||
|
||||
@@ -25,7 +25,7 @@ def validate_type[T](value: object, expected_type: type[T]) -> tuple[bool, str |
|
||||
return False, f"Expected {expected_name}, got {actual_type}"
|
||||
|
||||
|
||||
def is_valid_email(email: str) -> bool:
|
||||
def is_valid_email(email: object) -> bool:
|
||||
"""Check if string is a valid email address.
|
||||
|
||||
Args:
|
||||
@@ -42,7 +42,7 @@ def is_valid_email(email: str) -> bool:
|
||||
return bool(re.match(pattern, email))
|
||||
|
||||
|
||||
def is_valid_url(url: str) -> bool:
|
||||
def is_valid_url(url: object) -> bool:
|
||||
"""Check if string is a valid URL.
|
||||
|
||||
Args:
|
||||
@@ -61,7 +61,7 @@ def is_valid_url(url: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def is_valid_phone(phone: str) -> bool:
|
||||
def is_valid_phone(phone: object) -> bool:
|
||||
"""Check if string is a valid phone number.
|
||||
|
||||
Args:
|
||||
@@ -81,7 +81,7 @@ def is_valid_phone(phone: str) -> bool:
|
||||
|
||||
|
||||
def validate_string_length(
|
||||
value: str,
|
||||
value: object,
|
||||
min_length: int | None = None,
|
||||
max_length: int | None = None,
|
||||
) -> tuple[bool, str | None]:
|
||||
@@ -110,7 +110,7 @@ def validate_string_length(
|
||||
|
||||
|
||||
def validate_number_range(
|
||||
value: int | float,
|
||||
value: object,
|
||||
min_value: int | float | None = None,
|
||||
max_value: int | float | None = None,
|
||||
) -> tuple[bool, str | None]:
|
||||
@@ -124,7 +124,7 @@ def validate_number_range(
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
if not isinstance(value, int | float):
|
||||
if not isinstance(value, (int, float)):
|
||||
return False, "Value must be a number"
|
||||
|
||||
if min_value is not None and value < min_value:
|
||||
@@ -137,7 +137,7 @@ def validate_number_range(
|
||||
|
||||
|
||||
def validate_list_length(
|
||||
value: list[object],
|
||||
value: object,
|
||||
min_length: int | None = None,
|
||||
max_length: int | None = None,
|
||||
) -> tuple[bool, str | None]:
|
||||
|
||||
@@ -17,7 +17,7 @@ from bb_core.caching import FileCache
|
||||
# NOTE: LLMCache is not available in bb_core
|
||||
# FileCache has been replaced with FileCache
|
||||
|
||||
UTC = UTC
|
||||
# UTC is already imported from datetime
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -68,12 +68,14 @@ class TestFileCache:
|
||||
async def test_ainit_idempotent(self, file_backend: FileCache) -> None:
|
||||
"""Test that initialization can happen multiple times safely."""
|
||||
# FileCache initializes on first use
|
||||
await file_backend._ensure_initialized()
|
||||
assert file_backend._initialized
|
||||
await file_backend.ensure_initialized()
|
||||
# Verify by checking if cache directory exists
|
||||
assert file_backend.cache_dir.exists()
|
||||
|
||||
# Should be safe to call again
|
||||
await file_backend._ensure_initialized()
|
||||
assert file_backend._initialized
|
||||
await file_backend.ensure_initialized()
|
||||
# Directory should still exist
|
||||
assert file_backend.cache_dir.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ainit_failure(self, temp_cache_dir: str) -> None:
|
||||
|
||||
@@ -48,7 +48,7 @@ class TestCacheIntegration:
|
||||
kwargs = {"flag": True, "data": {"key": "value"}, "custom_obj": TestClass()}
|
||||
|
||||
# Generate key and set value
|
||||
key = cache_instance._generate_key(args, cast("dict[str, object]", kwargs)) # type: ignore
|
||||
key = cache_instance._generate_key(args, cast("dict[str, object]", kwargs))
|
||||
test_value = {"result": "success", "count": 123}
|
||||
await cache_instance.set(key, test_value)
|
||||
|
||||
|
||||
@@ -54,9 +54,9 @@ class TestLLMCache:
|
||||
backend = MockBackend()
|
||||
cache = LLMCache(backend=cast("CacheBackend[Any]", backend))
|
||||
|
||||
# Backend should be set
|
||||
assert cache._backend is backend
|
||||
assert not cache._ainit_done
|
||||
# Backend should be set (accessing protected attributes in tests)
|
||||
assert cache._backend is backend # noqa: SLF001
|
||||
assert not cache._ainit_done # noqa: SLF001
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_without_backend(self) -> None:
|
||||
@@ -64,11 +64,12 @@ class TestLLMCache:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cache = LLMCache[str](cache_dir=tmpdir, ttl=3600, serializer="pickle")
|
||||
|
||||
# Should create AsyncFileCacheBackend
|
||||
assert isinstance(cache._backend, AsyncFileCacheBackend)
|
||||
assert str(cache._backend.cache_dir) == tmpdir
|
||||
assert cache._backend.ttl == 3600
|
||||
assert cache._backend.serializer == "pickle"
|
||||
# Should create AsyncFileCacheBackend (accessing protected attributes in
|
||||
# tests)
|
||||
assert isinstance(cache._backend, AsyncFileCacheBackend) # noqa: SLF001
|
||||
assert str(cache._backend.cache_dir) == tmpdir # noqa: SLF001
|
||||
assert cache._backend.ttl == 3600 # noqa: SLF001
|
||||
assert cache._backend.serializer == "pickle" # noqa: SLF001
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_backend_initialized(self) -> None:
|
||||
@@ -76,14 +77,14 @@ class TestLLMCache:
|
||||
backend = MockBackend()
|
||||
cache = LLMCache(backend=cast("CacheBackend[Any]", backend))
|
||||
|
||||
# First call should initialize
|
||||
await cache._ensure_backend_initialized()
|
||||
# First call should initialize (accessing protected method in tests)
|
||||
await cache._ensure_backend_initialized() # noqa: SLF001
|
||||
assert backend.ainit_called == 1
|
||||
assert cache._ainit_done
|
||||
assert cache._ainit_done # noqa: SLF001
|
||||
|
||||
# Subsequent calls should not reinitialize
|
||||
await cache._ensure_backend_initialized()
|
||||
await cache._ensure_backend_initialized()
|
||||
await cache._ensure_backend_initialized() # noqa: SLF001
|
||||
await cache._ensure_backend_initialized() # noqa: SLF001
|
||||
assert backend.ainit_called == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -108,25 +109,26 @@ class TestLLMCache:
|
||||
backend = SimpleBackend()
|
||||
cache: LLMCache[None] = LLMCache(backend=backend)
|
||||
|
||||
# Should handle backend without ainit gracefully
|
||||
await cache._ensure_backend_initialized()
|
||||
assert cache._ainit_done
|
||||
# Should handle backend without ainit gracefully (accessing protected method
|
||||
# in tests)
|
||||
await cache._ensure_backend_initialized() # noqa: SLF001
|
||||
assert cache._ainit_done # noqa: SLF001
|
||||
|
||||
def test_generate_key_basic(self) -> None:
|
||||
"""Test basic key generation."""
|
||||
cache = LLMCache()
|
||||
|
||||
# Test with simple args
|
||||
key1 = cache._generate_key(("hello", 42), {"flag": True})
|
||||
# Test with simple args (accessing protected method in tests)
|
||||
key1 = cache._generate_key(("hello", 42), {"flag": True}) # noqa: SLF001
|
||||
assert isinstance(key1, str)
|
||||
assert len(key1) == 64 # SHA-256 hex digest length
|
||||
|
||||
# Same args should generate same key
|
||||
key2 = cache._generate_key(("hello", 42), {"flag": True})
|
||||
key2 = cache._generate_key(("hello", 42), {"flag": True}) # noqa: SLF001
|
||||
assert key1 == key2
|
||||
|
||||
# Different args should generate different keys
|
||||
key3 = cache._generate_key(("hello", 43), {"flag": True})
|
||||
key3 = cache._generate_key(("hello", 43), {"flag": True}) # noqa: SLF001
|
||||
assert key1 != key3
|
||||
|
||||
def test_generate_key_complex_types(self) -> None:
|
||||
@@ -144,7 +146,7 @@ class TestLLMCache:
|
||||
)
|
||||
complex_kwargs = {"list": [4, 5, 6], "tuple": (7, 8, 9), "none": None}
|
||||
|
||||
key = cache._generate_key(
|
||||
key = cache._generate_key( # noqa: SLF001
|
||||
complex_args, cast("dict[str, object]", complex_kwargs)
|
||||
)
|
||||
assert isinstance(key, str)
|
||||
@@ -154,13 +156,13 @@ class TestLLMCache:
|
||||
"""Test key generation with empty arguments."""
|
||||
cache = LLMCache()
|
||||
|
||||
# Empty args and kwargs
|
||||
key = cache._generate_key((), {})
|
||||
# Empty args and kwargs (accessing protected method in tests)
|
||||
key = cache._generate_key((), {}) # noqa: SLF001
|
||||
assert isinstance(key, str)
|
||||
assert len(key) == 64
|
||||
|
||||
# The key should be consistent for empty args
|
||||
key2 = cache._generate_key((), {})
|
||||
key2 = cache._generate_key((), {}) # noqa: SLF001
|
||||
assert key == key2
|
||||
|
||||
def test_generate_key_json_serialization_error(self) -> None:
|
||||
@@ -172,8 +174,9 @@ class TestLLMCache:
|
||||
def __str__(self) -> str:
|
||||
return "non-serializable"
|
||||
|
||||
# This should fall back to string representation
|
||||
key = cache._generate_key((NonSerializable(),), {})
|
||||
# This should fall back to string representation (accessing protected method
|
||||
# in tests)
|
||||
key = cache._generate_key((NonSerializable(),), {}) # noqa: SLF001
|
||||
assert isinstance(key, str)
|
||||
assert len(key) == 64
|
||||
|
||||
@@ -186,8 +189,9 @@ class TestLLMCache:
|
||||
def __str__(self) -> str:
|
||||
raise RuntimeError("Cannot convert to string")
|
||||
|
||||
# Should still generate a key using error messages
|
||||
key = cache._generate_key((FailingObject(),), {})
|
||||
# Should still generate a key using error messages (accessing protected
|
||||
# method in tests)
|
||||
key = cache._generate_key((FailingObject(),), {}) # noqa: SLF001
|
||||
assert isinstance(key, str)
|
||||
assert len(key) == 64
|
||||
|
||||
@@ -290,20 +294,21 @@ class TestLLMCache:
|
||||
args = ("test", 123, [1, 2, 3])
|
||||
kwargs = {"flag": True, "data": {"nested": "value"}}
|
||||
|
||||
key1 = cache1._generate_key(args, cast("dict[str, object]", kwargs))
|
||||
key2 = cache2._generate_key(args, cast("dict[str, object]", kwargs))
|
||||
key1 = cache1._generate_key(args, cast("dict[str, object]", kwargs)) # noqa: SLF001
|
||||
key2 = cache2._generate_key(args, cast("dict[str, object]", kwargs)) # noqa: SLF001
|
||||
|
||||
assert key1 == key2
|
||||
|
||||
def test_generate_key_order_matters(self) -> None:
|
||||
"""Test that argument order affects key generation."""
|
||||
cache = LLMCache()
|
||||
# Different arg order should produce different keys
|
||||
key1 = cache._generate_key((1, 2), {})
|
||||
key2 = cache._generate_key((2, 1), {})
|
||||
# Different arg order should produce different keys (accessing protected
|
||||
# method in tests)
|
||||
key1 = cache._generate_key((1, 2), {}) # noqa: SLF001
|
||||
key2 = cache._generate_key((2, 1), {}) # noqa: SLF001
|
||||
assert key1 != key2
|
||||
|
||||
# Different kwarg order should produce same key (sorted)
|
||||
key3 = cache._generate_key((), {"a": 1, "b": 2})
|
||||
key4 = cache._generate_key((), {"b": 2, "a": 1})
|
||||
key3 = cache._generate_key((), {"a": 1, "b": 2}) # noqa: SLF001
|
||||
key4 = cache._generate_key((), {"b": 2, "a": 1}) # noqa: SLF001
|
||||
assert key3 == key4
|
||||
|
||||
@@ -57,8 +57,8 @@ class TestRedisCache:
|
||||
|
||||
assert cache.url == "redis://example.com:6379"
|
||||
assert cache.key_prefix == "myapp:"
|
||||
assert cache._decode_responses is True # pyright: ignore[reportPrivateUsage]
|
||||
assert cache._client is None # pyright: ignore[reportPrivateUsage]
|
||||
assert cache._decode_responses is True
|
||||
assert cache._client is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_connected_success(
|
||||
@@ -66,10 +66,10 @@ class TestRedisCache:
|
||||
) -> None:
|
||||
"""Test successful Redis connection."""
|
||||
with patch("redis.asyncio.from_url", return_value=mock_redis_client):
|
||||
client = await cache._ensure_connected() # pyright: ignore[reportPrivateUsage]
|
||||
client = await cache._ensure_connected()
|
||||
|
||||
assert client == mock_redis_client
|
||||
assert cache._client == mock_redis_client # pyright: ignore[reportPrivateUsage]
|
||||
assert cache._client == mock_redis_client
|
||||
mock_redis_client.ping.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -82,13 +82,13 @@ class TestRedisCache:
|
||||
patch("redis.asyncio.from_url", return_value=mock_client),
|
||||
pytest.raises(ConfigurationError, match="Failed to connect to Redis"),
|
||||
):
|
||||
await cache._ensure_connected()
|
||||
await cache._ensure_connected() # noqa: SLF001
|
||||
|
||||
def test_make_key(self, cache: RedisCache) -> None:
|
||||
"""Test key prefixing."""
|
||||
assert cache._make_key("mykey") == "test:mykey"
|
||||
assert cache._make_key("") == "test:"
|
||||
assert cache._make_key("path/to/key") == "test:path/to/key"
|
||||
"""Test key prefixing (accessing protected method in tests)."""
|
||||
assert cache._make_key("mykey") == "test:mykey" # noqa: SLF001
|
||||
assert cache._make_key("") == "test:" # noqa: SLF001
|
||||
assert cache._make_key("path/to/key") == "test:path/to/key" # noqa: SLF001
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_not_found(
|
||||
@@ -332,17 +332,18 @@ class TestRedisCache:
|
||||
await cache.close()
|
||||
|
||||
mock_redis_client.close.assert_called_once()
|
||||
assert cache._client is None
|
||||
assert cache._client is None # noqa: SLF001
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_without_connection(self, cache: RedisCache) -> None:
|
||||
"""Test closing when no connection exists."""
|
||||
assert cache._client is None
|
||||
"""Test closing when no connection exists (accessing protected attribute
|
||||
in tests)."""
|
||||
assert cache._client is None # noqa: SLF001
|
||||
|
||||
# Should not raise any exception
|
||||
await cache.close()
|
||||
|
||||
assert cache._client is None
|
||||
assert cache._client is None # noqa: SLF001
|
||||
|
||||
|
||||
class TestRedisCacheIntegration:
|
||||
|
||||
13
packages/business-buddy-core/tests/edge_helpers/__init__.py
Normal file
13
packages/business-buddy-core/tests/edge_helpers/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Edge helpers test suite.
|
||||
|
||||
Comprehensive tests for all edge helper modules covering:
|
||||
- Core routing factory functions
|
||||
- Flow control helpers
|
||||
- Validation helpers
|
||||
- User interaction helpers
|
||||
- Error handling helpers
|
||||
- Monitoring helpers
|
||||
|
||||
These tests ensure the edge helpers handle all edge cases properly
|
||||
and are robust for use in LangGraph workflows.
|
||||
"""
|
||||
458
packages/business-buddy-core/tests/edge_helpers/test_core.py
Normal file
458
packages/business-buddy-core/tests/edge_helpers/test_core.py
Normal file
@@ -0,0 +1,458 @@
|
||||
"""Comprehensive tests for core edge helper routing factories.
|
||||
|
||||
These tests cover edge cases including:
|
||||
- None/empty state values
|
||||
- Invalid data types
|
||||
- Missing keys
|
||||
- Boundary conditions
|
||||
- Both dict and StateProtocol implementations
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from bb_core.edge_helpers.core import (
|
||||
create_bool_router,
|
||||
create_enum_router,
|
||||
create_field_presence_router,
|
||||
create_list_length_router,
|
||||
create_status_router,
|
||||
create_threshold_router,
|
||||
)
|
||||
|
||||
|
||||
class MockState:
|
||||
"""Mock state object implementing StateProtocol."""
|
||||
|
||||
def __init__(self, data: dict[str, Any]):
|
||||
self._data = data
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
return self._data.get(key, default)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return self._data.get(name)
|
||||
|
||||
|
||||
class TestCreateEnumRouter:
|
||||
"""Test create_enum_router with various edge cases."""
|
||||
|
||||
def test_basic_enum_routing_dict_state(self):
|
||||
"""Test basic enum routing with dict state."""
|
||||
enum_mapping = {
|
||||
"continue": "next",
|
||||
"retry": "retry_node",
|
||||
"error": "error_node",
|
||||
}
|
||||
router = create_enum_router(enum_mapping)
|
||||
|
||||
assert router({"routing_decision": "continue"}) == "next"
|
||||
assert router({"routing_decision": "retry"}) == "retry_node"
|
||||
assert router({"routing_decision": "error"}) == "error_node"
|
||||
|
||||
def test_basic_enum_routing_state_protocol(self):
|
||||
"""Test basic enum routing with StateProtocol."""
|
||||
enum_mapping = {
|
||||
"continue": "next",
|
||||
"retry": "retry_node",
|
||||
"error": "error_node",
|
||||
}
|
||||
router = create_enum_router(enum_mapping)
|
||||
|
||||
state = MockState({"routing_decision": "continue"})
|
||||
assert router(state) == "next"
|
||||
|
||||
state = MockState({"routing_decision": "retry"})
|
||||
assert router(state) == "retry_node"
|
||||
|
||||
def test_default_target_fallback(self):
|
||||
"""Test fallback to default target."""
|
||||
enum_mapping = {"continue": "next"}
|
||||
router = create_enum_router(enum_mapping, default_target="fallback")
|
||||
|
||||
# Missing key should use default
|
||||
assert router({}) == "fallback"
|
||||
assert router({"routing_decision": None}) == "fallback"
|
||||
assert router({"routing_decision": "unknown"}) == "fallback"
|
||||
|
||||
def test_none_values(self):
|
||||
"""Test handling of None values."""
|
||||
enum_mapping = {"continue": "next"}
|
||||
router = create_enum_router(enum_mapping)
|
||||
|
||||
assert router({"routing_decision": None}) == "end"
|
||||
assert router(MockState({"routing_decision": None})) == "end"
|
||||
|
||||
def test_numeric_enum_values(self):
|
||||
"""Test enum routing with numeric values."""
|
||||
enum_mapping = {"1": "step_one", "2": "step_two", "0": "reset"}
|
||||
router = create_enum_router(enum_mapping, state_key="step")
|
||||
|
||||
assert router({"step": 1}) == "step_one"
|
||||
assert router({"step": "2"}) == "step_two"
|
||||
assert router({"step": 0}) == "reset"
|
||||
|
||||
def test_custom_state_key(self):
|
||||
"""Test with custom state key."""
|
||||
enum_mapping = {"active": "process", "inactive": "skip"}
|
||||
router = create_enum_router(enum_mapping, state_key="status")
|
||||
|
||||
assert router({"status": "active"}) == "process"
|
||||
assert router({"status": "inactive"}) == "skip"
|
||||
|
||||
def test_missing_state_key(self):
|
||||
"""Test behavior when state key is missing."""
|
||||
enum_mapping = {"continue": "next"}
|
||||
router = create_enum_router(enum_mapping, state_key="missing_key")
|
||||
|
||||
assert router({"other_key": "value"}) == "end"
|
||||
assert router(MockState({"other_key": "value"})) == "end"
|
||||
|
||||
def test_empty_state(self):
|
||||
"""Test with empty state."""
|
||||
enum_mapping = {"continue": "next"}
|
||||
router = create_enum_router(enum_mapping)
|
||||
|
||||
assert router({}) == "end"
|
||||
assert router(MockState({})) == "end"
|
||||
|
||||
|
||||
class TestCreateBoolRouter:
|
||||
"""Test create_bool_router with various edge cases."""
|
||||
|
||||
def test_basic_bool_routing(self):
|
||||
"""Test basic boolean routing."""
|
||||
router = create_bool_router("success", "failure")
|
||||
|
||||
assert router({"condition": True}) == "success"
|
||||
assert router({"condition": False}) == "failure"
|
||||
assert router(MockState({"condition": True})) == "success"
|
||||
|
||||
def test_truthy_falsy_values(self):
|
||||
"""Test truthy/falsy value handling."""
|
||||
router = create_bool_router("yes", "no", state_key="flag")
|
||||
|
||||
# Truthy values
|
||||
assert router({"flag": 1}) == "yes"
|
||||
assert router({"flag": "true"}) == "yes"
|
||||
assert router({"flag": [1, 2, 3]}) == "yes"
|
||||
assert router({"flag": {"key": "value"}}) == "yes"
|
||||
|
||||
# Falsy values
|
||||
assert router({"flag": 0}) == "no"
|
||||
assert router({"flag": ""}) == "no"
|
||||
assert router({"flag": []}) == "no"
|
||||
assert router({"flag": {}}) == "no"
|
||||
assert router({"flag": None}) == "no"
|
||||
|
||||
def test_missing_condition_key(self):
|
||||
"""Test behavior when condition key is missing."""
|
||||
router = create_bool_router("true_path", "false_path", state_key="missing")
|
||||
|
||||
assert router({}) == "false_path"
|
||||
assert router({"other_key": True}) == "false_path"
|
||||
|
||||
def test_default_false_behavior(self):
|
||||
"""Test default false behavior."""
|
||||
router = create_bool_router("yes", "no")
|
||||
|
||||
# Missing key should default to False
|
||||
assert router({}) == "no"
|
||||
assert router(MockState({})) == "no"
|
||||
|
||||
|
||||
class TestCreateThresholdRouter:
|
||||
"""Test create_threshold_router with various edge cases."""
|
||||
|
||||
def test_basic_threshold_routing(self):
|
||||
"""Test basic threshold comparison."""
|
||||
router = create_threshold_router(0.5, "high", "low")
|
||||
|
||||
assert router({"score": 0.7}) == "high"
|
||||
assert router({"score": 0.3}) == "low"
|
||||
assert router({"score": 0.5}) == "high" # Equal goes to above_target
|
||||
|
||||
def test_equal_target_parameter(self):
|
||||
"""Test custom equal_target parameter."""
|
||||
router = create_threshold_router(0.5, "high", "low", equal_target="equal")
|
||||
|
||||
assert router({"score": 0.5}) == "equal"
|
||||
assert router(MockState({"score": 0.5})) == "equal"
|
||||
|
||||
def test_string_numeric_conversion(self):
|
||||
"""Test conversion of string numbers."""
|
||||
router = create_threshold_router(10.0, "above", "below", state_key="value")
|
||||
|
||||
assert router({"value": "15.5"}) == "above"
|
||||
assert router({"value": "5"}) == "below"
|
||||
assert router({"value": "10.0"}) == "above"
|
||||
|
||||
def test_invalid_numeric_values(self):
|
||||
"""Test handling of invalid numeric values."""
|
||||
router = create_threshold_router(0.5, "high", "low")
|
||||
|
||||
# Non-numeric values should route to below_target
|
||||
assert router({"score": "invalid"}) == "low"
|
||||
assert router({"score": None}) == "low"
|
||||
assert router({"score": []}) == "low"
|
||||
assert router({"score": {}}) == "low"
|
||||
|
||||
def test_missing_score_key(self):
|
||||
"""Test behavior when score key is missing."""
|
||||
router = create_threshold_router(0.5, "high", "low", state_key="missing")
|
||||
|
||||
# Should default to 0.0 and route to below_target
|
||||
assert router({}) == "low"
|
||||
assert router({"other_key": 0.8}) == "low"
|
||||
|
||||
def test_boundary_conditions(self):
|
||||
"""Test exact boundary conditions."""
|
||||
router = create_threshold_router(0.0, "positive", "negative")
|
||||
|
||||
assert router({"score": 0.0}) == "positive"
|
||||
assert router({"score": -0.1}) == "negative"
|
||||
assert router({"score": 0.1}) == "positive"
|
||||
|
||||
|
||||
class TestCreateFieldPresenceRouter:
|
||||
"""Test create_field_presence_router with various edge cases."""
|
||||
|
||||
def test_all_fields_present(self):
|
||||
"""Test when all required fields are present."""
|
||||
router = create_field_presence_router(
|
||||
["field1", "field2", "field3"], "complete", "incomplete"
|
||||
)
|
||||
|
||||
state = {"field1": "value1", "field2": "value2", "field3": "value3"}
|
||||
assert router(state) == "complete"
|
||||
|
||||
def test_missing_fields(self):
|
||||
"""Test when some fields are missing."""
|
||||
router = create_field_presence_router(
|
||||
["field1", "field2", "field3"], "complete", "incomplete"
|
||||
)
|
||||
|
||||
# Missing field2
|
||||
state = {"field1": "value1", "field3": "value3"}
|
||||
assert router(state) == "incomplete"
|
||||
|
||||
# All missing
|
||||
assert router({}) == "incomplete"
|
||||
|
||||
def test_empty_string_values(self):
|
||||
"""Test handling of empty string values."""
|
||||
router = create_field_presence_router(
|
||||
["field1", "field2"], "complete", "incomplete"
|
||||
)
|
||||
|
||||
# Empty string should be considered missing
|
||||
state = {"field1": "value", "field2": ""}
|
||||
assert router(state) == "incomplete"
|
||||
|
||||
def test_none_values(self):
|
||||
"""Test handling of None values."""
|
||||
router = create_field_presence_router(
|
||||
["field1", "field2"], "complete", "incomplete"
|
||||
)
|
||||
|
||||
state = {"field1": "value", "field2": None}
|
||||
assert router(state) == "incomplete"
|
||||
|
||||
def test_state_protocol_implementation(self):
|
||||
"""Test with StateProtocol implementation."""
|
||||
router = create_field_presence_router(
|
||||
["field1", "field2"], "complete", "incomplete"
|
||||
)
|
||||
|
||||
# Complete case
|
||||
state = MockState({"field1": "value1", "field2": "value2"})
|
||||
assert router(state) == "complete"
|
||||
|
||||
# Incomplete case
|
||||
state = MockState({"field1": "value1"})
|
||||
assert router(state) == "incomplete"
|
||||
|
||||
def test_zero_and_false_values(self):
|
||||
"""Test that 0 and False are considered valid values."""
|
||||
router = create_field_presence_router(
|
||||
["number", "flag"], "complete", "incomplete"
|
||||
)
|
||||
|
||||
# 0 and False should be valid (not None or empty string)
|
||||
state = {"number": 0, "flag": False}
|
||||
assert router(state) == "complete"
|
||||
|
||||
|
||||
class TestCreateListLengthRouter:
|
||||
"""Test create_list_length_router with various edge cases."""
|
||||
|
||||
def test_sufficient_length(self):
|
||||
"""Test when list meets minimum length."""
|
||||
router = create_list_length_router(3, "enough", "not_enough")
|
||||
|
||||
assert router({"items": [1, 2, 3]}) == "enough"
|
||||
assert router({"items": [1, 2, 3, 4, 5]}) == "enough"
|
||||
|
||||
def test_insufficient_length(self):
|
||||
"""Test when list is too short."""
|
||||
router = create_list_length_router(3, "enough", "not_enough")
|
||||
|
||||
assert router({"items": [1, 2]}) == "not_enough"
|
||||
assert router({"items": []}) == "not_enough"
|
||||
|
||||
def test_exact_length(self):
|
||||
"""Test exact minimum length."""
|
||||
router = create_list_length_router(3, "enough", "not_enough")
|
||||
|
||||
assert router({"items": [1, 2, 3]}) == "enough"
|
||||
|
||||
def test_non_list_values(self):
|
||||
"""Test handling of non-list values."""
|
||||
router = create_list_length_router(3, "enough", "not_enough")
|
||||
|
||||
# Non-list values should route to insufficient
|
||||
assert router({"items": "string"}) == "not_enough"
|
||||
assert router({"items": None}) == "not_enough"
|
||||
assert router({"items": 123}) == "not_enough"
|
||||
assert router({"items": {"key": "value"}}) == "not_enough"
|
||||
|
||||
def test_missing_items_key(self):
|
||||
"""Test behavior when items key is missing."""
|
||||
router = create_list_length_router(
|
||||
3, "enough", "not_enough", state_key="missing"
|
||||
)
|
||||
|
||||
# Should default to empty list and route to insufficient
|
||||
assert router({}) == "not_enough"
|
||||
assert router({"other_key": [1, 2, 3, 4]}) == "not_enough"
|
||||
|
||||
def test_custom_state_key(self):
|
||||
"""Test with custom state key."""
|
||||
router = create_list_length_router(2, "ok", "too_few", state_key="results")
|
||||
|
||||
assert router({"results": [1, 2, 3]}) == "ok"
|
||||
assert router({"results": [1]}) == "too_few"
|
||||
|
||||
def test_string_as_iterable(self):
|
||||
"""Test string handling (has length but not a list)."""
|
||||
router = create_list_length_router(5, "enough", "not_enough")
|
||||
|
||||
# Strings have len() but should be treated as non-list
|
||||
assert router({"items": "hello world"}) == "not_enough"
|
||||
|
||||
def test_list_like_iterables(self):
|
||||
"""Test handling of list-like but non-list iterables (tuple, set)."""
|
||||
router = create_list_length_router(3, "ok", "not_ok")
|
||||
|
||||
# Tuple with enough elements
|
||||
assert router({"items": (1, 2, 3)}) == "ok"
|
||||
# Tuple with too few elements
|
||||
assert router({"items": (1,)}) == "not_ok"
|
||||
|
||||
# Set with enough elements
|
||||
assert router({"items": {1, 2, 3}}) == "ok"
|
||||
# Set with too few elements
|
||||
assert router({"items": {1}}) == "not_ok"
|
||||
|
||||
def test_state_protocol_implementation(self):
|
||||
"""Test with StateProtocol implementation."""
|
||||
router = create_list_length_router(2, "enough", "not_enough")
|
||||
|
||||
state = MockState({"items": [1, 2, 3]})
|
||||
assert router(state) == "enough"
|
||||
|
||||
state = MockState({"items": [1]})
|
||||
assert router(state) == "not_enough"
|
||||
|
||||
|
||||
class TestStatusRouter:
|
||||
"""Test create_status_router (alias for create_enum_router)."""
|
||||
|
||||
def test_status_routing(self):
|
||||
"""Test status-based routing."""
|
||||
status_mapping = {
|
||||
"pending": "wait",
|
||||
"running": "monitor",
|
||||
"completed": "finish",
|
||||
"failed": "error_handler",
|
||||
}
|
||||
router = create_status_router(status_mapping)
|
||||
|
||||
assert router({"status": "pending"}) == "wait"
|
||||
assert router({"status": "running"}) == "monitor"
|
||||
assert router({"status": "completed"}) == "finish"
|
||||
assert router({"status": "failed"}) == "error_handler"
|
||||
assert router({"status": "unknown"}) == "end" # default
|
||||
|
||||
|
||||
class TestEdgeCasesAndIntegration:
|
||||
"""Test various edge cases and integration scenarios."""
|
||||
|
||||
def test_extremely_large_values(self):
|
||||
"""Test with extremely large numeric values."""
|
||||
router = create_threshold_router(1e10, "huge", "normal")
|
||||
|
||||
assert router({"score": 1e15}) == "huge"
|
||||
assert router({"score": 1e5}) == "normal"
|
||||
|
||||
def test_extremely_small_values(self):
|
||||
"""Test with extremely small numeric values."""
|
||||
router = create_threshold_router(1e-10, "above", "below")
|
||||
|
||||
assert router({"score": 1e-15}) == "below"
|
||||
assert router({"score": 1e-5}) == "above"
|
||||
|
||||
def test_unicode_string_handling(self):
|
||||
"""Test handling of unicode strings."""
|
||||
enum_mapping = {"🔥": "hot", "❄️": "cold", "🌞": "warm"}
|
||||
router = create_enum_router(enum_mapping, state_key="weather")
|
||||
|
||||
assert router({"weather": "🔥"}) == "hot"
|
||||
assert router({"weather": "❄️"}) == "cold"
|
||||
assert router({"weather": "🌈"}) == "end" # default
|
||||
|
||||
def test_deeply_nested_state_access(self):
|
||||
"""Test that routers handle flat state keys only."""
|
||||
router = create_bool_router("yes", "no", state_key="nested.deep.value")
|
||||
|
||||
# Should look for literal key "nested.deep.value", not traverse
|
||||
state = {"nested.deep.value": True}
|
||||
assert router(state) == "yes"
|
||||
|
||||
# Should not traverse nested structures
|
||||
nested_state = {"nested": {"deep": {"value": True}}}
|
||||
assert router(nested_state) == "no" # key not found at top level
|
||||
|
||||
def test_concurrent_router_usage(self):
|
||||
"""Test that routers are stateless and thread-safe."""
|
||||
router = create_threshold_router(0.5, "high", "low")
|
||||
|
||||
# Multiple calls with different states should not interfere
|
||||
state1 = {"score": 0.8}
|
||||
state2 = {"score": 0.2}
|
||||
|
||||
assert router(state1) == "high"
|
||||
assert router(state2) == "low"
|
||||
assert router(state1) == "high" # Should still work
|
||||
|
||||
def test_memory_efficiency_large_lists(self):
|
||||
"""Test memory efficiency with large lists."""
|
||||
router = create_list_length_router(1000, "big", "small")
|
||||
|
||||
large_list = list(range(2000))
|
||||
small_list = list(range(500))
|
||||
|
||||
assert router({"items": large_list}) == "big"
|
||||
assert router({"items": small_list}) == "small"
|
||||
|
||||
def test_type_coercion_consistency(self):
|
||||
"""Test consistent type coercion across routers."""
|
||||
# Test that all routers handle type coercion consistently
|
||||
enum_router = create_enum_router({"1": "one", "2": "two"})
|
||||
threshold_router = create_threshold_router(1.5, "high", "low")
|
||||
|
||||
# Numeric as string
|
||||
assert enum_router({"routing_decision": 1}) == "one"
|
||||
assert threshold_router({"score": "2.0"}) == "high"
|
||||
|
||||
# String as number (where applicable)
|
||||
assert enum_router({"routing_decision": "2"}) == "two"
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,614 @@
|
||||
"""Comprehensive tests for flow control edge helpers.
|
||||
|
||||
Tests cover edge cases including:
|
||||
- Various message formats for tool calls
|
||||
- Timeout conditions and edge cases
|
||||
- Multi-step progress tracking
|
||||
- Iteration limits and boundary conditions
|
||||
- Workflow state transitions
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from bb_core.edge_helpers.flow_control import (
|
||||
check_completion_criteria,
|
||||
check_iteration_limit,
|
||||
check_workflow_state,
|
||||
multi_step_progress,
|
||||
should_continue,
|
||||
timeout_check,
|
||||
)
|
||||
|
||||
|
||||
class MockState:
|
||||
"""Mock state object for testing."""
|
||||
|
||||
def __init__(self, data: dict[str, Any]):
|
||||
self._data = data
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
return self._data.get(key, default)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return self._data.get(name)
|
||||
|
||||
|
||||
class TestShouldContinue:
|
||||
"""Test should_continue function with various message formats."""
|
||||
|
||||
def test_no_messages_should_end(self):
|
||||
"""Test that empty messages list results in end."""
|
||||
assert should_continue({}) == "end"
|
||||
assert should_continue({"messages": []}) == "end"
|
||||
assert should_continue(MockState({})) == "end"
|
||||
|
||||
def test_dict_message_with_tool_calls(self):
|
||||
"""Test dict message format with tool_calls."""
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'll help you with that.",
|
||||
"tool_calls": [{"function": {"name": "search"}}],
|
||||
}
|
||||
]
|
||||
|
||||
assert should_continue({"messages": messages}) == "continue"
|
||||
|
||||
def test_dict_message_with_additional_kwargs_tool_calls(self):
|
||||
"""Test dict message with tool_calls in additional_kwargs."""
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Let me search for that.",
|
||||
"additional_kwargs": {"tool_calls": [{"function": {"name": "search"}}]},
|
||||
}
|
||||
]
|
||||
|
||||
assert should_continue({"messages": messages}) == "continue"
|
||||
|
||||
def test_dict_message_with_function_call(self):
|
||||
"""Test dict message with legacy function_call format."""
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"function_call": {"name": "search", "arguments": "{}"},
|
||||
}
|
||||
]
|
||||
|
||||
assert should_continue({"messages": messages}) == "continue"
|
||||
|
||||
def test_object_message_with_tool_calls_attribute(self):
|
||||
"""Test message object with tool_calls attribute."""
|
||||
|
||||
class MockMessage:
|
||||
def __init__(self, tool_calls=None):
|
||||
self.tool_calls = tool_calls
|
||||
|
||||
# With tool calls
|
||||
messages = [MockMessage(tool_calls=[{"function": "search"}])]
|
||||
assert should_continue({"messages": messages}) == "continue"
|
||||
|
||||
# Without tool calls
|
||||
messages = [MockMessage(tool_calls=None)]
|
||||
assert should_continue({"messages": messages}) == "end"
|
||||
|
||||
# Empty tool calls
|
||||
messages = [MockMessage(tool_calls=[])]
|
||||
assert should_continue({"messages": messages}) == "end"
|
||||
|
||||
def test_object_message_with_additional_kwargs(self):
|
||||
"""Test message object with additional_kwargs containing tool_calls."""
|
||||
|
||||
class MockMessage:
|
||||
def __init__(self, additional_kwargs=None):
|
||||
self.additional_kwargs = additional_kwargs
|
||||
|
||||
# With tool calls in additional_kwargs
|
||||
messages = [
|
||||
MockMessage(additional_kwargs={"tool_calls": [{"function": "search"}]})
|
||||
]
|
||||
assert should_continue({"messages": messages}) == "continue"
|
||||
|
||||
# Without tool calls
|
||||
messages = [MockMessage(additional_kwargs={})]
|
||||
assert should_continue({"messages": messages}) == "end"
|
||||
|
||||
# None additional_kwargs
|
||||
messages = [MockMessage(additional_kwargs=None)]
|
||||
assert should_continue({"messages": messages}) == "end"
|
||||
|
||||
def test_message_without_tool_calls(self):
|
||||
"""Test message without any tool calls should end."""
|
||||
messages = [
|
||||
{"role": "assistant", "content": "Here's your answer without tool calls."}
|
||||
]
|
||||
|
||||
assert should_continue({"messages": messages}) == "end"
|
||||
|
||||
def test_mixed_message_formats(self):
|
||||
"""Test with multiple messages in different formats."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Help me search"},
|
||||
{"role": "assistant", "content": "I'll search for you."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Let me use a tool.",
|
||||
"tool_calls": [{"function": "search"}],
|
||||
},
|
||||
]
|
||||
|
||||
# Should check only the last message
|
||||
assert should_continue({"messages": messages}) == "continue"
|
||||
|
||||
def test_state_protocol_implementation(self):
|
||||
"""Test with StateProtocol implementation."""
|
||||
messages = [{"tool_calls": [{"function": {"name": "search"}}]}]
|
||||
|
||||
state = MockState({"messages": messages})
|
||||
assert should_continue(state) == "continue"
|
||||
|
||||
def test_edge_case_empty_tool_calls(self):
|
||||
"""Test edge case with empty tool_calls list."""
|
||||
messages = [{"role": "assistant", "content": "Response", "tool_calls": []}]
|
||||
|
||||
assert should_continue({"messages": messages}) == "end"
|
||||
|
||||
def test_edge_case_malformed_messages(self):
|
||||
"""Test with malformed message structures."""
|
||||
# Non-dict, non-object message
|
||||
messages = ["string_message"]
|
||||
assert should_continue({"messages": messages}) == "end"
|
||||
|
||||
# None message
|
||||
messages = [None]
|
||||
assert should_continue({"messages": messages}) == "end"
|
||||
|
||||
# Message with invalid tool_calls type
|
||||
messages = [{"tool_calls": "not_a_list"}]
|
||||
assert should_continue({"messages": messages}) == "end"
|
||||
|
||||
|
||||
class TestTimeoutCheck:
|
||||
"""Test timeout_check function with various edge cases."""
|
||||
|
||||
@patch("time.time")
|
||||
def test_no_timeout_within_limit(self, mock_time):
|
||||
"""Test no timeout when within time limit."""
|
||||
mock_time.return_value = 1000.0
|
||||
router = timeout_check(timeout_seconds=300.0)
|
||||
|
||||
# Start time 100 seconds ago
|
||||
state = {"start_time": 900.0}
|
||||
assert router(state) == "continue"
|
||||
|
||||
@patch("time.time")
|
||||
def test_timeout_exceeded(self, mock_time):
|
||||
"""Test timeout when time limit exceeded."""
|
||||
mock_time.return_value = 1000.0
|
||||
router = timeout_check(timeout_seconds=300.0)
|
||||
|
||||
# Start time 400 seconds ago (exceeds 300s limit)
|
||||
state = {"start_time": 600.0}
|
||||
assert router(state) == "timeout"
|
||||
|
||||
@patch("time.time")
|
||||
def test_exact_timeout_boundary(self, mock_time):
|
||||
"""Test exact timeout boundary."""
|
||||
mock_time.return_value = 1000.0
|
||||
router = timeout_check(timeout_seconds=300.0)
|
||||
|
||||
# Exactly at timeout limit
|
||||
state = {"start_time": 700.0}
|
||||
assert router(state) == "continue"
|
||||
|
||||
# Just over timeout limit
|
||||
state = {"start_time": 699.9}
|
||||
assert router(state) == "timeout"
|
||||
|
||||
def test_missing_start_time(self):
|
||||
"""Test behavior when start_time is missing."""
|
||||
router = timeout_check(timeout_seconds=300.0)
|
||||
|
||||
assert router({}) == "continue"
|
||||
assert router({"other_key": 123}) == "continue"
|
||||
|
||||
def test_none_start_time(self):
|
||||
"""Test behavior when start_time is None."""
|
||||
router = timeout_check(timeout_seconds=300.0)
|
||||
|
||||
assert router({"start_time": None}) == "continue"
|
||||
|
||||
def test_custom_start_time_key(self):
|
||||
"""Test with custom start time key."""
|
||||
router = timeout_check(timeout_seconds=60.0, start_time_key="begin_time")
|
||||
|
||||
with patch("time.time", return_value=1000.0):
|
||||
state = {"begin_time": 950.0} # 50 seconds ago
|
||||
assert router(state) == "continue"
|
||||
|
||||
state = {"begin_time": 930.0} # 70 seconds ago
|
||||
assert router(state) == "timeout"
|
||||
|
||||
def test_state_protocol_implementation(self):
|
||||
"""Test with StateProtocol implementation."""
|
||||
router = timeout_check(timeout_seconds=300.0)
|
||||
|
||||
with patch("time.time", return_value=1000.0):
|
||||
state = MockState({"start_time": 900.0})
|
||||
assert router(state) == "continue"
|
||||
|
||||
|
||||
class TestMultiStepProgress:
|
||||
"""Test multi_step_progress function with various edge cases."""
|
||||
|
||||
def test_normal_progression(self):
|
||||
"""Test normal step progression."""
|
||||
router = multi_step_progress(total_steps=5)
|
||||
|
||||
assert router({"current_step": 0}) == "next_step"
|
||||
assert router({"current_step": 1}) == "next_step"
|
||||
assert router({"current_step": 4}) == "next_step"
|
||||
|
||||
def test_completion(self):
|
||||
"""Test completion when steps are done."""
|
||||
router = multi_step_progress(total_steps=5)
|
||||
|
||||
assert router({"current_step": 5}) == "complete"
|
||||
assert router({"current_step": 6}) == "complete"
|
||||
assert router({"current_step": 100}) == "complete"
|
||||
|
||||
def test_negative_step_error(self):
|
||||
"""Test error for negative step numbers."""
|
||||
router = multi_step_progress(total_steps=5)
|
||||
|
||||
assert router({"current_step": -1}) == "error"
|
||||
assert router({"current_step": -10}) == "error"
|
||||
|
||||
def test_invalid_step_types(self):
|
||||
"""Test error for invalid step types."""
|
||||
router = multi_step_progress(total_steps=5)
|
||||
|
||||
assert router({"current_step": "invalid"}) == "error"
|
||||
assert router({"current_step": []}) == "error"
|
||||
assert router({"current_step": {}}) == "error"
|
||||
assert router({"current_step": None}) == "error"
|
||||
|
||||
def test_missing_step_key(self):
|
||||
"""Test behavior when step key is missing."""
|
||||
router = multi_step_progress(total_steps=5)
|
||||
|
||||
# Should default to 0 and return next_step
|
||||
assert router({}) == "next_step"
|
||||
|
||||
def test_custom_step_key(self):
|
||||
"""Test with custom step key."""
|
||||
router = multi_step_progress(total_steps=3, step_key="progress")
|
||||
|
||||
assert router({"progress": 0}) == "next_step"
|
||||
assert router({"progress": 3}) == "complete"
|
||||
|
||||
def test_single_step_workflow(self):
|
||||
"""Test single-step workflow."""
|
||||
router = multi_step_progress(total_steps=1)
|
||||
|
||||
assert router({"current_step": 0}) == "next_step"
|
||||
assert router({"current_step": 1}) == "complete"
|
||||
|
||||
def test_zero_steps_workflow(self):
|
||||
"""Test zero-steps workflow edge case."""
|
||||
router = multi_step_progress(total_steps=0)
|
||||
|
||||
assert router({"current_step": 0}) == "complete"
|
||||
assert router({"current_step": -1}) == "error"
|
||||
|
||||
def test_string_numeric_conversion(self):
|
||||
"""Test conversion of string numbers."""
|
||||
router = multi_step_progress(total_steps=5)
|
||||
|
||||
assert router({"current_step": "2"}) == "next_step"
|
||||
assert router({"current_step": "5"}) == "complete"
|
||||
assert router({"current_step": "-1"}) == "error"
|
||||
|
||||
|
||||
class TestCheckIterationLimit:
|
||||
"""Test check_iteration_limit function with various edge cases."""
|
||||
|
||||
def test_within_limit(self):
|
||||
"""Test iterations within limit."""
|
||||
router = check_iteration_limit(max_iterations=10)
|
||||
|
||||
assert router({"iteration_count": 0}) == "continue"
|
||||
assert router({"iteration_count": 5}) == "continue"
|
||||
assert router({"iteration_count": 9}) == "continue"
|
||||
|
||||
def test_limit_reached(self):
|
||||
"""Test when iteration limit is reached."""
|
||||
router = check_iteration_limit(max_iterations=10)
|
||||
|
||||
assert router({"iteration_count": 10}) == "limit_reached"
|
||||
assert router({"iteration_count": 15}) == "limit_reached"
|
||||
|
||||
def test_missing_iteration_count(self):
|
||||
"""Test behavior when iteration count is missing."""
|
||||
router = check_iteration_limit(max_iterations=5)
|
||||
|
||||
# Should default to 0 and return continue
|
||||
assert router({}) == "continue"
|
||||
|
||||
def test_invalid_iteration_types(self):
|
||||
"""Test handling of invalid iteration count types."""
|
||||
router = check_iteration_limit(max_iterations=5)
|
||||
|
||||
# Invalid types should default to continue
|
||||
assert router({"iteration_count": "invalid"}) == "continue"
|
||||
assert router({"iteration_count": []}) == "continue"
|
||||
assert router({"iteration_count": None}) == "continue"
|
||||
|
||||
def test_custom_iteration_key(self):
|
||||
"""Test with custom iteration key."""
|
||||
router = check_iteration_limit(max_iterations=3, iteration_key="retry_count")
|
||||
|
||||
assert router({"retry_count": 2}) == "continue"
|
||||
assert router({"retry_count": 3}) == "limit_reached"
|
||||
|
||||
def test_zero_max_iterations(self):
|
||||
"""Test with zero max iterations."""
|
||||
router = check_iteration_limit(max_iterations=0)
|
||||
|
||||
assert router({"iteration_count": 0}) == "limit_reached"
|
||||
assert router({"iteration_count": -1}) == "continue"
|
||||
|
||||
def test_negative_iteration_count(self):
|
||||
"""Test with negative iteration count."""
|
||||
router = check_iteration_limit(max_iterations=5)
|
||||
|
||||
assert router({"iteration_count": -1}) == "continue"
|
||||
assert router({"iteration_count": -10}) == "continue"
|
||||
|
||||
|
||||
class TestCheckCompletionCriteria:
|
||||
"""Test check_completion_criteria function with various edge cases."""
|
||||
|
||||
def test_all_criteria_met(self):
|
||||
"""Test when all completion criteria are met."""
|
||||
router = check_completion_criteria(["task1", "task2", "task3"])
|
||||
|
||||
state = {
|
||||
"completed_task1": True,
|
||||
"completed_task2": True,
|
||||
"completed_task3": True,
|
||||
}
|
||||
assert router(state) == "complete"
|
||||
|
||||
def test_some_criteria_missing(self):
|
||||
"""Test when some criteria are not met."""
|
||||
router = check_completion_criteria(["task1", "task2", "task3"])
|
||||
|
||||
state = {
|
||||
"completed_task1": True,
|
||||
"completed_task2": False,
|
||||
"completed_task3": True,
|
||||
}
|
||||
assert router(state) == "continue"
|
||||
|
||||
def test_no_criteria_met(self):
|
||||
"""Test when no criteria are met."""
|
||||
router = check_completion_criteria(["task1", "task2"])
|
||||
|
||||
state = {}
|
||||
assert router(state) == "continue"
|
||||
|
||||
def test_custom_condition_prefix(self):
|
||||
"""Test with custom condition prefix."""
|
||||
router = check_completion_criteria(
|
||||
["validation", "processing"], condition_prefix="done_"
|
||||
)
|
||||
|
||||
state = {"done_validation": True, "done_processing": True}
|
||||
assert router(state) == "complete"
|
||||
|
||||
state = {"done_validation": True, "done_processing": False}
|
||||
assert router(state) == "continue"
|
||||
|
||||
def test_falsy_values_as_incomplete(self):
|
||||
"""Test that falsy values are treated as incomplete."""
|
||||
router = check_completion_criteria(["task1", "task2"])
|
||||
|
||||
# False, 0, None, empty string should be incomplete
|
||||
state = {"completed_task1": False, "completed_task2": True}
|
||||
assert router(state) == "continue"
|
||||
|
||||
state = {"completed_task1": 0, "completed_task2": True}
|
||||
assert router(state) == "continue"
|
||||
|
||||
state = {"completed_task1": None, "completed_task2": True}
|
||||
assert router(state) == "continue"
|
||||
|
||||
def test_truthy_values_as_complete(self):
|
||||
"""Test that truthy values are treated as complete."""
|
||||
router = check_completion_criteria(["task1", "task2"])
|
||||
|
||||
state = {"completed_task1": 1, "completed_task2": "done"}
|
||||
assert router(state) == "complete"
|
||||
|
||||
state = {"completed_task1": [1, 2, 3], "completed_task2": {"status": "done"}}
|
||||
assert router(state) == "complete"
|
||||
|
||||
def test_empty_criteria_list(self):
|
||||
"""Test with empty criteria list."""
|
||||
router = check_completion_criteria([])
|
||||
|
||||
# No criteria means always complete
|
||||
assert router({}) == "complete"
|
||||
assert router({"any_key": "any_value"}) == "complete"
|
||||
|
||||
|
||||
class TestCheckWorkflowState:
|
||||
"""Test check_workflow_state function with various edge cases."""
|
||||
|
||||
def test_valid_state_transitions(self):
|
||||
"""Test valid state transitions."""
|
||||
transitions = {
|
||||
"init": "process",
|
||||
"process": "validate",
|
||||
"validate": "complete",
|
||||
"complete": "end",
|
||||
}
|
||||
router = check_workflow_state(transitions)
|
||||
|
||||
assert router({"workflow_state": "init"}) == "process"
|
||||
assert router({"workflow_state": "process"}) == "validate"
|
||||
assert router({"workflow_state": "validate"}) == "complete"
|
||||
assert router({"workflow_state": "complete"}) == "end"
|
||||
|
||||
def test_unknown_state_default(self):
|
||||
"""Test unknown state uses default target."""
|
||||
transitions = {"known": "next"}
|
||||
router = check_workflow_state(transitions, default_target="fallback")
|
||||
|
||||
assert router({"workflow_state": "unknown"}) == "fallback"
|
||||
assert router({"workflow_state": None}) == "fallback"
|
||||
|
||||
def test_missing_workflow_state(self):
|
||||
"""Test missing workflow state key."""
|
||||
transitions = {"init": "process"}
|
||||
router = check_workflow_state(transitions)
|
||||
|
||||
assert router({}) == "error" # default_target="error"
|
||||
|
||||
def test_custom_state_key(self):
|
||||
"""Test with custom state key."""
|
||||
transitions = {"start": "middle", "middle": "end"}
|
||||
router = check_workflow_state(
|
||||
transitions, state_key="current_phase", default_target="unknown"
|
||||
)
|
||||
|
||||
assert router({"current_phase": "start"}) == "middle"
|
||||
assert router({"current_phase": "middle"}) == "end"
|
||||
assert router({"current_phase": "invalid"}) == "unknown"
|
||||
|
||||
def test_numeric_state_values(self):
|
||||
"""Test with numeric state values."""
|
||||
transitions = {"1": "stage_one", "2": "stage_two"}
|
||||
router = check_workflow_state(transitions)
|
||||
|
||||
# Numeric values should be converted to strings
|
||||
assert router({"workflow_state": 1}) == "stage_one"
|
||||
assert router({"workflow_state": "2"}) == "stage_two"
|
||||
|
||||
def test_case_sensitivity(self):
|
||||
"""Test case sensitivity in state matching."""
|
||||
transitions = {"Running": "monitor", "running": "process"}
|
||||
router = check_workflow_state(transitions)
|
||||
|
||||
# Should be case sensitive
|
||||
assert router({"workflow_state": "Running"}) == "monitor"
|
||||
assert router({"workflow_state": "running"}) == "process"
|
||||
assert router({"workflow_state": "RUNNING"}) == "error"
|
||||
|
||||
def test_complex_state_names(self):
|
||||
"""Test with complex state names."""
|
||||
transitions = {
|
||||
"user_input_required": "wait_for_input",
|
||||
"processing_with_retry": "retry_handler",
|
||||
"validation_failed_critical": "escalate",
|
||||
}
|
||||
router = check_workflow_state(transitions)
|
||||
|
||||
assert router({"workflow_state": "user_input_required"}) == "wait_for_input"
|
||||
assert router({"workflow_state": "processing_with_retry"}) == "retry_handler"
|
||||
assert router({"workflow_state": "validation_failed_critical"}) == "escalate"
|
||||
|
||||
|
||||
class TestIntegrationAndEdgeCases:
|
||||
"""Test integration scenarios and edge cases."""
|
||||
|
||||
def test_state_protocol_consistency(self):
|
||||
"""Test that all functions work consistently with StateProtocol."""
|
||||
state_data = {
|
||||
"messages": [{"tool_calls": [{"function": "search"}]}],
|
||||
"start_time": time.time() - 100,
|
||||
"current_step": 2,
|
||||
"iteration_count": 3,
|
||||
"completed_task1": True,
|
||||
"completed_task2": False,
|
||||
"workflow_state": "processing",
|
||||
}
|
||||
|
||||
dict_state = state_data
|
||||
protocol_state = MockState(state_data)
|
||||
|
||||
# All functions should return same results for both state types
|
||||
assert should_continue(dict_state) == should_continue(protocol_state)
|
||||
|
||||
timeout_router = timeout_check(300.0)
|
||||
assert timeout_router(dict_state) == timeout_router(protocol_state)
|
||||
|
||||
progress_router = multi_step_progress(5)
|
||||
assert progress_router(dict_state) == progress_router(protocol_state)
|
||||
|
||||
limit_router = check_iteration_limit(10)
|
||||
assert limit_router(dict_state) == limit_router(protocol_state)
|
||||
|
||||
criteria_router = check_completion_criteria(["task1", "task2"])
|
||||
assert criteria_router(dict_state) == criteria_router(protocol_state)
|
||||
|
||||
workflow_router = check_workflow_state({"processing": "next"})
|
||||
assert workflow_router(dict_state) == workflow_router(protocol_state)
|
||||
|
||||
def test_performance_with_large_message_lists(self):
|
||||
"""Test performance with large message lists."""
|
||||
large_message_list = [
|
||||
{"role": "assistant", "content": f"Message {i}"} for i in range(1000)
|
||||
]
|
||||
# Add tool call to last message
|
||||
large_message_list[-1]["tool_calls"] = "search"
|
||||
|
||||
state = {"messages": large_message_list}
|
||||
|
||||
# Should efficiently check only the last message
|
||||
assert should_continue(state) == "continue"
|
||||
|
||||
def test_memory_safety_with_recursive_structures(self):
|
||||
"""Test memory safety with recursive data structures."""
|
||||
# Create a recursive structure (not recommended but should not crash)
|
||||
recursive_dict = {}
|
||||
recursive_dict["self"] = recursive_dict
|
||||
|
||||
state = {"workflow_state": recursive_dict}
|
||||
|
||||
# Should handle gracefully by converting to string
|
||||
workflow_router = check_workflow_state({"[object Object]": "next"})
|
||||
result = workflow_router(state)
|
||||
# Should not crash and return default
|
||||
assert result == "error"
|
||||
|
||||
def test_concurrent_router_instances(self):
|
||||
"""Test multiple router instances don't interfere."""
|
||||
router1 = timeout_check(timeout_seconds=100.0)
|
||||
router2 = timeout_check(timeout_seconds=200.0)
|
||||
|
||||
with patch("time.time", return_value=1000.0):
|
||||
state = {"start_time": 850.0} # 150 seconds ago
|
||||
|
||||
# Different timeouts should give different results
|
||||
assert router1(state) == "timeout" # 150 > 100
|
||||
assert router2(state) == "continue" # 150 < 200
|
||||
|
||||
def test_extreme_numeric_values(self):
|
||||
"""Test handling of extreme numeric values."""
|
||||
# Very large iteration count
|
||||
limit_router = check_iteration_limit(max_iterations=10)
|
||||
assert limit_router({"iteration_count": 2**63 - 1}) == "limit_reached"
|
||||
|
||||
# Very large step count
|
||||
progress_router = multi_step_progress(total_steps=5)
|
||||
assert progress_router({"current_step": 10**10}) == "complete"
|
||||
|
||||
# Very old start time
|
||||
with patch("time.time", return_value=1000.0):
|
||||
timeout_router = timeout_check(timeout_seconds=60.0)
|
||||
assert timeout_router({"start_time": -1000000.0}) == "timeout"
|
||||
1165
packages/business-buddy-core/tests/edge_helpers/test_monitoring.py
Normal file
1165
packages/business-buddy-core/tests/edge_helpers/test_monitoring.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,934 @@
|
||||
"""Comprehensive tests for user interaction edge helpers.
|
||||
|
||||
Tests cover edge cases including:
|
||||
- Various interrupt signal formats
|
||||
- Complex user authorization scenarios
|
||||
- Multiple feedback conditions
|
||||
- Escalation triggers and edge cases
|
||||
- Input collection edge cases
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from bb_core.edge_helpers.user_interaction import (
|
||||
check_user_authorization,
|
||||
collect_user_input,
|
||||
escalate_to_human,
|
||||
human_interrupt,
|
||||
pass_status_to_user,
|
||||
user_feedback_loop,
|
||||
)
|
||||
|
||||
|
||||
class MockState:
|
||||
"""Mock state object for testing."""
|
||||
|
||||
def __init__(self, data: dict[str, Any]):
|
||||
self._data = data
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
return self._data.get(key, default)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return self._data.get(name)
|
||||
|
||||
|
||||
class TestHumanInterrupt:
|
||||
"""Test human_interrupt function with various edge cases."""
|
||||
|
||||
def test_default_interrupt_signals(self):
|
||||
"""Test default interrupt signals."""
|
||||
router = human_interrupt()
|
||||
|
||||
# Default interrupt signals
|
||||
assert router({"human_interrupt": "stop"}) == "interrupt"
|
||||
assert router({"human_interrupt": "pause"}) == "interrupt"
|
||||
assert router({"human_interrupt": "cancel"}) == "interrupt"
|
||||
assert router({"human_interrupt": "abort"}) == "interrupt"
|
||||
assert router({"human_interrupt": "interrupt"}) == "interrupt"
|
||||
assert router({"human_interrupt": "halt"}) == "interrupt"
|
||||
|
||||
def test_non_interrupt_signals(self):
|
||||
"""Test signals that should not trigger interrupt."""
|
||||
router = human_interrupt()
|
||||
|
||||
assert router({"human_interrupt": "continue"}) == "continue"
|
||||
assert router({"human_interrupt": "proceed"}) == "continue"
|
||||
assert router({"human_interrupt": "go"}) == "continue"
|
||||
assert router({"human_interrupt": "run"}) == "continue"
|
||||
|
||||
def test_case_insensitive_signals(self):
|
||||
"""Test case insensitive signal matching."""
|
||||
router = human_interrupt()
|
||||
|
||||
assert router({"human_interrupt": "STOP"}) == "interrupt"
|
||||
assert router({"human_interrupt": "Cancel"}) == "interrupt"
|
||||
assert router({"human_interrupt": "ABORT"}) == "interrupt"
|
||||
|
||||
def test_whitespace_handling(self):
|
||||
"""Test whitespace handling in signals."""
|
||||
router = human_interrupt()
|
||||
|
||||
assert router({"human_interrupt": " stop "}) == "interrupt"
|
||||
assert router({"human_interrupt": "\\tcancel\\n"}) == "interrupt"
|
||||
assert router({"human_interrupt": " abort "}) == "interrupt"
|
||||
|
||||
def test_custom_interrupt_signals(self):
|
||||
"""Test custom interrupt signals."""
|
||||
router = human_interrupt(interrupt_signals=["quit", "exit", "emergency"])
|
||||
|
||||
assert router({"human_interrupt": "quit"}) == "interrupt"
|
||||
assert router({"human_interrupt": "exit"}) == "interrupt"
|
||||
assert router({"human_interrupt": "emergency"}) == "interrupt"
|
||||
|
||||
# Default signals should not work with custom list
|
||||
assert router({"human_interrupt": "stop"}) == "continue"
|
||||
|
||||
def test_missing_interrupt_key(self):
|
||||
"""Test behavior when interrupt key is missing."""
|
||||
router = human_interrupt()
|
||||
|
||||
assert router({}) == "continue"
|
||||
assert router({"other_key": "stop"}) == "continue"
|
||||
|
||||
def test_none_interrupt_signal(self):
|
||||
"""Test behavior with None interrupt signal."""
|
||||
router = human_interrupt()
|
||||
|
||||
assert router({"human_interrupt": None}) == "continue"
|
||||
|
||||
def test_custom_interrupt_key(self):
|
||||
"""Test with custom interrupt key."""
|
||||
router = human_interrupt(interrupt_key="user_command")
|
||||
|
||||
assert router({"user_command": "stop"}) == "interrupt"
|
||||
assert router({"user_command": "continue"}) == "continue"
|
||||
|
||||
def test_numeric_signals(self):
|
||||
"""Test numeric signals converted to strings."""
|
||||
router = human_interrupt(interrupt_signals=["0", "1"])
|
||||
|
||||
assert router({"human_interrupt": 0}) == "interrupt"
|
||||
assert router({"human_interrupt": "1"}) == "interrupt"
|
||||
assert router({"human_interrupt": 2}) == "continue"
|
||||
|
||||
def test_empty_signal(self):
|
||||
"""Test empty signal handling."""
|
||||
router = human_interrupt()
|
||||
|
||||
assert router({"human_interrupt": ""}) == "continue"
|
||||
assert router({"human_interrupt": " "}) == "continue" # Only whitespace
|
||||
|
||||
def test_state_protocol_implementation(self):
|
||||
"""Test with StateProtocol implementation."""
|
||||
router = human_interrupt()
|
||||
|
||||
state = MockState({"human_interrupt": "stop"})
|
||||
assert router(state) == "interrupt"
|
||||
|
||||
state = MockState({"human_interrupt": "continue"})
|
||||
assert router(state) == "continue"
|
||||
|
||||
|
||||
class TestPassStatusToUser:
|
||||
"""Test pass_status_to_user function with various edge cases."""
|
||||
|
||||
def test_default_status_levels(self):
|
||||
"""Test default status level mappings."""
|
||||
router = pass_status_to_user()
|
||||
|
||||
assert router({"status": "error", "notify_user": True}) == "urgent"
|
||||
assert router({"status": "failed", "notify_user": True}) == "urgent"
|
||||
assert router({"status": "warning", "notify_user": True}) == "medium"
|
||||
assert router({"status": "completed", "notify_user": True}) == "low"
|
||||
assert router({"status": "success", "notify_user": True}) == "low"
|
||||
assert router({"status": "in_progress", "notify_user": True}) == "info"
|
||||
|
||||
def test_notification_disabled(self):
|
||||
"""Test when notifications are disabled."""
|
||||
router = pass_status_to_user()
|
||||
|
||||
assert router({"status": "error", "notify_user": False}) == "no_notification"
|
||||
assert router({"status": "failed", "notify_user": False}) == "no_notification"
|
||||
|
||||
def test_missing_notify_key(self):
|
||||
"""Test default notification behavior when notify key is missing."""
|
||||
router = pass_status_to_user()
|
||||
|
||||
# Should default to True (notify)
|
||||
assert router({"status": "error"}) == "urgent"
|
||||
assert router({"status": "warning"}) == "medium"
|
||||
|
||||
def test_custom_status_levels(self):
|
||||
"""Test custom status level mappings."""
|
||||
custom_levels = {
|
||||
"critical": "immediate",
|
||||
"high": "urgent",
|
||||
"normal": "standard",
|
||||
}
|
||||
router = pass_status_to_user(status_levels=custom_levels)
|
||||
|
||||
assert router({"status": "critical", "notify_user": True}) == "immediate"
|
||||
assert router({"status": "high", "notify_user": True}) == "urgent"
|
||||
assert router({"status": "normal", "notify_user": True}) == "standard"
|
||||
|
||||
# Unknown status should return no_notification
|
||||
assert router({"status": "unknown", "notify_user": True}) == "no_notification"
|
||||
|
||||
def test_custom_keys(self):
|
||||
"""Test with custom status and notify keys."""
|
||||
router = pass_status_to_user(
|
||||
status_key="current_status", notify_key="send_notification"
|
||||
)
|
||||
|
||||
state = {"current_status": "error", "send_notification": True}
|
||||
assert router(state) == "urgent"
|
||||
|
||||
state = {"current_status": "error", "send_notification": False}
|
||||
assert router(state) == "no_notification"
|
||||
|
||||
def test_missing_status_key(self):
|
||||
"""Test behavior when status key is missing."""
|
||||
router = pass_status_to_user()
|
||||
|
||||
assert router({"notify_user": True}) == "no_notification"
|
||||
assert router({}) == "no_notification"
|
||||
|
||||
def test_none_status(self):
|
||||
"""Test behavior with None status."""
|
||||
router = pass_status_to_user()
|
||||
|
||||
assert router({"status": None, "notify_user": True}) == "no_notification"
|
||||
|
||||
def test_case_insensitive_status(self):
|
||||
"""Test case insensitive status matching."""
|
||||
router = pass_status_to_user()
|
||||
|
||||
assert router({"status": "ERROR", "notify_user": True}) == "urgent"
|
||||
assert router({"status": "Warning", "notify_user": True}) == "medium"
|
||||
assert router({"status": "COMPLETED", "notify_user": True}) == "low"
|
||||
|
||||
def test_numeric_status(self):
|
||||
"""Test numeric status values."""
|
||||
custom_levels = {"1": "low", "2": "medium", "3": "high"}
|
||||
router = pass_status_to_user(status_levels=custom_levels)
|
||||
|
||||
assert router({"status": 1, "notify_user": True}) == "low"
|
||||
assert router({"status": "2", "notify_user": True}) == "medium"
|
||||
assert router({"status": 3, "notify_user": True}) == "high"
|
||||
|
||||
def test_complex_status_values(self):
|
||||
"""Test complex status values converted to strings."""
|
||||
router = pass_status_to_user()
|
||||
|
||||
# Complex objects should be converted to strings
|
||||
assert router({"status": ["error"], "notify_user": True}) == "no_notification"
|
||||
assert (
|
||||
router({"status": {"type": "error"}, "notify_user": True})
|
||||
== "no_notification"
|
||||
)
|
||||
|
||||
|
||||
class TestUserFeedbackLoop:
|
||||
"""Test user_feedback_loop function with various edge cases."""
|
||||
|
||||
def test_explicit_feedback_required(self):
|
||||
"""Test explicit feedback requirement."""
|
||||
router = user_feedback_loop()
|
||||
|
||||
state = {"requires_feedback": True, "feedback_type": "validation"}
|
||||
assert router(state) == "feedback_validation"
|
||||
|
||||
state = {"requires_feedback": True, "feedback_type": "clarification"}
|
||||
assert router(state) == "feedback_clarification"
|
||||
|
||||
def test_default_feedback_type(self):
|
||||
"""Test default feedback type when not specified."""
|
||||
router = user_feedback_loop()
|
||||
|
||||
state = {"requires_feedback": True}
|
||||
assert router(state) == "feedback_general"
|
||||
|
||||
def test_condition_based_feedback(self):
|
||||
"""Test feedback triggered by conditions."""
|
||||
router = user_feedback_loop()
|
||||
|
||||
# Default conditions that trigger feedback
|
||||
assert router({"low_confidence": True}) == "feedback_low_confidence"
|
||||
assert router({"ambiguous_input": True}) == "feedback_ambiguous_input"
|
||||
assert router({"multiple_options": True}) == "feedback_multiple_options"
|
||||
assert router({"validation_failed": True}) == "feedback_validation_failed"
|
||||
|
||||
def test_no_feedback_needed(self):
|
||||
"""Test when no feedback is needed."""
|
||||
router = user_feedback_loop()
|
||||
|
||||
assert router({}) == "no_feedback"
|
||||
assert router({"requires_feedback": False}) == "no_feedback"
|
||||
assert router({"low_confidence": False}) == "no_feedback"
|
||||
|
||||
def test_custom_feedback_conditions(self):
|
||||
"""Test custom feedback conditions."""
|
||||
custom_conditions = ["needs_review", "unclear_intent", "missing_data"]
|
||||
router = user_feedback_loop(feedback_required_conditions=custom_conditions)
|
||||
|
||||
assert router({"needs_review": True}) == "feedback_needs_review"
|
||||
assert router({"unclear_intent": True}) == "feedback_unclear_intent"
|
||||
assert router({"missing_data": True}) == "feedback_missing_data"
|
||||
|
||||
# Default conditions should not trigger with custom list
|
||||
assert router({"low_confidence": True}) == "no_feedback"
|
||||
|
||||
def test_custom_feedback_keys(self):
|
||||
"""Test with custom feedback keys."""
|
||||
router = user_feedback_loop(
|
||||
feedback_key="user_input_needed", feedback_type_key="input_type"
|
||||
)
|
||||
|
||||
state = {"user_input_needed": True, "input_type": "choice"}
|
||||
assert router(state) == "feedback_choice"
|
||||
|
||||
def test_multiple_conditions_triggered(self):
|
||||
"""Test when multiple conditions are triggered."""
|
||||
router = user_feedback_loop()
|
||||
|
||||
# Should return feedback for first matching condition
|
||||
state = {"low_confidence": True, "ambiguous_input": True}
|
||||
result = router(state)
|
||||
assert result in ["feedback_low_confidence", "feedback_ambiguous_input"]
|
||||
|
||||
def test_priority_explicit_over_conditions(self):
|
||||
"""Test that explicit feedback takes priority over conditions."""
|
||||
router = user_feedback_loop()
|
||||
|
||||
state = {
|
||||
"requires_feedback": True,
|
||||
"feedback_type": "explicit",
|
||||
"low_confidence": True, # This should be ignored
|
||||
}
|
||||
assert router(state) == "feedback_explicit"
|
||||
|
||||
def test_falsy_conditions(self):
|
||||
"""Test that falsy condition values don't trigger feedback."""
|
||||
router = user_feedback_loop()
|
||||
|
||||
state = {
|
||||
"low_confidence": False,
|
||||
"ambiguous_input": 0,
|
||||
"multiple_options": None,
|
||||
"validation_failed": "",
|
||||
}
|
||||
assert router(state) == "no_feedback"
|
||||
|
||||
def test_truthy_non_boolean_conditions(self):
|
||||
"""Test truthy non-boolean values trigger feedback."""
|
||||
router = user_feedback_loop()
|
||||
|
||||
assert router({"low_confidence": 1}) == "feedback_low_confidence"
|
||||
assert router({"ambiguous_input": "yes"}) == "feedback_ambiguous_input"
|
||||
assert router({"multiple_options": [1, 2, 3]}) == "feedback_multiple_options"
|
||||
|
||||
def test_state_protocol_implementation(self):
|
||||
"""Test with StateProtocol implementation."""
|
||||
router = user_feedback_loop()
|
||||
|
||||
state = MockState({"requires_feedback": True, "feedback_type": "test"})
|
||||
assert router(state) == "feedback_test"
|
||||
|
||||
state = MockState({"low_confidence": True})
|
||||
assert router(state) == "feedback_low_confidence"
|
||||
|
||||
|
||||
class TestEscalateToHuman:
|
||||
"""Test escalate_to_human function with various edge cases."""
|
||||
|
||||
def test_automatic_escalation(self):
|
||||
"""Test automatic escalation based on reason."""
|
||||
router = escalate_to_human()
|
||||
|
||||
state = {"auto_escalate": True, "escalation_reason": "critical_error"}
|
||||
assert router(state) == "escalate_immediate"
|
||||
|
||||
state = {"auto_escalate": True, "escalation_reason": "security_issue"}
|
||||
assert router(state) == "escalate_immediate"
|
||||
|
||||
state = {"auto_escalate": True, "escalation_reason": "multiple_failures"}
|
||||
assert router(state) == "escalate_urgent"
|
||||
|
||||
def test_trigger_based_escalation(self):
|
||||
"""Test escalation based on trigger flags."""
|
||||
router = escalate_to_human()
|
||||
|
||||
assert router({"critical_error": True}) == "escalate_immediate"
|
||||
assert router({"security_issue": True}) == "escalate_immediate"
|
||||
assert router({"multiple_failures": True}) == "escalate_urgent"
|
||||
assert router({"user_request": True}) == "escalate_standard"
|
||||
assert router({"manual_review": True}) == "escalate_standard"
|
||||
assert router({"complex_case": True}) == "escalate_expert"
|
||||
|
||||
def test_no_escalation_needed(self):
|
||||
"""Test when no escalation is needed."""
|
||||
router = escalate_to_human()
|
||||
|
||||
assert router({}) == "continue_automated"
|
||||
assert router({"auto_escalate": False}) == "continue_automated"
|
||||
assert router({"some_other_flag": True}) == "continue_automated"
|
||||
|
||||
def test_custom_escalation_triggers(self):
|
||||
"""Test custom escalation triggers."""
|
||||
custom_triggers = {
|
||||
"data_breach": "immediate",
|
||||
"system_down": "urgent",
|
||||
"user_complaint": "standard",
|
||||
}
|
||||
router = escalate_to_human(escalation_triggers=custom_triggers)
|
||||
|
||||
assert router({"data_breach": True}) == "escalate_immediate"
|
||||
assert router({"system_down": True}) == "escalate_urgent"
|
||||
assert router({"user_complaint": True}) == "escalate_standard"
|
||||
|
||||
# Default triggers should not work with custom triggers
|
||||
assert router({"critical_error": True}) == "continue_automated"
|
||||
|
||||
def test_custom_escalation_keys(self):
|
||||
"""Test with custom escalation keys."""
|
||||
router = escalate_to_human(
|
||||
auto_escalate_key="force_escalate", escalation_reason_key="reason"
|
||||
)
|
||||
|
||||
state = {"force_escalate": True, "reason": "critical_error"}
|
||||
assert router(state) == "escalate_immediate"
|
||||
|
||||
def test_partial_reason_matching(self):
|
||||
"""Test partial reason matching in escalation reason."""
|
||||
router = escalate_to_human()
|
||||
|
||||
state = {
|
||||
"auto_escalate": True,
|
||||
"escalation_reason": "encountered critical_error in processing",
|
||||
}
|
||||
assert router(state) == "escalate_immediate"
|
||||
|
||||
state = {
|
||||
"auto_escalate": True,
|
||||
"escalation_reason": "multiple_failures detected",
|
||||
}
|
||||
assert router(state) == "escalate_urgent"
|
||||
|
||||
def test_unknown_escalation_reason(self):
|
||||
"""Test unknown escalation reason defaults to standard."""
|
||||
router = escalate_to_human()
|
||||
|
||||
state = {"auto_escalate": True, "escalation_reason": "unknown_issue"}
|
||||
assert router(state) == "escalate_standard"
|
||||
|
||||
def test_missing_escalation_reason(self):
|
||||
"""Test auto escalate without reason."""
|
||||
router = escalate_to_human()
|
||||
|
||||
state = {"auto_escalate": True}
|
||||
assert router(state) == "continue_automated"
|
||||
|
||||
def test_case_insensitive_reason_matching(self):
|
||||
"""Test case insensitive reason matching."""
|
||||
router = escalate_to_human()
|
||||
|
||||
state = {"auto_escalate": True, "escalation_reason": "CRITICAL_ERROR"}
|
||||
assert router(state) == "escalate_immediate"
|
||||
|
||||
state = {"auto_escalate": True, "escalation_reason": "Security_Issue"}
|
||||
assert router(state) == "escalate_immediate"
|
||||
|
||||
def test_multiple_triggers_precedence(self):
|
||||
"""Test precedence when multiple triggers are present."""
|
||||
router = escalate_to_human()
|
||||
|
||||
# Should find the first trigger in the dictionary order
|
||||
state = {"critical_error": True, "user_request": True, "complex_case": True}
|
||||
result = router(state)
|
||||
assert result.startswith("escalate_")
|
||||
|
||||
def test_state_protocol_implementation(self):
|
||||
"""Test with StateProtocol implementation."""
|
||||
router = escalate_to_human()
|
||||
|
||||
state = MockState(
|
||||
{"auto_escalate": True, "escalation_reason": "critical_error"}
|
||||
)
|
||||
assert router(state) == "escalate_immediate"
|
||||
|
||||
state = MockState({"critical_error": True})
|
||||
assert router(state) == "escalate_immediate"
|
||||
|
||||
|
||||
class TestCheckUserAuthorization:
|
||||
"""Test check_user_authorization function with various edge cases."""
|
||||
|
||||
def test_authorized_user(self):
|
||||
"""Test user with all required permissions."""
|
||||
router = check_user_authorization(["read", "write"])
|
||||
|
||||
user = {"permissions": ["read", "write", "admin"]}
|
||||
assert router({"user": user}) == "authorized"
|
||||
|
||||
def test_unauthorized_user_missing_permissions(self):
|
||||
"""Test user missing required permissions."""
|
||||
router = check_user_authorization(["read", "write", "admin"])
|
||||
|
||||
user = {"permissions": ["read", "write"]} # Missing admin
|
||||
assert router({"user": user}) == "unauthorized"
|
||||
|
||||
def test_unauthorized_user_no_permissions(self):
|
||||
"""Test user with no permissions."""
|
||||
router = check_user_authorization(["read"])
|
||||
|
||||
user = {"permissions": []}
|
||||
assert router({"user": user}) == "unauthorized"
|
||||
|
||||
user = {} # No permissions key
|
||||
assert router({"user": user}) == "unauthorized"
|
||||
|
||||
def test_missing_user(self):
|
||||
"""Test when user is missing from state."""
|
||||
router = check_user_authorization(["read"])
|
||||
|
||||
assert router({}) == "unauthorized"
|
||||
assert router({"other_key": "value"}) == "unauthorized"
|
||||
|
||||
def test_none_user(self):
|
||||
"""Test when user is None."""
|
||||
router = check_user_authorization(["read"])
|
||||
|
||||
assert router({"user": None}) == "unauthorized"
|
||||
|
||||
def test_user_object_with_attributes(self):
|
||||
"""Test user object with permission attributes."""
|
||||
|
||||
class MockUser:
|
||||
def __init__(self, permissions):
|
||||
self.permissions = permissions
|
||||
|
||||
router = check_user_authorization(["read", "write"])
|
||||
|
||||
user = MockUser(["read", "write", "admin"])
|
||||
assert router({"user": user}) == "authorized"
|
||||
|
||||
user = MockUser(["read"]) # Missing write
|
||||
assert router({"user": user}) == "unauthorized"
|
||||
|
||||
def test_custom_permission_keys(self):
|
||||
"""Test with custom user and permissions keys."""
|
||||
router = check_user_authorization(
|
||||
required_permissions=["access"],
|
||||
user_key="current_user",
|
||||
permissions_key="roles",
|
||||
)
|
||||
|
||||
user = {"roles": ["access", "modify"]}
|
||||
assert router({"current_user": user}) == "authorized"
|
||||
|
||||
user = {"roles": ["modify"]} # Missing access
|
||||
assert router({"current_user": user}) == "unauthorized"
|
||||
|
||||
def test_default_permissions(self):
|
||||
"""Test default required permissions."""
|
||||
router = check_user_authorization() # Default: ["basic_access"]
|
||||
|
||||
user = {"permissions": ["basic_access"]}
|
||||
assert router({"user": user}) == "authorized"
|
||||
|
||||
user = {"permissions": ["other_permission"]}
|
||||
assert router({"user": user}) == "unauthorized"
|
||||
|
||||
def test_empty_required_permissions(self):
|
||||
"""Test with empty required permissions list."""
|
||||
router = check_user_authorization(required_permissions=[])
|
||||
|
||||
# No requirements means everyone is authorized
|
||||
user = {"permissions": []}
|
||||
assert router({"user": user}) == "authorized"
|
||||
|
||||
user = {}
|
||||
assert router({"user": user}) == "authorized"
|
||||
|
||||
def test_non_list_permissions(self):
|
||||
"""Test when user permissions is not a list."""
|
||||
router = check_user_authorization(["read"])
|
||||
|
||||
user = {"permissions": "read"} # String instead of list
|
||||
assert router({"user": user}) == "unauthorized"
|
||||
|
||||
user = {"permissions": {"read": True}} # Dict instead of list
|
||||
assert router({"user": user}) == "unauthorized"
|
||||
|
||||
def test_case_sensitive_permissions(self):
|
||||
"""Test case sensitivity of permission matching."""
|
||||
router = check_user_authorization(["Read", "Write"])
|
||||
|
||||
user = {"permissions": ["read", "write"]} # Different case
|
||||
assert router({"user": user}) == "unauthorized"
|
||||
|
||||
user = {"permissions": ["Read", "Write"]} # Exact case
|
||||
assert router({"user": user}) == "authorized"
|
||||
|
||||
def test_user_object_missing_permissions_attribute(self):
|
||||
"""Test user object missing permissions attribute."""
|
||||
|
||||
class MockUser:
|
||||
def __init__(self):
|
||||
self.name = "John"
|
||||
|
||||
router = check_user_authorization(["read"])
|
||||
|
||||
user = MockUser() # No permissions attribute
|
||||
assert router({"user": user}) == "unauthorized"
|
||||
|
||||
def test_state_protocol_implementation(self):
|
||||
"""Test with StateProtocol implementation."""
|
||||
router = check_user_authorization(["read", "write"])
|
||||
|
||||
user = {"permissions": ["read", "write"]}
|
||||
state = MockState({"user": user})
|
||||
assert router(state) == "authorized"
|
||||
|
||||
user = {"permissions": ["read"]}
|
||||
state = MockState({"user": user})
|
||||
assert router(state) == "unauthorized"
|
||||
|
||||
|
||||
class TestCollectUserInput:
|
||||
"""Test collect_user_input function with various edge cases."""
|
||||
|
||||
def test_pending_input_with_default_types(self):
|
||||
"""Test pending input with default input types."""
|
||||
router = collect_user_input()
|
||||
|
||||
assert (
|
||||
router({"pending_user_input": True, "input_type": "text"}) == "text_input"
|
||||
)
|
||||
assert (
|
||||
router({"pending_user_input": True, "input_type": "choice"})
|
||||
== "choice_input"
|
||||
)
|
||||
assert (
|
||||
router({"pending_user_input": True, "input_type": "confirmation"})
|
||||
== "confirm_input"
|
||||
)
|
||||
assert (
|
||||
router({"pending_user_input": True, "input_type": "file"}) == "file_input"
|
||||
)
|
||||
assert (
|
||||
router({"pending_user_input": True, "input_type": "numeric"})
|
||||
== "number_input"
|
||||
)
|
||||
|
||||
def test_no_input_needed(self):
|
||||
"""Test when no input is needed."""
|
||||
router = collect_user_input()
|
||||
|
||||
assert router({"pending_user_input": False}) == "no_input_needed"
|
||||
assert router({}) == "no_input_needed" # Missing key defaults to False
|
||||
|
||||
def test_default_input_type(self):
|
||||
"""Test default input type when not specified."""
|
||||
router = collect_user_input()
|
||||
|
||||
assert router({"pending_user_input": True}) == "text_input" # Default to text
|
||||
|
||||
def test_custom_input_types(self):
|
||||
"""Test custom input type mappings."""
|
||||
custom_types = {
|
||||
"voice": "voice_recorder",
|
||||
"image": "image_capture",
|
||||
"location": "gps_picker",
|
||||
}
|
||||
router = collect_user_input(input_types=custom_types)
|
||||
|
||||
assert (
|
||||
router({"pending_user_input": True, "input_type": "voice"})
|
||||
== "voice_recorder"
|
||||
)
|
||||
assert (
|
||||
router({"pending_user_input": True, "input_type": "image"})
|
||||
== "image_capture"
|
||||
)
|
||||
assert (
|
||||
router({"pending_user_input": True, "input_type": "location"})
|
||||
== "gps_picker"
|
||||
)
|
||||
|
||||
# Unknown type should default to text_input
|
||||
assert (
|
||||
router({"pending_user_input": True, "input_type": "unknown"})
|
||||
== "text_input"
|
||||
)
|
||||
|
||||
def test_custom_input_keys(self):
|
||||
"""Test with custom input keys."""
|
||||
router = collect_user_input(
|
||||
pending_input_key="awaiting_input", input_type_key="required_input_type"
|
||||
)
|
||||
|
||||
state = {"awaiting_input": True, "required_input_type": "choice"}
|
||||
assert router(state) == "choice_input"
|
||||
|
||||
state = {"awaiting_input": False, "required_input_type": "text"}
|
||||
assert router(state) == "no_input_needed"
|
||||
|
||||
def test_case_insensitive_input_types(self):
|
||||
"""Test case insensitive input type matching."""
|
||||
router = collect_user_input()
|
||||
|
||||
assert (
|
||||
router({"pending_user_input": True, "input_type": "TEXT"}) == "text_input"
|
||||
)
|
||||
assert (
|
||||
router({"pending_user_input": True, "input_type": "Choice"})
|
||||
== "choice_input"
|
||||
)
|
||||
assert (
|
||||
router({"pending_user_input": True, "input_type": "FILE"}) == "file_input"
|
||||
)
|
||||
|
||||
def test_numeric_input_type(self):
|
||||
"""Test numeric input type values."""
|
||||
custom_types = {"1": "type_one", "2": "type_two"}
|
||||
router = collect_user_input(input_types=custom_types)
|
||||
|
||||
assert router({"pending_user_input": True, "input_type": 1}) == "type_one"
|
||||
assert router({"pending_user_input": True, "input_type": "2"}) == "type_two"
|
||||
|
||||
def test_truthy_falsy_pending_values(self):
|
||||
"""Test truthy/falsy values for pending input."""
|
||||
router = collect_user_input()
|
||||
|
||||
# Truthy values should indicate pending input
|
||||
assert router({"pending_user_input": 1, "input_type": "text"}) == "text_input"
|
||||
assert (
|
||||
router({"pending_user_input": "yes", "input_type": "text"}) == "text_input"
|
||||
)
|
||||
assert router({"pending_user_input": [1], "input_type": "text"}) == "text_input"
|
||||
|
||||
# Falsy values should indicate no input needed
|
||||
assert (
|
||||
router({"pending_user_input": 0, "input_type": "text"}) == "no_input_needed"
|
||||
)
|
||||
assert (
|
||||
router({"pending_user_input": "", "input_type": "text"})
|
||||
== "no_input_needed"
|
||||
)
|
||||
assert (
|
||||
router({"pending_user_input": [], "input_type": "text"})
|
||||
== "no_input_needed"
|
||||
)
|
||||
assert (
|
||||
router({"pending_user_input": None, "input_type": "text"})
|
||||
== "no_input_needed"
|
||||
)
|
||||
|
||||
def test_fallback_to_text_input(self):
|
||||
"""Test fallback to text input for unknown types."""
|
||||
router = collect_user_input()
|
||||
|
||||
assert (
|
||||
router({"pending_user_input": True, "input_type": "unknown_type"})
|
||||
== "text_input"
|
||||
)
|
||||
assert router({"pending_user_input": True, "input_type": None}) == "text_input"
|
||||
assert router({"pending_user_input": True, "input_type": ""}) == "text_input"
|
||||
|
||||
def test_state_protocol_implementation(self):
|
||||
"""Test with StateProtocol implementation."""
|
||||
router = collect_user_input()
|
||||
|
||||
state = MockState({"pending_user_input": True, "input_type": "choice"})
|
||||
assert router(state) == "choice_input"
|
||||
|
||||
state = MockState({"pending_user_input": False})
|
||||
assert router(state) == "no_input_needed"
|
||||
|
||||
|
||||
class TestIntegrationAndEdgeCases:
|
||||
"""Test integration scenarios and edge cases."""
|
||||
|
||||
def test_all_user_interaction_functions_consistency(self):
|
||||
"""Test all user interaction functions work consistently."""
|
||||
state = {
|
||||
"human_interrupt": "continue",
|
||||
"status": "warning",
|
||||
"notify_user": True,
|
||||
"requires_feedback": False,
|
||||
"low_confidence": False,
|
||||
"auto_escalate": False,
|
||||
"user": {"permissions": ["read", "write"]},
|
||||
"pending_user_input": False,
|
||||
}
|
||||
|
||||
interrupt_router = human_interrupt()
|
||||
status_router = pass_status_to_user()
|
||||
feedback_router = user_feedback_loop()
|
||||
escalation_router = escalate_to_human()
|
||||
auth_router = check_user_authorization(["read", "write"])
|
||||
input_router = collect_user_input()
|
||||
|
||||
assert interrupt_router(state) == "continue"
|
||||
assert status_router(state) == "medium"
|
||||
assert feedback_router(state) == "no_feedback"
|
||||
assert escalation_router(state) == "continue_automated"
|
||||
assert auth_router(state) == "authorized"
|
||||
assert input_router(state) == "no_input_needed"
|
||||
|
||||
def test_state_protocol_consistency_across_functions(self):
|
||||
"""Test StateProtocol consistency across all functions."""
|
||||
state_data = {
|
||||
"human_interrupt": "stop",
|
||||
"status": "error",
|
||||
"notify_user": True,
|
||||
"requires_feedback": True,
|
||||
"feedback_type": "urgent",
|
||||
"critical_error": True,
|
||||
"user": {"permissions": ["admin"]},
|
||||
"pending_user_input": True,
|
||||
"input_type": "choice",
|
||||
}
|
||||
|
||||
dict_state = state_data
|
||||
protocol_state = MockState(state_data)
|
||||
|
||||
# All functions should return same results for both state types
|
||||
interrupt_router = human_interrupt()
|
||||
assert interrupt_router(dict_state) == interrupt_router(protocol_state)
|
||||
|
||||
status_router = pass_status_to_user()
|
||||
assert status_router(dict_state) == status_router(protocol_state)
|
||||
|
||||
feedback_router = user_feedback_loop()
|
||||
assert feedback_router(dict_state) == feedback_router(protocol_state)
|
||||
|
||||
escalation_router = escalate_to_human()
|
||||
assert escalation_router(dict_state) == escalation_router(protocol_state)
|
||||
|
||||
auth_router = check_user_authorization(["admin"])
|
||||
assert auth_router(dict_state) == auth_router(protocol_state)
|
||||
|
||||
input_router = collect_user_input()
|
||||
assert input_router(dict_state) == input_router(protocol_state)
|
||||
|
||||
def test_cascading_user_interaction_scenario(self):
|
||||
"""Test cascading user interaction scenario."""
|
||||
# Start with normal operation
|
||||
state = {"status": "in_progress", "notify_user": True}
|
||||
|
||||
status_router = pass_status_to_user()
|
||||
assert status_router(state) == "info"
|
||||
|
||||
# User requests interrupt
|
||||
state["human_interrupt"] = "pause"
|
||||
interrupt_router = human_interrupt()
|
||||
assert interrupt_router(state) == "interrupt"
|
||||
|
||||
# System needs feedback due to interrupt
|
||||
state["requires_feedback"] = True
|
||||
state["feedback_type"] = "confirmation"
|
||||
feedback_router = user_feedback_loop()
|
||||
assert feedback_router(state) == "feedback_confirmation"
|
||||
|
||||
# Collect user input for feedback
|
||||
state["pending_user_input"] = True
|
||||
state["input_type"] = "confirmation"
|
||||
input_router = collect_user_input()
|
||||
assert input_router(state) == "confirm_input"
|
||||
|
||||
def test_security_and_authorization_flow(self):
|
||||
"""Test security and authorization flow."""
|
||||
# User with insufficient permissions
|
||||
unauthorized_user = {"permissions": ["read"]}
|
||||
state: dict[str, Any] = {"user": unauthorized_user}
|
||||
|
||||
auth_router = check_user_authorization(["read", "write", "admin"])
|
||||
assert auth_router(state) == "unauthorized"
|
||||
|
||||
# Should escalate due to unauthorized access
|
||||
state["security_issue"] = True
|
||||
escalation_router = escalate_to_human()
|
||||
assert escalation_router(state) == "escalate_immediate"
|
||||
|
||||
# Should notify user of security issue
|
||||
state["status"] = "failed"
|
||||
state["notify_user"] = True
|
||||
status_router = pass_status_to_user()
|
||||
assert status_router(state) == "urgent"
|
||||
|
||||
def test_error_recovery_interaction_flow(self):
|
||||
"""Test error recovery interaction flow."""
|
||||
# Multiple failures trigger escalation
|
||||
state: dict[str, Any] = {"multiple_failures": True}
|
||||
escalation_router = escalate_to_human()
|
||||
assert escalation_router(state) == "escalate_urgent"
|
||||
|
||||
# Low confidence requires feedback
|
||||
state["low_confidence"] = True
|
||||
feedback_router = user_feedback_loop()
|
||||
assert feedback_router(state) == "feedback_low_confidence"
|
||||
|
||||
# Need user input to resolve
|
||||
state["pending_user_input"] = True
|
||||
state["input_type"] = "choice"
|
||||
input_router = collect_user_input()
|
||||
assert input_router(state) == "choice_input"
|
||||
|
||||
# User might interrupt the process
|
||||
state["human_interrupt"] = "abort"
|
||||
interrupt_router = human_interrupt()
|
||||
assert interrupt_router(state) == "interrupt"
|
||||
|
||||
def test_performance_with_complex_user_objects(self):
|
||||
"""Test performance with complex user objects."""
|
||||
# Create complex user object
|
||||
complex_user = {
|
||||
"id": "user123",
|
||||
"profile": {
|
||||
"name": "John Doe",
|
||||
"department": "Engineering",
|
||||
"level": "Senior",
|
||||
},
|
||||
"permissions": ["read", "write", "deploy"]
|
||||
+ [f"perm_{i}" for i in range(100)],
|
||||
"metadata": {
|
||||
"login_count": 1500,
|
||||
"last_active": "2024-01-01",
|
||||
"preferences": {"theme": "dark", "notifications": True},
|
||||
},
|
||||
}
|
||||
|
||||
state = {"user": complex_user}
|
||||
auth_router = check_user_authorization(["read", "write"])
|
||||
|
||||
# Should handle complex user object efficiently
|
||||
assert auth_router(state) == "authorized"
|
||||
|
||||
def test_concurrent_router_usage(self):
|
||||
"""Test that routers are stateless and thread-safe."""
|
||||
interrupt_router = human_interrupt()
|
||||
|
||||
# Multiple calls with different states should not interfere
|
||||
state1 = {"human_interrupt": "stop"}
|
||||
state2 = {"human_interrupt": "continue"}
|
||||
|
||||
assert interrupt_router(state1) == "interrupt"
|
||||
assert interrupt_router(state2) == "continue"
|
||||
assert interrupt_router(state1) == "interrupt" # Should still work
|
||||
|
||||
def test_memory_efficiency_with_large_permission_lists(self):
|
||||
"""Test memory efficiency with large permission lists."""
|
||||
# Create user with many permissions
|
||||
large_permissions = [f"permission_{i}" for i in range(1000)]
|
||||
user = {"permissions": large_permissions}
|
||||
state = {"user": user}
|
||||
|
||||
# Should handle large permission lists efficiently
|
||||
auth_router = check_user_authorization(["permission_500"])
|
||||
assert auth_router(state) == "authorized"
|
||||
|
||||
auth_router = check_user_authorization(["permission_not_in_list"])
|
||||
assert auth_router(state) == "unauthorized"
|
||||
@@ -0,0 +1,813 @@
|
||||
"""Comprehensive tests for validation edge helpers.
|
||||
|
||||
Tests cover edge cases including:
|
||||
- Malformed JSON, XML, CSV formats
|
||||
- Invalid confidence and accuracy scores
|
||||
- Privacy pattern edge cases
|
||||
- Timestamp parsing failures
|
||||
- Required field validation edge cases
|
||||
- Length validation boundary conditions
|
||||
"""
|
||||
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from bb_core.edge_helpers.validation import (
|
||||
check_accuracy,
|
||||
check_confidence_level,
|
||||
check_data_freshness,
|
||||
check_output_length,
|
||||
check_privacy_compliance,
|
||||
validate_output_format,
|
||||
validate_required_fields,
|
||||
)
|
||||
|
||||
|
||||
class MockState:
|
||||
"""Mock state object for testing."""
|
||||
|
||||
def __init__(self, data: dict[str, Any]):
|
||||
self._data = data
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
return self._data.get(key, default)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return self._data.get(name)
|
||||
|
||||
|
||||
class TestCheckAccuracy:
|
||||
"""Test check_accuracy function with various edge cases."""
|
||||
|
||||
def test_high_accuracy_scores(self):
|
||||
"""Test scores above threshold."""
|
||||
router = check_accuracy(threshold=0.8)
|
||||
|
||||
assert router({"accuracy_score": 0.85}) == "high_accuracy"
|
||||
assert router({"accuracy_score": 1.0}) == "high_accuracy"
|
||||
assert router({"accuracy_score": 0.8}) == "high_accuracy" # Equal to threshold
|
||||
|
||||
def test_low_accuracy_scores(self):
|
||||
"""Test scores below threshold."""
|
||||
router = check_accuracy(threshold=0.8)
|
||||
|
||||
assert router({"accuracy_score": 0.75}) == "low_accuracy"
|
||||
assert router({"accuracy_score": 0.0}) == "low_accuracy"
|
||||
assert router({"accuracy_score": 0.79}) == "low_accuracy"
|
||||
|
||||
def test_string_numeric_conversion(self):
|
||||
"""Test conversion of string numbers."""
|
||||
router = check_accuracy(threshold=0.5)
|
||||
|
||||
assert router({"accuracy_score": "0.8"}) == "high_accuracy"
|
||||
assert router({"accuracy_score": "0.3"}) == "low_accuracy"
|
||||
assert router({"accuracy_score": "1"}) == "high_accuracy"
|
||||
assert router({"accuracy_score": "0"}) == "low_accuracy"
|
||||
|
||||
def test_invalid_accuracy_values(self):
|
||||
"""Test handling of invalid accuracy values."""
|
||||
router = check_accuracy(threshold=0.5)
|
||||
|
||||
# Non-numeric values should route to low_accuracy
|
||||
assert router({"accuracy_score": "invalid"}) == "low_accuracy"
|
||||
assert router({"accuracy_score": []}) == "low_accuracy"
|
||||
assert router({"accuracy_score": {}}) == "low_accuracy"
|
||||
assert router({"accuracy_score": None}) == "low_accuracy"
|
||||
|
||||
def test_missing_accuracy_key(self):
|
||||
"""Test behavior when accuracy key is missing."""
|
||||
router = check_accuracy(threshold=0.5)
|
||||
|
||||
# Should default to 0.0 and route to low_accuracy
|
||||
assert router({}) == "low_accuracy"
|
||||
assert router({"other_key": 0.9}) == "low_accuracy"
|
||||
|
||||
def test_custom_accuracy_key(self):
|
||||
"""Test with custom accuracy key."""
|
||||
router = check_accuracy(threshold=0.7, accuracy_key="confidence_score")
|
||||
|
||||
assert router({"confidence_score": 0.8}) == "high_accuracy"
|
||||
assert router({"confidence_score": 0.6}) == "low_accuracy"
|
||||
|
||||
def test_extreme_values(self):
|
||||
"""Test extreme accuracy values."""
|
||||
router = check_accuracy(threshold=0.5)
|
||||
|
||||
# Values above 1.0
|
||||
assert router({"accuracy_score": 1.5}) == "high_accuracy"
|
||||
assert router({"accuracy_score": 100.0}) == "high_accuracy"
|
||||
|
||||
# Negative values
|
||||
assert router({"accuracy_score": -0.1}) == "low_accuracy"
|
||||
assert router({"accuracy_score": -10.0}) == "low_accuracy"
|
||||
|
||||
def test_state_protocol_implementation(self):
|
||||
"""Test with StateProtocol implementation."""
|
||||
router = check_accuracy(threshold=0.8)
|
||||
|
||||
state = MockState({"accuracy_score": 0.9})
|
||||
assert router(state) == "high_accuracy"
|
||||
|
||||
state = MockState({"accuracy_score": 0.7})
|
||||
assert router(state) == "low_accuracy"
|
||||
|
||||
|
||||
class TestCheckConfidenceLevel:
|
||||
"""Test check_confidence_level function with various edge cases."""
|
||||
|
||||
def test_high_confidence_scores(self):
|
||||
"""Test scores above threshold."""
|
||||
router = check_confidence_level(threshold=0.7)
|
||||
|
||||
assert router({"confidence": 0.8}) == "high_confidence"
|
||||
assert router({"confidence": 1.0}) == "high_confidence"
|
||||
assert router({"confidence": 0.7}) == "high_confidence" # Equal to threshold
|
||||
|
||||
def test_low_confidence_scores(self):
|
||||
"""Test scores below threshold."""
|
||||
router = check_confidence_level(threshold=0.7)
|
||||
|
||||
assert router({"confidence": 0.6}) == "low_confidence"
|
||||
assert router({"confidence": 0.0}) == "low_confidence"
|
||||
assert router({"confidence": 0.69}) == "low_confidence"
|
||||
|
||||
def test_invalid_confidence_values(self):
|
||||
"""Test handling of invalid confidence values."""
|
||||
router = check_confidence_level(threshold=0.5)
|
||||
|
||||
# Non-numeric values should route to low_confidence
|
||||
assert router({"confidence": "high"}) == "low_confidence"
|
||||
assert router({"confidence": None}) == "low_confidence"
|
||||
assert router({"confidence": []}) == "low_confidence"
|
||||
|
||||
def test_custom_confidence_key(self):
|
||||
"""Test with custom confidence key."""
|
||||
router = check_confidence_level(threshold=0.6, confidence_key="certainty")
|
||||
|
||||
assert router({"certainty": 0.8}) == "high_confidence"
|
||||
assert router({"certainty": 0.5}) == "low_confidence"
|
||||
|
||||
def test_percentage_values(self):
|
||||
"""Test percentage-style values."""
|
||||
router = check_confidence_level(threshold=70.0)
|
||||
|
||||
assert router({"confidence": 80.0}) == "high_confidence"
|
||||
assert router({"confidence": 60.0}) == "low_confidence"
|
||||
assert router({"confidence": "75"}) == "high_confidence"
|
||||
|
||||
|
||||
class TestValidateOutputFormat:
|
||||
"""Test validate_output_format function with various edge cases."""
|
||||
|
||||
def test_valid_json_format(self):
|
||||
"""Test valid JSON format validation."""
|
||||
router = validate_output_format(expected_format="json")
|
||||
|
||||
valid_json = '{"key": "value", "number": 123}'
|
||||
assert router({"output": valid_json}) == "valid_format"
|
||||
|
||||
# Complex JSON
|
||||
complex_json = (
|
||||
'{"users": [{"name": "John", "age": 30}, {"name": "Jane", "age": 25}]}'
|
||||
)
|
||||
assert router({"output": complex_json}) == "valid_format"
|
||||
|
||||
def test_invalid_json_format(self):
|
||||
"""Test invalid JSON format validation."""
|
||||
router = validate_output_format(expected_format="json")
|
||||
|
||||
# Malformed JSON
|
||||
assert (
|
||||
router({"output": '{"key": value}'}) == "invalid_format"
|
||||
) # Missing quotes
|
||||
assert router({"output": '{key: "value"}'}) == "invalid_format" # Unquoted key
|
||||
assert (
|
||||
router({"output": '{"key": "value",}'}) == "invalid_format"
|
||||
) # Trailing comma
|
||||
assert router({"output": "not json at all"}) == "invalid_format"
|
||||
|
||||
def test_valid_xml_format(self):
|
||||
"""Test valid XML format validation."""
|
||||
router = validate_output_format(expected_format="xml")
|
||||
|
||||
assert router({"output": "<root><item>value</item></root>"}) == "valid_format"
|
||||
assert router({"output": "<simple/>"}) == "valid_format"
|
||||
assert (
|
||||
router({"output": " <root>content</root> "}) == "valid_format"
|
||||
) # With whitespace
|
||||
|
||||
def test_invalid_xml_format(self):
|
||||
"""Test invalid XML format validation."""
|
||||
router = validate_output_format(expected_format="xml")
|
||||
|
||||
assert router({"output": "not xml"}) == "invalid_format"
|
||||
assert router({"output": "<unclosed"}) == "invalid_format"
|
||||
assert router({"output": "unclosed>"}) == "invalid_format"
|
||||
assert router({"output": ""}) == "invalid_format"
|
||||
|
||||
def test_valid_csv_format(self):
|
||||
"""Test valid CSV format validation."""
|
||||
router = validate_output_format(expected_format="csv")
|
||||
|
||||
assert router({"output": "name,age,city\\nJohn,30,NYC"}) == "valid_format"
|
||||
assert router({"output": "col1,col2,col3"}) == "valid_format"
|
||||
assert router({"output": "a,b,c\\n1,2,3\\n4,5,6"}) == "valid_format"
|
||||
|
||||
def test_invalid_csv_format(self):
|
||||
"""Test invalid CSV format validation."""
|
||||
router = validate_output_format(expected_format="csv")
|
||||
|
||||
assert router({"output": "no commas here"}) == "invalid_format"
|
||||
assert router({"output": ""}) == "invalid_format"
|
||||
assert router({"output": "single_column"}) == "invalid_format"
|
||||
|
||||
def test_text_format_validation(self):
|
||||
"""Test text format validation."""
|
||||
router = validate_output_format(expected_format="text")
|
||||
|
||||
assert router({"output": "This is valid text"}) == "valid_format"
|
||||
assert router({"output": "123"}) == "valid_format"
|
||||
assert router({"output": "Multi\\nline\\ntext"}) == "valid_format"
|
||||
|
||||
# Empty or non-string should be invalid
|
||||
assert router({"output": ""}) == "invalid_format"
|
||||
assert router({"output": None}) == "invalid_format"
|
||||
|
||||
def test_custom_output_key(self):
|
||||
"""Test with custom output key."""
|
||||
router = validate_output_format(expected_format="json", output_key="result")
|
||||
|
||||
assert router({"result": '{"valid": true}'}) == "valid_format"
|
||||
assert router({"result": "invalid json"}) == "invalid_format"
|
||||
|
||||
def test_missing_output_key(self):
|
||||
"""Test behavior when output key is missing."""
|
||||
router = validate_output_format(expected_format="json")
|
||||
|
||||
assert router({}) == "invalid_format"
|
||||
assert router({"other_key": '{"valid": true}'}) == "invalid_format"
|
||||
|
||||
def test_case_insensitive_format(self):
|
||||
"""Test case insensitive format specification."""
|
||||
json_router = validate_output_format(expected_format="JSON")
|
||||
xml_router = validate_output_format(expected_format="XML")
|
||||
csv_router = validate_output_format(expected_format="CSV")
|
||||
|
||||
assert json_router({"output": '{"key": "value"}'}) == "valid_format"
|
||||
assert xml_router({"output": "<root/>"}) == "valid_format"
|
||||
assert csv_router({"output": "a,b,c"}) == "valid_format"
|
||||
|
||||
def test_numeric_output_handling(self):
|
||||
"""Test handling of numeric output values."""
|
||||
router = validate_output_format(expected_format="json")
|
||||
|
||||
# Numbers should be converted to strings for validation
|
||||
assert (
|
||||
router({"output": 123}) == "valid_format"
|
||||
) # Valid JSON format (123 is valid JSON)
|
||||
|
||||
text_router = validate_output_format(expected_format="text")
|
||||
assert text_router({"output": 123}) == "valid_format" # Valid as text
|
||||
|
||||
|
||||
class TestCheckPrivacyCompliance:
|
||||
"""Test check_privacy_compliance function with various edge cases."""
|
||||
|
||||
def test_ssn_detection(self):
|
||||
"""Test SSN pattern detection."""
|
||||
router = check_privacy_compliance()
|
||||
|
||||
# Valid SSN patterns
|
||||
assert router({"content": "SSN: 123-45-6789"}) == "privacy_violation"
|
||||
assert router({"content": "Contact info: 987-65-4321"}) == "privacy_violation"
|
||||
|
||||
# Invalid SSN patterns (but these might match other patterns)
|
||||
assert (
|
||||
router({"content": "Phone: 123-456-7890"}) == "privacy_violation"
|
||||
) # Matches phone pattern
|
||||
assert (
|
||||
router({"content": "Code: 12-34-567"}) == "compliant"
|
||||
) # Wrong format for any pattern
|
||||
|
||||
def test_credit_card_detection(self):
|
||||
"""Test credit card pattern detection."""
|
||||
router = check_privacy_compliance()
|
||||
|
||||
# Valid credit card patterns
|
||||
assert router({"content": "Card: 1234 5678 9012 3456"}) == "privacy_violation"
|
||||
assert router({"content": "CC: 1234-5678-9012-3456"}) == "privacy_violation"
|
||||
assert router({"content": "Number: 1234567890123456"}) == "privacy_violation"
|
||||
|
||||
# Invalid patterns
|
||||
assert router({"content": "Code: 123 456 789"}) == "compliant" # Too short
|
||||
|
||||
def test_email_detection(self):
|
||||
"""Test email pattern detection."""
|
||||
router = check_privacy_compliance()
|
||||
|
||||
# Valid email patterns
|
||||
assert router({"content": "Contact: john@example.com"}) == "privacy_violation"
|
||||
assert (
|
||||
router({"content": "Email: user.name+tag@domain.co.uk"})
|
||||
== "privacy_violation"
|
||||
)
|
||||
|
||||
# Invalid email patterns
|
||||
assert router({"content": "Not an email: john@"}) == "compliant"
|
||||
assert router({"content": "@example.com"}) == "compliant"
|
||||
|
||||
def test_phone_number_detection(self):
|
||||
"""Test phone number pattern detection."""
|
||||
router = check_privacy_compliance()
|
||||
|
||||
# Valid phone patterns
|
||||
assert router({"content": "Call: 123-456-7890"}) == "privacy_violation"
|
||||
assert router({"content": "Phone: 123 456 7890"}) == "privacy_violation"
|
||||
assert router({"content": "Contact: 1234567890"}) == "privacy_violation"
|
||||
|
||||
# Invalid patterns
|
||||
assert (
|
||||
router({"content": "Code: 123-45-6789"}) == "privacy_violation"
|
||||
) # Matches SSN pattern
|
||||
|
||||
def test_custom_patterns(self):
|
||||
"""Test custom sensitive patterns."""
|
||||
custom_patterns = [
|
||||
r"\bAPI[_-]?KEY[_-]?\w+", # API key pattern
|
||||
r"\bTOKEN[_-]?\w{10,}", # Token pattern
|
||||
]
|
||||
router = check_privacy_compliance(sensitive_patterns=custom_patterns)
|
||||
|
||||
assert router({"content": "API_KEY_abc123def456"}) == "privacy_violation"
|
||||
assert router({"content": "TOKEN_1234567890abcdef"}) == "privacy_violation"
|
||||
assert router({"content": "Normal text here"}) == "compliant"
|
||||
|
||||
def test_case_insensitive_matching(self):
|
||||
"""Test case insensitive pattern matching."""
|
||||
router = check_privacy_compliance()
|
||||
|
||||
# Email should match regardless of case
|
||||
assert router({"content": "CONTACT: JOHN@EXAMPLE.COM"}) == "privacy_violation"
|
||||
assert router({"content": "Email: John@Example.Com"}) == "privacy_violation"
|
||||
|
||||
def test_multiple_violations(self):
|
||||
"""Test content with multiple privacy violations."""
|
||||
router = check_privacy_compliance()
|
||||
|
||||
content = (
|
||||
"Contact John at john@example.com or call 123-456-7890. SSN: 123-45-6789"
|
||||
)
|
||||
assert router({"content": content}) == "privacy_violation"
|
||||
|
||||
def test_custom_content_key(self):
|
||||
"""Test with custom content key."""
|
||||
router = check_privacy_compliance(content_key="message")
|
||||
|
||||
assert router({"message": "Email: test@example.com"}) == "privacy_violation"
|
||||
assert router({"message": "Clean content"}) == "compliant"
|
||||
|
||||
def test_missing_content_key(self):
|
||||
"""Test behavior when content key is missing."""
|
||||
router = check_privacy_compliance()
|
||||
|
||||
assert router({}) == "compliant"
|
||||
assert router({"other_key": "john@example.com"}) == "compliant"
|
||||
|
||||
def test_none_content(self):
|
||||
"""Test behavior with None content."""
|
||||
router = check_privacy_compliance()
|
||||
|
||||
assert router({"content": None}) == "compliant"
|
||||
|
||||
def test_numeric_content(self):
|
||||
"""Test handling of numeric content."""
|
||||
router = check_privacy_compliance()
|
||||
|
||||
# Should convert to string for matching
|
||||
assert (
|
||||
router({"content": 1234567890123456}) == "privacy_violation"
|
||||
) # Credit card
|
||||
assert router({"content": 123456789}) == "compliant" # Too short
|
||||
|
||||
def test_empty_patterns_list(self):
|
||||
"""Test with empty patterns list."""
|
||||
router = check_privacy_compliance(sensitive_patterns=[])
|
||||
|
||||
# No patterns means everything is compliant
|
||||
assert router({"content": "john@example.com"}) == "compliant"
|
||||
assert router({"content": "123-45-6789"}) == "compliant"
|
||||
|
||||
|
||||
class TestCheckDataFreshness:
|
||||
"""Test check_data_freshness function with various edge cases."""
|
||||
|
||||
@patch("bb_core.edge_helpers.validation.time")
|
||||
def test_fresh_data_within_limit(self, mock_time_module):
|
||||
"""Test data within freshness limit."""
|
||||
mock_time_module.time.return_value = 1000.0
|
||||
router = check_data_freshness(max_age_seconds=3600) # 1 hour
|
||||
|
||||
# Data from 30 minutes ago
|
||||
assert router({"timestamp": 800.0}) == "fresh"
|
||||
|
||||
@patch("bb_core.edge_helpers.validation.time")
|
||||
def test_stale_data_beyond_limit(self, mock_time_module):
|
||||
"""Test data beyond freshness limit."""
|
||||
mock_time_module.time.return_value = 1000.0
|
||||
router = check_data_freshness(max_age_seconds=3600) # 1 hour
|
||||
|
||||
# Data from 2 hours ago
|
||||
assert router({"timestamp": -600.0}) == "stale"
|
||||
|
||||
@patch("bb_core.edge_helpers.validation.time")
|
||||
def test_exact_freshness_boundary(self, mock_time_module):
|
||||
"""Test exact freshness boundary."""
|
||||
mock_time_module.time.return_value = 1000.0
|
||||
router = check_data_freshness(max_age_seconds=3600)
|
||||
|
||||
# Exactly at limit
|
||||
assert router({"timestamp": -2600.0}) == "fresh"
|
||||
|
||||
# Just over limit
|
||||
assert router({"timestamp": -2601.0}) == "stale"
|
||||
|
||||
def test_iso_timestamp_parsing(self):
|
||||
"""Test ISO format timestamp parsing."""
|
||||
router = check_data_freshness(max_age_seconds=3600)
|
||||
|
||||
# Recent ISO timestamp
|
||||
recent_iso = datetime.now().isoformat()
|
||||
assert router({"timestamp": recent_iso}) == "fresh"
|
||||
|
||||
# Old ISO timestamp
|
||||
old_iso = "2020-01-01T00:00:00Z"
|
||||
assert router({"timestamp": old_iso}) == "stale"
|
||||
|
||||
def test_invalid_timestamp_formats(self):
|
||||
"""Test handling of invalid timestamp formats."""
|
||||
router = check_data_freshness(max_age_seconds=3600)
|
||||
|
||||
# Invalid formats should be considered stale
|
||||
assert router({"timestamp": "invalid"}) == "stale"
|
||||
assert router({"timestamp": "not-a-date"}) == "stale"
|
||||
assert router({"timestamp": []}) == "stale"
|
||||
assert router({"timestamp": {}}) == "stale"
|
||||
|
||||
def test_missing_timestamp_key(self):
|
||||
"""Test behavior when timestamp key is missing."""
|
||||
router = check_data_freshness(max_age_seconds=3600)
|
||||
|
||||
assert router({}) == "stale"
|
||||
assert router({"other_key": time.time()}) == "stale"
|
||||
|
||||
def test_none_timestamp(self):
|
||||
"""Test behavior with None timestamp."""
|
||||
router = check_data_freshness(max_age_seconds=3600)
|
||||
|
||||
assert router({"timestamp": None}) == "stale"
|
||||
|
||||
def test_custom_timestamp_key(self):
|
||||
"""Test with custom timestamp key."""
|
||||
router = check_data_freshness(max_age_seconds=1800, timestamp_key="created_at")
|
||||
|
||||
with patch("bb_core.edge_helpers.validation.time") as mock_time_module:
|
||||
mock_time_module.time.return_value = 1000.0
|
||||
assert router({"created_at": 500.0}) == "fresh" # 500 seconds ago
|
||||
assert router({"created_at": -1000.0}) == "stale" # 2000 seconds ago
|
||||
|
||||
def test_string_numeric_timestamp(self):
|
||||
"""Test string numeric timestamp conversion."""
|
||||
router = check_data_freshness(max_age_seconds=3600)
|
||||
|
||||
with patch("bb_core.edge_helpers.validation.time") as mock_time_module:
|
||||
mock_time_module.time.return_value = 1000.0
|
||||
assert router({"timestamp": "800.0"}) == "fresh"
|
||||
assert router({"timestamp": "-600.0"}) == "stale"
|
||||
|
||||
def test_zero_max_age(self):
|
||||
"""Test with zero max age."""
|
||||
router = check_data_freshness(max_age_seconds=0)
|
||||
|
||||
with patch("bb_core.edge_helpers.validation.time") as mock_time_module:
|
||||
mock_time_module.time.return_value = 1000.0
|
||||
# Even current time should be stale with 0 max age
|
||||
assert router({"timestamp": 1000.0}) == "fresh" # Exactly current
|
||||
assert router({"timestamp": 999.9}) == "stale" # Slightly old
|
||||
|
||||
|
||||
class TestValidateRequiredFields:
|
||||
"""Test validate_required_fields function with various edge cases."""
|
||||
|
||||
def test_all_fields_present(self):
|
||||
"""Test when all required fields are present."""
|
||||
router = validate_required_fields(["field1", "field2", "field3"])
|
||||
|
||||
state = {"field1": "value1", "field2": "value2", "field3": "value3"}
|
||||
assert router(state) == "valid"
|
||||
|
||||
def test_missing_fields(self):
|
||||
"""Test when some fields are missing."""
|
||||
router = validate_required_fields(["field1", "field2", "field3"])
|
||||
|
||||
# Missing field2
|
||||
state = {"field1": "value1", "field3": "value3"}
|
||||
assert router(state) == "missing_fields"
|
||||
|
||||
# All missing
|
||||
assert router({}) == "missing_fields"
|
||||
|
||||
def test_strict_mode_empty_values(self):
|
||||
"""Test strict mode with empty values."""
|
||||
router = validate_required_fields(["field1", "field2"], strict_mode=True)
|
||||
|
||||
# Empty string should fail in strict mode
|
||||
state = {"field1": "value", "field2": ""}
|
||||
assert router(state) == "missing_fields"
|
||||
|
||||
# Empty list should fail in strict mode
|
||||
state = {"field1": "value", "field2": []}
|
||||
assert router(state) == "missing_fields"
|
||||
|
||||
def test_non_strict_mode_empty_values(self):
|
||||
"""Test non-strict mode with empty values."""
|
||||
router = validate_required_fields(["field1", "field2"], strict_mode=False)
|
||||
|
||||
# Empty string should pass in non-strict mode
|
||||
state = {"field1": "value", "field2": ""}
|
||||
assert router(state) == "valid"
|
||||
|
||||
# Empty list should pass in non-strict mode
|
||||
state = {"field1": "value", "field2": []}
|
||||
assert router(state) == "valid"
|
||||
|
||||
def test_none_values(self):
|
||||
"""Test handling of None values."""
|
||||
router = validate_required_fields(["field1", "field2"])
|
||||
|
||||
# None should always fail regardless of strict mode
|
||||
state = {"field1": "value", "field2": None}
|
||||
assert router(state) == "missing_fields"
|
||||
|
||||
def test_falsy_values_in_non_strict_mode(self):
|
||||
"""Test falsy values in non-strict mode."""
|
||||
router = validate_required_fields(["flag", "count"], strict_mode=False)
|
||||
|
||||
# False and 0 should be valid in non-strict mode
|
||||
state = {"flag": False, "count": 0}
|
||||
assert router(state) == "valid"
|
||||
|
||||
def test_falsy_values_in_strict_mode(self):
|
||||
"""Test falsy values in strict mode."""
|
||||
router = validate_required_fields(["flag", "count"], strict_mode=True)
|
||||
|
||||
# False and 0 should be valid even in strict mode (not empty)
|
||||
state = {"flag": False, "count": 0}
|
||||
assert router(state) == "valid"
|
||||
|
||||
def test_state_protocol_implementation(self):
|
||||
"""Test with StateProtocol implementation."""
|
||||
router = validate_required_fields(["field1", "field2"])
|
||||
|
||||
# Valid case
|
||||
state = MockState({"field1": "value1", "field2": "value2"})
|
||||
assert router(state) == "valid"
|
||||
|
||||
# Missing field case
|
||||
state = MockState({"field1": "value1"})
|
||||
assert router(state) == "missing_fields"
|
||||
|
||||
def test_complex_data_types(self):
|
||||
"""Test with complex data types as field values."""
|
||||
router = validate_required_fields(["data", "config"])
|
||||
|
||||
state = {
|
||||
"data": {"nested": {"value": 123}},
|
||||
"config": ["item1", "item2", {"key": "value"}],
|
||||
}
|
||||
assert router(state) == "valid"
|
||||
|
||||
def test_empty_required_fields_list(self):
|
||||
"""Test with empty required fields list."""
|
||||
router = validate_required_fields([])
|
||||
|
||||
# No requirements means always valid
|
||||
assert router({}) == "valid"
|
||||
assert router({"any_field": "any_value"}) == "valid"
|
||||
|
||||
|
||||
class TestCheckOutputLength:
|
||||
"""Test check_output_length function with various edge cases."""
|
||||
|
||||
def test_valid_length_within_bounds(self):
|
||||
"""Test output length within bounds."""
|
||||
router = check_output_length(min_length=5, max_length=20)
|
||||
|
||||
assert router({"output": "12345"}) == "valid_length" # Exactly min
|
||||
assert router({"output": "1234567890"}) == "valid_length" # Middle
|
||||
assert (
|
||||
router({"output": "12345678901234567890"}) == "valid_length"
|
||||
) # Exactly max
|
||||
|
||||
def test_output_too_short(self):
|
||||
"""Test output shorter than minimum."""
|
||||
router = check_output_length(min_length=10)
|
||||
|
||||
assert router({"output": "short"}) == "too_short"
|
||||
assert router({"output": ""}) == "too_short"
|
||||
assert router({"output": "123456789"}) == "too_short" # 9 chars, min is 10
|
||||
|
||||
def test_output_too_long(self):
|
||||
"""Test output longer than maximum."""
|
||||
router = check_output_length(min_length=1, max_length=10)
|
||||
|
||||
assert router({"output": "12345678901"}) == "too_long" # 11 chars, max is 10
|
||||
assert router({"output": "a" * 100}) == "too_long"
|
||||
|
||||
def test_no_max_length_limit(self):
|
||||
"""Test with no maximum length limit."""
|
||||
router = check_output_length(min_length=5, max_length=None)
|
||||
|
||||
assert router({"output": "12345"}) == "valid_length"
|
||||
assert (
|
||||
router({"output": "a" * 1000}) == "valid_length"
|
||||
) # Very long should be OK
|
||||
|
||||
def test_custom_content_key(self):
|
||||
"""Test with custom content key."""
|
||||
router = check_output_length(min_length=3, max_length=10, content_key="result")
|
||||
|
||||
assert router({"result": "hello"}) == "valid_length"
|
||||
assert router({"result": "hi"}) == "too_short"
|
||||
assert router({"result": "very long result text"}) == "too_long"
|
||||
|
||||
def test_missing_content_key(self):
|
||||
"""Test behavior when content key is missing."""
|
||||
router = check_output_length(min_length=5)
|
||||
|
||||
# Should default to empty string and be too short
|
||||
assert router({}) == "too_short"
|
||||
assert router({"other_key": "valid content"}) == "too_short"
|
||||
|
||||
def test_none_content(self):
|
||||
"""Test behavior with None content."""
|
||||
router = check_output_length(min_length=1)
|
||||
|
||||
# None should be converted to string "None" (4 chars)
|
||||
assert router({"output": None}) == "valid_length"
|
||||
|
||||
def test_numeric_content(self):
|
||||
"""Test handling of numeric content."""
|
||||
router = check_output_length(min_length=3, max_length=5)
|
||||
|
||||
assert router({"output": 123}) == "valid_length" # "123" = 3 chars
|
||||
assert router({"output": 12}) == "too_short" # "12" = 2 chars
|
||||
assert router({"output": 123456}) == "too_long" # "123456" = 6 chars
|
||||
|
||||
def test_zero_min_length(self):
|
||||
"""Test with zero minimum length."""
|
||||
router = check_output_length(min_length=0, max_length=5)
|
||||
|
||||
assert router({"output": ""}) == "valid_length"
|
||||
assert router({"output": "hello"}) == "valid_length"
|
||||
assert router({"output": "toolong"}) == "too_long"
|
||||
|
||||
def test_equal_min_max_length(self):
|
||||
"""Test with equal min and max length."""
|
||||
router = check_output_length(min_length=5, max_length=5)
|
||||
|
||||
assert router({"output": "12345"}) == "valid_length"
|
||||
assert router({"output": "1234"}) == "too_short"
|
||||
assert router({"output": "123456"}) == "too_long"
|
||||
|
||||
def test_unicode_character_counting(self):
|
||||
"""
|
||||
Test Unicode character counting.
|
||||
|
||||
Note: Length uses Python's built-in len(), which counts Unicode code points,
|
||||
not grapheme clusters. Some user-perceived characters (like certain emojis or
|
||||
emoji sequences) may be counted as more than one character.
|
||||
"""
|
||||
router = check_output_length(min_length=3, max_length=5)
|
||||
|
||||
# Unicode code points are counted; some emojis may be more than one.
|
||||
# For example, "🔥❄️🌞" is 4 code points: '🔥', '❄', '️' (variation), '🌞'
|
||||
assert (
|
||||
router({"output": "🔥❄️🌞"}) == "valid_length"
|
||||
) # 4 code points (note: '❄️' is two code points: '❄' + '️')
|
||||
assert router({"output": "café"}) == "valid_length" # 4 code points with é
|
||||
# "🔥❄️" is 3 code points: '🔥', '❄', '️'
|
||||
assert (
|
||||
router({"output": "🔥❄️"}) == "valid_length"
|
||||
) # 3 code points (emojis may be multi-codepoint)
|
||||
|
||||
def test_multiline_content(self):
|
||||
"""Test multiline content length calculation."""
|
||||
router = check_output_length(min_length=10, max_length=20)
|
||||
|
||||
multiline = "line1\\nline2\\nline3" # 17 chars including newlines
|
||||
assert router({"output": multiline}) == "valid_length"
|
||||
|
||||
def test_state_protocol_implementation(self):
|
||||
"""Test with StateProtocol implementation."""
|
||||
router = check_output_length(min_length=5, max_length=15)
|
||||
|
||||
state = MockState({"output": "valid output"})
|
||||
assert router(state) == "valid_length"
|
||||
|
||||
state = MockState({"output": "hi"})
|
||||
assert router(state) == "too_short"
|
||||
|
||||
|
||||
class TestIntegrationAndEdgeCases:
|
||||
"""Test integration scenarios and edge cases."""
|
||||
|
||||
def test_all_validators_with_same_state(self):
|
||||
"""Test all validators work consistently with same state."""
|
||||
state = {
|
||||
"accuracy_score": 0.85,
|
||||
"confidence": 0.9,
|
||||
"output": '{"result": "success", "data": [1, 2, 3]}',
|
||||
"content": "This is clean content without sensitive data",
|
||||
"timestamp": time.time() - 1800, # 30 minutes ago
|
||||
"user_id": "12345",
|
||||
"request_data": {"query": "test"},
|
||||
"response_time": "fast",
|
||||
}
|
||||
|
||||
accuracy_router = check_accuracy(threshold=0.8)
|
||||
confidence_router = check_confidence_level(threshold=0.8)
|
||||
format_router = validate_output_format(expected_format="json")
|
||||
privacy_router = check_privacy_compliance()
|
||||
freshness_router = check_data_freshness(max_age_seconds=3600)
|
||||
fields_router = validate_required_fields(["user_id", "request_data"])
|
||||
length_router = check_output_length(min_length=10, max_length=100)
|
||||
|
||||
assert accuracy_router(state) == "high_accuracy"
|
||||
assert confidence_router(state) == "high_confidence"
|
||||
assert format_router(state) == "valid_format"
|
||||
assert privacy_router(state) == "compliant"
|
||||
assert freshness_router(state) == "fresh"
|
||||
assert fields_router(state) == "valid"
|
||||
assert length_router(state) == "valid_length"
|
||||
|
||||
def test_chain_validation_failures(self):
|
||||
"""Test various validation failures in sequence."""
|
||||
problematic_state = {
|
||||
"accuracy_score": 0.3, # Low accuracy
|
||||
"confidence": "invalid", # Invalid confidence
|
||||
"output": "malformed json {", # Invalid JSON
|
||||
"content": "Contact me at john@example.com", # Privacy violation
|
||||
"timestamp": "2020-01-01", # Stale data
|
||||
"user_id": "", # Empty required field
|
||||
}
|
||||
|
||||
accuracy_router = check_accuracy(threshold=0.8)
|
||||
confidence_router = check_confidence_level(threshold=0.8)
|
||||
format_router = validate_output_format(expected_format="json")
|
||||
privacy_router = check_privacy_compliance()
|
||||
freshness_router = check_data_freshness(max_age_seconds=3600)
|
||||
fields_router = validate_required_fields(["user_id", "missing_field"])
|
||||
|
||||
assert accuracy_router(problematic_state) == "low_accuracy"
|
||||
assert confidence_router(problematic_state) == "low_confidence"
|
||||
assert format_router(problematic_state) == "invalid_format"
|
||||
assert privacy_router(problematic_state) == "privacy_violation"
|
||||
assert freshness_router(problematic_state) == "stale"
|
||||
assert fields_router(problematic_state) == "missing_fields"
|
||||
|
||||
def test_performance_with_large_content(self):
|
||||
"""Test performance with large content."""
|
||||
large_content = "x" * 100000 # 100KB of text
|
||||
|
||||
privacy_router = check_privacy_compliance()
|
||||
length_router = check_output_length(min_length=1000, max_length=200000)
|
||||
|
||||
state = {"content": large_content, "output": large_content}
|
||||
|
||||
# Should handle large content efficiently
|
||||
assert privacy_router(state) == "compliant"
|
||||
assert length_router(state) == "valid_length"
|
||||
|
||||
def test_concurrent_validator_usage(self):
|
||||
"""Test that validators are stateless and thread-safe."""
|
||||
router = check_accuracy(threshold=0.5)
|
||||
|
||||
# Multiple calls with different states should not interfere
|
||||
state1 = {"accuracy_score": 0.8}
|
||||
state2 = {"accuracy_score": 0.2}
|
||||
|
||||
assert router(state1) == "high_accuracy"
|
||||
assert router(state2) == "low_accuracy"
|
||||
assert router(state1) == "high_accuracy" # Should still work
|
||||
|
||||
def test_memory_efficiency_with_patterns(self):
|
||||
"""Test memory efficiency with regex patterns."""
|
||||
# Create router with many custom patterns
|
||||
many_patterns = [f"pattern{i}_\\w+" for i in range(100)]
|
||||
router = check_privacy_compliance(sensitive_patterns=many_patterns)
|
||||
|
||||
# Should handle many patterns without issues
|
||||
assert router({"content": "clean content"}) == "compliant"
|
||||
assert router({"content": "pattern50_sensitive_data"}) == "privacy_violation"
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Unit tests for error aggregation and deduplication."""
|
||||
|
||||
import asyncio
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -386,30 +386,35 @@ class TestErrorAggregator:
|
||||
assert summary["by_severity"]["critical"] == 3
|
||||
|
||||
def test_cleanup_old_errors(self):
|
||||
"""Test cleanup of old aggregated errors."""
|
||||
"""Test cleanup of old recent errors tracking."""
|
||||
reset_error_aggregator()
|
||||
aggregator = get_error_aggregator()
|
||||
|
||||
# Manually add old error
|
||||
# Manually add old entry to recent_errors
|
||||
old_error = create_error_info(message="Old error", error_type="Old")
|
||||
fingerprint = ErrorFingerprint.from_error_info(old_error)
|
||||
|
||||
# Create aggregated error with old timestamp
|
||||
old_time = datetime.now(UTC) - timedelta(hours=2)
|
||||
aggregated = AggregatedError(
|
||||
fingerprint=fingerprint,
|
||||
first_seen=old_time,
|
||||
last_seen=old_time,
|
||||
count=1,
|
||||
sample_errors=[old_error],
|
||||
)
|
||||
# Add old entry to recent_errors (this is what _cleanup_old_entries cleans)
|
||||
import time
|
||||
|
||||
aggregator.aggregated_errors[fingerprint.hash] = aggregated
|
||||
old_time = time.time() - (aggregator.dedup_window * 3) # Older than cutoff
|
||||
aggregator.recent_errors[fingerprint.hash] = old_time
|
||||
|
||||
# Cleanup should remove it
|
||||
# Add current entry that should remain
|
||||
current_error = create_error_info(message="Current error", error_type="Current")
|
||||
current_fingerprint = ErrorFingerprint.from_error_info(current_error)
|
||||
current_time = time.time()
|
||||
aggregator.recent_errors[current_fingerprint.hash] = current_time
|
||||
|
||||
# Should have 2 entries before cleanup
|
||||
assert len(aggregator.recent_errors) == 2
|
||||
|
||||
# Cleanup should remove the old entry but keep the current one
|
||||
aggregator._cleanup_old_entries()
|
||||
|
||||
assert len(aggregator.aggregated_errors) == 0
|
||||
assert len(aggregator.recent_errors) == 1
|
||||
assert current_fingerprint.hash in aggregator.recent_errors
|
||||
assert fingerprint.hash not in aggregator.recent_errors
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_access(self):
|
||||
|
||||
@@ -240,6 +240,7 @@ class TestErrorTelemetry:
|
||||
|
||||
# Add matching errors
|
||||
now = datetime.now(UTC)
|
||||
last_entry = None
|
||||
for i in range(4):
|
||||
entry = ErrorLogEntry(
|
||||
timestamp=(now - timedelta(seconds=i)).isoformat(),
|
||||
@@ -258,9 +259,11 @@ class TestErrorTelemetry:
|
||||
stack_trace=None,
|
||||
)
|
||||
telemetry.state.recent_errors.append(entry)
|
||||
last_entry = entry
|
||||
|
||||
# Check patterns
|
||||
telemetry._check_patterns(entry)
|
||||
assert last_entry is not None
|
||||
telemetry._check_patterns(last_entry)
|
||||
|
||||
# Should trigger alert
|
||||
alert_callback.assert_called_once()
|
||||
|
||||
@@ -155,7 +155,8 @@ class TestPerformanceFilter:
|
||||
|
||||
assert filter_obj.filter(record)
|
||||
assert hasattr(record, "timestamp")
|
||||
timestamp = record.timestamp
|
||||
timestamp = getattr(record, "timestamp", None)
|
||||
assert timestamp is not None
|
||||
assert isinstance(timestamp, str)
|
||||
# Should be ISO format timestamp
|
||||
assert "T" in timestamp
|
||||
|
||||
@@ -236,10 +236,6 @@ class TestRunAsyncChain:
|
||||
def sync_double(x: int) -> int:
|
||||
return x * 2
|
||||
|
||||
async def async_add_ten(x: int) -> int:
|
||||
await asyncio.sleep(0.01) # Simulate async work
|
||||
return x + 10
|
||||
|
||||
# Note: This test may not work with the stricter typing in async_helpers
|
||||
# Let's use only sync functions for now
|
||||
def add_ten_sync(x: int) -> int:
|
||||
|
||||
@@ -222,7 +222,7 @@ class TestTypeValidation:
|
||||
entity: EntityReference = {
|
||||
"id": f"test_{entity_type}",
|
||||
"name": f"Test {entity_type}",
|
||||
"type": entity_type, # type: ignore
|
||||
"type": entity_type,
|
||||
}
|
||||
assert entity["type"] == entity_type
|
||||
|
||||
|
||||
@@ -191,7 +191,7 @@ class TestPatternValidator:
|
||||
validator = PatternValidator(r"^\d+$")
|
||||
|
||||
# Should convert to string first
|
||||
is_valid, error = validator.validate(123)
|
||||
is_valid, _ = validator.validate(123)
|
||||
assert is_valid # "123" matches the pattern
|
||||
|
||||
|
||||
|
||||
@@ -151,7 +151,7 @@ class TestDownloadDocument:
|
||||
await download_document("https://example.com/doc.pdf", timeout=60)
|
||||
|
||||
# Verify timeout was passed to requests
|
||||
args, kwargs = mock_get.call_args
|
||||
_, kwargs = mock_get.call_args
|
||||
assert kwargs["timeout"] == 60
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ class TestCreateTempDocumentFile:
|
||||
created_files: list[str] = []
|
||||
try:
|
||||
for ext in extensions:
|
||||
file_path, file_name = create_temp_document_file(content, ext)
|
||||
file_path, _ = create_temp_document_file(content, ext)
|
||||
created_files.append(file_path)
|
||||
|
||||
assert file_path.endswith(ext)
|
||||
|
||||
@@ -418,10 +418,10 @@ class TestValidatedNode:
|
||||
class OutputModel(BaseModel):
|
||||
greeting: str
|
||||
|
||||
@validated_node( # type: ignore[arg-type]
|
||||
@validated_node( # type: ignore[misc]
|
||||
name="test_node", input_model=InputModel, output_model=OutputModel
|
||||
)
|
||||
def test_node(state: Any) -> dict[str, str]:
|
||||
def test_node(state: dict[str, Any]) -> dict[str, str]:
|
||||
return {"greeting": f"Hello, {state.get('name')}!"}
|
||||
|
||||
state = {"name": "World"}
|
||||
@@ -435,8 +435,8 @@ class TestValidatedNode:
|
||||
class InputModel(BaseModel):
|
||||
name: str
|
||||
|
||||
@validated_node(input_model=InputModel) # type: ignore[arg-type]
|
||||
def test_node(state: Any) -> dict[str, str]:
|
||||
@validated_node(input_model=InputModel) # type: ignore[misc]
|
||||
def test_node(state: dict[str, Any]) -> dict[str, str]:
|
||||
return {"result": f"Processed {state.get('name')}"}
|
||||
|
||||
state = {"name": "test"}
|
||||
@@ -450,8 +450,8 @@ class TestValidatedNode:
|
||||
class OutputModel(BaseModel):
|
||||
result: str
|
||||
|
||||
@validated_node(output_model=OutputModel) # type: ignore[arg-type] # pyright: ignore[reportArgumentType,reportGeneralTypeIssues]
|
||||
def test_node(state: Any) -> dict[str, str]:
|
||||
@validated_node(output_model=OutputModel) # type: ignore[misc]
|
||||
def test_node(state: dict[str, Any]) -> dict[str, str]:
|
||||
return {"result": "success"}
|
||||
|
||||
result = test_node({})
|
||||
@@ -461,8 +461,8 @@ class TestValidatedNode:
|
||||
def test_validated_node_with_metadata(self):
|
||||
"""Test validated_node with metadata."""
|
||||
|
||||
@validated_node(name="test_node", custom_field="custom_value")
|
||||
def test_node(state):
|
||||
@validated_node(name="test_node", custom_field="custom_value") # type: ignore[misc]
|
||||
def test_node(state: dict[str, Any]) -> dict[str, Any]:
|
||||
return state
|
||||
|
||||
# The decorator should not modify the function behavior
|
||||
@@ -621,10 +621,14 @@ class TestValidateAllGraphs:
|
||||
async def test_validate_all_graphs_non_async_function(self):
|
||||
"""Test validation with non-async graph function."""
|
||||
|
||||
def non_async_function():
|
||||
async def non_async_function() -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
graph_functions = {"non_async": non_async_function}
|
||||
|
||||
result = await validate_all_graphs(graph_functions)
|
||||
from typing import cast
|
||||
|
||||
result = await validate_all_graphs(
|
||||
cast(dict[str, Callable[[], Awaitable[Any]]], graph_functions)
|
||||
)
|
||||
assert result is False
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from bb_core.validation.merge import (
|
||||
_handle_average_collection,
|
||||
_handle_dict_update,
|
||||
_handle_list_extend,
|
||||
_handle_numeric_operation,
|
||||
@@ -291,16 +290,15 @@ class TestMergeHelperFunctions:
|
||||
result_list = merged_list[0]
|
||||
assert len(result_list) == 2 # Deduplicated
|
||||
|
||||
def test_handle_average_collection(self):
|
||||
"""Test average collection helper."""
|
||||
def test_handle_dict_update_simple(self):
|
||||
"""Test dict update helper."""
|
||||
merged = {}
|
||||
|
||||
_handle_average_collection(merged, "score", 85.5)
|
||||
_handle_average_collection(merged, "score", 92.0)
|
||||
_handle_average_collection(merged, "rating", 4.5)
|
||||
_handle_dict_update(merged, {"key1": "value1"})
|
||||
_handle_dict_update(merged, {"key2": "value2"})
|
||||
|
||||
assert merged["values"]["score"] == [85.5, 92.0]
|
||||
assert merged["values"]["rating"] == [4.5]
|
||||
assert merged["key1"] == "value1"
|
||||
assert merged["key2"] == "value2"
|
||||
|
||||
|
||||
class TestMergeEdgeCases:
|
||||
|
||||
@@ -232,7 +232,8 @@ class TestAssessAuthoritativeSources:
|
||||
# First 3 should be authoritative
|
||||
assert len(auth_sources) == 3
|
||||
assert all(
|
||||
source["url"] in [s["url"] for s in auth_sources] for source in sources[:3]
|
||||
source.get("url") in [s.get("url") for s in auth_sources]
|
||||
for source in sources[:3]
|
||||
)
|
||||
|
||||
def test_assess_authoritative_quality_score(self):
|
||||
@@ -247,7 +248,7 @@ class TestAssessAuthoritativeSources:
|
||||
|
||||
# Only high quality score should be authoritative
|
||||
assert len(auth_sources) == 1
|
||||
assert auth_sources[0]["quality_score"] == 0.9
|
||||
assert auth_sources[0].get("quality_score") == 0.9
|
||||
|
||||
def test_assess_authoritative_credibility_terms(self):
|
||||
"""Test assessment based on credibility terms."""
|
||||
|
||||
@@ -113,7 +113,7 @@ from .text import (
|
||||
try:
|
||||
from .tools import CategoryExtractionTool
|
||||
except ImportError:
|
||||
CategoryExtractionTool = None # type: ignore[assignment,misc]
|
||||
CategoryExtractionTool = None
|
||||
|
||||
__all__ = [
|
||||
# From core
|
||||
@@ -196,4 +196,4 @@ __all__ = [
|
||||
|
||||
# Conditionally add tools to __all__
|
||||
if CategoryExtractionTool is not None:
|
||||
__all__ += ["CategoryExtractionTool", "StatisticsExtractionTool"]
|
||||
__all__ += ["CategoryExtractionTool"]
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
"""Base classes and interfaces for extraction."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseExtractor(ABC):
|
||||
"""Abstract base class for extractors."""
|
||||
|
||||
@abstractmethod
|
||||
def extract(self, text: str) -> list[dict]:
|
||||
def extract(self, text: str) -> list[dict[str, Any]]:
|
||||
"""Extract information from text.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -67,14 +67,14 @@ class ExtractionConfig(TypedDict, total=False):
|
||||
|
||||
|
||||
# Entity extraction types
|
||||
class EntityTypedDict(TypedDict, total=False):
|
||||
class EntityTypedDict(TypedDict):
|
||||
"""Type definition for extracted entities."""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
value: str | None
|
||||
confidence: float
|
||||
metadata: dict[str, str | int | float | bool]
|
||||
value: NotRequired[str | None]
|
||||
confidence: NotRequired[float]
|
||||
metadata: NotRequired[dict[str, str | int | float | bool]]
|
||||
|
||||
|
||||
class RelationshipTypedDict(TypedDict, total=False):
|
||||
@@ -88,7 +88,7 @@ class RelationshipTypedDict(TypedDict, total=False):
|
||||
|
||||
|
||||
# Company extraction types
|
||||
class CompanyExtractionResultTypedDict(TypedDict, total=False):
|
||||
class CompanyExtractionResultTypedDict(TypedDict):
|
||||
"""Type definition for company extraction results."""
|
||||
|
||||
name: str
|
||||
|
||||
@@ -71,7 +71,9 @@ def extract_company_names(
|
||||
normalized = _normalize_company_name(name)
|
||||
if not normalized: # Skip if normalization results in empty string
|
||||
continue
|
||||
if normalized not in merged or result["confidence"] > merged[normalized]["confidence"]:
|
||||
if normalized not in merged or result.get("confidence", 0.0) > merged[normalized].get(
|
||||
"confidence", 0.0
|
||||
):
|
||||
merged[normalized] = result
|
||||
|
||||
# Convert to list and filter by confidence
|
||||
@@ -141,9 +143,9 @@ def add_source_metadata(
|
||||
url: str,
|
||||
) -> None:
|
||||
"""Add source metadata to company extraction result."""
|
||||
company["sources"] = company.get("sources", set())
|
||||
if isinstance(company["sources"], set):
|
||||
company["sources"].add(f"{source_type}_{result_index}")
|
||||
sources = company.get("sources", set())
|
||||
sources.add(f"{source_type}_{result_index}")
|
||||
company["sources"] = sources
|
||||
company["url"] = url
|
||||
|
||||
|
||||
@@ -347,23 +349,27 @@ def _handle_duplicate_company(
|
||||
) -> bool:
|
||||
"""Handle potential duplicate company entries."""
|
||||
for existing in final_results:
|
||||
existing_normalized = _normalize_company_name(existing["name"])
|
||||
existing_normalized = _normalize_company_name(existing.get("name", ""))
|
||||
|
||||
if normalized_name == existing_normalized or _is_substantial_name_overlap(
|
||||
normalized_name, existing_normalized
|
||||
):
|
||||
# Merge sources
|
||||
if (
|
||||
"sources" in existing
|
||||
and "sources" in company
|
||||
and isinstance(existing["sources"], set)
|
||||
and isinstance(company["sources"], set)
|
||||
(
|
||||
"sources" in existing
|
||||
and "sources" in company
|
||||
and isinstance(existing.get("sources"), set)
|
||||
and isinstance(company.get("sources"), set)
|
||||
)
|
||||
and existing.get("sources")
|
||||
and company.get("sources")
|
||||
):
|
||||
existing["sources"].update(company["sources"])
|
||||
|
||||
# Update confidence if higher
|
||||
if company["confidence"] > existing["confidence"]:
|
||||
existing["confidence"] = company["confidence"]
|
||||
if company.get("confidence", 0.0) > existing.get("confidence", 0.0):
|
||||
existing["confidence"] = company.get("confidence", 0.0)
|
||||
|
||||
return True
|
||||
|
||||
@@ -397,13 +403,13 @@ def _deduplicate_and_filter_results(
|
||||
final_results: list[CompanyExtractionResultTypedDict] = []
|
||||
|
||||
# Sort by confidence (descending) to process higher confidence first
|
||||
sorted_results = sorted(results, key=lambda x: x["confidence"], reverse=True)
|
||||
sorted_results = sorted(results, key=lambda x: x.get("confidence", 0.0), reverse=True)
|
||||
|
||||
for company in sorted_results:
|
||||
if company["confidence"] < min_confidence:
|
||||
if company.get("confidence", 0.0) < min_confidence:
|
||||
continue
|
||||
|
||||
normalized_name = _normalize_company_name(company["name"])
|
||||
normalized_name = _normalize_company_name(company.get("name", ""))
|
||||
|
||||
if not _handle_duplicate_company(normalized_name, company, final_results):
|
||||
final_results.append(company)
|
||||
@@ -412,82 +418,3 @@ def _deduplicate_and_filter_results(
|
||||
|
||||
|
||||
# === Functions for specific extraction needs ===
|
||||
|
||||
|
||||
def _extract_companies_from_single_result(
|
||||
result: dict[str, str],
|
||||
result_index: int,
|
||||
known_companies: set[str] | None,
|
||||
min_confidence: float,
|
||||
) -> list[CompanyExtractionResultTypedDict]:
|
||||
"""Extract companies from a single search result."""
|
||||
companies: list[CompanyExtractionResultTypedDict] = []
|
||||
|
||||
# Extract from title with higher confidence
|
||||
title_companies = extract_company_names(
|
||||
result.get("title", ""),
|
||||
known_companies=known_companies,
|
||||
min_confidence=min_confidence + 0.1,
|
||||
)
|
||||
|
||||
for company in title_companies:
|
||||
add_source_metadata("title", company, result_index, result.get("url", ""))
|
||||
companies.append(company)
|
||||
|
||||
# Extract from snippet
|
||||
snippet_companies = extract_company_names(
|
||||
result.get("snippet", ""),
|
||||
known_companies=known_companies,
|
||||
min_confidence=min_confidence,
|
||||
)
|
||||
|
||||
for company in snippet_companies:
|
||||
# Check if not duplicate from title
|
||||
is_duplicate = any(
|
||||
_is_substantial_name_overlap(company["name"], tc["name"]) for tc in title_companies
|
||||
)
|
||||
if not is_duplicate:
|
||||
add_source_metadata("snippet", company, result_index, result.get("url", ""))
|
||||
companies.append(company)
|
||||
|
||||
return companies
|
||||
|
||||
|
||||
def _update_existing_company(
|
||||
new_company: CompanyExtractionResultTypedDict,
|
||||
existing: CompanyExtractionResultTypedDict,
|
||||
) -> None:
|
||||
"""Update existing company with new information."""
|
||||
# Update confidence (weighted average)
|
||||
existing["confidence"] = (existing["confidence"] + new_company["confidence"]) / 2
|
||||
|
||||
# Merge sources
|
||||
if (
|
||||
"sources" in existing
|
||||
and "sources" in new_company
|
||||
and isinstance(existing["sources"], set)
|
||||
and isinstance(new_company["sources"], set)
|
||||
):
|
||||
existing["sources"].update(new_company["sources"])
|
||||
|
||||
|
||||
def _add_new_company(
|
||||
company: CompanyExtractionResultTypedDict,
|
||||
name_lower: str,
|
||||
merged_results: dict[str, CompanyExtractionResultTypedDict],
|
||||
) -> None:
|
||||
"""Add new company to results."""
|
||||
merged_results[name_lower] = company
|
||||
|
||||
|
||||
def _merge_company_into_results(
|
||||
company: CompanyExtractionResultTypedDict,
|
||||
merged_results: dict[str, CompanyExtractionResultTypedDict],
|
||||
) -> None:
|
||||
"""Merge company into consolidated results."""
|
||||
name_lower = company["name"].lower()
|
||||
|
||||
if name_lower in merged_results:
|
||||
_update_existing_company(company, merged_results[name_lower])
|
||||
else:
|
||||
_add_new_company(company, name_lower, merged_results)
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import cast
|
||||
|
||||
from bb_core.logging import get_logger
|
||||
|
||||
@@ -171,7 +172,8 @@ class ComponentExtractor(BaseExtractor):
|
||||
if matches and isinstance(matches[0], tuple):
|
||||
items = [match[0] if match else "" for match in matches]
|
||||
else:
|
||||
items = list(matches)
|
||||
# Explicitly type the matches as strings to avoid Sized type
|
||||
items = [str(match) for match in matches] if matches else []
|
||||
best_pattern_idx = i
|
||||
|
||||
if best_pattern_idx >= 0:
|
||||
@@ -181,10 +183,8 @@ class ComponentExtractor(BaseExtractor):
|
||||
if not items:
|
||||
items = [line.strip() for line in text_block.split("\n") if line.strip()]
|
||||
|
||||
# Process each item
|
||||
# Ensure items is properly typed
|
||||
items_list: list[str] = items if isinstance(items, list) else []
|
||||
for item in items_list:
|
||||
# Process each item - cast to ensure proper typing
|
||||
for item in cast("list[str]", items):
|
||||
# Check if this item contains multiple ingredients separated by commas
|
||||
# e.g., "curry powder, ginger powder, allspice"
|
||||
if "," in item and not any(char.isdigit() for char in item[:5]):
|
||||
|
||||
@@ -34,7 +34,7 @@ async def extract_companies(text: str) -> list[SimpleNamespace]:
|
||||
"""Extract company names from text asynchronously."""
|
||||
raw = extract_company_names(text)
|
||||
if raw:
|
||||
names = [c["name"] for c in raw]
|
||||
names = [c.get("name", "") for c in raw if c.get("name")]
|
||||
else:
|
||||
words = re.findall(r"\b[A-Z][a-zA-Z]+\b", text)
|
||||
skip = {"The", "And", "Or", "For", "To", "In", "On", "Of"}
|
||||
@@ -273,7 +273,7 @@ def extract_thought_action_pairs(text: str) -> list[JsonDict]:
|
||||
and len(content) == 2
|
||||
):
|
||||
action_name, args_str = content
|
||||
action = {
|
||||
action: JsonDict = {
|
||||
"thought": current_thought,
|
||||
"action": action_name,
|
||||
"args": parse_action_args(args_str) if args_str else {},
|
||||
@@ -315,7 +315,7 @@ def extract_entities(text: str) -> dict[str, list[str]]:
|
||||
|
||||
# Extract companies using existing functionality
|
||||
companies = extract_company_names(text)
|
||||
company_names = [c["name"] for c in companies]
|
||||
company_names = [c.get("name", "") for c in companies if c.get("name")]
|
||||
|
||||
# Extract URLs
|
||||
url_pattern = re.compile(
|
||||
@@ -453,7 +453,7 @@ class ExtractedEntity:
|
||||
class EntityExtractor(BaseExtractor):
|
||||
"""Extract entities from text."""
|
||||
|
||||
def extract(self, text: str) -> list[dict]:
|
||||
def extract(self, text: str) -> list[dict[str, Any]]:
|
||||
"""Extract entities from text.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -12,7 +12,7 @@ try:
|
||||
from bb_core.validation.statistics import HIGH_CREDIBILITY_TERMS
|
||||
except ImportError:
|
||||
# Fallback definition if not available
|
||||
HIGH_CREDIBILITY_TERMS = [
|
||||
_fallback_credibility_terms = [
|
||||
"university",
|
||||
"research",
|
||||
"institute",
|
||||
@@ -30,6 +30,7 @@ except ImportError:
|
||||
"bureau",
|
||||
"department",
|
||||
]
|
||||
HIGH_CREDIBILITY_TERMS = _fallback_credibility_terms # pyright: ignore[reportConstantRedefinition]
|
||||
|
||||
from .numeric import extract_year
|
||||
|
||||
@@ -91,8 +92,9 @@ def extract_credibility_terms(text: str) -> list[str]:
|
||||
>>> extract_credibility_terms("Reported by a renowned research institute")
|
||||
['research', 'institute']
|
||||
"""
|
||||
if not isinstance(text, str):
|
||||
raise TypeError("text must be a string")
|
||||
# Type is guaranteed by signature; no need for isinstance
|
||||
if not text:
|
||||
return []
|
||||
return [term for term in HIGH_CREDIBILITY_TERMS if term.lower() in text.lower()]
|
||||
|
||||
|
||||
@@ -112,8 +114,9 @@ def assess_source_quality(text: str) -> float:
|
||||
>>> assess_source_quality("According to a study by a .edu institution, ...")
|
||||
0.8
|
||||
"""
|
||||
if not isinstance(text, str):
|
||||
raise TypeError("text must be a string")
|
||||
# Type is guaranteed by signature; no need for isinstance
|
||||
if not text:
|
||||
return 0.0
|
||||
score = 0.5
|
||||
|
||||
# Count credibility terms
|
||||
|
||||
@@ -58,10 +58,6 @@ def extract_json_from_text(text: str) -> JsonDict | None:
|
||||
|
||||
# This is a simplified approach - find strings and escape newlines in them
|
||||
# Match JSON strings (basic pattern, not perfect but should work for most cases)
|
||||
def replace_newlines_in_match(match: re.Match[str]) -> str:
|
||||
string_content = match.group(0)
|
||||
# Replace actual newlines with \n, being careful not to double-escape
|
||||
return string_content.replace("\n", "\\n").replace("\r", "\\r")
|
||||
|
||||
# Pattern to match strings in JSON (handles escaped quotes)
|
||||
string_pattern = r'"(?:[^"\\]|\\.)*"'
|
||||
@@ -445,7 +441,7 @@ def parse_action_args(text: str) -> ActionArgsDict:
|
||||
"""
|
||||
# Try JSON first
|
||||
json_data = extract_json_from_text(text)
|
||||
if json_data and isinstance(json_data, dict):
|
||||
if json_data:
|
||||
# Convert to ActionArgsDict format
|
||||
result: dict[str, Any] = {}
|
||||
for k, v in json_data.items():
|
||||
|
||||
@@ -281,11 +281,8 @@ async def _process_chunked_content(
|
||||
chunk_results.append(result)
|
||||
|
||||
# Merge chunk results
|
||||
merged_result = merge_chunk_results(chunk_results, category=category)
|
||||
# Ensure the result is a JsonDict
|
||||
if isinstance(merged_result, dict):
|
||||
return merged_result
|
||||
return _default_extraction_result(category)
|
||||
return merge_chunk_results(chunk_results, category=category)
|
||||
# Return the merged result
|
||||
|
||||
|
||||
def _default_extraction_result(category: str) -> JsonDict:
|
||||
|
||||
@@ -78,6 +78,7 @@ def benchmark():
|
||||
kwargs = {}
|
||||
|
||||
total_time = 0.0
|
||||
result = None
|
||||
for _ in range(rounds):
|
||||
start = time.time()
|
||||
for _ in range(iterations):
|
||||
|
||||
@@ -119,8 +119,8 @@ class TestFactTypedDict:
|
||||
assert main_fact["text"] == "Revenue analysis shows growth"
|
||||
assert main_fact["type"] == "financial"
|
||||
assert len(main_fact["statistics"]) == 1
|
||||
assert main_fact["statistics"][0]["text"] == "Sub-statistic"
|
||||
assert main_fact["statistics"][0]["value"] == 1.2
|
||||
assert main_fact["statistics"][0].get("text") == "Sub-statistic"
|
||||
assert main_fact["statistics"][0].get("value") == 1.2
|
||||
|
||||
def test_fact_confidence_types(self) -> None:
|
||||
"""Test fact with different confidence types."""
|
||||
@@ -175,8 +175,8 @@ class TestCategoryFactTypedDict:
|
||||
}
|
||||
|
||||
assert category_fact["type"] == "market_analysis"
|
||||
assert category_fact["data"]["text"] == "Market share is 45%"
|
||||
assert category_fact["data"]["value"] == 45.0
|
||||
assert category_fact["data"].get("text") == "Market share is 45%"
|
||||
assert category_fact["data"].get("value") == 45.0
|
||||
|
||||
def test_category_fact_with_complex_data(self) -> None:
|
||||
"""Test category fact with complex nested data."""
|
||||
@@ -201,9 +201,9 @@ class TestCategoryFactTypedDict:
|
||||
}
|
||||
|
||||
assert category_fact["type"] == "comprehensive_analysis"
|
||||
assert category_fact["data"]["quality_score"] == 0.92
|
||||
assert len(category_fact["data"]["citations"]) == 1
|
||||
assert category_fact["data"]["citations"][0]["source"] == "Research Institute"
|
||||
assert category_fact["data"].get("quality_score") == 0.92
|
||||
assert len(category_fact["data"].get("citations", [])) == 1
|
||||
assert category_fact["data"].get("citations", [])[0].get("source") == "Research Institute"
|
||||
|
||||
|
||||
class TestExtractionResult:
|
||||
@@ -216,7 +216,7 @@ class TestExtractionResult:
|
||||
result: ExtractionResult = {"facts": facts}
|
||||
|
||||
assert len(result["facts"]) == 1
|
||||
assert result["facts"][0]["text"] == "Simple fact"
|
||||
assert result["facts"][0].get("text") == "Simple fact"
|
||||
|
||||
def test_extraction_result_complete(self) -> None:
|
||||
"""Test complete extraction result."""
|
||||
@@ -366,11 +366,13 @@ class TestTypeCompatibility:
|
||||
result: ExtractionResult = {"facts": [main_fact], "relevance_score": 0.95}
|
||||
|
||||
# Verify all levels work correctly
|
||||
assert category_fact["data"]["text"] == "Main finding"
|
||||
assert len(category_fact["data"]["statistics"]) == 1
|
||||
assert category_fact["data"]["statistics"][0]["text"] == "Sub-analysis"
|
||||
assert len(result["facts"]) == 1
|
||||
assert result["facts"][0]["type"] == "research"
|
||||
assert category_fact["data"].get("text") == "Main finding"
|
||||
statistics = category_fact["data"].get("statistics", [])
|
||||
assert len(statistics) == 1
|
||||
assert statistics[0].get("text") == "Sub-analysis"
|
||||
facts = result.get("facts", [])
|
||||
assert len(facts) == 1
|
||||
assert facts[0].get("type") == "research"
|
||||
|
||||
def test_optional_field_handling(self) -> None:
|
||||
"""Test handling of optional fields across types."""
|
||||
@@ -404,7 +406,7 @@ class TestTypeCompatibility:
|
||||
facts_list.append(fact)
|
||||
|
||||
assert len(facts_list) == 1
|
||||
assert isinstance(facts_list[0]["citations"], list)
|
||||
assert isinstance(facts_list[0].get("citations"), list)
|
||||
|
||||
def test_union_type_handling(self) -> None:
|
||||
"""Test handling of union types in FactTypedDict."""
|
||||
@@ -469,8 +471,9 @@ class TestEdgeCasesAndValidation:
|
||||
"statistics": nested_stats,
|
||||
}
|
||||
|
||||
assert len(main_fact["statistics"]) == 10
|
||||
assert main_fact["statistics"][5]["value"] == 50.0
|
||||
statistics = main_fact.get("statistics", [])
|
||||
assert len(statistics) == 10
|
||||
assert statistics[5].get("value") == 50.0
|
||||
|
||||
def test_special_string_values(self) -> None:
|
||||
"""Test handling of special string values."""
|
||||
|
||||
@@ -20,7 +20,7 @@ def test_extract_company_names_basic() -> None:
|
||||
results = extract_company_names(text)
|
||||
|
||||
assert len(results) >= 1
|
||||
company_names = [r["name"] for r in results]
|
||||
company_names = [r.get("name", "") for r in results]
|
||||
assert any("Acme Inc" in name for name in company_names)
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ def test_extract_company_names_with_indicators() -> None:
|
||||
text = "Apple Corporation and Microsoft Technologies are tech giants."
|
||||
results = extract_company_names(text)
|
||||
|
||||
company_names = [r["name"] for r in results]
|
||||
company_names = [r.get("name", "") for r in results]
|
||||
assert any("Apple Corporation" in name for name in company_names)
|
||||
assert any("Microsoft Technologies" in name for name in company_names)
|
||||
|
||||
@@ -72,7 +72,7 @@ def test_extract_companies_from_search_results() -> None:
|
||||
results = extract_companies_from_search_results(search_results)
|
||||
|
||||
assert len(results) >= 2
|
||||
company_names = [r["name"] for r in results]
|
||||
company_names = [r.get("name", "") for r in results]
|
||||
assert any("Acme Inc" in name for name in company_names)
|
||||
assert any("Beta LLC" in name for name in company_names)
|
||||
|
||||
@@ -98,7 +98,7 @@ def test_add_source_metadata() -> None:
|
||||
assert company["confidence"] == 0.9
|
||||
assert "sources" in company
|
||||
assert "search_0" in company["sources"]
|
||||
assert company["url"] == "https://example.com"
|
||||
assert company.get("url") == "https://example.com"
|
||||
|
||||
|
||||
def test_add_source_metadata_multiple_sources() -> None:
|
||||
@@ -114,7 +114,7 @@ def test_add_source_metadata_multiple_sources() -> None:
|
||||
assert "sources" in company
|
||||
assert "indicator_match" in company["sources"]
|
||||
assert "search_1" in company["sources"]
|
||||
assert company["url"] == "https://beta.com"
|
||||
assert company.get("url") == "https://beta.com"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import ClassVar, final
|
||||
from typing import ClassVar, cast
|
||||
from xml.etree.ElementTree import Element
|
||||
|
||||
import defusedxml.ElementTree as ET
|
||||
@@ -263,7 +263,6 @@ class ArxivPaper(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
class ArxivSearchOptions(BaseModel):
|
||||
"""Options for searching arXiv.
|
||||
|
||||
@@ -276,6 +275,12 @@ class ArxivSearchOptions(BaseModel):
|
||||
sort_order: Sort order (ascending or descending).
|
||||
"""
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
def __init__(self, **data: object) -> None:
|
||||
"""Initialize ArxivSearchOptions with proper Pydantic setup."""
|
||||
super().__init__(**data)
|
||||
|
||||
query: str = Field(
|
||||
default="",
|
||||
description="The search query string.",
|
||||
@@ -330,14 +335,13 @@ class ArxivSearchOptions(BaseModel):
|
||||
"""Validate and convert sort_by to enum."""
|
||||
if isinstance(value, SortBy):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
if isinstance(value, str) and value:
|
||||
try:
|
||||
from typing import cast
|
||||
|
||||
return cast("SortBy", SortBy(value))
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid sort_by value: {value}")
|
||||
raise ValueError(f"sort_by must be a SortBy enum or string, got {type(value)}")
|
||||
# Default to RELEVANCE if value is falsy
|
||||
return SortBy.RELEVANCE
|
||||
|
||||
@field_validator("sort_order", mode="before")
|
||||
@classmethod
|
||||
@@ -345,19 +349,15 @@ class ArxivSearchOptions(BaseModel):
|
||||
"""Validate and convert sort_order to enum."""
|
||||
if isinstance(value, SortOrder):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
if isinstance(value, str) and value:
|
||||
try:
|
||||
from typing import cast
|
||||
|
||||
return cast("SortOrder", SortOrder(value))
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid sort_order value: {value}")
|
||||
raise ValueError(
|
||||
f"sort_order must be a SortOrder enum or string, got {type(value)}"
|
||||
)
|
||||
# Default to DESCENDING if value is falsy
|
||||
return SortOrder.DESCENDING
|
||||
|
||||
|
||||
@final
|
||||
class ArxivClient(BaseAPIClient):
|
||||
"""ArXiv API client using bb_core infrastructure."""
|
||||
|
||||
@@ -397,19 +397,18 @@ class ArxivClient(BaseAPIClient):
|
||||
|
||||
with error_context(operation="arxiv_search", query=query):
|
||||
# Convert string parameters to enum types if needed
|
||||
from typing import cast
|
||||
|
||||
if isinstance(sort_by, SortBy):
|
||||
sort_by_enum = sort_by
|
||||
elif sort_by:
|
||||
sort_by_enum = cast("SortBy", SortBy(sort_by))
|
||||
sort_by_enum = SortBy(sort_by)
|
||||
else:
|
||||
sort_by_enum = SortBy.RELEVANCE
|
||||
|
||||
if isinstance(sort_order, SortOrder):
|
||||
sort_order_enum = sort_order
|
||||
elif sort_order:
|
||||
sort_order_enum = cast("SortOrder", SortOrder(sort_order))
|
||||
sort_order_enum = SortOrder(sort_order)
|
||||
else:
|
||||
sort_order_enum = SortOrder.DESCENDING
|
||||
|
||||
|
||||
@@ -95,8 +95,6 @@ class BaseAPIClient(ABC):
|
||||
url: str | None = None,
|
||||
**kwargs: object,
|
||||
) -> dict[str, object]:
|
||||
from typing import Any
|
||||
|
||||
from bb_core.networking import RequestMethod
|
||||
|
||||
# Build full URL if relative
|
||||
@@ -112,23 +110,29 @@ class BaseAPIClient(ABC):
|
||||
full_url = url
|
||||
|
||||
# Convert method string to RequestMethod enum
|
||||
method_enum = (
|
||||
RequestMethod(method.upper()) if method else RequestMethod.GET
|
||||
)
|
||||
if isinstance(method, RequestMethod):
|
||||
method_enum: RequestMethod = method
|
||||
else:
|
||||
method_str = method.upper() if method else "GET"
|
||||
# Use getattr for safe enum access with fallback
|
||||
method_enum = getattr(RequestMethod, method_str, RequestMethod.GET)
|
||||
|
||||
# Extract parameters for the request method with proper typing
|
||||
params_raw = kwargs.get("params")
|
||||
params: dict[str, Any] | None = (
|
||||
params: dict[str, str | int | float | bool] | None = (
|
||||
params_raw if isinstance(params_raw, dict) else None
|
||||
)
|
||||
|
||||
json_raw = kwargs.get("json_data") or kwargs.get("json")
|
||||
json_data: dict[str, Any] | None = (
|
||||
json_raw if isinstance(json_raw, dict) else None
|
||||
)
|
||||
json_data: (
|
||||
dict[
|
||||
str, str | int | float | bool | list[object] | dict[str, object]
|
||||
]
|
||||
| None
|
||||
) = json_raw if isinstance(json_raw, dict) else None
|
||||
|
||||
data_raw = kwargs.get("data")
|
||||
data: dict[str, Any] | None = (
|
||||
data: dict[str, str | int | float | bool | bytes] | None = (
|
||||
data_raw if isinstance(data_raw, dict) else None
|
||||
)
|
||||
|
||||
@@ -141,11 +145,9 @@ class BaseAPIClient(ABC):
|
||||
headers = {**self._headers_dict, **headers}
|
||||
|
||||
# Call the HTTPClient request method
|
||||
# Cast method to RequestMethod to satisfy type checker
|
||||
from typing import cast
|
||||
|
||||
response = await self._http_client.request(
|
||||
method=cast("RequestMethod", method_enum),
|
||||
method=method_enum,
|
||||
url=full_url,
|
||||
params=params,
|
||||
json_data=json_data,
|
||||
@@ -249,8 +251,8 @@ class BaseAPIClient(ABC):
|
||||
raise ValueError(f"Invalid response from {url}")
|
||||
|
||||
async def _post(
|
||||
self, endpoint: str, json: dict | None = None, **kwargs: object
|
||||
) -> dict:
|
||||
self, endpoint: str, json: dict[str, object] | None = None, **kwargs: object
|
||||
) -> dict[str, object]:
|
||||
"""Make a POST request to the API.
|
||||
|
||||
Args:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user