feat: add Paperless NGX agent with robust error handling (#42)

* feat: add Paperless NGX agent with robust error handling

- Implement ReAct agent for document management with Paperless NGX
- Add comprehensive error handling to prevent crashes on missing credentials
- Use global ServiceFactory singleton pattern for dependency injection
- Integrate edge helpers for error routing and retry logic
- Add user-friendly error messages when Paperless is not configured
- Support runtime configuration through RunnableConfig
- Include all Paperless tools (search, update, tags, correspondents, etc.)
- Add factory function for LangGraph API compatibility
- Ensure graceful degradation when credentials are missing

* Apply suggestions from code review

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* feat: enhance custom tool node for Paperless NGX with improved error handling

- Introduced ReActAgentState for better state management in custom tool node.
- Added explicit configuration validation and user-friendly error messages for missing Paperless NGX credentials.
- Enhanced logging for configuration usage and error reporting.
- Updated agent factory to ensure proper configuration handling for tool execution.

* fix: update and add files, resolve pre-commit and safe.directory issues

* chore: update project configurations and enhance documentation

- Added new commands to CLAUDE.local.md for formatting and type checking.
- Updated Makefile to use 'basedpyright' for linting instead of 'pyright'.
- Included 'pytest' as a new dependency in pyproject.toml.
- Cleaned up pyrightconfig.json by removing unnecessary stubPath.
- Deleted obsolete test_fixes_summary.md file.
- Introduced new test-failures.txt and type_errors.txt files for better error tracking.
- Added new devcontainer configuration files for improved development environment setup.

These changes collectively enhance project organization, improve documentation clarity, and streamline the development workflow.

* chore: add pre-commit as a dependency in pyproject.toml

- Included "pre-commit>=4.2.0" in the dependencies section of pyproject.toml to enhance code quality checks and maintainability.

* chore: update project configurations and enhance documentation

- Added new commands to CLAUDE.local.md for formatting and type checking.
- Updated Makefile to use `basedpyright` for linting instead of `pyright`.
- Included `pytest` as a dependency in pyproject.toml for testing.
- Cleaned up pyrightconfig.json by removing unnecessary stubPath.
- Deleted obsolete test_fixes_summary.md and added new test-failures.txt for tracking test failures.
- Introduced new devcontainer configuration files for improved development environment setup.

These changes collectively enhance project organization, improve testing capabilities, and provide clearer development guidelines.

* chore: update devcontainer configuration to include host .venv bind mount

- Added a bind mount for the host's .venv directory to the devcontainer configuration for shared access, avoiding conflicts with the existing .venv volume.

* chore: add pre-commit as a development dependency in pyproject.toml

- Included "pre-commit>=4.2.0" in the dev dependencies section of pyproject.toml to enhance code quality checks and maintainability.

* chore: remove obsolete basedpyright_output.json and update LLMConfig initialization

- Deleted the obsolete basedpyright_output.json file.
- Updated LLMConfig initialization in app.py to set a default profile using LLMProfile.LARGE for improved configuration management.

* fix: improve formatting and consistency in Paperless NGX client and tests

- Added missing newlines for better readability in paperless.py and test_paperless.py.
- Ensured consistent formatting in function calls and docstrings across the codebase.
- Enhanced clarity in custom field query examples within the documentation.

* fix: improve output formatting and assertions in crash tests

- Adjusted print statements in run_crash_tests.py for consistent newline usage.
- Enhanced assertions in test_concurrency_races.py to clarify expected behavior during concurrent operations.
- Improved readability of markdown output in summary reports.

* fix: refactor paperless_ngx_agent_factory to be asynchronous

- Changed paperless_ngx_agent_factory from a synchronous to an asynchronous function.
- Removed unnecessary nested asyncio handling and streamlined agent creation process.
- Updated comments for clarity regarding checkpointer usage.

* fix: resolve ruff linting errors

- Fix B027 errors by adding return statements to non-abstract ainit methods
- Fix E501 line length errors in embeddings.py
- Fix ANN204, D105, ANN202, ANN401 type annotation errors
- Add proper type annotations for __await__ and service_factory

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: update project configurations and enhance documentation

- Added examples directory to .pyreflyignore for improved file management.
- Updated Makefile to include new test_watch target for better testing workflow.
- Refined pyproject.toml and pyrightconfig.json for enhanced project configuration.
- Introduced new check_singletons.md documentation for clarity on singleton checks.
- Removed obsolete type_errors.txt file to streamline error tracking.

These changes collectively improve project organization, enhance documentation clarity, and refine testing processes.

* fix: enhance graph validation and improve type handling in extraction service

- Updated `validate_all_graphs` to handle coroutine functions more effectively and skip those requiring arguments.
- Improved type casting in `ComponentExtractor` to ensure proper handling of matches.
- Enhanced message handling in `call_model_node` to ensure proper indexing and debugging output.
- Refactored `scrape_status_summary_node` for clearer progress calculation and improved handling of R2R information.
- Streamlined error handling and type casting in `SemanticExtractionService` for better response processing.
- Improved test readability and structure in `test_check_duplicate_error_handling.py` with consistent patching style.

These changes collectively enhance the robustness and clarity of the validation and extraction processes.

* fix: improve scraper name validation in extraction orchestrator

- Updated validation for `scraper_name` in `extract_key_information` and `process_single_url` functions to ensure it is a string.
- Simplified logic to default to "beautifulsoup" if `scraper_name` is not a valid string.

These changes enhance the robustness of the extraction process by ensuring proper type handling for scraper names.

* fix: restore type safety in response formatting

- Updated `format_response_for_caller` to ensure `validation_issues_list` is a list of strings and refined URL extraction from sources.
- These changes enhance type safety and improve the robustness of response data handling.

* fix: enhance type safety in content validation and service factory

- Added type checks to `validate_content` and `preprocess_content` functions to ensure inputs are strings.
- Improved date validation in `count_recent_sources` to check for both presence and type.
- Updated condition for updating `llm_kwargs` in `ServiceFactory` to ensure `kwargs` is a dictionary.

These changes collectively enhance type safety and robustness in validation and service handling.

* fix: enhance type safety in content extraction and analysis

- Added comments to clarify that `extracted_content` and `catalog_metadata` are always dictionaries, improving type safety in `identify_component_focus_node`, `find_affected_catalog_items_node`, and `batch_analyze_components_node`.
- Simplified checks for `catalog_items` to ensure they are lists, enhancing robustness in handling extracted content.

These changes collectively improve type handling and clarity in the content extraction and analysis processes.

* Fix error handling for mixed error types in validation checks

Co-authored-by: travis.vas <travis.vas@gmail.com>

* fix: enhance type checking in URL filtering and human feedback processing

- Added type checks to ensure results are dictionaries in `filter_search_results` and human feedback functions.
- Improved robustness by skipping non-dictionary results, enhancing type safety in data handling.

These changes collectively improve type handling and clarity in the scraping and validation processes.

* fix: enhance type checking in configuration handling

- Updated `check_existing_content_node` and `store_processing_metadata` functions to check if `config` is an instance of `AppConfig` before validation.
- Improved type safety by ensuring that the correct type is used for `app_config`.

These changes collectively enhance type handling and robustness in configuration processing.

* fix: improve code readability and type safety in various nodes

- Refactored conditional expressions in multiple nodes to enhance readability and maintainability.
- Ensured consistent type handling for extracted content and configuration parameters across several functions.
- These changes collectively improve clarity and robustness in data processing and validation.

* fix: improve type safety in RouterConfig category and severity handling

- Refactored category and severity processing in RouterConfig to use direct enumeration instead of casting.
- Added type checks to raise errors for invalid types, enhancing robustness in configuration handling.

These changes collectively improve type safety and error handling in the RouterConfig class.

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: Cursor Agent <cursoragent@cursor.com>
This commit is contained in:
2025-07-16 16:31:44 -04:00
committed by GitHub
parent aaa9fa285d
commit 62314d77b0
320 changed files with 22219 additions and 3718 deletions

View File

@@ -0,0 +1,173 @@
Ensure that the modules, functions, and classes in $ARGUMENTS have adopted my global singleton patterns
### **Tier 1: Core Architectural Pillars**
These are the most critical, high-level abstractions that every developer should use to interact with the application's core services and configuration.
---
#### **1. The Global Service Factory**
This is the **single most important singleton** in your application. It provides centralized, asynchronous, and cached access to all major services.
* **Primary Accessor**: `get_global_factory(config: AppConfig | None = None)`
* **Location**: `src/biz_bud/services/factory.py`
* **Purpose**: To provide a singleton instance of the `ServiceFactory`, which in turn creates and manages the lifecycle of essential services like LLM clients, database connections, and caches.
* **Usage Pattern**:
```python
# In any async part of your application
from src.biz_bud.services.factory import get_global_factory
service_factory = await get_global_factory()
llm_client = await service_factory.get_llm_client()
vector_store = await service_factory.get_vector_store()
db_service = await service_factory.get_db_service()
```
* **Instead of**:
* Instantiating services like `LangchainLLMClient` or `PostgresStore` directly.
* Managing database connection pools or API clients manually in different parts of the code.
---
#### **2. Application Configuration Loading**
Configuration is managed centrally and should be accessed through these standardized functions to ensure all overrides (from environment variables or runtime) are correctly applied.
* **Primary Accessors**: `load_config()` and `load_config_async()`
* **Location**: `src/biz_bud/config/loader.py`
* **Purpose**: To load the `AppConfig` from `config.yaml`, merge it with environment variables, and return a validated Pydantic model. The async version is for use within an existing event loop.
* **Runtime Override Helper**: `resolve_app_config_with_overrides(runnable_config: RunnableConfig)`
* **Purpose**: **This is the standard pattern for graphs.** It takes the base `AppConfig` and intelligently merges it with runtime parameters passed in a `RunnableConfig` (e.g., `llm_profile_override`, `temperature`).
* **Usage Pattern (Inside a Graph Factory)**:
```python
from src.biz_bud.config.loader import resolve_app_config_with_overrides
from src.biz_bud.services.factory import ServiceFactory
def my_graph_factory(config: dict[str, Any]) -> Pregel:
runnable_config = RunnableConfig(configurable=config.get("configurable", {}))
app_config = resolve_app_config_with_overrides(runnable_config=runnable_config)
service_factory = ServiceFactory(app_config)
# ... inject factory into a graph or use it to build nodes
```
* **Instead of**:
* Manually reading `config.yaml` with `pyyaml`.
* Using `os.getenv()` scattered throughout the codebase.
* Manually parsing `RunnableConfig` inside every node.
---
### **Tier 2: Standardized Interaction Patterns**
These are the common patterns and helpers for core tasks like AI model interaction, caching, and error handling.
---
#### **3. LLM Interaction**
All interactions with Large Language Models should go through standardized nodes or clients to ensure consistency in configuration, message handling, and error management.
* **Primary Graph Node**: `call_model_node(state: dict, config: RunnableConfig)`
* **Location**: `src/biz_bud/nodes/llm/call.py`
* **Purpose**: This is the **standard node for all LLM calls** within a LangGraph workflow. It correctly resolves the LLM profile (`tiny`, `small`, `large`), handles message history, parses tool calls, and manages exceptions.
* **Service Factory Method**: `ServiceFactory.get_llm_for_node(node_context: str, llm_profile_override: str | None)`
* **Location**: `src/biz_bud/services/factory.py`
* **Purpose**: For custom nodes that require more complex logic, this method provides a pre-configured, wrapped LLM client from the factory. The `node_context` helps select an appropriate default model size.
* **Instead of**:
* Directly importing and using `ChatOpenAI`, `ChatAnthropic`, etc.
* Manually constructing message lists or handling API errors for each LLM call.
* Implementing your own retry logic for LLM calls.
---
#### **4. Caching System**
The project provides a default, asynchronous, in-memory cache and a Redis-backed cache. Direct interaction should be minimal; prefer the decorator.
* **Primary Decorator**: `@cache_async(ttl: int)`
* **Location**: `packages/business-buddy-core/src/bb_core/caching/decorators.py`
* **Purpose**: The standard way to cache the results of any `async` function. It automatically generates a key based on the function and its arguments.
* **Singleton Accessors**:
* `get_default_cache_async()`: Gets the default in-memory cache instance.
* `ServiceFactory.get_redis_cache()`: Gets the Redis cache backend if configured.
* **Usage Pattern**:
```python
from bb_core.caching import cache_async
@cache_async(ttl=3600) # Cache for 1 hour
async def my_expensive_api_call(arg1: str, arg2: int) -> dict:
# ... implementation
```
* **Instead of**:
* Implementing your own caching logic with dictionaries or files.
* Instantiating `InMemoryCache` or `RedisCache` manually.
---
#### **5. Error Handling & Lifecycle Subsystem**
This is a comprehensive, singleton-based system for robust error management and application lifecycle.
* **Global Singletons**:
* `get_error_aggregator()`: (`errors/aggregator.py`) Use to report errors for deduplication and rate-limiting.
* `get_error_router()`: (`errors/router.py`) Use to define and apply routing logic for different error types.
* `get_error_logger()`: (`errors/logger.py`) Use for consistent, structured error logging.
* **Primary Decorator**: `@handle_errors(error_type)`
* **Location**: `packages/business-buddy-core/src/bb_core/errors/base.py`
* **Purpose**: Wraps functions to automatically catch common exceptions (`httpx` errors, `pydantic` validation errors) and convert them into standardized `BusinessBuddyError` types.
* **Lifecycle Management**:
* `get_singleton_manager()`: (`services/singleton_manager.py`) Main accessor for the lifecycle manager.
* `cleanup_all_singletons()`: Use this at application shutdown to gracefully close all registered services (DB pools, HTTP sessions, etc.).
* **Instead of**:
* Using generic `try...except Exception` blocks.
* Manually logging error details with `logger.error()`.
* Forgetting to close resources like database connections.
---
### **Tier 3: Reusable Helpers & Utilities**
These are specific tools and helpers for common, recurring tasks across the codebase.
---
#### **6. Asynchronous and Networking Utilities**
* **Location**: `packages/business-buddy-core/src/bb_core/networking/async_utils.py`
* **Key Helpers**:
* `gather_with_concurrency(n, *tasks)`: Runs multiple awaitables concurrently with a semaphore to limit parallelism.
* `retry_async(...)`: A decorator to add exponential backoff retry logic to any `async` function.
* `RateLimiter(calls_per_second)`: An `async` context manager to enforce rate limits.
* `HTTPClient`: The base client for making robust HTTP requests, managed by the `ServiceFactory`.
---
#### **7. Graph State & Node Helpers**
* **State Management**: `StateUpdater(base_state)`
* **Location**: `packages/business-buddy-core/src/bb_core/langgraph/state_immutability.py`
* **Purpose**: Provides a **safe, fluent API** for updating graph state immutably (e.g., `updater.set("key", val).append("list_key", item).build()`).
* **Instead of**: Directly mutating the state dictionary (`state["key"] = value`), which can cause difficult-to-debug issues in concurrent or resumable graphs.
* **Node Validation Decorators**: `@validate_node_input(Model)` and `@validate_node_output(Model)`
* **Location**: `packages/business-buddy-core/src/bb_core/validation/graph_validation.py`
* **Purpose**: Ensures that the input to a node and the output from a node conform to a specified Pydantic model, automatically adding errors to the state if validation fails.
* **Edge Helpers**:
* **Location**: `packages/business-buddy-core/src/bb_core/edge_helpers/`
* **Purpose**: A suite of pre-built, configurable routing functions for common graph conditions (e.g., `handle_error`, `retry_on_failure`, `check_critical_error`, `should_continue`). Use these to define conditional edges in your graphs.
---
#### **8. High-Level Workflow Tools**
These tools abstract away complex, multi-step processes like searching and scraping.
* **Unified Scraper**: `UnifiedScraper(config: ScrapeConfig)`
* **Location**: `packages/business-buddy-tools/src/bb_tools/scrapers/unified.py`
* **Purpose**: The single entry point for all web scraping. It automatically selects the best strategy (BeautifulSoup, Firecrawl, Jina) based on the URL and configuration.
* **Unified Search**: `UnifiedSearchTool(config: SearchConfig)`
* **Location**: `packages/business-buddy-tools/src/bb_tools/search/unified.py`
* **Purpose**: The single entry point for web searches. It can query multiple providers (Tavily, Jina, Arxiv) in parallel and deduplicates results.
* **Search & Scrape Fetcher**: `WebContentFetcher(search_tool, scraper)`
* **Location**: `packages/business-buddy-tools/src/bb_tools/actions/fetch.py`
* **Purpose**: A high-level tool that combines the `UnifiedSearchTool` and `UnifiedScraper` to perform a complete "search and scrape" workflow.
By consistently using these established patterns and utilities, you will improve code quality, reduce bugs, and make the entire project easier to understand and maintain.

View File

@@ -0,0 +1,88 @@
Check the following code for quality and strict type safety: $ARGUMENTS
# Type Safety and Type Checking Rules
## 1. Type Safety
- All function parameters and variables must have explicit types declared or be inferable with strict typing.
- **No usage of `Any` is allowed**, either explicitly or implicitly.
- Make sure all types are explicit and specific. Avoid overly broad types like `object`.
- Use typing features from the standard library (`typing` or `typing_extensions`), preferably Python 3.9+ syntax (`list[str]` instead of `List[str]`).
- Avoid legacy/deprecated typing constructs, e.g., no `NotRequired` from older `typing-extensions` versions or Pydantic v1.0 conventions.
- Use Pydantic v2+ patterns if necessary or alternatives that are modern and well-maintained.
---
## 2. Return Types
- Every function and method must have an explicit return type annotation, even if returning `None`.
- For async functions, specify return types with `async def func(...) -> Awaitable[ReturnType]` or more commonly just `async def func(...) -> ReturnType`.
- Avoid implicit returns or assumed returns.
- If the function returns multiple types (e.g., `Union`), make it clear and as narrow as possible.
---
## 3. Type Hints
- Use inline type hints for variables where applicable and useful to improve readability.
- Use `TypedDict`, `Literal`, `Protocol`, `Final`, `TypeAlias` and other modern typing features appropriately.
- **No shortcuts:** do not disable or skip type checking rules or warnings anywhere.
- Errors and warnings raised by pyrefly, ruff, and basedpyright must be resolved, not ignored.
- Don't use third-party typing decorators or hacks that circumvent standard Python type hints.
---
## 4. No Ignoring or Shortcuts
- Do not add `# type: ignore` or lint ignores.
- All issues flagged by linters/type-checkers must be handled by fixing the root problem.
- Don't commit code with ignored errors or warnings.
- Avoid complexity that makes typing impossible or too complex; refactor instead.
---
## 5. Tools and Integration
- Use **pyrefly** for autocompletion and precise inline type hints while coding.
- Use **ruff** to enforce style and typing rules:
- Enable rules that enforce function annotations (flake8-annotations or pyright analogs).
- Enable rules to disallow `Any` and enforce consistent type hints.
- Use **basedpyright** for type checking:
- Treat all warnings/errors as errors.
- Set strict mode enabled.
- Fix all type errors reported.
---
## Example Enforcement Sample
```python
from typing import List
def get_usernames(users: List[str]) -> List[str]:
# Explicit input and output types
return [user.strip().lower() for user in users]
async def fetch_data(url: str) -> dict[str, str]:
# Explicit return type for async func
response = await some_http_client.get(url)
return response.json()
```
- No use of `Any` anywhere.
- All functions/types explicitly annotated.
- No places where errors are ignored.
---
## Summary Checklist
- All functions/methods have explicit return types.
- No use of `Any` types.
- No `# type: ignore` or lint ignores.
- No legacy/disallowed typing patterns (`NotRequired`, Pydantic v1, etc.).
- Correct use of modern python typing features.
- Passes all pyrefly suggestions without disabling checks.
- Passes all ruff lint rules, especially regarding annotations.
- Passes all basedpyright type checks with strict mode and zero errors.

View File

@@ -4,18 +4,45 @@
{
"hooks": [
{
"command": "cd $(git rev-parse --show-toplevel) && /home/vasceannie/repos/biz-budz/scripts/black-file.sh",
"command": "cd $(git rev-parse --show-toplevel) && ./scripts/black-file.sh",
"type": "command"
}
],
"matcher": "Write|Edit|MultiEdit"
},
{
"hooks": [
{
"command": "echo 'BLOCKED: Using SKIP= with git commit is forbidden. Run proper pre-commit hooks instead.' && exit 1",
"type": "command"
}
],
"matcher": "Bash.*SKIP=.*git commit"
},
{
"hooks": [
{
"command": "echo 'BLOCKED: Using --no-verify with git commit is forbidden. Pre-commit hooks must run.' && exit 1",
"type": "command"
}
],
"matcher": "Bash.*git commit.*--no-verify"
},
{
"hooks": [
{
"command": "echo 'BLOCKED: Using git commit with -n flag (no-verify) is forbidden. Pre-commit hooks must run.' && exit 1",
"type": "command"
}
],
"matcher": "Bash.*git commit.*-n"
}
],
"PostToolUse": [
{
"hooks": [
{
"command": "cd $(git rev-parse --show-toplevel) && /home/vasceannie/repos/biz-budz/scripts/lint-file.sh",
"command": "cd $(git rev-parse --show-toplevel) && ./scripts/lint-file.sh",
"type": "command"
}
],

57
.devcontainer/Dockerfile Normal file
View File

@@ -0,0 +1,57 @@
FROM mcr.microsoft.com/devcontainers/python:1-3.12-bookworm
# Install additional system dependencies
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
&& apt-get -y install --no-install-recommends \
postgresql-client \
redis-tools \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# Install Node.js and npm
RUN curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - \
&& apt-get install -y nodejs \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# Install UV package manager globally
RUN pip install --no-cache-dir uv
# Install global npm packages
RUN npm install -g task-master-ai repomix @anthropic-ai/claude-code
# Create cache directories with proper permissions
RUN mkdir -p /home/vscode/.cache/uv \
&& chown -R vscode:vscode /home/vscode/.cache
# Set up Python environment paths
ENV PYTHONPATH="/workspace/src:/workspace:$PYTHONPATH"
# Allow vscode user to use sudo without password for permission fixes
RUN echo "vscode ALL=(ALL) NOPASSWD: /bin/chown, /bin/chmod, /usr/bin/chown, /usr/bin/chmod" >> /etc/sudoers
# Set working directory
WORKDIR /workspace
# Create a startup script to fix permissions on volumes and ensure group membership
RUN echo '#!/bin/bash\n\
# Fix ownership of workspace directories\n\
if [ -d "/workspace/.venv" ]; then\n\
sudo chown -R vscode:vscode /workspace/.venv\n\
fi\n\
# Ensure vscode user is in root group for shared permissions\n\
if ! groups vscode | grep -q "\broot\b"; then\n\
sudo usermod -a -G root vscode\n\
fi\n\
# Fix permissions on config and cache directories if they exist\n\
if [ -d "/home/vscode/.config" ]; then\n\
sudo chown -R vscode:vscode /home/vscode/.config 2>/dev/null || true\n\
fi\n\
if [ -d "/home/vscode/.cache" ]; then\n\
sudo chown -R vscode:vscode /home/vscode/.cache 2>/dev/null || true\n\
fi\n\
exec "$@"' > /usr/local/bin/docker-entrypoint.sh \
&& chmod +x /usr/local/bin/docker-entrypoint.sh
ENTRYPOINT ["/usr/local/bin/docker-entrypoint.sh"]
CMD ["sleep", "infinity"]

View File

@@ -0,0 +1,95 @@
{
"name": "Business Buddy Dev Container",
"dockerComposeFile": "docker-compose.yml",
"service": "app",
"workspaceFolder": "/workspace",
"features": {
"ghcr.io/devcontainers/features/python:1": {
"version": "3.12",
"installJupyterlab": false
},
"ghcr.io/devcontainers/features/node:1": {
"version": "lts"
},
"ghcr.io/devcontainers/features/git:1": {},
"ghcr.io/devcontainers/features/github-cli:1": {},
"ghcr.io/devcontainers/features/common-utils:2": {
"installZsh": true,
"configureZshAsDefaultShell": true,
"installOhMyZsh": true,
"installOhMyZshConfig": true,
"upgradePackages": true
}
},
"customizations": {
"vscode": {
"settings": {
"python.defaultInterpreterPath": "/workspace/.venv/bin/python",
"python.terminal.activateEnvironment": true,
"python.linting.enabled": true,
"python.linting.pylintEnabled": false,
"python.formatting.provider": "none",
"[python]": {
"editor.defaultFormatter": "charliermarsh.ruff",
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit",
"source.fixAll": "explicit"
}
},
"terminal.integrated.defaultProfile.linux": "zsh",
"git.enabled": true,
"git.path": "/usr/local/bin/git",
"git.allowNoVerifyCommit": false,
"git.alwaysSignOff": false,
"git.useEditorAsCommitInput": true,
"git.enableCommitSigning": false
},
"extensions": [
"ms-python.python",
"ms-python.vscode-pylance",
"charliermarsh.ruff",
"ms-azuretools.vscode-docker",
"GitHub.copilot",
"ms-vscode.makefile-tools",
"donjayamanne.githistory",
"streetsidesoftware.code-spell-checker"
]
}
},
"forwardPorts": [2024, 5432, 6379, 6333],
"postCreateCommand": "bash .devcontainer/setup.sh",
"containerEnv": {
"PYTHONPATH": "/workspace/src:/workspace",
"DATABASE_URL": "postgres://user:password@postgres:5432/langgraph_db",
"REDIS_URL": "redis://redis:6379/0",
"QDRANT_URL": "http://qdrant:6333",
"CHOKIDAR_USEPOLLING": "true",
"WATCHPACK_POLLING": "true",
"LANGGRAPH_HOST": "0.0.0.0",
"LANGGRAPH_PORT": "2024",
"LANGCHAIN_TRACING_V2": "true",
"LANGCHAIN_PROJECT": "biz-budz-dev",
"LANGSMITH_TRACING": "true"
},
"mounts": [
"source=${localWorkspaceFolder},target=/workspace,type=bind,consistency=delegated",
"source=${localEnv:HOME}/.config,target=/home/vscode/.config,type=bind,consistency=cached",
"source=${localEnv:HOME}/.cache,target=/home/vscode/.cache,type=bind,consistency=cached"
],
"remoteUser": "vscode",
"runArgs": [
"--user=1000:1000",
"--group-add=0"
],
"updateContentCommand": "bash .devcontainer/fix-permissions.sh"
}

View File

@@ -0,0 +1,125 @@
services:
app:
build:
context: ..
dockerfile: .devcontainer/Dockerfile
ports:
- "2024:2024" # LangGraph Studio
user: "1000:1000"
volumes:
# Main workspace with delegated consistency for better performance
- type: bind
source: ..
target: /workspace
consistency: delegated
# Mount .venv to a volume to hide host's .venv
- venv-volume:/workspace/.venv
# Bind mount host .venv for shared access (different name to avoid conflicts)
- ./.venv-host:/workspace/.venv-host
# UV cache volume
- uv-cache:/home/vscode/.cache/uv
# Additional bind mounts for hot reload
- type: bind
source: ../src
target: /workspace/src
consistency: consistent
- type: bind
source: ../tests
target: /workspace/tests
consistency: consistent
- type: bind
source: ../packages
target: /workspace/packages
consistency: consistent
# Mount host config and cache directories for sharing permissions and API access
- type: bind
source: ${HOME}/.config
target: /home/vscode/.config
consistency: cached
- type: bind
source: ${HOME}/.cache
target: /home/vscode/.cache
consistency: cached
environment:
# Python configuration
PYTHONPATH: /workspace/src:/workspace
# Database connections
DATABASE_URL: postgres://user:password@postgres:5432/langgraph_db
REDIS_URL: redis://redis:6379/0
QDRANT_URL: http://qdrant:6333
# Development environment
NODE_ENV: development
PYTHON_ENV: development
# LangGraph configuration
LANGGRAPH_HOST: 0.0.0.0
LANGGRAPH_PORT: 2024
# LangSmith/LangChain tracing configuration
LANGCHAIN_TRACING_V2: "true"
LANGCHAIN_PROJECT: "biz-budz-dev"
LANGSMITH_TRACING: "true"
# Pass through API keys from host environment
LANGCHAIN_API_KEY: ${LANGCHAIN_API_KEY:-}
LANGSMITH_API_KEY: ${LANGSMITH_API_KEY:-}
ANTHROPIC_API_KEY: ${ANTHROPIC_API_KEY:-}
OPENAI_API_KEY: ${OPENAI_API_KEY:-}
TAVILY_API_KEY: ${TAVILY_API_KEY:-}
PERPLEXITY_API_KEY: ${PERPLEXITY_API_KEY:-}
depends_on:
- postgres
- redis
- qdrant
networks:
- devcontainer-network
postgres:
image: postgres:14
restart: unless-stopped
volumes:
- postgres-data:/var/lib/postgresql/data
- ./init-items.sql:/docker-entrypoint-initdb.d/init-items.sql:ro
environment:
POSTGRES_USER: user
POSTGRES_PASSWORD: password
POSTGRES_DB: langgraph_db
networks:
- devcontainer-network
healthcheck:
test: ["CMD-SHELL", "pg_isready -U user -d langgraph_db"]
interval: 5s
timeout: 5s
retries: 5
redis:
image: redis:7
restart: unless-stopped
volumes:
- redis-data:/data
networks:
- devcontainer-network
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 5s
timeout: 3s
retries: 5
qdrant:
image: qdrant/qdrant:latest
restart: unless-stopped
volumes:
- qdrant-storage:/qdrant/storage
environment:
QDRANT__SERVICE__HTTP_PORT: 6333
networks:
- devcontainer-network
volumes:
postgres-data:
redis-data:
qdrant-storage:
venv-volume:
uv-cache:
networks:
devcontainer-network:
driver: bridge

View File

@@ -0,0 +1,40 @@
#!/bin/bash
# Quick fix script for permissions and API access in running container
echo "🔧 Fixing permissions and configuring API access..."
# Fix ownership of cache directories
sudo chown -R vscode:vscode /home/vscode/.cache 2>/dev/null || true
sudo chown -R vscode:vscode /home/vscode/.config 2>/dev/null || true
# Create cache directory with proper permissions
mkdir -p /home/vscode/.cache/uv
touch /home/vscode/.cache/uv/CACHEDIR.TAG
chmod -R 755 /home/vscode/.cache/uv
# Create config directories for LangSmith/LangChain
mkdir -p /home/vscode/.config/langchain
mkdir -p /home/vscode/.config/langsmith
# Ensure vscode user is in root group for shared permissions
if ! groups vscode | grep -q "\broot\b"; then
sudo usermod -a -G root vscode
echo "✓ Added vscode user to root group"
fi
# Copy shell profiles
if [ -f /home/vscode/.zshrc.host ]; then
cp /home/vscode/.zshrc.host /home/vscode/.zshrc
sed -i 's|/home/vasceannie|/home/vscode|g' /home/vscode/.zshrc
echo "✓ Copied .zshrc"
fi
if [ -f /home/vscode/.bashrc.host ]; then
cp /home/vscode/.bashrc.host /home/vscode/.bashrc
echo "✓ Copied .bashrc"
fi
# Set proper permissions for workspace
sudo chown -R vscode:vscode /workspace/.venv 2>/dev/null || true
echo "✅ Permissions and API access configured!"

View File

@@ -0,0 +1,146 @@
CREATE TABLE IF NOT EXISTS items (
id VARCHAR(32) PRIMARY KEY,
name TEXT NOT NULL,
online_name TEXT,
hidden BOOLEAN DEFAULT FALSE,
available BOOLEAN DEFAULT TRUE,
auto_manage BOOLEAN DEFAULT FALSE,
price INTEGER NOT NULL,
price_type VARCHAR(16) NOT NULL,
default_tax_rates BOOLEAN,
is_revenue BOOLEAN,
modified_time BIGINT NOT NULL,
deleted BOOLEAN DEFAULT FALSE,
enabled_online BOOLEAN DEFAULT TRUE,
alternate_name TEXT,
code TEXT,
sku TEXT,
unit_name TEXT,
cost INTEGER,
price_without_vat INTEGER
);
INSERT INTO items (
id, name, online_name, hidden, available, auto_manage, price, price_type, default_tax_rates, is_revenue,
modified_time, deleted, enabled_online, alternate_name, code, sku, unit_name, cost, price_without_vat
) VALUES
('8FWKTEEY952NJ','LGE Can Tastee Cheese 4.8-Lbs','LGE Can Tastee Cheese 4.8-Lbs',FALSE,TRUE,FALSE,7000,'FIXED',TRUE,TRUE,1742946876000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
('AES43GBJYVMHE','Chili Wings 3 for $4.00','Chili Wings 3 for $4.00',FALSE,TRUE,FALSE,400,'FIXED',TRUE,TRUE,1741900252000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
('0PEJY11DE4K1Y','LG Brown Stew Chicken Only','LG Brown Stew Chicken Only',FALSE,TRUE,FALSE,2500,'FIXED',TRUE,TRUE,1741546970000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
('94TSPNQ53MHT2','LG Beef Ribs','LG Beef Ribs',FALSE,TRUE,FALSE,1800,'FIXED',TRUE,TRUE,1740336153000,FALSE,TRUE,'','','','',0,0),
('TRMEF1K37ZVBY','Reg Beef Ribs','Reg Beef Ribs',FALSE,TRUE,FALSE,1500,'FIXED',TRUE,TRUE,1740336132000,FALSE,TRUE,'','','','',0,0),
('XDVX9NQGGNHWM','LG Cookup Salmon','LG Cookup Salmon',FALSE,TRUE,FALSE,1450,'FIXED',TRUE,TRUE,1736307052000,FALSE,TRUE,'','','','',0,0),
('3ZSY0QBJXADYP','Reg Cookup Salmon','Reg Cookup Salmon',FALSE,TRUE,FALSE,1250,'FIXED',TRUE,TRUE,1736307028000,FALSE,TRUE,'','','','',0,0),
('FVE90V5KANJMC','Mini Cookup Salmon','Mini Cookup Salmon',FALSE,TRUE,FALSE,1000,'FIXED',TRUE,TRUE,1741888827000,FALSE,TRUE,'','','','',0,0),
('Y5S0TVYTEPMPY','Christmas 2024 Package For 2-4 People','Christmas 2024 Package For 2-4 People',TRUE,FALSE,FALSE,22500,'FIXED',TRUE,TRUE,1740882286000,FALSE,TRUE,'','','','',0,0),
('JCRTY9PR0PWGM','Christmas 2024 Package For 6-8 People','Christmas 2024 Package For 6-8 People',TRUE,FALSE,FALSE,55000,'FIXED',TRUE,TRUE,1740882290000,FALSE,TRUE,'','','','',0,0),
('550XACY435GEG','Christmas 2024 Package For 10-15 People','Christmas 2024 Package For 10-15 People',TRUE,FALSE,FALSE,75000,'FIXED',TRUE,TRUE,1740882282000,FALSE,TRUE,'','','','',0,0),
('EDKQXCK0M7SQW','Christmas 2024 Package For 20-25 People','Christmas 2024 Package For 20-25 People',TRUE,FALSE,FALSE,95000,'FIXED',TRUE,TRUE,1740882288000,FALSE,TRUE,'','','','',0,0),
('0TRKQ08BW6RNY','SM Tray Spinach Rice','SM Tray Spinach Rice',FALSE,TRUE,FALSE,6000,'FIXED',TRUE,TRUE,1732042296000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
('3J7VACG51026P','SM Tray Spinach','SM Tray Spinach',FALSE,TRUE,FALSE,7000,'FIXED',TRUE,TRUE,1732042206000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
('8XZS0RP2YAXGP','LG Tray Veggie Fried Rice','LG Tray Veggie Fried Rice',FALSE,TRUE,FALSE,12500,'FIXED',TRUE,TRUE,1732041639000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
('VRS51MX1PY98E','LG Veggie Fried Rice','LG Veggie Fried Rice',FALSE,TRUE,FALSE,1400,'FIXED',TRUE,TRUE,1732040104000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
('KWMNPW3RTVJ6Y','Reg Veggie Fried Rice','Reg Veggie Fried Rice',FALSE,TRUE,FALSE,1100,'FIXED',TRUE,TRUE,1732040028000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
('QX3077QP7ECAP','Mini Veggie Fried Rice','Mini Veggie Fried Rice',FALSE,TRUE,FALSE,800,'FIXED',TRUE,TRUE,1732039964000,FALSE,TRUE,'','','','',0,0),
('FPBG4P3PMDZYC','LG TRAY ITAL STEW','LG TRAY ITAL STEW',FALSE,TRUE,FALSE,13000,'FIXED',TRUE,TRUE,1737753773000,FALSE,TRUE,'','','','',0,0),
('KSM43DRQMAWZT','SM TRAY ITAL STEW','SM TRAY ITAL STEW',FALSE,TRUE,FALSE,6500,'FIXED',TRUE,TRUE,1737753765000,FALSE,TRUE,'','','','',0,0),
('RG1KS6KPK9TKR','SM Tray Jerk Chicken Lo Mein','SM Tray Jerk Chicken Lo Mein',FALSE,TRUE,FALSE,7500,'FIXED',TRUE,TRUE,1731973151000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
('CPQ44F6PXZ232','SM TRAY CANDID YAMS','SM TRAY CANDID YAMS',FALSE,TRUE,FALSE,4000,'FIXED',TRUE,TRUE,1732045788000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
('26MBWTPXPB2B6','PACKAGE FOR 6-8 PEOPLE','PACKAGE FOR 6-8 PEOPLE',TRUE,FALSE,FALSE,32500,'FIXED',TRUE,TRUE,1740894161000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
('7YTB27AP3VENC','PACKAGE FOR 2-4 PEOPLE','PACKAGE FOR 2-4 PEOPLE',TRUE,FALSE,FALSE,15000,'FIXED',TRUE,TRUE,1740894155000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
('FJM2JJKF1WAE0','PACKAGE FOR 10-15 PEOPLE','PACKAGE FOR 10-15 PEOPLE',TRUE,FALSE,FALSE,47500,'FIXED',TRUE,TRUE,1741890382000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
('WKKVKJ6SK0G46','PACKAGE FOR 20-25 PEOPLE','PACKAGE FOR 20-25 PEOPLE',TRUE,FALSE,FALSE,90000,'FIXED',TRUE,TRUE,1740894159000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
('EQ3AN5BC9Z9GE','SEAFOOD DINNER PACKAGE 8-12 PEOPLE','SEAFOOD DINNER PACKAGE 8-12 PEOPLE',TRUE,FALSE,FALSE,45000,'FIXED',TRUE,TRUE,1741892428000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
('6CZDMYB936MP6','LG TRAY JERK CHICKEN AND SHRIMP LOMEIN','LG TRAY JERK CHICKEN AND SHRIMP LOMEIN',FALSE,TRUE,FALSE,19000,'FIXED',TRUE,TRUE,1730731144000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
('Q40XKXMYZSYVR','SM TRAY JERK CHICKEN AND SHRIMP LOMEIN','SM TRAY JERK CHICKEN AND SHRIMP LOMEIN',FALSE,TRUE,FALSE,9500,'FIXED',TRUE,TRUE,1730731082000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
('1ENMJYCQCVBJ4','LG TRAY JERK CHICKEN AND SHRIMP RASTA PASTA','LG TRAY JERK CHICKEN AND SHRIMP RASTA PASTA',FALSE,TRUE,FALSE,24000,'FIXED',TRUE,TRUE,1730730985000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
('RQGMC7Q52E8VT','SM TRAY JERK CHICKEN AND SHRIMP RASTA PASTA','SM TRAY JERK CHICKEN AND SHRIMP RASTA PASTA',FALSE,TRUE,FALSE,12000,'FIXED',TRUE,TRUE,1730730724000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
('DEDEAPYWRV60P','SMALL TRAY COOKUP CODFISH','SMALL TRAY COOKUP CODFISH',FALSE,TRUE,FALSE,12500,'FIXED',TRUE,TRUE,1729783003000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
('7JWZG5VQ2FNW0','1 DOZEN CODFISH FRITTER','1 DOZEN CODFISH FRITTER',FALSE,TRUE,FALSE,2400,'FIXED',TRUE,TRUE,1740879802000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
('Z2592B3GP15PA','1 DOZEN FESTIVAL','1 DOZEN FESTIVAL',FALSE,TRUE,FALSE,2400,'FIXED',TRUE,TRUE,1740879835000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
('J70AD21A88E5C','I DOZEN FRIED DUMPLING','I DOZEN FRIED DUMPLING',TRUE,FALSE,FALSE,1800,'FIXED',TRUE,TRUE,1740890803000,FALSE,TRUE,NULL,NULL,NULL,NULL,NULL,NULL),
('TPK211AZHRN8E','Extra Pumpkin Or Spinach Rice-LG','Extra Pumpkin Or Spinach Rice-LG',FALSE,TRUE,FALSE,300,'FIXED',TRUE,TRUE,1726181679000,FALSE,TRUE,'','','','',0,0),
('Z2R5SFT2FK2QP','Extra Pumpkin Or Spinach Rice-Reg','Extra Pumpkin Or Spinach Rice-Reg',FALSE,TRUE,FALSE,200,'FIXED',TRUE,TRUE,1726181558000,FALSE,TRUE,'','','','',0,0),
('7KDDT1QWR57GC','Quinoi','Quinoi',TRUE,FALSE,FALSE,300,'FIXED',TRUE,TRUE,1741891643000,FALSE,TRUE,'','','','',0,0),
('MGR82XDRZKFAE','Whole Cheese Cake','Whole Cheese Cake',FALSE,TRUE,FALSE,5500,'FIXED',TRUE,TRUE,1726006044000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
('F1TB52KPDAFMG','Whole Chocolate Cake','Whole Chocolate Cake',FALSE,TRUE,FALSE,5500,'FIXED',TRUE,TRUE,1726005980000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
('AFXX5PGWX3GK8','Whole Red Velvet Cake','Whole Red Velvet Cake',FALSE,TRUE,FALSE,5500,'FIXED',TRUE,TRUE,1726005743000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
('WTJ801A75ZAK0','LG Tray Esco Chicken','LG Tray Esco Chicken',FALSE,TRUE,FALSE,14500,'FIXED',TRUE,TRUE,1725666148000,FALSE,TRUE,'','','','',0,0),
('DMD91V58QH1FY','Esco Snapper Meal','Esco Snapper Meal',FALSE,TRUE,FALSE,2000,'FIXED',TRUE,TRUE,1741545187000,FALSE,TRUE,'','','','',0,0),
('3867387970RC0','Steam Butter Fish','Steam Butter Fish',TRUE,FALSE,FALSE,1500,'FIXED',TRUE,TRUE,1741893336000,FALSE,TRUE,'','','','',0,0),
('JF3GTKY0392YW','Bammy','Bammy',FALSE,TRUE,FALSE,200,'FIXED',TRUE,TRUE,1729341142000,FALSE,TRUE,'','','','',0,0),
('DXBBZ2FCVGX02','Steam Snapper Fish','Steam Snapper Fish',FALSE,TRUE,FALSE,2200,'FIXED',TRUE,TRUE,1741893681000,FALSE,TRUE,'','','','',0,0),
('1P1F01N96ZWJG','Roots Man','Roots Man',FALSE,TRUE,FALSE,600,'FIXED',TRUE,TRUE,1722967143000,FALSE,TRUE,'','','','',0,0),
('SGC3DAXDH9ZBG','Fiwi Lemongrass Tea','Fiwi Lemongrass Tea',FALSE,TRUE,FALSE,499,'FIXED',TRUE,TRUE,1721353508000,FALSE,TRUE,'','','','',0,0),
('8PJ2VCBEV4T1Y','Extra Mac & Cheese','Extra Mac & Cheese',FALSE,TRUE,FALSE,150,'FIXED',TRUE,TRUE,1717690516000,FALSE,TRUE,'','','','',0,0),
('0Y0FJHHCWZKTT','LG Butter Shrimp (9 Shrimp)','LG Butter Shrimp (9 Shrimp)',FALSE,TRUE,FALSE,2100,'FIXED',TRUE,TRUE,1741547134000,FALSE,TRUE,'','','','',0,0),
('7Y2H5RX0XSX3Y','Extra Rasta Pasta-LG','Extra Rasta Pasta-LG',FALSE,TRUE,FALSE,600,'FIXED',TRUE,TRUE,1716147863000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
('1V7H5DXSCC2Y6','LG Shrimp LoMein','LG Shrimp LoMein',FALSE,TRUE,FALSE,22500,'FIXED',TRUE,TRUE,1715966030000,FALSE,TRUE,'','','','',0,0),
('AWBNJDJQQBSD4','Stakehouse Cheesecake','Stakehouse Cheesecake',FALSE,TRUE,FALSE,550,'FIXED',TRUE,TRUE,1715048741000,FALSE,TRUE,'','','','',0,0),
('V7EYC23J5C63J','Fruit Of Life Coconut Water','Fruit Of Life Coconut Water',TRUE,FALSE,FALSE,499,'FIXED',TRUE,TRUE,1740889976000,FALSE,TRUE,'','','','',0,0),
('BFWXPV6B6Z2VC','Esco Saltfish Meal-Large','Esco Saltfish Meal-Large',FALSE,TRUE,FALSE,1700,'FIXED',TRUE,TRUE,1740887749000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
('8A1HWX6V88F2Y','Esco Saltfish Meal-Regular ','Esco Saltfish Meal-Regular ',FALSE,TRUE,FALSE,1400,'FIXED',TRUE,TRUE,1712318630000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
('3JP9245D50S9Y','Esco Saltfish Meal-Mini','Esco Saltfish Meal-Mini',FALSE,TRUE,FALSE,1100,'FIXED',TRUE,TRUE,1740887717000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
('VECQK0YJZZEZ4','Soursop With Lime','Soursop With Lime',FALSE,TRUE,FALSE,599,'FIXED',TRUE,TRUE,1714774724000,FALSE,TRUE,'','','','',0,0),
('2H4XS57ZD9D9T','SOURSOP LEAF W/LEMON & GINGER','SOURSOP LEAF W/LEMON & GINGER',FALSE,TRUE,FALSE,499,'FIXED',TRUE,TRUE,1711841482000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
('YHDKA6KPQFXRC','Mackerel In Tomato Sauce-Large','Mackerel In Tomato Sauce-Large',FALSE,TRUE,FALSE,1500,'FIXED',TRUE,TRUE,1741888772000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
('66JY11P9BC4BE','Mackerel In Tomato Sauce-Regular ','Mackerel In Tomato Sauce-Regular ',FALSE,TRUE,FALSE,1300,'FIXED',TRUE,TRUE,1741888756000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
('DFYHFNPWVV9F2','Mackerel In Tomato Sauce-Mini','Mackerel In Tomato Sauce-Mini',FALSE,TRUE,FALSE,1000,'FIXED',TRUE,TRUE,1741888738000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
('8NEBXD6RF0QSM','HTB Easter Bun','HTB Easter Bun',TRUE,FALSE,FALSE,1850,'FIXED',TRUE,TRUE,1740890798000,FALSE,TRUE,'','','','',0,0),
('DM6RTJH8KVKS6','SM CAN TASTEE CHEESE 1.1 LB','SM CAN TASTEE CHEESE 1.1 LB',TRUE,FALSE,FALSE,2199,'FIXED',TRUE,TRUE,1742946771000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
('28AKKR5RY1YJ8','SM CAN TASTEE CHEESE 8.8OZ','SM CAN TASTEE CHEESE 8.8OZ',TRUE,FALSE,FALSE,1399,'FIXED',TRUE,TRUE,1742946780000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
('132FXN33YJB1P','9" Fruit Cake','9" Fruit Cake',FALSE,TRUE,FALSE,8500,'FIXED',TRUE,TRUE,1734399306000,FALSE,TRUE,'','','','',0,0),
('4X9KC9YZV9HG8','8" Fruit Cake','8" Fruit Cake',FALSE,TRUE,FALSE,6500,'FIXED',TRUE,TRUE,1734399293000,FALSE,TRUE,'','','','',0,0),
('4133GEYQJMDZ4','Extra shrimp Fried Rice-Mini','Extra shrimp Fried Rice-Mini',FALSE,TRUE,FALSE,250,'FIXED',TRUE,TRUE,1740889316000,FALSE,TRUE,'','','','',0,0),
('WSWZD03ZPJXBA','Extra shrimp Fried Rice-Lge','Extra shrimp Fried Rice-Lge',FALSE,TRUE,FALSE,500,'FIXED',TRUE,TRUE,1741545377000,FALSE,TRUE,'','','','',0,0),
('YW7Q7B1R1RBBT','Extra shrimp Fried Rice-Reg','Extra shrimp Fried Rice-Reg',FALSE,TRUE,FALSE,400,'FIXED',TRUE,TRUE,1709922064000,FALSE,TRUE,'','','','',0,0),
('N7BQK2H3HJQ8Y','Mini Oxtail Rasta Pasta','Mini Oxtail Rasta Pasta',TRUE,FALSE,FALSE,1300,'FIXED',TRUE,TRUE,1741888888000,FALSE,TRUE,'','','','',0,0),
('SGNFMEHEDME44','Oxtail Rasta Pasta-Large','Oxtail Rasta Pasta-Large',TRUE,FALSE,FALSE,2100,'FIXED',TRUE,TRUE,1742955126000,FALSE,TRUE,'','','','',0,0),
('FXJRNBKNXPWN4','Oxtail Rasta Pasta-Regular','Oxtail Rasta Pasta-Regular',TRUE,FALSE,FALSE,1800,'FIXED',TRUE,TRUE,1741890372000,FALSE,TRUE,'','','','',0,0),
('SCKJSJBP6655W','Gift card','Gift card',FALSE,TRUE,FALSE,0,'VARIABLE',FALSE,FALSE,1707342109000,FALSE,TRUE,NULL,'CLOVER_GIFT_CARD',NULL,NULL,NULL,NULL),
('18RWNDTG7Q8H4','Bag of Cookies','Bag of Cookies',TRUE,FALSE,FALSE,895,'FIXED',TRUE,TRUE,1740880223000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
('YGV43SZXQVJZ0','Extra Veggie Rice-Large','Extra Veggie Rice-Large',FALSE,TRUE,FALSE,400,'FIXED',TRUE,TRUE,1741545489000,FALSE,TRUE,'','','','',0,0),
('E0YYQ8NBQ6K32','Extra Veggie Rice-Mini','Extra Veggie Rice-Mini',FALSE,TRUE,FALSE,150,'FIXED',TRUE,TRUE,1705960154000,FALSE,TRUE,'','','','',0,0),
('TEDBEC74W0788','Extra Veggie Rice-Reg','Extra Veggie Rice-Reg',FALSE,TRUE,FALSE,250,'FIXED',TRUE,TRUE,1705960103000,FALSE,TRUE,'','','','',0,0),
('7K83E2CGP5A72','Veggie Soup for 200 People','Veggie Soup for 200 People',TRUE,FALSE,FALSE,30000,'FIXED',TRUE,TRUE,1741894706000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
('KZRXCM1SJY9RE','FRUIT CAKE W/ICING SLICE','FRUIT CAKE W/ICING SLICE',FALSE,TRUE,FALSE,495,'FIXED',TRUE,TRUE,1704486470000,FALSE,TRUE,'','','','',0,0),
('66QNN6HH1Z28C','Jerk Chicken Fried Rice-LG','Jerk Chicken Fried Rice-LG',FALSE,TRUE,FALSE,1550,'FIXED',TRUE,TRUE,1732040452000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
('GT6DEXTX8B57P','Jerk Chicken Fried Rice-Reg','Jerk Chicken Fried Rice-Reg',FALSE,TRUE,FALSE,1300,'FIXED',TRUE,TRUE,1740890941000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
('ZJ6MH0DXTZ5VG','Jerk Chicken Fried Rice-Mini','Jerk Chicken Fried Rice-Mini',FALSE,TRUE,FALSE,1100,'FIXED',TRUE,TRUE,1740890971000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
('79AZD841FFGYE','2024 Thanksgiving Package -20-25 People','2024 Thanksgiving Package -20-25 People',TRUE,FALSE,FALSE,90000,'FIXED',TRUE,TRUE,1740879977000,FALSE,TRUE,'','','','',0,0),
('GM610GMEHNJ74','One Gallon Sorrel','One Gallon Sorrel',FALSE,TRUE,FALSE,3000,'FIXED',TRUE,TRUE,1699824929000,FALSE,TRUE,'','','','',0,0),
('NDZV105PFGPD6','One Gallon Lemonade','One Gallon Lemonade',FALSE,TRUE,FALSE,3000,'FIXED',TRUE,TRUE,1699824899000,FALSE,TRUE,'','','','',0,0),
('ZYP2T4VBEGYBR','One Gallon Carrot Juice','One Gallon Carrot Juice',FALSE,TRUE,FALSE,3000,'FIXED',TRUE,TRUE,1741890338000,FALSE,TRUE,'','','','',0,0),
('TR67T34YF7HKA','2024 Thanksgiving Package 10-15 People','2024 Thanksgiving Package 10-15 People',TRUE,FALSE,FALSE,47500,'FIXED',TRUE,TRUE,1740879984000,FALSE,TRUE,'','','','',0,0),
('TX3NFGDBEH7RT','Veggie LoMein Mini','Veggie LoMein Mini',FALSE,TRUE,FALSE,900,'FIXED',TRUE,TRUE,1699729415000,FALSE,TRUE,'','','','',0,0),
('X1T0JPPCWR5MG','Veggie LoMein Reg','Veggie LoMein Reg',FALSE,TRUE,FALSE,1200,'FIXED',TRUE,TRUE,1699729309000,FALSE,TRUE,'','','','',0,0),
('FFW13B7P1QX6A','Veggie LoMein Lge','Veggie LoMein Lge',FALSE,TRUE,FALSE,1500,'FIXED',TRUE,TRUE,1699729351000,FALSE,TRUE,'','','','',0,0),
('D7K8GCZVRJH4Y','Reg Cow Foot Only','Reg Cow Foot Only',FALSE,TRUE,FALSE,1500,'FIXED',TRUE,TRUE,1697315969000,FALSE,TRUE,'','','','',0,0),
('K09YHN1WEQ0RY','Reg Stew Peas Only','Reg Stew Peas Only',FALSE,TRUE,FALSE,1500,'FIXED',TRUE,TRUE,1697315886000,FALSE,TRUE,'','','','',0,0),
('CPCQEYF5F3M8Y','CURRY SHRIMP & LOBSTER COMBO','CURRY SHRIMP & LOBSTER COMBO',TRUE,FALSE,FALSE,2200,'FIXED',TRUE,TRUE,1740887468000,FALSE,TRUE,'','','','',0,0),
('KE17PY026J59W','1 Piece Crab Leg','1 Piece Crab Leg',FALSE,TRUE,FALSE,1000,'FIXED',TRUE,TRUE,1710544918000,FALSE,TRUE,NULL,NULL,NULL,NULL,0,NULL),
('D8V3PSJTF79VP','Regular Steam King Fish','Regular Steam King Fish',FALSE,TRUE,FALSE,1600,'FIXED',TRUE,TRUE,1741901309000,FALSE,TRUE,'','','','',0,0),
('NJ2AA5WVYEE9A','SHRIMP LO MEIN LG','SHRIMP LO MEIN LG',FALSE,TRUE,FALSE,1975,'FIXED',TRUE,TRUE,1741892531000,FALSE,TRUE,'','','','',0,0),
('8KJ4A7K88XDNA','SHRIMP LO MEIN REG','SHRIMP LO MEIN REG',FALSE,TRUE,FALSE,1675,'FIXED',TRUE,TRUE,1741892562000,FALSE,TRUE,'','','','',0,0),
('V3FGX8ZBXY052','SHRIMP LO MEIN-MINI','SHRIMP LO MEIN-MINI',FALSE,TRUE,FALSE,1300,'FIXED',TRUE,TRUE,1717700888000,FALSE,TRUE,'','','','',0,0),
('9VP3RZ7EQFY84','This Is It Watermelon','This Is It Watermelon',FALSE,TRUE,FALSE,450,'FIXED',TRUE,TRUE,1701397722000,FALSE,TRUE,'','','','',0,0)
ON CONFLICT (id) DO UPDATE SET
name = EXCLUDED.name,
online_name = EXCLUDED.online_name,
hidden = EXCLUDED.hidden,
available = EXCLUDED.available,
auto_manage = EXCLUDED.auto_manage,
price = EXCLUDED.price,
price_type = EXCLUDED.price_type,
default_tax_rates = EXCLUDED.default_tax_rates,
is_revenue = EXCLUDED.is_revenue,
modified_time = EXCLUDED.modified_time,
deleted = EXCLUDED.deleted,
enabled_online = EXCLUDED.enabled_online,
alternate_name = EXCLUDED.alternate_name,
code = EXCLUDED.code,
sku = EXCLUDED.sku,
unit_name = EXCLUDED.unit_name,
cost = EXCLUDED.cost,
price_without_vat = EXCLUDED.price_without_vat
;

View File

@@ -0,0 +1,118 @@
#!/bin/bash
set -e
echo "🚀 Setting up Business Buddy development environment..."
# Set up shell configuration
echo "🐚 Configuring shell environment..."
# Ensure zsh is the default shell
if command -v zsh &> /dev/null; then
sudo chsh -s $(which zsh) vscode || true
fi
# Create necessary directories for zsh plugins if they don't exist
mkdir -p ~/.cache ~/.local/bin ~/.config
# Fix UV cache permissions
if [ -d ~/.cache/uv ]; then
sudo chown -R vscode:vscode ~/.cache/uv
fi
mkdir -p ~/.cache/uv
touch ~/.cache/uv/CACHEDIR.TAG
chmod -R 755 ~/.cache/uv
# Set up zsh history
touch ~/.zsh_history
chmod 600 ~/.zsh_history
# Copy host shell profiles if they exist
if [ -f ~/.zshrc.host ]; then
cp ~/.zshrc.host ~/.zshrc
# Update paths in .zshrc for container environment
sed -i 's|/home/vasceannie|/home/vscode|g' ~/.zshrc
fi
if [ -f ~/.bashrc.host ]; then
cp ~/.bashrc.host ~/.bashrc
fi
# Source P10k config if it exists
if [ -f ~/.p10k.zsh ]; then
echo "✓ Powerlevel10k configuration found"
fi
# Ensure we're in the workspace directory
cd /workspace
# Create virtual environment if it doesn't exist
if [ ! -d ".venv" ]; then
echo "📦 Creating Python virtual environment..."
python3.12 -m venv .venv
fi
# Activate virtual environment
source .venv/bin/activate
# Install UV in the virtual environment if not already installed
if ! command -v uv &> /dev/null; then
echo "📦 Installing UV package manager..."
pip install --upgrade pip
pip install uv
fi
# Install project dependencies
echo "📦 Installing Python dependencies..."
uv pip install -e ".[dev]"
# Install pre-commit hooks
echo "🔗 Installing pre-commit hooks..."
pre-commit install || true
# Create .env file from example if it doesn't exist
if [ ! -f ".env" ]; then
if [ -f ".env.example" ]; then
cp .env.example .env
echo "📝 Created .env file from .env.example - please add your API keys"
else
echo "⚠️ No .env.example file found - please create a .env file with your API keys"
fi
fi
# Initialize Task Master if not already initialized
if [ ! -d ".taskmaster" ]; then
echo "📋 Initializing Task Master AI..."
task-master init --yes --rules claude cursor || true
fi
# Wait for services to be ready
echo "⏳ Waiting for services to start..."
sleep 5
# Check database connection
echo "🔍 Checking database connection..."
until pg_isready -h postgres -p 5432 -U user; do
echo "Waiting for PostgreSQL..."
sleep 2
done
# Check Redis connection
echo "🔍 Checking Redis connection..."
until redis-cli -h redis ping; do
echo "Waiting for Redis..."
sleep 2
done
echo "✅ Development environment setup complete!"
echo ""
echo "📌 Next steps:"
echo " 1. Add your API keys to the .env file"
echo " 2. Run 'make test' to verify the setup"
echo " 3. Start coding! 🎉"
echo ""
echo "📚 Useful commands:"
echo " - make test : Run tests with coverage"
echo " - make lint-all : Run all linters"
echo " - make format : Format code"
echo " - task-master list : View Task Master tasks"
echo ""

44
.devcontainer/setup.sh Normal file
View File

@@ -0,0 +1,44 @@
#!/bin/bash
set -e
echo "🚀 Setting up Business Buddy development environment..."
# Ensure we're in the workspace directory
cd /workspace
# Create virtual environment if it doesn't exist
if [ ! -d ".venv" ]; then
echo "📦 Creating Python virtual environment..."
uv venv .venv --python 3.12
fi
# Activate virtual environment
source .venv/bin/activate
# Install project dependencies using UV
echo "📦 Installing Python dependencies..."
uv pip install -e ".[dev]" || echo "⚠️ Some dependencies failed to install"
# Install pre-commit hooks
echo "🔗 Installing pre-commit hooks..."
pre-commit install || true
# Create .env file from example if it doesn't exist
if [ ! -f ".env" ]; then
if [ -f ".env.example" ]; then
cp .env.example .env
echo "📝 Created .env file from .env.example"
fi
fi
# Fix permissions - ensure all directories exist first
echo "🔧 Fixing permissions..."
mkdir -p /home/vscode/.cache/uv
sudo chown -R vscode:vscode /home/vscode/.cache || true
echo "✅ Development environment setup complete!"
echo ""
echo "📌 Next steps:"
echo " 1. Add your API keys to the .env file"
echo " 2. Run 'source .venv/bin/activate' to activate the virtual environment"
echo " 3. Run 'make test' to verify the setup"

3
.gitignore vendored
View File

@@ -68,7 +68,8 @@ cover/
# Translations
*.mo
*.pot
.cenv/
.venv-host/
# Django stuff:
*.log
local_settings.py

View File

@@ -2,17 +2,13 @@
# Install with: pip install pre-commit && pre-commit install
repos:
# Ruff - Fast Python linter and formatter
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.3
# Black - Python code formatter
- repo: https://github.com/psf/black
rev: 24.4.2
hooks:
- id: ruff
name: ruff (linter)
args: [--fix, --unsafe-fixes]
types_or: [python, pyi]
- id: ruff-format
name: ruff (formatter)
types_or: [python, pyi]
- id: black
language_version: python3
types: [python, pyi]
# Basic file checks
- repo: https://github.com/pre-commit/pre-commit-hooks
@@ -44,29 +40,6 @@ repos:
stages: [pre-commit]
verbose: true
# Optional: Spell checking for documentation
- repo: https://github.com/codespell-project/codespell
rev: v2.4.1
hooks:
- id: codespell
name: codespell
description: Checks for common misspellings in text files
entry: codespell
language: python
types: [text]
exclude: |
(?x)^(
\.git/.*|
\.pytest_cache/.*|
\.venv/.*|
__pycache__/.*|
.*\.pyc|
\.codespellignore|
tests/cassettes/.*
)$
args: [--ignore-words=.codespellignore]
# Configuration for pre-commit
default_stages: [pre-commit]
fail_fast: false
minimum_pre_commit_version: '3.0.0'

View File

@@ -4,6 +4,15 @@ dist/
__pycache__/
.venv/
venv/
.cenv/
.venv-host/
**/.venv/
**/venv/
**/site-packages/
**/lib/python*/
**/bin/
**/include/
**/share/
.git/
.pytest_cache/
.mypy_cache/
@@ -14,3 +23,6 @@ htmlcov/
*.pyo
.archive/
**/.archive/
node_modules/
cache/
examples/

View File

@@ -16,7 +16,7 @@
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter"
},
"python.pythonPath": "${workspaceFolder}/.venv/bin/python3.12",
"python.pythonPath": "",
"mypy-type-checker.interpreter": [
"${workspaceFolder}/.venv/bin/python3.12"
],

View File

@@ -1,4 +1,4 @@
# CLAUDE.md
# CLAUDE.local.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
@@ -17,11 +17,20 @@ make test TEST_FILE=tests/unit_tests/nodes/llm/test_unit_call.py
# Run single test function
pytest tests/path/to/test.py::test_function_name -v
# Run crash tests for resilience testing
python tests/crash_tests/run_crash_tests.py
# Run specific test categories
pytest -m "not slow" # Skip slow tests
pytest -m integration # Only integration tests
pytest -m e2e # Only end-to-end tests
pytest -m unit # Only unit tests
```
### Code Quality
```bash
# Run all linters (ruff, mypy, pyrefly, codespell) - ALWAYS run before committing
# Run all linters (ruff, basedpyright, pyrefly, codespell) - ALWAYS run before committing
make lint-all
# Format code with ruff
@@ -32,6 +41,15 @@ make pre-commit
# Advanced type checking with Pyrefly
pyrefly check .
# Lint single file
make lint-file FILE_PATH=path/to/file.py
# Format single file
make black FILE_PATH=path/to/file.py
# Generate coverage report
make coverage-report
```
### Development Services
@@ -43,13 +61,142 @@ make start
make stop
# Quick setup for new developers
make setup
# Complete environment setup
./scripts/setup-dev.sh
```
### Package Management
The project uses UV for all package management:
```bash
# Install main project with all packages
uv pip install -e ".[dev]"
# Install individual packages for development
uv pip install -e packages/business-buddy-core
uv pip install -e packages/business-buddy-extraction
uv pip install -e packages/business-buddy-tools
# Sync dependencies
uv sync
```
## Architecture
This is a LangGraph-based ReAct (Reasoning and Action) agent system designed for business research and analysis.
### Project Structure
```
biz-budz/ # Main project root (monorepo)
├── src/biz_bud/ # Main application
│ ├── agents/ # Specialized agent implementations
│ ├── config/ # Configuration system
│ ├── graphs/ # LangGraph workflow orchestration
│ ├── nodes/ # Modular processing units
│ ├── services/ # External dependency abstractions
│ ├── states/ # TypedDict state management
│ └── utils/ # Utility functions and helpers
├── packages/ # Modular utility packages
│ ├── business-buddy-core/ # Core utilities & helpers
│ │ └── src/bb_core/
│ │ ├── caching/ # Cache management system
│ │ ├── edge_helpers/ # LangGraph edge routing utilities
│ │ ├── errors/ # Error handling system
│ │ ├── langgraph/ # LangGraph-specific utilities
│ │ ├── logging/ # Logging system
│ │ ├── networking/ # Network utilities
│ │ ├── validation/ # Validation system
│ │ ├── utils/ # General utilities
│ ├── business-buddy-extraction/ # Entity & data extraction
│ │ └── src/bb_extraction/
│ │ ├── core/ # Core extraction framework
│ │ ├── domain/ # Domain-specific extractors
│ │ ├── numeric/ # Numeric data extraction
│ │ ├── statistics/ # Statistical extraction
│ │ ├── text/ # Text processing
│ │ └── tools.py # Extraction tools
│ └── business-buddy-tools/ # Web tools, scrapers, API clients
│ └── src/bb_tools/
│ ├── actions/ # High-level action workflows
│ ├── api_clients/ # API client implementations
│ │ ├── arxiv.py # ArXiv API client
│ │ ├── base.py # Base API client
│ │ ├── firecrawl.py # Firecrawl API client
│ │ ├── jina.py # Jina API client
│ │ ├── paperless.py # Paperless NGX client
│ │ ├── r2r.py # R2R API client
│ │ └── tavily.py # Tavily search client
│ ├── apis/ # API abstractions
│ │ ├── arxiv.py # ArXiv API abstraction
│ │ ├── firecrawl.py # Firecrawl API abstraction
│ │ └── jina/ # Jina API modules
│ ├── browser/ # Browser automation
│ │ ├── base.py # Base browser interface
│ │ ├── browser.py # Selenium browser implementation
│ │ ├── browser_helper.py # Browser utilities
│ │ ├── driverless_browser.py # Driverless browser
│ │ └── js/ # JavaScript utilities
│ │ └── overlay.js # Browser overlay script
│ ├── flows/ # Workflow implementations
│ │ ├── agent_creator.py # Agent creation workflows
│ │ ├── catalog_inspect.py # Catalog inspection
│ │ ├── fetch.py # Content fetching flows
│ │ ├── human_assistance.py # Human interaction flows
│ │ ├── md_processing.py # Markdown processing
│ │ ├── query_processing.py # Query processing
│ │ ├── report_gen.py # Report generation
│ │ └── scrape.py # Scraping workflows
│ ├── interfaces/ # Protocol definitions
│ │ └── web_tools.py # Web tools protocols
│ ├── loaders/ # Data loaders
│ │ └── web_base_loader.py # Web content loader
│ ├── r2r/ # R2R integration
│ │ └── tools.py # R2R tools
│ ├── scrapers/ # Web scraping implementations
│ │ ├── base.py # Base scraper interface
│ │ ├── beautiful_soup.py # BeautifulSoup scraper
│ │ ├── pymupdf.py # PyMuPDF scraper
│ │ ├── strategies/ # Scraping strategies
│ │ │ ├── beautifulsoup.py # BeautifulSoup strategy
│ │ │ ├── firecrawl.py # Firecrawl strategy
│ │ │ └── jina.py # Jina strategy
│ │ ├── unified.py # Unified scraper
│ │ ├── unified_scraper.py # Alternative unified scraper
│ │ └── utils/ # Scraping utilities
│ │ └── __init__.py # Scraping utilities
│ ├── search/ # Search implementations
│ │ ├── base.py # Base search interface
│ │ ├── providers/ # Search providers
│ │ │ ├── arxiv.py # ArXiv search provider
│ │ │ ├── jina.py # Jina search provider
│ │ │ └── tavily.py # Tavily search provider
│ │ ├── unified.py # Unified search tool
│ │ └── web_search.py # Web search implementation
│ ├── stores/ # Data storage
│ │ └── database.py # Database storage utilities
│ ├── stubs/ # Type stubs
│ │ ├── langgraph.pyi # LangGraph stubs
│ │ └── r2r.pyi # R2R stubs
│ ├── utils/ # Tool utilities
│ │ └── html_utils.py # HTML processing utilities
│ ├── constants.py # Tool constants
│ ├── interfaces.py # Tool interfaces
│ └── models.py # Data models
├── tests/ # Comprehensive test suite
│ ├── unit_tests/ # Unit tests with mocks
│ ├── integration_tests/ # Integration tests
│ ├── e2e/ # End-to-end tests
│ ├── crash_tests/ # Resilience & failure tests
│ └── manual/ # Manual test scripts
├── docker/ # Docker configurations
├── scripts/ # Development and deployment scripts
└── .taskmaster/ # Task Master AI project files
```
### Core Components
1. **Graphs** (`src/biz_bud/graphs/`): Define workflow orchestration using LangGraph state machines
@@ -64,6 +211,7 @@ This is a LangGraph-based ReAct (Reasoning and Action) agent system designed for
- `llm/`: LLM interaction layer
- `research/`: Web search, extraction, synthesis with optimization
- `validation/`: Content and logic validation, human feedback
- `integrations/`: External service integrations (Firecrawl, Repomix, etc.)
3. **States** (`src/biz_bud/states/`): TypedDict-based state management for type safety across workflows
@@ -72,11 +220,23 @@ This is a LangGraph-based ReAct (Reasoning and Action) agent system designed for
- Database (PostgreSQL via asyncpg)
- Vector store (Qdrant)
- Cache (Redis)
- Singleton management for expensive resources
5. **Configuration** (`src/biz_bud/config/`): Multi-source configuration system
- Pydantic models for validation
5. **Agents** (`src/biz_bud/agents/`): Specialized agent implementations
- `ngx_agent.py`: Paperless NGX integration agent
- `rag_agent.py`: RAG workflow agent
- `research_agent.py`: Research automation agent
6. **Configuration** (`src/biz_bud/config/`): Multi-source configuration system
- Schema-based configuration (`schemas/`): Typed configuration models
- Environment variables override `config.yaml` defaults
- LLM profiles (tiny, small, large, reasoning)
- Service-specific configurations
7. **Packages** (`packages/`): Modular utility libraries
- **business-buddy-core**: Core utilities, error handling, edge helpers
- **business-buddy-extraction**: Entity and data extraction tools
- **business-buddy-tools**: Web tools, scrapers, API clients
### Key Design Patterns
@@ -227,166 +387,106 @@ uv run pre-commit install
- Use RunnableConfig to make services accessible and configurable at runtime.
- Employ decorators and wrappers to add cross-cutting concerns like logging, caching, or metrics without cluttering core logic.
## Commands
- Never use `pip` directly - always use `uv`
- Don't modify `tasks.json` manually when using Task Master
- Always run `make lint-all` before committing
- Use absolute imports from package roots
- Avoid circular imports between packages
- Always use the centralized ServiceFactory for external services
- Never hardcode API keys - use environment variables
### Testing
```bash
# Run all tests with coverage (uses pytest-xdist for parallel execution)
make test
## Configuration System
# Run tests in watch mode
make test_watch
Business Buddy uses a sophisticated configuration system:
# Run specific test file
make test TEST_FILE=tests/unit_tests/nodes/llm/test_unit_call.py
### Schema-based Configuration (`src/biz_bud/config/schemas/`)
- `analysis.py`: Analysis configuration schemas
- `app.py`: Application-wide settings
- `core.py`: Core configuration types
- `llm.py`: LLM provider configurations
- `research.py`: Research workflow settings
- `services.py`: External service configurations
- `tools.py`: Tool-specific settings
# Run single test function
pytest tests/path/to/test.py::test_function_name -v
### Usage
```python
from biz_bud.config.loader import load_config
config = load_config()
# Access typed configurations
llm_config = config.llm_config
research_config = config.research_config
service_config = config.service_config
```
### Code Quality
## Docker Services
The project requires these services (via Docker):
- **PostgreSQL**: Main database with asyncpg
- **Redis**: Caching and session management
- **Qdrant**: Vector database for embeddings
Start all services:
```bash
# Run all linters (ruff, mypy, pyrefly, codespell) - ALWAYS run before committing
make lint-all
# Format code with ruff
make format
# Run pre-commit hooks (recommended)
make pre-commit
# Advanced type checking with Pyrefly
pyrefly check .
make start # Uses docker/compose-dev.yaml
make stop # Stop and clean up
```
## Architecture
## Import Guidelines
This is a LangGraph-based ReAct (Reasoning and Action) agent system designed for business research and analysis.
```python
# From main application
from biz_bud.nodes.analysis import data_node
from biz_bud.services.factory import ServiceFactory
from biz_bud.states.research import ResearchState
### Core Components
# From packages (always use full path)
from bb_core.edge_helpers import create_conditional_edge
from bb_extraction.domain import CompanyExtractor
from bb_tools.api_clients import TavilyClient
1. **Graphs** (`src/biz_bud/graphs/`): Define workflow orchestration using LangGraph state machines
- `research.py`: Market research workflow implementation
- `graph.py`: Main agent graph with reasoning and action cycles
- `research_agent.py`: Research-specific agent workflow
- `menu_intelligence.py`: Menu analysis subgraph
2. **Nodes** (`src/biz_bud/nodes/`): Modular processing units
- `analysis/`: Data analysis, interpretation, planning, visualization
- `core/`: Input/output handling, error management
- `llm/`: LLM interaction layer
- `research/`: Web search, extraction, synthesis with optimization
- `validation/`: Content and logic validation, human feedback
3. **States** (`src/biz_bud/states/`): TypedDict-based state management for type safety across workflows
4. **Services** (`src/biz_bud/services/`): Abstract external dependencies
- LLM providers (Anthropic, OpenAI, Google, Cohere, etc.)
- Database (PostgreSQL via asyncpg)
- Vector store (Qdrant)
- Cache (Redis)
5. **Configuration** (`src/biz_bud/config/`): Multi-source configuration system
- Pydantic models for validation
- Environment variables override `config.yaml` defaults
- LLM profiles (tiny, small, large, reasoning)
### Key Design Patterns
- **State-Driven Workflows**: All graphs use TypedDict states for type-safe data flow
- **Decorator Pattern**: `@log_config` and `@error_handling` for cross-cutting concerns
- **Service Abstraction**: Clean interfaces for external dependencies
- **Modular Nodes**: Each node has a single responsibility and can be tested independently
- **Parallel Processing**: Search and extraction operations utilize asyncio for performance
### Testing Strategy
- Unit tests in `tests/unit_tests/` with mocked dependencies
- Integration tests in `tests/integration_tests/` for full workflows
- E2E tests in `tests/e2e/` for complete system validation
- VCR cassettes for API mocking in `tests/cassettes/`
- Test markers: `slow`, `integration`, `unit`, `e2e`, `web`, `browser`
- Coverage requirement: 70% minimum
### Test Architecture
#### Test Organization
- **Naming Convention**: All test files follow `test_*.py` pattern
- Unit tests: `test_<module_name>.py`
- Integration tests: `test_<feature>_integration.py`
- E2E tests: `test_<workflow>_e2e.py`
- Manual tests: `test_<feature>_manual.py`
#### Test Helpers (`tests/helpers/`)
- **Assertions** (`assertions/custom_assertions.py`): Reusable assertion functions
- **Factories** (`factories/state_factories.py`): State builders for creating test data
- **Fixtures** (`fixtures/`): Shared pytest fixtures
- `config_fixtures.py`: Configuration mocks and test configs
- `mock_fixtures.py`: Common mock objects
- **Mocks** (`mocks/mock_builders.py`): Builder classes for complex mocks
- `MockLLMBuilder`: Creates mock LLM clients with configurable responses
- `StateBuilder`: Creates typed state objects for workflows
#### Key Testing Patterns
1. **Async Testing**: Use `@pytest.mark.asyncio` for async functions
2. **Mock Builders**: Use builder pattern for complex mocks
```python
mock_llm = MockLLMBuilder()
.with_model("gpt-4")
.with_response("Test response")
.build()
```
3. **State Factories**: Create valid state objects easily
```python
state = StateBuilder.research_state()
.with_query("test query")
.with_search_results([...])
.build()
```
4. **Service Factory Mocking**: Mock the service factory for dependency injection
```python
with patch("biz_bud.utils.service_helpers.get_service_factory",
return_value=mock_service_factory):
# Test code here
```
#### Common Test Patterns
- **E2E Workflow Tests**: Test complete workflows with mocked external services
- **Resilient Node Tests**: Nodes should handle failures gracefully
- Extraction continues even if vector storage fails
- Partial results are returned when some operations fail
- **Configuration Tests**: Validate Pydantic models and config schemas
- **Import Testing**: Ensure all public APIs are importable
### Environment Setup
```bash
# Prerequisites: Python 3.12+, UV package manager, Docker
# Create and activate virtual environment
uv venv
source .venv/bin/activate # Always use this activation path
# Install dependencies with UV
uv pip install -e ".[dev]"
# Install pre-commit hooks
uv run pre-commit install
# Create .env file with required API keys:
# TAVILY_API_KEY=your_key
# OPENAI_API_KEY=your_key (or other LLM provider keys)
# Never use relative imports across packages
```
## Development Principles
## Architectural Patterns
- **Type Safety**: No `Any` types or `# type: ignore` annotations allowed
- **Documentation**: Imperative docstrings with punctuation
- **Package Management**: Always use UV, not pip
- **Pre-commit**: Never skip pre-commit checks
- **Testing**: Write tests for new functionality, maintain 70%+ coverage
- **Error Handling**: Use centralized decorators for consistency
### State Management
- All states are TypedDict-based for type safety
- States are immutable within nodes
- Use reducer functions for state updates
## Development Warnings
### Service Factory Pattern
- Centralized service creation via `ServiceFactory`
- Singleton management for expensive resources
- Dependency injection throughout
- Do not try and launch 'langgraph dev' or any variation
### Error Handling
- Centralized error aggregation
- Namespace-based error routing
- Comprehensive telemetry integration
### Async-First Design
- All I/O operations are async
- Proper connection pooling
- Graceful degradation on failures
## Development Tools
### Pyrefly Configuration (`pyrefly.toml`)
- Advanced type checking beyond mypy
- Monorepo-aware with package path resolution
- Custom import handling for external libraries
### Pre-commit Hooks (`.pre-commit-config.yaml`)
- Automated code quality checks
- Includes ruff, pyrefly, codespell
- File size and merge conflict checks
### Additional Make Commands
```bash
make setup # Complete setup for new machines
make lint-file FILE_PATH=path/to/file.py # Single file linting
make black FILE_PATH=path/to/file.py # Format single file
make pyrefly # Run pyrefly type checking
make coverage-report # Generate HTML coverage report
```

View File

@@ -13,13 +13,19 @@ TEST_FILE ?= tests/
ifeq ($(OS),Windows_NT)
ACTIVATE = .venv\Scripts\activate
PYTHON = python
else ifneq (,$(wildcard .venv-host))
ACTIVATE = source .venv-host/bin/activate
PYTHON = python3
else ifneq (,$(wildcard .cenv))
ACTIVATE = source .cenv/bin/activate
PYTHON = python3
else
ACTIVATE = source .venv/bin/activate
PYTHON = python3
endif
test:
@bash -c "$(ACTIVATE) && PYTHONPATH=src coverage run --source=biz_bud -m pytest -n 4 tests/"
@bash -c "$(ACTIVATE) && PYTHONPATH=src coverage run --source=biz_bud -m pytest -n 8 tests/"
@bash -c "$(ACTIVATE) && coverage report --show-missing"
test_watch:
@@ -112,7 +118,7 @@ lint_tests: MYPY_CACHE=.mypy_cache_test
# Legacy lint targets - now use pre-commit
lint lint_diff lint_package lint_tests:
@echo "Running linting via pre-commit hooks..."
@bash -c "$(ACTIVATE) && pre-commit run ruff --all-files"
@bash -c "$(ACTIVATE) && pre-commit run black --all-files"
pyrefly:
@bash -c "$(ACTIVATE) && pyrefly check ."
@@ -121,33 +127,23 @@ pyrefly:
lint-all: pre-commit
@echo "\n🔍 Running additional type checks..."
@bash -c "$(ACTIVATE) && pyrefly check . || true"
@bash -c "$(ACTIVATE) && ruff check . || true"
# Run pre-commit hooks (single source of truth for linting)
pre-commit:
@echo "🔧 Running pre-commit hooks..."
@echo "This includes: ruff (lint + format), pyrefly, codespell, and file checks"
@echo "This includes: black (format), pyrefly, and file checks"
@bash -c "$(ACTIVATE) && pre-commit run --all-files"
# Format code using pre-commit
format format_diff:
@echo "Formatting code via pre-commit..."
@bash -c "$(ACTIVATE) && pre-commit run ruff-format --all-files || true"
@bash -c "$(ACTIVATE) && pre-commit run ruff --all-files || true"
spell_check:
@bash -c "$(ACTIVATE) && codespell --toml pyproject.toml"
spell_fix:
@bash -c "$(ACTIVATE) && codespell --toml pyproject.toml -w"
@bash -c "$(ACTIVATE) && pre-commit run black --all-files || true"
# Single file linting for hooks (expects FILE_PATH variable)
lint-file:
ifdef FILE_PATH
@echo "🔍 Linting $(FILE_PATH)..."
@bash -c "$(ACTIVATE) && pyrefly check '$(FILE_PATH)'"
@bash -c "$(ACTIVATE) && ruff check '$(FILE_PATH)' --fix"
@bash -c "$(ACTIVATE) && pyright '$(FILE_PATH)'"
@echo "✅ Linting complete"
else
@echo "❌ FILE_PATH not provided"
@@ -164,6 +160,20 @@ else
@exit 1
endif
######################
# LANGGRAPH DEVELOPMENT
######################
# Start LangGraph development server with proper container configuration
langgraph-dev:
@echo "🚀 Starting LangGraph development server..."
@bash -c "$(ACTIVATE) && langgraph dev --host 0.0.0.0 --port 2024"
# Start LangGraph development server with local studio
langgraph-dev-local:
@echo "🚀 Starting LangGraph development server with local studio..."
@bash -c "$(ACTIVATE) && langgraph dev --host 0.0.0.0 --port 2024 --studio-local"
######################
# HELP
######################
@@ -174,8 +184,8 @@ help:
@echo 'start - start Docker services (postgres, redis, qdrant)'
@echo 'stop - stop Docker services'
@echo 'format - run code formatters'
@echo 'lint - run ruff linter'
@echo 'lint-all - run all linters and type checkers (ruff, mypy, pyrefly, codespell)'
@echo 'lint - run black formatter via pre-commit'
@echo 'lint-all - run all type checkers (pyrefly)'
@echo 'pre-commit - run all pre-commit hooks'
@echo 'pyrefly - run pyrefly type checking'
@echo 'test - run unit tests in parallel (pytest-xdist) with coverage'
@@ -183,6 +193,8 @@ help:
@echo 'test TEST_FILE=<test_file> - run all tests in file'
@echo 'test_watch - run unit tests in watch mode'
@echo 'coverage-report - generate HTML coverage report at htmlcov/index.html'
@echo 'langgraph-dev - start LangGraph dev server (for containers/devcontainer)'
@echo 'langgraph-dev-local - start LangGraph dev server with local studio'
@echo 'tree - show tree of .py files in src/'
coverage-report:

View File

@@ -49,6 +49,7 @@ logging:
# LLM profiles configuration
# Env Override: e.g., TINY_LLM_NAME, LARGE_LLM_TEMPERATURE
llm_config:
default_profile: "large" # Options: tiny, small, large, reasoning
tiny:
name: "openai/gpt-4.1-mini"
temperature: 0.7

0
docker/entrypoint.sh Executable file → Normal file
View File

0
examples/crawl_r2r_docs.py Executable file → Normal file
View File

0
examples/crawl_r2r_docs_fixed.py Executable file → Normal file
View File

View File

@@ -84,7 +84,7 @@ async def example_search_and_scrape():
print(f"Found and scraped {len(results)} search results")
for i, result in enumerate(results):
if isinstance(result, dict):
if result:
print(f"\n{i + 1}. {result.get('title', 'No title')}")
print(f" URL: {result.get('url', 'No URL')}")
markdown = result.get("markdown")

View File

@@ -50,7 +50,7 @@ async def test_rag_agent_with_firecrawl():
# Show processing result
if result.get("processing_result"):
processing_result = result["processing_result"]
if processing_result and isinstance(processing_result, dict):
if processing_result:
if processing_result.get("skipped"):
print(f"\nSkipped: {processing_result.get('reason')}")
else:

View File

@@ -8,7 +8,8 @@
"catalog_research": "./src/biz_bud/graphs/catalog_research.py:catalog_research_factory",
"url_to_r2r": "./src/biz_bud/graphs/url_to_r2r.py:url_to_r2r_graph_factory",
"rag_agent": "./src/biz_bud/agents/rag_agent.py:create_rag_agent_for_api",
"error_handling": "./src/biz_bud/graphs/error_handling.py:error_handling_graph_factory"
"error_handling": "./src/biz_bud/graphs/error_handling.py:error_handling_graph_factory",
"paperless_ngx_agent": "./src/biz_bud/agents/ngx_agent.py:paperless_ngx_agent_factory"
},
"env": ".env"
}

View File

@@ -105,10 +105,10 @@ from bb_core.errors import (
# Helpers
from bb_core.helpers import (
_is_sensitive_field,
_redact_sensitive_data,
create_error_details,
is_sensitive_field,
preserve_url_fields,
redact_sensitive_data,
safe_serialize_response,
)
@@ -266,8 +266,8 @@ __all__ = [
"Tone",
# Helpers
"preserve_url_fields",
"_is_sensitive_field",
"_redact_sensitive_data",
"is_sensitive_field",
"redact_sensitive_data",
"safe_serialize_response",
# Networking
"gather_with_concurrency",

View File

@@ -106,3 +106,14 @@ class CacheBackend(ABC):
"""
for key in keys:
await self.delete(key)
async def ainit(self) -> None:
"""Initialize the cache backend.
This method can be overridden by implementations that need
async initialization. The default implementation does nothing.
Note: This is intentionally non-abstract to provide a default
implementation for backends that don't need initialization.
"""
return None

View File

@@ -52,7 +52,7 @@ class AsyncFileCacheBackend[T](CacheBackend[T]):
async def ainit(self) -> None:
"""Async initialization method for compatibility."""
if not self._initialized:
await self._file_cache._ensure_initialized()
await self._file_cache.ensure_initialized()
self._initialized = True
async def get(self, key: str) -> T | None:

View File

@@ -48,7 +48,7 @@ class CacheBackend[T](ABC):
"""Clear all cache entries."""
...
async def ainit(self) -> None: # noqa: B027
async def ainit(self) -> None:
"""Initialize the cache backend.
This method can be overridden by implementations that need
@@ -57,4 +57,4 @@ class CacheBackend[T](ABC):
Note: This is intentionally non-abstract to provide a default
implementation for backends that don't need initialization.
"""
pass
return None

View File

@@ -191,7 +191,7 @@ def cache_async(
# Convert ParamSpecKwargs to dict for processing
kwargs_dict = cast("dict[str, object]", kwargs)
force_refresh = kwargs_dict.pop("force_refresh", False)
kwargs = cast("P.kwargs", kwargs_dict)
# Note: kwargs remains unchanged since we work with kwargs_dict
# Generate cache key (excluding force_refresh from key generation)
try:
@@ -207,7 +207,7 @@ def cache_async(
cache_key = _generate_cache_key(
func.__name__,
cast("tuple[object, ...]", args),
cast("dict[str, object]", kwargs),
kwargs_dict,
key_prefix,
)
except Exception:

View File

@@ -55,11 +55,18 @@ class FileCache(CacheBackend):
self.serializer = serializer
self.key_prefix = key_prefix
self._initialized = False
self._init_lock = asyncio.Lock()
self.logger = get_logger(__name__)
async def _ensure_initialized(self) -> None:
async def ensure_initialized(self) -> None:
"""Ensure cache directory exists."""
if not self._initialized:
if self._initialized:
return
async with self._init_lock:
if self._initialized:
return
try:
await asyncio.to_thread(
self.cache_dir.mkdir, parents=True, exist_ok=True
@@ -90,7 +97,7 @@ class FileCache(CacheBackend):
Returns:
Cached bytes or None if not found or expired
"""
await self._ensure_initialized()
await self.ensure_initialized()
file_path = self._get_file_path(key)
try:
@@ -165,7 +172,7 @@ class FileCache(CacheBackend):
value: Value to store as bytes
ttl: Time-to-live in seconds (None uses default TTL)
"""
await self._ensure_initialized()
await self.ensure_initialized()
file_path = self._get_file_path(key)
# Use default TTL if not specified and cache has TTL
@@ -178,18 +185,20 @@ class FileCache(CacheBackend):
"ttl": effective_ttl,
}
temp_path = None
try:
# Serialize the entire structure
if self.serializer == "pickle":
content = pickle.dumps(cache_data)
else:
# For JSON, ensure bytes are encoded properly
if isinstance(value, bytes):
cache_data["value"] = value.decode("utf-8", errors="replace")
cache_data["value"] = value.decode("utf-8", errors="replace")
content = json.dumps(cache_data).encode("utf-8")
# Write atomically using a temporary file
temp_path = file_path.with_suffix(".tmp")
# Write atomically using a temporary file with unique name
import uuid
temp_path = file_path.with_suffix(f".tmp.{uuid.uuid4().hex[:8]}")
async with aiofiles.open(temp_path, "wb") as f:
await f.write(content)
@@ -199,11 +208,12 @@ class FileCache(CacheBackend):
except (OSError, pickle.PickleError, json.JSONDecodeError) as e:
self.logger.error(f"Failed to write cache file {file_path}: {e}")
# Clean up temp file if it exists
try:
if temp_path.exists():
await asyncio.to_thread(os.remove, temp_path)
except OSError:
pass
if temp_path is not None:
try:
if temp_path.exists():
await asyncio.to_thread(os.remove, temp_path)
except OSError:
pass
async def delete(self, key: str) -> None:
"""Remove value from cache.
@@ -211,25 +221,33 @@ class FileCache(CacheBackend):
Args:
key: Cache key
"""
await self._ensure_initialized()
await self.ensure_initialized()
file_path = self._get_file_path(key)
await self._delete_file(file_path)
async def clear(self) -> None:
"""Clear all cache entries."""
await self._ensure_initialized()
await self.ensure_initialized()
try:
# List all cache files
cache_files = [
f
for f in self.cache_dir.iterdir()
if f.is_file() and f.suffix == ".cache"
]
# List all cache files in a thread-safe way
def list_cache_files() -> list[Path]:
if not self.cache_dir.exists():
return []
return [
f
for f in self.cache_dir.iterdir()
if f.is_file() and f.suffix == ".cache"
]
# Delete all cache files
for file_path in cache_files:
await self._delete_file(file_path)
cache_files = await asyncio.to_thread(list_cache_files)
# Delete all cache files concurrently
if cache_files:
await asyncio.gather(
*[self._delete_file(file_path) for file_path in cache_files],
return_exceptions=True,
)
except OSError as e:
self.logger.error(f"Failed to clear cache directory: {e}")
@@ -289,8 +307,14 @@ class FileCache(CacheBackend):
items: Dictionary mapping keys to values
ttl: Time-to-live in seconds (None uses default TTL)
"""
# Use asyncio.gather for parallel storage
# Use asyncio.gather for parallel storage with limited concurrency
semaphore = asyncio.Semaphore(10) # Limit concurrent file operations
async def set_with_semaphore(key: str, value: bytes) -> None:
async with semaphore:
await self.set(key, value, ttl)
await asyncio.gather(
*[self.set(key, value, ttl) for key, value in items.items()],
*[set_with_semaphore(key, value) for key, value in items.items()],
return_exceptions=True,
)

View File

@@ -0,0 +1,82 @@
"""Edge helpers for LangGraph routing and conditional logic.
This module provides factory functions and edge helpers for creating flexible
routing logic in LangGraph workflows. Edge helpers are functions that accept
state and return enum values or strings indicating the next node(s) to route to.
Example:
from bb_core.edge_helpers import create_enum_router, should_continue
# Create a custom router
router = create_enum_router({"yes": "continue_node", "no": "end_node"})
graph.add_conditional_edges("source_node", router)
# Use built-in edge helpers
graph.add_conditional_edges("agent", should_continue)
"""
from bb_core.edge_helpers.core import (
create_bool_router,
create_enum_router,
create_status_router,
create_threshold_router,
)
from bb_core.edge_helpers.error_handling import (
fallback_to_default,
handle_error,
retry_on_failure,
)
from bb_core.edge_helpers.flow_control import (
multi_step_progress,
should_continue,
timeout_check,
)
from bb_core.edge_helpers.monitoring import (
check_resource_availability,
log_and_monitor,
trigger_notifications,
)
from bb_core.edge_helpers.user_interaction import (
escalate_to_human,
human_interrupt,
pass_status_to_user,
user_feedback_loop,
)
from bb_core.edge_helpers.validation import (
check_accuracy,
check_confidence_level,
check_data_freshness,
check_privacy_compliance,
validate_output_format,
)
__all__ = [
# Core factories
"create_enum_router",
"create_bool_router",
"create_status_router",
"create_threshold_router",
# Flow control
"should_continue",
"timeout_check",
"multi_step_progress",
# Error handling
"handle_error",
"retry_on_failure",
"fallback_to_default",
# Validation
"check_accuracy",
"check_confidence_level",
"validate_output_format",
"check_privacy_compliance",
"check_data_freshness",
# User interaction
"human_interrupt",
"pass_status_to_user",
"user_feedback_loop",
"escalate_to_human",
# Monitoring
"log_and_monitor",
"check_resource_availability",
"trigger_notifications",
]

View File

@@ -0,0 +1,248 @@
"""Core routing factory functions for edge helpers.
This module provides factory functions for creating flexible routing functions
that can be used as conditional edges in LangGraph workflows.
"""
from collections.abc import Callable
from typing import Any, Literal, Protocol, TypeVar
# Generic state type for edge functions
StateT = TypeVar("StateT")
class StateProtocol(Protocol):
"""Protocol for state objects that can be used with edge helpers."""
def get(self, key: str, default: Any = None) -> Any:
"""Get a value from the state."""
...
def create_enum_router(
enum_to_target: dict[str, str],
state_key: str = "routing_decision",
default_target: str = "end",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that maps enum values to target nodes.
Args:
enum_to_target: Mapping of enum values to target node names
state_key: Key in state to check for routing decision
default_target: Default target if no match found
Returns:
Router function that accepts state and returns target node name
Example:
router = create_enum_router({
"continue": "next_step",
"retry": "retry_node",
"error": "error_handler"
})
graph.add_conditional_edges("source_node", router)
"""
def router(state: dict[str, Any] | StateProtocol) -> str:
if hasattr(state, "get") or isinstance(state, dict):
enum_value = state.get(state_key)
else:
enum_value = getattr(state, state_key, None)
return enum_to_target.get(str(enum_value), default_target)
return router
def create_bool_router(
true_target: str,
false_target: str,
state_key: str = "condition",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router based on a boolean condition in state.
Args:
true_target: Target node when condition is True
false_target: Target node when condition is False
state_key: Key in state to check for boolean value
Returns:
Router function that accepts state and returns target node name
Example:
router = create_bool_router("success_node", "failure_node", "is_valid")
graph.add_conditional_edges("validator", router)
"""
def router(state: dict[str, Any] | StateProtocol) -> str:
if hasattr(state, "get") or isinstance(state, dict):
condition = state.get(state_key, False)
else:
condition = getattr(state, state_key, False)
return true_target if condition else false_target
return router
def create_status_router(
status_mapping: dict[str, str],
state_key: str = "status",
default_target: str = "end",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router based on operation status.
Args:
status_mapping: Mapping of status values to target nodes
state_key: Key in state containing status value
default_target: Default target if status not found in mapping
Returns:
Router function that accepts state and returns target node name
Example:
router = create_status_router({
"pending": "wait_node",
"running": "monitor_node",
"completed": "success_node",
"failed": "error_node"
})
graph.add_conditional_edges("task_checker", router)
"""
return create_enum_router(status_mapping, state_key, default_target)
def create_threshold_router(
threshold: float,
above_target: str,
below_target: str,
state_key: str = "score",
equal_target: str | None = None,
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router based on numeric threshold comparison.
Args:
threshold: Threshold value to compare against
above_target: Target when value is above threshold
below_target: Target when value is below threshold
state_key: Key in state containing numeric value
equal_target: Optional target when value equals threshold
(defaults to above_target)
Returns:
Router function that accepts state and returns target node name
Example:
router = create_threshold_router(
0.8, "high_confidence", "low_confidence", "confidence"
)
graph.add_conditional_edges("scorer", router)
"""
def router(state: dict[str, Any] | StateProtocol) -> str:
if hasattr(state, "get") or isinstance(state, dict):
value = state.get(state_key, 0.0)
else:
value = getattr(state, state_key, 0.0)
try:
numeric_value = float(value)
if numeric_value > threshold:
return above_target
elif numeric_value < threshold:
return below_target
else:
return equal_target or above_target
except (ValueError, TypeError):
return below_target
return router
def create_field_presence_router(
required_fields: list[str],
complete_target: str,
incomplete_target: str,
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router based on presence of required fields in state.
Args:
required_fields: List of field names that must be present
complete_target: Target when all fields are present
incomplete_target: Target when any field is missing
Returns:
Router function that accepts state and returns target node name
Example:
router = create_field_presence_router(
["user_input", "processed_data", "validation_result"],
"proceed",
"gather_missing_data"
)
graph.add_conditional_edges("data_checker", router)
"""
def router(state: dict[str, Any] | StateProtocol) -> str:
for field in required_fields:
if hasattr(state, "get") or isinstance(state, dict):
value = state.get(field)
else:
value = getattr(state, field, None)
if value is None or value == "":
return incomplete_target
return complete_target
return router
def create_list_length_router(
min_length: int,
sufficient_target: str,
insufficient_target: str,
state_key: str = "items",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router based on list length in state.
Args:
min_length: Minimum required length
sufficient_target: Target when list meets minimum length
insufficient_target: Target when list is too short
state_key: Key in state containing the list
Returns:
Router function that accepts state and returns target node name
Example:
router = create_list_length_router(
3, "process_results", "gather_more", "search_results"
)
graph.add_conditional_edges("result_checker", router)
"""
def router(state: dict[str, Any] | StateProtocol) -> str:
if hasattr(state, "get") or isinstance(state, dict):
items = state.get(state_key, [])
else:
items = getattr(state, state_key, [])
# Only process actual lists/sequences, not strings or other types
if not isinstance(items, (list, tuple)):
return insufficient_target
try:
return (
sufficient_target if len(items) >= min_length else insufficient_target
)
except TypeError:
return insufficient_target
return router
# Type aliases for common router patterns
BoolRouter = Callable[[StateT], Literal["true", "false"]]
StatusRouter = Callable[[StateT], Literal["pending", "running", "completed", "failed"]]
ContinueRouter = Callable[[StateT], Literal["continue", "end"]]

View File

@@ -0,0 +1,120 @@
What it would look like
/packages/bb_core/
edge_helpers/
__init__.py
routing.py
error_handling.py
accuracy_checks.py
...
Edge helper functions would be generic routing functions that accept the current state and return an enum or string indicating the next node(s) to route to.
These functions can be parameterized or configured at runtime with mappings of enum values to target node names, allowing flexible routing without hardcoding targets.
The package can provide factory functions or classes that generate routing functions given a source node and a dynamic mapping of enum-to-target nodes.
The routing functions would follow a consistent interface, e.g., def route(state: State) -> Literal[...] or def route(state: State) -> str.
Example conceptual snippet:
python
from typing import Callable, Dict, Literal, TypeVar
State = TypeVar("State")
def create_enum_router(
enum_to_target: Dict[str, str],
) -> Callable[[State], str]:
def router(state: State) -> str:
# Determine enum value from state (custom logic)
enum_value = determine_enum_value(state)
return enum_to_target.get(enum_value, "default_node")
return router
How to organize it at runtime
At graph construction or initialization time, you instantiate these routing functions with the current source node context and the dynamic mapping of enum values to target nodes.
You then register these routing functions as conditional edges on the graph, e.g.:
python
graph.add_conditional_edges("source_node", create_enum_router(mapping))
This allows the graph to remain flexible and extensible, with routing logic decoupled from static graph definitions.
should_continue
Decide whether to continue or end based on tool calls in the last AI message.
check_accuracy
Route based on whether the agent's last response meets a defined accuracy threshold.
handle_error
Detect error messages or exceptions and route to error handling or retry nodes.
append_visual_aids
Determine if visual aids (images, charts) should be appended to the response.
pass_status_to_user
Route to nodes that update the human user with status or progress messages.
human_interrupt
Detect if a human user has requested to interrupt or stop the agent.
timeout_check
Route based on whether a timeout has occurred during processing.
retry_on_failure
Decide to retry a failed step or escalate after a certain number of attempts.
fallback_to_default
Route to a default or safe response if the agent output is empty or nonsensical.
check_confidence_level
Route based on the confidence score of the LLM's output.
validate_output_format
Check if the output matches expected schema or format, route to correction if not.
multi_step_progress
Route based on progress through a multi-step workflow (e.g., step 1, step 2, etc.).
user_feedback_loop
Route to nodes that solicit or process user feedback on the agent's response.
escalate_to_human
Route to a human agent or supervisor if automated handling fails or is insufficient.
check_resource_availability
Route based on availability of external resources or APIs needed for next steps.
handle_ambiguous_input
Detect ambiguous or unclear user input and route to clarification nodes.
language_detection
Route based on detected language of user input for multilingual support.
check_privacy_compliance
Route to privacy checks or redaction if sensitive data is detected.
log_and_monitor
Route to logging or monitoring nodes for audit or debugging purposes.
load_balance
Route requests to different nodes or agents for load balancing.
check_user_authentication
Route based on user authentication or authorization status.
handle_rate_limiting
Detect API rate limits and route to wait or fallback nodes.
check_data_freshness
Route based on whether data used is up-to-date or stale.
trigger_notifications
Route to nodes that send notifications or alerts to users or admins.
conditional_branching_by_intent
Route based on detected user intent or topic classification.

View File

@@ -0,0 +1,365 @@
"""Error handling edge helpers for robust workflow management.
This module provides edge helpers for detecting errors, implementing retry logic,
and routing to appropriate error handling or recovery nodes.
"""
from collections.abc import Callable
from typing import Any, Literal, TypeVar
from bb_core.edge_helpers.core import StateProtocol
StateT = TypeVar("StateT", bound=StateProtocol)
def handle_error(
error_types: dict[str, str] | None = None,
error_key: str = "error",
default_target: str = "generic_error_handler",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that handles different types of errors.
Args:
error_types: Mapping of error type/class names to handler nodes
error_key: Key in state containing error information
default_target: Default target if error type not in mapping
Returns:
Router function that routes based on error type
Example:
error_router = handle_error({
"ValidationError": "validation_recovery",
"NetworkError": "network_retry",
"AuthenticationError": "auth_failure_handler"
})
graph.add_conditional_edges("error_detector", error_router)
"""
if error_types is None:
error_types = {
"ValidationError": "validation_error_handler",
"NetworkError": "network_error_handler",
"TimeoutError": "timeout_error_handler",
"AuthenticationError": "auth_error_handler",
}
def router(state: dict[str, Any] | StateProtocol) -> str:
if hasattr(state, "get") or isinstance(state, dict):
error = state.get(error_key)
else:
error = getattr(state, error_key, None)
if error is None:
return "no_error"
# Handle different error formats
error_type = None
if isinstance(error, dict):
error_type = error.get("type") or error.get("error_type")
elif isinstance(error, str):
error_type = error
elif hasattr(error, "__class__"):
error_type = error.__class__.__name__
if error_type:
return error_types.get(error_type, default_target)
return default_target
return router
def retry_on_failure(
max_retries: int = 3,
retry_count_key: str = "retry_count",
error_key: str = "error",
) -> Callable[
[dict[str, Any] | StateProtocol],
Literal["retry", "max_retries_exceeded", "success"],
]:
"""Create a router that implements retry logic for failures.
Args:
max_retries: Maximum number of retry attempts
retry_count_key: Key in state tracking retry attempts
error_key: Key in state containing error information
Returns:
Router function that returns "retry", "max_retries_exceeded", or "success"
Example:
retry_router = retry_on_failure(max_retries=5)
graph.add_conditional_edges("failure_handler", retry_router)
"""
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["retry", "max_retries_exceeded", "success"]:
# Check if there's an error
if hasattr(state, "get") or isinstance(state, dict):
error = state.get(error_key)
retry_count = state.get(retry_count_key, 0)
else:
error = getattr(state, error_key, None)
retry_count = getattr(state, retry_count_key, 0)
# No error means success
if error is None:
return "success"
try:
current_retries = int(retry_count)
if current_retries >= max_retries:
return "max_retries_exceeded"
else:
return "retry"
except (ValueError, TypeError):
return "retry"
return router
def fallback_to_default(
fallback_conditions: list[str] | None = None,
output_key: str = "output",
) -> Callable[[dict[str, Any] | StateProtocol], Literal["use_output", "fallback"]]:
"""Create a router that falls back to default when output is inadequate.
Args:
fallback_conditions: List of conditions that trigger fallback
output_key: Key in state containing output to check
Returns:
Router function that returns "use_output" or "fallback"
Example:
fallback_router = fallback_to_default([
"empty_output", "low_confidence", "invalid_format"
])
graph.add_conditional_edges("output_validator", fallback_router)
"""
if fallback_conditions is None:
fallback_conditions = ["empty_output", "low_confidence", "error"]
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["use_output", "fallback"]:
if hasattr(state, "get") or isinstance(state, dict):
output = state.get(output_key)
else:
output = getattr(state, output_key, None)
# Check for empty or null output
if output is None or output == "":
return "fallback"
# Check for other fallback conditions in state
for condition in fallback_conditions:
if hasattr(state, "get") or isinstance(state, dict):
condition_value = state.get(condition, False)
else:
condition_value = getattr(state, condition, False)
if condition_value:
return "fallback"
return "use_output"
return router
def check_critical_error(
critical_error_types: list[str] | None = None,
error_key: str = "error",
) -> Callable[[dict[str, Any] | StateProtocol], Literal["critical", "non_critical"]]:
"""Create a router that identifies critical errors requiring immediate attention.
Args:
critical_error_types: List of error types considered critical
error_key: Key in state containing error information
Returns:
Router function that returns "critical" or "non_critical"
Example:
critical_router = check_critical_error([
"SecurityError", "DataCorruptionError", "SystemFailure"
])
graph.add_conditional_edges("error_classifier", critical_router)
"""
if critical_error_types is None:
critical_error_types = [
"SecurityError",
"AuthenticationError",
"AuthorizationError",
"DataCorruptionError",
"SystemFailure",
"OutOfMemoryError",
]
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["critical", "non_critical"]:
if hasattr(state, "get") or isinstance(state, dict):
error = state.get(error_key)
else:
error = getattr(state, error_key, None)
if error is None:
return "non_critical"
# Determine error type
error_type = None
if isinstance(error, dict):
error_type = error.get("type") or error.get("error_type")
elif isinstance(error, str):
error_type = error
elif hasattr(error, "__class__"):
error_type = error.__class__.__name__
if error_type and error_type in critical_error_types:
return "critical"
return "non_critical"
return router
def escalation_policy(
escalation_threshold: int = 3,
failure_count_key: str = "consecutive_failures",
escalation_types: dict[str, str] | None = None,
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that implements error escalation policies.
Args:
escalation_threshold: Number of failures before escalation
failure_count_key: Key in state tracking consecutive failures
escalation_types: Mapping of failure types to escalation targets
Returns:
Router function that handles escalation logic
Example:
escalation_router = escalation_policy(
escalation_threshold=3,
escalation_types={
"timeout": "timeout_escalation",
"validation": "validation_escalation"
}
)
graph.add_conditional_edges("failure_monitor", escalation_router)
"""
if escalation_types is None:
escalation_types = {
"timeout": "timeout_escalation",
"network": "network_escalation",
"validation": "validation_escalation",
"auth": "security_escalation",
"security": "security_escalation",
}
def router(state: dict[str, Any] | StateProtocol) -> str:
if hasattr(state, "get") or isinstance(state, dict):
failure_count = state.get(failure_count_key, 0)
error = state.get("error")
else:
failure_count = getattr(state, failure_count_key, 0)
error = getattr(state, "error", None)
try:
failures = int(failure_count)
if failures < escalation_threshold:
return "continue_monitoring"
# Determine escalation type based on error
if error:
error_type = None
if isinstance(error, dict):
error_type = error.get("type") or error.get("error_type")
elif isinstance(error, str):
error_type = error.lower()
elif hasattr(error, "__class__"):
error_type = error.__class__.__name__
if error_type:
for escalation_key, target in escalation_types.items():
if escalation_key.lower() in error_type.lower():
return target
return "generic_escalation"
except (ValueError, TypeError):
return "continue_monitoring"
return router
def circuit_breaker(
failure_threshold: int = 5,
recovery_timeout: int = 60,
state_key: str = "circuit_breaker_state",
failure_count_key: str = "failure_count",
last_failure_key: str = "last_failure_time",
) -> Callable[[dict[str, Any] | StateProtocol], Literal["allow", "reject", "probe"]]:
"""Create a router that implements circuit breaker pattern.
Args:
failure_threshold: Number of failures before opening circuit
recovery_timeout: Seconds to wait before trying to close circuit
state_key: Key storing circuit state (closed/open/half_open)
failure_count_key: Key tracking failure count
last_failure_key: Key storing last failure timestamp
Returns:
Router function implementing circuit breaker logic
Example:
breaker_router = circuit_breaker(failure_threshold=3, recovery_timeout=30)
graph.add_conditional_edges("service_call", breaker_router)
"""
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["allow", "reject", "probe"]:
import time
if hasattr(state, "get") or isinstance(state, dict):
circuit_state = state.get(state_key, "closed")
failure_count = state.get(failure_count_key, 0)
last_failure = state.get(last_failure_key, 0)
else:
circuit_state = getattr(state, state_key, "closed")
failure_count = getattr(state, failure_count_key, 0)
last_failure = getattr(state, last_failure_key, 0)
current_time = time.time()
# Convert numeric values with error handling
try:
failure_count = int(failure_count)
except (ValueError, TypeError):
failure_count = 0
try:
last_failure = float(last_failure)
except (ValueError, TypeError):
last_failure = 0.0
if circuit_state == "open":
# Check if recovery timeout has passed
if current_time - last_failure >= recovery_timeout:
return "probe" # Try half-open state
else:
return "reject" # Still in timeout
elif circuit_state == "half_open":
return "probe" # Allow one request to test
else: # closed state
if failure_count > failure_threshold:
return "reject" # Open circuit
else:
return "allow" # Normal operation
return router

View File

@@ -0,0 +1,262 @@
"""Flow control edge helpers for managing workflow progression.
This module provides edge helpers for controlling the flow of execution
in LangGraph workflows, including continuation logic, timeout checks,
and multi-step progress tracking.
"""
import time
from collections.abc import Callable
from typing import Any, Literal, TypeVar
from bb_core.edge_helpers.core import StateProtocol
StateT = TypeVar("StateT", bound=StateProtocol)
def should_continue(
state: dict[str, Any] | StateProtocol,
) -> Literal["continue", "end"]:
"""Decide whether to continue or end based on tool calls in the last AI message.
Checks for the presence of tool calls in the last message to determine
if the workflow should continue processing or end.
Args:
state: State containing messages with potential tool calls
Returns:
"continue" if tool calls are present, "end" otherwise
Example:
graph.add_conditional_edges("agent", should_continue)
"""
# Check for messages in state
if hasattr(state, "get") or isinstance(state, dict):
messages = state.get("messages", [])
else:
messages = getattr(state, "messages", [])
if not messages:
return "end"
last_message = messages[-1]
# Check various formats for tool calls
if isinstance(last_message, dict):
# Check for tool_calls in additional_kwargs (must be a list)
additional_kwargs = last_message.get("additional_kwargs", {})
tool_calls = additional_kwargs.get("tool_calls")
if tool_calls and isinstance(tool_calls, list) and len(tool_calls) > 0:
return "continue"
# Check for tool_calls directly (must be a list)
tool_calls = last_message.get("tool_calls")
if tool_calls and isinstance(tool_calls, list) and len(tool_calls) > 0:
return "continue"
# Check for function_call
if last_message.get("function_call"):
return "continue"
# Check if message object has tool_calls attribute (must be a list)
if hasattr(last_message, "tool_calls"):
tool_calls = getattr(last_message, "tool_calls", None)
if tool_calls and isinstance(tool_calls, list) and len(tool_calls) > 0:
return "continue"
# Check if message object has additional_kwargs with tool_calls (must be a list)
if hasattr(last_message, "additional_kwargs"):
additional_kwargs = getattr(last_message, "additional_kwargs", None) or {}
if isinstance(additional_kwargs, dict):
tool_calls = additional_kwargs.get("tool_calls")
if tool_calls and isinstance(tool_calls, list) and len(tool_calls) > 0:
return "continue"
return "end"
def timeout_check(
timeout_seconds: float = 300.0,
start_time_key: str = "start_time",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that checks for timeout conditions.
Args:
timeout_seconds: Maximum allowed execution time in seconds
start_time_key: Key in state containing start timestamp
Returns:
Router function that returns "timeout" or "continue"
Example:
timeout_router = timeout_check(timeout_seconds=60.0)
graph.add_conditional_edges("long_task", timeout_router)
"""
def router(state: dict[str, Any] | StateProtocol) -> Literal["timeout", "continue"]:
if hasattr(state, "get") or isinstance(state, dict):
start_time = state.get(start_time_key)
else:
start_time = getattr(state, start_time_key, None)
if start_time is None:
# No start time recorded, assume we should continue
return "continue"
current_time = time.time()
elapsed = current_time - start_time
return "timeout" if elapsed > timeout_seconds else "continue"
return router
def multi_step_progress(
total_steps: int,
step_key: str = "current_step",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router for multi-step workflow progress tracking.
Args:
total_steps: Total number of steps in the workflow
step_key: Key in state containing current step number
Returns:
Router function that returns "next_step", "complete", or "error"
Example:
progress_router = multi_step_progress(total_steps=5)
graph.add_conditional_edges("step_processor", progress_router)
"""
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["next_step", "complete", "error"]:
if hasattr(state, "get") or isinstance(state, dict):
current_step = state.get(step_key, 0)
else:
current_step = getattr(state, step_key, 0)
try:
step_num = int(current_step)
if step_num < 0:
return "error"
elif step_num >= total_steps:
return "complete"
else:
return "next_step"
except (ValueError, TypeError):
return "error"
return router
def check_iteration_limit(
max_iterations: int = 10,
iteration_key: str = "iteration_count",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that enforces iteration limits to prevent infinite loops.
Args:
max_iterations: Maximum allowed iterations
iteration_key: Key in state containing iteration counter
Returns:
Router function that returns "continue" or "limit_reached"
Example:
limit_router = check_iteration_limit(max_iterations=5)
graph.add_conditional_edges("retry_node", limit_router)
"""
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["continue", "limit_reached"]:
if hasattr(state, "get") or isinstance(state, dict):
iterations = state.get(iteration_key, 0)
else:
iterations = getattr(state, iteration_key, 0)
try:
iter_count = int(iterations)
return "limit_reached" if iter_count >= max_iterations else "continue"
except (ValueError, TypeError):
return "continue"
return router
def check_completion_criteria(
required_conditions: list[str],
condition_prefix: str = "completed_",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that checks if all completion criteria are met.
Args:
required_conditions: List of condition names that must be true
condition_prefix: Prefix for condition keys in state
Returns:
Router function that returns "complete" or "continue"
Example:
completion_router = check_completion_criteria([
"data_validated", "report_generated", "notifications_sent"
])
graph.add_conditional_edges("final_check", completion_router)
"""
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["complete", "continue"]:
for condition in required_conditions:
condition_key = f"{condition_prefix}{condition}"
if hasattr(state, "get") or isinstance(state, dict):
is_complete = state.get(condition_key, False)
else:
is_complete = getattr(state, condition_key, False)
if not is_complete:
return "continue"
return "complete"
return router
def check_workflow_state(
state_transitions: dict[str, str],
state_key: str = "workflow_state",
default_target: str = "error",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router based on workflow state transitions.
Args:
state_transitions: Mapping of current states to next targets
state_key: Key in state containing current workflow state
default_target: Default target if state not found in transitions
Returns:
Router function that returns target based on workflow state
Example:
workflow_router = check_workflow_state({
"initialized": "processing",
"processing": "validation",
"validation": "completion",
"completion": "end"
})
graph.add_conditional_edges("state_manager", workflow_router)
"""
def router(state: dict[str, Any] | StateProtocol) -> str:
if hasattr(state, "get") or isinstance(state, dict):
current_state = state.get(state_key)
else:
current_state = getattr(state, state_key, None)
return state_transitions.get(str(current_state), default_target)
return router

View File

@@ -0,0 +1,394 @@
"""Monitoring and operational edge helpers for system management.
This module provides edge helpers for monitoring system health, checking
resource availability, triggering notifications, and managing operational
concerns like load balancing and rate limiting.
"""
import time
from collections.abc import Callable
from typing import Any, Literal, TypeVar
from bb_core.edge_helpers.core import StateProtocol
StateT = TypeVar("StateT", bound=StateProtocol)
def log_and_monitor(
log_levels: dict[str, str] | None = None,
log_level_key: str = "log_level",
should_log_key: str = "should_log",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that determines logging and monitoring actions.
Args:
log_levels: Mapping of log levels to monitoring actions
log_level_key: Key in state containing log level
should_log_key: Key in state indicating if logging is enabled
Returns:
Router function that returns monitoring action or "no_logging"
Example:
monitor_router = log_and_monitor({
"debug": "debug_monitor",
"info": "info_monitor",
"warning": "alert_monitor",
"error": "urgent_monitor"
})
graph.add_conditional_edges("logger", monitor_router)
"""
if log_levels is None:
log_levels = {
"debug": "debug_log",
"info": "info_log",
"warning": "warning_alert",
"error": "error_alert",
"critical": "critical_alert",
}
def router(state: dict[str, Any] | StateProtocol) -> str:
if hasattr(state, "get") or isinstance(state, dict):
should_log = state.get(should_log_key, True)
log_level = state.get(log_level_key, "info")
else:
should_log = getattr(state, should_log_key, True)
log_level = getattr(state, log_level_key, "info")
if not should_log:
return "no_logging"
level_str = str(log_level).lower()
return log_levels.get(level_str, "info_log")
return router
def check_resource_availability(
resource_thresholds: dict[str, float] | None = None,
resources_key: str = "resource_usage",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that checks system resource availability.
Args:
resource_thresholds: Mapping of resource names to threshold percentages
resources_key: Key in state containing resource usage data
Returns:
Router function that returns "resources_available" or resource constraint
Example:
resource_router = check_resource_availability({
"cpu": 0.8, # 80% threshold
"memory": 0.9, # 90% threshold
"disk": 0.95 # 95% threshold
})
graph.add_conditional_edges("resource_check", resource_router)
"""
if resource_thresholds is None:
resource_thresholds = {
"cpu": 0.8,
"memory": 0.9,
"disk": 0.95,
"network": 0.8,
}
def router(state: dict[str, Any] | StateProtocol) -> str:
if hasattr(state, "get") or isinstance(state, dict):
resources = state.get(resources_key)
else:
resources = getattr(state, resources_key, None)
if resources is None or not isinstance(resources, dict):
return "resources_unknown"
# Check each resource against its threshold
for resource_name, threshold in resource_thresholds.items():
usage = resources.get(resource_name, 0.0)
try:
usage_percent = float(usage)
if usage_percent >= threshold:
return f"{resource_name}_constrained"
except (ValueError, TypeError):
continue
return "resources_available"
return router
def trigger_notifications(
notification_rules: dict[str, str] | None = None,
alert_level_key: str = "alert_level",
notification_enabled_key: str = "notifications_enabled",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that triggers appropriate notifications.
Args:
notification_rules: Mapping of alert levels to notification types
alert_level_key: Key in state containing alert level
notification_enabled_key: Key in state for notification toggle
Returns:
Router function that returns notification type or "no_notification"
Example:
notify_router = trigger_notifications({
"low": "email_notification",
"medium": "slack_notification",
"high": "sms_notification",
"critical": "phone_notification"
})
graph.add_conditional_edges("alert_handler", notify_router)
"""
if notification_rules is None:
notification_rules = {
"low": "email_notification",
"medium": "slack_notification",
"high": "sms_notification",
"critical": "phone_notification",
"emergency": "all_channels",
}
def router(state: dict[str, Any] | StateProtocol) -> str:
if hasattr(state, "get") or isinstance(state, dict):
notifications_enabled = state.get(notification_enabled_key, True)
alert_level = state.get(alert_level_key)
else:
notifications_enabled = getattr(state, notification_enabled_key, True)
alert_level = getattr(state, alert_level_key, None)
if not notifications_enabled or alert_level is None:
return "no_notification"
level_str = str(alert_level).lower()
return notification_rules.get(level_str, "no_notification")
return router
def check_rate_limiting(
rate_limits: dict[str, dict[str, float]] | None = None,
request_counts_key: str = "request_counts",
time_window_key: str = "time_window",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that enforces rate limiting policies.
Args:
rate_limits: Mapping of endpoints/operations to rate limit configs
request_counts_key: Key in state containing request count data
time_window_key: Key in state containing time window info
Returns:
Router function that returns "allowed" or "rate_limited"
Example:
rate_router = check_rate_limiting({
"api_calls": {"max_requests": 100, "window_seconds": 60},
"file_uploads": {"max_requests": 10, "window_seconds": 60}
})
graph.add_conditional_edges("rate_limiter", rate_router)
"""
if rate_limits is None:
rate_limits = {
"default": {"max_requests": 100, "window_seconds": 60},
"api_calls": {"max_requests": 100, "window_seconds": 60},
"heavy_operations": {"max_requests": 10, "window_seconds": 300},
}
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["allowed", "rate_limited"]:
if hasattr(state, "get") or isinstance(state, dict):
request_counts = state.get(request_counts_key, {})
time_window_data = state.get(time_window_key, {})
else:
request_counts = getattr(state, request_counts_key, {})
time_window_data = getattr(state, time_window_key, {})
# Use time from time_window_data if available, otherwise use system time
current_time = time_window_data.get("current_time", time.time())
# Check rate limits for each configured endpoint
for endpoint, limits in rate_limits.items():
if endpoint not in request_counts:
continue
max_requests = limits.get("max_requests", 100)
window_seconds = limits.get("window_seconds", 60)
# Get request history for this endpoint
endpoint_data = request_counts.get(endpoint, {})
request_count = endpoint_data.get("count", 0)
window_start = endpoint_data.get("window_start", current_time)
# Check if we're still in the same time window
if (
current_time - window_start < window_seconds
and request_count >= max_requests
):
return "rate_limited"
return "allowed"
return router
def load_balance(
load_balancing_strategy: str = "round_robin",
available_nodes: list[str] | None = None,
node_status_key: str = "node_status",
current_node_key: str = "current_node_index",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that implements load balancing across nodes.
Args:
load_balancing_strategy: Strategy to use ("round_robin", "least_loaded")
available_nodes: List of available node names
node_status_key: Key in state containing node health status
current_node_key: Key in state tracking current node index
Returns:
Router function that returns selected node name
Example:
balance_router = load_balance(
strategy="round_robin",
available_nodes=["node1", "node2", "node3"]
)
graph.add_conditional_edges("load_balancer", balance_router)
"""
if available_nodes is None:
available_nodes = ["primary_node", "secondary_node"]
def router(state: dict[str, Any] | StateProtocol) -> str:
if hasattr(state, "get") or isinstance(state, dict):
node_status = state.get(node_status_key, {})
current_index = state.get(current_node_key, 0)
else:
node_status = getattr(state, node_status_key, {})
current_index = getattr(state, current_node_key, 0)
# Filter healthy nodes
healthy_nodes = []
for node in available_nodes:
node_health = node_status.get(node, {})
if node_health.get("healthy", True): # Default to healthy
healthy_nodes.append(node)
if not healthy_nodes:
return "no_available_nodes"
if load_balancing_strategy == "round_robin":
try:
index = int(current_index) % len(healthy_nodes)
return healthy_nodes[index]
except (ValueError, TypeError):
return healthy_nodes[0]
elif load_balancing_strategy == "least_loaded":
# Find node with lowest load
min_load = float("inf")
selected_node = healthy_nodes[0]
for node in healthy_nodes:
node_data = node_status.get(node, {})
load = node_data.get("load", 0.0)
try:
if float(load) < min_load:
min_load = float(load)
selected_node = node
except (ValueError, TypeError):
continue
return selected_node
else: # Default to first healthy node
return healthy_nodes[0]
return router
def health_check(
health_criteria: dict[str, Any] | None = None,
health_status_key: str = "health_status",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that performs system health checks.
Args:
health_criteria: Criteria for determining system health
health_status_key: Key in state containing health check results
Returns:
Router function that returns "healthy", "degraded", or "unhealthy"
Example:
health_router = health_check({
"response_time_ms": 1000,
"error_rate": 0.05,
"uptime_percent": 0.99
})
graph.add_conditional_edges("health_monitor", health_router)
"""
if health_criteria is None:
health_criteria = {
"response_time_ms": 1000,
"error_rate": 0.05,
"cpu_usage": 0.8,
"memory_usage": 0.9,
}
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["healthy", "degraded", "unhealthy"]:
if hasattr(state, "get") or isinstance(state, dict):
health_status = state.get(health_status_key, {})
else:
health_status = getattr(state, health_status_key, {})
if not isinstance(health_status, dict):
return "unhealthy"
unhealthy_count = 0
degraded_count = 0
for metric, threshold in health_criteria.items():
current_value = health_status.get(metric)
if current_value is None:
continue
try:
value = float(current_value)
threshold_val = float(threshold)
# Different metrics have different "good" directions
if metric in [
"error_rate",
"response_time_ms",
"cpu_usage",
"memory_usage",
"latency",
]:
# Lower is better
if value > threshold_val * 1.5: # 150% of threshold
unhealthy_count += 1
elif value > threshold_val:
degraded_count += 1
else:
# Higher is better (like uptime_percent)
if value < threshold_val * 0.8: # 80% of threshold
unhealthy_count += 1
elif value < threshold_val:
degraded_count += 1
except (ValueError, TypeError):
continue
if unhealthy_count > 0:
return "unhealthy"
elif degraded_count > 0:
return "degraded"
else:
return "healthy"
return router

View File

@@ -0,0 +1,326 @@
"""User interaction edge helpers for human-in-the-loop workflows.
This module provides edge helpers for managing user interactions, interrupts,
feedback loops, and escalation to human operators.
"""
from collections.abc import Callable
from typing import Any, Literal, TypeVar
from bb_core.edge_helpers.core import StateProtocol
StateT = TypeVar("StateT", bound=StateProtocol)
def human_interrupt(
interrupt_signals: list[str] | None = None,
interrupt_key: str = "human_interrupt",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that detects human interruption requests.
Args:
interrupt_signals: List of signals that indicate interrupt request
interrupt_key: Key in state containing interrupt signal
Returns:
Router function that returns "interrupt" or "continue"
Example:
interrupt_router = human_interrupt([
"stop", "pause", "cancel", "abort"
])
graph.add_conditional_edges("user_input_check", interrupt_router)
"""
if interrupt_signals is None:
interrupt_signals = ["stop", "pause", "cancel", "abort", "interrupt", "halt"]
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["interrupt", "continue"]:
if hasattr(state, "get") or isinstance(state, dict):
signal = state.get(interrupt_key)
else:
signal = getattr(state, interrupt_key, None)
if signal is None:
return "continue"
signal_str = str(signal).lower().strip()
normalized_interrupt_signals = {s.lower() for s in interrupt_signals}
return "interrupt" if signal_str in normalized_interrupt_signals else "continue"
return router
def pass_status_to_user(
status_levels: dict[str, str] | None = None,
status_key: str = "status",
notify_key: str = "notify_user",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that determines when to notify users of status.
Args:
status_levels: Mapping of status values to notification urgency
status_key: Key in state containing current status
notify_key: Key in state indicating if user should be notified
Returns:
Router function that returns notification priority or "no_notification"
Example:
status_router = pass_status_to_user({
"error": "urgent",
"warning": "medium",
"completed": "low"
})
graph.add_conditional_edges("status_monitor", status_router)
"""
if status_levels is None:
status_levels = {
"error": "urgent",
"failed": "urgent",
"warning": "medium",
"completed": "low",
"success": "low",
"in_progress": "info",
}
def router(state: dict[str, Any] | StateProtocol) -> str:
if hasattr(state, "get") or isinstance(state, dict):
status = state.get(status_key)
should_notify = state.get(notify_key, True)
else:
status = getattr(state, status_key, None)
should_notify = getattr(state, notify_key, True)
if not should_notify or status is None:
return "no_notification"
status_str = str(status).lower()
return status_levels.get(status_str, "no_notification")
return router
def user_feedback_loop(
feedback_required_conditions: list[str] | None = None,
feedback_key: str = "requires_feedback",
feedback_type_key: str = "feedback_type",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that manages user feedback collection.
Args:
feedback_required_conditions: Conditions that trigger feedback request
feedback_key: Key in state indicating if feedback is required
feedback_type_key: Key in state specifying type of feedback needed
Returns:
Router function that returns feedback type or "no_feedback"
Example:
feedback_router = user_feedback_loop([
"low_confidence", "ambiguous_input", "multiple_options"
])
graph.add_conditional_edges("feedback_check", feedback_router)
"""
if feedback_required_conditions is None:
feedback_required_conditions = [
"low_confidence",
"ambiguous_input",
"multiple_options",
"validation_failed",
]
def router(state: dict[str, Any] | StateProtocol) -> str:
# Check if feedback is explicitly required
if hasattr(state, "get") or isinstance(state, dict):
requires_feedback = state.get(feedback_key, False)
feedback_type = state.get(feedback_type_key, "general")
else:
requires_feedback = getattr(state, feedback_key, False)
feedback_type = getattr(state, feedback_type_key, "general")
if requires_feedback:
return f"feedback_{feedback_type}"
# Check for conditions that trigger feedback
for condition in feedback_required_conditions:
if hasattr(state, "get") or isinstance(state, dict):
condition_met = state.get(condition, False)
else:
condition_met = getattr(state, condition, False)
if condition_met:
return f"feedback_{condition}"
return "no_feedback"
return router
def escalate_to_human(
escalation_triggers: dict[str, str] | None = None,
auto_escalate_key: str = "auto_escalate",
escalation_reason_key: str = "escalation_reason",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that escalates to human operators when needed.
Args:
escalation_triggers: Mapping of trigger conditions to escalation types
auto_escalate_key: Key in state for automatic escalation flag
escalation_reason_key: Key in state containing escalation reason
Returns:
Router function that returns escalation type or "continue_automated"
Example:
escalation_router = escalate_to_human({
"critical_error": "immediate",
"multiple_failures": "urgent",
"user_request": "standard"
})
graph.add_conditional_edges("escalation_check", escalation_router)
"""
if escalation_triggers is None:
escalation_triggers = {
"critical_error": "immediate",
"security_issue": "immediate",
"multiple_failures": "urgent",
"user_request": "standard",
"manual_review": "standard",
"complex_case": "expert",
}
def router(state: dict[str, Any] | StateProtocol) -> str:
# Check for automatic escalation flag
if hasattr(state, "get") or isinstance(state, dict):
auto_escalate = state.get(auto_escalate_key, False)
escalation_reason = state.get(escalation_reason_key)
else:
auto_escalate = getattr(state, auto_escalate_key, False)
escalation_reason = getattr(state, escalation_reason_key, None)
if auto_escalate and escalation_reason:
reason_str = str(escalation_reason).lower()
for trigger, escalation_type in escalation_triggers.items():
if trigger in reason_str:
return f"escalate_{escalation_type}"
return "escalate_standard"
# Check for specific escalation triggers in state
for trigger, escalation_type in escalation_triggers.items():
if hasattr(state, "get") or isinstance(state, dict):
trigger_present = state.get(trigger, False)
else:
trigger_present = getattr(state, trigger, False)
if trigger_present:
return f"escalate_{escalation_type}"
return "continue_automated"
return router
def check_user_authorization(
required_permissions: list[str] | None = None,
user_key: str = "user",
permissions_key: str = "permissions",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that checks user authorization levels.
Args:
required_permissions: List of permissions required for access
user_key: Key in state containing user information
permissions_key: Key in user data containing permissions list
Returns:
Router function that returns "authorized" or "unauthorized"
Example:
auth_router = check_user_authorization([
"read_data", "write_reports", "admin_access"
])
graph.add_conditional_edges("permission_check", auth_router)
"""
if required_permissions is None:
required_permissions = ["basic_access"]
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["authorized", "unauthorized"]:
if hasattr(state, "get") or isinstance(state, dict):
user = state.get(user_key)
else:
user = getattr(state, user_key, None)
if user is None:
return "unauthorized"
# Extract user permissions
user_permissions = []
if isinstance(user, dict):
user_permissions = user.get(permissions_key, [])
elif hasattr(user, permissions_key):
user_permissions = getattr(user, permissions_key, [])
if not isinstance(user_permissions, list):
return "unauthorized"
# Check if user has all required permissions
for permission in required_permissions:
if permission not in user_permissions:
return "unauthorized"
return "authorized"
return router
def collect_user_input(
input_types: dict[str, str] | None = None,
pending_input_key: str = "pending_user_input",
input_type_key: str = "input_type",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that manages user input collection.
Args:
input_types: Mapping of input types to collection methods
pending_input_key: Key in state indicating pending input
input_type_key: Key in state specifying type of input needed
Returns:
Router function that returns input collection method or "no_input_needed"
Example:
input_router = collect_user_input({
"text": "text_input_form",
"choice": "multiple_choice",
"file": "file_upload"
})
graph.add_conditional_edges("input_manager", input_router)
"""
if input_types is None:
input_types = {
"text": "text_input",
"choice": "choice_input",
"confirmation": "confirm_input",
"file": "file_input",
"numeric": "number_input",
}
def router(state: dict[str, Any] | StateProtocol) -> str:
if hasattr(state, "get") or isinstance(state, dict):
pending = state.get(pending_input_key, False)
input_type = state.get(input_type_key, "text")
else:
pending = getattr(state, pending_input_key, False)
input_type = getattr(state, input_type_key, "text")
if not pending:
return "no_input_needed"
input_type_str = str(input_type).lower()
return input_types.get(input_type_str, "text_input")
return router

View File

@@ -0,0 +1,320 @@
"""Validation edge helpers for quality control and data integrity.
This module provides edge helpers for validating outputs, checking accuracy,
confidence levels, format compliance, and data privacy requirements.
"""
import json
import re
import time
from collections.abc import Callable
from typing import Any, Literal, TypeVar
from bb_core.edge_helpers.core import StateProtocol
StateT = TypeVar("StateT", bound=StateProtocol)
def check_accuracy(
threshold: float = 0.8,
accuracy_key: str = "accuracy_score",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that checks if accuracy meets threshold.
Args:
threshold: Minimum accuracy threshold (0.0 to 1.0)
accuracy_key: Key in state containing accuracy score
Returns:
Router function that returns "high_accuracy" or "low_accuracy"
Example:
accuracy_router = check_accuracy(threshold=0.85)
graph.add_conditional_edges("quality_check", accuracy_router)
"""
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["high_accuracy", "low_accuracy"]:
if hasattr(state, "get") or isinstance(state, dict):
accuracy = state.get(accuracy_key, 0.0)
else:
accuracy = getattr(state, accuracy_key, 0.0)
try:
score = float(accuracy)
return "high_accuracy" if score >= threshold else "low_accuracy"
except (ValueError, TypeError):
return "low_accuracy"
return router
def check_confidence_level(
threshold: float = 0.7,
confidence_key: str = "confidence",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router based on confidence score threshold.
Args:
threshold: Minimum confidence threshold (0.0 to 1.0)
confidence_key: Key in state containing confidence score
Returns:
Router function that returns "high_confidence" or "low_confidence"
Example:
confidence_router = check_confidence_level(threshold=0.75)
graph.add_conditional_edges("llm_output", confidence_router)
"""
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["high_confidence", "low_confidence"]:
if hasattr(state, "get") or isinstance(state, dict):
confidence = state.get(confidence_key, 0.0)
else:
confidence = getattr(state, confidence_key, 0.0)
try:
score = float(confidence)
return "high_confidence" if score >= threshold else "low_confidence"
except (ValueError, TypeError):
return "low_confidence"
return router
def validate_output_format(
expected_format: str = "json",
output_key: str = "output",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that validates output format.
Args:
expected_format: Expected format ("json", "xml", "csv", "text")
output_key: Key in state containing output to validate
Returns:
Router function that returns "valid_format" or "invalid_format"
Example:
format_router = validate_output_format(expected_format="json")
graph.add_conditional_edges("formatter", format_router)
"""
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["valid_format", "invalid_format"]:
if hasattr(state, "get") or isinstance(state, dict):
output = state.get(output_key, "")
else:
output = getattr(state, output_key, "")
if not output:
return "invalid_format"
try:
if expected_format.lower() == "json":
json.loads(str(output))
return "valid_format"
elif expected_format.lower() == "xml":
# Basic XML validation
if str(output).strip().startswith("<") and str(output).strip().endswith(
">"
):
return "valid_format"
return "invalid_format"
elif expected_format.lower() == "csv":
# Basic CSV validation - check for comma separation
lines = str(output).strip().split("\n")
if len(lines) > 0 and "," in lines[0]:
return "valid_format"
return "invalid_format"
else: # Default to text validation
if isinstance(output, str) and len(output.strip()) > 0:
return "valid_format"
return "invalid_format"
except Exception:
return "invalid_format"
return router
def check_privacy_compliance(
sensitive_patterns: list[str] | None = None,
content_key: str = "content",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that checks for privacy compliance.
Args:
sensitive_patterns: List of regex patterns for sensitive data
content_key: Key in state containing content to check
Returns:
Router function that returns "compliant" or "privacy_violation"
Example:
privacy_router = check_privacy_compliance([
r'\\b\\d{3}-\\d{2}-\\d{4}\\b', # SSN
r'\\b\\d{4}[- ]?\\d{4}[- ]?\\d{4}[- ]?\\d{4}\\b' # Credit card
])
graph.add_conditional_edges("privacy_check", privacy_router)
"""
if sensitive_patterns is None:
sensitive_patterns = [
r"\b\d{3}-\d{2}-\d{4}\b", # SSN pattern
r"\b\d{4}[- ]?\d{4}[- ]?\d{4}[- ]?\d{4}\b", # Credit card pattern
r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", # Email
r"\b\d{3}[- ]?\d{3}[- ]?\d{4}\b", # Phone number
]
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["compliant", "privacy_violation"]:
if hasattr(state, "get") or isinstance(state, dict):
content = state.get(content_key, "")
else:
content = getattr(state, content_key, "")
content_str = str(content)
for pattern in sensitive_patterns:
if re.search(pattern, content_str, re.IGNORECASE):
return "privacy_violation"
return "compliant"
return router
def check_data_freshness(
max_age_seconds: int = 3600, # 1 hour default
timestamp_key: str = "timestamp",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that checks if data is fresh enough.
Args:
max_age_seconds: Maximum age in seconds before data is stale
timestamp_key: Key in state containing data timestamp
Returns:
Router function that returns "fresh" or "stale"
Example:
freshness_router = check_data_freshness(max_age_seconds=1800) # 30 minutes
graph.add_conditional_edges("data_check", freshness_router)
"""
def router(state: dict[str, Any] | StateProtocol) -> Literal["fresh", "stale"]:
if hasattr(state, "get") or isinstance(state, dict):
timestamp = state.get(timestamp_key)
else:
timestamp = getattr(state, timestamp_key, None)
if timestamp is None:
return "stale"
try:
# Handle different timestamp formats
if isinstance(timestamp, str):
# Try parsing ISO format
from datetime import datetime
dt = datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
timestamp_value = dt.timestamp()
else:
timestamp_value = float(timestamp)
current_time = time.time()
age = current_time - timestamp_value
return "fresh" if age <= max_age_seconds else "stale"
except (ValueError, TypeError):
return "stale"
return router
def validate_required_fields(
required_fields: list[str],
strict_mode: bool = True,
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that validates presence of required fields.
Args:
required_fields: List of field names that must be present
strict_mode: If True, fields must be non-empty; if False, just present
Returns:
Router function that returns "valid" or "missing_fields"
Example:
field_router = validate_required_fields([
"user_id", "request_data", "timestamp"
])
graph.add_conditional_edges("input_validator", field_router)
"""
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["valid", "missing_fields"]:
for field in required_fields:
if hasattr(state, "get") or isinstance(state, dict):
value = state.get(field)
else:
value = getattr(state, field, None)
if value is None:
return "missing_fields"
if strict_mode and (
value == "" or (isinstance(value, list) and len(value) == 0)
):
return "missing_fields"
return "valid"
return router
def check_output_length(
min_length: int = 1,
max_length: int | None = None,
content_key: str = "output",
) -> Callable[[dict[str, Any] | StateProtocol], str]:
"""Create a router that validates output length constraints.
Args:
min_length: Minimum required length
max_length: Maximum allowed length (None for no limit)
content_key: Key in state containing content to check
Returns:
Router function that returns "valid_length", "too_short", or "too_long"
Example:
length_router = check_output_length(min_length=10, max_length=1000)
graph.add_conditional_edges("length_validator", length_router)
"""
def router(
state: dict[str, Any] | StateProtocol,
) -> Literal["valid_length", "too_short", "too_long"]:
if hasattr(state, "get") or isinstance(state, dict):
content = state.get(content_key, "")
else:
content = getattr(state, content_key, "")
content_length = len(str(content))
if content_length < min_length:
return "too_short"
elif max_length is not None and content_length > max_length:
return "too_long"
else:
return "valid_length"
return router

View File

@@ -13,8 +13,6 @@ def get_embedding_client() -> Any:
ImportError: If main application dependencies not available
"""
try:
from biz_bud.services.factory import ServiceFactory # noqa: F401
# This is a placeholder - actual implementation would require ServiceFactory
raise ImportError(
"Embedding client dependencies not available. "
@@ -61,20 +59,10 @@ def get_embeddings_instance(
Raises:
ImportError: If main application dependencies not available
"""
# Try to create the appropriate embeddings instance
try:
if embedding_provider == "openai":
from langchain_openai import OpenAIEmbeddings
# Create OpenAIEmbeddings with explicit model parameter
embeddings_kwargs = {"model": model or "text-embedding-3-small", **kwargs}
return OpenAIEmbeddings(**embeddings_kwargs)
else:
raise ValueError(f"Unsupported embedding provider: {embedding_provider}")
except ImportError as e:
# If dependencies not available, raise our custom error
raise ImportError(
"Embeddings instance not available in bb_core. "
"This function requires the main biz_bud application "
"with langchain dependencies."
) from e
# This is a placeholder implementation - actual embedding functionality
# requires the main biz_bud application with proper dependency injection
raise ImportError(
"Embedding service not available in bb_core. "
"This function requires the main biz_bud application with "
"proper dependency injection."
)

View File

@@ -10,6 +10,7 @@ This module provides mechanisms for:
from __future__ import annotations
import hashlib
import threading
import time
from collections import defaultdict, deque
from dataclasses import dataclass, field
@@ -129,6 +130,7 @@ class RateLimitWindow:
window_size: int = 60 # seconds
max_errors: int = 10
timestamps: deque[float] = field(default_factory=deque)
_lock: threading.Lock = field(default_factory=threading.Lock, init=False)
def is_allowed(self) -> bool:
"""Check if error reporting is allowed within rate limit.
@@ -136,18 +138,21 @@ class RateLimitWindow:
Returns:
True if within rate limit, False otherwise
"""
current_time = time.time()
with self._lock:
current_time = time.time()
# Remove old timestamps outside the window
while self.timestamps and self.timestamps[0] < current_time - self.window_size:
self.timestamps.popleft()
# Remove old timestamps outside the window
while (
self.timestamps and self.timestamps[0] < current_time - self.window_size
):
self.timestamps.popleft()
# Check if we're within the limit
if len(self.timestamps) < self.max_errors:
self.timestamps.append(current_time)
return True
# Check if we're within the limit
if len(self.timestamps) < self.max_errors:
self.timestamps.append(current_time)
return True
return False
return False
def time_until_allowed(self) -> float:
"""Calculate time until next error is allowed.
@@ -155,14 +160,15 @@ class RateLimitWindow:
Returns:
Seconds until next error can be reported
"""
if not self.timestamps:
return 0.0
with self._lock:
if not self.timestamps:
return 0.0
current_time = time.time()
oldest_timestamp = self.timestamps[0]
time_until_expiry = (oldest_timestamp + self.window_size) - current_time
current_time = time.time()
oldest_timestamp = self.timestamps[0]
time_until_expiry = (oldest_timestamp + self.window_size) - current_time
return max(0.0, time_until_expiry)
return max(0.0, time_until_expiry)
class ErrorAggregator:
@@ -185,6 +191,7 @@ class ErrorAggregator:
"""
self.dedup_window = dedup_window
self.aggregate_similar = aggregate_similar
self._lock = threading.RLock() # Use RLock for potential recursive calls
# Storage for aggregated errors by fingerprint
self.aggregated_errors: dict[str, AggregatedError] = {}
@@ -222,47 +229,48 @@ class ErrorAggregator:
Returns:
Tuple of (should_report, reason_if_not)
"""
fingerprint = ErrorFingerprint.from_error_info(error)
fingerprint_hash = fingerprint.hash
current_time = time.time()
with self._lock:
fingerprint = ErrorFingerprint.from_error_info(error)
fingerprint_hash = fingerprint.hash
current_time = time.time()
# Check deduplication
if fingerprint_hash in self.recent_errors:
time_since_last = current_time - self.recent_errors[fingerprint_hash]
if time_since_last < self.dedup_window:
# Check deduplication
if fingerprint_hash in self.recent_errors:
time_since_last = current_time - self.recent_errors[fingerprint_hash]
if time_since_last < self.dedup_window:
return (
False,
f"Duplicate error suppressed ({time_since_last:.1f}s ago)",
)
# Check rate limiting
details = error.get("details", {})
severity = details.get("severity", ErrorSeverity.ERROR.value)
# Check severity-specific rate limit
if (
severity in self.rate_limiters
and not self.rate_limiters[severity].is_allowed()
):
wait_time = self.rate_limiters[severity].time_until_allowed()
return (
False,
f"Duplicate error suppressed ({time_since_last:.1f}s ago)",
f"Rate limit exceeded for {severity} errors ({wait_time:.1f}s)",
)
# Check rate limiting
details = error.get("details", {})
severity = details.get("severity", ErrorSeverity.ERROR.value)
# Check global rate limit
if not self.global_rate_limiter.is_allowed():
wait_time = self.global_rate_limiter.time_until_allowed()
return False, f"Global rate limit exceeded (wait {wait_time:.1f}s)"
# Check severity-specific rate limit
if (
severity in self.rate_limiters
and not self.rate_limiters[severity].is_allowed()
):
wait_time = self.rate_limiters[severity].time_until_allowed()
return (
False,
f"Rate limit exceeded for {severity} errors ({wait_time:.1f}s)",
)
# Update tracking
self.recent_errors[fingerprint_hash] = current_time
# Check global rate limit
if not self.global_rate_limiter.is_allowed():
wait_time = self.global_rate_limiter.time_until_allowed()
return False, f"Global rate limit exceeded (wait {wait_time:.1f}s)"
# Clean old entries periodically
if len(self.recent_errors) > 1000:
self._cleanup_old_entries()
# Update tracking
self.recent_errors[fingerprint_hash] = current_time
# Clean old entries periodically
if len(self.recent_errors) > 1000:
self._cleanup_old_entries()
return True, None
return True, None
def add_error(self, error: ErrorInfo) -> AggregatedError:
"""Add an error to aggregation.
@@ -273,19 +281,20 @@ class ErrorAggregator:
Returns:
Aggregated error information
"""
fingerprint = ErrorFingerprint.from_error_info(error)
with self._lock:
fingerprint = ErrorFingerprint.from_error_info(error)
if self.aggregate_similar and fingerprint.hash in self.aggregated_errors:
# Update existing aggregation
aggregated = self.aggregated_errors[fingerprint.hash]
aggregated.add_occurrence(error)
else:
# Create new aggregation
aggregated = AggregatedError(fingerprint=fingerprint)
aggregated.add_occurrence(error)
self.aggregated_errors[fingerprint.hash] = aggregated
if self.aggregate_similar and fingerprint.hash in self.aggregated_errors:
# Update existing aggregation
aggregated = self.aggregated_errors[fingerprint.hash]
aggregated.add_occurrence(error)
else:
# Create new aggregation
aggregated = AggregatedError(fingerprint=fingerprint)
aggregated.add_occurrence(error)
self.aggregated_errors[fingerprint.hash] = aggregated
return aggregated
return aggregated
def get_aggregated_errors(
self,
@@ -301,24 +310,27 @@ class ErrorAggregator:
Returns:
List of aggregated errors
"""
current_time = datetime.now(UTC)
results = []
with self._lock:
current_time = datetime.now(UTC)
results = []
for aggregated in self.aggregated_errors.values():
if aggregated.count < min_count:
continue
if time_window:
time_since_last = (current_time - aggregated.last_seen).total_seconds()
if time_since_last > time_window:
for aggregated in self.aggregated_errors.values():
if aggregated.count < min_count:
continue
results.append(aggregated)
if time_window:
time_since_last = (
current_time - aggregated.last_seen
).total_seconds()
if time_since_last > time_window:
continue
# Sort by count (descending) and recency
results.sort(key=lambda x: (-x.count, x.last_seen), reverse=True)
results.append(aggregated)
return results
# Sort by count (descending) and recency
results.sort(key=lambda x: (-x.count, x.last_seen), reverse=True)
return results
def get_error_summary(self) -> dict[str, Any]:
"""Get summary of aggregated errors.
@@ -326,58 +338,61 @@ class ErrorAggregator:
Returns:
Summary statistics
"""
total_errors = sum(agg.count for agg in self.aggregated_errors.values())
unique_errors = len(self.aggregated_errors)
with self._lock:
total_errors = sum(agg.count for agg in self.aggregated_errors.values())
unique_errors = len(self.aggregated_errors)
# Group by category
by_category = defaultdict(int)
by_severity = defaultdict(int)
by_node = defaultdict(int)
# Group by category
by_category = defaultdict(int)
by_severity = defaultdict(int)
by_node = defaultdict(int)
for aggregated in self.aggregated_errors.values():
by_category[aggregated.fingerprint.category] += aggregated.count
for aggregated in self.aggregated_errors.values():
by_category[aggregated.fingerprint.category] += aggregated.count
# Get severity from sample
if aggregated.sample_errors:
severity = (
aggregated.sample_errors[0]
.get("details", {})
.get("severity", "unknown")
)
by_severity[severity] += aggregated.count
# Get severity from sample
if aggregated.sample_errors:
severity = (
aggregated.sample_errors[0]
.get("details", {})
.get("severity", "unknown")
)
by_severity[severity] += aggregated.count
if aggregated.fingerprint.node:
by_node[aggregated.fingerprint.node] += aggregated.count
if aggregated.fingerprint.node:
by_node[aggregated.fingerprint.node] += aggregated.count
return {
"total_errors": total_errors,
"unique_errors": unique_errors,
"by_category": dict(by_category),
"by_severity": dict(by_severity),
"by_node": dict(by_node),
"top_errors": [
{
"fingerprint": agg.fingerprint.hash,
"type": agg.fingerprint.error_type,
"category": agg.fingerprint.category,
"count": agg.count,
"first_seen": agg.first_seen.isoformat(),
"last_seen": agg.last_seen.isoformat(),
}
for agg in self.get_aggregated_errors(min_count=2)[:10]
],
}
return {
"total_errors": total_errors,
"unique_errors": unique_errors,
"by_category": dict(by_category),
"by_severity": dict(by_severity),
"by_node": dict(by_node),
"top_errors": [
{
"fingerprint": agg.fingerprint.hash,
"type": agg.fingerprint.error_type,
"category": agg.fingerprint.category,
"count": agg.count,
"first_seen": agg.first_seen.isoformat(),
"last_seen": agg.last_seen.isoformat(),
}
for agg in self.get_aggregated_errors(min_count=2)[:10]
],
}
def reset(self) -> None:
"""Reset all aggregation state."""
self.aggregated_errors.clear()
self.recent_errors.clear()
for limiter in self.rate_limiters.values():
limiter.timestamps.clear()
self.global_rate_limiter.timestamps.clear()
with self._lock:
self.aggregated_errors.clear()
self.recent_errors.clear()
for limiter in self.rate_limiters.values():
limiter.timestamps.clear()
self.global_rate_limiter.timestamps.clear()
def _cleanup_old_entries(self) -> None:
"""Remove old entries from recent errors tracking."""
# Lock is already held by caller
current_time = time.time()
cutoff_time = current_time - (self.dedup_window * 2)
@@ -388,6 +403,7 @@ class ErrorAggregator:
# Global instance for easy access
_error_aggregator: ErrorAggregator | None = None
_error_aggregator_lock = threading.Lock()
def get_error_aggregator() -> ErrorAggregator:
@@ -397,13 +413,19 @@ def get_error_aggregator() -> ErrorAggregator:
Global ErrorAggregator instance
"""
global _error_aggregator
if _error_aggregator is None:
if _error_aggregator is not None:
return _error_aggregator
with _error_aggregator_lock:
if _error_aggregator is not None:
return _error_aggregator
_error_aggregator = ErrorAggregator()
return _error_aggregator
return _error_aggregator
def reset_error_aggregator() -> None:
"""Reset the global error aggregator."""
global _error_aggregator
if _error_aggregator is not None:
_error_aggregator.reset()
with _error_aggregator_lock:
if _error_aggregator is not None:
_error_aggregator.reset()

View File

@@ -177,7 +177,7 @@ class ErrorRegistry:
instance = cast("ErrorRegistry", super().__new__(cls))
cls._instance = instance
instance._initialize_registry()
return cast("ErrorRegistry", cls._instance)
return cls._instance
def _initialize_registry(self) -> None:
"""Initialize the error registry with default mappings."""

View File

@@ -290,13 +290,17 @@ def create_formatted_error(
# Auto-categorize if no category or error_code provided
if not category and not error_code:
auto_category, auto_code = categorize_error(exception)
category = auto_category.value
category = auto_category
if auto_code and not error_code:
error_code = auto_code
# Get ErrorCategory enum
try:
error_category = ErrorCategory(category or "unknown")
if isinstance(category, ErrorCategory):
error_category: ErrorCategory = category
else:
category_str = category or "unknown"
error_category = cast(ErrorCategory, ErrorCategory(category_str))
except ValueError:
error_category = ErrorCategory.UNKNOWN
@@ -304,7 +308,7 @@ def create_formatted_error(
formatted_message = ErrorMessageFormatter.format_error_message(
message=message,
error_code=error_code,
category=cast("ErrorCategory", error_category),
category=error_category,
template_type=template_type,
**context,
)

View File

@@ -160,17 +160,27 @@ class RouterConfig:
# Categories
if "categories" in config:
condition.categories = cast(
"list[ErrorCategory]",
[ErrorCategory(cat) for cat in config["categories"]],
)
categories_list: list[ErrorCategory] = []
for cat in config["categories"]:
if isinstance(cat, str):
categories_list.append(cast(ErrorCategory, ErrorCategory(cat)))
elif isinstance(cat, ErrorCategory):
categories_list.append(cat)
else:
raise TypeError(f"Invalid type for category: {cat!r} (type: {type(cat)})")
condition.categories = categories_list
# Severities
if "severities" in config:
condition.severities = cast(
"list[ErrorSeverity]",
[ErrorSeverity(sev) for sev in config["severities"]],
)
severities_list: list[ErrorSeverity] = []
for sev in config["severities"]:
if isinstance(sev, str):
severities_list.append(cast(ErrorSeverity, ErrorSeverity(sev)))
elif isinstance(sev, ErrorSeverity):
severities_list.append(sev)
else:
raise TypeError(f"Invalid type for severity: {sev!r} (type: {type(sev)})")
condition.severities = severities_list
# Nodes
if "nodes" in config:
@@ -204,10 +214,13 @@ class RouterConfig:
path.parent.mkdir(parents=True, exist_ok=True)
# Build config dict
config = {"default_action": self.router.default_action.value, "routes": []}
config: dict[str, Any] = {
"default_action": self.router.default_action.value,
"routes": [],
}
for route in self.router.routes:
route_config = {
route_config: dict[str, Any] = {
"name": route.name,
"action": route.action.value,
"priority": route.priority,

View File

@@ -375,12 +375,10 @@ class ErrorTelemetry:
else:
# Default: log the alert
import logging
from typing import cast
logger = cast("logging.Logger", logging.getLogger(__name__))
level = cast(
"int",
logging.ERROR if severity in ["critical", "error"] else logging.WARNING,
logger = logging.getLogger(__name__)
level = (
logging.ERROR if severity in ["critical", "error"] else logging.WARNING
)
logger.log(
level,

View File

@@ -79,7 +79,7 @@ def create_error_details(
}
def _is_sensitive_field(field_name: str) -> bool:
def is_sensitive_field(field_name: str) -> bool:
"""Check if a field name indicates sensitive data.
Args:
@@ -88,16 +88,13 @@ def _is_sensitive_field(field_name: str) -> bool:
Returns:
True if the field name indicates sensitive data
"""
if not isinstance(field_name, str):
return False
field_lower = field_name.lower()
# Check against all sensitive patterns
return any(re.match(pattern, field_lower) for pattern in SENSITIVE_PATTERNS)
def _redact_sensitive_data(data: Any, max_depth: int = 10) -> Any:
def redact_sensitive_data(data: Any, max_depth: int = 10) -> Any:
"""Recursively redact sensitive data from nested structures.
Args:
@@ -113,13 +110,13 @@ def _redact_sensitive_data(data: Any, max_depth: int = 10) -> Any:
if isinstance(data, dict):
result = {}
for key, value in data.items():
if _is_sensitive_field(key):
if is_sensitive_field(key):
result[key] = REDACTED_VALUE
else:
result[key] = _redact_sensitive_data(value, max_depth - 1)
result[key] = redact_sensitive_data(value, max_depth - 1)
return result
elif isinstance(data, list):
return [_redact_sensitive_data(item, max_depth - 1) for item in data]
return [redact_sensitive_data(item, max_depth - 1) for item in data]
else:
return data
@@ -165,7 +162,7 @@ def safe_serialize_response(response: Any) -> dict[str, Any]:
# Handle lists directly (return as-is after redaction)
elif isinstance(response, list):
# Apply redaction directly to the list and return it
return _redact_sensitive_data(response)
return redact_sensitive_data(response)
# Handle built-in types without __dict__
elif response is None or isinstance(response, (str, int, float, bool)):
data = {"type": type(response).__name__, "value": str(response)}
@@ -180,7 +177,7 @@ def safe_serialize_response(response: Any) -> dict[str, Any]:
data = {"type": type(response).__name__, "value": str(response)}
# Redact sensitive data from the result
return _redact_sensitive_data(data)
return redact_sensitive_data(data)
except Exception as e:
return {

View File

@@ -9,7 +9,7 @@ import asyncio
import functools
import time
from collections.abc import Callable
from datetime import datetime
from datetime import UTC, datetime
from typing import Any, TypedDict, cast
from ..logging import get_logger
@@ -188,8 +188,7 @@ def track_metrics(
)
metric = cast("NodeMetric", metrics[metric_name])
if metric is not None:
metric["count"] = (metric["count"] or 0) + 1
metric["count"] = (metric["count"] or 0) + 1
try:
result = await func(*args, **kwargs)
@@ -205,7 +204,7 @@ def track_metrics(
metric["avg_duration_ms"] = (
metric["total_duration_ms"] or 0.0
) / count
metric["last_execution"] = datetime.utcnow().isoformat()
metric["last_execution"] = datetime.now(UTC).isoformat()
return result
@@ -221,7 +220,7 @@ def track_metrics(
metric["avg_duration_ms"] = (
metric["total_duration_ms"] or 0.0
) / count
metric["last_execution"] = datetime.utcnow().isoformat()
metric["last_execution"] = datetime.now(UTC).isoformat()
metric["last_error"] = str(e)
raise
@@ -251,8 +250,7 @@ def track_metrics(
)
metric = cast("NodeMetric", metrics[metric_name])
if metric is not None:
metric["count"] = (metric["count"] or 0) + 1
metric["count"] = (metric["count"] or 0) + 1
try:
result = func(*args, **kwargs)
@@ -267,7 +265,7 @@ def track_metrics(
metric["avg_duration_ms"] = (
metric["total_duration_ms"] or 0.0
) / count
metric["last_execution"] = datetime.utcnow().isoformat()
metric["last_execution"] = datetime.now(UTC).isoformat()
return result
@@ -282,7 +280,7 @@ def track_metrics(
metric["avg_duration_ms"] = (
metric["total_duration_ms"] or 0.0
) / count
metric["last_execution"] = datetime.utcnow().isoformat()
metric["last_execution"] = datetime.now(UTC).isoformat()
metric["last_error"] = str(e)
raise
@@ -339,7 +337,7 @@ def handle_errors(
"node": func.__name__,
"error": str(e),
"type": type(e).__name__,
"timestamp": datetime.utcnow().isoformat(),
"timestamp": datetime.now(UTC).isoformat(),
}
)
@@ -373,7 +371,7 @@ def handle_errors(
"node": func.__name__,
"error": str(e),
"type": type(e).__name__,
"timestamp": datetime.utcnow().isoformat(),
"timestamp": datetime.now(UTC).isoformat(),
}
)

View File

@@ -7,7 +7,7 @@ enabling consistent configuration injection across all nodes and tools.
from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast
from langchain_core.runnables import RunnableConfig
from langgraph.graph import StateGraph
@@ -53,17 +53,18 @@ def configure_graph_with_injection(
continue
# Create wrapper that injects config
wrapped_node = create_config_injected_node(node_func, base_config)
# Replace the node
graph_builder.nodes[node_name] = wrapped_node
if callable(node_func):
callable_node = cast(Callable[..., object], node_func)
wrapped_node = create_config_injected_node(callable_node, base_config)
# Replace the node
graph_builder.nodes[node_name] = wrapped_node
return graph_builder
def create_config_injected_node(
node_func: Callable[..., object], base_config: RunnableConfig
) -> Callable[..., object]:
) -> Any:
"""Create a node wrapper that injects RunnableConfig.
Args:
@@ -86,7 +87,7 @@ def create_config_injected_node(
# Node already expects config, wrap to provide it
@wraps(node_func)
async def config_aware_wrapper(
state: dict, config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig | None = None
) -> Any:
# Merge base config with runtime config
if config:
@@ -102,7 +103,7 @@ def create_config_injected_node(
else:
return node_func(state, config=merged_config)
wrapped = RunnableLambda(config_aware_wrapper).with_config(base_config)
wrapped: Any = RunnableLambda(config_aware_wrapper).with_config(base_config)
else:
# Node doesn't expect config, just wrap with config
wrapped = RunnableLambda(node_func).with_config(base_config)

View File

@@ -166,12 +166,12 @@ class ConfigurationProvider:
# Copy metadata
self_metadata = getattr(self._config, "metadata", {})
other_metadata = getattr(other, "metadata", {})
merged.metadata = {**self_metadata, **other_metadata}
merged.metadata = {**self_metadata, **other_metadata} # type: ignore[attr-defined]
# Copy configurable
self_configurable = getattr(self._config, "configurable", {})
other_configurable = getattr(other, "configurable", {})
merged.configurable = {**self_configurable, **other_configurable}
merged.configurable = {**self_configurable, **other_configurable} # type: ignore[attr-defined]
# Copy other attributes
for attr in ["tags", "callbacks", "recursion_limit"]:
@@ -207,13 +207,14 @@ class ConfigurationProvider:
config = RunnableConfig()
# Set configurable values
config.configurable = {
configurable_dict = {
"app_config": app_config,
"service_factory": service_factory,
}
config.configurable = configurable_dict # type: ignore[attr-defined]
# Set metadata
config.metadata = metadata
config.metadata = metadata # type: ignore[attr-defined]
return cls(config)
@@ -273,7 +274,7 @@ def create_runnable_config(
if max_tokens_override is not None:
configurable["max_tokens_override"] = max_tokens_override
config.configurable = configurable
config.configurable = configurable # type: ignore[attr-defined]
# Set metadata
metadata = {}
@@ -287,6 +288,6 @@ def create_runnable_config(
if session_id is not None:
metadata["session_id"] = session_id
config.metadata = metadata
config.metadata = metadata # type: ignore[attr-defined]
return config

View File

@@ -91,6 +91,7 @@ def setup_logging(
root.setLevel(numeric_level)
# Add console handler
console_handler = None
if use_rich:
console_handler = SafeRichHandler(
console=_console,
@@ -114,11 +115,12 @@ def setup_logging(
if hasattr(console_handler, "setFormatter"):
console_handler.setFormatter(formatter)
# Set handler level and add to root
if hasattr(console_handler, "setLevel"):
console_handler.setLevel(numeric_level)
if hasattr(root, "addHandler"):
root.addHandler(console_handler)
# Set handler level and add to root (only if handler was created)
if console_handler is not None:
if hasattr(console_handler, "setLevel"):
console_handler.setLevel(numeric_level)
if hasattr(root, "addHandler"):
root.addHandler(console_handler)
# Add file handler if specified
if log_file:

View File

@@ -47,6 +47,9 @@ def log_function_call(
if logger is None:
logger = get_logger(func.__module__)
# Assert logger is not None after assignment
assert logger is not None
# Log function call
message_parts = [f"Calling {func.__name__}"]
if include_args and (args or kwargs):

View File

@@ -337,8 +337,9 @@ class APIClient:
) -> APIResponse:
"""Make request with retry logic."""
# Normalize method to RequestMethod
method_val = RequestMethod(method) if isinstance(method, str) else method
method_val = cast(RequestMethod, method_val)
method_val = (
method if isinstance(method, RequestMethod) else RequestMethod(method)
)
# Merge headers
request_headers: dict[str, str] = {**self.headers}
if headers:
@@ -351,10 +352,10 @@ class APIClient:
try:
# Log request
logger.info(
f"Making {method_val.value} request to: {url}",
f"Making {method_val} request to: {url}",
extra={
"operation": "api_request",
"method": method_val.value,
"method": method_val,
"url": url,
"attempt": attempt + 1,
},
@@ -600,8 +601,10 @@ def create_api_client(
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
config_kwargs = {"timeout": timeout, "max_retries": max_retries}
config = RequestConfig(**config_kwargs)
config: RequestConfig = RequestConfig.model_validate({
"timeout": timeout,
"max_retries": max_retries
})
if client_type == "basic":
return APIClient(base_url=base_url, headers=headers, config=config)
@@ -649,8 +652,10 @@ def proxied_rate_limited_request(
import asyncio
async def _make_request() -> object:
config_kwargs = {"timeout": timeout_val}
async with APIClient(config=RequestConfig(**config_kwargs)) as client:
config: RequestConfig = RequestConfig.model_validate({
"timeout": timeout_val
})
async with APIClient(config=config) as client:
# Map method string to RequestMethod enum
method_upper: str = method.upper()
try:

View File

@@ -4,10 +4,13 @@ import asyncio
import functools
import time
from collections import defaultdict, deque
from collections.abc import Callable
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, TypeVar, cast
from typing import TYPE_CHECKING, Any, TypeVar, Union, cast, overload
if TYPE_CHECKING:
pass
T = TypeVar("T")
@@ -149,7 +152,7 @@ def exponential_backoff(
async def retry_with_backoff(
func: Callable, # Simplified type to avoid overload issues
func: Callable[..., Any],
config: RetryConfig,
*args,
**kwargs,
@@ -174,12 +177,14 @@ async def retry_with_backoff(
try:
if asyncio.iscoroutinefunction(func):
# For async functions
result = await func(*args, **kwargs)
return cast("T", result)
async_func = cast(Callable[..., Awaitable[T]], func)
result = await async_func(*args, **kwargs)
return result
else:
# For sync functions
result = func(*args, **kwargs)
return cast("T", result)
sync_func = cast(Callable[..., T], func)
result = sync_func(*args, **kwargs)
return result
except Exception as e:
# Check if exception is in the allowed list
if not any(isinstance(e, exc_type) for exc_type in config.exceptions):

View File

@@ -1,6 +1,6 @@
"""Type definitions for networking module."""
from typing import Literal, TypedDict
from typing import Any, Literal, TypedDict
try:
from typing import NotRequired
@@ -17,7 +17,7 @@ class HTTPResponse(TypedDict):
headers: dict[str, str]
content: bytes
text: NotRequired[str]
json: NotRequired[dict | list]
json: NotRequired[dict[str, Any] | list[Any]]
class RequestOptions(TypedDict):
@@ -27,7 +27,7 @@ class RequestOptions(TypedDict):
url: str
headers: NotRequired[dict[str, str]]
params: NotRequired[dict[str, str]]
json: NotRequired[dict | list]
json: NotRequired[dict[str, Any] | list[Any]]
data: NotRequired[bytes | str]
timeout: NotRequired[float | tuple[float, float]]
follow_redirects: NotRequired[bool]

View File

@@ -74,12 +74,12 @@ class ServiceHelperRemovedError(ImportError):
)
def get_service_factory(*args, **kwargs): # type: ignore
def get_service_factory(*args, **kwargs):
"""REMOVED: Use biz_bud.services.factory.get_global_factory() instead."""
raise ServiceHelperRemovedError("get_service_factory")
def get_service_factory_sync(*args, **kwargs): # type: ignore
def get_service_factory_sync(*args, **kwargs):
"""REMOVED: Use biz_bud.services.factory.get_global_factory() instead."""
raise ServiceHelperRemovedError("get_service_factory_sync")

View File

@@ -183,7 +183,7 @@ class WebSearchHistoryEntry(TypedDict, total=False):
query: str
timestamp: str
results: list[dict]
results: list[dict[str, Any]]
summary: str
@@ -203,7 +203,7 @@ class ApiResponseDataTypedDict(TypedDict, total=False):
validation_passed: bool | None
validation_issues: list[str]
entities: NotRequired[list[str]]
report_metadata: NotRequired[dict]
report_metadata: NotRequired[dict[str, Any]]
class ApiResponseTypedDict(TypedDict, total=False):
@@ -246,7 +246,7 @@ class ToolCallTypedDict(TypedDict):
class ParsedInputTypedDict(TypedDict, total=False):
"""Structured input payload with 'raw_payload' and 'user_query' keys."""
raw_payload: dict
raw_payload: dict[str, Any]
user_query: str

View File

@@ -1,5 +1,10 @@
"""Core utilities for Business Buddy framework."""
from bb_core.utils.lazy_loader import (
LazyProxy,
ThreadSafeLazyLoader,
create_lazy_loader,
)
from bb_core.utils.url_normalizer import URLNormalizer
__all__ = ["URLNormalizer"]
__all__ = ["URLNormalizer", "ThreadSafeLazyLoader", "LazyProxy", "create_lazy_loader"]

View File

@@ -0,0 +1,80 @@
"""Thread-safe lazy loading utilities for singleton pattern implementations.
This module provides thread-safe utilities for lazy loading expensive resources
like graphs and configurations in a multi-threaded environment.
"""
import threading
from collections.abc import Callable
class ThreadSafeLazyLoader[T]:
"""Thread-safe lazy loader for singleton instances."""
def __init__(self, factory: Callable[[], T]) -> None:
"""Initialize the lazy loader with a factory function.
Args:
factory: Function that creates the instance when called
"""
self._factory = factory
self._instance: T | None = None
self._lock = threading.Lock()
def get_instance(self) -> T:
"""Get the singleton instance, creating it if necessary.
Returns:
The singleton instance
"""
# Double-checked locking pattern for thread safety
if self._instance is None:
with self._lock:
# Check again after acquiring lock
if self._instance is None:
self._instance = self._factory()
return self._instance
def reset(self) -> None:
"""Reset the instance (mainly for testing purposes)."""
with self._lock:
self._instance = None
def create_lazy_loader[T](factory: Callable[[], T]) -> ThreadSafeLazyLoader[T]:
"""Create a thread-safe lazy loader for the given factory function.
Args:
factory: Function that creates the instance when called
Returns:
ThreadSafeLazyLoader instance
"""
return ThreadSafeLazyLoader(factory)
class LazyProxy[T]:
"""Proxy object that forwards attribute access to a lazy-loaded instance."""
def __init__(self, loader: ThreadSafeLazyLoader[T]) -> None:
"""Initialize the proxy with a lazy loader.
Args:
loader: ThreadSafeLazyLoader instance
"""
self._loader = loader
def __getattr__(self, name: str) -> object:
"""Forward attribute access to the lazy-loaded instance."""
return getattr(self._loader.get_instance(), name)
def __call__(self, *args: object, **kwargs: object) -> object:
"""Forward calls to the lazy-loaded instance."""
instance = self._loader.get_instance()
if not callable(instance):
raise TypeError(f"Instance of type {type(instance)} is not callable")
return instance(*args, **kwargs)

View File

@@ -11,14 +11,15 @@ Example:
'split into chunks.']
"""
from typing import Any, Protocol, cast, runtime_checkable
from typing import TYPE_CHECKING, Any, Protocol, cast, runtime_checkable
try:
if TYPE_CHECKING:
import nltk
from nltk.tokenize.api import TokenizerI
except ImportError:
TokenizerI = None
nltk = None
else:
try:
import nltk
except ImportError:
nltk = None
# Optional: tiktoken import for future token estimation (not used in chunking)
try:
@@ -79,41 +80,8 @@ def chunk_text(text: str, chunk_size: int = 1000, overlap: int = 100) -> list[st
if tokenizer is None:
tokenizer = _get_tokenizer()
if tokenizer is None:
# Fallback: split by whitespace tokens using SimpleTokenizer,
# or by characters if only one token
simple_tokenizer = SimpleTokenizer()
tokens = list(simple_tokenizer.span_tokenize(text))
if len(tokens) <= 1:
# No spaces, fallback to character-based chunking
chunks = []
i = 0
while i < len(text):
chunk = text[i : i + chunk_size]
chunks.append(chunk)
i += chunk_size - overlap if (chunk_size - overlap) > 0 else chunk_size
return chunks
else:
chunks = []
step = chunk_size - overlap
if step <= 0:
step = chunk_size
i = 0
while i < len(tokens):
chunk_tokens = tokens[i : i + chunk_size]
if chunk_tokens:
start_pos = chunk_tokens[0][0]
end_pos = chunk_tokens[-1][1]
chunks.append(text[start_pos:end_pos])
i += step
return chunks
if not isinstance(tokenizer, TokenizerProtocol):
raise RuntimeError(
"Tokenizer does not implement the required span_tokenize method."
)
# Tokenize and create chunks
# _get_tokenizer() always returns a tokenizer (with fallback), so no need to check
# None
tokens = list(tokenizer.span_tokenize(text))
# Fallback to character-based splitting if only a single token
@@ -128,6 +96,9 @@ def chunk_text(text: str, chunk_size: int = 1000, overlap: int = 100) -> list[st
return char_chunks
chunks = []
step = chunk_size - overlap
if step <= 0:
step = chunk_size
i = 0
while i < len(tokens):
chunk_tokens = tokens[i : i + chunk_size]
@@ -135,7 +106,7 @@ def chunk_text(text: str, chunk_size: int = 1000, overlap: int = 100) -> list[st
start_pos = chunk_tokens[0][0]
end_pos = chunk_tokens[-1][1]
chunks.append(text[start_pos:end_pos])
i += chunk_size - overlap
i += step
return chunks
@@ -156,12 +127,9 @@ def _get_tokenizer() -> TokenizerProtocol:
return None
# Download the required NLTK data
if nltk is not None:
nltk.download(tokenizer_path.split("/")[0], quiet=True)
# Load the tokenizer
tokenizer_obj: Any = nltk.data.load(tokenizer_path)
else:
return None
nltk.download(tokenizer_path.split("/")[0], quiet=True)
# Load the tokenizer
tokenizer_obj: Any = nltk.data.load(tokenizer_path)
# Check if it's a tokenizer that implements span_tokenize
if hasattr(tokenizer_obj, "span_tokenize") and callable(
@@ -206,7 +174,7 @@ def _get_tokenizer() -> TokenizerProtocol:
# Final fallback to word tokenizer
try:
if nltk is not None:
if nltk:
nltk.download("punkt", quiet=True)
from nltk.tokenize import word_tokenize
@@ -267,6 +235,7 @@ def split_into_sentences(text: str) -> list[str]:
Returns:
List of sentences
"""
# Check if nltk exists and is not None before checking hasattr
if nltk and hasattr(nltk, "data"):
try:
nltk.download("punkt", quiet=True)

View File

@@ -269,7 +269,7 @@ def detect_content_type(url: str, content: str) -> str:
if isinstance(result, dict) and "type" in result:
content_heuristic_type = str(result["type"])
break
if isinstance(result, str) and result is not None:
if isinstance(result, str) and result:
content_heuristic_type = result
break
except Exception:

View File

@@ -176,7 +176,12 @@ def validate_content(content: str, min_length: int = MIN_CONTENT_LENGTH) -> bool
Returns:
True if content is valid, False otherwise.
"""
if not content or not isinstance(content, str):
if not isinstance(content, str):
warning_highlight(
"Content is not a string.", category="content_validation"
)
return False
if not content:
warning_highlight(
"Content is empty or not a string.", category="content_validation"
)
@@ -207,7 +212,7 @@ def preprocess_content(arg1: str | None, arg2: str) -> str:
Returns:
The cleaned and normalized content string.
"""
if not arg2 or not isinstance(arg2, str) or len(arg2.strip()) < MIN_CONTENT_LENGTH:
if not isinstance(arg2, str) or not arg2 or len(arg2.strip()) < MIN_CONTENT_LENGTH:
return ""
content = arg2
cleaned = re.sub(r"(?is)<.*?>", "", content)

View File

@@ -2,7 +2,7 @@
import functools
from collections.abc import Callable
from typing import ParamSpec, TypeVar, cast
from typing import Any, ParamSpec, TypeVar, cast
from ..errors import ValidationError
@@ -28,7 +28,7 @@ def _check_type(value: object, expected_type: object) -> bool:
def validate_args(
**validators: Callable[[object], tuple[bool, str | None]],
) -> Callable[[Callable[P, T]], Callable[P, T]]:
) -> Callable[[Callable[..., T]], Callable[..., T]]:
"""Decorator to validate function arguments.
Args:
@@ -50,17 +50,14 @@ def validate_args(
```
"""
def decorator(func: Callable[P, T]) -> Callable[P, T]:
def decorator(func: Callable[..., T]) -> Callable[..., T]:
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
def wrapper(*args: Any, **kwargs: Any) -> T:
# Get function signature
import inspect
sig = inspect.signature(cast("Callable[..., object]", func))
# Convert ParamSpec args/kwargs to regular tuple/dict for binding
bound_args = sig.bind(
*cast("tuple[object, ...]", args), **cast("dict[str, object]", kwargs)
)
sig = inspect.signature(func)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
# Validate each argument
@@ -73,7 +70,7 @@ def validate_args(
f"Invalid argument '{arg_name}': {error_msg}"
)
return func(*cast("P.args", args), **cast("P.kwargs", kwargs))
return func(*args, **kwargs)
return wrapper
@@ -82,7 +79,7 @@ def validate_args(
def validate_return(
validator: Callable[[object], tuple[bool, str | None]],
) -> Callable[[Callable[P, T]], Callable[P, T]]:
) -> Callable[[Callable[..., T]], Callable[..., T]]:
"""Decorator to validate function return value.
Args:
@@ -99,9 +96,9 @@ def validate_return(
```
"""
def decorator(func: Callable[P, T]) -> Callable[P, T]:
def decorator(func: Callable[..., T]) -> Callable[..., T]:
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
def wrapper(*args: Any, **kwargs: Any) -> T:
result = func(*args, **kwargs)
is_valid, error_msg = validator(result)
if not is_valid:
@@ -117,7 +114,7 @@ def validate_return(
def validate_not_none(
*arg_names: str,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
) -> Callable[[Callable[..., T]], Callable[..., T]]:
"""Decorator to ensure specified arguments are not None.
Args:
@@ -127,16 +124,13 @@ def validate_not_none(
Decorated function
"""
def decorator(func: Callable[P, T]) -> Callable[P, T]:
def decorator(func: Callable[..., T]) -> Callable[..., T]:
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
def wrapper(*args: Any, **kwargs: Any) -> T:
import inspect
sig = inspect.signature(cast("Callable[..., object]", func))
# Convert ParamSpec args/kwargs to regular tuple/dict for binding
bound_args = sig.bind(
*cast("tuple[object, ...]", args), **cast("dict[str, object]", kwargs)
)
sig = inspect.signature(func)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
for arg_name in arg_names:
@@ -146,7 +140,7 @@ def validate_not_none(
):
raise ValidationError(f"Argument '{arg_name}' cannot be None")
return func(*cast("P.args", args), **cast("P.kwargs", kwargs))
return func(*args, **kwargs)
return wrapper
@@ -155,7 +149,7 @@ def validate_not_none(
def validate_types(
**expected_types: type,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
) -> Callable[[Callable[..., T]], Callable[..., T]]:
"""Decorator to validate argument types.
Args:
@@ -172,16 +166,13 @@ def validate_types(
```
"""
def decorator(func: Callable[P, T]) -> Callable[P, T]:
def decorator(func: Callable[..., T]) -> Callable[..., T]:
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
def wrapper(*args: Any, **kwargs: Any) -> T:
import inspect
sig = inspect.signature(cast("Callable[..., object]", func))
# Convert ParamSpec args/kwargs to regular tuple/dict for binding
bound_args = sig.bind(
*cast("tuple[object, ...]", args), **cast("dict[str, object]", kwargs)
)
sig = inspect.signature(func)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
for arg_name, expected_type in expected_types.items():
@@ -196,7 +187,7 @@ def validate_types(
f"got {actual_type}"
)
return func(*cast("P.args", args), **cast("P.kwargs", kwargs))
return func(*args, **kwargs)
return wrapper

View File

@@ -169,7 +169,7 @@ async def process_document(url: str, content: bytes | None = None) -> tuple[str,
cache_key = f"document_processing_{url_hash}"
cached_result = _get_from_cache(cache_key)
if cached_result and isinstance(cached_result, dict):
if cached_result:
logger.info(f"Using cached document extraction for {url}")
text = cached_result.get("text", "")
content_type = cached_result.get("content_type", document_type)
@@ -219,6 +219,8 @@ async def process_document(url: str, content: bytes | None = None) -> tuple[str,
logger.info(
f"Document processed in {time.time() - start_time:.2f}s [{document_type}]"
)
if extracted_text is None:
raise ValueError("Extracted text is None after processing")
return extracted_text, document_type
except Exception as e:

View File

@@ -136,7 +136,7 @@ def mark_validated(func: Callable[..., Any]) -> None:
def has_errors(state: StateDict) -> bool:
"""Check if the state contains errors."""
return isinstance(state, dict) and "errors" in state and bool(state["errors"])
return "errors" in state and bool(state["errors"])
def add_exception_error(state: StateDict, exc: Exception, source: str) -> StateDict:
@@ -177,21 +177,15 @@ def validate_node_input(input_model: type[BaseModel]) -> Callable[[F], F]:
return state
try:
validator(state)
async_func = cast("Callable[..., Awaitable[StateDict]]", func)
async_func = cast(Callable[..., Awaitable[StateDict]], func)
result = await async_func(state, *args, **kwargs)
return cast("StateDict", result)
return cast(StateDict, result)
except NodeValidationError as exc:
validation_source: str = f"{func.__module__}.{func.__name__}"
return cast(
"StateDict",
add_exception_error(state, exc, validation_source),
)
return add_exception_error(state, exc, validation_source)
except Exception as exc:
exception_source: str = f"{func.__module__}.{func.__name__}"
return cast(
"StateDict",
add_exception_error(state, exc, exception_source),
)
return add_exception_error(state, exc, exception_source)
mark_validated(async_wrapper)
return cast("F", async_wrapper)
@@ -206,19 +200,13 @@ def validate_node_input(input_model: type[BaseModel]) -> Callable[[F], F]:
try:
validator(state)
result = func(state, *args, **kwargs)
return cast("StateDict", result)
return cast(StateDict, result)
except NodeValidationError as exc:
sync_source: str = f"{func.__module__}.{func.__name__}"
return cast(
"StateDict",
add_exception_error(state, exc, sync_source),
)
return add_exception_error(state, exc, sync_source)
except Exception as exc:
sync_exc_source: str = f"{func.__module__}.{func.__name__}"
return cast(
"StateDict",
add_exception_error(state, exc, sync_exc_source),
)
return add_exception_error(state, exc, sync_exc_source)
mark_validated(sync_wrapper)
return cast("F", sync_wrapper)
@@ -245,20 +233,16 @@ def validate_node_output(output_model: type[BaseModel]) -> Callable[[F], F]:
) -> StateDict:
async_func = cast("Callable[..., Awaitable[StateDict]]", func)
result = await async_func(state, *args, **kwargs)
result_dict: StateDict = result
try:
pydantic_validator = cast(
"Callable[[dict[str, Any]], BaseModel]", validator
)
validated_result: BaseModel = pydantic_validator(
cast("dict[str, Any]", result)
)
return cast("StateDict", validated_result.model_dump())
validated_result: BaseModel = pydantic_validator(result_dict)
return validated_result.model_dump()
except NodeValidationError as e:
source: str = f"{func.__module__}.{func.__name__}"
return cast(
"StateDict",
add_exception_error(cast("StateDict", result), e, source),
)
return add_exception_error(result_dict, e, source)
mark_validated(async_wrapper)
return cast("F", async_wrapper)
@@ -268,21 +252,18 @@ def validate_node_output(output_model: type[BaseModel]) -> Callable[[F], F]:
def sync_wrapper(
state: StateDict, *args: object, **kwargs: object
) -> StateDict:
result = func(state, *args, **kwargs)
sync_func = cast("Callable[..., StateDict]", func)
result = sync_func(state, *args, **kwargs)
result_dict: StateDict = result
try:
pydantic_validator = cast(
"Callable[[dict[str, Any]], BaseModel]", validator
)
validated_result: BaseModel = pydantic_validator(
cast("dict[str, Any]", result)
)
return cast("StateDict", validated_result.model_dump())
validated_result: BaseModel = pydantic_validator(result_dict)
return validated_result.model_dump()
except NodeValidationError as e:
source: str = f"{func.__module__}.{func.__name__}"
return cast(
"StateDict",
add_exception_error(cast("StateDict", result), e, source),
)
return add_exception_error(result_dict, e, source)
mark_validated(sync_wrapper)
return cast("F", sync_wrapper)
@@ -290,14 +271,14 @@ def validate_node_output(output_model: type[BaseModel]) -> Callable[[F], F]:
return decorator
def validated_node(
_func: Any = None,
def validated_node[F: Callable[..., object]](
_func: F | None = None,
*,
name: str | None = None,
input_model: type[BaseModel] | None = None,
output_model: type[BaseModel] | None = None,
**metadata: object,
) -> Any:
) -> F | Callable[[F], F]:
"""Create a validated node with input and output models.
Decorates a function to create a node that validates both input and
@@ -333,7 +314,7 @@ def validated_node(
return decorator
else:
# Called without parameters: @validated_node
return cast("F | Callable[[F], F]", decorator(_func))
return decorator(_func)
async def validate_graph(graph: object, graph_id: str = "unknown") -> bool:
@@ -390,24 +371,31 @@ async def ensure_graph_compatibility(
async def validate_all_graphs(
graph_functions: dict[str, Callable[[], Awaitable[Any]]],
graph_functions: dict[str, object],
) -> bool:
"""Validate all graph creation functions.
Performs batch validation of multiple graph creation functions.
"""
all_valid: bool = True
for name, func in graph_functions.items():
for name in graph_functions:
try:
func = graph_functions[name]
if not asyncio.iscoroutinefunction(func):
all_valid = False
continue
# Type narrowing for pyrefly - we know it's a coroutine function now
# Use a more specific cast to help with type inference
if callable(func):
# Use type ignore to suppress the complex type inference issue
graph = await func() # type: ignore[misc]
await validate_graph(graph, name)
# Call the function directly since we know it's a coroutine function
# Use inspect to determine if function takes arguments
sig = inspect.signature(func)
if len(sig.parameters) == 0:
# Cast to the expected coroutine function type
coro_func = cast("Callable[[], Awaitable[object]]", func)
graph = await coro_func()
else:
# Skip functions that require arguments for now
all_valid = False
continue
await validate_graph(graph, name)
except Exception:
all_valid = False
return all_valid

View File

@@ -324,16 +324,5 @@ def _handle_list_extend(
if sublist_str not in seen_items:
seen_items.add(sublist_str)
if not isinstance(deduped_sublist, list):
deduped_sublist = [deduped_sublist] # type: ignore[unreachable]
deduped_sublist = [deduped_sublist]
merged_list.append(deduped_sublist)
def _handle_average_collection(
merged: dict[str, dict[str, list[float]]], key: str, value: float
) -> None:
"""Collect values for averaging under merged['values'][key]."""
if "values" not in merged:
merged["values"] = {}
if key not in merged["values"]:
merged["values"][key] = []
merged["values"][key].append(value)

View File

@@ -335,9 +335,7 @@ def count_recent_sources(sources: list["Source"], recency_threshold: int) -> int
try:
if published_date.endswith("Z"):
published_date = published_date.replace("Z", "+00:00")
date_candidate = datetime.fromisoformat(published_date)
if isinstance(date_candidate, datetime):
date = date_candidate
date = datetime.fromisoformat(published_date)
except ValueError:
parsed_result = dateutil_parser.parse(published_date)
# dateutil_parser.parse may rarely return a tuple
@@ -345,20 +343,20 @@ def count_recent_sources(sources: list["Source"], recency_threshold: int) -> int
if isinstance(parsed_result, tuple):
# Take the first element if it's a datetime
first_elem = parsed_result[0] if parsed_result else None
if isinstance(first_elem, datetime):
if first_elem:
date = first_elem
else:
continue
elif isinstance(parsed_result, datetime):
elif parsed_result:
date = parsed_result
else:
continue
elif isinstance(published_date, datetime):
elif published_date:
date = published_date
else:
continue
# Ensure date is a datetime object
if not isinstance(date, datetime):
# Ensure date is valid
if not (date and isinstance(date, datetime)):
continue
if date.tzinfo is None:
date = date.replace(tzinfo=UTC)
@@ -400,12 +398,12 @@ def extract_topics_from_facts(facts: list["ExtractedFact"]) -> list[str]:
def get_topics_in_fact(fact: "ExtractedFact") -> set[str]:
"""Extract topics from a single fact."""
topics = set()
if "data" in fact and isinstance(fact["data"], dict):
if "data" in fact and fact["data"]:
data = fact["data"]
if fact.get("type") == "vendor":
if "vendor_name" in data:
vendor_name = data["vendor_name"]
if isinstance(vendor_name, str):
if vendor_name and isinstance(vendor_name, str):
topics.add(vendor_name.lower())
elif fact.get("type") == "relationship":
entities = data.get("entities", [])
@@ -426,7 +424,7 @@ def get_topics_in_fact(fact: "ExtractedFact") -> set[str]:
description = data["description"]
if isinstance(description, str):
extract_noun_phrases(description, topics)
if "source_text" in fact and isinstance(fact["source_text"], str):
if "source_text" in fact and fact["source_text"]:
extract_noun_phrases(fact["source_text"], topics)
return topics
@@ -449,10 +447,10 @@ def perform_statistical_validation(facts: list["ExtractedFact"]) -> float:
pattern = re.compile(r"\b\d+(?:\.\d+)?\b")
for fact in facts:
source_text = fact.get("source_text", "")
if isinstance(source_text, str):
if source_text:
found_numbers = pattern.findall(source_text)
numeric_values.extend(float(n) for n in found_numbers)
if "data" in fact and isinstance(fact["data"], dict):
if "data" in fact and fact["data"]:
for _key, val in fact["data"].items():
if isinstance(val, int | float):
numeric_values.append(float(val))

View File

@@ -25,7 +25,7 @@ def validate_type[T](value: object, expected_type: type[T]) -> tuple[bool, str |
return False, f"Expected {expected_name}, got {actual_type}"
def is_valid_email(email: str) -> bool:
def is_valid_email(email: object) -> bool:
"""Check if string is a valid email address.
Args:
@@ -42,7 +42,7 @@ def is_valid_email(email: str) -> bool:
return bool(re.match(pattern, email))
def is_valid_url(url: str) -> bool:
def is_valid_url(url: object) -> bool:
"""Check if string is a valid URL.
Args:
@@ -61,7 +61,7 @@ def is_valid_url(url: str) -> bool:
return False
def is_valid_phone(phone: str) -> bool:
def is_valid_phone(phone: object) -> bool:
"""Check if string is a valid phone number.
Args:
@@ -81,7 +81,7 @@ def is_valid_phone(phone: str) -> bool:
def validate_string_length(
value: str,
value: object,
min_length: int | None = None,
max_length: int | None = None,
) -> tuple[bool, str | None]:
@@ -110,7 +110,7 @@ def validate_string_length(
def validate_number_range(
value: int | float,
value: object,
min_value: int | float | None = None,
max_value: int | float | None = None,
) -> tuple[bool, str | None]:
@@ -124,7 +124,7 @@ def validate_number_range(
Returns:
Tuple of (is_valid, error_message)
"""
if not isinstance(value, int | float):
if not isinstance(value, (int, float)):
return False, "Value must be a number"
if min_value is not None and value < min_value:
@@ -137,7 +137,7 @@ def validate_number_range(
def validate_list_length(
value: list[object],
value: object,
min_length: int | None = None,
max_length: int | None = None,
) -> tuple[bool, str | None]:

View File

@@ -17,7 +17,7 @@ from bb_core.caching import FileCache
# NOTE: LLMCache is not available in bb_core
# FileCache has been replaced with FileCache
UTC = UTC
# UTC is already imported from datetime
@pytest.fixture
@@ -68,12 +68,14 @@ class TestFileCache:
async def test_ainit_idempotent(self, file_backend: FileCache) -> None:
"""Test that initialization can happen multiple times safely."""
# FileCache initializes on first use
await file_backend._ensure_initialized()
assert file_backend._initialized
await file_backend.ensure_initialized()
# Verify by checking if cache directory exists
assert file_backend.cache_dir.exists()
# Should be safe to call again
await file_backend._ensure_initialized()
assert file_backend._initialized
await file_backend.ensure_initialized()
# Directory should still exist
assert file_backend.cache_dir.exists()
@pytest.mark.asyncio
async def test_ainit_failure(self, temp_cache_dir: str) -> None:

View File

@@ -48,7 +48,7 @@ class TestCacheIntegration:
kwargs = {"flag": True, "data": {"key": "value"}, "custom_obj": TestClass()}
# Generate key and set value
key = cache_instance._generate_key(args, cast("dict[str, object]", kwargs)) # type: ignore
key = cache_instance._generate_key(args, cast("dict[str, object]", kwargs))
test_value = {"result": "success", "count": 123}
await cache_instance.set(key, test_value)

View File

@@ -54,9 +54,9 @@ class TestLLMCache:
backend = MockBackend()
cache = LLMCache(backend=cast("CacheBackend[Any]", backend))
# Backend should be set
assert cache._backend is backend
assert not cache._ainit_done
# Backend should be set (accessing protected attributes in tests)
assert cache._backend is backend # noqa: SLF001
assert not cache._ainit_done # noqa: SLF001
@pytest.mark.asyncio
async def test_init_without_backend(self) -> None:
@@ -64,11 +64,12 @@ class TestLLMCache:
with tempfile.TemporaryDirectory() as tmpdir:
cache = LLMCache[str](cache_dir=tmpdir, ttl=3600, serializer="pickle")
# Should create AsyncFileCacheBackend
assert isinstance(cache._backend, AsyncFileCacheBackend)
assert str(cache._backend.cache_dir) == tmpdir
assert cache._backend.ttl == 3600
assert cache._backend.serializer == "pickle"
# Should create AsyncFileCacheBackend (accessing protected attributes in
# tests)
assert isinstance(cache._backend, AsyncFileCacheBackend) # noqa: SLF001
assert str(cache._backend.cache_dir) == tmpdir # noqa: SLF001
assert cache._backend.ttl == 3600 # noqa: SLF001
assert cache._backend.serializer == "pickle" # noqa: SLF001
@pytest.mark.asyncio
async def test_ensure_backend_initialized(self) -> None:
@@ -76,14 +77,14 @@ class TestLLMCache:
backend = MockBackend()
cache = LLMCache(backend=cast("CacheBackend[Any]", backend))
# First call should initialize
await cache._ensure_backend_initialized()
# First call should initialize (accessing protected method in tests)
await cache._ensure_backend_initialized() # noqa: SLF001
assert backend.ainit_called == 1
assert cache._ainit_done
assert cache._ainit_done # noqa: SLF001
# Subsequent calls should not reinitialize
await cache._ensure_backend_initialized()
await cache._ensure_backend_initialized()
await cache._ensure_backend_initialized() # noqa: SLF001
await cache._ensure_backend_initialized() # noqa: SLF001
assert backend.ainit_called == 1
@pytest.mark.asyncio
@@ -108,25 +109,26 @@ class TestLLMCache:
backend = SimpleBackend()
cache: LLMCache[None] = LLMCache(backend=backend)
# Should handle backend without ainit gracefully
await cache._ensure_backend_initialized()
assert cache._ainit_done
# Should handle backend without ainit gracefully (accessing protected method
# in tests)
await cache._ensure_backend_initialized() # noqa: SLF001
assert cache._ainit_done # noqa: SLF001
def test_generate_key_basic(self) -> None:
"""Test basic key generation."""
cache = LLMCache()
# Test with simple args
key1 = cache._generate_key(("hello", 42), {"flag": True})
# Test with simple args (accessing protected method in tests)
key1 = cache._generate_key(("hello", 42), {"flag": True}) # noqa: SLF001
assert isinstance(key1, str)
assert len(key1) == 64 # SHA-256 hex digest length
# Same args should generate same key
key2 = cache._generate_key(("hello", 42), {"flag": True})
key2 = cache._generate_key(("hello", 42), {"flag": True}) # noqa: SLF001
assert key1 == key2
# Different args should generate different keys
key3 = cache._generate_key(("hello", 43), {"flag": True})
key3 = cache._generate_key(("hello", 43), {"flag": True}) # noqa: SLF001
assert key1 != key3
def test_generate_key_complex_types(self) -> None:
@@ -144,7 +146,7 @@ class TestLLMCache:
)
complex_kwargs = {"list": [4, 5, 6], "tuple": (7, 8, 9), "none": None}
key = cache._generate_key(
key = cache._generate_key( # noqa: SLF001
complex_args, cast("dict[str, object]", complex_kwargs)
)
assert isinstance(key, str)
@@ -154,13 +156,13 @@ class TestLLMCache:
"""Test key generation with empty arguments."""
cache = LLMCache()
# Empty args and kwargs
key = cache._generate_key((), {})
# Empty args and kwargs (accessing protected method in tests)
key = cache._generate_key((), {}) # noqa: SLF001
assert isinstance(key, str)
assert len(key) == 64
# The key should be consistent for empty args
key2 = cache._generate_key((), {})
key2 = cache._generate_key((), {}) # noqa: SLF001
assert key == key2
def test_generate_key_json_serialization_error(self) -> None:
@@ -172,8 +174,9 @@ class TestLLMCache:
def __str__(self) -> str:
return "non-serializable"
# This should fall back to string representation
key = cache._generate_key((NonSerializable(),), {})
# This should fall back to string representation (accessing protected method
# in tests)
key = cache._generate_key((NonSerializable(),), {}) # noqa: SLF001
assert isinstance(key, str)
assert len(key) == 64
@@ -186,8 +189,9 @@ class TestLLMCache:
def __str__(self) -> str:
raise RuntimeError("Cannot convert to string")
# Should still generate a key using error messages
key = cache._generate_key((FailingObject(),), {})
# Should still generate a key using error messages (accessing protected
# method in tests)
key = cache._generate_key((FailingObject(),), {}) # noqa: SLF001
assert isinstance(key, str)
assert len(key) == 64
@@ -290,20 +294,21 @@ class TestLLMCache:
args = ("test", 123, [1, 2, 3])
kwargs = {"flag": True, "data": {"nested": "value"}}
key1 = cache1._generate_key(args, cast("dict[str, object]", kwargs))
key2 = cache2._generate_key(args, cast("dict[str, object]", kwargs))
key1 = cache1._generate_key(args, cast("dict[str, object]", kwargs)) # noqa: SLF001
key2 = cache2._generate_key(args, cast("dict[str, object]", kwargs)) # noqa: SLF001
assert key1 == key2
def test_generate_key_order_matters(self) -> None:
"""Test that argument order affects key generation."""
cache = LLMCache()
# Different arg order should produce different keys
key1 = cache._generate_key((1, 2), {})
key2 = cache._generate_key((2, 1), {})
# Different arg order should produce different keys (accessing protected
# method in tests)
key1 = cache._generate_key((1, 2), {}) # noqa: SLF001
key2 = cache._generate_key((2, 1), {}) # noqa: SLF001
assert key1 != key2
# Different kwarg order should produce same key (sorted)
key3 = cache._generate_key((), {"a": 1, "b": 2})
key4 = cache._generate_key((), {"b": 2, "a": 1})
key3 = cache._generate_key((), {"a": 1, "b": 2}) # noqa: SLF001
key4 = cache._generate_key((), {"b": 2, "a": 1}) # noqa: SLF001
assert key3 == key4

View File

@@ -57,8 +57,8 @@ class TestRedisCache:
assert cache.url == "redis://example.com:6379"
assert cache.key_prefix == "myapp:"
assert cache._decode_responses is True # pyright: ignore[reportPrivateUsage]
assert cache._client is None # pyright: ignore[reportPrivateUsage]
assert cache._decode_responses is True
assert cache._client is None
@pytest.mark.asyncio
async def test_ensure_connected_success(
@@ -66,10 +66,10 @@ class TestRedisCache:
) -> None:
"""Test successful Redis connection."""
with patch("redis.asyncio.from_url", return_value=mock_redis_client):
client = await cache._ensure_connected() # pyright: ignore[reportPrivateUsage]
client = await cache._ensure_connected()
assert client == mock_redis_client
assert cache._client == mock_redis_client # pyright: ignore[reportPrivateUsage]
assert cache._client == mock_redis_client
mock_redis_client.ping.assert_called_once()
@pytest.mark.asyncio
@@ -82,13 +82,13 @@ class TestRedisCache:
patch("redis.asyncio.from_url", return_value=mock_client),
pytest.raises(ConfigurationError, match="Failed to connect to Redis"),
):
await cache._ensure_connected()
await cache._ensure_connected() # noqa: SLF001
def test_make_key(self, cache: RedisCache) -> None:
"""Test key prefixing."""
assert cache._make_key("mykey") == "test:mykey"
assert cache._make_key("") == "test:"
assert cache._make_key("path/to/key") == "test:path/to/key"
"""Test key prefixing (accessing protected method in tests)."""
assert cache._make_key("mykey") == "test:mykey" # noqa: SLF001
assert cache._make_key("") == "test:" # noqa: SLF001
assert cache._make_key("path/to/key") == "test:path/to/key" # noqa: SLF001
@pytest.mark.asyncio
async def test_get_not_found(
@@ -332,17 +332,18 @@ class TestRedisCache:
await cache.close()
mock_redis_client.close.assert_called_once()
assert cache._client is None
assert cache._client is None # noqa: SLF001
@pytest.mark.asyncio
async def test_close_without_connection(self, cache: RedisCache) -> None:
"""Test closing when no connection exists."""
assert cache._client is None
"""Test closing when no connection exists (accessing protected attribute
in tests)."""
assert cache._client is None # noqa: SLF001
# Should not raise any exception
await cache.close()
assert cache._client is None
assert cache._client is None # noqa: SLF001
class TestRedisCacheIntegration:

View File

@@ -0,0 +1,13 @@
"""Edge helpers test suite.
Comprehensive tests for all edge helper modules covering:
- Core routing factory functions
- Flow control helpers
- Validation helpers
- User interaction helpers
- Error handling helpers
- Monitoring helpers
These tests ensure the edge helpers handle all edge cases properly
and are robust for use in LangGraph workflows.
"""

View File

@@ -0,0 +1,458 @@
"""Comprehensive tests for core edge helper routing factories.
These tests cover edge cases including:
- None/empty state values
- Invalid data types
- Missing keys
- Boundary conditions
- Both dict and StateProtocol implementations
"""
from typing import Any
from bb_core.edge_helpers.core import (
create_bool_router,
create_enum_router,
create_field_presence_router,
create_list_length_router,
create_status_router,
create_threshold_router,
)
class MockState:
"""Mock state object implementing StateProtocol."""
def __init__(self, data: dict[str, Any]):
self._data = data
def get(self, key: str, default: Any = None) -> Any:
return self._data.get(key, default)
def __getattr__(self, name: str) -> Any:
return self._data.get(name)
class TestCreateEnumRouter:
"""Test create_enum_router with various edge cases."""
def test_basic_enum_routing_dict_state(self):
"""Test basic enum routing with dict state."""
enum_mapping = {
"continue": "next",
"retry": "retry_node",
"error": "error_node",
}
router = create_enum_router(enum_mapping)
assert router({"routing_decision": "continue"}) == "next"
assert router({"routing_decision": "retry"}) == "retry_node"
assert router({"routing_decision": "error"}) == "error_node"
def test_basic_enum_routing_state_protocol(self):
"""Test basic enum routing with StateProtocol."""
enum_mapping = {
"continue": "next",
"retry": "retry_node",
"error": "error_node",
}
router = create_enum_router(enum_mapping)
state = MockState({"routing_decision": "continue"})
assert router(state) == "next"
state = MockState({"routing_decision": "retry"})
assert router(state) == "retry_node"
def test_default_target_fallback(self):
"""Test fallback to default target."""
enum_mapping = {"continue": "next"}
router = create_enum_router(enum_mapping, default_target="fallback")
# Missing key should use default
assert router({}) == "fallback"
assert router({"routing_decision": None}) == "fallback"
assert router({"routing_decision": "unknown"}) == "fallback"
def test_none_values(self):
"""Test handling of None values."""
enum_mapping = {"continue": "next"}
router = create_enum_router(enum_mapping)
assert router({"routing_decision": None}) == "end"
assert router(MockState({"routing_decision": None})) == "end"
def test_numeric_enum_values(self):
"""Test enum routing with numeric values."""
enum_mapping = {"1": "step_one", "2": "step_two", "0": "reset"}
router = create_enum_router(enum_mapping, state_key="step")
assert router({"step": 1}) == "step_one"
assert router({"step": "2"}) == "step_two"
assert router({"step": 0}) == "reset"
def test_custom_state_key(self):
"""Test with custom state key."""
enum_mapping = {"active": "process", "inactive": "skip"}
router = create_enum_router(enum_mapping, state_key="status")
assert router({"status": "active"}) == "process"
assert router({"status": "inactive"}) == "skip"
def test_missing_state_key(self):
"""Test behavior when state key is missing."""
enum_mapping = {"continue": "next"}
router = create_enum_router(enum_mapping, state_key="missing_key")
assert router({"other_key": "value"}) == "end"
assert router(MockState({"other_key": "value"})) == "end"
def test_empty_state(self):
"""Test with empty state."""
enum_mapping = {"continue": "next"}
router = create_enum_router(enum_mapping)
assert router({}) == "end"
assert router(MockState({})) == "end"
class TestCreateBoolRouter:
"""Test create_bool_router with various edge cases."""
def test_basic_bool_routing(self):
"""Test basic boolean routing."""
router = create_bool_router("success", "failure")
assert router({"condition": True}) == "success"
assert router({"condition": False}) == "failure"
assert router(MockState({"condition": True})) == "success"
def test_truthy_falsy_values(self):
"""Test truthy/falsy value handling."""
router = create_bool_router("yes", "no", state_key="flag")
# Truthy values
assert router({"flag": 1}) == "yes"
assert router({"flag": "true"}) == "yes"
assert router({"flag": [1, 2, 3]}) == "yes"
assert router({"flag": {"key": "value"}}) == "yes"
# Falsy values
assert router({"flag": 0}) == "no"
assert router({"flag": ""}) == "no"
assert router({"flag": []}) == "no"
assert router({"flag": {}}) == "no"
assert router({"flag": None}) == "no"
def test_missing_condition_key(self):
"""Test behavior when condition key is missing."""
router = create_bool_router("true_path", "false_path", state_key="missing")
assert router({}) == "false_path"
assert router({"other_key": True}) == "false_path"
def test_default_false_behavior(self):
"""Test default false behavior."""
router = create_bool_router("yes", "no")
# Missing key should default to False
assert router({}) == "no"
assert router(MockState({})) == "no"
class TestCreateThresholdRouter:
"""Test create_threshold_router with various edge cases."""
def test_basic_threshold_routing(self):
"""Test basic threshold comparison."""
router = create_threshold_router(0.5, "high", "low")
assert router({"score": 0.7}) == "high"
assert router({"score": 0.3}) == "low"
assert router({"score": 0.5}) == "high" # Equal goes to above_target
def test_equal_target_parameter(self):
"""Test custom equal_target parameter."""
router = create_threshold_router(0.5, "high", "low", equal_target="equal")
assert router({"score": 0.5}) == "equal"
assert router(MockState({"score": 0.5})) == "equal"
def test_string_numeric_conversion(self):
"""Test conversion of string numbers."""
router = create_threshold_router(10.0, "above", "below", state_key="value")
assert router({"value": "15.5"}) == "above"
assert router({"value": "5"}) == "below"
assert router({"value": "10.0"}) == "above"
def test_invalid_numeric_values(self):
"""Test handling of invalid numeric values."""
router = create_threshold_router(0.5, "high", "low")
# Non-numeric values should route to below_target
assert router({"score": "invalid"}) == "low"
assert router({"score": None}) == "low"
assert router({"score": []}) == "low"
assert router({"score": {}}) == "low"
def test_missing_score_key(self):
"""Test behavior when score key is missing."""
router = create_threshold_router(0.5, "high", "low", state_key="missing")
# Should default to 0.0 and route to below_target
assert router({}) == "low"
assert router({"other_key": 0.8}) == "low"
def test_boundary_conditions(self):
"""Test exact boundary conditions."""
router = create_threshold_router(0.0, "positive", "negative")
assert router({"score": 0.0}) == "positive"
assert router({"score": -0.1}) == "negative"
assert router({"score": 0.1}) == "positive"
class TestCreateFieldPresenceRouter:
"""Test create_field_presence_router with various edge cases."""
def test_all_fields_present(self):
"""Test when all required fields are present."""
router = create_field_presence_router(
["field1", "field2", "field3"], "complete", "incomplete"
)
state = {"field1": "value1", "field2": "value2", "field3": "value3"}
assert router(state) == "complete"
def test_missing_fields(self):
"""Test when some fields are missing."""
router = create_field_presence_router(
["field1", "field2", "field3"], "complete", "incomplete"
)
# Missing field2
state = {"field1": "value1", "field3": "value3"}
assert router(state) == "incomplete"
# All missing
assert router({}) == "incomplete"
def test_empty_string_values(self):
"""Test handling of empty string values."""
router = create_field_presence_router(
["field1", "field2"], "complete", "incomplete"
)
# Empty string should be considered missing
state = {"field1": "value", "field2": ""}
assert router(state) == "incomplete"
def test_none_values(self):
"""Test handling of None values."""
router = create_field_presence_router(
["field1", "field2"], "complete", "incomplete"
)
state = {"field1": "value", "field2": None}
assert router(state) == "incomplete"
def test_state_protocol_implementation(self):
"""Test with StateProtocol implementation."""
router = create_field_presence_router(
["field1", "field2"], "complete", "incomplete"
)
# Complete case
state = MockState({"field1": "value1", "field2": "value2"})
assert router(state) == "complete"
# Incomplete case
state = MockState({"field1": "value1"})
assert router(state) == "incomplete"
def test_zero_and_false_values(self):
"""Test that 0 and False are considered valid values."""
router = create_field_presence_router(
["number", "flag"], "complete", "incomplete"
)
# 0 and False should be valid (not None or empty string)
state = {"number": 0, "flag": False}
assert router(state) == "complete"
class TestCreateListLengthRouter:
"""Test create_list_length_router with various edge cases."""
def test_sufficient_length(self):
"""Test when list meets minimum length."""
router = create_list_length_router(3, "enough", "not_enough")
assert router({"items": [1, 2, 3]}) == "enough"
assert router({"items": [1, 2, 3, 4, 5]}) == "enough"
def test_insufficient_length(self):
"""Test when list is too short."""
router = create_list_length_router(3, "enough", "not_enough")
assert router({"items": [1, 2]}) == "not_enough"
assert router({"items": []}) == "not_enough"
def test_exact_length(self):
"""Test exact minimum length."""
router = create_list_length_router(3, "enough", "not_enough")
assert router({"items": [1, 2, 3]}) == "enough"
def test_non_list_values(self):
"""Test handling of non-list values."""
router = create_list_length_router(3, "enough", "not_enough")
# Non-list values should route to insufficient
assert router({"items": "string"}) == "not_enough"
assert router({"items": None}) == "not_enough"
assert router({"items": 123}) == "not_enough"
assert router({"items": {"key": "value"}}) == "not_enough"
def test_missing_items_key(self):
"""Test behavior when items key is missing."""
router = create_list_length_router(
3, "enough", "not_enough", state_key="missing"
)
# Should default to empty list and route to insufficient
assert router({}) == "not_enough"
assert router({"other_key": [1, 2, 3, 4]}) == "not_enough"
def test_custom_state_key(self):
"""Test with custom state key."""
router = create_list_length_router(2, "ok", "too_few", state_key="results")
assert router({"results": [1, 2, 3]}) == "ok"
assert router({"results": [1]}) == "too_few"
def test_string_as_iterable(self):
"""Test string handling (has length but not a list)."""
router = create_list_length_router(5, "enough", "not_enough")
# Strings have len() but should be treated as non-list
assert router({"items": "hello world"}) == "not_enough"
def test_list_like_iterables(self):
"""Test handling of list-like but non-list iterables (tuple, set)."""
router = create_list_length_router(3, "ok", "not_ok")
# Tuple with enough elements
assert router({"items": (1, 2, 3)}) == "ok"
# Tuple with too few elements
assert router({"items": (1,)}) == "not_ok"
# Set with enough elements
assert router({"items": {1, 2, 3}}) == "ok"
# Set with too few elements
assert router({"items": {1}}) == "not_ok"
def test_state_protocol_implementation(self):
"""Test with StateProtocol implementation."""
router = create_list_length_router(2, "enough", "not_enough")
state = MockState({"items": [1, 2, 3]})
assert router(state) == "enough"
state = MockState({"items": [1]})
assert router(state) == "not_enough"
class TestStatusRouter:
"""Test create_status_router (alias for create_enum_router)."""
def test_status_routing(self):
"""Test status-based routing."""
status_mapping = {
"pending": "wait",
"running": "monitor",
"completed": "finish",
"failed": "error_handler",
}
router = create_status_router(status_mapping)
assert router({"status": "pending"}) == "wait"
assert router({"status": "running"}) == "monitor"
assert router({"status": "completed"}) == "finish"
assert router({"status": "failed"}) == "error_handler"
assert router({"status": "unknown"}) == "end" # default
class TestEdgeCasesAndIntegration:
"""Test various edge cases and integration scenarios."""
def test_extremely_large_values(self):
"""Test with extremely large numeric values."""
router = create_threshold_router(1e10, "huge", "normal")
assert router({"score": 1e15}) == "huge"
assert router({"score": 1e5}) == "normal"
def test_extremely_small_values(self):
"""Test with extremely small numeric values."""
router = create_threshold_router(1e-10, "above", "below")
assert router({"score": 1e-15}) == "below"
assert router({"score": 1e-5}) == "above"
def test_unicode_string_handling(self):
"""Test handling of unicode strings."""
enum_mapping = {"🔥": "hot", "❄️": "cold", "🌞": "warm"}
router = create_enum_router(enum_mapping, state_key="weather")
assert router({"weather": "🔥"}) == "hot"
assert router({"weather": "❄️"}) == "cold"
assert router({"weather": "🌈"}) == "end" # default
def test_deeply_nested_state_access(self):
"""Test that routers handle flat state keys only."""
router = create_bool_router("yes", "no", state_key="nested.deep.value")
# Should look for literal key "nested.deep.value", not traverse
state = {"nested.deep.value": True}
assert router(state) == "yes"
# Should not traverse nested structures
nested_state = {"nested": {"deep": {"value": True}}}
assert router(nested_state) == "no" # key not found at top level
def test_concurrent_router_usage(self):
"""Test that routers are stateless and thread-safe."""
router = create_threshold_router(0.5, "high", "low")
# Multiple calls with different states should not interfere
state1 = {"score": 0.8}
state2 = {"score": 0.2}
assert router(state1) == "high"
assert router(state2) == "low"
assert router(state1) == "high" # Should still work
def test_memory_efficiency_large_lists(self):
"""Test memory efficiency with large lists."""
router = create_list_length_router(1000, "big", "small")
large_list = list(range(2000))
small_list = list(range(500))
assert router({"items": large_list}) == "big"
assert router({"items": small_list}) == "small"
def test_type_coercion_consistency(self):
"""Test consistent type coercion across routers."""
# Test that all routers handle type coercion consistently
enum_router = create_enum_router({"1": "one", "2": "two"})
threshold_router = create_threshold_router(1.5, "high", "low")
# Numeric as string
assert enum_router({"routing_decision": 1}) == "one"
assert threshold_router({"score": "2.0"}) == "high"
# String as number (where applicable)
assert enum_router({"routing_decision": "2"}) == "two"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,614 @@
"""Comprehensive tests for flow control edge helpers.
Tests cover edge cases including:
- Various message formats for tool calls
- Timeout conditions and edge cases
- Multi-step progress tracking
- Iteration limits and boundary conditions
- Workflow state transitions
"""
import time
from typing import Any
from unittest.mock import patch
from bb_core.edge_helpers.flow_control import (
check_completion_criteria,
check_iteration_limit,
check_workflow_state,
multi_step_progress,
should_continue,
timeout_check,
)
class MockState:
"""Mock state object for testing."""
def __init__(self, data: dict[str, Any]):
self._data = data
def get(self, key: str, default: Any = None) -> Any:
return self._data.get(key, default)
def __getattr__(self, name: str) -> Any:
return self._data.get(name)
class TestShouldContinue:
"""Test should_continue function with various message formats."""
def test_no_messages_should_end(self):
"""Test that empty messages list results in end."""
assert should_continue({}) == "end"
assert should_continue({"messages": []}) == "end"
assert should_continue(MockState({})) == "end"
def test_dict_message_with_tool_calls(self):
"""Test dict message format with tool_calls."""
messages = [
{
"role": "assistant",
"content": "I'll help you with that.",
"tool_calls": [{"function": {"name": "search"}}],
}
]
assert should_continue({"messages": messages}) == "continue"
def test_dict_message_with_additional_kwargs_tool_calls(self):
"""Test dict message with tool_calls in additional_kwargs."""
messages = [
{
"role": "assistant",
"content": "Let me search for that.",
"additional_kwargs": {"tool_calls": [{"function": {"name": "search"}}]},
}
]
assert should_continue({"messages": messages}) == "continue"
def test_dict_message_with_function_call(self):
"""Test dict message with legacy function_call format."""
messages = [
{
"role": "assistant",
"content": "",
"function_call": {"name": "search", "arguments": "{}"},
}
]
assert should_continue({"messages": messages}) == "continue"
def test_object_message_with_tool_calls_attribute(self):
"""Test message object with tool_calls attribute."""
class MockMessage:
def __init__(self, tool_calls=None):
self.tool_calls = tool_calls
# With tool calls
messages = [MockMessage(tool_calls=[{"function": "search"}])]
assert should_continue({"messages": messages}) == "continue"
# Without tool calls
messages = [MockMessage(tool_calls=None)]
assert should_continue({"messages": messages}) == "end"
# Empty tool calls
messages = [MockMessage(tool_calls=[])]
assert should_continue({"messages": messages}) == "end"
def test_object_message_with_additional_kwargs(self):
"""Test message object with additional_kwargs containing tool_calls."""
class MockMessage:
def __init__(self, additional_kwargs=None):
self.additional_kwargs = additional_kwargs
# With tool calls in additional_kwargs
messages = [
MockMessage(additional_kwargs={"tool_calls": [{"function": "search"}]})
]
assert should_continue({"messages": messages}) == "continue"
# Without tool calls
messages = [MockMessage(additional_kwargs={})]
assert should_continue({"messages": messages}) == "end"
# None additional_kwargs
messages = [MockMessage(additional_kwargs=None)]
assert should_continue({"messages": messages}) == "end"
def test_message_without_tool_calls(self):
"""Test message without any tool calls should end."""
messages = [
{"role": "assistant", "content": "Here's your answer without tool calls."}
]
assert should_continue({"messages": messages}) == "end"
def test_mixed_message_formats(self):
"""Test with multiple messages in different formats."""
messages = [
{"role": "user", "content": "Help me search"},
{"role": "assistant", "content": "I'll search for you."},
{
"role": "assistant",
"content": "Let me use a tool.",
"tool_calls": [{"function": "search"}],
},
]
# Should check only the last message
assert should_continue({"messages": messages}) == "continue"
def test_state_protocol_implementation(self):
"""Test with StateProtocol implementation."""
messages = [{"tool_calls": [{"function": {"name": "search"}}]}]
state = MockState({"messages": messages})
assert should_continue(state) == "continue"
def test_edge_case_empty_tool_calls(self):
"""Test edge case with empty tool_calls list."""
messages = [{"role": "assistant", "content": "Response", "tool_calls": []}]
assert should_continue({"messages": messages}) == "end"
def test_edge_case_malformed_messages(self):
"""Test with malformed message structures."""
# Non-dict, non-object message
messages = ["string_message"]
assert should_continue({"messages": messages}) == "end"
# None message
messages = [None]
assert should_continue({"messages": messages}) == "end"
# Message with invalid tool_calls type
messages = [{"tool_calls": "not_a_list"}]
assert should_continue({"messages": messages}) == "end"
class TestTimeoutCheck:
"""Test timeout_check function with various edge cases."""
@patch("time.time")
def test_no_timeout_within_limit(self, mock_time):
"""Test no timeout when within time limit."""
mock_time.return_value = 1000.0
router = timeout_check(timeout_seconds=300.0)
# Start time 100 seconds ago
state = {"start_time": 900.0}
assert router(state) == "continue"
@patch("time.time")
def test_timeout_exceeded(self, mock_time):
"""Test timeout when time limit exceeded."""
mock_time.return_value = 1000.0
router = timeout_check(timeout_seconds=300.0)
# Start time 400 seconds ago (exceeds 300s limit)
state = {"start_time": 600.0}
assert router(state) == "timeout"
@patch("time.time")
def test_exact_timeout_boundary(self, mock_time):
"""Test exact timeout boundary."""
mock_time.return_value = 1000.0
router = timeout_check(timeout_seconds=300.0)
# Exactly at timeout limit
state = {"start_time": 700.0}
assert router(state) == "continue"
# Just over timeout limit
state = {"start_time": 699.9}
assert router(state) == "timeout"
def test_missing_start_time(self):
"""Test behavior when start_time is missing."""
router = timeout_check(timeout_seconds=300.0)
assert router({}) == "continue"
assert router({"other_key": 123}) == "continue"
def test_none_start_time(self):
"""Test behavior when start_time is None."""
router = timeout_check(timeout_seconds=300.0)
assert router({"start_time": None}) == "continue"
def test_custom_start_time_key(self):
"""Test with custom start time key."""
router = timeout_check(timeout_seconds=60.0, start_time_key="begin_time")
with patch("time.time", return_value=1000.0):
state = {"begin_time": 950.0} # 50 seconds ago
assert router(state) == "continue"
state = {"begin_time": 930.0} # 70 seconds ago
assert router(state) == "timeout"
def test_state_protocol_implementation(self):
"""Test with StateProtocol implementation."""
router = timeout_check(timeout_seconds=300.0)
with patch("time.time", return_value=1000.0):
state = MockState({"start_time": 900.0})
assert router(state) == "continue"
class TestMultiStepProgress:
"""Test multi_step_progress function with various edge cases."""
def test_normal_progression(self):
"""Test normal step progression."""
router = multi_step_progress(total_steps=5)
assert router({"current_step": 0}) == "next_step"
assert router({"current_step": 1}) == "next_step"
assert router({"current_step": 4}) == "next_step"
def test_completion(self):
"""Test completion when steps are done."""
router = multi_step_progress(total_steps=5)
assert router({"current_step": 5}) == "complete"
assert router({"current_step": 6}) == "complete"
assert router({"current_step": 100}) == "complete"
def test_negative_step_error(self):
"""Test error for negative step numbers."""
router = multi_step_progress(total_steps=5)
assert router({"current_step": -1}) == "error"
assert router({"current_step": -10}) == "error"
def test_invalid_step_types(self):
"""Test error for invalid step types."""
router = multi_step_progress(total_steps=5)
assert router({"current_step": "invalid"}) == "error"
assert router({"current_step": []}) == "error"
assert router({"current_step": {}}) == "error"
assert router({"current_step": None}) == "error"
def test_missing_step_key(self):
"""Test behavior when step key is missing."""
router = multi_step_progress(total_steps=5)
# Should default to 0 and return next_step
assert router({}) == "next_step"
def test_custom_step_key(self):
"""Test with custom step key."""
router = multi_step_progress(total_steps=3, step_key="progress")
assert router({"progress": 0}) == "next_step"
assert router({"progress": 3}) == "complete"
def test_single_step_workflow(self):
"""Test single-step workflow."""
router = multi_step_progress(total_steps=1)
assert router({"current_step": 0}) == "next_step"
assert router({"current_step": 1}) == "complete"
def test_zero_steps_workflow(self):
"""Test zero-steps workflow edge case."""
router = multi_step_progress(total_steps=0)
assert router({"current_step": 0}) == "complete"
assert router({"current_step": -1}) == "error"
def test_string_numeric_conversion(self):
"""Test conversion of string numbers."""
router = multi_step_progress(total_steps=5)
assert router({"current_step": "2"}) == "next_step"
assert router({"current_step": "5"}) == "complete"
assert router({"current_step": "-1"}) == "error"
class TestCheckIterationLimit:
"""Test check_iteration_limit function with various edge cases."""
def test_within_limit(self):
"""Test iterations within limit."""
router = check_iteration_limit(max_iterations=10)
assert router({"iteration_count": 0}) == "continue"
assert router({"iteration_count": 5}) == "continue"
assert router({"iteration_count": 9}) == "continue"
def test_limit_reached(self):
"""Test when iteration limit is reached."""
router = check_iteration_limit(max_iterations=10)
assert router({"iteration_count": 10}) == "limit_reached"
assert router({"iteration_count": 15}) == "limit_reached"
def test_missing_iteration_count(self):
"""Test behavior when iteration count is missing."""
router = check_iteration_limit(max_iterations=5)
# Should default to 0 and return continue
assert router({}) == "continue"
def test_invalid_iteration_types(self):
"""Test handling of invalid iteration count types."""
router = check_iteration_limit(max_iterations=5)
# Invalid types should default to continue
assert router({"iteration_count": "invalid"}) == "continue"
assert router({"iteration_count": []}) == "continue"
assert router({"iteration_count": None}) == "continue"
def test_custom_iteration_key(self):
"""Test with custom iteration key."""
router = check_iteration_limit(max_iterations=3, iteration_key="retry_count")
assert router({"retry_count": 2}) == "continue"
assert router({"retry_count": 3}) == "limit_reached"
def test_zero_max_iterations(self):
"""Test with zero max iterations."""
router = check_iteration_limit(max_iterations=0)
assert router({"iteration_count": 0}) == "limit_reached"
assert router({"iteration_count": -1}) == "continue"
def test_negative_iteration_count(self):
"""Test with negative iteration count."""
router = check_iteration_limit(max_iterations=5)
assert router({"iteration_count": -1}) == "continue"
assert router({"iteration_count": -10}) == "continue"
class TestCheckCompletionCriteria:
"""Test check_completion_criteria function with various edge cases."""
def test_all_criteria_met(self):
"""Test when all completion criteria are met."""
router = check_completion_criteria(["task1", "task2", "task3"])
state = {
"completed_task1": True,
"completed_task2": True,
"completed_task3": True,
}
assert router(state) == "complete"
def test_some_criteria_missing(self):
"""Test when some criteria are not met."""
router = check_completion_criteria(["task1", "task2", "task3"])
state = {
"completed_task1": True,
"completed_task2": False,
"completed_task3": True,
}
assert router(state) == "continue"
def test_no_criteria_met(self):
"""Test when no criteria are met."""
router = check_completion_criteria(["task1", "task2"])
state = {}
assert router(state) == "continue"
def test_custom_condition_prefix(self):
"""Test with custom condition prefix."""
router = check_completion_criteria(
["validation", "processing"], condition_prefix="done_"
)
state = {"done_validation": True, "done_processing": True}
assert router(state) == "complete"
state = {"done_validation": True, "done_processing": False}
assert router(state) == "continue"
def test_falsy_values_as_incomplete(self):
"""Test that falsy values are treated as incomplete."""
router = check_completion_criteria(["task1", "task2"])
# False, 0, None, empty string should be incomplete
state = {"completed_task1": False, "completed_task2": True}
assert router(state) == "continue"
state = {"completed_task1": 0, "completed_task2": True}
assert router(state) == "continue"
state = {"completed_task1": None, "completed_task2": True}
assert router(state) == "continue"
def test_truthy_values_as_complete(self):
"""Test that truthy values are treated as complete."""
router = check_completion_criteria(["task1", "task2"])
state = {"completed_task1": 1, "completed_task2": "done"}
assert router(state) == "complete"
state = {"completed_task1": [1, 2, 3], "completed_task2": {"status": "done"}}
assert router(state) == "complete"
def test_empty_criteria_list(self):
"""Test with empty criteria list."""
router = check_completion_criteria([])
# No criteria means always complete
assert router({}) == "complete"
assert router({"any_key": "any_value"}) == "complete"
class TestCheckWorkflowState:
"""Test check_workflow_state function with various edge cases."""
def test_valid_state_transitions(self):
"""Test valid state transitions."""
transitions = {
"init": "process",
"process": "validate",
"validate": "complete",
"complete": "end",
}
router = check_workflow_state(transitions)
assert router({"workflow_state": "init"}) == "process"
assert router({"workflow_state": "process"}) == "validate"
assert router({"workflow_state": "validate"}) == "complete"
assert router({"workflow_state": "complete"}) == "end"
def test_unknown_state_default(self):
"""Test unknown state uses default target."""
transitions = {"known": "next"}
router = check_workflow_state(transitions, default_target="fallback")
assert router({"workflow_state": "unknown"}) == "fallback"
assert router({"workflow_state": None}) == "fallback"
def test_missing_workflow_state(self):
"""Test missing workflow state key."""
transitions = {"init": "process"}
router = check_workflow_state(transitions)
assert router({}) == "error" # default_target="error"
def test_custom_state_key(self):
"""Test with custom state key."""
transitions = {"start": "middle", "middle": "end"}
router = check_workflow_state(
transitions, state_key="current_phase", default_target="unknown"
)
assert router({"current_phase": "start"}) == "middle"
assert router({"current_phase": "middle"}) == "end"
assert router({"current_phase": "invalid"}) == "unknown"
def test_numeric_state_values(self):
"""Test with numeric state values."""
transitions = {"1": "stage_one", "2": "stage_two"}
router = check_workflow_state(transitions)
# Numeric values should be converted to strings
assert router({"workflow_state": 1}) == "stage_one"
assert router({"workflow_state": "2"}) == "stage_two"
def test_case_sensitivity(self):
"""Test case sensitivity in state matching."""
transitions = {"Running": "monitor", "running": "process"}
router = check_workflow_state(transitions)
# Should be case sensitive
assert router({"workflow_state": "Running"}) == "monitor"
assert router({"workflow_state": "running"}) == "process"
assert router({"workflow_state": "RUNNING"}) == "error"
def test_complex_state_names(self):
"""Test with complex state names."""
transitions = {
"user_input_required": "wait_for_input",
"processing_with_retry": "retry_handler",
"validation_failed_critical": "escalate",
}
router = check_workflow_state(transitions)
assert router({"workflow_state": "user_input_required"}) == "wait_for_input"
assert router({"workflow_state": "processing_with_retry"}) == "retry_handler"
assert router({"workflow_state": "validation_failed_critical"}) == "escalate"
class TestIntegrationAndEdgeCases:
"""Test integration scenarios and edge cases."""
def test_state_protocol_consistency(self):
"""Test that all functions work consistently with StateProtocol."""
state_data = {
"messages": [{"tool_calls": [{"function": "search"}]}],
"start_time": time.time() - 100,
"current_step": 2,
"iteration_count": 3,
"completed_task1": True,
"completed_task2": False,
"workflow_state": "processing",
}
dict_state = state_data
protocol_state = MockState(state_data)
# All functions should return same results for both state types
assert should_continue(dict_state) == should_continue(protocol_state)
timeout_router = timeout_check(300.0)
assert timeout_router(dict_state) == timeout_router(protocol_state)
progress_router = multi_step_progress(5)
assert progress_router(dict_state) == progress_router(protocol_state)
limit_router = check_iteration_limit(10)
assert limit_router(dict_state) == limit_router(protocol_state)
criteria_router = check_completion_criteria(["task1", "task2"])
assert criteria_router(dict_state) == criteria_router(protocol_state)
workflow_router = check_workflow_state({"processing": "next"})
assert workflow_router(dict_state) == workflow_router(protocol_state)
def test_performance_with_large_message_lists(self):
"""Test performance with large message lists."""
large_message_list = [
{"role": "assistant", "content": f"Message {i}"} for i in range(1000)
]
# Add tool call to last message
large_message_list[-1]["tool_calls"] = "search"
state = {"messages": large_message_list}
# Should efficiently check only the last message
assert should_continue(state) == "continue"
def test_memory_safety_with_recursive_structures(self):
"""Test memory safety with recursive data structures."""
# Create a recursive structure (not recommended but should not crash)
recursive_dict = {}
recursive_dict["self"] = recursive_dict
state = {"workflow_state": recursive_dict}
# Should handle gracefully by converting to string
workflow_router = check_workflow_state({"[object Object]": "next"})
result = workflow_router(state)
# Should not crash and return default
assert result == "error"
def test_concurrent_router_instances(self):
"""Test multiple router instances don't interfere."""
router1 = timeout_check(timeout_seconds=100.0)
router2 = timeout_check(timeout_seconds=200.0)
with patch("time.time", return_value=1000.0):
state = {"start_time": 850.0} # 150 seconds ago
# Different timeouts should give different results
assert router1(state) == "timeout" # 150 > 100
assert router2(state) == "continue" # 150 < 200
def test_extreme_numeric_values(self):
"""Test handling of extreme numeric values."""
# Very large iteration count
limit_router = check_iteration_limit(max_iterations=10)
assert limit_router({"iteration_count": 2**63 - 1}) == "limit_reached"
# Very large step count
progress_router = multi_step_progress(total_steps=5)
assert progress_router({"current_step": 10**10}) == "complete"
# Very old start time
with patch("time.time", return_value=1000.0):
timeout_router = timeout_check(timeout_seconds=60.0)
assert timeout_router({"start_time": -1000000.0}) == "timeout"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,934 @@
"""Comprehensive tests for user interaction edge helpers.
Tests cover edge cases including:
- Various interrupt signal formats
- Complex user authorization scenarios
- Multiple feedback conditions
- Escalation triggers and edge cases
- Input collection edge cases
"""
from typing import Any
from bb_core.edge_helpers.user_interaction import (
check_user_authorization,
collect_user_input,
escalate_to_human,
human_interrupt,
pass_status_to_user,
user_feedback_loop,
)
class MockState:
"""Mock state object for testing."""
def __init__(self, data: dict[str, Any]):
self._data = data
def get(self, key: str, default: Any = None) -> Any:
return self._data.get(key, default)
def __getattr__(self, name: str) -> Any:
return self._data.get(name)
class TestHumanInterrupt:
"""Test human_interrupt function with various edge cases."""
def test_default_interrupt_signals(self):
"""Test default interrupt signals."""
router = human_interrupt()
# Default interrupt signals
assert router({"human_interrupt": "stop"}) == "interrupt"
assert router({"human_interrupt": "pause"}) == "interrupt"
assert router({"human_interrupt": "cancel"}) == "interrupt"
assert router({"human_interrupt": "abort"}) == "interrupt"
assert router({"human_interrupt": "interrupt"}) == "interrupt"
assert router({"human_interrupt": "halt"}) == "interrupt"
def test_non_interrupt_signals(self):
"""Test signals that should not trigger interrupt."""
router = human_interrupt()
assert router({"human_interrupt": "continue"}) == "continue"
assert router({"human_interrupt": "proceed"}) == "continue"
assert router({"human_interrupt": "go"}) == "continue"
assert router({"human_interrupt": "run"}) == "continue"
def test_case_insensitive_signals(self):
"""Test case insensitive signal matching."""
router = human_interrupt()
assert router({"human_interrupt": "STOP"}) == "interrupt"
assert router({"human_interrupt": "Cancel"}) == "interrupt"
assert router({"human_interrupt": "ABORT"}) == "interrupt"
def test_whitespace_handling(self):
"""Test whitespace handling in signals."""
router = human_interrupt()
assert router({"human_interrupt": " stop "}) == "interrupt"
assert router({"human_interrupt": "\\tcancel\\n"}) == "interrupt"
assert router({"human_interrupt": " abort "}) == "interrupt"
def test_custom_interrupt_signals(self):
"""Test custom interrupt signals."""
router = human_interrupt(interrupt_signals=["quit", "exit", "emergency"])
assert router({"human_interrupt": "quit"}) == "interrupt"
assert router({"human_interrupt": "exit"}) == "interrupt"
assert router({"human_interrupt": "emergency"}) == "interrupt"
# Default signals should not work with custom list
assert router({"human_interrupt": "stop"}) == "continue"
def test_missing_interrupt_key(self):
"""Test behavior when interrupt key is missing."""
router = human_interrupt()
assert router({}) == "continue"
assert router({"other_key": "stop"}) == "continue"
def test_none_interrupt_signal(self):
"""Test behavior with None interrupt signal."""
router = human_interrupt()
assert router({"human_interrupt": None}) == "continue"
def test_custom_interrupt_key(self):
"""Test with custom interrupt key."""
router = human_interrupt(interrupt_key="user_command")
assert router({"user_command": "stop"}) == "interrupt"
assert router({"user_command": "continue"}) == "continue"
def test_numeric_signals(self):
"""Test numeric signals converted to strings."""
router = human_interrupt(interrupt_signals=["0", "1"])
assert router({"human_interrupt": 0}) == "interrupt"
assert router({"human_interrupt": "1"}) == "interrupt"
assert router({"human_interrupt": 2}) == "continue"
def test_empty_signal(self):
"""Test empty signal handling."""
router = human_interrupt()
assert router({"human_interrupt": ""}) == "continue"
assert router({"human_interrupt": " "}) == "continue" # Only whitespace
def test_state_protocol_implementation(self):
"""Test with StateProtocol implementation."""
router = human_interrupt()
state = MockState({"human_interrupt": "stop"})
assert router(state) == "interrupt"
state = MockState({"human_interrupt": "continue"})
assert router(state) == "continue"
class TestPassStatusToUser:
"""Test pass_status_to_user function with various edge cases."""
def test_default_status_levels(self):
"""Test default status level mappings."""
router = pass_status_to_user()
assert router({"status": "error", "notify_user": True}) == "urgent"
assert router({"status": "failed", "notify_user": True}) == "urgent"
assert router({"status": "warning", "notify_user": True}) == "medium"
assert router({"status": "completed", "notify_user": True}) == "low"
assert router({"status": "success", "notify_user": True}) == "low"
assert router({"status": "in_progress", "notify_user": True}) == "info"
def test_notification_disabled(self):
"""Test when notifications are disabled."""
router = pass_status_to_user()
assert router({"status": "error", "notify_user": False}) == "no_notification"
assert router({"status": "failed", "notify_user": False}) == "no_notification"
def test_missing_notify_key(self):
"""Test default notification behavior when notify key is missing."""
router = pass_status_to_user()
# Should default to True (notify)
assert router({"status": "error"}) == "urgent"
assert router({"status": "warning"}) == "medium"
def test_custom_status_levels(self):
"""Test custom status level mappings."""
custom_levels = {
"critical": "immediate",
"high": "urgent",
"normal": "standard",
}
router = pass_status_to_user(status_levels=custom_levels)
assert router({"status": "critical", "notify_user": True}) == "immediate"
assert router({"status": "high", "notify_user": True}) == "urgent"
assert router({"status": "normal", "notify_user": True}) == "standard"
# Unknown status should return no_notification
assert router({"status": "unknown", "notify_user": True}) == "no_notification"
def test_custom_keys(self):
"""Test with custom status and notify keys."""
router = pass_status_to_user(
status_key="current_status", notify_key="send_notification"
)
state = {"current_status": "error", "send_notification": True}
assert router(state) == "urgent"
state = {"current_status": "error", "send_notification": False}
assert router(state) == "no_notification"
def test_missing_status_key(self):
"""Test behavior when status key is missing."""
router = pass_status_to_user()
assert router({"notify_user": True}) == "no_notification"
assert router({}) == "no_notification"
def test_none_status(self):
"""Test behavior with None status."""
router = pass_status_to_user()
assert router({"status": None, "notify_user": True}) == "no_notification"
def test_case_insensitive_status(self):
"""Test case insensitive status matching."""
router = pass_status_to_user()
assert router({"status": "ERROR", "notify_user": True}) == "urgent"
assert router({"status": "Warning", "notify_user": True}) == "medium"
assert router({"status": "COMPLETED", "notify_user": True}) == "low"
def test_numeric_status(self):
"""Test numeric status values."""
custom_levels = {"1": "low", "2": "medium", "3": "high"}
router = pass_status_to_user(status_levels=custom_levels)
assert router({"status": 1, "notify_user": True}) == "low"
assert router({"status": "2", "notify_user": True}) == "medium"
assert router({"status": 3, "notify_user": True}) == "high"
def test_complex_status_values(self):
"""Test complex status values converted to strings."""
router = pass_status_to_user()
# Complex objects should be converted to strings
assert router({"status": ["error"], "notify_user": True}) == "no_notification"
assert (
router({"status": {"type": "error"}, "notify_user": True})
== "no_notification"
)
class TestUserFeedbackLoop:
"""Test user_feedback_loop function with various edge cases."""
def test_explicit_feedback_required(self):
"""Test explicit feedback requirement."""
router = user_feedback_loop()
state = {"requires_feedback": True, "feedback_type": "validation"}
assert router(state) == "feedback_validation"
state = {"requires_feedback": True, "feedback_type": "clarification"}
assert router(state) == "feedback_clarification"
def test_default_feedback_type(self):
"""Test default feedback type when not specified."""
router = user_feedback_loop()
state = {"requires_feedback": True}
assert router(state) == "feedback_general"
def test_condition_based_feedback(self):
"""Test feedback triggered by conditions."""
router = user_feedback_loop()
# Default conditions that trigger feedback
assert router({"low_confidence": True}) == "feedback_low_confidence"
assert router({"ambiguous_input": True}) == "feedback_ambiguous_input"
assert router({"multiple_options": True}) == "feedback_multiple_options"
assert router({"validation_failed": True}) == "feedback_validation_failed"
def test_no_feedback_needed(self):
"""Test when no feedback is needed."""
router = user_feedback_loop()
assert router({}) == "no_feedback"
assert router({"requires_feedback": False}) == "no_feedback"
assert router({"low_confidence": False}) == "no_feedback"
def test_custom_feedback_conditions(self):
"""Test custom feedback conditions."""
custom_conditions = ["needs_review", "unclear_intent", "missing_data"]
router = user_feedback_loop(feedback_required_conditions=custom_conditions)
assert router({"needs_review": True}) == "feedback_needs_review"
assert router({"unclear_intent": True}) == "feedback_unclear_intent"
assert router({"missing_data": True}) == "feedback_missing_data"
# Default conditions should not trigger with custom list
assert router({"low_confidence": True}) == "no_feedback"
def test_custom_feedback_keys(self):
"""Test with custom feedback keys."""
router = user_feedback_loop(
feedback_key="user_input_needed", feedback_type_key="input_type"
)
state = {"user_input_needed": True, "input_type": "choice"}
assert router(state) == "feedback_choice"
def test_multiple_conditions_triggered(self):
"""Test when multiple conditions are triggered."""
router = user_feedback_loop()
# Should return feedback for first matching condition
state = {"low_confidence": True, "ambiguous_input": True}
result = router(state)
assert result in ["feedback_low_confidence", "feedback_ambiguous_input"]
def test_priority_explicit_over_conditions(self):
"""Test that explicit feedback takes priority over conditions."""
router = user_feedback_loop()
state = {
"requires_feedback": True,
"feedback_type": "explicit",
"low_confidence": True, # This should be ignored
}
assert router(state) == "feedback_explicit"
def test_falsy_conditions(self):
"""Test that falsy condition values don't trigger feedback."""
router = user_feedback_loop()
state = {
"low_confidence": False,
"ambiguous_input": 0,
"multiple_options": None,
"validation_failed": "",
}
assert router(state) == "no_feedback"
def test_truthy_non_boolean_conditions(self):
"""Test truthy non-boolean values trigger feedback."""
router = user_feedback_loop()
assert router({"low_confidence": 1}) == "feedback_low_confidence"
assert router({"ambiguous_input": "yes"}) == "feedback_ambiguous_input"
assert router({"multiple_options": [1, 2, 3]}) == "feedback_multiple_options"
def test_state_protocol_implementation(self):
"""Test with StateProtocol implementation."""
router = user_feedback_loop()
state = MockState({"requires_feedback": True, "feedback_type": "test"})
assert router(state) == "feedback_test"
state = MockState({"low_confidence": True})
assert router(state) == "feedback_low_confidence"
class TestEscalateToHuman:
"""Test escalate_to_human function with various edge cases."""
def test_automatic_escalation(self):
"""Test automatic escalation based on reason."""
router = escalate_to_human()
state = {"auto_escalate": True, "escalation_reason": "critical_error"}
assert router(state) == "escalate_immediate"
state = {"auto_escalate": True, "escalation_reason": "security_issue"}
assert router(state) == "escalate_immediate"
state = {"auto_escalate": True, "escalation_reason": "multiple_failures"}
assert router(state) == "escalate_urgent"
def test_trigger_based_escalation(self):
"""Test escalation based on trigger flags."""
router = escalate_to_human()
assert router({"critical_error": True}) == "escalate_immediate"
assert router({"security_issue": True}) == "escalate_immediate"
assert router({"multiple_failures": True}) == "escalate_urgent"
assert router({"user_request": True}) == "escalate_standard"
assert router({"manual_review": True}) == "escalate_standard"
assert router({"complex_case": True}) == "escalate_expert"
def test_no_escalation_needed(self):
"""Test when no escalation is needed."""
router = escalate_to_human()
assert router({}) == "continue_automated"
assert router({"auto_escalate": False}) == "continue_automated"
assert router({"some_other_flag": True}) == "continue_automated"
def test_custom_escalation_triggers(self):
"""Test custom escalation triggers."""
custom_triggers = {
"data_breach": "immediate",
"system_down": "urgent",
"user_complaint": "standard",
}
router = escalate_to_human(escalation_triggers=custom_triggers)
assert router({"data_breach": True}) == "escalate_immediate"
assert router({"system_down": True}) == "escalate_urgent"
assert router({"user_complaint": True}) == "escalate_standard"
# Default triggers should not work with custom triggers
assert router({"critical_error": True}) == "continue_automated"
def test_custom_escalation_keys(self):
"""Test with custom escalation keys."""
router = escalate_to_human(
auto_escalate_key="force_escalate", escalation_reason_key="reason"
)
state = {"force_escalate": True, "reason": "critical_error"}
assert router(state) == "escalate_immediate"
def test_partial_reason_matching(self):
"""Test partial reason matching in escalation reason."""
router = escalate_to_human()
state = {
"auto_escalate": True,
"escalation_reason": "encountered critical_error in processing",
}
assert router(state) == "escalate_immediate"
state = {
"auto_escalate": True,
"escalation_reason": "multiple_failures detected",
}
assert router(state) == "escalate_urgent"
def test_unknown_escalation_reason(self):
"""Test unknown escalation reason defaults to standard."""
router = escalate_to_human()
state = {"auto_escalate": True, "escalation_reason": "unknown_issue"}
assert router(state) == "escalate_standard"
def test_missing_escalation_reason(self):
"""Test auto escalate without reason."""
router = escalate_to_human()
state = {"auto_escalate": True}
assert router(state) == "continue_automated"
def test_case_insensitive_reason_matching(self):
"""Test case insensitive reason matching."""
router = escalate_to_human()
state = {"auto_escalate": True, "escalation_reason": "CRITICAL_ERROR"}
assert router(state) == "escalate_immediate"
state = {"auto_escalate": True, "escalation_reason": "Security_Issue"}
assert router(state) == "escalate_immediate"
def test_multiple_triggers_precedence(self):
"""Test precedence when multiple triggers are present."""
router = escalate_to_human()
# Should find the first trigger in the dictionary order
state = {"critical_error": True, "user_request": True, "complex_case": True}
result = router(state)
assert result.startswith("escalate_")
def test_state_protocol_implementation(self):
"""Test with StateProtocol implementation."""
router = escalate_to_human()
state = MockState(
{"auto_escalate": True, "escalation_reason": "critical_error"}
)
assert router(state) == "escalate_immediate"
state = MockState({"critical_error": True})
assert router(state) == "escalate_immediate"
class TestCheckUserAuthorization:
"""Test check_user_authorization function with various edge cases."""
def test_authorized_user(self):
"""Test user with all required permissions."""
router = check_user_authorization(["read", "write"])
user = {"permissions": ["read", "write", "admin"]}
assert router({"user": user}) == "authorized"
def test_unauthorized_user_missing_permissions(self):
"""Test user missing required permissions."""
router = check_user_authorization(["read", "write", "admin"])
user = {"permissions": ["read", "write"]} # Missing admin
assert router({"user": user}) == "unauthorized"
def test_unauthorized_user_no_permissions(self):
"""Test user with no permissions."""
router = check_user_authorization(["read"])
user = {"permissions": []}
assert router({"user": user}) == "unauthorized"
user = {} # No permissions key
assert router({"user": user}) == "unauthorized"
def test_missing_user(self):
"""Test when user is missing from state."""
router = check_user_authorization(["read"])
assert router({}) == "unauthorized"
assert router({"other_key": "value"}) == "unauthorized"
def test_none_user(self):
"""Test when user is None."""
router = check_user_authorization(["read"])
assert router({"user": None}) == "unauthorized"
def test_user_object_with_attributes(self):
"""Test user object with permission attributes."""
class MockUser:
def __init__(self, permissions):
self.permissions = permissions
router = check_user_authorization(["read", "write"])
user = MockUser(["read", "write", "admin"])
assert router({"user": user}) == "authorized"
user = MockUser(["read"]) # Missing write
assert router({"user": user}) == "unauthorized"
def test_custom_permission_keys(self):
"""Test with custom user and permissions keys."""
router = check_user_authorization(
required_permissions=["access"],
user_key="current_user",
permissions_key="roles",
)
user = {"roles": ["access", "modify"]}
assert router({"current_user": user}) == "authorized"
user = {"roles": ["modify"]} # Missing access
assert router({"current_user": user}) == "unauthorized"
def test_default_permissions(self):
"""Test default required permissions."""
router = check_user_authorization() # Default: ["basic_access"]
user = {"permissions": ["basic_access"]}
assert router({"user": user}) == "authorized"
user = {"permissions": ["other_permission"]}
assert router({"user": user}) == "unauthorized"
def test_empty_required_permissions(self):
"""Test with empty required permissions list."""
router = check_user_authorization(required_permissions=[])
# No requirements means everyone is authorized
user = {"permissions": []}
assert router({"user": user}) == "authorized"
user = {}
assert router({"user": user}) == "authorized"
def test_non_list_permissions(self):
"""Test when user permissions is not a list."""
router = check_user_authorization(["read"])
user = {"permissions": "read"} # String instead of list
assert router({"user": user}) == "unauthorized"
user = {"permissions": {"read": True}} # Dict instead of list
assert router({"user": user}) == "unauthorized"
def test_case_sensitive_permissions(self):
"""Test case sensitivity of permission matching."""
router = check_user_authorization(["Read", "Write"])
user = {"permissions": ["read", "write"]} # Different case
assert router({"user": user}) == "unauthorized"
user = {"permissions": ["Read", "Write"]} # Exact case
assert router({"user": user}) == "authorized"
def test_user_object_missing_permissions_attribute(self):
"""Test user object missing permissions attribute."""
class MockUser:
def __init__(self):
self.name = "John"
router = check_user_authorization(["read"])
user = MockUser() # No permissions attribute
assert router({"user": user}) == "unauthorized"
def test_state_protocol_implementation(self):
"""Test with StateProtocol implementation."""
router = check_user_authorization(["read", "write"])
user = {"permissions": ["read", "write"]}
state = MockState({"user": user})
assert router(state) == "authorized"
user = {"permissions": ["read"]}
state = MockState({"user": user})
assert router(state) == "unauthorized"
class TestCollectUserInput:
"""Test collect_user_input function with various edge cases."""
def test_pending_input_with_default_types(self):
"""Test pending input with default input types."""
router = collect_user_input()
assert (
router({"pending_user_input": True, "input_type": "text"}) == "text_input"
)
assert (
router({"pending_user_input": True, "input_type": "choice"})
== "choice_input"
)
assert (
router({"pending_user_input": True, "input_type": "confirmation"})
== "confirm_input"
)
assert (
router({"pending_user_input": True, "input_type": "file"}) == "file_input"
)
assert (
router({"pending_user_input": True, "input_type": "numeric"})
== "number_input"
)
def test_no_input_needed(self):
"""Test when no input is needed."""
router = collect_user_input()
assert router({"pending_user_input": False}) == "no_input_needed"
assert router({}) == "no_input_needed" # Missing key defaults to False
def test_default_input_type(self):
"""Test default input type when not specified."""
router = collect_user_input()
assert router({"pending_user_input": True}) == "text_input" # Default to text
def test_custom_input_types(self):
"""Test custom input type mappings."""
custom_types = {
"voice": "voice_recorder",
"image": "image_capture",
"location": "gps_picker",
}
router = collect_user_input(input_types=custom_types)
assert (
router({"pending_user_input": True, "input_type": "voice"})
== "voice_recorder"
)
assert (
router({"pending_user_input": True, "input_type": "image"})
== "image_capture"
)
assert (
router({"pending_user_input": True, "input_type": "location"})
== "gps_picker"
)
# Unknown type should default to text_input
assert (
router({"pending_user_input": True, "input_type": "unknown"})
== "text_input"
)
def test_custom_input_keys(self):
"""Test with custom input keys."""
router = collect_user_input(
pending_input_key="awaiting_input", input_type_key="required_input_type"
)
state = {"awaiting_input": True, "required_input_type": "choice"}
assert router(state) == "choice_input"
state = {"awaiting_input": False, "required_input_type": "text"}
assert router(state) == "no_input_needed"
def test_case_insensitive_input_types(self):
"""Test case insensitive input type matching."""
router = collect_user_input()
assert (
router({"pending_user_input": True, "input_type": "TEXT"}) == "text_input"
)
assert (
router({"pending_user_input": True, "input_type": "Choice"})
== "choice_input"
)
assert (
router({"pending_user_input": True, "input_type": "FILE"}) == "file_input"
)
def test_numeric_input_type(self):
"""Test numeric input type values."""
custom_types = {"1": "type_one", "2": "type_two"}
router = collect_user_input(input_types=custom_types)
assert router({"pending_user_input": True, "input_type": 1}) == "type_one"
assert router({"pending_user_input": True, "input_type": "2"}) == "type_two"
def test_truthy_falsy_pending_values(self):
"""Test truthy/falsy values for pending input."""
router = collect_user_input()
# Truthy values should indicate pending input
assert router({"pending_user_input": 1, "input_type": "text"}) == "text_input"
assert (
router({"pending_user_input": "yes", "input_type": "text"}) == "text_input"
)
assert router({"pending_user_input": [1], "input_type": "text"}) == "text_input"
# Falsy values should indicate no input needed
assert (
router({"pending_user_input": 0, "input_type": "text"}) == "no_input_needed"
)
assert (
router({"pending_user_input": "", "input_type": "text"})
== "no_input_needed"
)
assert (
router({"pending_user_input": [], "input_type": "text"})
== "no_input_needed"
)
assert (
router({"pending_user_input": None, "input_type": "text"})
== "no_input_needed"
)
def test_fallback_to_text_input(self):
"""Test fallback to text input for unknown types."""
router = collect_user_input()
assert (
router({"pending_user_input": True, "input_type": "unknown_type"})
== "text_input"
)
assert router({"pending_user_input": True, "input_type": None}) == "text_input"
assert router({"pending_user_input": True, "input_type": ""}) == "text_input"
def test_state_protocol_implementation(self):
"""Test with StateProtocol implementation."""
router = collect_user_input()
state = MockState({"pending_user_input": True, "input_type": "choice"})
assert router(state) == "choice_input"
state = MockState({"pending_user_input": False})
assert router(state) == "no_input_needed"
class TestIntegrationAndEdgeCases:
"""Test integration scenarios and edge cases."""
def test_all_user_interaction_functions_consistency(self):
"""Test all user interaction functions work consistently."""
state = {
"human_interrupt": "continue",
"status": "warning",
"notify_user": True,
"requires_feedback": False,
"low_confidence": False,
"auto_escalate": False,
"user": {"permissions": ["read", "write"]},
"pending_user_input": False,
}
interrupt_router = human_interrupt()
status_router = pass_status_to_user()
feedback_router = user_feedback_loop()
escalation_router = escalate_to_human()
auth_router = check_user_authorization(["read", "write"])
input_router = collect_user_input()
assert interrupt_router(state) == "continue"
assert status_router(state) == "medium"
assert feedback_router(state) == "no_feedback"
assert escalation_router(state) == "continue_automated"
assert auth_router(state) == "authorized"
assert input_router(state) == "no_input_needed"
def test_state_protocol_consistency_across_functions(self):
"""Test StateProtocol consistency across all functions."""
state_data = {
"human_interrupt": "stop",
"status": "error",
"notify_user": True,
"requires_feedback": True,
"feedback_type": "urgent",
"critical_error": True,
"user": {"permissions": ["admin"]},
"pending_user_input": True,
"input_type": "choice",
}
dict_state = state_data
protocol_state = MockState(state_data)
# All functions should return same results for both state types
interrupt_router = human_interrupt()
assert interrupt_router(dict_state) == interrupt_router(protocol_state)
status_router = pass_status_to_user()
assert status_router(dict_state) == status_router(protocol_state)
feedback_router = user_feedback_loop()
assert feedback_router(dict_state) == feedback_router(protocol_state)
escalation_router = escalate_to_human()
assert escalation_router(dict_state) == escalation_router(protocol_state)
auth_router = check_user_authorization(["admin"])
assert auth_router(dict_state) == auth_router(protocol_state)
input_router = collect_user_input()
assert input_router(dict_state) == input_router(protocol_state)
def test_cascading_user_interaction_scenario(self):
"""Test cascading user interaction scenario."""
# Start with normal operation
state = {"status": "in_progress", "notify_user": True}
status_router = pass_status_to_user()
assert status_router(state) == "info"
# User requests interrupt
state["human_interrupt"] = "pause"
interrupt_router = human_interrupt()
assert interrupt_router(state) == "interrupt"
# System needs feedback due to interrupt
state["requires_feedback"] = True
state["feedback_type"] = "confirmation"
feedback_router = user_feedback_loop()
assert feedback_router(state) == "feedback_confirmation"
# Collect user input for feedback
state["pending_user_input"] = True
state["input_type"] = "confirmation"
input_router = collect_user_input()
assert input_router(state) == "confirm_input"
def test_security_and_authorization_flow(self):
"""Test security and authorization flow."""
# User with insufficient permissions
unauthorized_user = {"permissions": ["read"]}
state: dict[str, Any] = {"user": unauthorized_user}
auth_router = check_user_authorization(["read", "write", "admin"])
assert auth_router(state) == "unauthorized"
# Should escalate due to unauthorized access
state["security_issue"] = True
escalation_router = escalate_to_human()
assert escalation_router(state) == "escalate_immediate"
# Should notify user of security issue
state["status"] = "failed"
state["notify_user"] = True
status_router = pass_status_to_user()
assert status_router(state) == "urgent"
def test_error_recovery_interaction_flow(self):
"""Test error recovery interaction flow."""
# Multiple failures trigger escalation
state: dict[str, Any] = {"multiple_failures": True}
escalation_router = escalate_to_human()
assert escalation_router(state) == "escalate_urgent"
# Low confidence requires feedback
state["low_confidence"] = True
feedback_router = user_feedback_loop()
assert feedback_router(state) == "feedback_low_confidence"
# Need user input to resolve
state["pending_user_input"] = True
state["input_type"] = "choice"
input_router = collect_user_input()
assert input_router(state) == "choice_input"
# User might interrupt the process
state["human_interrupt"] = "abort"
interrupt_router = human_interrupt()
assert interrupt_router(state) == "interrupt"
def test_performance_with_complex_user_objects(self):
"""Test performance with complex user objects."""
# Create complex user object
complex_user = {
"id": "user123",
"profile": {
"name": "John Doe",
"department": "Engineering",
"level": "Senior",
},
"permissions": ["read", "write", "deploy"]
+ [f"perm_{i}" for i in range(100)],
"metadata": {
"login_count": 1500,
"last_active": "2024-01-01",
"preferences": {"theme": "dark", "notifications": True},
},
}
state = {"user": complex_user}
auth_router = check_user_authorization(["read", "write"])
# Should handle complex user object efficiently
assert auth_router(state) == "authorized"
def test_concurrent_router_usage(self):
"""Test that routers are stateless and thread-safe."""
interrupt_router = human_interrupt()
# Multiple calls with different states should not interfere
state1 = {"human_interrupt": "stop"}
state2 = {"human_interrupt": "continue"}
assert interrupt_router(state1) == "interrupt"
assert interrupt_router(state2) == "continue"
assert interrupt_router(state1) == "interrupt" # Should still work
def test_memory_efficiency_with_large_permission_lists(self):
"""Test memory efficiency with large permission lists."""
# Create user with many permissions
large_permissions = [f"permission_{i}" for i in range(1000)]
user = {"permissions": large_permissions}
state = {"user": user}
# Should handle large permission lists efficiently
auth_router = check_user_authorization(["permission_500"])
assert auth_router(state) == "authorized"
auth_router = check_user_authorization(["permission_not_in_list"])
assert auth_router(state) == "unauthorized"

View File

@@ -0,0 +1,813 @@
"""Comprehensive tests for validation edge helpers.
Tests cover edge cases including:
- Malformed JSON, XML, CSV formats
- Invalid confidence and accuracy scores
- Privacy pattern edge cases
- Timestamp parsing failures
- Required field validation edge cases
- Length validation boundary conditions
"""
import time
from datetime import datetime
from typing import Any
from unittest.mock import patch
from bb_core.edge_helpers.validation import (
check_accuracy,
check_confidence_level,
check_data_freshness,
check_output_length,
check_privacy_compliance,
validate_output_format,
validate_required_fields,
)
class MockState:
"""Mock state object for testing."""
def __init__(self, data: dict[str, Any]):
self._data = data
def get(self, key: str, default: Any = None) -> Any:
return self._data.get(key, default)
def __getattr__(self, name: str) -> Any:
return self._data.get(name)
class TestCheckAccuracy:
"""Test check_accuracy function with various edge cases."""
def test_high_accuracy_scores(self):
"""Test scores above threshold."""
router = check_accuracy(threshold=0.8)
assert router({"accuracy_score": 0.85}) == "high_accuracy"
assert router({"accuracy_score": 1.0}) == "high_accuracy"
assert router({"accuracy_score": 0.8}) == "high_accuracy" # Equal to threshold
def test_low_accuracy_scores(self):
"""Test scores below threshold."""
router = check_accuracy(threshold=0.8)
assert router({"accuracy_score": 0.75}) == "low_accuracy"
assert router({"accuracy_score": 0.0}) == "low_accuracy"
assert router({"accuracy_score": 0.79}) == "low_accuracy"
def test_string_numeric_conversion(self):
"""Test conversion of string numbers."""
router = check_accuracy(threshold=0.5)
assert router({"accuracy_score": "0.8"}) == "high_accuracy"
assert router({"accuracy_score": "0.3"}) == "low_accuracy"
assert router({"accuracy_score": "1"}) == "high_accuracy"
assert router({"accuracy_score": "0"}) == "low_accuracy"
def test_invalid_accuracy_values(self):
"""Test handling of invalid accuracy values."""
router = check_accuracy(threshold=0.5)
# Non-numeric values should route to low_accuracy
assert router({"accuracy_score": "invalid"}) == "low_accuracy"
assert router({"accuracy_score": []}) == "low_accuracy"
assert router({"accuracy_score": {}}) == "low_accuracy"
assert router({"accuracy_score": None}) == "low_accuracy"
def test_missing_accuracy_key(self):
"""Test behavior when accuracy key is missing."""
router = check_accuracy(threshold=0.5)
# Should default to 0.0 and route to low_accuracy
assert router({}) == "low_accuracy"
assert router({"other_key": 0.9}) == "low_accuracy"
def test_custom_accuracy_key(self):
"""Test with custom accuracy key."""
router = check_accuracy(threshold=0.7, accuracy_key="confidence_score")
assert router({"confidence_score": 0.8}) == "high_accuracy"
assert router({"confidence_score": 0.6}) == "low_accuracy"
def test_extreme_values(self):
"""Test extreme accuracy values."""
router = check_accuracy(threshold=0.5)
# Values above 1.0
assert router({"accuracy_score": 1.5}) == "high_accuracy"
assert router({"accuracy_score": 100.0}) == "high_accuracy"
# Negative values
assert router({"accuracy_score": -0.1}) == "low_accuracy"
assert router({"accuracy_score": -10.0}) == "low_accuracy"
def test_state_protocol_implementation(self):
"""Test with StateProtocol implementation."""
router = check_accuracy(threshold=0.8)
state = MockState({"accuracy_score": 0.9})
assert router(state) == "high_accuracy"
state = MockState({"accuracy_score": 0.7})
assert router(state) == "low_accuracy"
class TestCheckConfidenceLevel:
"""Test check_confidence_level function with various edge cases."""
def test_high_confidence_scores(self):
"""Test scores above threshold."""
router = check_confidence_level(threshold=0.7)
assert router({"confidence": 0.8}) == "high_confidence"
assert router({"confidence": 1.0}) == "high_confidence"
assert router({"confidence": 0.7}) == "high_confidence" # Equal to threshold
def test_low_confidence_scores(self):
"""Test scores below threshold."""
router = check_confidence_level(threshold=0.7)
assert router({"confidence": 0.6}) == "low_confidence"
assert router({"confidence": 0.0}) == "low_confidence"
assert router({"confidence": 0.69}) == "low_confidence"
def test_invalid_confidence_values(self):
"""Test handling of invalid confidence values."""
router = check_confidence_level(threshold=0.5)
# Non-numeric values should route to low_confidence
assert router({"confidence": "high"}) == "low_confidence"
assert router({"confidence": None}) == "low_confidence"
assert router({"confidence": []}) == "low_confidence"
def test_custom_confidence_key(self):
"""Test with custom confidence key."""
router = check_confidence_level(threshold=0.6, confidence_key="certainty")
assert router({"certainty": 0.8}) == "high_confidence"
assert router({"certainty": 0.5}) == "low_confidence"
def test_percentage_values(self):
"""Test percentage-style values."""
router = check_confidence_level(threshold=70.0)
assert router({"confidence": 80.0}) == "high_confidence"
assert router({"confidence": 60.0}) == "low_confidence"
assert router({"confidence": "75"}) == "high_confidence"
class TestValidateOutputFormat:
"""Test validate_output_format function with various edge cases."""
def test_valid_json_format(self):
"""Test valid JSON format validation."""
router = validate_output_format(expected_format="json")
valid_json = '{"key": "value", "number": 123}'
assert router({"output": valid_json}) == "valid_format"
# Complex JSON
complex_json = (
'{"users": [{"name": "John", "age": 30}, {"name": "Jane", "age": 25}]}'
)
assert router({"output": complex_json}) == "valid_format"
def test_invalid_json_format(self):
"""Test invalid JSON format validation."""
router = validate_output_format(expected_format="json")
# Malformed JSON
assert (
router({"output": '{"key": value}'}) == "invalid_format"
) # Missing quotes
assert router({"output": '{key: "value"}'}) == "invalid_format" # Unquoted key
assert (
router({"output": '{"key": "value",}'}) == "invalid_format"
) # Trailing comma
assert router({"output": "not json at all"}) == "invalid_format"
def test_valid_xml_format(self):
"""Test valid XML format validation."""
router = validate_output_format(expected_format="xml")
assert router({"output": "<root><item>value</item></root>"}) == "valid_format"
assert router({"output": "<simple/>"}) == "valid_format"
assert (
router({"output": " <root>content</root> "}) == "valid_format"
) # With whitespace
def test_invalid_xml_format(self):
"""Test invalid XML format validation."""
router = validate_output_format(expected_format="xml")
assert router({"output": "not xml"}) == "invalid_format"
assert router({"output": "<unclosed"}) == "invalid_format"
assert router({"output": "unclosed>"}) == "invalid_format"
assert router({"output": ""}) == "invalid_format"
def test_valid_csv_format(self):
"""Test valid CSV format validation."""
router = validate_output_format(expected_format="csv")
assert router({"output": "name,age,city\\nJohn,30,NYC"}) == "valid_format"
assert router({"output": "col1,col2,col3"}) == "valid_format"
assert router({"output": "a,b,c\\n1,2,3\\n4,5,6"}) == "valid_format"
def test_invalid_csv_format(self):
"""Test invalid CSV format validation."""
router = validate_output_format(expected_format="csv")
assert router({"output": "no commas here"}) == "invalid_format"
assert router({"output": ""}) == "invalid_format"
assert router({"output": "single_column"}) == "invalid_format"
def test_text_format_validation(self):
"""Test text format validation."""
router = validate_output_format(expected_format="text")
assert router({"output": "This is valid text"}) == "valid_format"
assert router({"output": "123"}) == "valid_format"
assert router({"output": "Multi\\nline\\ntext"}) == "valid_format"
# Empty or non-string should be invalid
assert router({"output": ""}) == "invalid_format"
assert router({"output": None}) == "invalid_format"
def test_custom_output_key(self):
"""Test with custom output key."""
router = validate_output_format(expected_format="json", output_key="result")
assert router({"result": '{"valid": true}'}) == "valid_format"
assert router({"result": "invalid json"}) == "invalid_format"
def test_missing_output_key(self):
"""Test behavior when output key is missing."""
router = validate_output_format(expected_format="json")
assert router({}) == "invalid_format"
assert router({"other_key": '{"valid": true}'}) == "invalid_format"
def test_case_insensitive_format(self):
"""Test case insensitive format specification."""
json_router = validate_output_format(expected_format="JSON")
xml_router = validate_output_format(expected_format="XML")
csv_router = validate_output_format(expected_format="CSV")
assert json_router({"output": '{"key": "value"}'}) == "valid_format"
assert xml_router({"output": "<root/>"}) == "valid_format"
assert csv_router({"output": "a,b,c"}) == "valid_format"
def test_numeric_output_handling(self):
"""Test handling of numeric output values."""
router = validate_output_format(expected_format="json")
# Numbers should be converted to strings for validation
assert (
router({"output": 123}) == "valid_format"
) # Valid JSON format (123 is valid JSON)
text_router = validate_output_format(expected_format="text")
assert text_router({"output": 123}) == "valid_format" # Valid as text
class TestCheckPrivacyCompliance:
"""Test check_privacy_compliance function with various edge cases."""
def test_ssn_detection(self):
"""Test SSN pattern detection."""
router = check_privacy_compliance()
# Valid SSN patterns
assert router({"content": "SSN: 123-45-6789"}) == "privacy_violation"
assert router({"content": "Contact info: 987-65-4321"}) == "privacy_violation"
# Invalid SSN patterns (but these might match other patterns)
assert (
router({"content": "Phone: 123-456-7890"}) == "privacy_violation"
) # Matches phone pattern
assert (
router({"content": "Code: 12-34-567"}) == "compliant"
) # Wrong format for any pattern
def test_credit_card_detection(self):
"""Test credit card pattern detection."""
router = check_privacy_compliance()
# Valid credit card patterns
assert router({"content": "Card: 1234 5678 9012 3456"}) == "privacy_violation"
assert router({"content": "CC: 1234-5678-9012-3456"}) == "privacy_violation"
assert router({"content": "Number: 1234567890123456"}) == "privacy_violation"
# Invalid patterns
assert router({"content": "Code: 123 456 789"}) == "compliant" # Too short
def test_email_detection(self):
"""Test email pattern detection."""
router = check_privacy_compliance()
# Valid email patterns
assert router({"content": "Contact: john@example.com"}) == "privacy_violation"
assert (
router({"content": "Email: user.name+tag@domain.co.uk"})
== "privacy_violation"
)
# Invalid email patterns
assert router({"content": "Not an email: john@"}) == "compliant"
assert router({"content": "@example.com"}) == "compliant"
def test_phone_number_detection(self):
"""Test phone number pattern detection."""
router = check_privacy_compliance()
# Valid phone patterns
assert router({"content": "Call: 123-456-7890"}) == "privacy_violation"
assert router({"content": "Phone: 123 456 7890"}) == "privacy_violation"
assert router({"content": "Contact: 1234567890"}) == "privacy_violation"
# Invalid patterns
assert (
router({"content": "Code: 123-45-6789"}) == "privacy_violation"
) # Matches SSN pattern
def test_custom_patterns(self):
"""Test custom sensitive patterns."""
custom_patterns = [
r"\bAPI[_-]?KEY[_-]?\w+", # API key pattern
r"\bTOKEN[_-]?\w{10,}", # Token pattern
]
router = check_privacy_compliance(sensitive_patterns=custom_patterns)
assert router({"content": "API_KEY_abc123def456"}) == "privacy_violation"
assert router({"content": "TOKEN_1234567890abcdef"}) == "privacy_violation"
assert router({"content": "Normal text here"}) == "compliant"
def test_case_insensitive_matching(self):
"""Test case insensitive pattern matching."""
router = check_privacy_compliance()
# Email should match regardless of case
assert router({"content": "CONTACT: JOHN@EXAMPLE.COM"}) == "privacy_violation"
assert router({"content": "Email: John@Example.Com"}) == "privacy_violation"
def test_multiple_violations(self):
"""Test content with multiple privacy violations."""
router = check_privacy_compliance()
content = (
"Contact John at john@example.com or call 123-456-7890. SSN: 123-45-6789"
)
assert router({"content": content}) == "privacy_violation"
def test_custom_content_key(self):
"""Test with custom content key."""
router = check_privacy_compliance(content_key="message")
assert router({"message": "Email: test@example.com"}) == "privacy_violation"
assert router({"message": "Clean content"}) == "compliant"
def test_missing_content_key(self):
"""Test behavior when content key is missing."""
router = check_privacy_compliance()
assert router({}) == "compliant"
assert router({"other_key": "john@example.com"}) == "compliant"
def test_none_content(self):
"""Test behavior with None content."""
router = check_privacy_compliance()
assert router({"content": None}) == "compliant"
def test_numeric_content(self):
"""Test handling of numeric content."""
router = check_privacy_compliance()
# Should convert to string for matching
assert (
router({"content": 1234567890123456}) == "privacy_violation"
) # Credit card
assert router({"content": 123456789}) == "compliant" # Too short
def test_empty_patterns_list(self):
"""Test with empty patterns list."""
router = check_privacy_compliance(sensitive_patterns=[])
# No patterns means everything is compliant
assert router({"content": "john@example.com"}) == "compliant"
assert router({"content": "123-45-6789"}) == "compliant"
class TestCheckDataFreshness:
"""Test check_data_freshness function with various edge cases."""
@patch("bb_core.edge_helpers.validation.time")
def test_fresh_data_within_limit(self, mock_time_module):
"""Test data within freshness limit."""
mock_time_module.time.return_value = 1000.0
router = check_data_freshness(max_age_seconds=3600) # 1 hour
# Data from 30 minutes ago
assert router({"timestamp": 800.0}) == "fresh"
@patch("bb_core.edge_helpers.validation.time")
def test_stale_data_beyond_limit(self, mock_time_module):
"""Test data beyond freshness limit."""
mock_time_module.time.return_value = 1000.0
router = check_data_freshness(max_age_seconds=3600) # 1 hour
# Data from 2 hours ago
assert router({"timestamp": -600.0}) == "stale"
@patch("bb_core.edge_helpers.validation.time")
def test_exact_freshness_boundary(self, mock_time_module):
"""Test exact freshness boundary."""
mock_time_module.time.return_value = 1000.0
router = check_data_freshness(max_age_seconds=3600)
# Exactly at limit
assert router({"timestamp": -2600.0}) == "fresh"
# Just over limit
assert router({"timestamp": -2601.0}) == "stale"
def test_iso_timestamp_parsing(self):
"""Test ISO format timestamp parsing."""
router = check_data_freshness(max_age_seconds=3600)
# Recent ISO timestamp
recent_iso = datetime.now().isoformat()
assert router({"timestamp": recent_iso}) == "fresh"
# Old ISO timestamp
old_iso = "2020-01-01T00:00:00Z"
assert router({"timestamp": old_iso}) == "stale"
def test_invalid_timestamp_formats(self):
"""Test handling of invalid timestamp formats."""
router = check_data_freshness(max_age_seconds=3600)
# Invalid formats should be considered stale
assert router({"timestamp": "invalid"}) == "stale"
assert router({"timestamp": "not-a-date"}) == "stale"
assert router({"timestamp": []}) == "stale"
assert router({"timestamp": {}}) == "stale"
def test_missing_timestamp_key(self):
"""Test behavior when timestamp key is missing."""
router = check_data_freshness(max_age_seconds=3600)
assert router({}) == "stale"
assert router({"other_key": time.time()}) == "stale"
def test_none_timestamp(self):
"""Test behavior with None timestamp."""
router = check_data_freshness(max_age_seconds=3600)
assert router({"timestamp": None}) == "stale"
def test_custom_timestamp_key(self):
"""Test with custom timestamp key."""
router = check_data_freshness(max_age_seconds=1800, timestamp_key="created_at")
with patch("bb_core.edge_helpers.validation.time") as mock_time_module:
mock_time_module.time.return_value = 1000.0
assert router({"created_at": 500.0}) == "fresh" # 500 seconds ago
assert router({"created_at": -1000.0}) == "stale" # 2000 seconds ago
def test_string_numeric_timestamp(self):
"""Test string numeric timestamp conversion."""
router = check_data_freshness(max_age_seconds=3600)
with patch("bb_core.edge_helpers.validation.time") as mock_time_module:
mock_time_module.time.return_value = 1000.0
assert router({"timestamp": "800.0"}) == "fresh"
assert router({"timestamp": "-600.0"}) == "stale"
def test_zero_max_age(self):
"""Test with zero max age."""
router = check_data_freshness(max_age_seconds=0)
with patch("bb_core.edge_helpers.validation.time") as mock_time_module:
mock_time_module.time.return_value = 1000.0
# Even current time should be stale with 0 max age
assert router({"timestamp": 1000.0}) == "fresh" # Exactly current
assert router({"timestamp": 999.9}) == "stale" # Slightly old
class TestValidateRequiredFields:
"""Test validate_required_fields function with various edge cases."""
def test_all_fields_present(self):
"""Test when all required fields are present."""
router = validate_required_fields(["field1", "field2", "field3"])
state = {"field1": "value1", "field2": "value2", "field3": "value3"}
assert router(state) == "valid"
def test_missing_fields(self):
"""Test when some fields are missing."""
router = validate_required_fields(["field1", "field2", "field3"])
# Missing field2
state = {"field1": "value1", "field3": "value3"}
assert router(state) == "missing_fields"
# All missing
assert router({}) == "missing_fields"
def test_strict_mode_empty_values(self):
"""Test strict mode with empty values."""
router = validate_required_fields(["field1", "field2"], strict_mode=True)
# Empty string should fail in strict mode
state = {"field1": "value", "field2": ""}
assert router(state) == "missing_fields"
# Empty list should fail in strict mode
state = {"field1": "value", "field2": []}
assert router(state) == "missing_fields"
def test_non_strict_mode_empty_values(self):
"""Test non-strict mode with empty values."""
router = validate_required_fields(["field1", "field2"], strict_mode=False)
# Empty string should pass in non-strict mode
state = {"field1": "value", "field2": ""}
assert router(state) == "valid"
# Empty list should pass in non-strict mode
state = {"field1": "value", "field2": []}
assert router(state) == "valid"
def test_none_values(self):
"""Test handling of None values."""
router = validate_required_fields(["field1", "field2"])
# None should always fail regardless of strict mode
state = {"field1": "value", "field2": None}
assert router(state) == "missing_fields"
def test_falsy_values_in_non_strict_mode(self):
"""Test falsy values in non-strict mode."""
router = validate_required_fields(["flag", "count"], strict_mode=False)
# False and 0 should be valid in non-strict mode
state = {"flag": False, "count": 0}
assert router(state) == "valid"
def test_falsy_values_in_strict_mode(self):
"""Test falsy values in strict mode."""
router = validate_required_fields(["flag", "count"], strict_mode=True)
# False and 0 should be valid even in strict mode (not empty)
state = {"flag": False, "count": 0}
assert router(state) == "valid"
def test_state_protocol_implementation(self):
"""Test with StateProtocol implementation."""
router = validate_required_fields(["field1", "field2"])
# Valid case
state = MockState({"field1": "value1", "field2": "value2"})
assert router(state) == "valid"
# Missing field case
state = MockState({"field1": "value1"})
assert router(state) == "missing_fields"
def test_complex_data_types(self):
"""Test with complex data types as field values."""
router = validate_required_fields(["data", "config"])
state = {
"data": {"nested": {"value": 123}},
"config": ["item1", "item2", {"key": "value"}],
}
assert router(state) == "valid"
def test_empty_required_fields_list(self):
"""Test with empty required fields list."""
router = validate_required_fields([])
# No requirements means always valid
assert router({}) == "valid"
assert router({"any_field": "any_value"}) == "valid"
class TestCheckOutputLength:
"""Test check_output_length function with various edge cases."""
def test_valid_length_within_bounds(self):
"""Test output length within bounds."""
router = check_output_length(min_length=5, max_length=20)
assert router({"output": "12345"}) == "valid_length" # Exactly min
assert router({"output": "1234567890"}) == "valid_length" # Middle
assert (
router({"output": "12345678901234567890"}) == "valid_length"
) # Exactly max
def test_output_too_short(self):
"""Test output shorter than minimum."""
router = check_output_length(min_length=10)
assert router({"output": "short"}) == "too_short"
assert router({"output": ""}) == "too_short"
assert router({"output": "123456789"}) == "too_short" # 9 chars, min is 10
def test_output_too_long(self):
"""Test output longer than maximum."""
router = check_output_length(min_length=1, max_length=10)
assert router({"output": "12345678901"}) == "too_long" # 11 chars, max is 10
assert router({"output": "a" * 100}) == "too_long"
def test_no_max_length_limit(self):
"""Test with no maximum length limit."""
router = check_output_length(min_length=5, max_length=None)
assert router({"output": "12345"}) == "valid_length"
assert (
router({"output": "a" * 1000}) == "valid_length"
) # Very long should be OK
def test_custom_content_key(self):
"""Test with custom content key."""
router = check_output_length(min_length=3, max_length=10, content_key="result")
assert router({"result": "hello"}) == "valid_length"
assert router({"result": "hi"}) == "too_short"
assert router({"result": "very long result text"}) == "too_long"
def test_missing_content_key(self):
"""Test behavior when content key is missing."""
router = check_output_length(min_length=5)
# Should default to empty string and be too short
assert router({}) == "too_short"
assert router({"other_key": "valid content"}) == "too_short"
def test_none_content(self):
"""Test behavior with None content."""
router = check_output_length(min_length=1)
# None should be converted to string "None" (4 chars)
assert router({"output": None}) == "valid_length"
def test_numeric_content(self):
"""Test handling of numeric content."""
router = check_output_length(min_length=3, max_length=5)
assert router({"output": 123}) == "valid_length" # "123" = 3 chars
assert router({"output": 12}) == "too_short" # "12" = 2 chars
assert router({"output": 123456}) == "too_long" # "123456" = 6 chars
def test_zero_min_length(self):
"""Test with zero minimum length."""
router = check_output_length(min_length=0, max_length=5)
assert router({"output": ""}) == "valid_length"
assert router({"output": "hello"}) == "valid_length"
assert router({"output": "toolong"}) == "too_long"
def test_equal_min_max_length(self):
"""Test with equal min and max length."""
router = check_output_length(min_length=5, max_length=5)
assert router({"output": "12345"}) == "valid_length"
assert router({"output": "1234"}) == "too_short"
assert router({"output": "123456"}) == "too_long"
def test_unicode_character_counting(self):
"""
Test Unicode character counting.
Note: Length uses Python's built-in len(), which counts Unicode code points,
not grapheme clusters. Some user-perceived characters (like certain emojis or
emoji sequences) may be counted as more than one character.
"""
router = check_output_length(min_length=3, max_length=5)
# Unicode code points are counted; some emojis may be more than one.
# For example, "🔥❄️🌞" is 4 code points: '🔥', '❄', '' (variation), '🌞'
assert (
router({"output": "🔥❄️🌞"}) == "valid_length"
) # 4 code points (note: '❄️' is two code points: '❄' + '')
assert router({"output": "café"}) == "valid_length" # 4 code points with é
# "🔥❄️" is 3 code points: '🔥', '❄', ''
assert (
router({"output": "🔥❄️"}) == "valid_length"
) # 3 code points (emojis may be multi-codepoint)
def test_multiline_content(self):
"""Test multiline content length calculation."""
router = check_output_length(min_length=10, max_length=20)
multiline = "line1\\nline2\\nline3" # 17 chars including newlines
assert router({"output": multiline}) == "valid_length"
def test_state_protocol_implementation(self):
"""Test with StateProtocol implementation."""
router = check_output_length(min_length=5, max_length=15)
state = MockState({"output": "valid output"})
assert router(state) == "valid_length"
state = MockState({"output": "hi"})
assert router(state) == "too_short"
class TestIntegrationAndEdgeCases:
"""Test integration scenarios and edge cases."""
def test_all_validators_with_same_state(self):
"""Test all validators work consistently with same state."""
state = {
"accuracy_score": 0.85,
"confidence": 0.9,
"output": '{"result": "success", "data": [1, 2, 3]}',
"content": "This is clean content without sensitive data",
"timestamp": time.time() - 1800, # 30 minutes ago
"user_id": "12345",
"request_data": {"query": "test"},
"response_time": "fast",
}
accuracy_router = check_accuracy(threshold=0.8)
confidence_router = check_confidence_level(threshold=0.8)
format_router = validate_output_format(expected_format="json")
privacy_router = check_privacy_compliance()
freshness_router = check_data_freshness(max_age_seconds=3600)
fields_router = validate_required_fields(["user_id", "request_data"])
length_router = check_output_length(min_length=10, max_length=100)
assert accuracy_router(state) == "high_accuracy"
assert confidence_router(state) == "high_confidence"
assert format_router(state) == "valid_format"
assert privacy_router(state) == "compliant"
assert freshness_router(state) == "fresh"
assert fields_router(state) == "valid"
assert length_router(state) == "valid_length"
def test_chain_validation_failures(self):
"""Test various validation failures in sequence."""
problematic_state = {
"accuracy_score": 0.3, # Low accuracy
"confidence": "invalid", # Invalid confidence
"output": "malformed json {", # Invalid JSON
"content": "Contact me at john@example.com", # Privacy violation
"timestamp": "2020-01-01", # Stale data
"user_id": "", # Empty required field
}
accuracy_router = check_accuracy(threshold=0.8)
confidence_router = check_confidence_level(threshold=0.8)
format_router = validate_output_format(expected_format="json")
privacy_router = check_privacy_compliance()
freshness_router = check_data_freshness(max_age_seconds=3600)
fields_router = validate_required_fields(["user_id", "missing_field"])
assert accuracy_router(problematic_state) == "low_accuracy"
assert confidence_router(problematic_state) == "low_confidence"
assert format_router(problematic_state) == "invalid_format"
assert privacy_router(problematic_state) == "privacy_violation"
assert freshness_router(problematic_state) == "stale"
assert fields_router(problematic_state) == "missing_fields"
def test_performance_with_large_content(self):
"""Test performance with large content."""
large_content = "x" * 100000 # 100KB of text
privacy_router = check_privacy_compliance()
length_router = check_output_length(min_length=1000, max_length=200000)
state = {"content": large_content, "output": large_content}
# Should handle large content efficiently
assert privacy_router(state) == "compliant"
assert length_router(state) == "valid_length"
def test_concurrent_validator_usage(self):
"""Test that validators are stateless and thread-safe."""
router = check_accuracy(threshold=0.5)
# Multiple calls with different states should not interfere
state1 = {"accuracy_score": 0.8}
state2 = {"accuracy_score": 0.2}
assert router(state1) == "high_accuracy"
assert router(state2) == "low_accuracy"
assert router(state1) == "high_accuracy" # Should still work
def test_memory_efficiency_with_patterns(self):
"""Test memory efficiency with regex patterns."""
# Create router with many custom patterns
many_patterns = [f"pattern{i}_\\w+" for i in range(100)]
router = check_privacy_compliance(sensitive_patterns=many_patterns)
# Should handle many patterns without issues
assert router({"content": "clean content"}) == "compliant"
assert router({"content": "pattern50_sensitive_data"}) == "privacy_violation"

View File

@@ -1,7 +1,7 @@
"""Unit tests for error aggregation and deduplication."""
import asyncio
from datetime import UTC, datetime, timedelta
from datetime import UTC, datetime
import pytest
@@ -386,30 +386,35 @@ class TestErrorAggregator:
assert summary["by_severity"]["critical"] == 3
def test_cleanup_old_errors(self):
"""Test cleanup of old aggregated errors."""
"""Test cleanup of old recent errors tracking."""
reset_error_aggregator()
aggregator = get_error_aggregator()
# Manually add old error
# Manually add old entry to recent_errors
old_error = create_error_info(message="Old error", error_type="Old")
fingerprint = ErrorFingerprint.from_error_info(old_error)
# Create aggregated error with old timestamp
old_time = datetime.now(UTC) - timedelta(hours=2)
aggregated = AggregatedError(
fingerprint=fingerprint,
first_seen=old_time,
last_seen=old_time,
count=1,
sample_errors=[old_error],
)
# Add old entry to recent_errors (this is what _cleanup_old_entries cleans)
import time
aggregator.aggregated_errors[fingerprint.hash] = aggregated
old_time = time.time() - (aggregator.dedup_window * 3) # Older than cutoff
aggregator.recent_errors[fingerprint.hash] = old_time
# Cleanup should remove it
# Add current entry that should remain
current_error = create_error_info(message="Current error", error_type="Current")
current_fingerprint = ErrorFingerprint.from_error_info(current_error)
current_time = time.time()
aggregator.recent_errors[current_fingerprint.hash] = current_time
# Should have 2 entries before cleanup
assert len(aggregator.recent_errors) == 2
# Cleanup should remove the old entry but keep the current one
aggregator._cleanup_old_entries()
assert len(aggregator.aggregated_errors) == 0
assert len(aggregator.recent_errors) == 1
assert current_fingerprint.hash in aggregator.recent_errors
assert fingerprint.hash not in aggregator.recent_errors
@pytest.mark.asyncio
async def test_concurrent_access(self):

View File

@@ -240,6 +240,7 @@ class TestErrorTelemetry:
# Add matching errors
now = datetime.now(UTC)
last_entry = None
for i in range(4):
entry = ErrorLogEntry(
timestamp=(now - timedelta(seconds=i)).isoformat(),
@@ -258,9 +259,11 @@ class TestErrorTelemetry:
stack_trace=None,
)
telemetry.state.recent_errors.append(entry)
last_entry = entry
# Check patterns
telemetry._check_patterns(entry)
assert last_entry is not None
telemetry._check_patterns(last_entry)
# Should trigger alert
alert_callback.assert_called_once()

View File

@@ -155,7 +155,8 @@ class TestPerformanceFilter:
assert filter_obj.filter(record)
assert hasattr(record, "timestamp")
timestamp = record.timestamp
timestamp = getattr(record, "timestamp", None)
assert timestamp is not None
assert isinstance(timestamp, str)
# Should be ISO format timestamp
assert "T" in timestamp

View File

@@ -236,10 +236,6 @@ class TestRunAsyncChain:
def sync_double(x: int) -> int:
return x * 2
async def async_add_ten(x: int) -> int:
await asyncio.sleep(0.01) # Simulate async work
return x + 10
# Note: This test may not work with the stricter typing in async_helpers
# Let's use only sync functions for now
def add_ten_sync(x: int) -> int:

View File

@@ -222,7 +222,7 @@ class TestTypeValidation:
entity: EntityReference = {
"id": f"test_{entity_type}",
"name": f"Test {entity_type}",
"type": entity_type, # type: ignore
"type": entity_type,
}
assert entity["type"] == entity_type

View File

@@ -191,7 +191,7 @@ class TestPatternValidator:
validator = PatternValidator(r"^\d+$")
# Should convert to string first
is_valid, error = validator.validate(123)
is_valid, _ = validator.validate(123)
assert is_valid # "123" matches the pattern

View File

@@ -151,7 +151,7 @@ class TestDownloadDocument:
await download_document("https://example.com/doc.pdf", timeout=60)
# Verify timeout was passed to requests
args, kwargs = mock_get.call_args
_, kwargs = mock_get.call_args
assert kwargs["timeout"] == 60
@@ -187,7 +187,7 @@ class TestCreateTempDocumentFile:
created_files: list[str] = []
try:
for ext in extensions:
file_path, file_name = create_temp_document_file(content, ext)
file_path, _ = create_temp_document_file(content, ext)
created_files.append(file_path)
assert file_path.endswith(ext)

View File

@@ -418,10 +418,10 @@ class TestValidatedNode:
class OutputModel(BaseModel):
greeting: str
@validated_node( # type: ignore[arg-type]
@validated_node( # type: ignore[misc]
name="test_node", input_model=InputModel, output_model=OutputModel
)
def test_node(state: Any) -> dict[str, str]:
def test_node(state: dict[str, Any]) -> dict[str, str]:
return {"greeting": f"Hello, {state.get('name')}!"}
state = {"name": "World"}
@@ -435,8 +435,8 @@ class TestValidatedNode:
class InputModel(BaseModel):
name: str
@validated_node(input_model=InputModel) # type: ignore[arg-type]
def test_node(state: Any) -> dict[str, str]:
@validated_node(input_model=InputModel) # type: ignore[misc]
def test_node(state: dict[str, Any]) -> dict[str, str]:
return {"result": f"Processed {state.get('name')}"}
state = {"name": "test"}
@@ -450,8 +450,8 @@ class TestValidatedNode:
class OutputModel(BaseModel):
result: str
@validated_node(output_model=OutputModel) # type: ignore[arg-type] # pyright: ignore[reportArgumentType,reportGeneralTypeIssues]
def test_node(state: Any) -> dict[str, str]:
@validated_node(output_model=OutputModel) # type: ignore[misc]
def test_node(state: dict[str, Any]) -> dict[str, str]:
return {"result": "success"}
result = test_node({})
@@ -461,8 +461,8 @@ class TestValidatedNode:
def test_validated_node_with_metadata(self):
"""Test validated_node with metadata."""
@validated_node(name="test_node", custom_field="custom_value")
def test_node(state):
@validated_node(name="test_node", custom_field="custom_value") # type: ignore[misc]
def test_node(state: dict[str, Any]) -> dict[str, Any]:
return state
# The decorator should not modify the function behavior
@@ -621,10 +621,14 @@ class TestValidateAllGraphs:
async def test_validate_all_graphs_non_async_function(self):
"""Test validation with non-async graph function."""
def non_async_function():
async def non_async_function() -> MagicMock:
return MagicMock()
graph_functions = {"non_async": non_async_function}
result = await validate_all_graphs(graph_functions)
from typing import cast
result = await validate_all_graphs(
cast(dict[str, Callable[[], Awaitable[Any]]], graph_functions)
)
assert result is False

View File

@@ -3,7 +3,6 @@
from unittest.mock import patch
from bb_core.validation.merge import (
_handle_average_collection,
_handle_dict_update,
_handle_list_extend,
_handle_numeric_operation,
@@ -291,16 +290,15 @@ class TestMergeHelperFunctions:
result_list = merged_list[0]
assert len(result_list) == 2 # Deduplicated
def test_handle_average_collection(self):
"""Test average collection helper."""
def test_handle_dict_update_simple(self):
"""Test dict update helper."""
merged = {}
_handle_average_collection(merged, "score", 85.5)
_handle_average_collection(merged, "score", 92.0)
_handle_average_collection(merged, "rating", 4.5)
_handle_dict_update(merged, {"key1": "value1"})
_handle_dict_update(merged, {"key2": "value2"})
assert merged["values"]["score"] == [85.5, 92.0]
assert merged["values"]["rating"] == [4.5]
assert merged["key1"] == "value1"
assert merged["key2"] == "value2"
class TestMergeEdgeCases:

View File

@@ -232,7 +232,8 @@ class TestAssessAuthoritativeSources:
# First 3 should be authoritative
assert len(auth_sources) == 3
assert all(
source["url"] in [s["url"] for s in auth_sources] for source in sources[:3]
source.get("url") in [s.get("url") for s in auth_sources]
for source in sources[:3]
)
def test_assess_authoritative_quality_score(self):
@@ -247,7 +248,7 @@ class TestAssessAuthoritativeSources:
# Only high quality score should be authoritative
assert len(auth_sources) == 1
assert auth_sources[0]["quality_score"] == 0.9
assert auth_sources[0].get("quality_score") == 0.9
def test_assess_authoritative_credibility_terms(self):
"""Test assessment based on credibility terms."""

View File

@@ -113,7 +113,7 @@ from .text import (
try:
from .tools import CategoryExtractionTool
except ImportError:
CategoryExtractionTool = None # type: ignore[assignment,misc]
CategoryExtractionTool = None
__all__ = [
# From core
@@ -196,4 +196,4 @@ __all__ = [
# Conditionally add tools to __all__
if CategoryExtractionTool is not None:
__all__ += ["CategoryExtractionTool", "StatisticsExtractionTool"]
__all__ += ["CategoryExtractionTool"]

View File

@@ -1,13 +1,14 @@
"""Base classes and interfaces for extraction."""
from abc import ABC, abstractmethod
from typing import Any
class BaseExtractor(ABC):
"""Abstract base class for extractors."""
@abstractmethod
def extract(self, text: str) -> list[dict]:
def extract(self, text: str) -> list[dict[str, Any]]:
"""Extract information from text.
Args:

View File

@@ -67,14 +67,14 @@ class ExtractionConfig(TypedDict, total=False):
# Entity extraction types
class EntityTypedDict(TypedDict, total=False):
class EntityTypedDict(TypedDict):
"""Type definition for extracted entities."""
name: str
type: str
value: str | None
confidence: float
metadata: dict[str, str | int | float | bool]
value: NotRequired[str | None]
confidence: NotRequired[float]
metadata: NotRequired[dict[str, str | int | float | bool]]
class RelationshipTypedDict(TypedDict, total=False):
@@ -88,7 +88,7 @@ class RelationshipTypedDict(TypedDict, total=False):
# Company extraction types
class CompanyExtractionResultTypedDict(TypedDict, total=False):
class CompanyExtractionResultTypedDict(TypedDict):
"""Type definition for company extraction results."""
name: str

View File

@@ -71,7 +71,9 @@ def extract_company_names(
normalized = _normalize_company_name(name)
if not normalized: # Skip if normalization results in empty string
continue
if normalized not in merged or result["confidence"] > merged[normalized]["confidence"]:
if normalized not in merged or result.get("confidence", 0.0) > merged[normalized].get(
"confidence", 0.0
):
merged[normalized] = result
# Convert to list and filter by confidence
@@ -141,9 +143,9 @@ def add_source_metadata(
url: str,
) -> None:
"""Add source metadata to company extraction result."""
company["sources"] = company.get("sources", set())
if isinstance(company["sources"], set):
company["sources"].add(f"{source_type}_{result_index}")
sources = company.get("sources", set())
sources.add(f"{source_type}_{result_index}")
company["sources"] = sources
company["url"] = url
@@ -347,23 +349,27 @@ def _handle_duplicate_company(
) -> bool:
"""Handle potential duplicate company entries."""
for existing in final_results:
existing_normalized = _normalize_company_name(existing["name"])
existing_normalized = _normalize_company_name(existing.get("name", ""))
if normalized_name == existing_normalized or _is_substantial_name_overlap(
normalized_name, existing_normalized
):
# Merge sources
if (
"sources" in existing
and "sources" in company
and isinstance(existing["sources"], set)
and isinstance(company["sources"], set)
(
"sources" in existing
and "sources" in company
and isinstance(existing.get("sources"), set)
and isinstance(company.get("sources"), set)
)
and existing.get("sources")
and company.get("sources")
):
existing["sources"].update(company["sources"])
# Update confidence if higher
if company["confidence"] > existing["confidence"]:
existing["confidence"] = company["confidence"]
if company.get("confidence", 0.0) > existing.get("confidence", 0.0):
existing["confidence"] = company.get("confidence", 0.0)
return True
@@ -397,13 +403,13 @@ def _deduplicate_and_filter_results(
final_results: list[CompanyExtractionResultTypedDict] = []
# Sort by confidence (descending) to process higher confidence first
sorted_results = sorted(results, key=lambda x: x["confidence"], reverse=True)
sorted_results = sorted(results, key=lambda x: x.get("confidence", 0.0), reverse=True)
for company in sorted_results:
if company["confidence"] < min_confidence:
if company.get("confidence", 0.0) < min_confidence:
continue
normalized_name = _normalize_company_name(company["name"])
normalized_name = _normalize_company_name(company.get("name", ""))
if not _handle_duplicate_company(normalized_name, company, final_results):
final_results.append(company)
@@ -412,82 +418,3 @@ def _deduplicate_and_filter_results(
# === Functions for specific extraction needs ===
def _extract_companies_from_single_result(
result: dict[str, str],
result_index: int,
known_companies: set[str] | None,
min_confidence: float,
) -> list[CompanyExtractionResultTypedDict]:
"""Extract companies from a single search result."""
companies: list[CompanyExtractionResultTypedDict] = []
# Extract from title with higher confidence
title_companies = extract_company_names(
result.get("title", ""),
known_companies=known_companies,
min_confidence=min_confidence + 0.1,
)
for company in title_companies:
add_source_metadata("title", company, result_index, result.get("url", ""))
companies.append(company)
# Extract from snippet
snippet_companies = extract_company_names(
result.get("snippet", ""),
known_companies=known_companies,
min_confidence=min_confidence,
)
for company in snippet_companies:
# Check if not duplicate from title
is_duplicate = any(
_is_substantial_name_overlap(company["name"], tc["name"]) for tc in title_companies
)
if not is_duplicate:
add_source_metadata("snippet", company, result_index, result.get("url", ""))
companies.append(company)
return companies
def _update_existing_company(
new_company: CompanyExtractionResultTypedDict,
existing: CompanyExtractionResultTypedDict,
) -> None:
"""Update existing company with new information."""
# Update confidence (weighted average)
existing["confidence"] = (existing["confidence"] + new_company["confidence"]) / 2
# Merge sources
if (
"sources" in existing
and "sources" in new_company
and isinstance(existing["sources"], set)
and isinstance(new_company["sources"], set)
):
existing["sources"].update(new_company["sources"])
def _add_new_company(
company: CompanyExtractionResultTypedDict,
name_lower: str,
merged_results: dict[str, CompanyExtractionResultTypedDict],
) -> None:
"""Add new company to results."""
merged_results[name_lower] = company
def _merge_company_into_results(
company: CompanyExtractionResultTypedDict,
merged_results: dict[str, CompanyExtractionResultTypedDict],
) -> None:
"""Merge company into consolidated results."""
name_lower = company["name"].lower()
if name_lower in merged_results:
_update_existing_company(company, merged_results[name_lower])
else:
_add_new_company(company, name_lower, merged_results)

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import re
from typing import cast
from bb_core.logging import get_logger
@@ -171,7 +172,8 @@ class ComponentExtractor(BaseExtractor):
if matches and isinstance(matches[0], tuple):
items = [match[0] if match else "" for match in matches]
else:
items = list(matches)
# Explicitly type the matches as strings to avoid Sized type
items = [str(match) for match in matches] if matches else []
best_pattern_idx = i
if best_pattern_idx >= 0:
@@ -181,10 +183,8 @@ class ComponentExtractor(BaseExtractor):
if not items:
items = [line.strip() for line in text_block.split("\n") if line.strip()]
# Process each item
# Ensure items is properly typed
items_list: list[str] = items if isinstance(items, list) else []
for item in items_list:
# Process each item - cast to ensure proper typing
for item in cast("list[str]", items):
# Check if this item contains multiple ingredients separated by commas
# e.g., "curry powder, ginger powder, allspice"
if "," in item and not any(char.isdigit() for char in item[:5]):

View File

@@ -34,7 +34,7 @@ async def extract_companies(text: str) -> list[SimpleNamespace]:
"""Extract company names from text asynchronously."""
raw = extract_company_names(text)
if raw:
names = [c["name"] for c in raw]
names = [c.get("name", "") for c in raw if c.get("name")]
else:
words = re.findall(r"\b[A-Z][a-zA-Z]+\b", text)
skip = {"The", "And", "Or", "For", "To", "In", "On", "Of"}
@@ -273,7 +273,7 @@ def extract_thought_action_pairs(text: str) -> list[JsonDict]:
and len(content) == 2
):
action_name, args_str = content
action = {
action: JsonDict = {
"thought": current_thought,
"action": action_name,
"args": parse_action_args(args_str) if args_str else {},
@@ -315,7 +315,7 @@ def extract_entities(text: str) -> dict[str, list[str]]:
# Extract companies using existing functionality
companies = extract_company_names(text)
company_names = [c["name"] for c in companies]
company_names = [c.get("name", "") for c in companies if c.get("name")]
# Extract URLs
url_pattern = re.compile(
@@ -453,7 +453,7 @@ class ExtractedEntity:
class EntityExtractor(BaseExtractor):
"""Extract entities from text."""
def extract(self, text: str) -> list[dict]:
def extract(self, text: str) -> list[dict[str, Any]]:
"""Extract entities from text.
Args:

View File

@@ -12,7 +12,7 @@ try:
from bb_core.validation.statistics import HIGH_CREDIBILITY_TERMS
except ImportError:
# Fallback definition if not available
HIGH_CREDIBILITY_TERMS = [
_fallback_credibility_terms = [
"university",
"research",
"institute",
@@ -30,6 +30,7 @@ except ImportError:
"bureau",
"department",
]
HIGH_CREDIBILITY_TERMS = _fallback_credibility_terms # pyright: ignore[reportConstantRedefinition]
from .numeric import extract_year
@@ -91,8 +92,9 @@ def extract_credibility_terms(text: str) -> list[str]:
>>> extract_credibility_terms("Reported by a renowned research institute")
['research', 'institute']
"""
if not isinstance(text, str):
raise TypeError("text must be a string")
# Type is guaranteed by signature; no need for isinstance
if not text:
return []
return [term for term in HIGH_CREDIBILITY_TERMS if term.lower() in text.lower()]
@@ -112,8 +114,9 @@ def assess_source_quality(text: str) -> float:
>>> assess_source_quality("According to a study by a .edu institution, ...")
0.8
"""
if not isinstance(text, str):
raise TypeError("text must be a string")
# Type is guaranteed by signature; no need for isinstance
if not text:
return 0.0
score = 0.5
# Count credibility terms

View File

@@ -58,10 +58,6 @@ def extract_json_from_text(text: str) -> JsonDict | None:
# This is a simplified approach - find strings and escape newlines in them
# Match JSON strings (basic pattern, not perfect but should work for most cases)
def replace_newlines_in_match(match: re.Match[str]) -> str:
string_content = match.group(0)
# Replace actual newlines with \n, being careful not to double-escape
return string_content.replace("\n", "\\n").replace("\r", "\\r")
# Pattern to match strings in JSON (handles escaped quotes)
string_pattern = r'"(?:[^"\\]|\\.)*"'
@@ -445,7 +441,7 @@ def parse_action_args(text: str) -> ActionArgsDict:
"""
# Try JSON first
json_data = extract_json_from_text(text)
if json_data and isinstance(json_data, dict):
if json_data:
# Convert to ActionArgsDict format
result: dict[str, Any] = {}
for k, v in json_data.items():

View File

@@ -281,11 +281,8 @@ async def _process_chunked_content(
chunk_results.append(result)
# Merge chunk results
merged_result = merge_chunk_results(chunk_results, category=category)
# Ensure the result is a JsonDict
if isinstance(merged_result, dict):
return merged_result
return _default_extraction_result(category)
return merge_chunk_results(chunk_results, category=category)
# Return the merged result
def _default_extraction_result(category: str) -> JsonDict:

View File

@@ -78,6 +78,7 @@ def benchmark():
kwargs = {}
total_time = 0.0
result = None
for _ in range(rounds):
start = time.time()
for _ in range(iterations):

View File

@@ -119,8 +119,8 @@ class TestFactTypedDict:
assert main_fact["text"] == "Revenue analysis shows growth"
assert main_fact["type"] == "financial"
assert len(main_fact["statistics"]) == 1
assert main_fact["statistics"][0]["text"] == "Sub-statistic"
assert main_fact["statistics"][0]["value"] == 1.2
assert main_fact["statistics"][0].get("text") == "Sub-statistic"
assert main_fact["statistics"][0].get("value") == 1.2
def test_fact_confidence_types(self) -> None:
"""Test fact with different confidence types."""
@@ -175,8 +175,8 @@ class TestCategoryFactTypedDict:
}
assert category_fact["type"] == "market_analysis"
assert category_fact["data"]["text"] == "Market share is 45%"
assert category_fact["data"]["value"] == 45.0
assert category_fact["data"].get("text") == "Market share is 45%"
assert category_fact["data"].get("value") == 45.0
def test_category_fact_with_complex_data(self) -> None:
"""Test category fact with complex nested data."""
@@ -201,9 +201,9 @@ class TestCategoryFactTypedDict:
}
assert category_fact["type"] == "comprehensive_analysis"
assert category_fact["data"]["quality_score"] == 0.92
assert len(category_fact["data"]["citations"]) == 1
assert category_fact["data"]["citations"][0]["source"] == "Research Institute"
assert category_fact["data"].get("quality_score") == 0.92
assert len(category_fact["data"].get("citations", [])) == 1
assert category_fact["data"].get("citations", [])[0].get("source") == "Research Institute"
class TestExtractionResult:
@@ -216,7 +216,7 @@ class TestExtractionResult:
result: ExtractionResult = {"facts": facts}
assert len(result["facts"]) == 1
assert result["facts"][0]["text"] == "Simple fact"
assert result["facts"][0].get("text") == "Simple fact"
def test_extraction_result_complete(self) -> None:
"""Test complete extraction result."""
@@ -366,11 +366,13 @@ class TestTypeCompatibility:
result: ExtractionResult = {"facts": [main_fact], "relevance_score": 0.95}
# Verify all levels work correctly
assert category_fact["data"]["text"] == "Main finding"
assert len(category_fact["data"]["statistics"]) == 1
assert category_fact["data"]["statistics"][0]["text"] == "Sub-analysis"
assert len(result["facts"]) == 1
assert result["facts"][0]["type"] == "research"
assert category_fact["data"].get("text") == "Main finding"
statistics = category_fact["data"].get("statistics", [])
assert len(statistics) == 1
assert statistics[0].get("text") == "Sub-analysis"
facts = result.get("facts", [])
assert len(facts) == 1
assert facts[0].get("type") == "research"
def test_optional_field_handling(self) -> None:
"""Test handling of optional fields across types."""
@@ -404,7 +406,7 @@ class TestTypeCompatibility:
facts_list.append(fact)
assert len(facts_list) == 1
assert isinstance(facts_list[0]["citations"], list)
assert isinstance(facts_list[0].get("citations"), list)
def test_union_type_handling(self) -> None:
"""Test handling of union types in FactTypedDict."""
@@ -469,8 +471,9 @@ class TestEdgeCasesAndValidation:
"statistics": nested_stats,
}
assert len(main_fact["statistics"]) == 10
assert main_fact["statistics"][5]["value"] == 50.0
statistics = main_fact.get("statistics", [])
assert len(statistics) == 10
assert statistics[5].get("value") == 50.0
def test_special_string_values(self) -> None:
"""Test handling of special string values."""

View File

@@ -20,7 +20,7 @@ def test_extract_company_names_basic() -> None:
results = extract_company_names(text)
assert len(results) >= 1
company_names = [r["name"] for r in results]
company_names = [r.get("name", "") for r in results]
assert any("Acme Inc" in name for name in company_names)
@@ -29,7 +29,7 @@ def test_extract_company_names_with_indicators() -> None:
text = "Apple Corporation and Microsoft Technologies are tech giants."
results = extract_company_names(text)
company_names = [r["name"] for r in results]
company_names = [r.get("name", "") for r in results]
assert any("Apple Corporation" in name for name in company_names)
assert any("Microsoft Technologies" in name for name in company_names)
@@ -72,7 +72,7 @@ def test_extract_companies_from_search_results() -> None:
results = extract_companies_from_search_results(search_results)
assert len(results) >= 2
company_names = [r["name"] for r in results]
company_names = [r.get("name", "") for r in results]
assert any("Acme Inc" in name for name in company_names)
assert any("Beta LLC" in name for name in company_names)
@@ -98,7 +98,7 @@ def test_add_source_metadata() -> None:
assert company["confidence"] == 0.9
assert "sources" in company
assert "search_0" in company["sources"]
assert company["url"] == "https://example.com"
assert company.get("url") == "https://example.com"
def test_add_source_metadata_multiple_sources() -> None:
@@ -114,7 +114,7 @@ def test_add_source_metadata_multiple_sources() -> None:
assert "sources" in company
assert "indicator_match" in company["sources"]
assert "search_1" in company["sources"]
assert company["url"] == "https://beta.com"
assert company.get("url") == "https://beta.com"
@pytest.mark.asyncio

View File

@@ -3,7 +3,7 @@
import re
from datetime import datetime
from enum import Enum
from typing import ClassVar, final
from typing import ClassVar, cast
from xml.etree.ElementTree import Element
import defusedxml.ElementTree as ET
@@ -263,7 +263,6 @@ class ArxivPaper(BaseModel):
)
@final
class ArxivSearchOptions(BaseModel):
"""Options for searching arXiv.
@@ -276,6 +275,12 @@ class ArxivSearchOptions(BaseModel):
sort_order: Sort order (ascending or descending).
"""
model_config = {"arbitrary_types_allowed": True}
def __init__(self, **data: object) -> None:
"""Initialize ArxivSearchOptions with proper Pydantic setup."""
super().__init__(**data)
query: str = Field(
default="",
description="The search query string.",
@@ -330,14 +335,13 @@ class ArxivSearchOptions(BaseModel):
"""Validate and convert sort_by to enum."""
if isinstance(value, SortBy):
return value
if isinstance(value, str):
if isinstance(value, str) and value:
try:
from typing import cast
return cast("SortBy", SortBy(value))
except ValueError:
raise ValueError(f"Invalid sort_by value: {value}")
raise ValueError(f"sort_by must be a SortBy enum or string, got {type(value)}")
# Default to RELEVANCE if value is falsy
return SortBy.RELEVANCE
@field_validator("sort_order", mode="before")
@classmethod
@@ -345,19 +349,15 @@ class ArxivSearchOptions(BaseModel):
"""Validate and convert sort_order to enum."""
if isinstance(value, SortOrder):
return value
if isinstance(value, str):
if isinstance(value, str) and value:
try:
from typing import cast
return cast("SortOrder", SortOrder(value))
except ValueError:
raise ValueError(f"Invalid sort_order value: {value}")
raise ValueError(
f"sort_order must be a SortOrder enum or string, got {type(value)}"
)
# Default to DESCENDING if value is falsy
return SortOrder.DESCENDING
@final
class ArxivClient(BaseAPIClient):
"""ArXiv API client using bb_core infrastructure."""
@@ -397,19 +397,18 @@ class ArxivClient(BaseAPIClient):
with error_context(operation="arxiv_search", query=query):
# Convert string parameters to enum types if needed
from typing import cast
if isinstance(sort_by, SortBy):
sort_by_enum = sort_by
elif sort_by:
sort_by_enum = cast("SortBy", SortBy(sort_by))
sort_by_enum = SortBy(sort_by)
else:
sort_by_enum = SortBy.RELEVANCE
if isinstance(sort_order, SortOrder):
sort_order_enum = sort_order
elif sort_order:
sort_order_enum = cast("SortOrder", SortOrder(sort_order))
sort_order_enum = SortOrder(sort_order)
else:
sort_order_enum = SortOrder.DESCENDING

View File

@@ -95,8 +95,6 @@ class BaseAPIClient(ABC):
url: str | None = None,
**kwargs: object,
) -> dict[str, object]:
from typing import Any
from bb_core.networking import RequestMethod
# Build full URL if relative
@@ -112,23 +110,29 @@ class BaseAPIClient(ABC):
full_url = url
# Convert method string to RequestMethod enum
method_enum = (
RequestMethod(method.upper()) if method else RequestMethod.GET
)
if isinstance(method, RequestMethod):
method_enum: RequestMethod = method
else:
method_str = method.upper() if method else "GET"
# Use getattr for safe enum access with fallback
method_enum = getattr(RequestMethod, method_str, RequestMethod.GET)
# Extract parameters for the request method with proper typing
params_raw = kwargs.get("params")
params: dict[str, Any] | None = (
params: dict[str, str | int | float | bool] | None = (
params_raw if isinstance(params_raw, dict) else None
)
json_raw = kwargs.get("json_data") or kwargs.get("json")
json_data: dict[str, Any] | None = (
json_raw if isinstance(json_raw, dict) else None
)
json_data: (
dict[
str, str | int | float | bool | list[object] | dict[str, object]
]
| None
) = json_raw if isinstance(json_raw, dict) else None
data_raw = kwargs.get("data")
data: dict[str, Any] | None = (
data: dict[str, str | int | float | bool | bytes] | None = (
data_raw if isinstance(data_raw, dict) else None
)
@@ -141,11 +145,9 @@ class BaseAPIClient(ABC):
headers = {**self._headers_dict, **headers}
# Call the HTTPClient request method
# Cast method to RequestMethod to satisfy type checker
from typing import cast
response = await self._http_client.request(
method=cast("RequestMethod", method_enum),
method=method_enum,
url=full_url,
params=params,
json_data=json_data,
@@ -249,8 +251,8 @@ class BaseAPIClient(ABC):
raise ValueError(f"Invalid response from {url}")
async def _post(
self, endpoint: str, json: dict | None = None, **kwargs: object
) -> dict:
self, endpoint: str, json: dict[str, object] | None = None, **kwargs: object
) -> dict[str, object]:
"""Make a POST request to the API.
Args:

Some files were not shown because too many files have changed in this diff Show More