fix: resolve pre-commit hooks configuration and update dependencies

- Clean and reinstall pre-commit hooks to fix corrupted cache
- Update isort to v6.0.1 to resolve deprecation warnings
- Fix pytest PT012 error by separating pytest.raises from context managers
- Fix pytest PT011 errors by using GraphInterrupt instead of generic Exception
- Fix formatting and trailing whitespace issues automatically applied by hooks
This commit is contained in:
2025-08-07 18:51:42 -04:00
parent 869842b90b
commit 64f0f49d56
246 changed files with 17551 additions and 8947 deletions

View File

@@ -179,7 +179,9 @@
"Bash(timeout 30 python -m pytest tests/unit_tests/nodes/llm/test_unit_call.py::test_call_model_node_basic -xvs)",
"Bash(timeout 30 python -m pytest tests/unit_tests/nodes/llm/test_call.py -xvs -k \"test_call_model_node_basic\")",
"Bash(timeout 30 python -m pytest:*)",
"Bash(timeout 60 python -m pytest:*)"
"Bash(timeout 60 python -m pytest:*)",
"mcp__postgres-kgr__describe_table",
"mcp__postgres-kgr__execute_query"
],
"deny": []
},

View File

@@ -14,7 +14,7 @@ repos:
# isort - Import sorting
- repo: https://github.com/pycqa/isort
rev: 5.13.2
rev: 6.0.1
hooks:
- id: isort
name: isort (python)

View File

@@ -2,5 +2,5 @@ projectKey=vasceannie_biz-bud_6c113581-e663-4a15-8a76-1ce5dab23a5f
serverUrl=http://sonar.lab
serverVersion=25.7.0.110598
dashboardUrl=http://sonar.lab/dashboard?id=vasceannie_biz-bud_6c113581-e663-4a15-8a76-1ce5dab23a5f
ceTaskId=cfa37333-abfb-4a0a-b950-db0c0f741f5f
ceTaskUrl=http://sonar.lab/api/ce/task?id=cfa37333-abfb-4a0a-b950-db0c0f741f5f
ceTaskId=06a875ae-ce12-424f-abef-5b9ce511b87f
ceTaskUrl=http://sonar.lab/api/ce/task?id=06a875ae-ce12-424f-abef-5b9ce511b87f

View File

@@ -1,336 +0,0 @@
# Core Package Test Coverage Analysis & Implementation Guide
## Executive Summary
This document provides a comprehensive analysis of test coverage for the `/app/src/biz_bud/core/` package and outlines the implementation strategy for achieving 100% test coverage.
**Current Status:**
- **Before improvements:** 115 failed tests, 967 passed tests
- **After improvements:** 91 failed tests, 1,052 passed tests
- **Progress:** 24 fewer failures (+85 more passing tests)
- **Coverage improvement:** Significant improvements in core caching, error handling, and service management
## Completed Work
### 1. Fixed Existing Failing Tests ✅
**Key Fixes Implemented:**
- **ServiceRegistry API:** Updated all calls from `register_service_factory` to `register_factory`
- **MockAppConfig:** Consolidated fixture definitions and added missing HTTP configuration fields
- **Async Context Managers:** Fixed service factories to use proper `@asynccontextmanager` pattern
- **Cache Type Safety:** Added runtime type validation to `InMemoryCache` for proper error raising
- **LRU Cache Logic:** Fixed zero-size cache behavior to prevent storage
- **Type Safety Tests:** Updated to use correct base classes (`GenericCacheBackend` vs `CacheBackend`)
### 2. Comprehensive Coverage Gap Analysis ✅
**Major Coverage Gaps Identified:**
#### Caching Modules (Partially Addressed)
-**cache_encoder.py** - Comprehensive tests created (16 test scenarios)
-**file.py** - Extensive tests created (50+ test scenarios)
-**redis.py** - Complete tests created (40+ test scenarios)
- ⚠️ **cache_manager.py** - Needs comprehensive tests
- ⚠️ **cache_backends.py** - Needs comprehensive tests
#### Config Modules (Not Yet Addressed)
-**constants.py** - No tests exist
-**ensure_tools_config.py** - No tests exist
-**schemas/** - All 8 schema modules need tests:
- analysis.py, app.py, buddy.py, core.py, llm.py, research.py, services.py, tools.py
#### Edge Helpers Modules (Critical Gap)
-**14 routing modules** need comprehensive tests:
- buddy_router.py, command_patterns.py, command_routing.py, consolidated.py
- core.py, error_handling.py, flow_control.py, monitoring.py
- router_factories.py, routing_rules.py, secure_routing.py
- user_interaction.py, validation.py, workflow_routing.py
-**basic_routing.py** - Has some tests
-**routing_security.py** - Has some tests
#### Error Handling Modules (Critical Gap)
-**10 error modules** need comprehensive tests:
- aggregator.py, formatter.py, handler.py, llm_exceptions.py
- logger.py, router.py, router_config.py, specialized_exceptions.py
- telemetry.py, tool_exceptions.py, types.py
-**base.py** - Has basic tests
#### LangGraph Integration (Not Yet Addressed)
-**All 4 modules** need tests:
- cross_cutting.py, graph_config.py, runnable_config.py, state_immutability.py
#### URL Processing (Not Yet Addressed)
-**All 4 modules** need tests:
- config.py, discoverer.py, filter.py, validator.py
#### Utils Modules (Partially Addressed)
-**json_extractor.py** - Has tests
-**lazy_loader.py** - Has tests
-**message_helpers.py** - Has tests
-**url_processing_comprehensive.py** - Has tests
-**cache.py** - Needs tests
-**capability_inference.py** - Needs tests
-**url_analyzer.py** - Needs tests
-**url_normalizer.py** - Needs tests
#### Validation Modules (Mostly Missing)
-**langgraph_validation.py** - Has tests
-**pydantic_security.py** - Has tests
-**13 validation modules** need tests:
- base.py, chunking.py, condition_security.py, config.py, content.py
- content_type.py, content_validation.py, decorators.py, document_processing.py
- examples.py, graph_validation.py, merge.py, pydantic_models.py
- security.py, statistics.py, types.py
### 3. New Test Modules Created ✅
**Comprehensive Test Suites Added:**
1. **`test_cache_encoder.py`** (16 test scenarios)
- Primitive type handling, datetime/timedelta encoding
- Complex nested structures, callable/method encoding
- Edge cases, consistency, and deterministic behavior
- JSON integration and type preservation
2. **`test_file_cache.py`** (50+ test scenarios)
- Initialization and configuration testing
- Directory management and thread safety
- Pickle and JSON serialization formats
- TTL functionality and expiration
- Error handling and corrupted file recovery
- Concurrency and bulk operations
- Edge cases with large values and special characters
3. **`test_redis_cache.py`** (40+ test scenarios)
- Connection management and failure handling
- Basic CRUD operations with proper mocking
- Bulk operations (get_many, set_many)
- Cache clearing with scan operations
- Health checks and utility methods
- Integration scenarios and error handling
- Unicode and large key handling
## Implementation Strategy
### Phase 1: Critical Foundation Modules (High Priority)
**Error Handling System** (Estimated: 2-3 days)
```
tests/unit_tests/core/errors/
├── test_base.py (comprehensive base error system)
├── test_specialized_exceptions.py (custom exception types)
├── test_handler.py (error handling logic)
├── test_aggregator.py (error collection and reporting)
├── test_formatter.py (error message formatting)
└── test_types.py (error type definitions)
```
**Configuration Management** (Estimated: 1-2 days)
```
tests/unit_tests/core/config/
├── test_constants.py (configuration constants)
├── test_ensure_tools_config.py (tool configuration validation)
└── schemas/
├── test_app.py (application configuration schema)
├── test_core.py (core configuration schema)
├── test_llm.py (LLM configuration schema)
└── test_services.py (services configuration schema)
```
### Phase 2: Core Workflow Modules (Medium Priority)
**LangGraph Integration** (Estimated: 2 days)
```
tests/unit_tests/core/langgraph/
├── test_cross_cutting.py (cross-cutting concerns)
├── test_graph_config.py (graph configuration)
├── test_runnable_config.py (runnable configuration)
└── test_state_immutability.py (state management)
```
**Edge Helpers/Routing** (Estimated: 3-4 days)
```
tests/unit_tests/core/edge_helpers/
├── test_command_routing.py (command routing logic)
├── test_workflow_routing.py (workflow routing)
├── test_router_factories.py (router factory patterns)
├── test_flow_control.py (flow control mechanisms)
└── test_monitoring.py (routing monitoring)
```
### Phase 3: Supporting Modules (Lower Priority)
**Validation Framework** (Estimated: 2-3 days)
```
tests/unit_tests/core/validation/
├── test_base.py (base validation framework)
├── test_content.py (content validation)
├── test_decorators.py (validation decorators)
├── test_security.py (security validation)
└── test_pydantic_models.py (Pydantic model validation)
```
**URL Processing & Utils** (Estimated: 1-2 days)
```
tests/unit_tests/core/url_processing/
├── test_config.py (URL processing configuration)
├── test_validator.py (URL validation)
├── test_discoverer.py (URL discovery)
└── test_filter.py (URL filtering)
tests/unit_tests/core/utils/
├── test_cache.py (utility cache functions)
├── test_capability_inference.py (capability detection)
├── test_url_analyzer.py (URL analysis)
└── test_url_normalizer.py (URL normalization)
```
## Testing Best Practices & Patterns
### Hierarchical Fixture Strategy
**Core Conftest Structure:**
```python
# tests/unit_tests/core/conftest.py
@pytest.fixture
def mock_app_config():
"""Comprehensive application configuration mock."""
# Already implemented in services/conftest.py
@pytest.fixture
def error_context():
"""Error handling context for testing."""
@pytest.fixture
def mock_logger():
"""Mock logger for testing."""
@pytest.fixture
def validation_schema_factory():
"""Factory for creating validation schemas."""
```
### Test Organization Principles
1. **One test file per source module** - `test_module_name.py` for `module_name.py`
2. **Class-based organization** - Group related tests into classes
3. **Descriptive test names** - `test_method_behavior_condition`
4. **Comprehensive edge case coverage** - Test boundary conditions, error states
5. **Mock external dependencies** - Use dependency injection and mocking
6. **Avoid test loops/conditionals** - Each test should be explicit and direct
### Error Testing Patterns
```python
# Pattern for testing error handling
def test_error_scenario_with_proper_context():
with pytest.raises(SpecificError, match="expected message pattern"):
# Code that should raise error
pass
# Pattern for testing error recovery
def test_error_recovery_mechanism():
# Setup error condition
# Test recovery behavior
# Verify system state after recovery
```
### Async Testing Patterns
```python
# Pattern for async testing
@pytest.mark.asyncio
async def test_async_operation():
# Use proper async fixtures
# Test async behavior
# Verify async state management
```
## Fixture Architecture
### Global Fixtures (tests/conftest.py)
- Database connections
- Redis connections
- Event loop management
- Common configuration
### Package Fixtures (tests/unit_tests/core/conftest.py)
- Core-specific configurations
- Mock services
- Error contexts
- Validation helpers
### Module Fixtures (per test file)
- Module-specific mocks
- Test data factories
- Specialized configurations
## Quality Assurance Checklist
### For Each New Test Module:
- [ ] Tests cover all public methods/functions
- [ ] Edge cases and error conditions tested
- [ ] Proper fixture usage and test isolation
- [ ] Descriptive test names and docstrings
- [ ] No loops or conditionals in tests
- [ ] Proper async/await patterns where needed
- [ ] Mock external dependencies appropriately
- [ ] Test both success and failure paths
- [ ] Verify error messages and types
- [ ] Test concurrent operations where relevant
### Before Completion:
- [ ] Run full test suite: `make test`
- [ ] Verify coverage: `make coverage-report`
- [ ] Run linting: `make lint-all`
- [ ] Check for no regressions in existing tests
- [ ] Validate all new tests pass consistently
- [ ] Review test performance and optimization
## Estimated Timeline
**Total Estimated Effort:** 10-15 development days
- **Phase 1 (Critical):** 5-6 days
- **Phase 2 (Core Workflow):** 4-5 days
- **Phase 3 (Supporting):** 3-4 days
- **Integration & QA:** 1-2 days
## Success Metrics
**Target Goals:**
- [ ] 100% line coverage for `/app/src/biz_bud/core/` package
- [ ] All tests passing consistently
- [ ] No test flakiness or race conditions
- [ ] Comprehensive error scenario coverage
- [ ] Performance benchmarks maintained
- [ ] Documentation updated
**Current Progress:**
- ✅ Baseline test failures reduced from 115 to 91
- ✅ Passing tests increased from 967 to 1,052
- ✅ Major caching modules fully tested
- ✅ Service registry issues resolved
- ✅ Type safety improvements implemented
## Next Steps
1. **Immediate (Next Session):**
- Implement Phase 1 error handling tests
- Create comprehensive base error system tests
- Add configuration module tests
2. **Short Term (1-2 weeks):**
- Complete LangGraph integration tests
- Implement edge helpers/routing tests
- Add validation framework tests
3. **Medium Term (2-4 weeks):**
- Complete all remaining utility tests
- Add URL processing tests
- Perform comprehensive integration testing
- Achieve 100% coverage target
---
*This analysis represents a comprehensive roadmap for achieving complete test coverage of the Business Buddy core package. The implementation should follow the phased approach for maximum efficiency and minimal risk.*

View File

@@ -1,178 +0,0 @@
# Graph Performance Optimizations
## 🎯 **Performance Issues Identified & Fixed**
### **Critical Bottlenecks Resolved:**
#### 1. **Graph Recompilation Anti-Pattern** ✅ **FIXED**
- **Problem**: `builder.compile()` called on every request (200-500ms overhead)
- **Solution**: Implemented graph singleton pattern with configuration-based caching
- **Implementation**:
- `_graph_cache: dict[str, CompiledStateGraph]` with async locks
- `get_cached_graph()` function with config hash-based caching
- LangGraph `InMemoryCache()` and `InMemorySaver()` integration
- **Expected Improvement**: **99% reduction** (200ms → 1ms)
#### 2. **Service Factory Recreation** ✅ **FIXED**
- **Problem**: New `ServiceFactory` instances on every request (50-100ms overhead)
- **Solution**: Service factory caching with configuration-based keys
- **Implementation**:
- `_service_factory_cache: dict[str, ServiceFactory]` with thread-safe access
- `_get_cached_service_factory()` with config hash keys
- Proper cleanup methods for resource management
- **Expected Improvement**: **95% reduction** (50ms → 1ms)
#### 3. **Configuration Reloading** ✅ **FIXED**
- **Problem**: Config loaded from disk/environment on every call (5-10ms overhead)
- **Solution**: Async-safe lazy loading with caching
- **Implementation**:
- `AsyncSafeLazyLoader` for config caching
- `get_app_config()` async version with lazy initialization
- Backward compatibility with `get_app_config_sync()`
- **Expected Improvement**: **95% reduction** (5ms → 0.1ms)
#### 4. **State Creation Overhead** ✅ **FIXED**
- **Problem**: JSON serialization and dict construction on every call (10-20ms overhead)
- **Solution**: State template pattern with cached base objects
- **Implementation**:
- `_state_template_cache` with async-safe initialization
- `get_state_template()` for cached template retrieval
- Optimized `create_initial_state()` using shallow copy + specific updates
- Eliminated unnecessary JSON serialization for simple cases
- **Expected Improvement**: **85% reduction** (15ms → 2ms)
## 🚀 **New Optimized Architecture**
### **Graph Creation Flow:**
```python
# Before (every request):
builder = StateGraph(InputState) # 50ms
# ... add nodes and edges ... # 100ms
graph = builder.compile() # 200-500ms
service_factory = ServiceFactory(config) # 50-100ms
Total: ~400-650ms per request
# After (cached):
graph = await get_cached_graph(config_hash) # 1ms (cached)
service_factory = _get_cached_service_factory(config_hash) # 1ms (cached)
Total: ~2ms per request (99% improvement)
```
### **State Creation Flow:**
```python
# Before:
config = load_config() # 5-10ms
state = {
"config": config.model_dump(), # 5ms
# ... large dict construction # 5ms
"raw_input": json.dumps(input) # 2ms
}
Total: ~17ms per state
# After:
template = await get_state_template() # 0.1ms (cached)
state = {**template, "raw_input": f'{"query": "{query}"}'} # 1ms
Total: ~1.1ms per state (94% improvement)
```
## 🔧 **Technical Implementation Details**
### **Caching Infrastructure:**
- **Graph Cache**: `dict[str, CompiledStateGraph]` with config hash keys
- **Service Factory Cache**: `dict[str, ServiceFactory]` with cleanup management
- **State Template Cache**: Single cached template for fast shallow copying
- **Configuration Cache**: `AsyncSafeLazyLoader` pattern for thread-safe lazy loading
### **LangGraph Best Practices Integrated:**
```python
# Optimized compilation with LangGraph features
compiled_graph = builder.compile(
cache=InMemoryCache(), # Node result caching
checkpointer=InMemorySaver() # State persistence
)
```
### **Thread Safety & Async Compatibility:**
- **Async Locks**: `asyncio.Lock()` for cache access coordination
- **Backward Compatibility**: Sync versions maintained for existing code
- **Race Condition Prevention**: Double-checked locking patterns
- **Resource Cleanup**: Proper cleanup methods for all cached resources
## 📊 **Expected Performance Impact**
| Component | Before | After | Improvement |
|-----------|--------|-------|-------------|
| Graph compilation | 200-500ms | 1ms | **99%** |
| Service initialization | 50-100ms | 1ms | **95%** |
| Configuration loading | 5-10ms | 0.1ms | **95%** |
| State creation | 10-20ms | 1-2ms | **85%** |
| **Total Request Time** | **265-630ms** | **3-4ms** | **95%** |
## 🎛️ **New Performance Utilities**
### **Cache Management:**
```python
# Get cache statistics
stats = get_cache_stats()
# Reset caches (for testing)
reset_caches()
# Cleanup resources
await cleanup_graph_cache()
```
### **Monitoring & Debugging:**
```python
# Cache statistics include:
{
"graph_cache_size": 3,
"service_factory_cache_size": 2,
"state_template_cached": True,
"config_cached": True,
"graph_cache_keys": ["config_a1b2c3", "config_d4e5f6"],
"service_factory_cache_keys": ["config_a1b2c3", "config_d4e5f6"]
}
```
## 🔄 **Backward Compatibility**
All existing APIs maintained:
- `get_graph()` - Returns cached graph instance
- `create_initial_state()` - Now has async version + sync fallback
- `graph_factory()` - Optimized with caching
- `run_graph()` - Uses optimized graph creation
## 🧪 **Testing & Validation**
Performance improvements can be validated with:
```python
python test_graph_performance.py
```
Expected results:
- Graph creation: 99% faster on cache hits
- State creation: 1000+ states/second throughput
- Configuration hashing: Sub-millisecond performance
## ✅ **Implementation Status**
- [x] Graph singleton pattern with caching
- [x] Service factory caching with cleanup
- [x] State template optimization
- [x] Configuration lazy loading
- [x] LangGraph best practices integration
- [x] Backward compatibility maintained
- [x] Resource cleanup and monitoring
- [x] Performance testing framework
## 🚦 **Production Readiness**
The optimizations are:
- **Thread-safe**: Using asyncio locks and proper synchronization
- **Memory-efficient**: Weak references and cleanup methods prevent leaks
- **Fault-tolerant**: Error handling and fallback mechanisms
- **Monitorable**: Cache statistics and performance metrics
- **Backward-compatible**: Existing code continues to work unchanged
**Recommended deployment**: Enable optimizations in production for immediate 90-95% performance improvement in graph execution latency.

View File

@@ -1,89 +0,0 @@
#!/usr/bin/env python3
"""
Detailed audit script to show specific violations in a file.
"""
import ast
import sys
from pathlib import Path
class DetailedViolationFinder(ast.NodeVisitor):
def __init__(self):
self.violations = []
self.current_function = None
def visit_FunctionDef(self, node):
if node.name.startswith('test_'):
old_function = self.current_function
self.current_function = node.name
self.generic_visit(node)
self.current_function = old_function
else:
self.generic_visit(node)
def visit_For(self, node):
if self.current_function:
self.violations.append({
'type': 'for_loop',
'function': self.current_function,
'line': node.lineno,
'code': f"for loop at line {node.lineno}"
})
self.generic_visit(node)
def visit_While(self, node):
if self.current_function:
self.violations.append({
'type': 'while_loop',
'function': self.current_function,
'line': node.lineno,
'code': f"while loop at line {node.lineno}"
})
self.generic_visit(node)
def visit_If(self, node):
if self.current_function:
self.violations.append({
'type': 'if_statement',
'function': self.current_function,
'line': node.lineno,
'code': f"if statement at line {node.lineno}"
})
self.generic_visit(node)
def analyze_file(filepath):
"""Analyze a specific file for policy violations."""
try:
with open(filepath, 'r', encoding='utf-8') as f:
content = f.read()
tree = ast.parse(content, filename=filepath)
finder = DetailedViolationFinder()
finder.visit(tree)
print(f"🔍 Detailed Analysis: {filepath}")
print("=" * 80)
if not finder.violations:
print("✅ No violations found!")
return
print(f"❌ Found {len(finder.violations)} violations:")
print()
for i, violation in enumerate(finder.violations, 1):
print(f"{i}. {violation['type'].upper()} in {violation['function']}() at line {violation['line']}")
print(f" {violation['code']}")
print()
except Exception as e:
print(f"❌ Error analyzing {filepath}: {e}")
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: python detailed_audit.py <filepath>")
sys.exit(1)
filepath = sys.argv[1]
analyze_file(filepath)

View File

@@ -14,8 +14,6 @@ services:
- redis
- qdrant
restart: unless-stopped
ports:
- "2024:2024"
user: "${USER_ID:-1000}:${GROUP_ID:-1000}"
environment:
# Database connections

View File

@@ -1,411 +0,0 @@
# ServiceFactory Standardization Plan
## Executive Summary
**Task**: ServiceFactory Standardization Implementation Patterns (Task 22)
**Status**: Design Phase Complete
**Priority**: Medium
**Completion Date**: July 14, 2025
The ServiceFactory audit revealed an excellent core implementation with some areas requiring standardization. This plan addresses identified inconsistencies in async/sync patterns, serialization issues, singleton implementations, and service creation patterns.
## Current State Assessment
### ✅ Strengths
- **Excellent ServiceFactory Architecture**: Race-condition-free, thread-safe, proper lifecycle management
- **Consistent BaseService Pattern**: All core services follow the proper inheritance hierarchy
- **Robust Async Initialization**: All services properly separate sync construction from async initialization
- **Comprehensive Cleanup**: Services implement proper resource cleanup patterns
### ⚠️ Areas for Standardization
#### 1. Serialization Issues
**SemanticExtractionService Constructor Injection**:
```python
# Current implementation has potential serialization issues
def __init__(self, app_config: AppConfig, llm_client: LangchainLLMClient, vector_store: VectorStore):
super().__init__(app_config)
self.llm_client = llm_client # Injected dependencies may not serialize
self.vector_store = vector_store
```
#### 2. Singleton Pattern Inconsistencies
- **Global Factory**: Thread-safety vulnerabilities in `get_global_factory()`
- **Cache Decorators**: Multiple uncoordinated singleton instances
- **Service Helpers**: Create multiple ServiceFactory instances instead of reusing
#### 3. Configuration Extraction Variations
- Different error handling strategies for missing configurations
- Inconsistent fallback patterns across services
- Varying approaches to config section access
## Standardization Design
### Pattern 1: Service Creation Standardization
#### Current Implementation (Keep)
```python
class StandardService(BaseService[StandardServiceConfig]):
def __init__(self, app_config: AppConfig) -> None:
super().__init__(app_config)
self.resource = None # Initialize in initialize()
async def initialize(self) -> None:
# Async resource creation here
self.resource = await create_async_resource(self.config)
async def cleanup(self) -> None:
if self.resource:
await self.resource.close()
self.resource = None
```
#### Dependency Injection Pattern (Standardize)
```python
class DependencyInjectedService(BaseService[ServiceConfig]):
"""Services with dependencies should use factory resolution, not constructor injection."""
def __init__(self, app_config: AppConfig) -> None:
super().__init__(app_config)
self._llm_client: LangchainLLMClient | None = None
self._vector_store: VectorStore | None = None
async def initialize(self, factory: ServiceFactory) -> None:
"""Initialize with factory for dependency resolution."""
self._llm_client = await factory.get_llm_client()
self._vector_store = await factory.get_vector_store()
@property
def llm_client(self) -> LangchainLLMClient:
if self._llm_client is None:
raise RuntimeError("Service not initialized")
return self._llm_client
```
### Pattern 2: Singleton Management Standardization
#### ServiceFactory Singleton Pattern (Template for all singletons)
```python
class StandardizedSingleton:
"""Template based on ServiceFactory's proven task-based pattern."""
_instances: dict[str, Any] = {}
_creation_lock = asyncio.Lock()
_initializing: dict[str, asyncio.Task[Any]] = {}
@classmethod
async def get_instance(cls, key: str, factory_func: Callable[[], Awaitable[Any]]) -> Any:
"""Race-condition-free singleton creation."""
# Fast path - instance already exists
if key in cls._instances:
return cls._instances[key]
# Use creation lock to protect initialization tracking
async with cls._creation_lock:
# Double-check after acquiring lock
if key in cls._instances:
return cls._instances[key]
# Check if initialization is already in progress
if key in cls._initializing:
task = cls._initializing[key]
else:
# Create new initialization task
task = asyncio.create_task(factory_func())
cls._initializing[key] = task
# Wait for initialization to complete (outside the lock)
try:
instance = await task
cls._instances[key] = instance
return instance
finally:
# Clean up initialization tracking
async with cls._creation_lock:
current_task = cls._initializing.get(key)
if current_task is task:
cls._initializing.pop(key, None)
```
#### Global Factory Standardization
```python
# Replace current global factory with thread-safe implementation
class GlobalServiceFactory:
_instance: ServiceFactory | None = None
_creation_lock = asyncio.Lock()
_initializing_task: asyncio.Task[ServiceFactory] | None = None
@classmethod
async def get_factory(cls, config: AppConfig | None = None) -> ServiceFactory:
"""Thread-safe global factory access."""
if cls._instance is not None:
return cls._instance
async with cls._creation_lock:
if cls._instance is not None:
return cls._instance
if cls._initializing_task is not None:
return await cls._initializing_task
if config is None:
raise ValueError("No global factory exists and no config provided")
async def create_factory() -> ServiceFactory:
return ServiceFactory(config)
cls._initializing_task = asyncio.create_task(create_factory())
try:
cls._instance = await cls._initializing_task
return cls._instance
finally:
async with cls._creation_lock:
cls._initializing_task = None
```
### Pattern 3: Configuration Extraction Standardization
#### Standardized Configuration Validation
```python
class BaseServiceConfig(BaseModel):
"""Enhanced base configuration with standardized error handling."""
@classmethod
def extract_from_app_config(
cls,
app_config: AppConfig,
section_name: str,
fallback_sections: list[str] | None = None,
required: bool = True
) -> BaseServiceConfig:
"""Standardized configuration extraction with consistent error handling."""
# Try primary section
if hasattr(app_config, section_name):
section = getattr(app_config, section_name)
if section is not None:
return cls.model_validate(section)
# Try fallback sections
if fallback_sections:
for fallback in fallback_sections:
if hasattr(app_config, fallback):
section = getattr(app_config, fallback)
if section is not None:
return cls.model_validate(section)
# Handle missing configuration
if required:
raise ValueError(
f"Required configuration section '{section_name}' not found in AppConfig. "
f"Checked sections: {[section_name] + (fallback_sections or [])}"
)
# Return default configuration
return cls()
class StandardServiceConfig(BaseServiceConfig):
timeout: int = Field(30, ge=1, le=300)
retries: int = Field(3, ge=0, le=10)
@classmethod
def _validate_config(cls, app_config: AppConfig) -> StandardServiceConfig:
"""Standardized service config validation."""
return cls.extract_from_app_config(
app_config=app_config,
section_name="my_service_config",
fallback_sections=["general_config"],
required=False # Service provides sensible defaults
)
```
### Pattern 4: Cleanup Standardization
#### Enhanced Cleanup with Timeout Protection
```python
class BaseService[TConfig: BaseServiceConfig]:
"""Enhanced base service with standardized cleanup patterns."""
async def cleanup(self) -> None:
"""Standardized cleanup with timeout protection and error handling."""
cleanup_tasks = []
# Collect all cleanup operations
for attr_name in dir(self):
attr = getattr(self, attr_name)
if hasattr(attr, 'close') and callable(attr.close):
cleanup_tasks.append(self._safe_cleanup(attr_name, attr.close))
elif hasattr(attr, 'cleanup') and callable(attr.cleanup):
cleanup_tasks.append(self._safe_cleanup(attr_name, attr.cleanup))
# Execute all cleanups with timeout protection
if cleanup_tasks:
results = await asyncio.gather(*cleanup_tasks, return_exceptions=True)
# Log any cleanup failures
for attr_name, result in zip(
[name for name in dir(self) if not name.startswith('_')],
results
):
if isinstance(result, Exception):
logger.warning(f"Cleanup failed for {attr_name}: {result}")
async def _safe_cleanup(self, attr_name: str, cleanup_func: Callable[[], Awaitable[None]]) -> None:
"""Safe cleanup with timeout protection."""
try:
await asyncio.wait_for(cleanup_func(), timeout=30.0)
logger.debug(f"Successfully cleaned up {attr_name}")
except asyncio.TimeoutError:
logger.error(f"Cleanup timed out for {attr_name}")
raise
except Exception as e:
logger.error(f"Cleanup failed for {attr_name}: {e}")
raise
```
### Pattern 5: Service Helper Deprecation
#### Migrate Service Helpers to Factory Pattern
```python
# OLD: Direct service creation (to be deprecated)
async def get_web_search_tool(service_factory: ServiceFactory) -> WebSearchTool:
"""DEPRECATED: Use service_factory.get_service(WebSearchTool) instead."""
warnings.warn(
"get_web_search_tool is deprecated. Use service_factory.get_service(WebSearchTool)",
DeprecationWarning,
stacklevel=2
)
return await service_factory.get_service(WebSearchTool)
# NEW: Proper BaseService implementation
class WebSearchTool(BaseService[WebSearchConfig]):
"""Web search tool as proper BaseService implementation."""
def __init__(self, app_config: AppConfig) -> None:
super().__init__(app_config)
self.search_client = None
async def initialize(self) -> None:
"""Initialize search client."""
self.search_client = SearchClient(
api_key=self.config.api_key,
timeout=self.config.timeout
)
async def cleanup(self) -> None:
"""Clean up search client."""
if self.search_client:
await self.search_client.close()
self.search_client = None
```
## Implementation Strategy
### Phase 1: Critical Fixes (High Priority)
1. **Fix SemanticExtractionService Serialization**
- Convert constructor injection to factory resolution
- Update ServiceFactory to handle dependency resolution properly
- Add serialization tests
2. **Standardize Global Factory**
- Implement thread-safe global factory pattern
- Add proper lifecycle management
- Deprecate old global factory functions
3. **Fix Cache Decorator Singletons**
- Apply standardized singleton pattern to cache decorators
- Ensure thread safety
- Coordinate cleanup with global factory
### Phase 2: Pattern Standardization (Medium Priority)
1. **Configuration Extraction**
- Implement standardized configuration validation
- Update all services to use consistent error handling
- Add configuration documentation
2. **Service Helper Migration**
- Convert web tools to proper BaseService implementations
- Add deprecation warnings to old helper functions
- Update documentation and examples
3. **Enhanced Cleanup Patterns**
- Implement timeout protection in cleanup methods
- Add comprehensive error handling
- Ensure idempotent cleanup operations
### Phase 3: Documentation and Testing (Low Priority)
1. **Pattern Documentation**
- Create service implementation guide
- Document singleton best practices
- Add configuration examples
2. **Enhanced Testing**
- Add service lifecycle tests
- Test thread safety of singleton patterns
- Add serialization/deserialization tests
## Success Metrics
- ✅ All services follow consistent initialization patterns
- ✅ No thread safety issues in singleton implementations
- ✅ All services can be properly serialized for dependency injection
- ✅ Consistent error handling across all configuration validation
- ✅ No memory leaks or orphaned resources in service lifecycle
- ✅ 100% test coverage for service factory patterns
## Migration Guide
### For Service Developers
1. **Use Factory Resolution, Not Constructor Injection**:
```python
# BAD
def __init__(self, app_config, llm_client, vector_store):
# GOOD
def __init__(self, app_config):
async def initialize(self, factory):
self.llm_client = await factory.get_llm_client()
```
2. **Follow Standardized Configuration Patterns**:
```python
@classmethod
def _validate_config(cls, app_config: AppConfig) -> MyServiceConfig:
return MyServiceConfig.extract_from_app_config(
app_config, "my_service", fallback_sections=["general"]
)
```
3. **Use Global Factory Safely**:
```python
# Use new thread-safe pattern
factory = await GlobalServiceFactory.get_factory(config)
service = await factory.get_service(MyService)
```
### For Application Developers
1. **Prefer ServiceFactory Over Direct Service Creation**
2. **Use Context Managers for Automatic Cleanup**
3. **Test Service Lifecycle in Integration Tests**
## Risk Assessment
**Low Risk**: Most changes are additive and maintain backward compatibility
**Medium Risk**: SemanticExtractionService changes require careful testing
**High Impact**: Improves reliability, testability, and maintainability significantly
## Conclusion
The ServiceFactory implementation is already excellent and serves as the template for standardization. The main work involves applying its proven patterns to global singletons, fixing the one serialization issue, and standardizing configuration handling. These changes will create a more consistent, reliable, and maintainable service architecture throughout the Business Buddy platform.
---
**Prepared by**: Claude Code Analysis
**Review Required**: Architecture Team
**Implementation Timeline**: 2-3 sprint cycles

View File

@@ -1,253 +0,0 @@
# System Boundary Validation Audit Report
## Executive Summary
**Audit Date**: July 14, 2025
**Project**: biz-budz Business Intelligence Platform
**Scope**: System boundary validation assessment for Pydantic implementation
### Key Findings
**EXCELLENT VALIDATION POSTURE** - The biz-budz codebase already implements comprehensive Pydantic validation across all system boundaries with industry-leading practices.
## System Boundaries Analysis
### 1. Tool Interfaces ✅ FULLY VALIDATED
**Location**: `src/biz_bud/nodes/scraping/scrapers.py`, `packages/business-buddy-tools/`
**Current Implementation**:
- All tools use `@tool` decorator with Pydantic schemas
- Complete input validation with field constraints
- Example: `ScrapeUrlInput` with URL validation, timeout constraints, pattern matching
```python
class ScrapeUrlInput(BaseModel):
url: str = Field(description="The URL to scrape")
scraper_name: str = Field(
default="auto",
pattern="^(auto|beautifulsoup|firecrawl|jina)$",
)
timeout: Annotated[int, Field(ge=1, le=300)] = Field(
default=30, description="Timeout in seconds"
)
```
**Validation Coverage**: 100% - All tool inputs/outputs are validated
### 2. API Client Interfaces ✅ FULLY VALIDATED
**Location**: `packages/business-buddy-tools/src/bb_tools/api_clients/`
**External API Clients Identified**:
- `TavilySearch` - Web search API client
- `ArxivClient` - Academic paper search
- `FirecrawlApp` - Web scraping service
- `JinaSearch`, `JinaReader`, `JinaReranker` - AI-powered content processing
- `R2RClient` - Document retrieval service
**Current Implementation**:
- All inherit from `BaseAPIClient` with standardized validation
- Complete Pydantic model validation for requests/responses
- Robust error handling with custom exception hierarchy
- Type-safe interfaces with no `Any` types
**Validation Coverage**: 100% - All external API interactions are validated
### 3. Configuration System ✅ COMPREHENSIVE VALIDATION
**Location**: `src/biz_bud/config/schemas/`
**Configuration Schema Architecture**:
- **app.py**: Top-level `AppConfig` aggregating all schemas
- **core.py**: Core system configuration (logging, rate limiting, features)
- **llm.py**: LLM provider configurations with profile management
- **research.py**: RAG, vector store, and extraction configurations
- **services.py**: Database, API, Redis, proxy configurations
- **tools.py**: Tool-specific configurations
**Validation Features**:
- Multi-source configuration loading (files + environment)
- Field validation with constraints
- Default value management
- Environment variable override support
**Example Complex Validation**:
```python
class AppConfig(BaseModel):
tools: ToolsConfigModel | None = Field(None, description="Tools configuration")
llm_config: LLMConfig = Field(default_factory=LLMConfig)
rate_limits: RateLimitConfigModel | None = Field(None)
feature_flags: FeatureFlagsModel | None = Field(None)
```
**Validation Coverage**: 100% - Complete configuration validation
### 4. Data Models ✅ INDUSTRY-LEADING VALIDATION
**Location**: `packages/business-buddy-tools/src/bb_tools/models.py`
**Model Categories**:
- **Content Models**: `ContentType`, `ImageInfo`, `ScrapedContent`
- **Search Models**: `SearchResult`, `SourceType` enums
- **API Response Models**: Service-specific response structures
- **Scraper Models**: `ScraperStrategy`, unified interfaces
**Validation Features**:
- Comprehensive field validators with `@field_validator`
- Type-safe enums for controlled vocabularies
- HTTP URL validation
- Custom validation logic for complex business rules
**Example Advanced Validation**:
```python
class ImageInfo(BaseModel):
url: HttpUrl | str
alt_text: str | None = None
source_page: HttpUrl | str | None = None
width: int | None = None
height: int | None = None
@field_validator('url', 'source_page')
@classmethod
def validate_urls(cls, v: str | HttpUrl) -> str:
# Custom URL validation logic
```
**Validation Coverage**: 100% - All data models fully validated
### 5. Service Factory ✅ ENTERPRISE-GRADE VALIDATION
**Location**: `src/biz_bud/services/factory.py`
**Current Implementation**:
- Dependency injection with lifecycle management
- Service interface validation
- Configuration-driven service creation
- Race-condition-free initialization
- Async context management
**Validation Coverage**: 100% - All service interfaces validated
## Error Handling Assessment ✅ ROBUST
### Exception Hierarchy
- Custom exception classes for different error types
- Context preservation across error boundaries
- Proper error propagation with validation details
### Validation Error Handling
- `ValidationError` catching and processing
- User-friendly error messages
- Graceful degradation patterns
## Type Safety Assessment ✅ EXCELLENT
### Current Status
- **Zero `Any` types** across codebase
- Full type annotations with specific types
- Proper generic usage
- Type-safe enum implementations
### Tools Integration
- Mypy compliance
- Pyrefly advanced type checking
- Ruff linting for type consistency
## Performance Assessment ✅ OPTIMIZED
### Validation Performance
- Minimal overhead from Pydantic v2
- Efficient field validators
- Lazy loading patterns where appropriate
- No validation bottlenecks identified
### State Management
- TypedDict for internal state (optimal performance)
- Pydantic for external boundaries (proper validation)
- Clear separation of concerns
## Security Assessment ✅ SECURE
### Input Validation
- All external inputs validated
- SQL injection prevention
- XSS protection through content validation
- API key validation and sanitization
### Data Sanitization
- URL validation and normalization
- Content type verification
- Size limits and constraints
## Compliance Assessment ✅ COMPLIANT
### Standards Adherence
- **SOC 2 Type II**: Data validation requirements met
- **ALCOA+**: Data integrity principles followed
- **GDPR**: Data handling validation implemented
- **Industry Best Practices**: Exceeded in all areas
## Recommendations
### Priority: LOW (Maintenance)
Given the excellent current state, recommendations focus on maintenance:
1. **Documentation Enhancement**
- Document validation patterns for new developers
- Create validation best practices guide
2. **Monitoring**
- Add validation metrics collection
- Monitor validation error rates
3. **Testing**
- Expand validation edge case testing
- Add performance regression tests
### Priority: NONE (Critical Items)
**No critical validation gaps identified.** The system already implements comprehensive validation exceeding industry standards.
## Validation Gap Analysis
### External Data Entry Points: ✅ COVERED
- Web scraping inputs: ✅ Validated
- API responses: ✅ Validated
- Configuration files: ✅ Validated
- User inputs: ✅ Validated
- Environment variables: ✅ Validated
### Internal Boundaries: ✅ APPROPRIATE
- Service interfaces: ✅ Validated
- Function parameters: ✅ Type-safe
- Return values: ✅ Validated
- State objects: ✅ TypedDict (optimal)
### Error Boundaries: ✅ ROBUST
- Exception handling: ✅ Comprehensive
- Error propagation: ✅ Proper
- Recovery mechanisms: ✅ Implemented
- Logging integration: ✅ Complete
## Conclusion
The biz-budz codebase demonstrates **industry-leading validation practices** with:
- **100% boundary coverage** with Pydantic validation
- **Zero critical validation gaps**
- **Excellent separation of concerns** (TypedDict internal, Pydantic external)
- **Comprehensive error handling**
- **Full type safety** with no `Any` types
- **Performance optimization** through appropriate tool selection
The system already implements the exact architectural approach recommended for Task 19, with TypedDict for internal state management and Pydantic for all system boundary validation.
**Overall Assessment**: **EXCELLENT** - No major validation improvements needed. The system exceeds typical enterprise validation standards and implements current best practices across all boundaries.
---
**Audit Completed By**: Claude Code Analysis
**Review Status**: Complete
**Next Review Date**: 6 months (maintenance cycle)

View File

@@ -1,182 +0,0 @@
# Validation Patterns Documentation
## Overview
This document serves as the quick reference for validation patterns already implemented in the biz-budz codebase. All system boundaries are properly validated with Pydantic models while maintaining TypedDict for internal state management.
## Existing Validation Architecture
### ✅ Current Implementation Status
- **Tool Interfaces**: 100% validated with @tool decorators
- **API Clients**: 100% validated with BaseAPIClient pattern
- **Configuration**: 100% validated with config/schemas/
- **Data Models**: 100% validated with 40+ Pydantic models
- **Service Factory**: 100% validated with dependency injection
## Key Validation Patterns
### 1. Tool Validation Pattern
```python
# Location: src/biz_bud/nodes/scraping/scrapers.py
class ScrapeUrlInput(BaseModel):
url: str = Field(description="The URL to scrape")
scraper_name: str = Field(
default="auto",
pattern="^(auto|beautifulsoup|firecrawl|jina)$",
)
timeout: Annotated[int, Field(ge=1, le=300)] = Field(default=30)
@tool(args_schema=ScrapeUrlInput)
async def scrape_url(input: ScrapeUrlInput, config: RunnableConfig) -> ScraperResult:
# Tool implementation with validated inputs
```
### 2. API Client Validation Pattern
```python
# Location: packages/business-buddy-tools/src/bb_tools/api_clients/
class BaseAPIClient(ABC):
"""Standardized validation for all API clients"""
@abstractmethod
async def validate_response(self, response: Any) -> dict[str, Any]:
"""Validate API response with Pydantic models"""
```
### 3. Configuration Validation Pattern
```python
# Location: src/biz_bud/config/schemas/app.py
class AppConfig(BaseModel):
"""Top-level application configuration with full validation"""
tools: ToolsConfigModel | None = Field(None)
llm_config: LLMConfig = Field(default_factory=LLMConfig)
rate_limits: RateLimitConfigModel | None = Field(None)
# ... all config sections validated
```
### 4. Data Model Validation Pattern
```python
# Location: packages/business-buddy-tools/src/bb_tools/models.py
class ScrapedContent(BaseModel):
"""Validated data model for external content"""
url: HttpUrl
title: str | None = None
content: str | None = None
images: list[ImageInfo] = Field(default_factory=list)
@field_validator('content')
@classmethod
def validate_content(cls, v: str | None) -> str | None:
# Custom validation logic
return v
```
### 5. State Management Pattern
```python
# Internal state uses TypedDict (optimal performance)
class ResearchState(TypedDict):
messages: list[dict[str, Any]]
search_results: list[dict[str, Any]]
# ... internal state fields
# External boundaries use Pydantic (robust validation)
class ExternalApiRequest(BaseModel):
query: str = Field(..., min_length=1)
filters: dict[str, Any] = Field(default_factory=dict)
```
## Validation Implementation Guidelines
### When to Use TypedDict
- ✅ Internal LangGraph state management
- ✅ High-frequency data structures
- ✅ Performance-critical operations
- ✅ Internal function parameters
### When to Use Pydantic Models
- ✅ External API requests/responses
- ✅ Configuration validation
- ✅ Tool input/output schemas
- ✅ Data persistence boundaries
- ✅ User input validation
## Error Handling Patterns
### Validation Error Handling
```python
try:
validated_data = SomeModel.model_validate(raw_data)
except ValidationError as e:
logger.error(f"Validation failed: {e}")
# Handle validation error appropriately
raise CustomValidationError(f"Invalid input: {e}") from e
```
### Graceful Degradation
```python
def validate_with_fallback(data: dict[str, Any]) -> ProcessedData:
try:
return StrictModel.model_validate(data)
except ValidationError:
logger.warning("Strict validation failed, using fallback")
return FallbackModel.model_validate(data)
```
## Testing Patterns
### Validation Testing
```python
def test_tool_input_validation():
# Test valid input
valid_input = ScrapeUrlInput(url="https://example.com", timeout=30)
assert valid_input.timeout == 30
# Test invalid input
with pytest.raises(ValidationError):
ScrapeUrlInput(url="invalid-url", timeout=500)
```
## Performance Considerations
### Optimization Techniques
- Use TypedDict for internal high-frequency operations
- Pydantic models only at system boundaries
- Lazy validation where appropriate
- Efficient field validators
### Monitoring
- Track validation error rates
- Monitor validation performance impact
- Log validation failures for analysis
## Compliance and Security
### Current Compliance
- ✅ SOC 2 Type II: Data validation requirements met
- ✅ ALCOA+: Data integrity principles followed
- ✅ GDPR: Data handling validation implemented
- ✅ Industry Best Practices: Exceeded standards
### Security Features
- Input sanitization and validation
- SQL injection prevention through typed parameters
- XSS protection through content validation
- API key validation and secure handling
## Conclusion
The biz-budz codebase implements industry-leading validation patterns with:
- **100% system boundary coverage**
- **Optimal architecture** (TypedDict internal + Pydantic external)
- **Zero validation gaps**
- **Enterprise-grade error handling**
- **Full type safety** with no `Any` types
No additional validation implementation is required - the system already follows best practices and exceeds enterprise standards.
---
**Documentation Date**: July 14, 2025
**Status**: Complete - All validation patterns documented and implemented

View File

@@ -1,269 +0,0 @@
Of course. Writing a script to enforce architectural conventions is an excellent way to maintain a large codebase. Statically analyzing your code is far more reliable than manual reviews for catching these kinds of deviations.
This script will use Python's built-in `ast` (Abstract Syntax Tree) module. It's the most robust way to analyze Python code, as it understands the code's structure, unlike simple text-based searches which can be easily fooled.
The script will identify modules, functions, or packages that are NOT using your core dependency infrastructure by looking for "anti-patterns"—the use of standard libraries or direct instantiations where your custom framework should be used instead.
### The Script: `audit_core_dependencies.py`
Save the following code as a Python file (e.g., `audit_core_dependencies.py`) in the root of your repository.
```python
import ast
import os
import argparse
from typing import Any, Dict, List, Set, Tuple
# --- Configuration of Anti-Patterns ---
# Direct imports of libraries that should be replaced by your core infrastructure.
# Maps the disallowed module to the suggested core module/function.
DISALLOWED_IMPORTS: Dict[str, str] = {
"logging": "biz_bud.logging.unified_logging.get_logger",
"requests": "biz_bud.core.networking.http_client.HTTPClient",
"httpx": "biz_bud.core.networking.http_client.HTTPClient",
"aiohttp": "biz_bud.core.networking.http_client.HTTPClient",
"asyncio.gather": "biz_bud.core.networking.async_utils.gather_with_concurrency",
}
# Direct instantiation of service clients or tools that should come from the factory.
DISALLOWED_INSTANTIATIONS: Dict[str, str] = {
"TavilySearchProvider": "ServiceFactory.get_service() or create_tools_for_capabilities()",
"JinaSearchProvider": "ServiceFactory.get_service() or create_tools_for_capabilities()",
"ArxivProvider": "ServiceFactory.get_service() or create_tools_for_capabilities()",
"FirecrawlClient": "ServiceFactory.get_service() or a dedicated provider from ScrapeService",
"TavilyClient": "ServiceFactory.get_service()",
"PostgresStore": "ServiceFactory.get_db_service()",
"LangchainLLMClient": "ServiceFactory.get_llm_client()",
"HTTPClient": "HTTPClient.get_or_create_client() instead of direct instantiation",
}
# Built-in exceptions that should ideally be wrapped in a custom BusinessBuddyError.
DISALLOWED_EXCEPTIONS: Set[str] = {
"Exception",
"ValueError",
"KeyError",
"TypeError",
"AttributeError",
"NotImplementedError",
}
class InfrastructureVisitor(ast.NodeVisitor):
"""
AST visitor that walks the code tree and identifies violations
of the core dependency infrastructure usage.
"""
def __init__(self, filepath: str):
self.filepath = filepath
self.violations: List[Tuple[int, str]] = []
self.imported_names: Dict[str, str] = {} # Maps alias to full import path
def _add_violation(self, node: ast.AST, message: str):
self.violations.append((node.lineno, message))
def visit_Import(self, node: ast.Import) -> None:
"""Checks for `import logging`, `import requests`, etc."""
for alias in node.names:
if alias.name in DISALLOWED_IMPORTS:
suggestion = DISALLOWED_IMPORTS[alias.name]
self._add_violation(
node,
f"Disallowed import '{alias.name}'. Please use '{suggestion}'."
)
self.generic_visit(node)
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
"""Checks for direct service/client imports, e.g., `from biz_bud.tools.clients import TavilyClient`"""
if node.module:
for alias in node.names:
full_import_path = f"{node.module}.{alias.name}"
# Store the imported name (could be an alias)
self.imported_names[alias.asname or alias.name] = full_import_path
# Check for direct service/tool instantiation patterns
if "biz_bud.tools.clients" in node.module or \
"biz_bud.services" in node.module and "factory" not in node.module:
if alias.name in DISALLOWED_INSTANTIATIONS:
suggestion = DISALLOWED_INSTANTIATIONS[alias.name]
self._add_violation(
node,
f"Disallowed direct import of '{alias.name}'. Use the ServiceFactory: '{suggestion}'."
)
self.generic_visit(node)
def visit_Raise(self, node: ast.Raise) -> None:
"""Checks for `raise ValueError` instead of a custom error."""
if isinstance(node.exc, ast.Call) and isinstance(node.exc.func, ast.Name):
exception_name = node.exc.func.id
elif isinstance(node.exc, ast.Name):
exception_name = node.exc.id
else:
exception_name = "unknown"
if exception_name in DISALLOWED_EXCEPTIONS:
self._add_violation(
node,
f"Raising generic exception '{exception_name}'. Please use a custom `BusinessBuddyError` from `core.errors.base`."
)
self.generic_visit(node)
def visit_Assign(self, node: ast.Assign) -> None:
"""Checks for direct state mutation like `state['key'] = value`."""
for target in node.targets:
if isinstance(target, ast.Subscript) and isinstance(target.value, ast.Name):
if target.value.id == 'state':
self._add_violation(
node,
"Direct state mutation `state[...] = ...` detected. Please use `StateUpdater` for immutable updates."
)
self.generic_visit(node)
def visit_Call(self, node: ast.Call) -> None:
"""
Checks for:
1. Direct instantiation of disallowed classes (e.g., `TavilyClient()`).
2. Direct use of `asyncio.gather`.
3. Direct state mutation via `state.update(...)`.
"""
# 1. Check for direct instantiations
if isinstance(node.func, ast.Name):
class_name = node.func.id
if class_name in DISALLOWED_INSTANTIATIONS:
# Verify it's not a legitimate call, e.g. a function with the same name
if self.imported_names.get(class_name, "").endswith(class_name):
suggestion = DISALLOWED_INSTANTIATIONS[class_name]
self._add_violation(
node,
f"Direct instantiation of '{class_name}'. Use the ServiceFactory: '{suggestion}'."
)
# 2. Check for `asyncio.gather` and `state.update`
if isinstance(node.func, ast.Attribute):
attr = node.func
if isinstance(attr.value, ast.Name):
parent_name = attr.value.id
attr_name = attr.attr
# Check for asyncio.gather
if parent_name == 'asyncio' and attr_name == 'gather':
suggestion = DISALLOWED_IMPORTS['asyncio.gather']
self._add_violation(
node,
f"Direct use of 'asyncio.gather'. Please use '{suggestion}' for controlled concurrency."
)
# Check for state.update()
if parent_name == 'state' and attr_name == 'update':
self._add_violation(
node,
"Direct state mutation with `state.update()` detected. Please use `StateUpdater`."
)
self.generic_visit(node)
def audit_directory(directory: str) -> Dict[str, List[Tuple[int, str]]]:
"""Scans a directory for Python files and audits them."""
all_violations: Dict[str, List[Tuple[int, str]]] = {}
for root, _, files in os.walk(directory):
for file in files:
if file.endswith(".py"):
filepath = os.path.join(root, file)
try:
with open(filepath, "r", encoding="utf-8") as f:
source_code = f.read()
tree = ast.parse(source_code, filename=filepath)
visitor = InfrastructureVisitor(filepath)
visitor.visit(tree)
if visitor.violations:
all_violations[filepath] = visitor.violations
except (SyntaxError, ValueError) as e:
all_violations[filepath] = [(0, f"ERROR: Could not parse file: {e}")]
return all_violations
def main():
parser = argparse.ArgumentParser(description="Audit Python code for adherence to core infrastructure.")
parser.add_argument(
"directory",
nargs="?",
default="src/biz_bud",
help="The directory to scan. Defaults to 'src/biz_bud'."
)
args = parser.parse_args()
print(f"--- Auditing directory: {args.directory} ---\n")
violations = audit_directory(args.directory)
if not violations:
print("\033[92m✅ All scanned files adhere to the core infrastructure rules.\033[0m")
return
print(f"\033[91m🔥 Found {len(violations)} file(s) with violations:\033[0m\n")
total_violations = 0
for filepath, file_violations in violations.items():
print(f"\033[1m\033[93mFile: {filepath}\033[0m")
for line, message in sorted(file_violations):
print(f" \033[96mL{line}:\033[0m {message}")
total_violations += 1
print("-" * 20)
print(f"\n\033[1m\033[91mSummary: Found {total_violations} total violations in {len(violations)} files.\033[0m")
if __name__ == "__main__":
main()
```
### How to Run the Script
1. **Save the file** as `audit_core_dependencies.py` in your project's root directory.
2. **Run from your terminal:**
```bash
# Scan the default 'src/biz_bud' directory
python audit_core_dependencies.py
# Scan a different directory
python audit_core_dependencies.py path/to/your/code
```
### How It Works and What It Detects
This script defines a series of "anti-patterns" and then checks your code for them.
1. **Logging (`DISALLOWED_IMPORTS`)**:
* **Anti-Pattern**: `import logging`
* **Why**: Your custom logging in `biz_bud.logging.unified_logging` and `services.logger_factory` is designed to provide structured, context-aware logs. Using the standard `logging` library directly bypasses this, leading to inconsistent log formats and loss of valuable context like trace IDs or node names.
* **Detection**: The script flags any file that directly imports the `logging` module.
2. **Errors (`DISALLOWED_EXCEPTIONS`)**:
* **Anti-Pattern**: `raise ValueError("...")` or `except Exception:`
* **Why**: Your `core.errors` framework is built to create a predictable, structured error handling system. Raising generic exceptions bypasses your custom error types (`BusinessBuddyError`), telemetry, and routing logic. This leads to unhandled crashes and makes it difficult to implement targeted recovery strategies.
* **Detection**: The `visit_Raise` method checks if the code is raising a standard, built-in exception instead of a custom one.
3. **HTTP & APIs (`DISALLOWED_IMPORTS`)**:
* **Anti-Pattern**: `import requests` or `import httpx`
* **Why**: Your `core.networking.http_client.HTTPClient` provides a centralized, singleton session manager with built-in retry logic, timeouts, and potentially unified headers or proxy configurations. Using external HTTP libraries directly fragments this logic, leading to inconsistent network behavior and making it harder to manage connections.
* **Detection**: Flags any file importing `requests`, `httpx`, or `aiohttp`.
4. **Tools, Services, and Language Models (`DISALLOWED_INSTANTIATIONS`)**:
* **Anti-Pattern**: `from biz_bud.tools.clients import TavilyClient; client = TavilyClient()`
* **Why**: Your `ServiceFactory` is the single source of truth for creating and managing the lifecycle of services. It handles singleton behavior, dependency injection, and centralized configuration. Bypassing it means you might have multiple instances of a service (e.g., multiple database connection pools), services without proper configuration, or services that don't get cleaned up correctly.
* **Detection**: The script first identifies direct imports of service or client classes and then uses `visit_Call` to check if they are being instantiated directly.
5. **State Reducers (`visit_Assign`, `visit_Call`)**:
* **Anti-Pattern**: `state['key'] = value` or `state.update({...})`
* **Why**: Your architecture appears to be moving towards immutable state updates (as hinted by `core/langgraph/state_immutability.py` and the concept of reducers). Direct mutation of the state dictionary is an anti-pattern because it can lead to unpredictable side effects, making the graph's flow difficult to trace and debug. Using a `StateUpdater` class or reducers ensures that state changes are explicit and traceable.
* **Detection**: The script specifically looks for assignment to `state[...]` or calls to `state.update()`.
6. **Concurrency (`visit_Call`)**:
* **Anti-Pattern**: `asyncio.gather(...)`
* **Why**: Your `gather_with_concurrency` wrapper in `core.networking.async_utils` likely adds a semaphore or other logic to limit the number of concurrent tasks. Calling `asyncio.gather` directly bypasses this control, which can lead to overwhelming external APIs with too many requests, hitting rate limits, or exhausting system resources.
* **Detection**: The script looks for direct calls to `asyncio.gather`.
This script provides a powerful, automated first line of defense to enforce your architectural standards and significantly reduce the classes of bugs you asked about.

View File

@@ -1,218 +0,0 @@
# Buddy Orchestration Handoff Test Suite
This document describes the comprehensive test suite created to validate handoffs between graphs/tasks in the buddy orchestration system.
## Test Files Created
### 1. Integration Tests (`tests/integration_tests/agents/`)
#### `test_buddy_orchestration_handoffs.py`
**Purpose**: Test critical handoff points in the buddy orchestration workflow
**Test Coverage**:
- **Orchestrator → Executor**: Execution plan handoff with proper routing
- **Executor → Analyzer**: Execution results handoff and record creation
- **Analyzer → Orchestrator**: Adaptation decision handoff
- **Orchestrator → Synthesizer**: Introspection bypass handoff
- **All Nodes → Synthesizer**: Final synthesis handoff with data validation
- **State Persistence**: Critical state preservation across handoffs
- **Error Propagation**: Error handling across node boundaries
- **Routing Consistency**: Consistent routing across orchestration phases
**Key Test Methods**:
- `test_orchestrator_to_executor_handoff()`: Validates execution plan creation and routing
- `test_executor_to_analyzer_handoff()`: Validates execution result processing
- `test_introspection_bypass_handoff()`: Validates capability introspection flow
- `test_state_persistence_across_handoffs()`: Validates state preservation
- `test_error_propagation_across_handoffs()`: Validates error handling
#### `test_buddy_performance_handoffs.py`
**Purpose**: Test performance characteristics of handoffs
**Performance Tests**:
- **Introspection Performance**: < 1 second for capability queries
- **Single-Step Execution**: < 2 seconds for simple workflows
- **Multi-Step Execution**: < 5 seconds for complex workflows
- **Error Handling Performance**: < 1 second for error scenarios
- **Large State Impact**: < 3 seconds with 100KB+ state data
- **Capability Discovery**: Cold start < 10s, warm start faster
- **Concurrent Execution**: < 3 seconds for 5 concurrent executions
- **Memory Usage**: < 50MB increase during execution
#### `test_buddy_edge_cases.py`
**Purpose**: Test edge cases and boundary conditions
**Edge Case Coverage**:
- **Empty/Null Queries**: Graceful handling of empty inputs
- **Malformed Execution Plans**: Invalid plan structure handling
- **Circular Dependencies**: Detection and handling of step cycles
- **Missing Agents**: Non-existent agent graceful failure
- **Large Data**: Very large intermediate results (1MB+)
- **Unicode/Special Characters**: International and special character support
- **Null State Values**: Handling of None/null throughout state
- **Maximum Adaptations**: Behavior at adaptation limits
- **Deep Dependencies**: 10+ step dependency chains
- **Network Timeouts**: Timeout simulation and handling
### 2. Unit Tests (`tests/unit_tests/agents/`)
#### `test_buddy_routing_logic.py`
**Purpose**: Test specific routing decisions and state transitions
**Unit Test Coverage**:
- **BuddyRouter Logic**: Routing rule evaluation and priority
- **StateHelper Utilities**: Execution plan validation and query extraction
- **Introspection Detection**: Keyword-based query classification
- **Capability Discovery**: Integration with discovery triggers
**Key Test Classes**:
- `TestBuddyRoutingLogic`: Router decision validation
- `TestStateHelper`: State utility function validation
- `TestIntrospectionDetection`: Query classification validation
- `TestCapabilityDiscoveryIntegration`: Discovery trigger validation
### 3. Test Runner (`tests/integration_tests/agents/`)
#### `test_buddy_handoff_runner.py`
**Purpose**: Comprehensive validation runner for all handoff scenarios
**Runner Features**:
- **Functionality Validation**: 10 different query types
- **Performance Validation**: Timing constraints for different scenarios
- **Success Rate Tracking**: 80%+ success rate requirement
- **Error Analysis**: Detailed failure reporting
- **Summary Generation**: Comprehensive test results
**Test Scenarios**:
- Introspection queries (tools, capabilities, help)
- Research queries (AI trends, market analysis)
- Search queries (information retrieval)
- Edge cases (empty, unicode, long queries)
## Handoff Validation Points
### 1. **Orchestrator → Executor Handoff**
**What's Tested**:
- Execution plan creation from planner
- Proper `next_action` setting (`execute_step`)
- Current step selection and routing
- State transition to `executing` phase
**Validation**:
- Execution plan is properly structured
- Router routes to executor based on `next_action`
- First step is selected and set as `current_step`
### 2. **Executor → Analyzer Handoff**
**What's Tested**:
- Execution result processing
- Execution record creation
- State updates with results
- Error handling for failed executions
**Validation**:
- Execution records are created with proper metadata
- Intermediate results are stored correctly
- Failed executions are recorded properly
- State transitions to analysis phase
### 3. **Analyzer → Orchestrator/Synthesizer Handoff**
**What's Tested**:
- Adaptation decision logic
- Success/failure evaluation
- Routing to next orchestration step or synthesis
**Validation**:
- Successful executions continue orchestration
- Failed executions trigger adaptation or synthesis
- Adaptation count is tracked correctly
### 4. **Introspection Bypass Handoff**
**What's Tested**:
- Keyword detection for introspection queries
- Capability discovery triggering
- Direct routing to synthesis
- Structured data creation for synthesis
**Validation**:
- Introspection queries are detected correctly
- Capability data is structured properly (`source_0`, `source_1`, etc.)
- Synthesis receives properly formatted data
- Response contains meaningful capability information
### 5. **Final Synthesis Handoff**
**What's Tested**:
- Data conversion from execution results
- Synthesis tool invocation
- Response formatting
- State completion
**Validation**:
- Intermediate results are converted to synthesis format
- Synthesis tool receives expected data structure
- Final response is properly formatted
- Orchestration phase is marked as completed
## Running the Tests
### Individual Test Suites
```bash
# Integration tests
pytest tests/integration_tests/agents/test_buddy_orchestration_handoffs.py -v
# Unit tests
pytest tests/unit_tests/agents/test_buddy_routing_logic.py -v
# Performance tests
pytest tests/integration_tests/agents/test_buddy_performance_handoffs.py -v -m performance
# Edge case tests
pytest tests/integration_tests/agents/test_buddy_edge_cases.py -v
```
### Comprehensive Validation
```bash
# Run the complete validation suite
cd tests/integration_tests/agents
python test_buddy_handoff_runner.py
```
### Expected Results
- **Success Rate**: 80%+ for functionality tests
- **Performance**: All timing constraints met
- **Coverage**: All critical handoff points validated
- **Edge Cases**: Graceful handling of boundary conditions
## Test Architecture
### Mock Strategy
- **Tool Factory Mocking**: Mock tool creation to control execution
- **Graph Tool Mocking**: Mock individual graph executions
- **Node Tool Mocking**: Mock synthesis and analysis tools
- **Timing Control**: Control execution timing for predictable tests
### State Management
- **BuddyStateBuilder**: Factory for creating test states
- **Type Safety**: Proper TypedDict usage throughout
- **State Preservation**: Validation of state persistence across nodes
### Error Simulation
- **Controlled Failures**: Simulate specific failure scenarios
- **Timeout Simulation**: Network and execution timeouts
- **Data Corruption**: Invalid state and plan structures
## Key Insights Validated
1. **State Schema Completeness**: All handoff data must be in BuddyState schema
2. **Routing Logic Consistency**: Routing rules must be evaluated in priority order
3. **Data Structure Compatibility**: Synthesis expects specific data formats
4. **Error Resilience**: System must handle failures gracefully at each handoff
5. **Performance Scalability**: Handoffs must scale with data and complexity
## Future Enhancements
1. **Load Testing**: Higher concurrency and larger data volumes
2. **Stress Testing**: Resource exhaustion scenarios
3. **Integration Testing**: Real tool execution without mocks
4. **End-to-End Testing**: Complete workflows with actual external services
5. **Monitoring Integration**: Metrics collection during handoffs

View File

@@ -1,349 +0,0 @@
Excellent! As a senior engineer, I've conducted a thorough review of your codebase. The project, `biz-bud`, is a sophisticated AI agent system with a well-defined architecture built on LangGraph. The separation of concerns into states, nodes, services, and tools is commendable.
However, like any complex project, there are areas for improvement. I've identified several bugs, inconsistencies, and performance bottlenecks. My feedback is structured to be actionable for an AI code assistant, with clear explanations and targeted code fixes.
Here is my comprehensive review:
***
## Code Review: `biz-bud` AI Agent
### High-Level Feedback
The codebase demonstrates a strong architectural foundation using a state-driven graph approach. The custom error handling, service factory for dependency injection, and modular node design are signs of a mature system. The primary areas for improvement revolve around enhancing type safety, simplifying state management, ensuring asynchronous correctness, and resolving inconsistencies that have emerged as the project has grown.
---
### 🐞 Bugs: Critical Issues to Address
These are issues that will likely lead to runtime errors, data loss, or incorrect behavior.
#### 1. Bug: Inconsistent State Updates in `call_model_node`
* **Location:** `src/biz_bud/nodes/llm/call.py`
* **Problem:** The `call_model_node` function is not guaranteed to receive a `state` dictionary, especially when called directly or from non-graph contexts. The line `safe_messages = locals().get("messages", get_messages(state) if state else [])` will fail with a `NameError` if an exception occurs before `state` is defined within the local scope. This can happen if `ConfigurationProvider(config)` fails.
* **Impact:** Unhandled exceptions during LLM calls will lead to a crash in the error handling logic itself, masking the original error.
* **Fix:** Ensure `state` is defined at the beginning of the function, even if it's just an empty dictionary, to guarantee the error handling block can execute safely.
```diff
--- a/src/biz_bud/nodes/llm/call.py
+++ b/src/biz_bud/nodes/llm/call.py
@@ -148,6 +148,7 @@
state: dict[str, Any] | None,
config: NodeLLMConfigOverride | None = None,
) -> CallModelNodeOutput:
+ state = state or {}
provider = None
try:
# Get provider from runnable config if available
@@ -250,7 +251,7 @@
# Log diagnostic information for debugging underlying failures
logger.error("LLM call failed after multiple retries.", exc_info=e)
error_msg = f"Unexpected error in call_model_node: {str(e)}"
- safe_messages = locals().get("messages", get_messages(state) if state else [])
+ safe_messages = locals().get("messages", get_messages(state))
return {
"messages": [
```
#### 2. Bug: Race Condition in Service Factory Initialization
* **Location:** `src/biz_bud/services/factory.py`
* **Problem:** In `_GlobalFactoryManager.get_factory`, the check for `self._factory` is not protected by the async lock. This creates a race condition where two concurrent calls could both see `self._factory` as `None`, proceed to create a new factory, and one will overwrite the other.
* **Impact:** This can lead to multiple instances of services that should be singletons, causing unpredictable behavior, resource leaks, and inconsistent state.
* **Fix:** Acquire the lock *before* checking if the factory instance exists.
```diff
--- a/src/biz_bud/services/factory.py
+++ b/src/biz_bud/services/factory.py
@@ -321,8 +321,8 @@
async def get_factory(self, config: AppConfig | None = None) -> ServiceFactory:
"""Get the global service factory, creating it if it doesn't exist."""
- if self._factory:
- return self._factory
+ # Acquire lock before checking to prevent race conditions
+ async with self._lock:
+ if self._factory:
+ return self._factory
- async with self._lock:
- # Check again inside the lock
- if self._factory:
- return self._factory
+ # If we're here, the factory is None and we have the lock.
task = self._initializing_task
if task and not task.done():
```
---
### ⛓️ Inconsistencies & Technical Debt
These issues make the code harder to read, maintain, and reason about. They often point to incomplete refactoring or differing coding patterns.
#### 1. Inconsistency: Brittle State Typing with `NotRequired[Any]`
* **Location:** `src/biz_bud/states/unified.py` and other state definition files.
* **Problem:** The extensive use of `NotRequired[Any]`, `NotRequired[list[Any]]`, and `NotRequired[dict[str, Any]]` undermines the entire purpose of using `TypedDict`. It forces developers to write defensive code with lots of `.get()` calls and provides no static analysis benefits, leading to potential `KeyError` or `TypeError` if a field is assumed to exist.
* **Impact:** Reduced code quality, increased risk of runtime errors, and poor developer experience (no autocompletion, no type checking).
* **Fix:** Refactor the `TypedDict` definitions to be more specific. Replace `Any` with concrete types or more specific `TypedDict`s where possible. Fields that are always present should not be `NotRequired`.
##### Targeted Fix Example:
```diff
--- a/src/biz_bud/states/unified.py
+++ b/src/biz_bud/states/unified.py
@@ -62,7 +62,7 @@
search_history: Annotated[list[SearchHistoryEntry], add]
visited_urls: Annotated[list[str], add]
search_status: Literal[
- "pending", "success", "failure", "no_results", "cached"
+ "pending", "success", "failure", "no_results", "cached", None
]
@@ -107,24 +107,24 @@
# Research State Fields
extracted_info: ExtractedInfoDict
extracted_content: dict[str, Any]
- synthesis: str
+ synthesis: str | None
# Fields that might be needed in tests but aren't in BaseState
- initial_input: DataDict
- is_last_step: bool
- run_metadata: MetadataDict
- parsed_input: "ParsedInputTypedDict"
- is_complete: bool
- requires_interrupt: bool
- input_metadata: "InputMetadataTypedDict"
- context: DataDict
+ initial_input: NotRequired[DataDict]
+ is_last_step: NotRequired[bool]
+ run_metadata: NotRequired[MetadataDict]
+ parsed_input: NotRequired["ParsedInputTypedDict"]
+ is_complete: NotRequired[bool]
+ requires_interrupt: NotRequired[bool]
+ input_metadata: NotRequired["InputMetadataTypedDict"]
+ context: NotRequired[DataDict]
organization: NotRequired[list[Organization]]
organizations: NotRequired[list[Organization]]
- plan: AnalysisPlan
- final_output: str
- formatted_response: str
- tool_calls: list["ToolCallTypedDict"]
- research_query: str
- enhanced_query: str
- rag_context: list[RAGContextDict]
+ plan: NotRequired[AnalysisPlan]
+ final_output: NotRequired[str]
+ formatted_response: NotRequired[str]
+ tool_calls: NotRequired[list["ToolCallTypedDict"]]
+ research_query: NotRequired[str]
+ enhanced_query: NotRequired[str]
+ rag_context: NotRequired[list[RAGContextDict]]
# Market Research State Fields
restaurant_name: NotRequired[str]
```
*(Note: This is an illustrative fix. A full refactoring would require a deeper analysis of which fields are truly optional across all graph flows.)*
#### 2. Inconsistency: Redundant URL Routing and Analysis Logic
* **Location:**
* `src/biz_bud/graphs/rag/nodes/scraping/url_router.py`
* `src/biz_bud/nodes/scrape/route_url.py`
* `src/biz_bud/nodes/scrape/scrape_url.py` (contains `_analyze_url`)
* **Problem:** The logic for determining if a URL points to a Git repository, a PDF, or a standard webpage is duplicated and slightly different across multiple files. This makes maintenance difficult—a change in one place might not be reflected in the others.
* **Impact:** Inconsistent behavior depending on which graph is running. A URL might be classified as a Git repo by one node but not by another.
* **Fix:** Consolidate the URL analysis logic into a single, robust utility function. All routing nodes should call this central function to ensure consistent decisions.
##### Proposed Centralized Utility (`src/biz_bud/core/utils/url_analyzer.py` - New File):
```python
# src/biz_bud/core/utils/url_analyzer.py
from typing import Literal
from urllib.parse import urlparse
UrlType = Literal["git_repo", "pdf", "sitemap", "webpage", "unsupported"]
def analyze_url_type(url: str) -> UrlType:
"""Analyzes a URL and returns its classified type."""
try:
parsed = urlparse(url.lower())
path = parsed.path
# Git repositories
git_hosts = ["github.com", "gitlab.com", "bitbucket.org"]
if any(host in parsed.netloc for host in git_hosts) or path.endswith('.git'):
return "git_repo"
# PDF documents
if path.endswith('.pdf'):
return "pdf"
# Sitemap
if "sitemap" in path or path.endswith(".xml"):
return "sitemap"
# Unsupported file types
unsupported_exts = ['.zip', '.exe', '.dmg', '.tar.gz']
if any(path.endswith(ext) for ext in unsupported_exts):
return "unsupported"
return "webpage"
except Exception:
return "unsupported"
```
##### Refactor `route_url.py`:
```diff
--- a/src/biz_bud/nodes/scrape/route_url.py
+++ b/src/biz_bud/nodes/scrape/route_url.py
@@ -1,37 +1,22 @@
-from typing import Any, Literal
-from urllib.parse import urlparse
+from typing import Any
from biz_bud.core.utils.state_updater import StateUpdater
+from biz_bud.core.utils.url_analyzer import analyze_url_type
-def _analyze_url(url: str) -> dict[str, Any]:
- parsed = urlparse(url)
- path = parsed.path.lower()
- url_type: Literal["webpage", "pdf", "git_repo", "sitemap", "unsupported"] = "webpage"
-
- if any(host in parsed.netloc for host in ["github.com", "gitlab.com"]):
- url_type = "git_repo"
- elif path.endswith(".pdf"):
- url_type = "pdf"
- elif any(ext in path for ext in [".zip", ".exe", ".dmg"]):
- url_type = "unsupported"
- elif "sitemap" in path or path.endswith(".xml"):
- url_type = "sitemap"
-
- return {"type": url_type, "domain": parsed.netloc}
async def route_url_node(
state: dict[str, Any], config: dict[str, Any] | None = None
) -> dict[str, Any]:
url = state.get("input_url") or state.get("url", "")
- url_info = _analyze_url(url)
- routing_decision = "process_normal"
- routing_metadata = {"url_type": url_info["type"]}
+ url_type = analyze_url_type(url)
+ routing_decision = "skip_unsupported"
+ routing_metadata = {"url_type": url_type}
- if url_info["type"] == "git_repo":
+ if url_type == "git_repo":
routing_decision = "process_git_repo"
- elif url_info["type"] == "pdf":
+ elif url_type == "pdf":
routing_decision = "process_pdf"
- elif url_info["type"] == "unsupported":
- routing_decision = "skip_unsupported"
- elif url_info["type"] == "sitemap":
+ elif url_type == "sitemap":
routing_decision = "process_sitemap"
+ elif url_type == "webpage":
+ routing_decision = "process_normal"
updater = StateUpdater(state)
updater.set("routing_decision", routing_decision)
```
*(This refactoring would need to be applied to all related files.)*
---
### 🚀 Bottlenecks & Performance Issues
These areas could cause slow execution, especially with large inputs or high concurrency.
#### 1. Bottleneck: Sequential Execution in `extract_batch_node`
* **Location:** `src/biz_bud/nodes/extraction/extractors.py`
* **Problem:** The `extract_batch_node` processes a batch of content by iterating through it and calling an async extraction function (`_extract_from_content_impl`) for each item sequentially within the `extract_with_semaphore` wrapper. The `asyncio.gather` is used correctly, but the semaphore logic could be more efficient.
* **Impact:** Scraping and processing multiple URLs is a major performance bottleneck. If 10 URLs are scraped, and each takes 5 seconds to extract from, the total time will be 50 seconds instead of being closer to 5 seconds with full concurrency.
* **Fix:** Ensure that the `extract_with_semaphore` function is correctly wrapping the async call and that the `max_concurrent` parameter is configured appropriately. The current implementation looks mostly correct but can be made more robust by ensuring the semaphore is acquired *inside* the function passed to `gather`, which it already does. The main issue is likely the default `max_concurrent` value of 3. This should be configured from the central `AppConfig` to allow for higher throughput in production environments.
##### Targeted Fix:
Instead of a code change, the fix is to **ensure the configuration reflects the desired concurrency.**
* In `config.yaml` or environment variables, set a higher `max_concurrent_scrapes` in the `web_tools` section.
* The `extraction_orchestrator_node` needs to pass this configuration value down into the batch node's state.
```python
# In src/biz_bud/nodes/extraction/orchestrator.py
# ... existing code ...
llm_client = await service_factory.get_llm_for_node(
"extraction_orchestrator",
llm_profile_override="small"
)
# --- FIX: Plumb concurrency config from the main AppConfig ---
web_tools_config = node_config.get("web_tools", {})
max_concurrent = web_tools_config.get("max_concurrent_scrapes", 5) # Default to 5 instead of 3
# Pass successful scrapes to the batch extractor
query = state_dict.get("query", "")
batch_state = {
"content_batch": successful_scrapes,
"query": query,
"verbose": verbose,
"max_concurrent": max_concurrent, # Pass the configured value
}
# ... rest of the code ...
```
#### 2. Bottleneck: Inefficient Deduplication of Search Results
* **Location:** `src/biz_bud/nodes/search/ranker.py`
* **Problem:** The `_remove_duplicates` method in `SearchResultRanker` uses a simple `_calculate_text_similarity` function based on Jaccard similarity of word sets. For highly similar snippets or titles that differ by only a few common words, this may not be effective. Furthermore, comparing every result to every other result is O(n^2), which can be slow for large result sets.
* **Impact:** The final output may contain redundant information, and the ranking step could be slow if many providers return a large number of overlapping results.
* **Fix:** Implement a more efficient and effective deduplication strategy. A good approach is to use a "near-duplicate" detection method like MinHash or SimHash. For a simpler but still effective improvement, we can cluster documents by title similarity and then only compare snippets within clusters.
##### Targeted Fix (Simplified Improvement):
```diff
--- a/src/biz_bud/nodes/search/ranker.py
+++ b/src/biz_bud/nodes/search/ranker.py
@@ -107,14 +107,26 @@
return freshness
def _remove_duplicates(
self, results: list[RankedSearchResult]
) -> list[RankedSearchResult]:
unique_results: list[RankedSearchResult] = []
- seen_urls: set[str] = set()
+ # Use a more robust check for duplicates than just URL
+ seen_hashes: set[str] = set()
for result in results:
- if result.url in seen_urls:
+ # Normalize URL for better duplicate detection
+ normalized_url = result.url.lower().rstrip("/")
+
+ # Create a simple hash from the title to quickly identify near-duplicate content
+ # A more advanced solution would use MinHash or SimHash here.
+ normalized_title = self._normalize_text(result.title)
+ content_hash = hashlib.md5(normalized_title[:50].encode()).hexdigest()
+
+ # Key for checking duplicates is a tuple of the normalized URL's domain and the content hash
+ duplicate_key = (result.source_domain, content_hash)
+
+ if duplicate_key in seen_hashes:
continue
- seen_urls.add(result.url)
-
+ seen_hashes.add(duplicate_key)
unique_results.append(result)
return unique_results
```
*(Note: `hashlib` would need to be imported.)*

View File

@@ -1,133 +0,0 @@
# Coverage Configuration Guide
This document explains the coverage reporting configuration for the Business Buddy project.
## Overview
The project uses `pytest-cov` for code coverage measurement with comprehensive reporting options configured in `pyproject.toml`.
## Configuration
### Coverage Collection (`[tool.coverage.run]`)
- **Source**: `src/biz_bud` - Measures coverage for all source code
- **Branch Coverage**: Enabled to track both statement and branch coverage
- **Parallel Execution**: Supports parallel test execution with xdist
- **Omitted Files**:
- Test files (`*/tests/*`, `*/test_*.py`, `*/conftest.py`)
- Init files (`*/__init__.py`)
- Entry points (`webapp.py`, CLI files)
- Database migrations
### Coverage Reporting (`[tool.coverage.report]`)
- **Show Missing**: Displays line numbers for uncovered code
- **Precision**: 2 decimal places for coverage percentages
- **Skip Empty**: Excludes empty files from reports
- **Comprehensive Exclusions**:
- Type checking blocks (`if TYPE_CHECKING:`)
- Debug code (`if DEBUG:`, `if __debug__:`)
- Platform-specific code
- Error handling patterns
- Abstract methods and protocols
### Report Formats
1. **Terminal**: `--cov-report=term-missing:skip-covered`
2. **HTML**: `htmlcov/index.html` with context information
3. **XML**: `coverage.xml` for CI/CD integration
4. **JSON**: `coverage.json` for programmatic access
## Usage
### Running Tests with Coverage
```bash
# Run all tests with coverage
make test
# Run specific test with coverage
pytest tests/path/to/test.py --cov=src/biz_bud --cov-report=html
# Run without coverage (faster for development)
pytest tests/path/to/test.py --no-cov
```
### Coverage Reports
```bash
# Generate HTML report
pytest --cov=src/biz_bud --cov-report=html
# View HTML report
open htmlcov/index.html # macOS
xdg-open htmlcov/index.html # Linux
# Generate XML report for CI
pytest --cov=src/biz_bud --cov-report=xml
# Generate JSON report
pytest --cov=src/biz_bud --cov-report=json
```
### Coverage Thresholds
- **Minimum Coverage**: 70% (configurable via `--cov-fail-under`)
- **Branch Coverage**: Required for thorough testing
- **Context Tracking**: Enabled to track which tests cover which code
## Best Practices
1. **Write Tests First**: Aim for high coverage through TDD
2. **Focus on Critical Paths**: Prioritize coverage for core business logic
3. **Use Exclusion Pragmas**: Mark intentionally untested code with `# pragma: no cover`
4. **Review Coverage Reports**: Use HTML reports to identify missed edge cases
5. **Monitor Trends**: Track coverage changes in CI/CD
## Exclusion Patterns
The configuration excludes common patterns that don't need testing:
- Type checking imports (`if TYPE_CHECKING:`)
- Debug statements (`if DEBUG:`, `if __debug__:`)
- Platform-specific code (`if sys.platform`)
- Abstract methods (`@abstract`, `raise NotImplementedError`)
- Error handling boilerplate (`except ImportError:`)
## Integration with CI/CD
The XML and JSON reports are designed for integration with:
- **GitHub Actions**: Upload coverage to services like Codecov
- **SonarQube**: Import coverage data for quality gates
- **IDE Integration**: Many IDEs can display coverage inline
## Troubleshooting
### Common Issues
1. **No Data Collected**: Ensure source paths match actual file locations
2. **Parallel Test Issues**: Coverage data may need combining with `coverage combine`
3. **Missing Files**: Check that files are imported during test execution
4. **Low Coverage**: Review exclusion patterns and test completeness
### Debug Commands
```bash
# Check coverage configuration
python -m coverage debug config
# Combine parallel coverage data
python -m coverage combine
# Erase coverage data
python -m coverage erase
```
## Files
- **Configuration**: `pyproject.toml` (`[tool.coverage.*]` sections)
- **Data File**: `.coverage` (temporary, in .gitignore)
- **HTML Reports**: `htmlcov/` directory (in .gitignore)
- **XML Report**: `coverage.xml` (in .gitignore)
- **JSON Report**: `coverage.json` (in .gitignore)

View File

@@ -1,161 +0,0 @@
# Graph Consolidation Implementation Summary
This document summarizes the implementation of the graph consolidation plan as outlined in `graph-consolidation.md`.
## Overview
The consolidation reorganized the Business Buddy codebase to:
- Create consolidated, reusable node modules in `src/biz_bud/nodes/`
- Reorganize graphs into feature-based subdirectories in `src/biz_bud/graphs/`
- Implement the "graphs as tools" pattern throughout
- Maintain backward compatibility during migration
## Phase 1: Consolidated Core Nodes
Created five core node modules that provide reusable functionality:
### 1. `nodes/core.py`
Core workflow operations including:
- `parse_and_validate_initial_payload` - Input validation and parsing
- `format_output_node` - Standard output formatting
- `handle_graph_error` - Error handling and recovery
- `preserve_url_fields_node` - URL field preservation
- `finalize_status_node` - Status finalization
### 2. `nodes/llm.py`
Language model interaction nodes:
- `call_model_node` - LLM invocation with retry logic
- `update_message_history_node` - Conversation history management
- `prepare_llm_messages_node` - Message preparation for LLM
### 3. `nodes/web_search.py`
Web search functionality:
- `web_search_node` - Generic web search
- `research_web_search_node` - Research-focused search
- `cached_web_search_node` - Cached search operations
### 4. `nodes/scrape.py`
Web scraping and content fetching:
- `scrape_url_node` - URL content extraction
- `discover_urls_node` - URL discovery from content
- `batch_process_urls_node` - Batch URL processing
- `route_url_node` - URL routing logic
### 5. `nodes/extraction.py`
Data extraction and semantic analysis:
- `extract_key_information_node` - Key information extraction
- `semantic_extract_node` - Semantic content extraction
- `orchestrate_extraction_node` - Extraction orchestration
## Phase 2: Graph Reorganization
Reorganized graphs into feature-based subdirectories:
### 1. `graphs/research/`
- `graph.py` - Research workflow implementation
- `nodes.py` - Research-specific nodes
- Implements `research_graph_factory` for graph-as-tool pattern
### 2. `graphs/catalog/`
- `graph.py` - Catalog workflow implementation
- `nodes.py` - Catalog-specific nodes
- Implements `catalog_graph_factory` for graph-as-tool pattern
### 3. `graphs/rag/`
- `graph.py` - RAG/R2R workflow implementation
- `nodes.py` - RAG-specific content preparation
- `integrations.py` - External service integrations
- Implements `url_to_rag_graph_factory` for graph-as-tool pattern
### 4. `graphs/analysis/` (New)
- `graph.py` - Comprehensive data analysis workflow
- `nodes.py` - Analysis-specific nodes
- Demonstrates creating new graphs following the pattern
### 5. `graphs/paperless/`
- `graph.py` - Paperless-NGX integration workflow
- `nodes.py` - Document processing nodes
- Implements `paperless_graph_factory` for graph-as-tool pattern
## Key Implementation Details
### Backward Compatibility
- Maintained imports in `nodes/__init__.py` for legacy code
- Used try/except blocks to handle gradual migration
- Created aliases for renamed functions
- Override legacy imports with consolidated versions when available
### Graph-as-Tool Pattern
Each graph module now exports:
- A factory function (e.g., `research_graph_factory`)
- A create function (e.g., `create_research_graph`)
- Input schema definitions
- Graph metadata including name, description, and schema
### Node Organization
- Generic, reusable nodes consolidated into core modules
- Graph-specific logic moved to graph subdirectories
- Clear separation between infrastructure and business logic
- Consistent use of decorators for metrics and logging
## Migration Strategy
1. **Gradual Migration**: Legacy imports remain functional while new code uses consolidated nodes
2. **Import Compatibility**: The `nodes/__init__.py` handles both old and new imports
3. **Testing**: Created `test_imports.py` to verify all imports work correctly
4. **Documentation**: Updated docstrings and added module-level documentation
## Benefits Achieved
1. **Cleaner Architecture**: Clear separation between reusable and specific functionality
2. **Better Composability**: Graphs can easily be used as tools in other graphs
3. **Reduced Duplication**: Common functionality consolidated into single locations
4. **Improved Maintainability**: Feature-based organization makes code easier to find
5. **Type Safety**: Maintained full type checking throughout migration
## Next Steps
1. Update all existing code to use new imports
2. Remove legacy node implementations once migration is complete
3. Create additional consolidated node modules as patterns emerge
4. Document the graph-as-tool pattern for future developers
5. Add integration tests for all graph factories
## File Structure After Consolidation
```
src/biz_bud/
├── nodes/
│ ├── __init__.py # Backward compatibility and exports
│ ├── core.py # Core workflow operations
│ ├── llm.py # LLM interaction nodes
│ ├── web_search.py # Web search nodes
│ ├── scrape.py # Web scraping nodes
│ └── extraction.py # Data extraction nodes
├── graphs/
│ ├── __init__.py # Graph exports and aliases
│ ├── graph.py # Main workflow graph
│ ├── research/
│ │ ├── __init__.py
│ │ ├── graph.py
│ │ └── nodes.py
│ ├── catalog/
│ │ ├── __init__.py
│ │ ├── graph.py
│ │ └── nodes.py
│ ├── rag/
│ │ ├── __init__.py
│ │ ├── graph.py
│ │ ├── nodes.py
│ │ └── integrations.py
│ ├── analysis/
│ │ ├── __init__.py
│ │ ├── graph.py
│ │ └── nodes.py
│ └── paperless/
│ ├── __init__.py
│ ├── graph.py
│ └── nodes.py
```
This consolidation provides a solid foundation for future development while maintaining compatibility with existing code.

View File

@@ -1,123 +0,0 @@
Of course. I can certainly help you review and restructure your `nodes` and `graphs` directories. Your intuition is correct; organizing your project to have a clear hierarchy of reusable components will significantly improve maintainability and scalability. The principle of "tools get registered to nodes that are used in graphs" is an excellent architectural guideline.
Here is a comprehensive review and a proposed restructuring plan for your project.
### Project Review and Analysis
Your project is well-structured with a strong separation of concerns in the `core` package. The use of registries, service factories, and typed states is commendable. However, as you've noted, the `nodes` package has become a bit of a "catch-all" for various functionalities, leading to some nodes being tightly coupled to specific workflows while residing in a generic location.
**Key Observations:**
* **Reusable Core Nodes:** You have several nodes that are fundamental to any workflow, such as those for handling input/output, calling the language model, and managing errors. These are prime candidates for consolidation.
* **Workflow-Specific Nodes:** Many nodes, particularly in `nodes/analysis`, `nodes/catalog`, and `nodes/rag`, are only relevant within the context of a specific graph or business process.
* **Implicit Workflows:** The current structure sometimes obscures the actual flow of a process. For instance, the entire RAG pipeline is implemented as a collection of nodes that are implicitly linked, but their relationship isn't immediately obvious from the file structure.
* **Graph as a Tool:** Your `ToolFactory` and `GraphRegistry` are well-designed to support the concept of graphs being callable as tools. The final step is to formalize this pattern across all your graphs.
### Proposed Restructuring Plan
The goal is to create a clear distinction between generic, reusable nodes and the specific business logic that constitutes a graph.
#### 1. Consolidate the `nodes` Package
The `nodes` package will be streamlined to contain only the fundamental, reusable building blocks for your graphs.
**New `src/biz_bud/nodes/` Structure:**
```
src/biz_bud/nodes/
├── __init__.py
├── core.py # Merged from nodes/core/*
├── llm.py # From nodes/llm/call.py
├── search.py # Unified search node from nodes/search/*
├── web.py # Unified web scraping/analysis from nodes/scraping/*
└── extraction.py # Unified extraction from nodes/extraction/*
```
**Actions:**
1. **Create `nodes/core.py`:** Merge the logic from `nodes/core/input.py`, `nodes/core/output.py`, and `nodes/core/error.py`. These nodes represent the standard entry, exit, and error handling points for any graph.
2. **Create `nodes/llm.py`:** Move the `call_model_node` from `nodes/llm/call.py` here. This will be the centralized node for all language model interactions.
3. **Create `nodes/search.py`:** Consolidate the functionality from `nodes/search/*` into a single, highly configurable search node. This node would take parameters to specify the search strategy (e.g., optimization, ranking, caching).
4. **Create `nodes/web.py`:** Merge the scraping and URL analysis logic from `nodes/scraping/*`. This node will handle all interactions with web pages.
5. **Create `nodes/extraction.py`:** Unify the extraction logic from `nodes/extraction/*` into a single node that can perform various types of data extraction based on its configuration.
#### 2. Relocate Specific Nodes and Restructure `graphs`
The `graphs` package will be reorganized into feature-centric subpackages. Each subpackage will contain the graph definition and any nodes that are specific to that graph's workflow.
**New `src/biz_bud/graphs/` Structure:**
```
src/biz_bud/graphs/
├── __init__.py
├── research.py
├── planner.py
├── main.py # Renamed from graph.py
├── catalog/
│ ├── __init__.py
│ ├── graph.py # Formerly graphs/catalog.py
│ └── nodes.py # Nodes from nodes/catalog/* and nodes/analysis/c_intel.py
├── rag/
│ ├── __init__.py
│ ├── graph.py # Formerly graphs/url_to_r2r.py
│ ├── nodes.py # All nodes from nodes/rag/*
│ └── integrations.py # Nodes from nodes/integrations/*
└── analysis/
├── __init__.py
├── graph.py # A new graph for data analysis
└── nodes.py # Nodes from nodes/analysis/*
```
**Actions:**
1. **Create Graph Subpackages:** For each major feature (e.g., `catalog`, `rag`, `analysis`), create a dedicated subdirectory within `graphs`.
2. **Move Graph Definitions:** Relocate the existing graph definitions (e.g., `graphs/catalog.py`) into their new subpackages (e.g., `graphs/catalog/graph.py`).
3. **Move Specific Nodes:**
* Move all nodes from `nodes/catalog/` and the catalog-specific nodes from `nodes/analysis/` into `graphs/catalog/nodes.py`.
* Move all nodes from `nodes/rag/` into `graphs/rag/nodes.py`.
* Move the integration-specific nodes (`repomix.py`, `firecrawl/`) into `graphs/rag/integrations.py`.
* Create a new `analysis` graph and move the generic data analysis nodes from `nodes/analysis/*` into `graphs/analysis/nodes.py`.
4. **Update Imports:** Adjust all import paths within the moved files to reflect the new structure.
#### 3. Refactoring Graphs to be Tool-Callable
To ensure each graph can be seamlessly used as a tool by your main agent, follow this pattern for each graph:
1. **Define a Pydantic Input Schema:** Each graph should have a clearly defined input model. This makes the graph's interface explicit and allows for automatic validation.
```python
# in src/biz_bud/graphs/research.py
from pydantic import BaseModel, Field
class ResearchGraphInput(BaseModel):
"""Input for the research graph."""
query: str = Field(description="The research topic or question.")
max_sources: int = Field(default=5, description="The maximum number of sources to use.")
```
2. **Update the Graph Factory:** The factory function for each graph should be registered with the `GraphRegistry` and use the input schema. Your existing `GRAPH_METADATA` pattern is perfect for this.
```python
# in src/biz_bud/graphs/research.py
GRAPH_METADATA = {
"name": "research",
"description": "Performs in-depth research on a given topic.",
"input_schema": ResearchGraphInput.model_json_schema(),
"capabilities": ["research", "web_search"],
# ...
}
def research_graph_factory(config: RunnableConfig) -> CompiledStateGraph:
# ... graph creation logic ...
return graph.compile()
```
3. **Leverage the `ToolFactory`:** Your `ToolFactory`'s `create_graph_tool` method is already designed to handle this. It will discover the registered graphs, use their input schemas to create a typed tool interface, and wrap their execution.
By applying these changes, your main agent, likely the `BuddyAgent`, can dynamically discover and utilize these graphs as high-level tools, orchestrating them to fulfill complex user requests.
This restructured approach will provide you with a more modular, scalable, and intuitive codebase that clearly separates reusable components from specific business logic.

View File

@@ -1,195 +0,0 @@
# LangGraph Improvements Implementation Summary
This document summarizes the improvements made to the Business Buddy codebase to better leverage LangGraph's advanced features and patterns.
## Overview
Based on the Context7 analysis and LangGraph best practices, I implemented the following improvements:
1. **Command Pattern Support** - Enhanced routing with state updates
2. **Send API Integration** - Parallel processing capabilities
3. **Consolidated Edge Helpers** - Unified routing interface
4. **Subgraph Patterns** - Composable workflow modules
5. **Example Implementations** - Demonstrated patterns in actual graphs
## 1. Command Pattern Implementation
### Created: `/src/biz_bud/core/edge_helpers/command_patterns.py`
This module provides Command-based routing patterns that combine state updates with control flow:
- `create_command_router` - Basic Command routing with state updates
- `create_retry_command_router` - Retry logic with automatic attempt tracking
- `create_conditional_command_router` - Multi-condition routing with updates
- `create_subgraph_command_router` - Subgraph delegation with Command.PARENT support
### Example Usage in Research Graph:
```python
# Create Command-based retry router
_search_retry_router = create_retry_command_router(
max_attempts=3,
retry_node="rag_enhance",
success_node="extract_info",
failure_node="synthesize",
attempt_key="search_attempts",
success_key="has_search_results"
)
# Use in graph
workflow.add_conditional_edges("prepare_search_results", _search_retry_router)
```
Benefits:
- Automatic state management (retry counts, status updates)
- Cleaner routing logic
- Combined control flow and state updates in single operation
## 2. Send API for Dynamic Edges
### Created: `/src/biz_bud/graphs/scraping/`
Demonstrated Send API usage for parallel URL processing:
```python
async def dispatch_urls(state: ScrapingState) -> list[Send]:
"""Dispatch URLs for parallel processing using Send API."""
sends = []
for i, url in enumerate(urls_to_process):
sends.append(Send(
"scrape_single_url",
{
"processing_url": url,
"url_index": i,
"current_depth": current_depth,
}
))
return sends
```
Benefits:
- True parallel processing of URLs
- Dynamic branching based on runtime data
- Efficient map-reduce patterns
- Scalable processing architecture
## 3. Consolidated Edge Helpers
### Created: `/src/biz_bud/core/edge_helpers/consolidated.py`
Unified interface for all routing patterns through the `EdgeHelpers` class:
```python
# Basic routing
router = edge_helpers.route_on_key("status", {"success": "continue", "error": "retry"})
# Command routing with retry
router = edge_helpers.command_route_with_retry(
success_check=lambda s: s["confidence"] > 0.8,
success_target="finalize",
retry_target="enhance",
failure_target="human_review"
)
# Send patterns for parallel processing
router = edge_helpers.send_to_processors("items", "process_item")
# Error handling with recovery
router = edge_helpers.command_route_on_error_with_recovery(
recovery_node="fix_errors",
max_recovery_attempts=2
)
```
Benefits:
- Consistent API for all routing patterns
- Type-safe routing functions
- Reusable patterns across graphs
- Built-in error handling and recovery
## 4. Subgraph Pattern Implementation
### Created: `/src/biz_bud/graphs/paperless/subgraphs.py`
Demonstrated subgraph patterns with Command.PARENT for returning to parent graphs:
```python
@standard_node(node_name="return_to_parent")
async def return_to_parent_node(state: dict[str, Any]) -> Command[str]:
"""Return control to parent graph with results."""
return Command(
goto="consolidate_results",
update={"subgraph_complete": True, "subgraph_type": "tag_suggestion"},
graph=Command.PARENT # Return to parent graph
)
```
Subgraphs Created:
- **Document Processing** - OCR, text extraction, metadata
- **Tag Suggestion** - Content analysis and auto-tagging
- **Document Search** - Advanced search with ranking
Benefits:
- Modular, reusable workflows
- Clear separation of concerns
- Easy testing of individual workflows
- Composable architecture
## 5. Updated Graphs
### Research Graph Updates:
- Uses Command-based retry router for search operations
- Implements synthesis quality checking with Command updates
- Cleaner state management through Command patterns
### Paperless Graph Updates:
- Integrated three subgraphs for specialized operations
- Added consolidation node for subgraph results
- Extended routing to support subgraph delegation
## Key Improvements Achieved
1. **Better State Management**
- Command patterns eliminate manual state tracking
- Automatic retry counting and status updates
- Cleaner separation of routing and state logic
2. **Enhanced Parallelism**
- Send API enables true parallel processing
- Map-reduce patterns for scalable operations
- Dynamic branching based on runtime data
3. **Improved Modularity**
- Subgraphs provide reusable workflow components
- Clear interfaces between parent and child graphs
- Better testing and maintenance
4. **Unified Routing Interface**
- Single source of truth for routing patterns
- Consistent API across all graphs
- Reusable routing logic
5. **Type Safety**
- Proper type hints throughout
- Validated routing functions
- Clear contracts between components
## Migration Guide
To use these improvements in existing graphs:
1. **Replace manual retry logic** with `create_retry_command_router`
2. **Convert parallel operations** to use Send API patterns
3. **Extract complex workflows** into subgraphs
4. **Use EdgeHelpers** for consistent routing patterns
5. **Leverage Command** for combined state updates and routing
## Next Steps
1. Migrate remaining graphs to use new patterns
2. Create more specialized subgraphs for common operations
3. Add telemetry and monitoring to Command routers
4. Implement caching strategies for subgraph results
5. Create graph composition utilities for complex workflows
The improvements provide a solid foundation for building more sophisticated, maintainable, and performant LangGraph workflows while maintaining backward compatibility with existing code.

View File

@@ -1,117 +0,0 @@
# LangGraph Patterns Implementation Summary
## Overview
This document summarizes the implementation of LangGraph best practices across the Business Buddy codebase, following the guidance in CLAUDE.md.
## Key Patterns Implemented
### 1. State Immutability
- Created `bb_core.langgraph.state_immutability` module with:
- `ImmutableDict` class to prevent state mutations
- `StateUpdater` builder pattern for immutable state updates
- `@ensure_immutable_node` decorator to enforce immutability
### 2. Cross-Cutting Concerns
- Created `bb_core.langgraph.cross_cutting` module with decorators:
- `@log_node_execution` - Automatic logging of node execution
- `@track_metrics` - Performance metrics tracking
- `@handle_errors` - Standardized error handling
- `@retry_on_failure` - Retry logic with exponential backoff
- `@standard_node` - Composite decorator combining all concerns
### 3. RunnableConfig Integration
- Created `bb_core.langgraph.runnable_config` module with:
- `ConfigurationProvider` class for type-safe config access
- Helper functions for creating and merging RunnableConfig
- Support for service factory and metadata access
### 4. Graph Configuration
- Created `bb_core.langgraph.graph_config` module with:
- `configure_graph_with_injection` for dependency injection
- `create_config_injected_node` for automatic config injection
- Helper utilities for config extraction
### 5. Tool Decorator Pattern
- Updated tools to use `@tool` decorator from langchain_core.tools
- Added Pydantic schemas for input/output validation
- Examples in `scrapers.py` and `bb_extraction/tools.py`
### 6. Service Factory Pattern
- Created example in `service_factory_example.py` showing:
- Protocol-based service interfaces
- Abstract factory pattern
- Integration with RunnableConfig
## Updated Nodes
### Core Nodes
- `core/input.py` - StateUpdater, RunnableConfig, immutability
- `core/output.py` - Standard decorators, immutable updates
- `core/error.py` - Error handling with immutability
### Other Nodes Updated
- `synthesis/synthesize.py` - Using bb_core imports
- `scraping/url_router.py` - Standard node pattern
- `analysis/interpret.py` - Service factory from config
- `validation/logic.py` - Immutable state updates
- `extraction/orchestrator.py` - Standard decorators
- `search/orchestrator.py` - Cross-cutting concerns
- `llm/call.py` - ConfigurationProvider usage
## Example Implementations
### 1. Research Subgraph (`research_subgraph.py`)
Demonstrates:
- Type-safe state schema with TypedDict
- Tool implementation with @tool decorator
- Nodes with @standard_node and @ensure_immutable_node
- StateUpdater for all state mutations
- Error handling and logging
- Reusable subgraph pattern
### 2. Service Factory Example (`service_factory_example.py`)
Shows:
- Protocol-based service definitions
- Abstract factory pattern
- Integration with RunnableConfig
- Service caching and reuse
- Mock implementations for testing
## Migration Notes
### Renamed Files
The following local utility files were renamed to .old as they're replaced by bb_core:
- `src/biz_bud/utils/state_immutability.py.old`
- `src/biz_bud/utils/cross_cutting.py.old`
- `src/biz_bud/config/runnable_config.py.old`
### Test Updates
- Updated test files to pass `None` for the config parameter
- Example: `parse_and_validate_initial_payload(state, None)`
## Benefits Achieved
1. **State Predictability**: Immutable state prevents accidental mutations
2. **Consistent Logging**: All nodes have standardized logging
3. **Performance Monitoring**: Automatic metrics tracking
4. **Error Resilience**: Standardized error handling with retries
5. **Dependency Injection**: Clean separation of concerns via RunnableConfig
6. **Type Safety**: Strong typing throughout with no Any types
7. **Reusability**: Subgraphs and service factories enable code reuse
## Next Steps
1. Continue updating remaining nodes to use the new patterns
2. Create comprehensive tests for the new utilities
3. Update integration tests to work with immutable state
4. Create migration guide for existing code
5. Add more example subgraphs and patterns
## Code Quality
All implementations follow strict requirements:
- No `Any` types without bounds
- No `# type: ignore` comments
- Imperative docstrings with punctuation
- Direct file updates (no parallel versions)
- Utilities centralized in bb_core package

View File

@@ -1,154 +0,0 @@
# LangGraph Best Practices Implementation
This document summarizes the LangGraph best practices that have been implemented in the Business Buddy codebase.
## 1. Configuration Management with RunnableConfig
### Updated Files:
- **`src/biz_bud/nodes/llm/call.py`**: Updated to accept both `NodeLLMConfigOverride` and `RunnableConfig`
- **`src/biz_bud/nodes/core/input.py`**: Updated to accept `RunnableConfig` and use immutable state patterns
- **`src/biz_bud/config/runnable_config.py`**: Created `ConfigurationProvider` class for type-safe config access
### Key Features:
- Type-safe configuration access via `ConfigurationProvider`
- Support for both legacy dict-based config and new RunnableConfig pattern
- Automatic extraction of context (run_id, user_id) from RunnableConfig
- Seamless integration with existing service factory
## 2. State Immutability
### New Files:
- **`src/biz_bud/utils/state_immutability.py`**: Complete immutability utilities
- `ImmutableDict`: Prevents accidental state mutations
- `StateUpdater`: Builder pattern for immutable updates
- `@ensure_immutable_node`: Decorator to enforce immutability
- Helper functions for state validation
### Updated Files:
- **`src/biz_bud/nodes/core/input.py`**: Refactored to use `StateUpdater` for all state changes
### Key Patterns:
```python
# Instead of mutating state directly:
# state["key"] = value
# Use StateUpdater:
updater = StateUpdater(state)
new_state = (
updater
.set("key", value)
.append("list_key", item)
.build()
)
```
## 3. Service Factory Enhancements
### Existing Excellence:
- **`src/biz_bud/services/factory.py`**: Already follows best practices
- Singleton pattern within factory scope
- Thread-safe initialization
- Proper lifecycle management
- Dependency injection support
### New Files:
- **`src/biz_bud/services/web_tools.py`**: Factory functions for web tools
- `get_web_search_tool()`: Creates configured search tools
- `get_unified_scraper()`: Creates configured scrapers
- Automatic provider/strategy registration based on API keys
## 4. Tool Decorator Pattern
### New Files:
- **`src/biz_bud/tools/web_search_tool.py`**: Demonstrates @tool decorator pattern
- `@tool` decorated functions with proper schemas
- Input/output validation with Pydantic models
- Integration with RunnableConfig
- Error handling and fallback mechanisms
- Batch operations support
### Key Features:
- Type-safe tool definitions
- Automatic schema generation for LangGraph
- Support for both sync and async tools
- Proper error propagation
## 5. Cross-Cutting Concerns
### New Files:
- **`src/biz_bud/utils/cross_cutting.py`**: Comprehensive decorators for:
- **Logging**: `@log_node_execution` with context extraction
- **Metrics**: `@track_metrics` for performance monitoring
- **Error Handling**: `@handle_errors` with fallback support
- **Retries**: `@retry_on_failure` with exponential backoff
- **Composite**: `@standard_node` combining all concerns
### Usage Example:
```python
@standard_node(
node_name="my_node",
metric_name="my_metric",
retry_attempts=3
)
async def my_node(state: dict, config: RunnableConfig) -> dict:
# Automatically gets logging, metrics, error handling, and retries
pass
```
## 6. Reusable Subgraphs
### New Files:
- **`src/biz_bud/graphs/research_subgraph.py`**: Complete research workflow
- Demonstrates all best practices in a real workflow
- Input validation → Search → Synthesis → Summary
- Conditional routing based on state
- Proper error handling and metrics
- Can be embedded in larger graphs
### Key Patterns:
- Clear state schema definition (`ResearchState`)
- Immutable state updates throughout
- Tool integration with error handling
- Modular design for reusability
## 7. Graph Configuration Utilities
### New Files:
- **`src/biz_bud/utils/graph_config.py`**: Graph configuration helpers
- `configure_graph_with_injection()`: Auto-inject config into all nodes
- `create_config_injected_node()`: Wrap nodes with config injection
- `update_node_to_use_config()`: Decorator for config-aware nodes
- Backward compatibility helpers
## Integration Points
### Direct Updates (No Wrappers):
Following your directive, all updates were made directly to existing files rather than creating parallel versions:
- `call.py` enhanced with RunnableConfig support
- `input.py` refactored for immutability
- No `_v2` files created
### Clean Architecture:
- Clear separation of concerns
- Utilities in `utils/` directory
- Tools in `tools/` directory
- Service integration in `services/`
- Graph components in `graphs/`
## Next Steps
1. **Update Remaining Nodes**: Apply these patterns to other nodes in the codebase
2. **Add Tests**: Create comprehensive tests for new utilities
3. **Documentation**: Add inline documentation and usage examples
4. **Schema Updates**: Add missing config fields (research_config, tools_config) to AppConfig
5. **Migration Guide**: Create guide for updating existing nodes to new patterns
## Benefits Achieved
1. **Type Safety**: Full type checking with Pydantic and proper annotations
2. **Immutability**: Prevents accidental state mutations
3. **Observability**: Built-in logging and metrics
4. **Reliability**: Automatic retries and error handling
5. **Modularity**: Reusable components and subgraphs
6. **Maintainability**: Clear patterns and separation of concerns
7. **Performance**: Efficient state updates and service reuse

View File

@@ -1,333 +0,0 @@
Of course. I have reviewed your project structure, focusing on the `src/biz_bud/tools` directory. Your assessment is correct; while there are strong individual components, the overall organization has opportunities for improvement regarding clarity, redundancy, and abstraction.
Here is a comprehensive review and a detailed plan to refactor your tools directory for better organization, efficiency, and scalability, formatted as LLM-compatible markdown.
---
### **Project Review & Tool Directory Refactoring Plan**
This plan addresses the disorganization and redundancy in the current `tools` directory. The goal is to establish a clear, scalable architecture that aligns with best practices for building agentic systems.
### **Part 1: High-Level Assessment & Key Issues**
Your current implementation correctly identifies the need for patterns like provider-based search and scraping strategies. However, the structure has several areas for improvement:
1. **Redundancy & Duplication:**
* **Jina API Logic:** Functionality for Jina is scattered across `tools/api_clients/jina.py`, `tools/apis/jina/search.py`, and `tools/search/providers/jina.py`. This creates confusion and high maintenance overhead.
* **Search Orchestration:** The roles of `tools/search/unified.py` and `tools/search/web_search.py` overlap significantly. Both attempt to provide a unified interface to search providers.
* **Action/Flow Ambiguity:** The directories `tools/actions` and `tools/flows` serve similar purposes (high-level tool compositions) and could be consolidated or re-evaluated. `tools/actions/scrape.py` and `tools/flows/scrape.py` are nearly identical in concept.
2. **Architectural Shortcomings:**
* **Tight Coupling:** The `ToolFactory` in `tools/core/factory.py` is overly complex and tightly coupled with `nodes` and `graphs`. It dynamically creates tools from other components, which breaks separation of concerns. Tools should be self-contained and independently testable.
* **Complex Tool Creation:** Dynamically creating Pydantic models (`create_model`) within the factory is an anti-pattern that hinders static analysis and introduces runtime risks. The `create_lightweight_search_tool` is a one-off implementation that adds unnecessary complexity.
* **Unclear Abstraction:** While provider patterns exist, the abstraction isn't consistently applied. The system doesn't have a single, clean interface for an agent to request a capability like "search" and specify a provider.
3. **Code Health & Bugs:**
* **Dead Code:** `services/web_tools.py` throws a `WebToolsRemovedError`, indicating that it's obsolete after a partial refactoring. `tools/interfaces/web_tools.py` seems to be a remnant of this.
* **Inconsistent Naming:** The mix of terms like `apis`, `api_clients`, `providers`, and `strategies` for similar concepts makes the architecture difficult to understand.
### **Part 2: Proposed New Tool Architecture**
I propose a new structure centered around two key concepts: **Clients** (how we talk to external services) and **Capabilities** (what the agent can do).
This design provides a clean separation of concerns and directly implements the provider-agnostic interface you described.
#### **New Directory Structure**
```
src/biz_bud/tools/
├── __init__.py
├── models.py # (Existing) Shared Pydantic models for tool I/O.
|
├── clients/ # [NEW] Centralized clients for all external APIs/SDKs.
│ ├── __init__.py
│ ├── firecrawl.py
│ ├── jina.py # Single, consolidated Jina client.
│ ├── paperless.py
│ ├── r2r.py
│ └── tavily.py
└── capabilities/ # [NEW] Defines agent abilities. One tool per capability.
├── __init__.py
├── analysis.py # Native python tools for extraction, transformation.
├── database.py # Tools for interacting with Postgres and Qdrant.
├── introspection.py # Tools for agent self-discovery (listing graphs, etc.).
├── search/ # Example of a provider-based capability.
│ ├── __init__.py
│ ├── interface.py # Defines SearchProvider protocol and standard SearchResult model.
│ ├── tool.py # The single `web_search` @tool for agents to call.
│ └── providers/ # Implementations for each search service.
│ ├── __init__.py
│ ├── arxiv.py
│ ├── jina.py
│ └── tavily.py
└── scrape/ # Another provider-based capability.
├── __init__.py
├── interface.py # Defines ScraperProvider protocol and ScrapedContent model.
├── tool.py # The single `scrape_url` @tool for agents to call.
└── providers/ # Implementations for each scraping service.
├── __init__.py
├── beautifulsoup.py
├── firecrawl.py
└── jina.py
```
### **Part 3: Implementation Guide**
Here are the steps and code to refactor your tools into the new architecture.
#### **Step 1: Create New Directories**
Create the new directory structure and delete the old, redundant ones.
```bash
# In src/biz_bud/
mkdir -p tools/clients
mkdir -p tools/capabilities/search/providers
mkdir -p tools/capabilities/scrape/providers
# Once migration is complete, you will remove the old directories:
# rm -rf tools/actions tools/api_clients tools/apis tools/browser tools/core \
# tools/extraction tools/flows tools/interfaces tools/loaders tools/r2r \
# tools/scrapers tools/search tools/stores tools/stubs tools/utils
```
#### **Step 2: Consolidate API Clients**
Unify all external service communication into the `clients/` directory. Each file should represent a client for a single service, handling all its different functions (e.g., Jina for search, rerank, and scraping).
**Example: Consolidated Jina Client**
This client will use your robust `core/networking/api_client.py` for making HTTP requests.
```python
# File: src/biz_bud/tools/clients/jina.py
import os
from typing import Any, cast
from ...core.networking.api_client import APIClient, APIResponse, RequestConfig
from ...core.config.constants import (
JINA_SEARCH_ENDPOINT,
JINA_RERANK_ENDPOINT,
JINA_READER_ENDPOINT,
)
from ..models import JinaSearchResponse, RerankRequest, RerankResponse
class JinaClient:
"""A consolidated client for all Jina AI services."""
def __init__(self, api_key: str | None = None, http_client: APIClient | None = None):
self.api_key = api_key or os.getenv("JINA_API_KEY")
if not self.api_key:
raise ValueError("Jina API key is required.")
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept": "application/json",
"Content-Type": "application/json",
}
self._http_client = http_client or APIClient(base_url="", headers=self.headers)
async def search(self, query: str) -> JinaSearchResponse:
"""Performs a web search using Jina Search."""
search_url = f"{JINA_SEARCH_ENDPOINT}{query}"
headers = {"Accept": "application/json"} # Search endpoint uses different auth
response = await self._http_client.get(search_url, headers=headers)
response.raise_for_status()
return JinaSearchResponse.model_validate(response.data)
async def scrape(self, url: str) -> dict[str, Any]:
"""Scrapes a URL using Jina's reader endpoint (jina.ai/read)."""
scrape_url = f"{JINA_READER_ENDPOINT}{url}"
response = await self._http_client.get(scrape_url, headers=self.headers)
response.raise_for_status()
# The reader endpoint returns raw content, which we'll wrap
return {
"url": url,
"content": response.data,
"provider": "jina"
}
async def rerank(self, request: RerankRequest) -> RerankResponse:
"""Reranks documents using the Jina Rerank API."""
response = await self._http_client.post(
JINA_RERANK_ENDPOINT,
json=request.model_dump(exclude_none=True)
)
response.raise_for_status()
return RerankResponse.model_validate(response.data)
```
#### **Step 3: Implement the `search` Capability**
This pattern creates a clear abstraction. The agent only needs to know about `web_search` and can optionally request a provider.
1. **Define the Interface:** Create a protocol that all search providers must follow.
```python
# File: src/biz_bud/tools/capabilities/search/interface.py
from typing import Protocol, Any
from ....tools.models import SearchResult
class SearchProvider(Protocol):
"""Protocol for a search provider that can execute a search query."""
name: str
async def search(self, query: str, max_results: int = 10) -> list[SearchResult]:
"""
Executes a search query and returns a list of standardized search results.
"""
...
```
2. **Create a Provider Implementation:** Refactor your existing provider logic to implement the protocol.
```python
# File: src/biz_bud/tools/capabilities/search/providers/tavily.py
from ....tools.clients.tavily import TavilyClient
from ....tools.models import SearchResult, SourceType
from ..interface import SearchProvider
class TavilySearchProvider(SearchProvider):
"""Search provider implementation for Tavily."""
name: str = "tavily"
def __init__(self, api_key: str | None = None):
self.client = TavilyClient(api_key=api_key)
async def search(self, query: str, max_results: int = 10) -> list[SearchResult]:
tavily_response = await self.client.search(
query=query,
max_results=max_results
)
search_results = []
for item in tavily_response.results:
search_results.append(
SearchResult(
title=item.title,
url=str(item.url),
snippet=item.content,
score=item.score,
source=SourceType.TAVILY,
metadata={"raw_content": item.raw_content}
)
)
return search_results
```
*(Repeat this for `jina.py` and `arxiv.py` in the same directory, adapting your existing code.)*
3. **Create the Unified Agent Tool:** This is the single, decorated function that agents will use. It dynamically selects the correct provider.
```python
# File: src/biz_bud/tools/capabilities/search/tool.py
from typing import Literal, Annotated
from langchain_core.tools import tool
from pydantic import BaseModel, Field
from .interface import SearchProvider
from .providers.tavily import TavilySearchProvider
from .providers.jina import JinaSearchProvider
from .providers.arxiv import ArxivProvider
from ....tools.models import SearchResult
# A registry of available providers
_providers: dict[str, SearchProvider] = {
"tavily": TavilySearchProvider(),
"jina": JinaSearchProvider(),
"arxiv": ArxivProvider(),
}
DEFAULT_PROVIDER = "tavily"
class WebSearchInput(BaseModel):
query: str = Field(description="The search query.")
provider: Annotated[
Literal["tavily", "jina", "arxiv"],
Field(description="The search provider to use.")
] = DEFAULT_PROVIDER
max_results: int = Field(default=10, description="Maximum number of results to return.")
@tool(args_schema=WebSearchInput)
async def web_search(query: str, provider: str = DEFAULT_PROVIDER, max_results: int = 10) -> list[SearchResult]:
"""
Performs a web search using a specified provider to find relevant documents.
Returns a standardized list of search results.
"""
search_provider = _providers.get(provider)
if not search_provider:
raise ValueError(f"Unknown search provider: {provider}. Available: {list(_providers.keys())}")
return await search_provider.search(query=query, max_results=max_results)
```
#### **Step 4: Consolidate Remaining Tools**
Move your other tools into the appropriate `capabilities` files.
* **Database Tools:** Move the logic from `tools/stores/database.py` into `tools/capabilities/database.py`. Decorate functions like `query_postgres` or `search_qdrant` with `@tool`.
```python
# File: src/biz_bud/tools/capabilities/database.py
from langchain_core.tools import tool
# ... import your db store ...
@tool
async def query_qdrant_by_id(document_id: str) -> dict:
"""Retrieves a document from the Qdrant vector store by its ID."""
# ... implementation ...
```
* **Analysis & Extraction Tools:** Consolidate all the logic from the `tools/extraction/` subdirectories into `tools/capabilities/analysis.py`.
* **Introspection Tools:** Move functions related to agent capabilities, like those in `tools/flows/agent_creator.py`, into `tools/capabilities/introspection.py`.
#### **Step 5: Decouple and Simplify the Tool Factory**
The complex `ToolFactory` is no longer needed. The responsibility is now split:
1. **Tool Definition:** Tools are defined as simple, decorated functions within the `capabilities` modules.
2. **Tool Collection:** A simple registry can discover and collect all functions decorated with `@tool`. The existing `ToolRegistry` can be simplified to do just this, removing all the dynamic creation logic.
Your existing `ToolRegistry` can be repurposed to simply discover tools from the `biz_bud.tools.capabilities` module path instead of creating them dynamically from nodes and graphs. This completely decouples the tool system from the graph execution system.
#### **Step 6: Final Cleanup**
Once all logic has been migrated to the new `clients/` and `capabilities/` directories, you can safely delete the following old directories from `src/biz_bud/tools/`:
* `actions/`
* `api_clients/`
* `apis/`
* `browser/` (Its logic should be moved into `capabilities/scrape/providers/browser.py`)
* `core/` (The complex factory is no longer needed)
* `extraction/`
* `flows/`
* `interfaces/`
* `loaders/`
* `r2r/` (Move R2R tools to `capabilities/database.py` or a dedicated `r2r.py`)
* `scrapers/`
* `search/`
* `stores/`
* `utils/`
### **Summary of Benefits**
By adopting this refactored architecture, you will gain:
1. **Clarity & Organization:** The structure is intuitive. `clients` handle *how* to talk to the outside world, and `capabilities` define *what* the agent can do.
2. **True Abstraction:** Agents and graphs can now use a single tool (e.g., `web_search`) and switch providers with a simple parameter, fulfilling your core requirement.
3. **Reduced Redundancy:** Consolidating clients (like Jina) and capabilities (like search) into single modules eliminates duplicate code and simplifies maintenance.
4. **Decoupling:** Tools are now independent of nodes and graphs. This makes them easier to test in isolation and allows them to be reused in different contexts.
5. **Scalability:** Adding a new search provider is as simple as adding one file to `capabilities/search/providers/` and updating the literal type in `tool.py`. The core logic remains untouched.

View File

@@ -1,289 +0,0 @@
Based on my analysis of your codebase, I'll help you refactor your tools package to be more compliant with LangGraph while maintaining synergy with your config, states, and services packages. Here's a comprehensive scaffold plan:
## Refactoring Strategy: Modular LangGraph-Compliant Tools Architecture
### 1. **Core Architecture Overview**
Your tools package currently has excellent separation of concerns but needs better integration with LangGraph patterns. We'll create a unified interface layer that bridges your existing tools with LangGraph's state management and node patterns.
### 2. **Reorganized Tools Package Structure**
```
src/biz_bud/tools/
|-- langgraph_integration/ # NEW: LangGraph-specific integration
| |-- __init__.py
| |-- base_tool.py # Base LangGraph tool wrapper
| |-- state_adapters.py # State conversion utilities
| +-- tool_registry.py # Centralized tool registration
|-- core/ # EXISTING (enhanced)
| |-- base.py # Enhanced with LangGraph compliance
| +-- interfaces.py # Protocol definitions
|-- implementations/ # RENAMED from current structure
| |-- search/
| |-- extraction/
| |-- scraping/
| +-- flows/
|-- config_integration/ # NEW: Config-aware tools
| |-- tool_factory.py
| +-- config_validator.py
+-- state_integration/ # NEW: State-aware tools
|-- state_mappers.py
+-- state_validators.py
```
### 3. **Type-Safe State Integration**
```python
# src/biz_bud/tools/langgraph_integration/base_tool.py
from typing import TypedDict, Any, Dict, List, Optional
from langchain_core.tools import BaseTool
from langgraph.types import Node
from biz_bud.states.base import BaseState
from biz_bud.core.config.schemas.tools import ToolsConfigModel
class ToolState(TypedDict):
"""State for tool execution with LangGraph compliance"""
query: str
context: Dict[str, Any]
config: ToolsConfigModel
prompt: str # Your dynamic prompt integration
result: Optional[Any]
error: Optional[str]
class LangGraphToolWrapper(BaseTool):
"""Wrapper to make existing tools LangGraph compliant"""
def __init__(self, tool_func, name: str, description: str):
super().__init__(
name=name,
description=description,
func=tool_func,
return_direct=True
)
def _run(self, state: ToolState) -> Dict[str, Any]:
"""Execute tool with state validation"""
try:
# Validate state using TypedDict
validated_state = ToolState(**state)
# Execute tool with prompt integration
result = self.func(
query=validated_state["query"],
context=validated_state["context"],
prompt=validated_state["prompt"], # Dynamic prompt
config=validated_state["config"]
)
return {"result": result, "error": None}
except Exception as e:
return {"result": None, "error": str(e)}
```
### 4. **Config-Tool Integration Layer**
```python
# src/biz_bud/tools/config_integration/tool_factory.py
from biz_bud.core.config.loader import load_config
from biz_bud.services.factory import ServiceFactory
from typing import Dict, Any, Type
class ConfigAwareToolFactory:
"""Factory for creating tools with configuration injection"""
def __init__(self):
self._config = None
self._service_factory = None
async def initialize(self, config_override: Dict[str, Any] = None):
"""Initialize with configuration"""
self._config = await load_config(overrides=config_override or {})
self._service_factory = await ServiceFactory(self._config).get_instance()
def create_search_tool(self) -> LangGraphToolWrapper:
"""Create search tool with config integration"""
from biz_bud.tools.search.unified import UnifiedSearchTool
search_config = self._config.tools.search
search_tool = UnifiedSearchTool(config=search_config)
return LangGraphToolWrapper(
tool_func=search_tool.search,
name="unified_search",
description="Search web with configurable providers"
)
def create_extraction_tool(self) -> LangGraphToolWrapper:
"""Create extraction tool with LLM service injection"""
extraction_config = self._config.tools.extract
llm_client = self._service_factory.get_llm_client()
return LangGraphToolWrapper(
tool_func=lambda **kwargs: self._extract_with_llm(
llm_client, extraction_config, **kwargs
),
name="semantic_extraction",
description="Extract structured data with LLM"
)
```
### 5. **State-Tool Bridge Pattern**
```python
# src/biz_bud/tools/state_integration/state_mappers.py
from biz_bud.states.research import ResearchState
from biz_bud.states.catalog import CatalogIntelState
from biz_bud.states.planner import PlannerState
class StateToolMapper:
"""Maps between LangGraph states and tool inputs/outputs"""
@staticmethod
def map_research_state_to_tool(state: ResearchState) -> Dict[str, Any]:
"""Convert ResearchState to tool-compatible format"""
return {
"query": state.get("search_query", ""),
"context": {
"search_history": state.get("search_history", []),
"visited_urls": state.get("visited_urls", []),
"sources": state.get("sources", [])
},
"prompt": state.get("initial_input", {}).get("query", "")
}
@staticmethod
def map_tool_result_to_research_state(
tool_result: Dict[str, Any],
state: ResearchState
) -> Dict[str, Any]:
"""Convert tool result back to ResearchState updates"""
if tool_result.get("error"):
return {
"errors": state.get("errors", []) + [tool_result["error"]],
"status": "error"
}
return {
"search_results": state.get("search_results", []) + [tool_result["result"]],
"search_history": state.get("search_history", []) + [{
"query": tool_result.get("query", ""),
"timestamp": tool_result.get("timestamp", ""),
"results": tool_result.get("result", [])
}]
}
```
### 6. **Refactoring Implementation Plan**
#### Phase 1: Foundation (Week 1)
1. **Create base abstractions** in `langgraph_integration/`
2. **Implement configuration validation** for tools
3. **Set up service factory integration**
#### Phase 2: Tool Migration (Week 2)
1. **Migrate existing tools** to use new base classes
2. **Update tool interfaces** to accept state objects
3. **Create state mappers** for each tool type
#### Phase 3: Graph Integration (Week 3)
1. **Build LangGraph nodes** using refactored tools
2. **Create workflow graphs** for common patterns
3. **Implement error handling** and retry logic
### 7. **Example Migration: Search Tool**
```python
# BEFORE: src/biz_bud/tools/search/web_search.py
class WebSearchTool:
def search(self, query: str, **kwargs):
# Direct implementation
pass
# AFTER: src/biz_bud/tools/implementations/search/web_search.py
from biz_bud.tools.langgraph_integration.base_tool import LangGraphToolWrapper
from biz_bud.tools.config_integration.tool_factory import ConfigAwareToolFactory
class WebSearchTool:
def __init__(self, config: SearchConfig, service_factory: ServiceFactory):
self.config = config
self.service_factory = service_factory
async def search(self, state: ToolState) -> Dict[str, Any]:
"""LangGraph-compliant search with full integration"""
# Use injected configuration
search_config = self.config
# Use service factory for LLM client
llm_client = await self.service_factory.get_llm_client()
# Process with prompt integration
enhanced_query = await llm_client.enhance_query(
state["query"],
context=state["context"],
prompt=state["prompt"]
)
# Execute search with proper error handling
try:
results = await self._execute_search(enhanced_query, search_config)
return {
"result": results,
"query": enhanced_query,
"timestamp": datetime.now().isoformat(),
"error": None
}
except Exception as e:
return {"result": None, "error": str(e)}
```
### 8. **Validation and Testing Strategy**
```python
# src/biz_bud/tests/tools/test_langgraph_integration.py
import pytest
from biz_bud.tools.langgraph_integration.tool_registry import ToolRegistry
from biz_bud.states.research import ResearchState
@pytest.mark.asyncio
async def test_tool_state_conversion():
"""Test state conversion for tools"""
state = ResearchState(
search_query="test query",
initial_input={"query": "user prompt"},
# ... other required fields
)
mapper = StateToolMapper()
tool_input = mapper.map_research_state_to_tool(state)
assert tool_input["prompt"] == "user prompt"
assert tool_input["query"] == "test query"
```
### 9. **Migration Checklist**
- [ ] Create `langgraph_integration/` directory
- [ ] Implement `LangGraphToolWrapper` base class
- [ ] Add configuration-aware tool factory
- [ ] Create state mappers for each tool type
- [ ] Update existing tools to use new patterns
- [ ] Create validation tests for each tool
- [ ] Document new integration patterns
- [ ] Provide backward compatibility layer
### 10. **Backward Compatibility Layer**
```python
# src/biz_bud/tools/legacy_compatibility.py
from biz_bud.tools.langgraph_integration.tool_registry import ToolRegistry
class LegacyToolAdapter:
"""Provides backward compatibility for existing code"""
def __init__(self):
self.registry = ToolRegistry()
def get_web_search_tool(self, *args, **kwargs):
"""Legacy interface preserved"""
return self.registry.get_tool("web_search")
```
This refactoring plan creates a clean separation between your existing tools and LangGraph integration while maintaining full backward compatibility. The new architecture provides type safety, configuration integration, and state management that aligns with LangGraph best practices [Repository: LangGraph GitHub].

View File

@@ -1,75 +0,0 @@
An analysis of your `core` package reveals a sophisticated and feature-rich system. However, its complexity also introduces several bug risks, potential crash points, and areas of redundancy. Here is a detailed breakdown.
### Bug Risks & Why You Are Crashing
Your application's stability is likely impacted by a few core architectural choices. Crashes are most likely originating from inconsistent data structures during runtime, race conditions in shared services, and incomplete error handling.
**1. Heavy Reliance on `TypedDict` for State Management (Highest Risk)**
This is the most significant risk and the most probable cause of runtime crashes like `KeyError` or `TypeError`.
* **The Problem:** Throughout the `src/biz_bud/states/` directory (e.g., `unified.py`, `rag_agent.py`, `base.py`), you use `TypedDict` to define the shape of your state. `TypedDict` is a tool for *static analysis* (like Mypy or Pyright) and provides **zero runtime validation**. If one node in your graph produces a dictionary that is missing a key or has a value of the wrong type, the next node that tries to access it will crash.
* **Evidence:** The state definitions are littered with `NotRequired[...]`, `... | None`, and `... | Any` (e.g., `src/biz_bud/states/unified.py`), which weakens the data contract between nodes. For example, a node might expect `state['search_results']` to be a list, but if a preceding search fails and the key is not added, a downstream node will crash with a `KeyError`.
* **Why it Crashes:** A function signature might indicate it accepts `ResearchState`, but at runtime, it just receives a standard Python `dict`. There's no guarantee that the dictionary actually conforms to the `ResearchState` structure.
* **Recommendation:** Systematically migrate all `TypedDict` state definitions to Pydantic `BaseModel`. Pydantic models perform runtime validation, which would turn these crashes into clear, catchable `ValidationError` exceptions at the boundary of each node. The `core/validation/graph_validation.py` module already attempts to do this with its `PydanticValidator`, but this should be the default for all state objects, not an add-on.
**2. Inconsistent Service and Singleton Management**
Race conditions and improper initialization/cleanup of shared services can lead to unpredictable behavior and crashes.
* **The Problem:** You have multiple patterns for managing global or shared objects: a `ServiceFactory` (`factory/service_factory.py`), a `SingletonLifecycleManager` (`services/singleton_manager.py`), and a `HTTPClient` singleton (`core/networking/http_client.py`). While sophisticated, inconsistencies in their use can cause issues. For instance, a service might be used before it's fully initialized or after it has been cleaned up.
* **Evidence:** The `ServiceFactory` uses an `_initializing` dictionary with `asyncio.Task` to prevent re-entrant initialization, which is good. However, if any service's `initialize()` method fails, the task will hold an exception that could be raised in an unexpected location when another part of the app tries to get that service. The `SingletonLifecycleManager` adds another layer of complexity, and it's unclear if all critical singletons (like the `ServiceFactory` itself) are registered with it.
* **Why it Crashes:** A race condition during startup could lead to a service being requested before its dependencies are ready. An error during the async initialization of a critical service (like the database connection pool in `services/db.py`) could cause cascading failures across the application.
**3. Potential for Unhandled Exceptions and Swallowed Errors**
While you have a comprehensive error-handling framework in `core/errors`, there are areas where it might be bypassed.
* **The Problem:** The `handle_errors` decorator and the `ErrorRouter` provide a structured way to manage failures. However, there are still many generic `try...except Exception` blocks in the codebase. If these blocks don't convert the generic exception into a `BusinessBuddyError` or log it properly, the root cause of the error can be hidden.
* **Evidence:** In `nodes/llm/call.py`, the `call_model_node` has a broad `except Exception as e:` block. While it does attempt to log and create an `ErrorInfo` object, any failure within this exception handling itself (e.g., a serialization issue with the state) would be an unhandled crash. Similarly, in `tools/clients/r2r.py`, the `search` method has a broad `except Exception`, which could mask the actual issue from the R2R client.
* **Why it Crashes:** A generic exception that is caught but not properly processed or routed can lead to the application being in an inconsistent state, causing a different crash later on. If an exception is "swallowed" (caught and ignored), the application might proceed with `None` or incorrect data, leading to a `TypeError` or `AttributeError` in a subsequent step.
**4. Configuration-Related Failures**
The system's behavior is heavily dependent on the `config.yaml` and environment variables. A missing or invalid configuration can lead to startup or runtime failures.
* **The Problem:** The configuration loader (`core/config/loader.py`) merges settings from multiple sources. Critical values like API keys are often optional in the Pydantic models (e.g., `core/config/schemas/services.py`). If a key is not provided in any source, the value will be `None`.
* **Evidence:** In `services/llm/client.py`, the `_get_llm_instance` function might receive `api_key=None`. While the underlying LangChain clients might raise an error, the application doesn't perform an upfront check, leading to a failure deep within a library call, which can be harder to debug.
* **Why it Crashes:** A service attempting to initialize without a required API key or a valid URL will crash. For example, the `PostgresStore` in `services/db.py` will fail to create its connection pool if database credentials are missing.
### Redundancy and Duplication
Duplicate code increases maintenance overhead and creates a risk of inconsistent behavior when one copy is updated but the other is not.
**1. Multiple JSON Extraction Implementations**
* **The Problem:** The logic for parsing potentially malformed JSON from LLM responses is implemented in at least two different places.
* **Evidence:**
* `services/llm/utils.py` contains a very detailed and robust `parse_json_response` function with multiple recovery strategies.
* `tools/capabilities/extraction/text/structured_extraction.py` contains a similar, but distinct, `extract_json_from_text` function, also with its own recovery logic like `_fix_truncated_json`.
* **Recommendation:** Consolidate this logic into a single, robust utility function, likely the one in `tools/capabilities/extraction/text/structured_extraction.py` as it appears more comprehensive. All other parts of the code should call this central function.
**2. Redundant URL Parsing and Analysis**
* **The Problem:** Logic for parsing, normalizing, and analyzing URLs is spread across multiple files instead of being centralized.
* **Evidence:**
* `core/utils/url_analyzer.py` provides a detailed `analyze_url_type` function.
* `core/utils/url_normalizer.py` provides a `URLNormalizer` class.
* Despite these utilities, manual URL parsing using `urlparse` and custom domain extraction logic is found in `nodes/scrape/route_url.py`, `tools/utils/url_filters.py`, `nodes/search/ranker.py`, and `graphs/rag/nodes/upload_r2r.py`.
* **Recommendation:** All URL analysis and normalization should be done through the `URLAnalyzer` and `URLNormalizer` utilities in `core/utils`. This ensures consistent behavior for identifying domains, extensions, and repository URLs.
**3. Duplicated State Field Definitions**
* **The Problem:** Even with `BaseState`, common state fields are often redefined or handled inconsistently across different state `TypedDict`s.
* **Evidence:**
* Fields like `query`, `search_results`, `extracted_info`, and `synthesis` appear in multiple state definitions (`states/research.py`, `states/buddy.py`, `states/unified.py`).
* `states/unified.py` is an attempt to solve this but acts as a "god object," containing almost every possible field from every workflow. This makes it very difficult to reason about what state is actually available at any given point in a specific graph.
* **Recommendation:** Instead of a single unified state, embrace smaller, composable Pydantic models for state. Define mixin classes for common concerns (e.g., a `SearchStateMixin` with `search_query` and `search_results`) that can be included in more specific state models for different graphs.
**4. Inconsistent Service Client Initialization**
* **The Problem:** While the `ServiceFactory` is the intended pattern, some parts of the code appear to instantiate service clients directly.
* **Evidence:**
* `tools/clients/` contains standalone clients like `FirecrawlClient` and `TavilyClient`.
* `tools/capabilities/search/tool.py` directly instantiates providers like `TavilySearchProvider` inside its `_initialize_providers` method.
* **Recommendation:** All external service clients and providers should be managed and instantiated exclusively through the `ServiceFactory`. This allows for centralized configuration, singleton management, and proper lifecycle control (initialization and cleanup).

View File

@@ -1,213 +0,0 @@
# Service Factory Client Integration
## Overview
This document outlines the architectural improvement made to integrate tool clients (JinaClient, FirecrawlClient, TavilyClient) with the ServiceFactory pattern, ensuring consistent dependency injection, lifecycle management, and testing patterns.
## Problem Identified
**Issue**: Test files were directly instantiating tool clients instead of using the ServiceFactory pattern:
```python
# ❌ INCORRECT: Direct instantiation (bypassing factory benefits)
client = JinaClient(Mock(spec=AppConfig))
client = FirecrawlClient(Mock(spec=AppConfig))
client = TavilyClient(Mock(spec=AppConfig))
```
**Root Cause**: Tool clients inherited from `BaseService` but lacked corresponding factory methods:
- Missing `get_jina_client()` → JinaClient
- Missing `get_firecrawl_client()` → FirecrawlClient
- Missing `get_tavily_client()` → TavilyClient
## Solution Implemented
### 1. Added Factory Methods
Extended `ServiceFactory` with new client methods in `/app/src/biz_bud/services/factory/service_factory.py`:
```python
async def get_jina_client(self) -> "JinaClient":
"""Get the Jina client service."""
from biz_bud.tools.clients.jina import JinaClient
return await self.get_service(JinaClient)
async def get_firecrawl_client(self) -> "FirecrawlClient":
"""Get the Firecrawl client service."""
from biz_bud.tools.clients.firecrawl import FirecrawlClient
return await self.get_service(FirecrawlClient)
async def get_tavily_client(self) -> "TavilyClient":
"""Get the Tavily client service."""
from biz_bud.tools.clients.tavily import TavilyClient
return await self.get_service(TavilyClient)
```
### 2. Added Type Checking Imports
Added TYPE_CHECKING imports for proper type hints:
```python
if TYPE_CHECKING:
from biz_bud.tools.clients.firecrawl import FirecrawlClient
from biz_bud.tools.clients.jina import JinaClient
from biz_bud.tools.clients.tavily import TavilyClient
```
### 3. Fixed Pydantic Compatibility
Fixed Pydantic errors in legacy_tools.py by adding proper type annotations:
```python
# Before (causing Pydantic errors):
args_schema = StatisticsExtractionInput
# After (properly typed):
args_schema: type[StatisticsExtractionInput] = StatisticsExtractionInput
```
### 4. Created Demonstration Tests
Created comprehensive test files showing proper factory usage:
- `/app/tests/unit_tests/services/test_factory_client_integration.py`
- `/app/tests/unit_tests/tools/clients/test_jina_factory_pattern.py`
## Proper Usage Pattern
### ✅ Correct Factory Pattern
```python
# Proper dependency injection and lifecycle management
async with ServiceFactory(config) as factory:
jina_client = await factory.get_jina_client()
firecrawl_client = await factory.get_firecrawl_client()
tavily_client = await factory.get_tavily_client()
# Use clients with automatic cleanup
result = await jina_client.search("query")
# Automatic cleanup when context exits
```
### ❌ Incorrect Direct Instantiation
```python
# Bypasses factory benefits - avoid this pattern
client = JinaClient(Mock(spec=AppConfig)) # No dependency injection
client = FirecrawlClient(config) # No lifecycle management
client = TavilyClient(config) # No singleton behavior
```
## Benefits of Factory Pattern
### 1. **Dependency Injection**
- Automatic configuration injection
- Consistent service creation
- Centralized configuration management
### 2. **Lifecycle Management**
- Proper initialization and cleanup
- Resource management
- Context manager support
### 3. **Singleton Behavior**
- Single instance per factory
- Memory efficiency
- State consistency
### 4. **Thread Safety**
- Race-condition-free initialization
- Concurrent access protection
- Proper locking mechanisms
### 5. **Testing Benefits**
- Consistent mocking patterns
- Easier dependency substitution
- Better test isolation
## Testing Pattern Comparison
### Old Pattern (Incorrect)
```python
def test_client_functionality():
# Direct instantiation bypasses factory
client = JinaClient(Mock(spec=AppConfig)) # ❌
result = client.some_method()
assert result == expected
```
### New Pattern (Correct)
```python
@pytest.mark.asyncio
async def test_client_functionality(service_factory):
# Factory-based creation with proper lifecycle
with patch.object(JinaClient, 'some_method') as mock_method:
client = await service_factory.get_jina_client() # ✅
result = await client.some_method()
mock_method.assert_called_once()
```
## Migration Guide for Existing Tests
### Step 1: Update Test Fixtures
```python
@pytest.fixture
async def service_factory(mock_app_config):
factory = ServiceFactory(mock_app_config)
yield factory
await factory.cleanup()
```
### Step 2: Replace Direct Instantiation
```python
# Before:
client = JinaClient(Mock(spec=AppConfig))
# After:
client = await service_factory.get_jina_client()
```
### Step 3: Add Proper Mocking
```python
with patch.object(JinaClient, '_validate_config') as mock_validate:
with patch.object(JinaClient, 'initialize', new_callable=AsyncMock):
mock_validate.return_value = Mock()
client = await service_factory.get_jina_client()
```
## Verification
### Factory Methods Available
-`ServiceFactory.get_jina_client()`
-`ServiceFactory.get_firecrawl_client()`
-`ServiceFactory.get_tavily_client()`
### Tests Pass
- ✅ Factory integration tests
- ✅ Singleton behavior tests
- ✅ Lifecycle management tests
- ✅ Thread safety tests
### Linting Clean
- ✅ Ruff checks pass
- ✅ Pydantic compatibility resolved
- ✅ Type hints correct
## Next Steps
1. **Update Existing Test Files**: Migrate remaining test files to use factory pattern
2. **Add More Factory Methods**: Consider adding factory methods for other BaseService subclasses
3. **Documentation Updates**: Update relevant README files with factory patterns
4. **Code Reviews**: Ensure all new code follows factory pattern
## Files Modified
### Core Changes
- `src/biz_bud/services/factory/service_factory.py` - Added factory methods
- `src/biz_bud/tools/capabilities/extraction/legacy_tools.py` - Fixed Pydantic errors
### Documentation/Tests
- `tests/unit_tests/services/test_factory_client_integration.py` - Factory integration tests
- `tests/unit_tests/tools/clients/test_jina_factory_pattern.py` - Factory pattern demonstration
- `docs/service-factory-client-integration.md` - This documentation
This architectural improvement ensures consistent service management across the codebase and provides a foundation for proper dependency injection and lifecycle management patterns.

View File

@@ -7,6 +7,8 @@
"research": "./src/biz_bud/graphs/research/graph.py:research_graph_factory_async",
"catalog": "./src/biz_bud/graphs/catalog/graph.py:catalog_factory_async",
"paperless": "./src/biz_bud/graphs/paperless/graph.py:paperless_graph_factory_async",
"receipt_processing": "./src/biz_bud/graphs/paperless/graph.py:create_receipt_processing_graph",
"paperless_agent": "./src/biz_bud/graphs/paperless/agent.py:create_paperless_agent",
"url_to_r2r": "./src/biz_bud/graphs/rag/graph.py:url_to_r2r_graph_factory_async",
"error_handling": "./src/biz_bud/graphs/error_handling.py:error_handling_graph_factory_async",
"analysis": "./src/biz_bud/graphs/analysis/graph.py:analysis_graph_factory_async",
@@ -14,6 +16,14 @@
},
"env": ".env",
"http": {
"app": "./src/biz_bud/webapp.py:app"
"app": "./src/biz_bud/webapp.py:app",
"host": "0.0.0.0",
"port": 2024,
"cors": {
"allow_origins": ["*"],
"allow_credentials": true,
"allow_methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS"],
"allow_headers": ["*"]
}
}
}

View File

@@ -1,91 +0,0 @@
#!/usr/bin/env python3
"""Audit to find files with multiple violations for prioritized fixing."""
import ast
import os
class LoopConditionalFinder(ast.NodeVisitor):
def __init__(self):
self.violations = []
self.current_function = None
def visit_FunctionDef(self, node):
if node.name.startswith('test_'):
old_function = self.current_function
self.current_function = node.name
self.generic_visit(node)
self.current_function = old_function
else:
self.generic_visit(node)
def visit_AsyncFunctionDef(self, node):
if node.name.startswith('test_'):
old_function = self.current_function
self.current_function = node.name
self.generic_visit(node)
self.current_function = old_function
else:
self.generic_visit(node)
def visit_For(self, node):
if self.current_function:
self.violations.append(f"Line {node.lineno}: for loop in {self.current_function}")
self.generic_visit(node)
def visit_While(self, node):
if self.current_function:
self.violations.append(f"Line {node.lineno}: while loop in {self.current_function}")
self.generic_visit(node)
def visit_If(self, node):
if self.current_function:
self.violations.append(f"Line {node.lineno}: if statement in {self.current_function}")
self.generic_visit(node)
def find_violations_in_file(file_path):
try:
with open(file_path, 'r') as f:
content = f.read()
tree = ast.parse(content)
finder = LoopConditionalFinder()
finder.visit(tree)
return finder.violations
except Exception as e:
return [f"Error parsing file: {e}"]
# Find test files and check violations
test_dirs = ['tests/unit_tests'] # Focus on unit tests first
multi_violation_files = []
for test_dir in test_dirs:
if os.path.exists(test_dir):
for root, dirs, files in os.walk(test_dir):
for file in files:
if file.startswith('test_') and file.endswith('.py'):
file_path = os.path.join(root, file)
violations = find_violations_in_file(file_path)
if len(violations) > 1: # Multiple violations
multi_violation_files.append((file_path, len(violations), violations))
# Sort by violation count (highest first)
multi_violation_files.sort(key=lambda x: x[1], reverse=True)
print("=== MULTI-VIOLATION FILES (Unit Tests) ===")
print(f"Found {len(multi_violation_files)} files with multiple violations")
print()
# Show top 20 files with most violations
print("TOP 20 FILES BY VIOLATION COUNT:")
for i, (file_path, count, violations) in enumerate(multi_violation_files[:20]):
print(f"{i+1:2d}. {file_path}: {count} violations")
print()
print("DETAILED VIEW OF TOP 10:")
for i, (file_path, count, violations) in enumerate(multi_violation_files[:10]):
print(f"\n{i+1}. {file_path} ({count} violations):")
for violation in violations[:5]: # Show first 5 violations
print(f" {violation}")
if len(violations) > 5:
print(f" ... and {len(violations) - 5} more")

View File

@@ -95,6 +95,7 @@ dependencies = [
"fastapi>=0.115.14",
"uvicorn>=0.35.0",
"flake8>=7.3.0",
"python-dateutil>=2.9.0.post0",
]
[project.optional-dependencies]

View File

@@ -20,14 +20,13 @@ components while providing a flexible orchestration layer.
import asyncio
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, StateGraph
from langgraph.graph.state import CompiledStateGraph
from biz_bud.agents.buddy_nodes_registry import ( # Import nodes
buddy_analyzer_node,
from biz_bud.agents.buddy_nodes_registry import buddy_analyzer_node # Import nodes
from biz_bud.agents.buddy_nodes_registry import (
buddy_executor_node,
buddy_orchestrator_node,
buddy_synthesizer_node,
@@ -44,9 +43,6 @@ from biz_bud.services.factory import ServiceFactory
from biz_bud.states.buddy import BuddyState
from biz_bud.tools.capabilities.workflow.execution import ResponseFormatter
if TYPE_CHECKING:
from langgraph.graph.graph import CompiledGraph
logger = get_logger(__name__)
@@ -61,7 +57,7 @@ __all__ = [
def create_buddy_orchestrator_graph(
config: AppConfig | None = None,
) -> CompiledStateGraph:
) -> CompiledStateGraph[BuddyState]:
"""Create the Buddy orchestrator graph with all components.
Args:
@@ -128,7 +124,7 @@ def create_buddy_orchestrator_graph(
def create_buddy_orchestrator_agent(
config: AppConfig | None = None,
service_factory: ServiceFactory | None = None,
) -> CompiledStateGraph:
) -> CompiledStateGraph[BuddyState]:
"""Create the Buddy orchestrator agent.
Args:
@@ -153,13 +149,13 @@ def create_buddy_orchestrator_agent(
# Direct singleton instance
_buddy_agent_instance: CompiledStateGraph | None = None
_buddy_agent_instance: CompiledStateGraph[BuddyState] | None = None
def get_buddy_agent(
config: AppConfig | None = None,
service_factory: ServiceFactory | None = None,
) -> CompiledStateGraph:
) -> CompiledStateGraph[BuddyState]:
"""Get or create the Buddy agent instance.
Uses singleton pattern for default instance.
@@ -330,12 +326,12 @@ async def stream_buddy_agent(
# Export for LangGraph API
def buddy_agent_factory(config: RunnableConfig) -> "CompiledGraph":
def buddy_agent_factory(config: RunnableConfig) -> CompiledStateGraph[BuddyState]:
"""Create factory function for LangGraph API."""
return get_buddy_agent()
async def buddy_agent_factory_async(config: RunnableConfig) -> "CompiledGraph":
async def buddy_agent_factory_async(config: RunnableConfig) -> CompiledStateGraph[BuddyState]:
"""Async factory function for LangGraph API to avoid blocking calls."""
# Use asyncio.to_thread to run the synchronous initialization in a thread
# This prevents blocking the event loop

View File

@@ -304,7 +304,7 @@ Respond with only: "simple" or "complex"
@handle_errors()
@ensure_immutable_node
async def buddy_orchestrator_node(
state: BuddyState, config: RunnableConfig | None = None
state: BuddyState, config: RunnableConfig | None
) -> dict[str, Any]: # noqa: ARG001
"""Coordinate the execution flow as main orchestrator node."""
logger.info("Buddy orchestrator analyzing request")
@@ -770,7 +770,7 @@ async def buddy_orchestrator_node(
@handle_errors()
@ensure_immutable_node
async def buddy_executor_node(
state: BuddyState, config: RunnableConfig | None = None
state: BuddyState, config: RunnableConfig | None
) -> dict[str, Any]: # noqa: ARG001
"""Execute the current step in the plan."""
current_step = state.get("current_step")
@@ -917,7 +917,7 @@ async def buddy_executor_node(
@handle_errors()
@ensure_immutable_node
async def buddy_analyzer_node(
state: BuddyState, config: RunnableConfig | None = None
state: BuddyState, config: RunnableConfig | None
) -> dict[str, Any]: # noqa: ARG001
"""Analyze execution results and determine if plan modification is needed."""
logger.info("Analyzing execution results")
@@ -962,7 +962,7 @@ async def buddy_analyzer_node(
@handle_errors()
@ensure_immutable_node
async def buddy_synthesizer_node(
state: BuddyState, config: RunnableConfig | None = None
state: BuddyState, config: RunnableConfig | None
) -> dict[str, Any]: # noqa: ARG001
"""Synthesize final results from all executions."""
logger.info("Synthesizing final results")
@@ -1095,7 +1095,7 @@ async def buddy_synthesizer_node(
@handle_errors()
@ensure_immutable_node
async def buddy_capability_discovery_node(
state: BuddyState, config: RunnableConfig | None = None
state: BuddyState, config: RunnableConfig | None
) -> dict[str, Any]: # noqa: ARG001
"""Discover and refresh system capabilities from registries.

View File

@@ -28,8 +28,10 @@ from .embeddings import get_embeddings_instance
from .enums import ReportSource, ResearchType, Tone
# Errors - import everything from the errors package
from .errors import ( # Error aggregation; Error telemetry; Base error types; Error logging; Error formatter; Error router; Router config; Error handler
AggregatedError,
from .errors import (
AggregatedError, # Error aggregation; Error telemetry; Base error types; Error logging; Error formatter; Error router; Router config; Error handler
)
from .errors import (
AlertThreshold,
AuthenticationError,
BusinessBuddyError,

View File

@@ -120,7 +120,7 @@ class LLMCache[T]:
# Try to determine the type based on T
# This is a simplified approach - for str type, try UTF-8 decode first
try:
if hasattr(self, '_type_hint') and getattr(self, '_type_hint', None) == str:
if hasattr(self, '_type_hint') and getattr(self, '_type_hint', None) is str:
return data.decode('utf-8') # type: ignore[return-value]
# For other types, try pickle first, then UTF-8 as fallback
try:
@@ -145,7 +145,7 @@ class LLMCache[T]:
if hasattr(backend_class, '__orig_bases__'):
orig_bases = getattr(backend_class, '__orig_bases__', ())
for base in orig_bases:
if hasattr(base, '__args__') and base.__args__ and base.__args__[0] == bytes:
if hasattr(base, '__args__') and base.__args__ and base.__args__[0] is bytes:
return True
# Check for bytes-only signature by attempting to inspect the set method
@@ -154,7 +154,7 @@ class LLMCache[T]:
if hasattr(self._backend, 'set'):
sig = inspect.signature(self._backend.set)
for param_name, param in sig.parameters.items():
if param_name == 'value' and param.annotation == bytes:
if param_name == 'value' and param.annotation is bytes:
return True
except Exception:
pass

View File

@@ -130,4 +130,3 @@ class InMemoryCache(CacheBackend[T], Generic[T]):
InMemoryCache doesn't require async initialization, so this is a no-op.
This method exists for compatibility with the caching system.
"""
pass

View File

@@ -186,6 +186,18 @@ DEFAULT_TOTAL_WORDS: Final[int] = 1200
# Error messages
UNREACHABLE_ERROR: Final = "Unreachable code path - this should never be executed"
# =============================================================================
# MESSAGE HISTORY MANAGEMENT CONSTANTS
# =============================================================================
# Message history management for conversation context
MAX_CONTEXT_TOKENS: Final[int] = 8000 # Maximum tokens for message context
MAX_MESSAGE_WINDOW: Final[int] = 30 # Fallback maximum messages to keep
MESSAGE_SUMMARY_THRESHOLD: Final[int] = 20 # When to trigger summarization
HISTORY_MANAGEMENT_STRATEGY: Final[str] = "hybrid" # Options: "trim_only", "summarize_only", "hybrid", "delete_and_trim"
PRESERVE_RECENT_MESSAGES: Final[int] = 5 # Minimum recent messages to always keep
MAX_SUMMARY_TOKENS: Final[int] = 500 # Maximum tokens for conversation summaries
# =============================================================================
# EMBEDDING MODEL CONSTANTS
# =============================================================================

View File

@@ -317,6 +317,8 @@ class ExtractionConfig(BaseModel):
"""
model_config = {"protected_namespaces": ()}
model_name: str = Field(
default="openai/gpt-4o",
description="The LLM model to use for semantic extraction.",

View File

@@ -17,7 +17,9 @@ Example:
# New focused router classes - recommended for new code
from .basic_routing import BasicRouters
from .command_patterns import CommandRouters
from .command_patterns import (
CommandRouters,
)
from .command_patterns import create_command_router as create_langgraph_command_router
from .command_patterns import (
create_conditional_command_router as create_conditional_langgraph_command_router,
@@ -68,6 +70,7 @@ from .validation import (
check_data_freshness,
check_privacy_compliance,
validate_output_format,
validate_required_fields,
)
from .workflow_routing import WorkflowRouters
@@ -90,6 +93,7 @@ __all__ = [
"check_accuracy",
"check_confidence_level",
"validate_output_format",
"validate_required_fields",
"check_privacy_compliance",
"check_data_freshness",
# User interaction

View File

@@ -103,6 +103,7 @@ from .router_config import RouterConfig, configure_default_router
from .specialized_exceptions import (
CollectionManagementError,
ConditionSecurityError,
DatabaseError,
GraphValidationError,
HTTPClientError,
ImmutableStateError,
@@ -276,6 +277,7 @@ __all__ = [
"URLValidationError",
"ServiceHelperRemovedError",
"WebToolsRemovedError",
"DatabaseError",
"R2RConnectionError",
"R2RDatabaseError",
"CollectionManagementError",

View File

@@ -615,6 +615,41 @@ class WebToolsRemovedError(ServiceHelperRemovedError):
self.tool_name = tool_name
# === Database Exceptions ===
class DatabaseError(BusinessBuddyError):
"""General exception for database operation failures."""
def __init__(
self,
message: str,
operation: str | None = None,
table_name: str | None = None,
database_name: str | None = None,
context: ErrorContext | None = None,
cause: Exception | None = None,
):
"""Initialize database error with operation details."""
super().__init__(
message,
ErrorSeverity.ERROR,
ErrorCategory.DATABASE,
context,
cause,
ErrorNamespace.DB_QUERY_ERROR,
)
self.operation = operation
self.table_name = table_name
self.database_name = database_name
if operation:
self.context.metadata["operation"] = operation
if table_name:
self.context.metadata["table_name"] = table_name
if database_name:
self.context.metadata["database_name"] = database_name
# === R2R and RAG System Exceptions ===
@@ -839,6 +874,8 @@ __all__ = [
# Service management exceptions
"ServiceHelperRemovedError",
"WebToolsRemovedError",
# Database exceptions
"DatabaseError",
# R2R and RAG system exceptions
"R2RConnectionError",
"R2RDatabaseError",

View File

@@ -7,7 +7,6 @@ nodes and tools in the Business Buddy framework.
import asyncio
import functools
import inspect
import time
from collections.abc import Callable
from datetime import UTC, datetime

View File

@@ -8,20 +8,26 @@ from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
from typing import Any, cast
from typing import Any, TypeVar, cast
from langchain_core.runnables import RunnableConfig
from langgraph.graph import StateGraph
from .runnable_config import ConfigurationProvider, create_runnable_config
# Note: StateLike might not be available in all versions of LangGraph
# We'll use an unbound TypeVar for compatibility
StateType = TypeVar('StateType')
def configure_graph_with_injection(
graph_builder: StateGraph,
graph_builder: StateGraph[Any],
app_config: Any,
service_factory: Any | None = None,
**config_overrides: Any,
) -> StateGraph:
) -> StateGraph[Any]:
"""Configure a graph builder with dependency injection.
This function wraps all nodes in the graph to automatically inject
@@ -85,7 +91,7 @@ def create_config_injected_node(
# Node already expects config, wrap to provide it
@wraps(node_func)
async def config_aware_wrapper(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig | None
# noqa: ARG001
) -> Any:
# Merge base config with runtime config
@@ -143,7 +149,7 @@ def update_node_to_use_config(
@wraps(node_func)
async def wrapper(
state: dict[str, object], config: RunnableConfig | None = None
state: dict[str, object], config: RunnableConfig | None
# noqa: ARG001
) -> object:
# Call original without config (for backward compatibility)

View File

@@ -18,7 +18,7 @@ class ConfigurationProvider:
configuration values, service instances, and metadata in a type-safe manner.
"""
def __init__(self, config: RunnableConfig | None = None):
def __init__(self, config: RunnableConfig | None):
"""Initialize the configuration provider.
Args:

View File

@@ -18,6 +18,7 @@ import copy
from collections.abc import Callable
from typing import Any, TypeVar, cast
import pandas as pd
from typing_extensions import ParamSpec
from biz_bud.core.errors import ImmutableStateError, StateValidationError
@@ -27,6 +28,65 @@ T = TypeVar("T")
CallableT = TypeVar("CallableT", bound=Callable[..., object])
def _states_equal(state1: Any, state2: Any) -> bool:
"""Compare two states safely, handling DataFrames and other complex objects.
Args:
state1: First state to compare
state2: Second state to compare
Returns:
True if states are equal, False otherwise
"""
if type(state1) is not type(state2):
return False
if isinstance(state1, dict) and isinstance(state2, dict):
if set(state1.keys()) != set(state2.keys()):
return False
for key in state1:
if not _states_equal(state1[key], state2[key]):
return False
return True
elif isinstance(state1, (list, tuple)):
if len(state1) != len(state2):
return False
return all(_states_equal(a, b) for a, b in zip(state1, state2))
elif isinstance(state1, pd.DataFrame):
if not isinstance(state2, pd.DataFrame):
return False
try:
return state1.equals(state2)
except Exception:
# If equals fails, consider them different
return False
elif isinstance(state1, pd.Series):
if not isinstance(state2, pd.Series):
return False
try:
return state1.equals(state2)
except Exception:
return False
else:
try:
return state1 == state2
except ValueError:
# Handle cases like numpy arrays that raise ValueError on comparison
try:
import numpy as np
if isinstance(state1, np.ndarray) and isinstance(state2, np.ndarray):
return np.array_equal(state1, state2)
except ImportError:
pass
# If comparison fails, consider them different
return False
class ImmutableDict:
"""An immutable dictionary that prevents modifications.
@@ -356,7 +416,7 @@ def ensure_immutable_node(
result = await result
# Verify original state wasn't modified (belt and suspenders)
if state != original_snapshot:
if not _states_equal(state, original_snapshot):
raise ImmutableStateError(
f"Node {node_func.__name__} modified the input state. "
"Nodes must return new state objects instead of mutating."
@@ -389,7 +449,7 @@ def ensure_immutable_node(
result = node_func(*new_args, **kwargs)
# Verify original state wasn't modified (belt and suspenders)
if state != original_snapshot:
if not _states_equal(state, original_snapshot):
raise ImmutableStateError(
f"Node {node_func.__name__} modified the input state. "
"Nodes must return new state objects instead of mutating."

View File

@@ -233,7 +233,6 @@ class APIClient:
) -> None:
"""Exit async context."""
# HTTPClient is a singleton and manages its own lifecycle
pass
@handle_errors(NetworkError, RateLimitError)
async def request(

View File

@@ -151,7 +151,6 @@ class RateLimiter:
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""Async context manager exit."""
pass
async def with_timeout[T]( # noqa: D103

View File

@@ -129,7 +129,6 @@ class CircuitBreakerState(Enum):
class CircuitBreakerError(Exception):
"""Raised when circuit breaker is open."""
pass
@dataclass

View File

@@ -56,17 +56,14 @@ ConfigChangeHandler = (
class ConfigurationError(Exception):
"""Base exception for configuration-related errors."""
pass
class ConfigurationValidationError(ConfigurationError):
"""Raised when configuration validation fails."""
pass
class ConfigurationLoadError(ConfigurationError):
"""Raised when configuration loading fails."""
pass
class ConfigurationManager:

View File

@@ -64,17 +64,14 @@ T = TypeVar("T")
class DIError(Exception):
"""Base exception for dependency injection errors."""
pass
class BindingNotFoundError(DIError):
"""Raised when a required binding is not found."""
pass
class InjectionError(DIError):
"""Raised when dependency injection fails."""
pass
class DIContainer:

View File

@@ -62,17 +62,14 @@ logger = get_logger(__name__)
class LifecycleError(Exception):
"""Base exception for lifecycle management errors."""
pass
class StartupError(LifecycleError):
"""Raised when service startup fails."""
pass
class ShutdownError(LifecycleError):
"""Raised when service shutdown fails."""
pass
class ServiceLifecycleManager:

View File

@@ -92,22 +92,18 @@ class ServiceProtocol(ABC):
class ServiceError(Exception):
"""Base exception for service-related errors."""
pass
class ServiceInitializationError(ServiceError):
"""Raised when service initialization fails."""
pass
class ServiceNotFoundError(ServiceError):
"""Raised when a requested service is not registered."""
pass
class CircularDependencyError(ServiceError):
"""Raised when circular dependencies are detected."""
pass
class ServiceRegistry:

View File

@@ -115,7 +115,6 @@ except ImportError:
# Legacy classes for backward compatibility
class URLProcessorConfig:
"""Legacy config class."""
pass
class ValidationLevel:
"""Legacy validation level enum."""
@@ -125,69 +124,53 @@ except ImportError:
class URLDiscoverer:
"""Legacy discoverer class."""
pass
class URLFilter:
"""Legacy filter class."""
pass
class URLValidator:
"""Legacy validator class."""
pass
# Legacy exception classes
class URLProcessingError(Exception):
"""Base URL processing error."""
pass
class URLValidationError(URLProcessingError):
"""URL validation error."""
pass
class URLNormalizationError(URLProcessingError):
"""URL normalization error."""
pass
class URLDiscoveryError(URLProcessingError):
"""URL discovery error."""
pass
class URLFilterError(URLProcessingError):
"""URL filter error."""
pass
class URLDeduplicationError(URLProcessingError):
"""URL deduplication error."""
pass
class URLCacheError(URLProcessingError):
"""URL cache error."""
pass
# Legacy data model classes
class ProcessedURL:
"""Legacy processed URL model."""
pass
class URLAnalysis:
"""Legacy URL analysis model."""
pass
class ValidationResult:
"""Legacy validation result model."""
pass
class DiscoveryResult:
"""Legacy discovery result model."""
pass
class FilterResult:
"""Legacy filter result model."""
pass
class DeduplicationResult:
"""Legacy deduplication result model."""
pass
# URLProcessor is defined above in the try block
@@ -340,7 +323,6 @@ def _initialize_module() -> None:
"""Initialize the URL processing module."""
# Module initialization logic can go here
# For now, just validate that required dependencies are available
pass
# Run initialization when module is imported

View File

@@ -56,4 +56,3 @@ class URLDiscoverer:
async def close(self) -> None:
"""Close discoverer resources."""
pass

View File

@@ -79,11 +79,7 @@ def format_raw_input(
raw_input_str = f"<non-serializable dict: {e}>"
return raw_input_str, extracted_query
# For unsupported types
if not isinstance(raw_input, str):
raw_input_str = f"<unsupported type: {type(raw_input).__name__}>"
return raw_input_str, user_query
# At this point, raw_input must be a str
return str(raw_input), user_query

File diff suppressed because it is too large Load Diff

View File

@@ -56,7 +56,6 @@ class Validator(ABC):
@abstractmethod
def validate(self, value: Any) -> tuple[bool, str | None]: # noqa: ANN401
"""Validate value and return (is_valid, error_message)."""
pass
class TypeValidator(Validator):

View File

@@ -101,7 +101,7 @@ _handle_analysis_errors = handle_error(
)
def create_analysis_graph() -> "CompiledStateGraph":
def create_analysis_graph() -> "CompiledStateGraph[AnalysisState]":
"""Create the data analysis workflow graph.
This graph implements a comprehensive analysis workflow:
@@ -121,10 +121,10 @@ def create_analysis_graph() -> "CompiledStateGraph":
# Add nodes
workflow.add_node("validate_input", parse_and_validate_initial_payload)
workflow.add_node("plan_analysis", formulate_analysis_plan_node)
workflow.add_node("prepare_data", prepare_analysis_data_node)
workflow.add_node("perform_analysis", perform_analysis_node)
workflow.add_node("generate_visualizations", generate_visualizations_node)
workflow.add_node("plan_analysis", formulate_analysis_plan_node) # type: ignore[arg-type]
workflow.add_node("prepare_data", prepare_analysis_data_node) # type: ignore[arg-type]
workflow.add_node("perform_analysis", perform_analysis_node) # type: ignore[arg-type]
workflow.add_node("generate_visualizations", generate_visualizations_node) # type: ignore[arg-type]
workflow.add_node("interpret_results", interpret_results_node)
workflow.add_node("compile_report", compile_analysis_report_node)
workflow.add_node("handle_error", handle_graph_error)
@@ -172,7 +172,7 @@ def create_analysis_graph() -> "CompiledStateGraph":
return workflow.compile()
def analysis_graph_factory(config: RunnableConfig) -> "CompiledStateGraph":
def analysis_graph_factory(config: RunnableConfig) -> "CompiledStateGraph[AnalysisState]":
"""Create analysis graph for graph-as-tool pattern.
This factory function follows the standard pattern for graphs

View File

@@ -159,7 +159,7 @@ def _prepare_dataframe(df: pd.DataFrame, key: str) -> tuple[pd.DataFrame, list[s
async def prepare_analysis_data(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Prepare all datasets in the workflow state for analysis by cleaning and type conversion.
@@ -464,7 +464,7 @@ def _analyze_dataset(
async def perform_basic_analysis(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Perform basic analysis (descriptive statistics, correlation) on all prepared datasets.

View File

@@ -0,0 +1,548 @@
"""data.py.
This module provides node functions for data cleaning, preparation, and type conversion
as part of the Business Buddy agent's analysis workflow. It includes helpers for
handling missing values, duplicates, type inference, and logging preparation steps.
These nodes are designed to be composed into automated or human-in-the-loop analysis
pipelines, ensuring that input data is ready for downstream statistical and ML analysis.
Functions:
- _convert_column_types: Attempts to infer and convert object columns to numeric or datetime.
- _prepare_dataframe: Cleans a DataFrame by dropping missing/duplicate rows and converting types.
- prepare_analysis_data: Node function to prepare all datasets in the workflow state.
"""
# Standard library imports
import contextlib
from typing import TYPE_CHECKING, Any, cast
from langchain_core.runnables import RunnableConfig
from biz_bud.core.errors import ValidationError
from biz_bud.core.langgraph.state_immutability import StateUpdater
if TYPE_CHECKING:
from biz_bud.states.analysis import AnalysisPlan
# Third-party imports
from typing import TypedDict
import numpy as np
import pandas as pd
from pydantic import BaseModel, ConfigDict
from biz_bud.core.types import ErrorInfo, create_error_info
# Business Buddy state and logging utilities
from biz_bud.logging import error_highlight, get_logger, info_highlight, warning_highlight
logger = get_logger(__name__)
# More specific types for prepared data and analysis results
_PreparedDataDict = dict[
str, pd.DataFrame | dict[str, Any] | str | list[Any] | int | float | None
]
class _AnalysisResult(TypedDict, total=False):
"""Analysis result for a single dataset."""
descriptive_statistics: dict[str, float | int | str | None]
correlation_matrix: dict[str, float | int | str | None]
_AnalysisResultsDict = dict[str, _AnalysisResult]
class PreparedDataModel(BaseModel):
"""Pydantic model for validating prepared data structure.
This model ensures that prepared_data contains valid data types
and structure expected by downstream analysis nodes.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
def __init__(self, **data):
"""Initialize with prepared data dictionary."""
# Validate that all values are of expected types
for key, value in data.items():
if not isinstance(
value, (pd.DataFrame, dict, str, list, int, float, type(None))
):
raise ValidationError(f"Invalid data type for key '{key}': {type(value)}")
super().__init__(**data)
# --- Node Functions ---
def _convert_column_types(df: pd.DataFrame) -> tuple[pd.DataFrame, list[str]]:
"""Attempt to convert columns of type 'object' in the DataFrame to numeric or datetime types.
Args:
df (pd.DataFrame): The input DataFrame to process.
Returns:
tuple[pd.DataFrame, list[str]]: The DataFrame with converted columns (if any),
and a list of strings describing which columns were converted and to what type.
This function iterates over all columns with dtype 'object' and tries to convert them
first to numeric, then to datetime. If conversion fails, the column is left unchanged.
"""
converted_cols: list[str] = []
for col in df.columns:
if df[col].dtype == "object":
try:
# Try to convert to numeric type
df[col] = pd.to_numeric(df[col], errors="raise")
converted_cols.append(f"'{col}' (to numeric)")
except (ValueError, TypeError):
# If numeric conversion fails, try datetime
with contextlib.suppress(ValueError, TypeError):
df[col] = pd.to_datetime(df[col], errors="raise", format="mixed")
converted_cols.append(f"'{col}' (to datetime)")
return df, converted_cols
def _prepare_dataframe(df: pd.DataFrame, key: str) -> tuple[pd.DataFrame, list[str]]:
"""Clean a DataFrame by dropping missing values, removing duplicates, and converting column types.
Args:
df (pd.DataFrame): The DataFrame to clean.
key (str): The dataset key (for logging).
Returns:
tuple[pd.DataFrame, list[str]]: The cleaned DataFrame and a list of log messages
describing the cleaning steps performed.
This function logs each step of the cleaning process, including the number of rows dropped
for missing values and duplicates, and any type conversions performed.
"""
initial_shape = df.shape
# Drop rows with missing values
df_cleaned = df.dropna()
dropped_rows = initial_shape[0] - df_cleaned.shape[0]
log_msgs: list[str] = [
f"# --- Data Preparation for '{key}' ---",
(
f"# - Dropped {dropped_rows} rows with missing values."
if dropped_rows > 0
else "# - No missing values found."
),
]
initial_rows = df_cleaned.shape[0]
# Drop duplicate rows
df_cleaned = df_cleaned.drop_duplicates()
dropped_duplicates = initial_rows - df_cleaned.shape[0]
log_msgs.append(
f"# - Dropped {dropped_duplicates} duplicate rows."
if dropped_duplicates > 0
else "# - No duplicate rows found."
)
# Attempt to convert object columns to numeric/datetime
df_cleaned, converted_cols = _convert_column_types(df_cleaned)
log_msgs.extend(
(
(
f"# - Attempted type conversion for: {converted_cols}."
if converted_cols
else "# - No automatic type conversions applied."
),
f"# - Final shape: {df_cleaned.shape}",
)
)
return df_cleaned, log_msgs
async def prepare_analysis_data(
state: dict[str, Any], config: RunnableConfig | None
) -> dict[str, Any]:
"""Prepare all datasets in the workflow state for analysis by cleaning and type conversion.
Args:
state: The current workflow state containing input data and analysis plan.
Returns:
The updated state with cleaned datasets and preparation logs.
This node function:
- Extracts input data and analysis plan from the state.
- Determines which datasets to prepare based on the plan.
- Cleans each DataFrame (drops missing/duplicate rows, converts types).
- Logs all preparation steps and errors.
- Updates the state with prepared data and logs for downstream analysis.
"""
info_highlight("Preparing data for analysis...")
# Cast state to Dict for dynamic field access
state_dict = cast("dict[str, object]", state)
# input_data maps dataset names to DataFrames or other objects (e.g., str, int, list, dict)
input_data_raw = state_dict.get("data")
input_data: _PreparedDataDict | None = cast(
"_PreparedDataDict | None", input_data_raw
)
# analysis_plan is a dict with string keys and values that are typically lists, dicts, or primitives
analysis_plan_raw = state_dict.get("analysis_plan")
analysis_plan: AnalysisPlan | None = cast("AnalysisPlan | None", analysis_plan_raw)
code_snippets_raw = state_dict.get("code_snippets")
code_snippets: list[str] = (
code_snippets_raw if isinstance(code_snippets_raw, list) else []
)
if not input_data:
error_highlight("No input data found to prepare for analysis.")
errors_raw = state_dict.get("errors", []) or []
existing_errors: list[ErrorInfo] = (
cast("list[ErrorInfo]", errors_raw) if isinstance(errors_raw, list) else []
)
new_errors = [
*existing_errors,
create_error_info(
message="No data to prepare",
node="data_preparation",
error_type="DataError",
severity="warning",
category="validation",
),
]
updater = StateUpdater(state)
updater.set("errors", new_errors)
updater.set("prepared_data", {})
state = updater.build()
return state
# Determine which datasets to prepare based on the analysis plan, if provided
datasets_to_prepare: list[str] | None = None
if analysis_plan and isinstance(analysis_plan.get("steps"), list):
datasets_to_prepare = list(
{
ds
for step in analysis_plan.get("steps", [])
for ds in step.get("datasets", [])
if isinstance(step.get("datasets"), list)
}
)
info_highlight(
f"Preparing datasets based on analysis plan: {datasets_to_prepare}"
)
prepared_data: dict[str, object] = {}
try:
for key, dataset in (input_data or {}).items():
if analysis_plan and isinstance(analysis_plan.get("steps"), list):
datasets_to_prepare = list(
{
ds
for step in analysis_plan.get("steps", [])
for ds in step.get("datasets", [])
if isinstance(step.get("datasets"), list)
}
)
if key not in datasets_to_prepare:
prepared_data[key] = dataset
continue
if isinstance(dataset, pd.DataFrame):
df_cleaned, log_msgs = _prepare_dataframe(dataset.copy(), key)
prepared_data[key] = df_cleaned
code_snippets.extend(log_msgs)
info_highlight(
f"Prepared DataFrame '{key}'. Final shape: {df_cleaned.shape}"
)
else:
prepared_data[key] = dataset
log_message = f"# - Dataset '{key}' (type: {type(dataset).__name__}) passed through without specific preparation."
code_snippets.append(log_message)
warning_highlight(log_message)
# Runtime validation
try:
_ = PreparedDataModel(**prepared_data)
logger.debug("Prepared data passed validation")
except (ValidationError, ValueError) as e:
warning_highlight(f"Prepared data validation warning: {e}")
# Continue processing despite validation issues for robustness
new_state = dict(state)
new_state["prepared_data"] = prepared_data
new_state["code_snippets"] = code_snippets
return new_state
except Exception as e:
error_highlight(f"Error preparing data: {e}")
prev_errors = cast("list[ErrorInfo]", state_dict.get("errors") or [])
new_state = dict(state)
new_state["errors"] = prev_errors + [
create_error_info(
message=f"Error preparing data: {e}",
node="data_preparation",
error_type=type(e).__name__,
severity="error",
category="unknown",
context={
"operation": "data_preparation",
"exception_details": str(e),
},
)
]
new_state["prepared_data"] = input_data or {}
return new_state
def _get_descriptive_statistics(
df: pd.DataFrame,
) -> tuple[dict[str, float | int | str | None] | None, str]:
"""Compute descriptive statistics for the given DataFrame.
Args:
df (pd.DataFrame): The DataFrame to analyze.
Returns:
tuple[dict[str, float | int | str | None] | None, str]: A dictionary of descriptive statistics (if successful),
and a log message describing the outcome.
This function uses pandas' describe() to compute statistics for all columns.
If an error occurs, it returns None and an error message.
"""
try:
stats = df.describe(include="all").to_dict()
return stats, "# - Calculated descriptive statistics."
except Exception as e:
return None, f"# - ERROR calculating descriptive statistics: {e}"
def _get_correlation_matrix(
df: pd.DataFrame,
) -> tuple[dict[str, float | int | str | None] | None, str]:
"""Compute the correlation matrix for all numeric columns in the DataFrame.
Args:
df (pd.DataFrame): The DataFrame to analyze.
Returns:
tuple[dict[str, float | int | str | None] | None, str]: A dictionary representing the correlation matrix (if successful),
and a log message describing the outcome.
If there are no numeric columns or only one, the function skips computation.
"""
try:
numeric_df = df.select_dtypes(include=[np.number])
if numeric_df.empty or numeric_df.shape[1] <= 1:
return (
None,
"# - Skipped correlation matrix (no numeric columns or only one).",
)
corr = numeric_df.corr().to_dict()
return corr, "# - Calculated correlation matrix for numeric columns."
except Exception as e:
return None, f"# - ERROR calculating correlation matrix: {e}"
def _handle_analysis_error(
state: dict[str, Any], error_msg: str, phase: str
) -> dict[str, Any]:
"""Handle errors during analysis by logging and updating the workflow state.
Args:
state: The current workflow state.
error_msg (str): The error message to log.
phase (str): The phase of analysis where the error occurred.
Returns:
The updated state with error information and cleared results.
This function appends the error to the state's error list and clears analysis results.
"""
# Cast state to Dict for dynamic field access
state_dict = cast("dict[str, object]", state)
error_highlight(error_msg)
errors_raw = state_dict.get("errors", [])
existing_errors = (
cast("list[ErrorInfo]", errors_raw) if isinstance(errors_raw, list) else []
)
error_info = create_error_info(
message=error_msg,
node=phase,
error_type="ProcessingError",
severity="error",
category="unknown",
)
new_errors = [*existing_errors, error_info]
updater = StateUpdater(state)
updater.set("errors", new_errors)
updater.set("analysis_results", {})
state = updater.build()
return state
def _parse_analysis_plan(
analysis_plan: dict[str, list[Any] | dict[str, Any] | str | int | float | None]
| None,
) -> tuple[list[str] | None, dict[str, list[str]]]:
"""Parse the analysis plan to extract datasets to analyze and methods for each dataset.
Args:
analysis_plan (dict[str, object] | None): The analysis plan dictionary.
Returns:
tuple[list[str] | None, dict[str, list[str]]]: A list of dataset keys to analyze,
and a mapping from dataset key to list of methods to run.
This function supports plans with multiple steps, each specifying datasets and methods.
"""
datasets_to_analyze: list[str] | None = None
methods_by_dataset: dict[str, list[str]] = {}
steps = analysis_plan.get("steps") if analysis_plan else None
if steps and isinstance(steps, list):
datasets_to_analyze = []
for step in steps:
if not isinstance(step, dict):
continue
step_datasets = step.get("datasets", [])
step_methods = step.get("methods", [])
if not (isinstance(step_datasets, list) and isinstance(step_methods, list)):
continue
# Only include string dataset names and method names
datasets_to_analyze.extend(
[ds for ds in step_datasets if isinstance(ds, str)]
)
for ds in step_datasets:
if isinstance(ds, str):
methods_by_dataset.setdefault(ds, []).extend(
[m for m in step_methods if isinstance(m, str)]
)
datasets_to_analyze = list(set(datasets_to_analyze))
return datasets_to_analyze, methods_by_dataset
def _analyze_dataset(
key: str,
dataset: pd.DataFrame | str | float | list[Any] | dict[str, Any] | None,
methods_to_run: list[str],
) -> tuple[dict[str, object], list[str]]:
"""Run specified analysis methods on a dataset and log the results.
Args:
key (str): The dataset key (for logging).
dataset (pd.DataFrame | str | float | list | dict | None): The dataset to analyze (usually a DataFrame).
methods_to_run (list[str]): List of analysis methods to run (e.g., 'descriptive_statistics', 'correlation').
Returns:
tuple[dict[str, object], list[str]]: A dictionary of analysis results and a list of log messages.
This function supports DataFrames (for which it computes statistics and correlations)
and logs a message for unsupported types.
"""
log_msgs: list[str] = [f"# --- Data Analysis for '{key}' ---"]
dataset_results: dict[str, object] = {}
if isinstance(dataset, pd.DataFrame):
df = dataset
if "descriptive_statistics" in methods_to_run:
desc_stats, desc_log = _get_descriptive_statistics(df)
if desc_stats is not None:
dataset_results["descriptive_statistics"] = desc_stats
log_msgs.append(desc_log)
if "correlation" in methods_to_run:
corr_matrix, corr_log = _get_correlation_matrix(df)
if corr_matrix is not None:
dataset_results["correlation_matrix"] = corr_matrix
log_msgs.append(corr_log)
else:
log_msgs.append(
f"# - No specific automated analysis applied for type: {type(dataset).__name__}."
)
return dataset_results, log_msgs
async def perform_basic_analysis(
state: dict[str, Any], config: RunnableConfig | None
) -> dict[str, Any]:
"""Perform basic analysis (descriptive statistics, correlation) on all prepared datasets.
Args:
state: The current workflow state containing prepared data and analysis plan.
Returns:
The updated state with analysis results and logs.
This node function:
- Extracts prepared data and analysis plan from the state.
- Determines which datasets and methods to analyze.
- Runs analysis for each dataset and logs results.
- Updates the state with analysis results and logs.
- Handles and logs any errors encountered during analysis.
"""
info_highlight("Performing basic data analysis...")
# Cast state to Dict for dynamic field access
state_dict = cast("dict[str, object]", state)
prepared_data_raw = state_dict.get("prepared_data")
prepared_data: _PreparedDataDict | None = cast(
"_PreparedDataDict | None", prepared_data_raw
)
analysis_plan_raw = state_dict.get("analysis_plan")
analysis_plan: AnalysisPlan | None = cast("AnalysisPlan | None", analysis_plan_raw)
code_snippets_raw = state_dict.get("code_snippets")
code_snippets: list[str] = (
code_snippets_raw if isinstance(code_snippets_raw, list) else []
)
if not prepared_data:
return _handle_analysis_error(
state, "No prepared data found to analyze.", "data_analysis"
)
datasets_to_analyze, methods_by_dataset = _parse_analysis_plan(
cast(
"dict[str, list[Any] | dict[str, Any] | str | int | float | None] | None",
analysis_plan,
)
)
if datasets_to_analyze is not None:
info_highlight(f"Analyzing datasets based on plan: {datasets_to_analyze}")
try:
# analysis_results maps dataset names to analysis result dicts
analysis_results: _AnalysisResultsDict = {}
datasets_analyzed: list[str] = []
for key, dataset in prepared_data.items():
if datasets_to_analyze is not None and key not in datasets_to_analyze:
continue
methods_to_run = list(
set(
methods_by_dataset.get(
key, ["descriptive_statistics", "correlation"]
)
)
)
dataset_results, log_msgs = _analyze_dataset(key, dataset, methods_to_run)
if dataset_results:
# Ensure dataset_results matches _AnalysisResult TypedDict
# Only keep keys that are valid for _AnalysisResult
valid_keys = {"descriptive_statistics", "correlation_matrix"}
filtered_results = {
k: v for k, v in dataset_results.items() if k in valid_keys
}
analysis_results[key] = cast("_AnalysisResult", filtered_results)
datasets_analyzed.append(key)
code_snippets.extend(log_msgs)
new_state = dict(state)
new_state["analysis_results"] = analysis_results
new_state["code_snippets"] = code_snippets
info_highlight(f"Basic analysis complete for datasets: {datasets_analyzed}")
return new_state
except Exception as e:
return _handle_analysis_error(
state, f"Error during data analysis: {e}", "data_analysis"
)

View File

@@ -55,7 +55,7 @@ class _InterpretationResultModel(BaseModel):
@standard_node(node_name="interpret_analysis_results", metric_name="interpretation")
@ensure_immutable_node
async def interpret_analysis_results(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Interprets the results generated by the analysis nodes using an LLM and updates the workflow state.
@@ -135,13 +135,12 @@ async def interpret_analysis_results(
# Get service factory from RunnableConfig if available
provider = None
service_factory = None
if config is not None:
try:
provider = ConfigurationProvider(config)
service_factory = provider.get_service_factory()
except (TypeError, AttributeError):
# Config is not a RunnableConfig or doesn't have expected structure
pass
try:
provider = ConfigurationProvider(config)
service_factory = provider.get_service_factory()
except (TypeError, AttributeError):
# Config is not a RunnableConfig or doesn't have expected structure
pass
if service_factory is None:
service_factory_raw = state_dict.get("service_factory")
@@ -231,7 +230,7 @@ class _ReportModel(BaseModel):
@standard_node(node_name="compile_analysis_report", metric_name="report_generation")
@ensure_immutable_node
async def compile_analysis_report(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Compile comprehensive analysis report from state data.

View File

@@ -0,0 +1,402 @@
"""interpret.py.
This module provides node functions for interpreting analysis results and compiling
final reports in the Business Buddy agent workflow. It leverages LLMs to generate
insightful, structured interpretations and reports based on the outputs of previous
analysis nodes, the user's task, and the analysis plan.
Functions:
- interpret_analysis_results: Uses an LLM to interpret analysis results and populates 'interpretations' in the state.
- compile_analysis_report: Uses an LLM to generate a structured report and populates 'report' in the state.
These nodes are designed to be composed into automated or human-in-the-loop workflows,
enabling rich, explainable outputs for end users.
"""
import json # For serializing results and prompts
from typing import TYPE_CHECKING, Any, cast
from pydantic import BaseModel
from pydantic import ValidationError as PydanticValidationError
from biz_bud.core.errors import ValidationError
if TYPE_CHECKING:
from biz_bud.core import ErrorInfo
from biz_bud.services.factory import ServiceFactory
# Import prompt templates for LLM calls
from langchain_core.runnables import RunnableConfig
from biz_bud.core.langgraph import (
ConfigurationProvider,
StateUpdater,
ensure_immutable_node,
standard_node,
)
# --- Node Functions ---
# Logging utilities
from biz_bud.logging import error_highlight, info_highlight
from biz_bud.prompts.analysis import COMPILE_REPORT_PROMPT, INTERPRET_RESULTS_PROMPT
class _InterpretationResultModel(BaseModel):
"""Model for storing interpretation results."""
key_findings: list[Any]
insights: list[Any]
limitations: list[Any]
next_steps: list[Any] = []
confidence_score: float = 0.0
@standard_node(node_name="interpret_analysis_results", metric_name="interpretation")
@ensure_immutable_node
async def interpret_analysis_results(
state: dict[str, Any], config: RunnableConfig | None
) -> dict[str, Any]:
"""Interprets the results generated by the analysis nodes using an LLM and updates the workflow state.
Args:
state (BusinessBuddyState): The current workflow state containing the task, analysis plan, and results.
Returns:
BusinessBuddyState: The updated state with an 'interpretations' key containing structured insights.
This node function:
- Extracts the task, analysis plan, and results from the state.
- Summarizes results and plan for prompt construction.
- Calls an LLM to generate key findings, insights, limitations, next steps, and a confidence score.
- Validates and logs the LLM response.
- Handles and logs errors, populating defaults if interpretation fails.
"""
info_highlight("Interpreting analysis results...")
# Cast state to Dict for dynamic field access
state_dict = cast("dict[str, object]", state)
task = state_dict.get("task")
if not isinstance(task, str):
task = None
analysis_plan = state_dict.get("analysis_plan", {})
analysis_results = state_dict.get("analysis_results")
# Define default/error state
default_interpretation = {
"key_findings": [],
"insights": [],
"limitations": [],
"next_steps": [],
"confidence_score": 0.0,
}
if not task:
error_highlight("Missing task description for result interpretation.")
prev_errors = cast("list[ErrorInfo]", state_dict.get("errors") or [])
new_state = dict(state)
new_state["errors"] = prev_errors + [
{"phase": "interpretation", "error": "Missing task description"}
]
new_state["interpretations"] = default_interpretation
return new_state
if not analysis_results:
error_highlight("No analysis results found to interpret.")
prev_errors = cast("list[ErrorInfo]", state_dict.get("errors") or [])
new_state = dict(state)
new_state["errors"] = prev_errors + [
{"phase": "interpretation", "error": "No analysis results to interpret"}
]
new_state["interpretations"] = default_interpretation
return new_state
try:
# Summarize results for the prompt to avoid excessive length
results_summary = json.dumps(
analysis_results, indent=2, default=str, ensure_ascii=False
)[:4000]
plan_summary = json.dumps(
analysis_plan, indent=2, default=str, ensure_ascii=False
)[:2000]
prompt = INTERPRET_RESULTS_PROMPT.format(
task=task,
analysis_plan=plan_summary,
analysis_results_summary=results_summary,
)
# *** Assuming call_llm returns dict or parses JSON ***
# Get service factory from RunnableConfig if available
provider = None
service_factory = None
if config is not None:
try:
provider = ConfigurationProvider(config)
service_factory = provider.get_service_factory()
except (TypeError, AttributeError):
# Config is not a RunnableConfig or doesn't have expected structure
pass
if service_factory is None:
service_factory_raw = state_dict.get("service_factory")
if service_factory_raw is None:
raise RuntimeError(
"ServiceFactory instance not found in state or config. Please provide it as 'service_factory'."
)
service_factory = cast("ServiceFactory", service_factory_raw)
async with service_factory.lifespan() as factory:
llm_client = await factory.get_llm_client()
model_name = state_dict.get("model_name")
model_identifier = model_name if isinstance(model_name, str) else None
llm_response = await llm_client.llm_json(
prompt=prompt,
model_identifier=model_identifier,
chunk_size=None,
overlap=None,
temperature=0.4,
input_token_limit=100000,
)
interpretation_json: dict[str, object] = cast(
"dict[str, object]", llm_response
)
# --------------------------------------------
# Basic validation
required_keys = [
"key_findings",
"insights",
"limitations",
] # Optional keys: next_steps, confidence_score
if any(k not in interpretation_json for k in required_keys):
# Attempt partial load if possible, log warning
warning_message = "LLM interpretation response missing some expected keys."
error_highlight(warning_message)
# Fill missing keys with defaults
for key in required_keys:
interpretation_json.setdefault(key, [])
# state["errors"] = state_dict.get('errors', []) + [{'phase': 'interpretation', 'error': warning_message}]
# raise ValueError("LLM interpretation response missing required keys.")
# Runtime validation
try:
# Cast dict[str, object] to dict[str, Any] for proper type validation
typed_json = cast("dict[str, Any]", interpretation_json)
validated_interpretation = _InterpretationResultModel(**typed_json)
except PydanticValidationError as e:
raise ValidationError(
f"LLM interpretation response failed validation: {e}"
) from e
interpretations = validated_interpretation.model_dump()
updater = StateUpdater(state)
updater = updater.set("interpretations", interpretations)
info_highlight("Analysis results interpreted successfully.")
info_highlight(
f"Key Findings sample: {interpretations.get('key_findings', [])[:1]}"
)
except Exception as e:
error_message = f"Error interpreting analysis results: {e}"
error_highlight(error_message)
updater = StateUpdater(state)
return (
updater.append(
"errors", {"phase": "interpretation", "error": error_message}
)
.set("interpretations", default_interpretation)
.build()
)
return updater.build()
class _ReportModel(BaseModel):
"""Model for analysis report structure."""
title: str
executive_summary: str
sections: list[Any]
key_findings: list[Any]
visualizations_included: list[Any]
limitations: list[Any]
@standard_node(node_name="compile_analysis_report", metric_name="report_generation")
@ensure_immutable_node
async def compile_analysis_report(
state: dict[str, Any], config: RunnableConfig | None
) -> dict[str, Any]:
"""Compile comprehensive analysis report from state data.
Args:
state: Current business buddy state
Returns:
Updated state with compiled report
"""
from biz_bud.services.factory import ServiceFactory
"""Compiles the final analysis report using an LLM, based on interpretations, results, and visualizations.
Args:
state (BusinessBuddyState): The current workflow state containing interpretations, results, and visualizations.
Returns:
BusinessBuddyState: The updated state with a 'report' key containing the structured report.
This node function:
- Extracts the task, analysis results, interpretations, and visualizations from the state.
- Summarizes results, interpretations, and visualization metadata for prompt construction.
- Calls an LLM to generate a structured report (title, executive summary, sections, findings, etc.).
- Validates and logs the LLM response.
- Handles and logs errors, populating a default report if report generation fails.
"""
info_highlight("Compiling analysis report...")
# Cast state to Dict for dynamic field access
state_dict = cast("dict[str, object]", state)
task = state_dict.get("task")
if not isinstance(task, str):
task = None
analysis_results = state_dict.get("analysis_results", {})
interpretations_raw = state_dict.get("interpretations")
interpretations = (
interpretations_raw if isinstance(interpretations_raw, dict) else None
)
visualizations = state_dict.get("visualizations", [])
if not isinstance(visualizations, list):
visualizations = []
# Prepare default/error report structure
default_report = {
"title": f"Analysis Report: {task or 'Untitled'}",
"executive_summary": "",
"sections": [],
"key_findings": [],
"visualizations_included": [],
"limitations": [],
}
if not task:
error_highlight("Missing task description for report compilation.")
prev_errors = cast("list[ErrorInfo]", state_dict.get("errors") or [])
new_state = dict(state)
new_state["errors"] = prev_errors + [
{"phase": "report_compilation", "error": "Missing task description"}
]
new_state["report"] = default_report
return new_state
if not interpretations:
error_highlight("Missing interpretations needed to compile the report.")
prev_errors = cast("list[ErrorInfo]", state_dict.get("errors") or [])
new_state = dict(state)
new_state["errors"] = prev_errors + [
{
"phase": "report_compilation",
"error": "Missing interpretations for report",
}
]
new_state["report"] = default_report
return new_state
try:
# Prepare context for the prompt
results_summary = json.dumps(
analysis_results, indent=2, default=str, ensure_ascii=False
)[:2000]
interpretations_summary = json.dumps(
interpretations, indent=2, default=str, ensure_ascii=False
)[:4000]
# Extract visualization types/descriptions for the prompt
viz_metadata = [
str(viz.get("type", viz.get("description", "Unnamed Visualization")))
for viz in (visualizations or [])
]
viz_metadata_summary = json.dumps(viz_metadata, indent=2)
prompt = COMPILE_REPORT_PROMPT.format(
task=task,
analysis_results_summary=results_summary,
interpretations=interpretations_summary,
visualization_metadata=viz_metadata_summary,
)
# *** Assuming call_llm returns dict or parses JSON ***
service_factory = state_dict.get("service_factory")
if not isinstance(service_factory, ServiceFactory):
raise RuntimeError(
"ServiceFactory instance not found in state. Please provide it as 'service_factory'."
)
async with service_factory.lifespan() as factory:
llm_client = await factory.get_llm_client()
model_name = state_dict.get("model_name")
model_identifier = model_name if isinstance(model_name, str) else None
temperature = state_dict.get("temperature", 0.6)
if not isinstance(temperature, (float, int)):
temperature = 0.6
report_json = await llm_client.llm_json(
prompt=prompt,
model_identifier=model_identifier,
chunk_size=None,
overlap=None,
temperature=temperature,
input_token_limit=200000,
)
# --------------------------------------------
# Basic validation
required_keys = [
"title",
"executive_summary",
"sections",
"key_findings",
"visualizations_included",
"limitations",
"conclusion",
]
if any(k not in report_json for k in required_keys):
raise ValidationError("LLM report compilation response missing required keys.")
# Add metadata to the generated report content
# Runtime validation
try:
typed_report_json = cast("dict[str, Any]", report_json)
validated_report = _ReportModel(**typed_report_json)
except PydanticValidationError as e:
raise ValidationError(f"LLM report compilation response failed validation: {e}") from e
report_data = validated_report.model_dump()
new_state = dict(state)
new_state["report"] = report_data
info_highlight("Analysis report compiled successfully.")
info_highlight(f"Report Title: {report_data.get('title')}")
info_highlight(
f"Report Executive Summary snippet: {report_data.get('executive_summary', '')[:100]}..."
)
return new_state
except Exception as e:
error_message = f"Error compiling analysis report: {e}"
error_highlight(error_message)
prev_errors = cast("list[ErrorInfo]", state_dict.get("errors") or [])
new_state = dict(state)
new_state["errors"] = prev_errors + [
{"phase": "report_compilation", "error": error_message}
]
new_state["report"] = default_report
return new_state
return new_state

View File

@@ -173,7 +173,7 @@ async def _create_visualization_tasks(
# --- Node Function ---
async def generate_data_visualizations(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Generate visualizations based on the prepared data and analysis plan/results.

View File

@@ -0,0 +1,263 @@
"""visualize.py.
This module provides node functions for generating data visualizations as part of the
Business Buddy agent's analysis workflow. It includes placeholder logic for async
visualization generation, parsing analysis plans for visualization steps, and
assembling visualization tasks for downstream processing.
Functions:
- _create_placeholder_visualization: Async placeholder for generating visualization metadata.
- _parse_analysis_plan: Extracts datasets to visualize from the analysis plan.
- _create_visualization_tasks: Assembles async tasks for generating visualizations.
- generate_data_visualizations: Node function to orchestrate visualization generation and update state.
These nodes are designed to be composed into automated or human-in-the-loop workflows,
enabling dynamic, context-aware visualization of analysis results.
"""
import asyncio # For async visualization generation
from typing import TYPE_CHECKING, Any
from biz_bud.core.networking.async_utils import gather_with_concurrency
if TYPE_CHECKING:
from biz_bud.states.analysis import (
VisualizationTypedDict,
)
import numpy as np
import pandas as pd # Required for data access
from langchain_core.runnables import RunnableConfig
from biz_bud.logging import error_highlight, info_highlight, warning_highlight
# Placeholder visualization function - replace with actual implementation
async def _create_placeholder_visualization(
df: pd.DataFrame, viz_type: str, dataset_key: str, column: str | None = None
) -> "VisualizationTypedDict":
"""Generate placeholder visualization metadata for a given DataFrame and visualization type.
Args:
df (pd.DataFrame): The DataFrame to visualize.
viz_type (str): The type of visualization (e.g., 'histogram', 'scatter plot').
dataset_key (str): The key identifying the dataset.
column (str | None): The column(s) to visualize, if applicable.
Returns:
dict[str, Any]: A dictionary containing placeholder visualization metadata.
This function simulates async work and returns a placeholder structure.
Replace with actual visualization logic as needed.
"""
await asyncio.sleep(0.01) # Simulate async work
column_str = f" ({column})" if column else ""
title = f"Placeholder {viz_type} for {dataset_key}{column_str}"
from typing import cast
# Create a VisualizationTypedDict with only the required fields
result: VisualizationTypedDict = {
"type": viz_type,
"code": f"# Placeholder code to generate {viz_type} for {dataset_key}{f' on column {column}' if column else ''}",
"image_data": "base64_encoded_placeholder_image_data", # Replace with actual image data
}
# Add extra fields via casting for extensibility
extended_result = cast("dict[str, Any]", result)
extended_result.update(
{
"dataset": dataset_key,
"title": title,
"params": {"column": column} if column else {},
"description": f"A placeholder {viz_type} showing distribution/relationship for {dataset_key}.",
}
)
return cast("VisualizationTypedDict", extended_result)
def _parse_analysis_plan(analysis_plan: dict[str, Any] | None) -> list[str] | None:
"""Extract the list of datasets to visualize from the analysis plan.
Args:
analysis_plan (dict[str, Any] | None): The analysis plan dictionary.
Returns:
list[str] | None: A list of dataset keys to visualize, or None if not specified.
This function scans the analysis plan for steps that include visualization-related methods.
"""
if not analysis_plan or not isinstance(analysis_plan.get("steps"), list):
return None
datasets_to_visualize = set()
for step in analysis_plan["steps"]:
step_datasets = step.get("datasets", [])
step_methods = [
m.lower() for m in step.get("methods", []) if isinstance(m, str)
]
if any(
viz_kw in step_methods
for viz_kw in ["visualization", "plot", "chart", "graph"]
) and isinstance(step_datasets, list):
datasets_to_visualize.update(step_datasets)
return list(datasets_to_visualize)
async def _create_visualization_tasks(
prepared_data: dict[str, Any], datasets_to_visualize: list[str] | None
) -> tuple[list["VisualizationTypedDict"], list[str]]:
"""Create visualization tasks for the given prepared data and datasets to visualize.
Args:
prepared_data (dict[str, Any]): The prepared data dictionary.
datasets_to_visualize (list[str] | None): A list of dataset keys to visualize, or None if not specified.
Returns:
tuple[list[Any], list[str]]: A tuple containing the visualization tasks and log messages.
"""
import pandas as pd
viz_coroutines = []
viz_results: list[VisualizationTypedDict] = []
log_msgs: list[str] = []
for key, dataset in prepared_data.items():
if datasets_to_visualize is not None and key not in datasets_to_visualize:
continue
if isinstance(dataset, pd.DataFrame):
df = dataset
log_msgs.append(f"# --- Visualization Generation for '{key}' ---")
info_highlight(
f"Generating placeholder visualizations for DataFrame '{key}'"
)
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
if len(numeric_cols) >= 1:
col1 = numeric_cols[0]
viz_coroutines.append(
_create_placeholder_visualization(df, "histogram", key, col1)
)
log_msgs.append(
f"# - Generated placeholder histogram for column '{col1}'."
)
if len(numeric_cols) >= 2:
col1, col2 = numeric_cols[0], numeric_cols[1]
viz_coroutines.append(
_create_placeholder_visualization(
df, "scatter plot", key, f"{col1} vs {col2}"
)
)
log_msgs.append(
f"# - Generated placeholder scatter plot for columns '{col1}' vs '{col2}'."
)
if not numeric_cols:
log_msgs.append(
"# - No numeric columns found for default visualizations."
)
else:
warning_highlight(
f"Skipping visualization for non-DataFrame dataset '{key}'"
)
# Await all visualization coroutines with controlled concurrency
if viz_coroutines:
gathered_results = await gather_with_concurrency(2, *viz_coroutines)
viz_results = list(gathered_results)
return viz_results, log_msgs
# --- Node Function ---
async def generate_data_visualizations(
state: dict[str, Any], config: RunnableConfig | None
) -> dict[str, Any]:
"""Generate visualizations based on the prepared data and analysis plan/results.
This node should ideally call utility functions that handle plotting logic
(e.g., using matplotlib, seaborn, plotly) and return image data (e.g., base64).
Populates 'visualizations' in the state.
Args:
state: The current workflow state containing prepared data and analysis plan.
config: Optional runnable configuration.
Returns:
dict[str, Any]: The updated state with generated visualizations and logs.
This node function:
- Extracts prepared data and analysis plan from the state.
- Determines which datasets to visualize.
- Assembles and runs async visualization tasks.
- Updates the state with generated visualizations and logs.
- Handles and logs any errors encountered during visualization.
"""
info_highlight("Generating data visualizations (placeholder)...")
from typing import cast
# State is already dict[str, Any]
state_dict = state
prepared_data: dict[str, Any] | None = state_dict.get("prepared_data")
analysis_plan_raw = state_dict.get("analysis_plan") # Optional guidance
code_snippets: list[str] = state_dict.get("code_snippets") or []
visualizations: list[VisualizationTypedDict] = (
state_dict.get("visualizations") or []
) # Append to existing
if not prepared_data:
error_highlight("No prepared data found to generate visualizations.")
new_state = dict(state)
new_state["visualizations"] = visualizations
return new_state
# Convert analysis_plan to dict format expected by parse_analysis_plan
analysis_plan_dict = cast("dict[str, Any] | None", analysis_plan_raw)
datasets_to_visualize = _parse_analysis_plan(analysis_plan_dict)
if datasets_to_visualize:
info_highlight(f"Visualizing datasets based on plan: {datasets_to_visualize}")
viz_results, log_msgs = await _create_visualization_tasks(
prepared_data, datasets_to_visualize
)
code_snippets.extend(log_msgs)
try:
# viz_results contains already awaited results from create_visualization_tasks
if viz_results:
visualizations.extend(viz_results)
info_highlight(f"Generated {len(viz_results)} placeholder visualizations.")
new_state = dict(state)
new_state["visualizations"] = visualizations
new_state["code_snippets"] = code_snippets
return new_state
except Exception as e:
error_highlight(f"Error generating visualizations: {e}")
from biz_bud.core import BusinessBuddyError
existing_errors = state_dict.get("errors")
if not isinstance(existing_errors, list):
existing_errors = []
from biz_bud.core import ErrorCategory, ErrorContext, ErrorSeverity
context = ErrorContext(
node_name="generate_data_visualizations",
operation="visualization_generation",
metadata={"phase": "visualization"},
)
error_info = BusinessBuddyError(
message=f"Error generating visualizations: {e}",
category=ErrorCategory.UNKNOWN,
severity=ErrorSeverity.ERROR,
context=context,
).to_error_info()
new_state = dict(state)
new_state["errors"] = existing_errors + [error_info]
new_state["visualizations"] = visualizations
return new_state

View File

@@ -15,12 +15,8 @@ and provide strategic recommendations for catalog management.
from .graph import GRAPH_METADATA, catalog_graph_factory, create_catalog_graph
# Import catalog-specific nodes from nodes directory
from .nodes import ( # All available catalog nodes
c_intel_node,
catalog_research_node,
component_default_node,
load_catalog_data_node,
)
from .nodes import c_intel_node # All available catalog nodes
from .nodes import catalog_research_node, component_default_node, load_catalog_data_node
# Create aliases for expected nodes
catalog_impact_analysis_node = c_intel_node # Main impact analysis

View File

@@ -7,8 +7,10 @@ from typing import TYPE_CHECKING, Any
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, START, StateGraph
from biz_bud.graphs.catalog.nodes import ( # c_intel nodes; catalog_research nodes; load_catalog_data nodes
aggregate_catalog_components_node,
from biz_bud.graphs.catalog.nodes import (
aggregate_catalog_components_node, # c_intel nodes; catalog_research nodes; load_catalog_data nodes
)
from biz_bud.graphs.catalog.nodes import (
batch_analyze_components_node,
extract_components_from_sources_node,
find_affected_catalog_items_node,
@@ -90,7 +92,7 @@ def _route_after_extract(state: dict[str, Any]) -> str:
return "aggregate_components" if status == "completed" else "generate_report"
def create_catalog_graph() -> Pregel:
def create_catalog_graph() -> Pregel[CatalogIntelState]:
"""Create the unified catalog management graph.
This graph combines both intelligence analysis and research workflows:
@@ -112,9 +114,9 @@ def create_catalog_graph() -> Pregel:
workflow.add_node("load_catalog_data", load_catalog_data_node)
workflow.add_node("find_affected_items", find_affected_catalog_items_node)
workflow.add_node("batch_analyze", batch_analyze_components_node)
workflow.add_node("research_components", research_catalog_item_components_node)
workflow.add_node("extract_components", extract_components_from_sources_node)
workflow.add_node("aggregate_components", aggregate_catalog_components_node)
workflow.add_node("research_components", research_catalog_item_components_node) # type: ignore[arg-type]
workflow.add_node("extract_components", extract_components_from_sources_node) # type: ignore[arg-type]
workflow.add_node("aggregate_components", aggregate_catalog_components_node) # type: ignore[arg-type]
workflow.add_node("generate_report", generate_catalog_optimization_report_node)
# Define workflow edges
@@ -157,7 +159,7 @@ def create_catalog_graph() -> Pregel:
return workflow.compile()
def catalog_factory(config: RunnableConfig) -> Pregel:
def catalog_factory(config: RunnableConfig) -> Pregel[CatalogIntelState]:
"""Create catalog graph (legacy name for compatibility).
Returns:
@@ -173,7 +175,7 @@ async def catalog_factory_async(config: RunnableConfig) -> Any: # noqa: ANN401
return await asyncio.to_thread(catalog_factory, config)
def catalog_graph_factory(config: RunnableConfig) -> Pregel:
def catalog_graph_factory(config: RunnableConfig) -> Pregel[CatalogIntelState]:
"""Create catalog graph for graph-as-tool pattern.
This factory function follows the standard pattern for graphs

View File

@@ -50,7 +50,7 @@ except ImportError:
@standard_node(node_name="catalog_optimization", metric_name="catalog_optimization")
async def catalog_optimization_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Generate optimization recommendations for the catalog.

View File

@@ -0,0 +1,176 @@
"""Catalog-specific nodes for the catalog management workflow.
This module contains nodes that are specific to catalog management,
component analysis, and catalog intelligence operations.
"""
from __future__ import annotations
from typing import Any
from langchain_core.runnables import RunnableConfig
from biz_bud.core.errors import create_error_info
from biz_bud.core.langgraph import standard_node
from biz_bud.logging import debug_highlight, error_highlight, info_highlight
# Import from local nodes directory
try:
from .nodes.analysis import catalog_impact_analysis_node
from .nodes.c_intel import (
batch_analyze_components_node,
find_affected_catalog_items_node,
generate_catalog_optimization_report_node,
identify_component_focus_node,
)
from .nodes.catalog_research import (
aggregate_catalog_components_node,
extract_components_from_sources_node,
research_catalog_item_components_node,
)
from .nodes.load_catalog_data import load_catalog_data_node
_legacy_imports_available = True
# These imports are here for reference but not used in this module
except ImportError:
# Fallback to None if imports fail - this is acceptable since this module
# primarily provides its own node implementations
_legacy_imports_available = False
batch_analyze_components_node = None
find_affected_catalog_items_node = None
generate_catalog_optimization_report_node = None
identify_component_focus_node = None
aggregate_catalog_components_node = None
extract_components_from_sources_node = None
research_catalog_item_components_node = None
load_catalog_data_node = None
catalog_impact_analysis_node = None
@standard_node(node_name="catalog_optimization", metric_name="catalog_optimization")
async def catalog_optimization_node(
state: dict[str, Any], config: RunnableConfig | None
) -> dict[str, Any]:
"""Generate optimization recommendations for the catalog.
This node analyzes the catalog structure, pricing, and components
to provide actionable optimization recommendations.
Args:
state: Current workflow state
config: Optional runtime configuration
Returns:
Updated state with optimization recommendations
"""
debug_highlight(
"Generating catalog optimization recommendations...", category="CatalogOptimization"
)
# Get analysis data
impact_analysis = state.get("impact_analysis", {})
catalog_data = state.get("catalog_data", {})
try:
optimization_report: dict[str, Any] = {
"recommendations": [],
"priority_actions": [],
"cost_savings_potential": {},
"efficiency_improvements": [],
}
# Analyze catalog structure
total_items = sum(
len(items) if isinstance(items, list) else 0 for items in catalog_data.values()
)
# Generate recommendations based on analysis
if affected_items := impact_analysis.get("affected_items", []):
# Component optimization
if len(affected_items) > 5:
optimization_report["recommendations"].append(
{
"type": "component_standardization",
"description": f"Standardize component usage across {len(affected_items)} items",
"impact": "high",
"effort": "medium",
}
)
if high_price_items := [item for item in affected_items if item.get("price", 0) > 15]:
optimization_report["priority_actions"].append(
{
"action": "Review pricing strategy",
"reason": f"{len(high_price_items)} high-value items affected",
"urgency": "high",
}
)
# Catalog structure optimization
if total_items > 50:
optimization_report["efficiency_improvements"].append(
{
"area": "catalog_structure",
"suggestion": "Consider categorization refinement",
"benefit": "Improved navigation and management",
}
)
# Cost savings analysis
optimization_report["cost_savings_potential"] = {
"component_consolidation": "5-10%",
"supplier_optimization": "3-7%",
"menu_engineering": "10-15%",
}
info_highlight(
f"Optimization report generated with {len(optimization_report['recommendations'])} recommendations",
category="CatalogOptimization",
)
return {
"optimization_report": optimization_report,
"report_metadata": {
"total_items_analyzed": total_items,
"recommendations_count": len(optimization_report["recommendations"]),
"priority_actions_count": len(optimization_report["priority_actions"]),
},
}
except Exception as e:
error_msg = f"Catalog optimization failed: {str(e)}"
error_highlight(error_msg, category="CatalogOptimization")
return {
"optimization_report": {},
"errors": [
create_error_info(
message=error_msg,
node="catalog_optimization",
severity="error",
category="optimization_error",
)
],
}
# Export all catalog-specific nodes
__all__ = [
"catalog_impact_analysis_node",
"catalog_optimization_node",
]
# Re-export legacy nodes if available
if _legacy_imports_available:
__all__.extend(
[
"batch_analyze_components_node",
"find_affected_catalog_items_node",
"generate_catalog_optimization_report_node",
"identify_component_focus_node",
"aggregate_catalog_components_node",
"extract_components_from_sources_node",
"research_catalog_item_components_node",
"load_catalog_data_node",
]
)

View File

@@ -17,7 +17,7 @@ from biz_bud.logging import debug_highlight, error_highlight, info_highlight, wa
@standard_node(node_name="catalog_impact_analysis", metric_name="impact_analysis")
async def catalog_impact_analysis_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Analyze the impact of changes on catalog items.
@@ -139,7 +139,7 @@ async def catalog_impact_analysis_node(
@standard_node(node_name="catalog_optimization", metric_name="catalog_optimization")
async def catalog_optimization_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Generate optimization recommendations for the catalog.

View File

@@ -0,0 +1,253 @@
"""Catalog analysis nodes for impact and optimization analysis.
This module contains nodes that perform catalog-level analysis,
including impact analysis and optimization recommendations.
"""
from __future__ import annotations
from typing import Any
from langchain_core.runnables import RunnableConfig
from biz_bud.core.errors import create_error_info
from biz_bud.core.langgraph import standard_node
from biz_bud.logging import debug_highlight, error_highlight, info_highlight, warning_highlight
@standard_node(node_name="catalog_impact_analysis", metric_name="impact_analysis")
async def catalog_impact_analysis_node(
state: dict[str, Any], config: RunnableConfig | None
) -> dict[str, Any]:
"""Analyze the impact of changes on catalog items.
This node performs comprehensive impact analysis for catalog changes,
including price changes, component substitutions, and availability updates.
Args:
state: Current workflow state
config: Optional runtime configuration
Returns:
Updated state with impact analysis results
"""
info_highlight("Performing catalog impact analysis...", category="CatalogImpact")
# Get analysis parameters
component_focus = state.get("component_focus", {})
catalog_data = state.get("catalog_data", {})
if not component_focus or not catalog_data:
warning_highlight(
"Missing component focus or catalog data", category="CatalogImpact"
)
return {
"impact_analysis": {},
"analysis_metadata": {"message": "Insufficient data for impact analysis"},
}
try:
# Analyze impact across different dimensions
impact_results: dict[str, Any] = {
"affected_items": [],
"cost_impact": {},
"availability_impact": {},
"quality_impact": {},
"recommendations": [],
}
# Extract component details
component_name = component_focus.get("name", "")
component_type = component_focus.get("type", "ingredient")
# Analyze affected catalog items
affected_count = 0
for category, items in catalog_data.items():
if isinstance(items, list):
for item in items:
if isinstance(item, dict):
# Check if item uses the component
components = item.get("components", [])
ingredients = item.get("ingredients", [])
uses_component = False
if component_type == "ingredient" and ingredients:
uses_component = any(
component_name.lower() in str(ing).lower()
for ing in ingredients
)
elif components:
uses_component = any(
component_name.lower() in str(comp).lower()
for comp in components
)
if uses_component:
affected_count += 1
impact_results["affected_items"].append(
{
"name": item.get("name", "Unknown"),
"category": category,
"price": item.get("price", 0),
"dependency_level": "high", # Simplified
}
)
# Calculate aggregate impacts
if affected_count > 0:
impact_results["cost_impact"] = {
"items_affected": affected_count,
"potential_cost_increase": f"{affected_count * 5}%", # Simplified calculation
"risk_level": "medium" if affected_count < 10 else "high",
}
impact_results["recommendations"] = [
f"Monitor {component_name} availability closely",
f"Consider alternative components for {affected_count} affected items",
"Update pricing strategy if costs increase significantly",
]
info_highlight(
f"Impact analysis completed: {affected_count} items affected",
category="CatalogImpact",
)
return {
"impact_analysis": impact_results,
"analysis_metadata": {
"component_analyzed": component_name,
"items_affected": affected_count,
"analysis_depth": "comprehensive",
},
}
except Exception as e:
error_msg = f"Catalog impact analysis failed: {str(e)}"
error_highlight(error_msg, category="CatalogImpact")
return {
"impact_analysis": {},
"errors": [
create_error_info(
message=error_msg,
node="catalog_impact_analysis",
severity="error",
category="analysis_error",
)
],
}
@standard_node(node_name="catalog_optimization", metric_name="catalog_optimization")
async def catalog_optimization_node(
state: dict[str, Any], config: RunnableConfig | None
) -> dict[str, Any]:
"""Generate optimization recommendations for the catalog.
This node analyzes the catalog structure, pricing, and components
to provide actionable optimization recommendations.
Args:
state: Current workflow state
config: Optional runtime configuration
Returns:
Updated state with optimization recommendations
"""
debug_highlight(
"Generating catalog optimization recommendations...",
category="CatalogOptimization",
)
# Get analysis data
impact_analysis = state.get("impact_analysis", {})
catalog_data = state.get("catalog_data", {})
try:
optimization_report: dict[str, Any] = {
"recommendations": [],
"priority_actions": [],
"cost_savings_potential": {},
"efficiency_improvements": [],
}
# Analyze catalog structure
total_items = sum(
len(items) if isinstance(items, list) else 0
for items in catalog_data.values()
)
# Generate recommendations based on analysis
if affected_items := impact_analysis.get("affected_items", []):
# Component optimization
if len(affected_items) > 5:
optimization_report["recommendations"].append(
{
"type": "component_standardization",
"description": f"Standardize component usage across {len(affected_items)} items",
"impact": "high",
"effort": "medium",
}
)
if high_price_items := [
item for item in affected_items if item.get("price", 0) > 15
]:
optimization_report["priority_actions"].append(
{
"action": "Review pricing strategy",
"reason": f"{len(high_price_items)} high-value items affected",
"urgency": "high",
}
)
# Catalog structure optimization
if total_items > 50:
optimization_report["efficiency_improvements"].append(
{
"area": "catalog_structure",
"suggestion": "Consider categorization refinement",
"benefit": "Improved navigation and management",
}
)
# Cost savings analysis
optimization_report["cost_savings_potential"] = {
"component_consolidation": "5-10%",
"supplier_optimization": "3-7%",
"menu_engineering": "10-15%",
}
info_highlight(
f"Optimization report generated with {len(optimization_report['recommendations'])} recommendations",
category="CatalogOptimization",
)
return {
"optimization_report": optimization_report,
"report_metadata": {
"total_items_analyzed": total_items,
"recommendations_count": len(optimization_report["recommendations"]),
"priority_actions_count": len(optimization_report["priority_actions"]),
},
}
except Exception as e:
error_msg = f"Catalog optimization failed: {str(e)}"
error_highlight(error_msg, category="CatalogOptimization")
return {
"optimization_report": {},
"errors": [
create_error_info(
message=error_msg,
node="catalog_optimization",
severity="error",
category="optimization_error",
)
],
}
__all__ = [
"catalog_impact_analysis_node",
"catalog_optimization_node",
]

View File

@@ -52,7 +52,7 @@ def _is_component_match(component: str, item_component: str) -> bool:
async def identify_component_focus_node(
state: CatalogIntelState, config: RunnableConfig | None = None
state: CatalogIntelState, config: RunnableConfig
) -> dict[str, Any]:
"""Identify component to focus on from context.
@@ -245,7 +245,7 @@ async def identify_component_focus_node(
async def find_affected_catalog_items_node(
state: CatalogIntelState, config: RunnableConfig | None = None
state: CatalogIntelState, config: RunnableConfig
) -> dict[str, Any]:
"""Find catalog items affected by the current component focus.
@@ -345,7 +345,7 @@ async def find_affected_catalog_items_node(
async def batch_analyze_components_node(
state: CatalogIntelState, config: RunnableConfig | None = None
state: CatalogIntelState, config: RunnableConfig
) -> dict[str, Any]:
"""Perform batch analysis of multiple components.
@@ -495,7 +495,7 @@ async def batch_analyze_components_node(
async def generate_catalog_optimization_report_node(
state: CatalogIntelState, config: RunnableConfig | None = None
state: CatalogIntelState, config: RunnableConfig
) -> dict[str, Any]:
"""Generate optimization recommendations based on analysis.

View File

@@ -0,0 +1,654 @@
"""Catalog intelligence analysis nodes for LangGraph workflows.
This module provides the actual implementation for catalog intelligence analysis,
including database queries and business logic.
"""
import re
from typing import Any
from langchain_core.runnables import RunnableConfig
from biz_bud.core.types import create_error_info
from biz_bud.core.utils.regex_security import findall_safe, search_safe
from biz_bud.logging import error_highlight, get_logger, info_highlight
from biz_bud.services.factory import ServiceFactory
from biz_bud.states.catalog import CatalogIntelState
_logger = get_logger(__name__)
def _is_component_match(component: str, item_component: str) -> bool:
"""Check if component matches item_component using word boundaries.
This prevents false positives like 'rice' matching 'price'.
Args:
component: The component to search for (e.g., 'rice')
item_component: The component from the item (e.g., 'rice flour', 'price')
Returns:
True if there's a whole word match, False otherwise.
"""
# Normalize strings for comparison
component_lower = component.lower().strip()
item_component_lower = item_component.lower().strip()
# Check exact match first
if component_lower == item_component_lower:
return True
# Check if component appears as a whole word in item_component
# Use word boundaries to prevent substring matches with safe regex
pattern = r"\b" + re.escape(component_lower) + r"\b"
if search_safe(pattern, item_component_lower):
return True
# Check reverse - if item_component appears as whole word in component
# This handles cases like "goat" matching "goat meat"
pattern = r"\b" + re.escape(item_component_lower) + r"\b"
return bool(search_safe(pattern, component_lower))
async def identify_component_focus_node(
state: CatalogIntelState, config: RunnableConfig | None
) -> dict[str, Any]:
"""Identify component to focus on from context.
This node analyzes the input messages or news content to identify
which component(s) should be analyzed for menu impact.
Args:
state: Current workflow state.
config: Runtime configuration.
Returns:
State updates with current_component_focus and/or batch_component_queries.
"""
info_highlight("Identifying component focus from context")
# Infer data source from catalog metadata if not already set
data_source_used = state.get("data_source_used")
if not data_source_used:
extracted_content = state.get("extracted_content", {})
# extracted_content is always a dict from CatalogIntelState
catalog_metadata = extracted_content.get("catalog_metadata", {})
# catalog_metadata is always a dict or empty
source = catalog_metadata.get("source")
if source == "database":
data_source_used = "database"
elif source in ["config.yaml", "yaml"]:
data_source_used = "yaml"
else:
# Check config for data_source hint
config_data = state.get("config", {})
if config_data.get("data_source"):
data_source_used = config_data.get("data_source")
if data_source_used:
_logger.info(f"Inferred data source: {data_source_used}")
# Extract from messages
messages_raw = state.get("messages", [])
messages = messages_raw or []
if not messages:
_logger.warning("No messages found in state")
result_dict: dict[str, Any] = {"current_component_focus": None}
if data_source_used:
result_dict["data_source_used"] = data_source_used
return result_dict
# Get the last message content
last_message = messages[-1]
content = str(getattr(last_message, "content", "")).lower()
_logger.info(f"Analyzing message content: {content[:100]}...")
# Common food components to look for
components = [
"avocado",
"beef",
"chicken",
"lettuce",
"tomato",
"cheese",
"onion",
"potato",
"rice",
"wheat",
"corn",
"pork",
"fish",
"shrimp",
"egg",
"milk",
"butter",
"oil",
"garlic",
"pepper",
"salt",
"sugar",
"flour",
"bread",
"pasta",
"bacon",
"turkey",
"goat", # Added goat to detect "goat meat"
"lamb",
"oxtail",
]
found_components: list[str] = []
content_lower = content.lower()
for component in components:
# Use word boundary matching to avoid false positives with safe regex
pattern = r"\b" + re.escape(component.lower()) + r"\b"
if search_safe(pattern, content_lower):
found_components.append(component)
_logger.info(f"Found component: {component}")
# Also look for context clues like "goat meat shortage" -> "goat"
# First check for specific meat shortages to prioritize them using safe regex
meat_shortage_pattern = r"(\w+\s+meat)\s+shortage"
meat_shortage_matches = findall_safe(meat_shortage_pattern, content, flags=re.IGNORECASE)
if meat_shortage_matches:
# If we find a specific meat shortage, focus on that
_logger.info(f"Found specific meat shortage: {meat_shortage_matches[0]}")
result = {
"current_component_focus": meat_shortage_matches[0].lower(),
"batch_component_queries": [],
}
# Include data_source_used (either from state or inferred)
if data_source_used:
result["data_source_used"] = data_source_used
return result
# Extract phrases like "X prices", "X shortage", "X increasing"
component_context_patterns = [
r"(\w+(?:\s+meat)?)\s+(?:price|prices|pricing)",
r"(\w+(?:\s+meat)?)\s+(?:shortage|shortages)",
r"(\w+(?:\s+meat)?)\s+(?:increasing|rising)",
r"(?:focus on|analyze)[\s:]*(\w+(?:\s+meat)?)",
]
for pattern in component_context_patterns:
matches = findall_safe(pattern, content, flags=re.IGNORECASE)
for match in matches:
# Extract the base component (e.g., "goat" from "goat meat")
match_clean = match.strip().lower()
base_component = match_clean.replace(" meat", "").strip()
# Check if the base component is in our known components
if base_component in components and base_component not in found_components:
found_components.append(base_component)
_logger.info(f"Found component from context: {base_component}")
# Also add the full match if it contains "meat" and base is valid
elif base_component in components and "meat" in match_clean:
# Add the full compound term like "goat meat"
if match_clean not in found_components:
found_components.append(match_clean)
_logger.info(
f"Found compound component from context: {match_clean}"
)
# If multiple components found, use batch analysis
if len(found_components) > 1:
# Special case: if we have both base and compound form (e.g., "goat" and "goat meat")
# prefer the base form for single component focus
base_forms = [comp for comp in found_components if " meat" not in comp]
if len(base_forms) == 1 and any(" meat" in comp for comp in found_components):
# Use the base form for single component focus
_logger.info(f"Using base component: {base_forms[0]}")
result = {
"current_component_focus": base_forms[0],
"batch_component_queries": [],
}
# Include data_source_used (either from state or inferred)
if data_source_used:
result["data_source_used"] = data_source_used
return result
_logger.info(f"Multiple components found: {found_components}")
batch_result: dict[str, Any] = {
"batch_component_queries": found_components,
"current_component_focus": None,
}
# Include data_source_used (either from state or inferred)
if data_source_used:
batch_result["data_source_used"] = data_source_used
return batch_result
elif len(found_components) == 1:
_logger.info(f"Single component found: {found_components[0]}")
result = {
"current_component_focus": found_components[0],
"batch_component_queries": [],
}
# Include data_source_used (either from state or inferred)
if data_source_used:
result["data_source_used"] = data_source_used
return result
else:
_logger.info("No specific components found in message")
empty_result: dict[str, Any] = {
"current_component_focus": None,
"batch_component_queries": [],
"component_news_impact_reports": [],
}
# Include data_source_used (either from state or inferred)
if data_source_used:
empty_result["data_source_used"] = data_source_used
return empty_result
async def find_affected_catalog_items_node(
state: CatalogIntelState, config: RunnableConfig | None
) -> dict[str, Any]:
"""Find catalog items affected by the current component focus.
Args:
state: Current workflow state.
config: Runtime configuration.
Returns:
State updates with affected menu items.
"""
component = None # Initialize to avoid unbound variable
try:
component = state.get("current_component_focus")
if not component:
_logger.warning("No component focus set")
return {}
info_highlight(f"Finding catalog items affected by: {component}")
# For simplified integration tests, check catalog items in state
extracted_content = state.get("extracted_content", {})
# extracted_content is always a dict from CatalogIntelState
catalog_items = extracted_content.get("catalog_items", [])
if not isinstance(catalog_items, list) and catalog_items:
catalog_items = []
if catalog_items:
# Find items that contain this component
affected_items = []
for item in catalog_items:
item_components = item.get("components", [])
# Use word boundary matching to prevent false positives
if any(
_is_component_match(component, comp) for comp in item_components
):
affected_items.append(item)
_logger.info(f"Found affected item: {item.get('name')}")
if affected_items:
return {"catalog_items_linked_to_component": affected_items}
# Get database service from state config
config_dict = state.get("config", {})
configurable = config_dict.get("configurable", {})
app_config = configurable.get("app_config")
if not app_config:
# No database access, return empty results
_logger.warning("App config not found in state, skipping database lookup")
return {"catalog_items_linked_to_component": []}
services = ServiceFactory(app_config)
result = {"catalog_items_linked_to_component": []}
try:
db = await services.get_db_service()
# Check if database has component methods
# Use getattr to avoid attribute errors
get_component_func = getattr(db, "get_component_by_name", None)
get_items_func = getattr(db, "get_catalog_items_by_component_id", None)
if get_component_func is not None:
# First, get the component ID
component_info = await get_component_func(str(component))
if not component_info:
_logger.warning(f"Component '{component}' not found in database")
elif get_items_func is not None:
# Get all catalog items with this component
catalog_items = await get_items_func(component_info["component_id"])
_logger.info(
f"Found {len(catalog_items)} catalog items with {component}"
)
result = {"catalog_items_linked_to_component": catalog_items}
else:
_logger.debug("Database doesn't support component methods")
finally:
await services.cleanup()
return result
except Exception as e:
error_highlight(f"Error finding affected items: {e}")
errors = state.get("errors", [])
error_info = create_error_info(
message=str(e),
error_type=type(e).__name__,
node="find_affected_catalog_items",
severity="error",
category="state",
context={
"component": component or "unknown",
"phase": "catalog_analysis",
"type": type(e).__name__,
},
)
errors.append(error_info)
return {"errors": errors, "catalog_items_linked_to_component": []}
async def batch_analyze_components_node(
state: CatalogIntelState, config: RunnableConfig | None
) -> dict[str, Any]:
"""Perform batch analysis of multiple components.
Args:
state: Current workflow state.
config: Runtime configuration.
Returns:
State updates with analysis results.
"""
try:
components_raw = state.get("batch_component_queries", [])
components = components_raw or []
if not components:
_logger.warning("No components to batch analyze")
return {}
info_highlight(f"Batch analyzing {len(components)} components")
# Get database service if available
config_dict = state.get("config", {})
configurable = config_dict.get("configurable", {})
app_config = configurable.get("app_config")
# If no app_config, generate basic impact reports without database
if not app_config:
_logger.info("No app config found, generating basic impact reports")
# Generate basic reports based on catalog items in state
extracted_content = state.get("extracted_content", {})
# extracted_content is always a dict from CatalogIntelState
catalog_items = extracted_content.get("catalog_items", [])
if not isinstance(catalog_items, list) and catalog_items:
catalog_items = []
impact_reports = []
for component in components:
# Find items with this component
affected_items = []
for item in catalog_items:
item_components = item.get("components", [])
if any(
_is_component_match(component, comp) for comp in item_components
):
affected_items.append(
{
"item_id": item.get("id"),
"item_name": item.get("name"),
"price_cents": int(item.get("price", 0) * 100),
}
)
if affected_items:
impact_reports.append(
{
"component_name": component,
"news_summary": f"{component} market conditions",
"market_sentiment": "neutral",
"affected_items_report": affected_items,
"analysis_timestamp": "",
"confidence_score": 0.7,
}
)
return {"component_news_impact_reports": impact_reports}
services = ServiceFactory(app_config)
result = {"component_news_impact_reports": []}
try:
db = await services.get_db_service()
# Get market context from state
market_context = {
"summary": state.get("news_summary", ""),
"sentiment": state.get("market_sentiment", "neutral"),
"timestamp": state.get("analysis_timestamp", ""),
}
# Batch process components
impact_reports: list[dict[str, Any]] = []
# Use getattr to avoid attribute errors
batch_get_func = getattr(db, "batch_get_catalog_items_by_components", None)
if batch_get_func is not None:
catalog_items_by_component = await batch_get_func(
components, max_concurrency=5
)
else:
# Fallback: return empty dict if method not available
catalog_items_by_component = {}
for component_name, catalog_items in catalog_items_by_component.items():
# Analyze impact based on market context
summary_text = market_context["summary"]
sentiment = (
"negative"
if any(
term in summary_text.lower()
for term in ["shortage", "price increase", "supply chain"]
)
else "neutral"
)
# Create impact report
impact_report = {
"component_name": component_name,
"news_summary": market_context["summary"],
"market_sentiment": sentiment,
"affected_items_report": [
{
"item_id": item["item_id"],
"item_name": item["item_name"],
"current_price_cents": item["price_cents"],
"potential_impact_notes": f"May be affected by {component_name} {sentiment} market conditions",
"recommended_action": (
"Monitor pricing"
if sentiment == "negative"
else "No action needed"
),
"priority_level": 2 if sentiment == "negative" else 0,
}
for item in catalog_items
],
"analysis_timestamp": market_context["timestamp"],
"confidence_score": 0.8 if catalog_items else 0.0,
}
impact_reports.append(impact_report)
result = {"component_news_impact_reports": impact_reports}
finally:
await services.cleanup()
return result
except Exception as e:
error_highlight(f"Error in batch analysis: {e}")
errors = state.get("errors", [])
error_info = create_error_info(
message=str(e),
error_type=type(e).__name__,
node="batch_analyze_components",
severity="error",
category="state",
context={"phase": "batch_analysis"},
)
errors.append(error_info)
return {"errors": errors}
async def generate_catalog_optimization_report_node(
state: CatalogIntelState, config: RunnableConfig | None
) -> dict[str, Any]:
"""Generate optimization recommendations based on analysis.
Args:
state: Current workflow state.
config: Runtime configuration.
Returns:
State updates with optimization suggestions.
"""
info_highlight("Generating catalog optimization report")
impact_reports_raw = state.get("component_news_impact_reports", [])
impact_reports = impact_reports_raw or []
if not impact_reports:
_logger.warning("No impact reports to process")
# Still generate basic suggestions based on catalog items
catalog_items = state.get("extracted_content", {}).get("catalog_items", [])
if not isinstance(catalog_items, list) and catalog_items:
catalog_items = []
if catalog_items:
basic_suggestions = _generate_basic_catalog_suggestions(catalog_items)
return {
"catalog_optimization_suggestions": basic_suggestions,
"component_news_impact_reports": [],
}
return {
"catalog_optimization_suggestions": [],
"component_news_impact_reports": [],
}
# Aggregate insights and generate recommendations
suggestions: list[dict[str, Any]] = []
high_priority_items: list[str] = []
medium_priority_items: list[str] = []
for report in impact_reports:
if report:
component_name = report.get("component_name", "Unknown")
sentiment = report.get("market_sentiment", "neutral")
affected_items = report.get("affected_items_report", [])
if sentiment == "negative" and affected_items:
# High priority suggestions for negative sentiment
for item in affected_items:
if item:
suggestion = {
"type": "price_adjustment",
"item_id": item.get("item_id"),
"item_name": item.get("item_name"),
"current_price_cents": item.get("current_price_cents"),
"reason": f"{component_name} supply chain issues",
"urgency": "high",
"recommended_action": "Consider 5-10% price increase",
"alternative_action": f"Find substitute for {component_name}",
}
suggestions.append(suggestion)
if item_name := item.get("item_name"):
high_priority_items.append(str(item_name))
elif sentiment == "positive" and affected_items:
# Medium priority suggestions for positive sentiment
for item in affected_items:
if item:
suggestion = {
"type": "promotion_opportunity",
"item_id": item.get("item_id"),
"item_name": item.get("item_name"),
"current_price_cents": item.get("current_price_cents"),
"reason": f"{component_name} favorable market conditions",
"urgency": "medium",
"recommended_action": "Feature in promotions",
"alternative_action": "Maintain current pricing",
}
suggestions.append(suggestion)
if item_name := item.get("item_name"):
medium_priority_items.append(str(item_name))
# Create summary
summary = {
"total_suggestions": len(suggestions),
"high_priority_count": len(high_priority_items),
"medium_priority_count": len(medium_priority_items),
"affected_components": [
r.get("component_name", "") for r in impact_reports if r
],
"report_timestamp": state.get("analysis_timestamp", ""),
}
return {
"catalog_optimization_suggestions": suggestions,
"optimization_summary": summary,
}
def _generate_basic_catalog_suggestions(
catalog_items: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Generate basic optimization suggestions based on catalog analysis."""
suggestions = []
# Analyze common components
all_components = []
for item in catalog_items:
components = item.get("components", [])
if isinstance(components, list):
all_components.extend(components)
# Count component frequency
component_counts = {}
for component in all_components:
if isinstance(component, str):
component_counts[component] = component_counts.get(component, 0) + 1
if common_components := sorted(
component_counts.items(), key=lambda x: x[1], reverse=True
)[:3]:
suggestions.append(
{
"type": "supply_optimization",
"priority": "medium",
"title": "Common Component Analysis",
"description": f"Most frequently used components: {', '.join([comp[0] for comp in common_components])}",
"recommendation": "Consider bulk purchasing agreements for these high-frequency components",
"affected_items": [
item.get("name")
for item in catalog_items
if any(
comp in item.get("components", [])
for comp, _ in common_components
)
],
}
)
if prices := [item.get("price", 0) for item in catalog_items if item.get("price")]:
avg_price = sum(prices) / len(prices)
if high_price_items := [
item for item in catalog_items if item.get("price", 0) > avg_price * 1.5
]:
suggestions.append(
{
"type": "pricing_strategy",
"priority": "low",
"title": "High-Value Item Focus",
"description": f"Items priced above 150% of average: {', '.join([item.get('name', 'Unknown') for item in high_price_items])}",
"recommendation": "Ensure quality and marketing support for premium items",
"affected_items": [
item.get("name", "Unknown") for item in high_price_items
],
}
)
return suggestions

View File

@@ -10,7 +10,7 @@ logger = get_logger(__name__)
async def research_catalog_item_components_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Research components for catalog items using web search.
@@ -46,7 +46,7 @@ async def research_catalog_item_components_node(
async def extract_components_from_sources_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Extract components from researched sources.
@@ -79,7 +79,7 @@ async def extract_components_from_sources_node(
async def aggregate_catalog_components_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Aggregate extracted components across catalog items.

View File

@@ -0,0 +1,112 @@
"""Catalog research nodes for component discovery and analysis."""
from typing import Any
from langchain_core.runnables import RunnableConfig
from biz_bud.logging import get_logger
logger = get_logger(__name__)
async def research_catalog_item_components_node(
state: dict[str, Any], config: RunnableConfig | None
) -> dict[str, Any]:
"""Research components for catalog items using web search.
This is a placeholder implementation that maintains backward compatibility
while the research functionality is being consolidated.
Args:
state: Current workflow state
config: Runtime configuration
Returns:
State updates with research results
"""
logger.info("Research catalog item components node - placeholder implementation")
# Basic implementation that satisfies the interface
return {
"catalog_component_research": {
"status": "completed",
"total_items": 0,
"researched_items": 0,
"cached_items": 0,
"searched_items": 0,
"research_results": [],
"metadata": {
"categories": [],
"subcategories": [],
"search_provider": "placeholder",
"cache_enabled": False,
},
}
}
async def extract_components_from_sources_node(
state: dict[str, Any], config: RunnableConfig | None
) -> dict[str, Any]:
"""Extract components from researched sources.
This is a placeholder implementation that maintains backward compatibility
while the extraction functionality is being consolidated.
Args:
state: Current workflow state
config: Runtime configuration
Returns:
State updates with extracted components
"""
logger.info("Extract components from sources node - placeholder implementation")
# Basic implementation that satisfies the interface
return {
"extracted_components": {
"status": "completed",
"total_items": 0,
"successfully_extracted": 0,
"total_components_found": 0,
"items": [],
"metadata": {
"extractor": "placeholder",
"categorizer": "placeholder",
},
}
}
async def aggregate_catalog_components_node(
state: dict[str, Any], config: RunnableConfig | None
) -> dict[str, Any]:
"""Aggregate extracted components across catalog items.
This is a placeholder implementation that maintains backward compatibility
while the aggregation functionality is being consolidated.
Args:
state: Current workflow state
config: Runtime configuration
Returns:
State updates with component analytics
"""
logger.info("Aggregate catalog components node - placeholder implementation")
# Basic implementation that satisfies the interface
return {
"component_analytics": {
"status": "completed",
"total_unique_components": 0,
"total_catalog_items": 0,
"common_components": [],
"category_distribution": {},
"bulk_purchase_recommendations": [],
"metadata": {
"analysis_type": "placeholder",
"timestamp": "",
},
}
}

View File

@@ -688,7 +688,7 @@ def _load_default_data(fallback_source: str) -> dict[str, Any]:
async def load_catalog_data_node(
state: CatalogResearchState, config: RunnableConfig | None = None
state: CatalogResearchState, config: RunnableConfig
) -> dict[str, Any]:
"""Load catalog data from configuration or database into extracted_content.

View File

@@ -10,7 +10,6 @@ from biz_bud.core.edge_helpers.core import create_bool_router, create_enum_route
from biz_bud.core.edge_helpers.error_handling import handle_error
if TYPE_CHECKING:
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.state import CompiledStateGraph
from biz_bud.nodes.error_handling import (
@@ -54,7 +53,7 @@ GRAPH_METADATA = {
def create_error_handling_graph(
checkpointer: PostgresSaver | None = None,
) -> "CompiledGraph":
) -> "CompiledStateGraph[Any]":
"""Create the error handling agent graph.
This graph can be used as a subgraph in any BizBud workflow
@@ -140,8 +139,8 @@ def check_error_recovery(state: ErrorHandlingState) -> str:
def add_error_handling_to_graph(
main_graph: StateGraph,
error_handler: "CompiledStateGraph",
main_graph: StateGraph[ErrorHandlingState],
error_handler: "CompiledStateGraph[ErrorHandlingState]",
nodes_to_protect: list[str],
error_node_name: str = "handle_error",
next_node_mapping: dict[str, str] | None = None,
@@ -280,7 +279,7 @@ def create_error_handling_config(
}
def error_handling_graph_factory(config: RunnableConfig) -> "CompiledGraph":
def error_handling_graph_factory(config: RunnableConfig) -> "CompiledStateGraph[Any]":
"""Create error handling graph for LangGraph API.
Args:

View File

@@ -52,7 +52,7 @@ def route_after_feedback(state: BusinessBuddyState) -> Literal["refine", "comple
return "refine" if should_apply_refinement(state) else "complete"
def create_human_feedback_workflow() -> StateGraph:
def create_human_feedback_workflow() -> StateGraph[BusinessBuddyState]:
"""Create a workflow with human feedback integration.
This workflow demonstrates:
@@ -165,14 +165,14 @@ async def run_example() -> None:
await cast(
"Coroutine[Any, Any, dict[str, Any]]",
app.ainvoke(
{"feedback_message": user_feedback},
cast("BusinessBuddyState", {"feedback_message": user_feedback}),
cast("RunnableConfig | None", config),
),
)
# More complex example with multiple feedback rounds
def create_iterative_feedback_workflow() -> StateGraph:
def create_iterative_feedback_workflow() -> StateGraph[BusinessBuddyState]:
"""Create a workflow that can handle multiple rounds of feedback.
This demonstrates a more complex pattern where refinement

View File

@@ -102,7 +102,7 @@ async def research_web_search(
@standard_node(node_name="search_web", metric_name="web_search")
@ensure_immutable_node
async def search_web_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Search the web for information related to the query.
@@ -126,13 +126,25 @@ async def search_web_node(
)
try:
# Use the tool with config
result = research_web_search(query)
# Simulate web search directly
mock_results = [
{
"title": f"Result for {query}",
"url": "https://example.com",
"snippet": "Sample content",
"relevance_score": 0.9,
}
for _ in range(3)
]
result = {
"results": mock_results,
"total_found": len(mock_results),
}
# Update state immutably
updater = StateUpdater(state)
return updater.set(
"search_results", result.get("results", []) if isinstance(result, dict) else []
"search_results", result.get("results", [])
).build()
except Exception as e:
@@ -147,7 +159,7 @@ async def search_web_node(
@standard_node(node_name="extract_facts", metric_name="fact_extraction")
@ensure_immutable_node
async def extract_facts_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Extract facts from search results.
@@ -192,7 +204,7 @@ async def extract_facts_node(
@standard_node(node_name="summarize_research", metric_name="research_summary")
@ensure_immutable_node
async def summarize_research_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Summarize the research findings.
@@ -245,7 +257,7 @@ async def summarize_research_node(
def create_research_subgraph(
app_config: object | None = None, service_factory: object | None = None
) -> CompiledStateGraph:
) -> CompiledStateGraph[ResearchSubgraphState]:
"""Create a reusable research subgraph.
This demonstrates creating a reusable subgraph that can be embedded

View File

@@ -220,7 +220,7 @@ def create_service_factory_from_config(config: RunnableConfig) -> ServiceFactory
# Example node using service factory
async def example_node_with_services(
state: dict[str, object], config: RunnableConfig | None = None
state: dict[str, object], config: RunnableConfig
) -> dict[str, object]:
"""Demonstrate service factory usage.

View File

@@ -17,7 +17,7 @@ The main graph represents a sophisticated agent workflow that can:
import asyncio
import os
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, TypeVar, cast
if TYPE_CHECKING:
from biz_bud.services.factory import ServiceFactory
@@ -55,6 +55,9 @@ from biz_bud.nodes import call_model_node, parse_and_validate_initial_payload
from biz_bud.services.factory import get_global_factory
from biz_bud.states.base import InputState
# Type variable for state types
StateType = TypeVar('StateType')
# Get logger instance
logger = get_logger(__name__)
@@ -77,7 +80,7 @@ def _create_async_factory_wrapper(sync_resolver_func, async_resolver_func):
"""
def create_sync_factory():
def sync_factory(config: RunnableConfig) -> CompiledStateGraph:
def sync_factory(config: RunnableConfig) -> CompiledStateGraph[InputState]:
"""Create graph for LangGraph API with RunnableConfig (optimized)."""
from langchain_core.runnables import RunnableConfig
@@ -96,7 +99,7 @@ def _create_async_factory_wrapper(sync_resolver_func, async_resolver_func):
return sync_factory
def create_async_factory():
async def async_factory(config: RunnableConfig) -> CompiledStateGraph:
async def async_factory(config: RunnableConfig) -> CompiledStateGraph[InputState]:
"""Create graph for LangGraph API with RunnableConfig (async optimized)."""
from langchain_core.runnables import RunnableConfig
@@ -117,7 +120,7 @@ def _create_async_factory_wrapper(sync_resolver_func, async_resolver_func):
return create_sync_factory(), create_async_factory()
def _handle_sync_async_context(app_config: AppConfig, service_factory: "ServiceFactory") -> CompiledStateGraph:
def _handle_sync_async_context(app_config: AppConfig, service_factory: "ServiceFactory") -> CompiledStateGraph[InputState]:
"""Handle sync/async context detection for graph creation.
This function uses centralized context detection from core.networking
@@ -348,7 +351,7 @@ async def search(state: Any) -> Any: # noqa: ANN401
def create_graph() -> CompiledStateGraph:
def create_graph() -> CompiledStateGraph[InputState]:
"""Build and compile the complete StateGraph for Business Buddy agent execution.
This function constructs the main workflow graph that serves as the execution
@@ -446,7 +449,7 @@ def create_graph() -> CompiledStateGraph:
async def create_graph_with_services(
app_config: AppConfig, service_factory: "ServiceFactory"
) -> CompiledStateGraph:
) -> CompiledStateGraph[InputState]:
"""Create graph with service factory injection using caching.
Args:
@@ -543,7 +546,7 @@ async def _get_or_create_service_factory_async(config_hash: str, app_config: App
async def create_graph_with_overrides_async(
config: dict[str, Any],
) -> CompiledStateGraph:
) -> CompiledStateGraph[InputState]:
"""Async version of create_graph_with_overrides.
Same functionality as create_graph_with_overrides but uses async config loading
@@ -614,7 +617,7 @@ def _load_config_with_logging() -> AppConfig:
_graph_cache_manager = InMemoryCache[CompiledStateGraph](max_size=100)
_graph_cache_manager = InMemoryCache[CompiledStateGraph[InputState]](max_size=100)
_graph_creation_locks: dict[str, asyncio.Lock] = {}
_locks_lock = asyncio.Lock() # Lock for managing the locks dict
_graph_cache_lock = asyncio.Lock() # Keep for backward compatibility
@@ -650,7 +653,7 @@ async def get_cached_graph(
config_hash: str = "default",
service_factory: "ServiceFactory | None" = None,
use_caching: bool = True
) -> CompiledStateGraph:
) -> CompiledStateGraph[InputState]:
"""Get cached compiled graph with optional service injection using GraphCache.
Args:
@@ -662,7 +665,7 @@ async def get_cached_graph(
Compiled and cached graph instance
"""
# Use GraphCache for thread-safe caching
async def build_graph_for_cache() -> CompiledStateGraph:
async def build_graph_for_cache() -> CompiledStateGraph[InputState]:
logger.info(f"Creating new graph instance for config: {config_hash}")
# Get or create service factory
@@ -763,7 +766,7 @@ async def get_cached_graph(
return graph
def get_graph() -> CompiledStateGraph:
def get_graph() -> CompiledStateGraph[InputState]:
"""Get the singleton graph instance (backward compatibility)."""
try:
asyncio.get_running_loop()
@@ -781,9 +784,9 @@ def get_graph() -> CompiledStateGraph:
# For backward compatibility - direct access (lazy initialization)
_module_graph: CompiledStateGraph | None = None
_module_graph: CompiledStateGraph[InputState] | None = None
def get_module_graph() -> CompiledStateGraph:
def get_module_graph() -> CompiledStateGraph[InputState]:
"""Get module-level graph instance (lazy initialization)."""
global _module_graph
if _module_graph is None:

View File

@@ -0,0 +1,276 @@
# Paperless Document Management System
This directory contains the Paperless NGX integration components for the Business Buddy system, including receipt processing workflows and document management capabilities.
## Overview
The Paperless integration provides:
- **Document Processing**: Automated receipt and invoice processing with LLM-based extraction
- **Intelligent Agent**: Natural language interface for document management operations
- **Database Integration**: Structured data storage with PostgreSQL schemas
- **Workflow Orchestration**: LangGraph-based state management and processing pipelines
## Components
### Core Graph (`graph.py`)
The main receipt processing workflow that handles:
- Document classification (receipt vs invoice vs purchase order)
- Line item extraction using structured LLM prompts
- Product validation against web catalogs
- Database insertion with normalized schemas
### Intelligent Agent (`agent.py`)
Natural language interface for document operations:
- Document search and retrieval
- Tag and metadata management
- Bulk operations and filtering
- Conversational interactions
### Processing Nodes (`nodes/`)
Individual workflow components:
- `paperless.py`: Paperless NGX API interactions
- `extraction.py`: LLM-based data extraction
- `validation.py`: Product and vendor validation
- `database.py`: PostgreSQL data persistence
### State Management (`states/`)
Typed state definitions:
- `receipt.py`: Receipt processing state with line items and metadata
- `paperless.py`: General Paperless document state
## Usage Examples
### Basic Receipt Processing
```python
from biz_bud.graphs.paperless.graph import create_receipt_processing_graph
# Process a receipt document
graph = create_receipt_processing_graph()
result = await graph.ainvoke({
"document_content": receipt_text,
"processing_stage": "pending"
})
print(f"Extracted {len(result['line_items'])} items")
print(f"Total: ${result['receipt_metadata']['final_total']}")
```
### Intelligent Agent Interface
```python
from biz_bud.graphs.paperless.agent import process_paperless_with_agent
# Natural language document management
response = await process_paperless_with_agent(
"Find all my grocery receipts from this month and tag them with 'monthly-review'"
)
print(f"Found {response.documents_found} documents")
print(f"Operation: {response.operation}")
```
### Direct Tool Usage
```python
from biz_bud.tools.capabilities.external.paperless.tool import (
search_paperless_documents,
get_paperless_document
)
# Search documents
results = await search_paperless_documents.ainvoke({
"query": "receipt grocery",
"limit": 10
})
# Get specific document
doc = await get_paperless_document.ainvoke({
"document_id": 123
})
```
## Configuration
### Environment Variables
```bash
# Paperless NGX Configuration
PAPERLESS_BASE_URL=http://localhost:8000
PAPERLESS_TOKEN=your_api_token_here
# Database Configuration (for receipt processing)
POSTGRES_USER=user
POSTGRES_PASSWORD=password
POSTGRES_HOST=postgres
POSTGRES_PORT=5432
POSTGRES_DB=langgraph_db
```
### Database Schema
The system uses PostgreSQL schemas defined in `/docker/db-init/`:
- `rpt_receipts`: Main receipt records
- `rpt_receipt_line_items`: Individual line items
- `rpt_master_products`: Product catalog
- `rpt_master_vendors`: Vendor information
- `rpt_reconciliation_log`: Processing audit trail
## Testing
### Comprehensive Test Suite
Run the complete workflow test:
```bash
python test_receipt_workflow.py
```
This tests:
1. Paperless NGX document fetching
2. LLM-based document classification
3. PostgreSQL connectivity and schema validation
4. Receipt extraction and validation
5. Database insertion with proper normalization
6. End-to-end workflow integration
7. Performance benchmarking
### Agent Demonstration
```bash
# Simple usage examples
python example_paperless_agent_usage.py
# Comprehensive demo with interactive features
python demo_paperless_agent.py
```
### Individual Component Testing
```bash
# Test specific workflow steps
pytest test_receipt_workflow.py::TestStep1PaperlessDocuments -v
pytest test_receipt_workflow.py::TestStep4ReceiptExtraction -v
pytest test_receipt_workflow.py::TestEndToEnd -v
```
## Architecture
### LangGraph Workflow Pattern
The system follows LangGraph best practices:
- **StateGraph construction**: Typed states with proper field annotations
- **Node decorators**: `@standard_node`, `@handle_errors`, `@log_node_execution`
- **Command routing**: Intelligent workflow navigation using `Command` objects
- **Service integration**: Dependency injection through `ServiceFactory`
### Service Factory Pattern
All external dependencies managed through centralized factory:
```python
from biz_bud.core.services import get_global_factory
async with get_global_factory() as factory:
llm_service = await factory.get_service('llm')
db_service = await factory.get_service('database')
```
### Error Handling
Comprehensive error management:
- **Custom exceptions**: Project-specific error types
- **Retry patterns**: Exponential backoff for external services
- **Circuit breakers**: Failure threshold management
- **Graceful degradation**: Fallback processing modes
## Integration Points
### LangChain Tools
Paperless operations exposed as LangChain tools:
- Searchable through natural language
- Composable in multi-step workflows
- Automatic parameter validation
- Structured response formatting
### Database Persistence
Normalized PostgreSQL storage:
- Referential integrity constraints
- Optimized indexes for common queries
- Audit trail for all operations
- Batch insertion capabilities
### Web Services
External validation services:
- Product catalog lookups
- Vendor verification
- Price comparison APIs
- Nutritional data enrichment
## Performance Considerations
### Batch Processing
The system supports efficient batch operations:
- Concurrent document processing (configurable limits)
- Database connection pooling
- Memory-efficient streaming for large documents
- Progress tracking and error recovery
### Caching Strategy
Multi-level caching for optimal performance:
- In-memory LRU cache for frequent operations
- Redis caching for shared data
- Database query result caching
- API response caching with TTL
### Monitoring
Built-in observability features:
- Structured logging with correlation IDs
- Processing time metrics
- Error rate tracking
- Resource utilization monitoring
## Development
### Adding New Document Types
1. Extend classification logic in nodes/extraction.py
2. Add corresponding state fields in states/receipt.py
3. Update database schema if needed
4. Add test cases for the new document type
### Extending Agent Capabilities
1. Add new intent patterns in agent.py intent_analyzer_node
2. Create corresponding handler nodes
3. Update response models in agent.py
4. Add demo examples for new capabilities
### Performance Optimization
1. Profile with the included performance test suite
2. Optimize database queries using EXPLAIN ANALYZE
3. Adjust concurrency limits based on system capacity
4. Monitor memory usage during batch processing
## Troubleshooting
### Common Issues
**Connection Errors**
- Verify Paperless NGX is running and accessible
- Check API token configuration
- Validate database connection parameters
**Processing Failures**
- Review LLM service configuration
- Check document content format and encoding
- Verify database schema matches expected structure
**Performance Issues**
- Adjust concurrent processing limits
- Optimize database indexes
- Consider caching frequently accessed data
### Debug Mode
Enable detailed logging:
```python
import logging
logging.getLogger('biz_bud.graphs.paperless').setLevel(logging.DEBUG)
```
### Health Checks
The system includes health check endpoints:
```python
# Test all components
python -c "
from test_receipt_workflow import main
import asyncio
asyncio.run(main())
"
```

View File

@@ -0,0 +1,924 @@
"""Paperless Document Management Agent using Business Buddy patterns.
This module provides an intelligent agent interface for Paperless NGX document management,
using the project's BaseState and existing infrastructure with proper helper utilities.
"""
from __future__ import annotations
import asyncio
import hashlib
import json
import time
from typing import TYPE_CHECKING, Any, cast
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import tool
from langgraph.graph import END, START, StateGraph
from biz_bud.core.caching.cache_manager import LLMCache
from biz_bud.core.config.constants import (
HISTORY_MANAGEMENT_STRATEGY,
MAX_CONTEXT_TOKENS,
MAX_MESSAGE_WINDOW,
MAX_SUMMARY_TOKENS,
MESSAGE_SUMMARY_THRESHOLD,
PRESERVE_RECENT_MESSAGES,
)
from biz_bud.core.langgraph.cross_cutting import standard_node, track_metrics
from biz_bud.core.utils.graph_helpers import (
create_initial_state_dict,
format_raw_input,
process_state_query,
)
from biz_bud.logging import get_logger
from biz_bud.prompts.paperless import PAPERLESS_SYSTEM_PROMPT
from biz_bud.states.base import BaseState
from biz_bud.tools.capabilities.batch.receipt_processing import batch_process_receipt_items
from biz_bud.tools.capabilities.database.tool import (
postgres_reconcile_receipt_items,
postgres_search_normalized_items,
postgres_update_normalized_description,
)
from biz_bud.tools.capabilities.external.paperless.tool import (
create_paperless_tag,
get_paperless_correspondent,
get_paperless_document,
get_paperless_document_type,
get_paperless_statistics,
get_paperless_tag,
get_paperless_tags_by_ids,
list_paperless_correspondents,
list_paperless_document_types,
list_paperless_tags,
search_paperless_documents,
update_paperless_document,
)
from biz_bud.tools.capabilities.search.tool import list_search_providers, web_search
if TYPE_CHECKING:
from langgraph.graph.state import CompiledStateGraph
from biz_bud.services.factory import ServiceFactory
logger = get_logger(__name__)
# Module-level caches for performance optimization
_compiled_graph_cache: dict[str, CompiledStateGraph[BaseState]] = {}
_global_factory: ServiceFactory | None = None
_llm_cache: LLMCache[str] | None = None
_cache_ttl = 300 # 5 minutes default TTL for LLM responses
# Batch processing wrapper for tag operations
@tool
async def get_paperless_tags_batch(tag_ids: list[int]) -> dict[str, Any]:
"""Get multiple Paperless tags by their IDs with optimized batch processing.
This is an optimized version that fetches tags concurrently for better performance
when dealing with multiple tag IDs.
Args:
tag_ids: List of tag IDs to fetch
Returns:
Dictionary with tag information for each ID
"""
try:
# Execute tag fetches concurrently
tasks = []
for tag_id in tag_ids:
task = asyncio.create_task(get_paperless_tag.ainvoke({"tag_id": tag_id}))
tasks.append((tag_id, task))
# Gather results with timeout
results = {}
for tag_id, task in tasks:
try:
result = await asyncio.wait_for(task, timeout=5.0)
results[str(tag_id)] = result
except asyncio.TimeoutError:
results[str(tag_id)] = {
"success": False,
"error": "Timeout fetching tag",
}
except Exception as e:
results[str(tag_id)] = {"success": False, "error": str(e)}
return {
"success": True,
"tags": results,
"count": len(results),
"requested_ids": tag_ids,
}
except Exception as e:
logger.error(f"Batch tag fetch failed: {e}")
return {"success": False, "error": str(e), "requested_ids": tag_ids}
# Define the tools list and create tools_by_name dictionary
PAPERLESS_TOOLS = [
# Paperless tools
search_paperless_documents,
get_paperless_document,
update_paperless_document,
create_paperless_tag,
list_paperless_tags,
get_paperless_tag,
get_paperless_tags_by_ids,
get_paperless_tags_batch, # Add batch version
list_paperless_correspondents,
get_paperless_correspondent,
list_paperless_document_types,
get_paperless_document_type,
get_paperless_statistics,
# Batch processing tools
batch_process_receipt_items, # Batch process multiple receipt items
# PostgreSQL reconciliation tools
postgres_reconcile_receipt_items,
postgres_search_normalized_items,
postgres_update_normalized_description,
# Web search tools for validation and canonicalization
web_search,
list_search_providers,
]
# Create tools dictionary for easy lookup
TOOLS_BY_NAME = {tool.name: tool for tool in PAPERLESS_TOOLS}
async def _apply_message_history_management(
messages: list[Any],
state: dict[str, Any]
) -> tuple[list[Any], asyncio.Task[Any] | None]:
"""Apply message history management with background summarization.
This enhanced function manages the message history AFTER the AI has responded,
using concurrent summarization that doesn't block the main execution thread.
It integrates completed summaries from previous iterations and launches new ones as needed.
Args:
messages: Current messages including the latest response
state: Current state containing service factory and other context
Returns:
Tuple of (managed_messages, optional_summarization_task)
"""
from langchain_core.messages import SystemMessage
from biz_bud.core.utils.message_helpers import (
count_message_tokens,
manage_message_history_concurrent,
)
original_message_count = len(messages)
if original_message_count == 0:
return messages, None
# Check for pending summarization from previous iteration
if "_summarization_task" in state and state["_summarization_task"]:
task = state["_summarization_task"]
if isinstance(task, asyncio.Task) and task.done():
try:
summary = await task
# Find where to integrate the summary
system_messages = [m for m in messages if isinstance(m, SystemMessage)]
other_messages = [m for m in messages if not isinstance(m, SystemMessage)]
# Replace old summary or add new one
summary_found = False
for i, msg in enumerate(system_messages):
if hasattr(msg, 'content') and "CONVERSATION SUMMARY" in str(msg.content):
system_messages[i] = summary
summary_found = True
break
if not summary_found:
system_messages.append(summary)
messages = system_messages + other_messages
logger.info("Integrated background summarization into message history")
except Exception as e:
logger.warning(f"Background summarization integration failed: {e}")
# Check if we need message history management
current_tokens = count_message_tokens(messages, "gpt-4.1-mini")
if current_tokens > MAX_CONTEXT_TOKENS or original_message_count > MAX_MESSAGE_WINDOW:
logger.info(
f"Post-response message history management: {original_message_count} messages, "
f"{current_tokens} tokens (limit: {MAX_CONTEXT_TOKENS})"
)
try:
# Get the service factory for LLM-based summarization
service_factory = state.get("service_factory")
if not service_factory:
global _global_factory
if _global_factory is None:
from biz_bud.services.factory import get_global_factory
_global_factory = await get_global_factory()
service_factory = _global_factory
# Apply concurrent message history management
managed_messages, metrics, new_task = await manage_message_history_concurrent(
messages=messages,
max_tokens=MAX_CONTEXT_TOKENS,
strategy=HISTORY_MANAGEMENT_STRATEGY,
service_factory=service_factory,
model_name="gpt-4",
preserve_recent=PRESERVE_RECENT_MESSAGES,
summarization_threshold=MESSAGE_SUMMARY_THRESHOLD,
max_summary_tokens=MAX_SUMMARY_TOKENS
)
logger.info(
f"Post-response history managed: {metrics['original_count']} -> {metrics['final_count']} messages, "
f"{metrics['original_tokens']} -> {metrics['final_tokens']} tokens "
f"(saved {metrics['tokens_saved']} tokens), "
f"strategy: {metrics['strategy_used']}, "
f"summarization_pending: {'yes' if metrics.get('summarization_pending') else 'no'}, "
f"trimming: {'yes' if metrics.get('trimming_used') else 'no'}"
)
if metrics.get("summarization_pending"):
logger.info("Background summarization task launched")
return managed_messages, new_task
except Exception as e:
logger.warning(f"Post-response message management failed, falling back to simple truncation: {e}")
# Fallback to simple truncation
if len(messages) > MAX_MESSAGE_WINDOW:
system_messages = [msg for msg in messages if isinstance(msg, SystemMessage)]
non_system_messages = [msg for msg in messages if not isinstance(msg, SystemMessage)]
# Keep system messages and most recent non-system messages
messages_to_keep = MAX_MESSAGE_WINDOW - len(system_messages)
if messages_to_keep > 0:
recent_messages = non_system_messages[-messages_to_keep:]
return system_messages + recent_messages, None
else:
return system_messages, None
return messages, None
# Return original messages if no management needed
return messages, None
@standard_node(node_name="paperless_agent", metric_name="paperless_call")
async def paperless_agent_node(
state: dict[str, Any],
config: RunnableConfig,
) -> dict[str, Any]:
"""Paperless agent node that binds tools to the LLM with caching.
Args:
state: Current state with messages
config: Optional runtime configuration
Returns:
Updated state with agent response
"""
from biz_bud.core.utils import get_messages
from biz_bud.services.factory import get_global_factory
# Declare global variables at function start
global _global_factory
# Get messages from state - check for managed history first
messages = get_messages(state)
# If we have managed history from a previous iteration, use it instead
if "_managed_history" in state and state["_managed_history"]:
messages = state["_managed_history"]
logger.debug(f"Using managed message history: {len(messages)} messages")
# Check if we need to add system prompt
has_system = any(isinstance(msg, SystemMessage) for msg in messages)
if not has_system:
# Only copy when we need to modify
messages = [SystemMessage(content=PAPERLESS_SYSTEM_PROMPT)] + list(messages)
# Debug logging
logger.info(f"Agent node invoked with {len(messages)} messages")
# Get the service factory (may have been set during message management above)
service_factory = state.get("service_factory")
if not service_factory:
# Use cached global factory
if _global_factory is None:
_global_factory = await get_global_factory()
service_factory = _global_factory
# Get the LLM client directly for tool binding
llm_client = await service_factory.get_llm_client()
await llm_client.initialize()
llm = llm_client.llm
if llm is None:
raise ValueError("Failed to get LLM from service")
# Bind tools to the LLM
llm_with_tools = llm.bind_tools(PAPERLESS_TOOLS)
# Invoke the LLM with tools
start_time = time.time()
# Log message context for debugging
message_count = len(messages)
from biz_bud.core.utils.message_helpers import count_message_tokens
token_count = count_message_tokens(messages, "gpt-4")
logger.info(f"Invoking LLM with {message_count} messages, ~{token_count} tokens, {len(PAPERLESS_TOOLS)} tools")
try:
response = await llm_with_tools.ainvoke(messages)
elapsed_time = time.time() - start_time
if elapsed_time > 1.0:
logger.warning(f"Slow LLM call: {elapsed_time:.2f}s")
except Exception as e:
elapsed_time = time.time() - start_time
logger.error(f"LLM invocation failed after {elapsed_time:.2f}s with {message_count} messages ({token_count} tokens): {e}")
# Add specific handling for OpenAI API errors
if "server had an error" in str(e).lower():
logger.warning("OpenAI API server error detected - this is typically temporary. Retrying with reduced tool set.")
# Attempt retry with essential tools only
try:
essential_tools = [
search_paperless_documents,
get_paperless_document,
list_paperless_tags,
get_paperless_tag
]
llm_minimal = llm.bind_tools(essential_tools)
logger.info(f"Retrying with {len(essential_tools)} essential tools instead of {len(PAPERLESS_TOOLS)}")
response = await llm_minimal.ainvoke(messages)
elapsed_time = time.time() - start_time
logger.info(f"Retry successful with minimal tools in {elapsed_time:.2f}s")
except Exception as retry_e:
logger.error(f"Retry with minimal tools also failed: {retry_e}")
raise e # Raise original error
elif "context_length_exceeded" in str(e).lower() or "maximum context length" in str(e).lower():
logger.error(f"Context length exceeded with {token_count} tokens - message history management may have failed")
raise e
else:
# Re-raise the original exception for the error handling system
raise e
# Apply message history management AFTER generating response
# This ensures the AI can see tool results and respond before history is summarized
updated_messages = messages + [response]
# Only apply management if the conversation is getting long
# For shorter conversations, just return the response normally
from biz_bud.core.utils.message_helpers import count_message_tokens
if len(updated_messages) > MESSAGE_SUMMARY_THRESHOLD or count_message_tokens(updated_messages, "gpt-4") > MAX_CONTEXT_TOKENS:
managed_messages, summarization_task = await _apply_message_history_management(updated_messages, state)
# Store task for next iteration if pending
if summarization_task:
return {
"messages": [response],
"_managed_history": managed_messages,
"_summarization_task": summarization_task # Store for next iteration
}
else:
return {"messages": [response], "_managed_history": managed_messages}
# Return the response normally for shorter conversations
return {"messages": [response]}
@track_metrics("tool_execution")
async def execute_single_tool(tool_call: dict[str, Any]) -> ToolMessage:
"""Execute a single tool call and return the result with automatic error handling and metrics.
Args:
tool_call: Tool call dictionary with id, name, and args
Returns:
ToolMessage with the execution result
"""
tool_id = tool_call.get("id")
tool_name = tool_call.get("name", "unknown")
tool_args = tool_call.get("args", {})
# Every tool_call MUST have a corresponding ToolMessage response
if not tool_id:
logger.error(f"Tool call missing ID for tool: {tool_name}")
raise ValueError(f"Tool call missing ID for tool: {tool_name} - this should have been filtered out earlier")
if tool_name not in TOOLS_BY_NAME:
logger.error(f"Tool {tool_name} not found")
return ToolMessage(
content=f"Error: Tool {tool_name} not found",
tool_call_id=tool_id,
)
# Get the tool function
tool_func = TOOLS_BY_NAME[tool_name]
# Execute the tool with proper error handling
logger.debug(f"Starting execution of tool: {tool_name}")
try:
result = await tool_func.ainvoke(tool_args)
# Convert result to string for ToolMessage content
if isinstance(result, str):
content = result
elif isinstance(result, (dict, list)):
content = json.dumps(result)
else:
# Handle unexpected result types
logger.warning(f"Tool {tool_name} returned unexpected type: {type(result)}")
content = str(result)
except Exception as e:
logger.error(f"Tool {tool_name} execution failed with error: {e}")
# Return error message but don't raise - let the agent handle it
content = f"Error executing {tool_name}: {str(e)}"
# Always return a valid ToolMessage
return ToolMessage(
content=content,
tool_call_id=tool_id,
)
@standard_node(node_name="tool_executor", metric_name="tool_execution")
async def tool_executor_node(
state: dict[str, Any],
config: RunnableConfig,
) -> dict[str, Any]:
"""Execute tool calls from the last AI message with concurrent execution.
Args:
state: Current state with messages
config: Runtime configuration
Returns:
Updated state with tool execution results
"""
from biz_bud.core.utils import get_messages
messages = get_messages(state)
if not messages:
return {"messages": []}
last_message = messages[-1]
# Check if last message is an AIMessage with tool calls
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
logger.warning("No tool calls found in last message")
return {"messages": []}
# Execute tool calls concurrently
start_time = time.time()
# Create tasks for concurrent execution
tool_tasks = []
logger.debug(f"Processing {len(last_message.tool_calls)} tool calls")
for i, tool_call in enumerate(last_message.tool_calls):
# Extract tool call data, ensuring we get a valid ID
tool_call_dict = None
tool_id = None
if hasattr(tool_call, 'get'):
# Tool call is already a dict-like object
tool_call_dict = tool_call
tool_id = tool_call.get("id") or tool_call.get("tool_call_id")
else:
# Handle LangChain ToolCall objects
tool_id = getattr(tool_call, "id", None) or getattr(tool_call, "tool_call_id", None)
tool_call_dict = {
"id": tool_id,
"name": getattr(tool_call, "name", ""),
"args": getattr(tool_call, "args", {}),
}
# Skip tool calls without valid IDs - this prevents OpenAI API errors
if not tool_id:
logger.error(f"Tool call {i} has no valid ID, skipping to prevent API error. Tool: {tool_call}")
continue
logger.debug(
f"Tool call {i}: id='{tool_id}', "
f"name='{tool_call_dict.get('name', 'missing')}', "
f"type={type(tool_call).__name__}"
)
task = asyncio.create_task(execute_single_tool(tool_call_dict))
tool_tasks.append((tool_id, task))
# If no valid tool calls, return early
if not tool_tasks:
logger.warning("No valid tool calls found (all had missing IDs)")
return {"messages": []}
# Execute all tools concurrently with generous timeout for batch processing
# Initialize timeout_seconds for the entire function scope
timeout_seconds = 120.0 # Default 2 minutes for most tools
try:
# Extract just the tasks for execution
tasks = [task for _, task in tool_tasks]
# Determine timeout based on tools being called
# Check if any tool is batch_process_receipt_items (needs extended timeout)
for tool_call in last_message.tool_calls:
tool_name = None
if hasattr(tool_call, 'get'):
tool_name = tool_call.get("name")
else:
tool_name = getattr(tool_call, "name", None)
if tool_name == "batch_process_receipt_items":
timeout_seconds = 600.0 # 10 minutes for batch processing
logger.info("Using extended timeout (10 minutes) for batch_process_receipt_items")
break
tool_results = await asyncio.wait_for(
asyncio.gather(*tasks, return_exceptions=True),
timeout=timeout_seconds,
)
except asyncio.TimeoutError:
logger.error(f"Tool execution timed out after {timeout_seconds} seconds")
# Create timeout messages using the tracked tool IDs
tool_results = []
for tool_id, task in tool_tasks:
tool_results.append(
ToolMessage(
content="Error: Tool execution timed out",
tool_call_id=tool_id,
)
)
# Process results using the tracked tool IDs
tool_messages = []
for i, result in enumerate(tool_results):
# Get the corresponding tool ID from our tracked list
tool_id = tool_tasks[i][0] if i < len(tool_tasks) else None
if isinstance(result, Exception):
# Handle exceptions from gather
if tool_id:
logger.error(f"Tool with ID {tool_id} failed with exception: {result}")
tool_messages.append(
ToolMessage(
content=f"Error executing tool: {str(result)}",
tool_call_id=tool_id,
)
)
else:
logger.error(f"Tool execution failed but no tool ID available: {result}")
# Skip this one to avoid invalid tool messages
elif isinstance(result, ToolMessage):
# Verify the tool message has a valid tool_call_id
if result.tool_call_id and result.tool_call_id != "unknown":
tool_messages.append(result)
else:
# Fix the tool message with the correct ID
if tool_id:
tool_messages.append(
ToolMessage(
content=result.content,
tool_call_id=tool_id,
)
)
else:
logger.error("ToolMessage has invalid tool_call_id and no fallback ID available")
else:
# Unexpected result type - try to convert to string representation
if tool_id:
logger.error(f"Unexpected tool result type for tool {tool_id}: {type(result)}, value: {result}")
# Try to extract useful content from the unexpected result
try:
if hasattr(result, '__dict__'):
content = json.dumps(result.__dict__)
else:
content = str(result)
except Exception as e:
content = f"Error: Unexpected tool result (type: {type(result).__name__})"
logger.error(f"Failed to serialize unexpected result: {e}")
tool_messages.append(
ToolMessage(
content=content,
tool_call_id=tool_id,
)
)
else:
logger.error(f"Unexpected tool result type with no tool ID: {type(result)}")
# Ensure we have the correct number of tool messages for valid tool calls
# This is critical for OpenAI API compatibility
if len(tool_messages) < len(tool_tasks):
logger.error(f"Insufficient tool messages generated: {len(tool_messages)} messages for {len(tool_tasks)} valid tool calls")
# Generate error messages for any missing tool responses
for i in range(len(tool_messages), len(tool_tasks)):
if i < len(tool_tasks):
tool_id = tool_tasks[i][0]
tool_messages.append(
ToolMessage(
content="Error: Tool execution failed",
tool_call_id=tool_id,
)
)
# Performance logging
elapsed_time = time.time() - start_time
logger.info(
f"Tool executor completed {len(tool_messages)} tools in {elapsed_time:.2f}s "
f"(concurrent execution)"
)
return {"messages": tool_messages}
def should_continue(state: dict[str, Any]) -> str:
"""Determine whether to continue to tools or end.
Args:
state: Current state with messages
Returns:
"tools" if tool calls are present, "end" otherwise
"""
from biz_bud.core.utils import get_messages
messages = get_messages(state)
if not messages:
return "end"
last_message = messages[-1]
# Check if last message has tool calls
if isinstance(last_message, AIMessage) and last_message.tool_calls:
return "tools"
return "end"
def create_paperless_agent(
config: dict[str, Any] | str | None = None,
) -> CompiledStateGraph[BaseState]:
"""Create a Paperless agent using Business Buddy patterns with caching.
This creates an agent that can bind tools to the LLM and execute them
for Paperless document management operations. The compiled graph is cached
for performance.
Args:
config: Configuration dict from LangGraph API or string cache key
Returns:
Compiled LangGraph agent with tool execution flow
"""
# Handle different config types
if config is None:
config_hash = "default"
elif isinstance(config, str):
config_hash = config
else:
# Assume it's a dict - create a stable hash from it
try:
config_str = json.dumps(config, sort_keys=True)
config_hash = hashlib.sha256(config_str.encode()).hexdigest()[:8]
except (TypeError, ValueError):
# Fallback if config is not JSON serializable
config_hash = "default"
# Check cache first (sync check is safe for read operations)
if config_hash in _compiled_graph_cache:
logger.debug(f"Returning cached graph for config: {config_hash}")
return _compiled_graph_cache[config_hash]
# Create the graph with BaseState
builder = StateGraph(BaseState)
# Add nodes
builder.add_node("agent", paperless_agent_node)
builder.add_node("tools", tool_executor_node)
# Define the flow
# 1. Start -> Agent
builder.add_edge(START, "agent")
# 2. Agent -> Conditional (tools or end)
builder.add_conditional_edges(
"agent",
should_continue,
{
"tools": "tools",
"end": END,
},
)
# 3. Tools -> Agent (for another iteration)
builder.add_edge("tools", "agent")
# Compile the graph
graph = builder.compile()
# Cache the compiled graph (thread-safe dict operation)
_compiled_graph_cache[config_hash] = graph
logger.info(f"Created and cached Paperless agent for config: {config_hash}")
return graph
async def process_paperless_request(
user_input: str,
thread_id: str | None = None,
**kwargs: Any,
) -> dict[str, Any]:
"""Process a Paperless request using the agent with optimized caching.
Args:
user_input: Natural language request
thread_id: Optional thread ID for conversation tracking
**kwargs: Additional parameters
Returns:
Agent response with results
Examples:
# Search for documents
response = await process_paperless_request("Find all receipts from Target")
# List all documents
response = await process_paperless_request("List all my documents")
# Get document details
response = await process_paperless_request("Show me details for document 123")
# View statistics
response = await process_paperless_request("Show system statistics")
"""
request_start = time.time()
logger.info(f"Processing paperless request: {user_input[:100]}...")
try:
# Use cached factory if available
factory_start = time.time()
global _global_factory
if _global_factory is None:
from biz_bud.core.config.loader import load_config
from biz_bud.services.factory import get_global_factory
config = load_config()
_global_factory = await get_global_factory(config)
factory_elapsed = time.time() - factory_start
logger.info(f"Initialized global factory in {factory_elapsed:.2f}s")
else:
logger.debug("Using cached global factory")
factory = _global_factory
# Create or get cached agent
graph_start = time.time()
agent = create_paperless_agent()
graph_elapsed = time.time() - graph_start
if graph_elapsed > 0.1:
logger.info(f"Graph creation/retrieval took {graph_elapsed * 1000:.2f}ms")
# Use proper helper functions to create initial state
# Process the query
user_query = process_state_query(
query=user_input, messages=None, state_update=None, default_query=user_input
)
# Format raw input
raw_input_str, extracted_query = format_raw_input(
raw_input={"query": user_input}, user_query=user_query
)
# Create messages with system prompt - use proper Message objects
from langchain_core.messages import SystemMessage
messages = [
SystemMessage(content=PAPERLESS_SYSTEM_PROMPT),
HumanMessage(content=user_input),
]
# Convert messages to dict format for the helper function
messages_dicts = [
{"role": "system", "content": PAPERLESS_SYSTEM_PROMPT},
{"role": "user", "content": user_input},
]
# Use the standard helper to create initial state
initial_state = create_initial_state_dict(
raw_input_str=raw_input_str,
user_query=extracted_query,
messages=messages_dicts,
thread_id=thread_id or f"paperless-{time.time()}",
config_dict={},
)
# Override with proper message objects
initial_state["messages"] = messages
# Add the service factory and tools to state
initial_state["service_factory"] = factory
initial_state["tools"] = PAPERLESS_TOOLS
# Invoke the agent
invoke_start = time.time()
result = await agent.ainvoke(cast(BaseState, initial_state))
invoke_elapsed = time.time() - invoke_start
# Log total request time
total_elapsed = time.time() - request_start
logger.info(
f"Paperless request completed in {total_elapsed:.2f}s "
f"(graph invoke: {invoke_elapsed:.2f}s)"
)
# Add performance metrics to result if it's a dict
if isinstance(result, dict):
result["_performance_metrics"] = {
"total_time_seconds": total_elapsed,
"graph_invoke_seconds": invoke_elapsed,
"cached_factory": True, # Always True at this point
"cached_graph": "default" in _compiled_graph_cache,
}
return result
except Exception as e:
total_elapsed = time.time() - request_start
logger.error(f"Error processing request after {total_elapsed:.2f}s: {e}")
return {
"status": "error",
"errors": [{"message": str(e)}],
"messages": [HumanMessage(content=user_input)],
"_performance_metrics": {
"total_time_seconds": total_elapsed,
"error": True,
},
}
async def initialize_paperless_agent() -> None:
"""Pre-initialize agent resources for better performance.
This function warms up the caches and initializes critical services
to improve first-request latency. Call this during application startup.
"""
logger.info("Initializing paperless agent resources...")
start_time = time.time()
try:
# Initialize global factory and critical services
global _global_factory
if _global_factory is None:
from biz_bud.core.config.loader import load_config
from biz_bud.services.factory import get_global_factory
config = load_config()
_global_factory = await get_global_factory(config)
# Pre-initialize critical services
await _global_factory.initialize_critical_services()
logger.info("Initialized global factory and critical services")
# Initialize LLM cache
global _llm_cache
if _llm_cache is None:
_llm_cache = LLMCache[str](ttl=_cache_ttl)
logger.info(f"Initialized LLM cache with TTL={_cache_ttl}s")
# Pre-compile default graph
create_paperless_agent("default")
logger.info("Pre-compiled default paperless agent graph")
# Pre-warm LLM service
if _global_factory:
await _global_factory.get_llm_for_node(
node_context="paperless_agent", temperature_override=0.3
)
logger.info("Pre-warmed LLM service for paperless agent")
elapsed_time = time.time() - start_time
logger.info(f"Paperless agent initialization completed in {elapsed_time:.2f}s")
except Exception as e:
logger.error(f"Failed to initialize paperless agent: {e}")
# Don't raise - allow the agent to work even if pre-initialization fails
# Export key functions
__all__ = [
"create_paperless_agent",
"process_paperless_request",
"initialize_paperless_agent",
"PAPERLESS_TOOLS",
]

View File

@@ -12,20 +12,32 @@ from typing import TYPE_CHECKING, Any, TypedDict
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, START, StateGraph
from biz_bud.core.edge_helpers import create_enum_router
from biz_bud.core.edge_helpers import (
WorkflowRouters,
check_confidence_level,
create_enum_router,
validate_required_fields,
)
from biz_bud.core.langgraph import configure_graph_with_injection
from biz_bud.graphs.paperless.nodes import ( # Core document management nodes; API integration nodes; Processing nodes
analyze_document_node,
from biz_bud.graphs.paperless.nodes import (
analyze_document_node, # Core document management nodes; API integration nodes; Processing nodes; Receipt processing nodes
)
from biz_bud.graphs.paperless.nodes import (
execute_document_search_node,
extract_document_text_node,
paperless_document_retrieval_node,
paperless_document_validator_node,
paperless_metadata_management_node,
paperless_search_node,
process_document_node,
receipt_item_validation_node,
receipt_line_items_parser_node,
receipt_llm_extraction_node,
suggest_document_tags_node,
)
from biz_bud.logging import get_logger
from biz_bud.states.base import BaseState
from biz_bud.states.receipt import ReceiptState
if TYPE_CHECKING:
from langgraph.graph.state import CompiledStateGraph
@@ -40,12 +52,16 @@ GRAPH_METADATA = {
"capabilities": [
"document_search",
"document_retrieval",
"document_validation",
"metadata_management",
"tag_management",
"paperless_ngx",
"document_processing",
"text_extraction",
"metadata_extraction",
"receipt_reconciliation",
"product_validation",
"llm_structured_extraction",
],
"example_queries": [
"Search for invoices from last month",
@@ -54,6 +70,9 @@ GRAPH_METADATA = {
"Extract text and metadata from PDF",
"Suggest tags for new document",
"Update document metadata",
"Validate if Paperless document exists in PostgreSQL",
"Process receipt for product reconciliation",
"Extract and validate receipt line items",
],
"input_requirements": ["query", "operation"],
"output_format": "structured document results with metadata and processing status",
@@ -64,7 +83,7 @@ class PaperlessStateRequired(TypedDict):
"""Required fields for Paperless NGX workflow."""
query: str
operation: str # orchestrate, search, retrieve, manage_metadata
operation: str # orchestrate, search, retrieve, manage_metadata, validate, receipt_reconciliation
class PaperlessStateOptional(TypedDict, total=False):
@@ -93,21 +112,29 @@ class PaperlessStateOptional(TypedDict, total=False):
limit: int
offset: int
# Document validation fields
paperless_document_id: str | None
document_content: str | None
raw_receipt_text: str | None
validation_method: str | None
similarity_threshold: float | None
# Results
paperless_results: list[dict[str, Any]] | None
metadata_results: dict[str, Any] | None
validation_results: dict[str, Any] | None
# Workflow control
workflow_step: str
needs_search: bool
needs_retrieval: bool
needs_metadata: bool
needs_validation: bool
class PaperlessState(BaseState, PaperlessStateRequired, PaperlessStateOptional):
"""State for Paperless NGX document management workflow."""
pass
# Simplified routing function
@@ -119,73 +146,32 @@ _route_by_operation = create_enum_router(
"analyze": "analyze",
"extract": "extract",
"suggest_tags": "suggest_tags",
"validate": "validate",
"receipt_reconciliation": "extract", # Use standard extraction for receipt content
},
state_key="operation",
default_target="search",
)
def create_paperless_graph(
def _create_receipt_processing_graph_internal(
config: dict[str, Any] | None = None,
app_config: object | None = None,
service_factory: object | None = None,
) -> CompiledStateGraph:
"""Create the standardized Paperless NGX document management graph.
) -> CompiledStateGraph[ReceiptState]:
"""Internal function to create a focused receipt processing graph using ReceiptState."""
builder = StateGraph(ReceiptState)
This simplified graph provides core document management workflows:
1. Document search and retrieval
2. Document processing and analysis
3. Text and metadata extraction
4. Tag suggestions and management
# Add receipt processing nodes
builder.add_node("receipt_llm_extraction", receipt_llm_extraction_node)
builder.add_node("receipt_line_items_parser", receipt_line_items_parser_node)
builder.add_node("receipt_item_validation", receipt_item_validation_node)
Args:
config: Optional configuration dictionary (deprecated, use app_config)
app_config: Application configuration object
service_factory: Service factory for dependency injection
Returns:
Compiled StateGraph for Paperless NGX operations
"""
# Simplified flow:
# START -> route_by_operation -> [search|retrieve|process|analyze|extract|suggest_tags] -> END
builder = StateGraph(PaperlessState)
# Add core nodes following standard pattern
builder.add_node("search", execute_document_search_node)
builder.add_node("retrieve", paperless_document_retrieval_node)
builder.add_node("process", process_document_node)
builder.add_node("analyze", analyze_document_node)
builder.add_node("extract", extract_document_text_node)
builder.add_node("suggest_tags", suggest_document_tags_node)
# Add API integration nodes
builder.add_node("api_search", paperless_search_node)
builder.add_node("metadata_mgmt", paperless_metadata_management_node)
# Simple routing from start
builder.add_conditional_edges(
START,
_route_by_operation,
{
"search": "search",
"retrieve": "retrieve",
"process": "process",
"analyze": "analyze",
"extract": "extract",
"suggest_tags": "suggest_tags",
},
)
# All nodes end the workflow
builder.add_edge("search", END)
builder.add_edge("retrieve", END)
builder.add_edge("process", END)
builder.add_edge("analyze", END)
builder.add_edge("extract", END)
builder.add_edge("suggest_tags", END)
builder.add_edge("api_search", END)
builder.add_edge("metadata_mgmt", END)
# Simple sequential flow
builder.add_edge(START, "receipt_llm_extraction")
builder.add_edge("receipt_llm_extraction", "receipt_line_items_parser")
builder.add_edge("receipt_line_items_parser", "receipt_item_validation")
builder.add_edge("receipt_item_validation", END)
# Configure with dependency injection if provided
if app_config or service_factory:
@@ -196,7 +182,203 @@ def create_paperless_graph(
return builder.compile()
def paperless_graph_factory(config: RunnableConfig) -> CompiledStateGraph:
def create_receipt_processing_graph(config: RunnableConfig) -> CompiledStateGraph[ReceiptState]:
"""Create a focused receipt processing graph for LangGraph API.
Args:
config: RunnableConfig from LangGraph API
Returns:
Compiled StateGraph for receipt processing operations
"""
return _create_receipt_processing_graph_internal()
def create_receipt_processing_graph_direct(
config: dict[str, Any] | None = None,
app_config: object | None = None,
service_factory: object | None = None,
) -> CompiledStateGraph[ReceiptState]:
"""Create a focused receipt processing graph for direct usage.
Args:
config: Optional configuration dictionary
app_config: Application configuration object
service_factory: Service factory for dependency injection
Returns:
Compiled StateGraph for receipt processing operations
"""
return _create_receipt_processing_graph_internal(config, app_config, service_factory)
def create_paperless_graph(
config: dict[str, Any] | None = None,
app_config: object | None = None,
service_factory: object | None = None,
) -> CompiledStateGraph[PaperlessState]:
"""Create the standardized Paperless NGX document management graph.
This graph provides intelligent document management workflows with:
1. Smart routing based on confidence and validation
2. Error handling and recovery
3. Integrated API and local processing
4. Quality gates and fallback paths
Args:
config: Optional configuration dictionary (deprecated, use app_config)
app_config: Application configuration object
service_factory: Service factory for dependency injection
Returns:
Compiled StateGraph for Paperless NGX operations
"""
# Use PaperlessState for proper state typing
# Node functions are compatible as PaperlessState extends TypedDict
builder = StateGraph(PaperlessState)
# Add core processing nodes
# Type ignore needed only for nodes that use dict[str, Any] signatures
# but PaperlessState is compatible at runtime as TypedDict extends dict
builder.add_node("search", execute_document_search_node)
builder.add_node("retrieve", paperless_document_retrieval_node) # type: ignore[arg-type]
builder.add_node("process", process_document_node)
builder.add_node("analyze", analyze_document_node)
builder.add_node("extract", extract_document_text_node)
builder.add_node("suggest_tags", suggest_document_tags_node)
builder.add_node("validate", paperless_document_validator_node)
# Add API integration nodes (now properly connected)
builder.add_node("api_search", paperless_search_node) # type: ignore[arg-type]
builder.add_node("metadata_mgmt", paperless_metadata_management_node) # type: ignore[arg-type]
# Add error handling and quality gates
def error_handler(state: PaperlessState) -> dict[str, Any]:
return {"status": "error_handled", **state}
def quality_check(state: PaperlessState) -> dict[str, Any]:
return {**state, "quality_checked": True}
builder.add_node("error_handler", error_handler)
builder.add_node("quality_check", quality_check)
# Create smart routers using edge_helpers
confidence_router = check_confidence_level(threshold=0.8, confidence_key="validation_confidence")
required_fields_router = validate_required_fields(["query", "operation"])
# error_router = handle_error("errors") # Reserved for future error handling enhancement
# Workflow stage router for complex processing
workflow_router = WorkflowRouters.route_workflow_stage(
stage_key="workflow_step",
stage_mapping={
"search_completed": "quality_check",
"retrieval_completed": "analyze",
"validation_completed": "metadata_mgmt",
"processing_completed": "suggest_tags",
"analysis_completed": "extract",
},
default_stage="END"
)
# Initial routing from START
builder.add_conditional_edges(
START,
required_fields_router,
{
"valid": "route_operation",
"missing_fields": "error_handler",
},
)
# Add operation routing node
builder.add_node("route_operation", lambda state: state)
builder.add_conditional_edges(
"route_operation",
_route_by_operation,
{
"search": "api_search", # Use API search first
"retrieve": "retrieve",
"process": "process",
"analyze": "analyze",
"extract": "extract",
"suggest_tags": "suggest_tags",
"validate": "validate",
"receipt_reconciliation": "extract",
},
)
# Connect API search to local search with confidence check
builder.add_conditional_edges(
"api_search",
confidence_router,
{
"high_confidence": "quality_check",
"low_confidence": "search", # Fallback to local search
},
)
# Local search flows to quality check
builder.add_edge("search", "quality_check")
# Quality check routes based on workflow stage
builder.add_conditional_edges(
"quality_check",
workflow_router,
)
# Validation flows to metadata management or ends
builder.add_conditional_edges(
"validate",
confidence_router,
{
"high_confidence": "metadata_mgmt",
"low_confidence": "error_handler",
},
)
# Process node can flow to analysis or extraction
builder.add_conditional_edges(
"process",
lambda state: "analyze" if state.get("needs_analysis", False) else "extract",
{
"analyze": "analyze",
"extract": "extract",
},
)
# Analysis flows to metadata management
builder.add_edge("analyze", "metadata_mgmt")
# Extraction and tag suggestion end the workflow
builder.add_edge("extract", END)
builder.add_edge("suggest_tags", END)
builder.add_edge("retrieve", END)
# Metadata management can suggest tags or end
builder.add_conditional_edges(
"metadata_mgmt",
lambda state: "suggest_tags" if state.get("operation") == "process" else "END",
{
"suggest_tags": "suggest_tags",
"END": END,
},
)
# Receipt processing through standard extraction path
# Error handler always ends
builder.add_edge("error_handler", END)
# Configure with dependency injection if provided
if app_config or service_factory:
builder = configure_graph_with_injection(
builder, app_config=app_config, service_factory=service_factory
)
return builder.compile()
def paperless_graph_factory(config: RunnableConfig) -> CompiledStateGraph[PaperlessState]:
"""Create Paperless graph for LangGraph API.
Args:
@@ -215,13 +397,37 @@ async def paperless_graph_factory_async(config: RunnableConfig) -> Any: # noqa:
return await asyncio.to_thread(paperless_graph_factory, config)
def receipt_processing_graph_factory(config: RunnableConfig) -> CompiledStateGraph[ReceiptState]:
"""Create receipt processing graph for LangGraph API.
Args:
config: RunnableConfig from LangGraph API
Returns:
Compiled receipt processing graph
"""
return create_receipt_processing_graph(config)
# Async factory function for LangGraph API
async def receipt_processing_graph_factory_async(config: RunnableConfig) -> Any: # noqa: ANN401
"""Async wrapper for receipt_processing_graph_factory to avoid blocking calls."""
import asyncio
return await asyncio.to_thread(receipt_processing_graph_factory, config)
# Create function reference for direct imports
paperless_graph = create_paperless_graph
__all__ = [
"create_paperless_graph",
"create_receipt_processing_graph",
"create_receipt_processing_graph_direct",
"receipt_processing_graph_factory",
"receipt_processing_graph_factory_async",
"paperless_graph_factory",
"paperless_graph_factory_async",
"paperless_graph",
"PaperlessState",
"GRAPH_METADATA",

View File

@@ -1,208 +0,0 @@
"""Paperless-NGX specific nodes for document management workflow.
This module contains nodes that are specific to the Paperless-NGX integration,
handling document upload, retrieval, and management operations.
"""
from __future__ import annotations
from typing import Any
from langchain_core.runnables import RunnableConfig
from biz_bud.core.errors import create_error_info
from biz_bud.core.langgraph import standard_node
from biz_bud.logging import debug_highlight, error_highlight, info_highlight
# Import from local nodes directory
try:
from .nodes.paperless import (
paperless_document_retrieval_node,
paperless_orchestrator_node,
paperless_search_node,
)
from .nodes.processing import process_document_node
# Legacy imports for compatibility
paperless_upload_node = paperless_orchestrator_node # Alias
paperless_retrieve_node = paperless_document_retrieval_node # Alias
_legacy_imports_available = True
except ImportError:
_legacy_imports_available = False
paperless_upload_node = None
paperless_search_node = None
paperless_retrieve_node = None
paperless_orchestrator_node = None
paperless_document_retrieval_node = None
paperless_metadata_management_node = None
process_document_node = None
@standard_node(node_name="paperless_query_builder", metric_name="query_building")
async def build_paperless_query_node(
state: dict[str, Any], config: RunnableConfig | None = None
) -> dict[str, Any]:
"""Build search queries for Paperless-NGX API.
This node transforms natural language queries into Paperless-NGX
search syntax with appropriate filters and parameters.
Args:
state: Current workflow state
config: Optional runtime configuration
Returns:
Updated state with search query
"""
debug_highlight("Building Paperless-NGX search query...", category="QueryBuilder")
user_query = state.get("query", "")
filters = state.get("search_filters", {})
try:
# Build query components
query_parts = []
# Add text search
if user_query:
query_parts.append(user_query)
# Add filters
if filters.get("document_type"):
query_parts.append(f'type:"{filters["document_type"]}"')
if filters.get("correspondent"):
query_parts.append(f'correspondent:"{filters["correspondent"]}"')
if filters.get("tags"):
query_parts.extend(f'tag:"{tag}"' for tag in filters["tags"])
if filters.get("date_range"):
date_range = filters["date_range"]
if date_range.get("start"):
query_parts.append(f'created:[{date_range["start"]} TO *]')
if date_range.get("end"):
query_parts.append(f'created:[* TO {date_range["end"]}]')
# Combine query parts
paperless_query = " AND ".join(query_parts) if query_parts else "*"
# Add sorting and pagination
query_params = {
"query": paperless_query,
"ordering": filters.get("ordering", "-created"),
"page_size": filters.get("page_size", 20),
"page": filters.get("page", 1),
}
info_highlight(f"Built Paperless query: {paperless_query}", category="QueryBuilder")
return {"paperless_query": paperless_query, "query_params": query_params}
except Exception as e:
error_msg = f"Query building failed: {str(e)}"
error_highlight(error_msg, category="QueryBuilder")
return {
"paperless_query": "*",
"query_params": {"query": "*"},
"errors": [
create_error_info(
message=error_msg,
node="build_paperless_query",
severity="warning",
category="query_error",
)
],
}
@standard_node(node_name="paperless_result_formatter", metric_name="result_formatting")
async def format_paperless_results_node(
state: dict[str, Any], config: RunnableConfig | None = None
) -> dict[str, Any]:
"""Format Paperless-NGX search results for presentation.
This node transforms raw API responses into user-friendly formats
with relevant metadata and summaries.
Args:
state: Current workflow state
config: Optional runtime configuration
Returns:
Updated state with formatted results
"""
info_highlight("Formatting Paperless-NGX results...", category="ResultFormatter")
raw_results = state.get("search_results", [])
try:
formatted_results = []
for result in raw_results:
formatted = {
"id": result.get("id"),
"title": result.get("title", "Untitled"),
"correspondent": result.get("correspondent_name", "Unknown"),
"document_type": result.get("document_type_name", "General"),
"created": result.get("created"),
"tags": [tag.get("name", "") for tag in result.get("tags", [])],
"summary": result.get("content", "")[:200] + "...",
"download_url": result.get("download_url"),
"relevance_score": result.get("relevance_score", 0),
}
formatted_results.append(formatted)
# Sort by relevance if available
formatted_results.sort(key=lambda x: x.get("relevance_score", 0), reverse=True)
# Generate summary
summary = {
"total_results": len(formatted_results),
"document_types": list({r["document_type"] for r in formatted_results}),
"correspondents": list({r["correspondent"] for r in formatted_results}),
"date_range": {
"earliest": min((r["created"] for r in formatted_results), default=None),
"latest": max((r["created"] for r in formatted_results), default=None),
},
}
info_highlight(
f"Formatted {len(formatted_results)} Paperless-NGX results", category="ResultFormatter"
)
return {"formatted_results": formatted_results, "results_summary": summary}
except Exception as e:
error_msg = f"Result formatting failed: {str(e)}"
error_highlight(error_msg, category="ResultFormatter")
return {
"formatted_results": [],
"results_summary": {"total_results": 0},
"errors": [
create_error_info(
message=error_msg,
node="format_paperless_results",
severity="warning",
category="formatting_error",
)
],
}
# Export all Paperless-specific nodes
__all__ = [
"process_document_node",
"build_paperless_query_node",
"format_paperless_results_node",
]
# Re-export legacy nodes if available
if _legacy_imports_available:
if paperless_upload_node:
__all__.append("paperless_upload_node")
if paperless_search_node:
__all__.append("paperless_search_node")
if paperless_retrieve_node:
__all__.append("paperless_retrieve_node")

View File

@@ -12,6 +12,7 @@ from .core import (
extract_document_text_node,
suggest_document_tags_node,
)
from .document_validator import paperless_document_validator_node
from .paperless import (
paperless_document_retrieval_node,
paperless_metadata_management_node,
@@ -23,6 +24,11 @@ from .processing import (
format_paperless_results_node,
process_document_node,
)
from .receipt_processing import (
receipt_item_validation_node,
receipt_line_items_parser_node,
receipt_llm_extraction_node,
)
__all__ = [
# core document management nodes
@@ -40,4 +46,10 @@ __all__ = [
"paperless_search_node",
"paperless_document_retrieval_node",
"paperless_metadata_management_node",
# document validation nodes
"paperless_document_validator_node",
# receipt processing nodes
"receipt_llm_extraction_node",
"receipt_line_items_parser_node",
"receipt_item_validation_node",
]

View File

@@ -33,7 +33,7 @@ class DocumentResult(TypedDict):
@standard_node(node_name="analyze_document", metric_name="document_analysis")
async def analyze_document_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Analyze document to determine processing requirements.
@@ -110,7 +110,7 @@ async def analyze_document_node(
@standard_node(node_name="extract_document_text", metric_name="text_extraction")
async def extract_document_text_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Extract text from document using appropriate method.
@@ -180,7 +180,7 @@ async def extract_document_text_node(
@standard_node(node_name="extract_document_metadata", metric_name="metadata_extraction")
async def extract_document_metadata_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Extract metadata from document.
@@ -266,7 +266,7 @@ async def extract_document_metadata_node(
@standard_node(node_name="suggest_document_tags", metric_name="tag_suggestion")
async def suggest_document_tags_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Suggest tags for document based on content analysis.
@@ -360,7 +360,7 @@ async def suggest_document_tags_node(
@standard_node(node_name="execute_document_search", metric_name="document_search")
async def execute_document_search_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Execute document search in Paperless-NGX.

View File

@@ -0,0 +1,412 @@
"""Document existence validator node for Paperless NGX to PostgreSQL validation.
This module provides functionality to check if documents from Paperless NGX exist
in the PostgreSQL database, specifically in the receipt processing tables.
"""
from __future__ import annotations
from difflib import SequenceMatcher
from typing import Any
from langchain_core.runnables import RunnableConfig
from biz_bud.core.errors import DatabaseError
from biz_bud.core.langgraph import handle_errors, log_node_execution, standard_node
from biz_bud.logging import get_logger
from biz_bud.services.factory import get_global_factory
logger = get_logger(__name__)
def _calculate_text_similarity(text1: str, text2: str) -> float:
"""Calculate similarity between two text strings using SequenceMatcher.
Args:
text1: First text string
text2: Second text string
Returns:
Similarity score between 0.0 and 1.0
"""
if not text1 or not text2:
return 0.0
# Normalize text for better matching
norm_text1 = text1.lower().strip()
norm_text2 = text2.lower().strip()
matcher = SequenceMatcher(None, norm_text1, norm_text2)
return matcher.ratio()
@standard_node(node_name="validate_document", metric_name="document_validation")
@handle_errors()
@log_node_execution("document_validation")
async def paperless_document_validator_node(
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Validate if a Paperless NGX document exists in PostgreSQL database.
This node checks if a document from Paperless NGX has been processed and stored
in the PostgreSQL receipt processing tables. It supports multiple validation
methods including exact text matching, similarity matching, and metadata matching.
Args:
state: Current workflow state containing validation parameters:
- paperless_document_id: Optional Paperless document ID
- document_content: Optional raw text content to validate
- raw_receipt_text: Optional direct receipt text to search
- validation_method: 'exact_match', 'similarity_match', or 'metadata_match'
- similarity_threshold: Minimum similarity score for fuzzy matching (default: 0.8)
config: Configuration with database credentials
Returns:
State updates with validation results:
- document_exists: Boolean indicating if document was found
- postgres_receipt_id: Receipt ID if found, None otherwise
- receipt_metadata: Dictionary with receipt details if found
- validation_confidence: Float score (0.0-1.0) indicating match confidence
- match_method: String indicating how the match was found
- validation_status: 'found', 'not_found', or 'error'
"""
try:
logger.info("Starting Paperless document validation")
# Extract validation parameters from state
paperless_doc_id = state.get("paperless_document_id")
document_content = state.get("document_content")
raw_receipt_text = state.get("raw_receipt_text")
validation_method = state.get("validation_method", "exact_match")
similarity_threshold = state.get("similarity_threshold", 0.8)
# Determine what content to validate
content_to_validate = raw_receipt_text or document_content
if not content_to_validate:
if paperless_doc_id:
# If we have a Paperless document ID but no content, we could
# potentially fetch it from Paperless API, but for now return error
return {
"document_exists": False,
"error": "Content required for validation - provide document_content or raw_receipt_text",
"validation_status": "error",
"postgres_receipt_id": None,
"receipt_metadata": None,
"validation_confidence": 0.0,
"match_method": "none"
}
else:
return {
"document_exists": False,
"error": "No document content or Paperless document ID provided",
"validation_status": "error",
"postgres_receipt_id": None,
"receipt_metadata": None,
"validation_confidence": 0.0,
"match_method": "none"
}
# Get database connection through service factory
factory = await get_global_factory()
# Try to get a database service - the exact service name may vary
# We'll need to check what's available in the project
try:
db_connection = await factory.get_database_connection() # type: ignore[attr-defined]
except Exception as e:
logger.error(f"Failed to get database connection: {e}")
raise DatabaseError(f"Database connection failed: {e}")
# Perform validation based on selected method
validation_result = None
if validation_method == "exact_match":
validation_result = await _perform_exact_match(db_connection, content_to_validate)
elif validation_method == "similarity_match":
validation_result = await _perform_similarity_match(
db_connection, content_to_validate, similarity_threshold
)
elif validation_method == "metadata_match":
validation_result = await _perform_metadata_match(db_connection, state)
else:
# Try all methods in order of confidence
logger.info("Attempting validation with all available methods")
# 1. Try exact match first
validation_result = await _perform_exact_match(db_connection, content_to_validate)
# 2. If no exact match, try similarity matching
if not validation_result["document_exists"]:
validation_result = await _perform_similarity_match(
db_connection, content_to_validate, similarity_threshold
)
# 3. If still no match, try metadata matching
if not validation_result["document_exists"]:
metadata_result = await _perform_metadata_match(db_connection, state)
if metadata_result["document_exists"]:
validation_result = metadata_result
# Add Paperless document ID to result if provided
if paperless_doc_id:
validation_result["paperless_document_id"] = paperless_doc_id
# Set final validation status
if validation_result["document_exists"]:
validation_result["validation_status"] = "found"
else:
validation_result["validation_status"] = "not_found"
logger.info(
f"Document validation completed: exists={validation_result['document_exists']}, "
f"method={validation_result['match_method']}, "
f"confidence={validation_result['validation_confidence']}"
)
return validation_result
except Exception as e:
logger.error(f"Error in document validation: {e}")
return {
"document_exists": False,
"error": str(e),
"validation_status": "error",
"postgres_receipt_id": None,
"receipt_metadata": None,
"validation_confidence": 0.0,
"match_method": "error"
}
async def _perform_exact_match(db_connection: Any, content: str) -> dict[str, Any]:
"""Perform exact text matching against database.
Args:
db_connection: Database connection object
content: Content to match exactly
Returns:
Dictionary with validation results
"""
try:
# Query for exact match in raw_receipt_text field
query = """
SELECT id, vendor_name, transaction_date, transaction_time,
final_total, payment_method, created_at, updated_at
FROM rpt_receipts
WHERE raw_receipt_text = %s
LIMIT 1
"""
async with db_connection.cursor() as cursor:
await cursor.execute(query, (content,))
result = await cursor.fetchone()
if result:
receipt_metadata = {
"receipt_id": result[0],
"vendor_name": result[1],
"transaction_date": result[2].isoformat() if result[2] else None,
"transaction_time": str(result[3]) if result[3] else None,
"final_total": float(result[4]) if result[4] else None,
"payment_method": result[5],
"created_at": result[6].isoformat() if result[6] else None,
"updated_at": result[7].isoformat() if result[7] else None,
}
return {
"document_exists": True,
"postgres_receipt_id": result[0],
"receipt_metadata": receipt_metadata,
"validation_confidence": 1.0,
"match_method": "exact_text_match"
}
else:
return {
"document_exists": False,
"postgres_receipt_id": None,
"receipt_metadata": None,
"validation_confidence": 0.0,
"match_method": "exact_text_match"
}
except Exception as e:
logger.error(f"Exact match validation failed: {e}")
raise DatabaseError(f"Exact match query failed: {e}")
async def _perform_similarity_match(
db_connection: Any, content: str, threshold: float
) -> dict[str, Any]:
"""Perform similarity-based matching against database.
Args:
db_connection: Database connection object
content: Content to match using similarity
threshold: Minimum similarity score to consider a match
Returns:
Dictionary with validation results
"""
try:
# Get recent receipts to compare against (limit for performance)
query = """
SELECT id, vendor_name, transaction_date, transaction_time,
final_total, payment_method, created_at, updated_at, raw_receipt_text
FROM rpt_receipts
WHERE raw_receipt_text IS NOT NULL
ORDER BY created_at DESC
LIMIT 1000
"""
async with db_connection.cursor() as cursor:
await cursor.execute(query)
results = await cursor.fetchall()
best_match = None
best_similarity = 0.0
for row in results:
if row[8]: # raw_receipt_text exists
similarity = _calculate_text_similarity(content, row[8])
if similarity > best_similarity and similarity >= threshold:
best_similarity = similarity
best_match = row
if best_match:
receipt_metadata = {
"receipt_id": best_match[0],
"vendor_name": best_match[1],
"transaction_date": best_match[2].isoformat() if best_match[2] else None,
"transaction_time": str(best_match[3]) if best_match[3] else None,
"final_total": float(best_match[4]) if best_match[4] else None,
"payment_method": best_match[5],
"created_at": best_match[6].isoformat() if best_match[6] else None,
"updated_at": best_match[7].isoformat() if best_match[7] else None,
}
return {
"document_exists": True,
"postgres_receipt_id": best_match[0],
"receipt_metadata": receipt_metadata,
"validation_confidence": best_similarity,
"match_method": "similarity_match"
}
else:
return {
"document_exists": False,
"postgres_receipt_id": None,
"receipt_metadata": None,
"validation_confidence": 0.0,
"match_method": "similarity_match"
}
except Exception as e:
logger.error(f"Similarity match validation failed: {e}")
raise DatabaseError(f"Similarity match query failed: {e}")
async def _perform_metadata_match(db_connection: Any, state: dict[str, Any]) -> dict[str, Any]:
"""Perform metadata-based matching against database.
Args:
db_connection: Database connection object
state: State containing metadata fields like vendor_name, transaction_date, etc.
Returns:
Dictionary with validation results
"""
try:
# Extract metadata from state
vendor_name = state.get("vendor_name")
transaction_date = state.get("transaction_date")
final_total = state.get("final_total")
if not any([vendor_name, transaction_date, final_total]):
return {
"document_exists": False,
"postgres_receipt_id": None,
"receipt_metadata": None,
"validation_confidence": 0.0,
"match_method": "metadata_match"
}
# Build dynamic query based on available metadata
conditions = []
params = []
if vendor_name:
conditions.append("LOWER(vendor_name) = LOWER(%s)")
params.append(vendor_name)
if transaction_date:
conditions.append("transaction_date = %s")
params.append(transaction_date)
if final_total:
# Allow small variance in total amount (e.g., for rounding differences)
conditions.append("ABS(final_total - %s) < 0.01")
params.append(final_total)
if not conditions:
return {
"document_exists": False,
"postgres_receipt_id": None,
"receipt_metadata": None,
"validation_confidence": 0.0,
"match_method": "metadata_match"
}
query = f"""
SELECT id, vendor_name, transaction_date, transaction_time,
final_total, payment_method, created_at, updated_at
FROM rpt_receipts
WHERE {' AND '.join(conditions)}
ORDER BY created_at DESC
LIMIT 1
"""
async with db_connection.cursor() as cursor:
await cursor.execute(query, params)
result = await cursor.fetchone()
if result:
# Calculate confidence based on how many metadata fields matched
confidence = len([x for x in [vendor_name, transaction_date, final_total] if x]) / 3.0
receipt_metadata = {
"receipt_id": result[0],
"vendor_name": result[1],
"transaction_date": result[2].isoformat() if result[2] else None,
"transaction_time": str(result[3]) if result[3] else None,
"final_total": float(result[4]) if result[4] else None,
"payment_method": result[5],
"created_at": result[6].isoformat() if result[6] else None,
"updated_at": result[7].isoformat() if result[7] else None,
}
return {
"document_exists": True,
"postgres_receipt_id": result[0],
"receipt_metadata": receipt_metadata,
"validation_confidence": confidence,
"match_method": "metadata_match"
}
else:
return {
"document_exists": False,
"postgres_receipt_id": None,
"receipt_metadata": None,
"validation_confidence": 0.0,
"match_method": "metadata_match"
}
except Exception as e:
logger.error(f"Metadata match validation failed: {e}")
raise DatabaseError(f"Metadata match query failed: {e}")
__all__ = [
"paperless_document_validator_node",
]

View File

@@ -6,11 +6,10 @@ providing document management capabilities through structured workflow nodes.
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.runnables import RunnableConfig
from langgraph.prebuilt import ToolNode
from biz_bud.core.config.constants import STATE_KEY_MESSAGES
from biz_bud.core.errors import ConfigurationError
@@ -26,14 +25,14 @@ if TYPE_CHECKING:
# Check Paperless client availability and define types
_paperless_available: bool = False
PaperlessClient: type[Any] | None = None
search_paperless_documents: Callable[..., Any] | None = None
get_paperless_document: Callable[..., Any] | None = None
update_paperless_document: Callable[..., Any] | None = None
create_paperless_tag: Callable[..., Any] | None = None
list_paperless_tags: Callable[..., Any] | None = None
list_paperless_correspondents: Callable[..., Any] | None = None
list_paperless_document_types: Callable[..., Any] | None = None
get_paperless_statistics: Callable[..., Any] | None = None
search_paperless_documents: Any | None = None
get_paperless_document: Any | None = None
update_paperless_document: Any | None = None
create_paperless_tag: Any | None = None
list_paperless_tags: Any | None = None
list_paperless_correspondents: Any | None = None
list_paperless_document_types: Any | None = None
get_paperless_statistics: Any | None = None
try:
from biz_bud.tools.capabilities.external.paperless.tool import (
@@ -55,7 +54,7 @@ except ImportError as e:
# All function variables remain None as initialized
def _get_paperless_tools() -> list[Callable[..., Any]]:
def _get_paperless_tools() -> list[Any]:
"""Get Paperless NGX tools directly."""
if not _paperless_available:
logger.warning("Paperless client not available, returning empty tool list")
@@ -111,7 +110,7 @@ async def _get_paperless_client() -> Any | None:
return PaperlessClient()
async def _get_paperless_llm(config: RunnableConfig | None = None) -> BaseChatModel:
async def _get_paperless_llm(config: RunnableConfig) -> BaseChatModel:
"""Get LLM client configured for Paperless operations."""
factory = await get_global_factory()
llm_client = await factory.get_llm_for_node(
@@ -138,7 +137,7 @@ async def _get_paperless_llm(config: RunnableConfig | None = None) -> BaseChatMo
return llm
def _validate_paperless_config(config: RunnableConfig | None) -> dict[str, str]:
def _validate_paperless_config(config: RunnableConfig) -> dict[str, str]:
"""Validate and extract Paperless NGX configuration."""
import os
@@ -183,7 +182,7 @@ def _validate_paperless_config(config: RunnableConfig | None) -> dict[str, str]:
async def paperless_orchestrator_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Orchestrate Paperless NGX document management operations.
@@ -260,17 +259,48 @@ async def paperless_orchestrator_node(
"routing_decision": "end",
}
logger.info(f"Executing {len(tool_calls)} tool calls using ToolNode")
logger.info(f"Executing {len(tool_calls)} tool calls directly")
# Use ToolNode for proper tool execution
tool_node = ToolNode(tools)
tool_result = await tool_node.ainvoke(
{STATE_KEY_MESSAGES: conversation_messages}, config
)
# Execute tools directly without ToolNode
tool_results = []
for tool_call in tool_calls:
tool_name = tool_call["name"]
tool_args = tool_call["args"]
tool_call_id = tool_call["id"]
# Find and execute the tool
tool_func = None
for tool in tools:
if hasattr(tool, 'name') and getattr(tool, 'name', None) == tool_name:
tool_func = tool
break
if tool_func:
try:
# Execute the tool
result = await tool_func.ainvoke(tool_args)
tool_message = ToolMessage(
content=str(result),
tool_call_id=tool_call_id
)
tool_results.append(tool_message)
except Exception as e:
logger.error(f"Error executing tool {tool_name}: {e}")
tool_message = ToolMessage(
content=f"Error: {str(e)}",
tool_call_id=tool_call_id
)
tool_results.append(tool_message)
else:
logger.error(f"Tool {tool_name} not found")
tool_message = ToolMessage(
content=f"Error: Tool {tool_name} not found",
tool_call_id=tool_call_id
)
tool_results.append(tool_message)
# Add tool messages to conversation
if STATE_KEY_MESSAGES in tool_result:
conversation_messages.extend(tool_result[STATE_KEY_MESSAGES])
conversation_messages.extend(tool_results)
# Get final response after tool execution
final_response = await llm_with_tools.ainvoke(conversation_messages, config)
@@ -291,7 +321,7 @@ async def paperless_orchestrator_node(
async def paperless_search_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Execute document search operations in Paperless NGX.
@@ -356,7 +386,7 @@ async def paperless_search_node(
async def paperless_document_retrieval_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Retrieve detailed document information from Paperless NGX.
@@ -426,7 +456,7 @@ async def paperless_document_retrieval_node(
async def paperless_metadata_management_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Manage document metadata and tags in Paperless NGX.
@@ -483,13 +513,18 @@ async def paperless_metadata_management_node(
"metadata_results": {},
}
update_result = await update_paperless_document(
doc_id=doc_id,
title=title,
correspondent_id=correspondent_id,
document_type_id=document_type_id,
tag_ids=tag_ids,
)
# Build arguments for tool invocation
update_args = {"document_id": doc_id}
if title is not None:
update_args["title"] = title
if correspondent_id is not None:
update_args["correspondent_id"] = correspondent_id
if document_type_id is not None:
update_args["document_type_id"] = document_type_id
if tag_ids is not None:
update_args["tag_ids"] = tag_ids
update_result = await update_paperless_document.ainvoke(update_args)
results.append(update_result)
elif operation_type == "create_tag":
@@ -512,10 +547,10 @@ async def paperless_metadata_management_node(
"metadata_results": {},
}
create_result = await create_paperless_tag(
name=tag_name,
color=tag_color,
)
create_result = await create_paperless_tag.ainvoke({
"name": tag_name,
"color": tag_color,
})
results.append(create_result)
elif operation_type == "list_tags":
@@ -528,7 +563,7 @@ async def paperless_metadata_management_node(
"metadata_results": {},
}
tags_result = await list_paperless_tags()
tags_result = await list_paperless_tags.ainvoke({})
results.append(tags_result)
elif operation_type == "list_correspondents":
@@ -541,7 +576,7 @@ async def paperless_metadata_management_node(
"metadata_results": {},
}
correspondents_result = await list_paperless_correspondents()
correspondents_result = await list_paperless_correspondents.ainvoke({})
results.append(correspondents_result)
elif operation_type == "list_document_types":
@@ -554,7 +589,7 @@ async def paperless_metadata_management_node(
"metadata_results": {},
}
types_result = await list_paperless_document_types()
types_result = await list_paperless_document_types.ainvoke({})
results.append(types_result)
elif operation_type == "get_statistics":
@@ -567,7 +602,7 @@ async def paperless_metadata_management_node(
"metadata_results": {},
}
stats_result = await get_paperless_statistics()
stats_result = await get_paperless_statistics.ainvoke({})
results.append(stats_result)
else:
@@ -580,9 +615,9 @@ async def paperless_metadata_management_node(
"metadata_results": {},
}
tags_result = await list_paperless_tags()
correspondents_result = await list_paperless_correspondents()
types_result = await list_paperless_document_types()
tags_result = await list_paperless_tags.ainvoke({})
correspondents_result = await list_paperless_correspondents.ainvoke({})
types_result = await list_paperless_document_types.ainvoke({})
results.extend([tags_result, correspondents_result, types_result])

View File

@@ -19,7 +19,7 @@ from biz_bud.logging import debug_highlight, error_highlight, info_highlight, wa
node_name="paperless_document_processor", metric_name="document_processing"
)
async def process_document_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Process documents for Paperless-NGX upload.
@@ -109,7 +109,7 @@ async def process_document_node(
@standard_node(node_name="paperless_query_builder", metric_name="query_building")
async def build_paperless_query_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Build search queries for Paperless-NGX API.
@@ -188,7 +188,7 @@ async def build_paperless_query_node(
@standard_node(node_name="paperless_result_formatter", metric_name="result_formatting")
async def format_paperless_results_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Format Paperless-NGX search results for presentation.

View File

@@ -0,0 +1,875 @@
"""Receipt processing nodes for Paperless-NGX integration.
This module provides specialized nodes for receipt reconciliation workflows,
including LLM-based structured extraction, product validation against web catalogs,
and canonicalization of product information.
"""
from __future__ import annotations
import asyncio
import json
from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig
from pydantic import BaseModel, Field
from biz_bud.core.errors import create_error_info
from biz_bud.core.langgraph import standard_node
from biz_bud.logging import debug_highlight, error_highlight, info_highlight, warning_highlight
from biz_bud.states.receipt import ReceiptLineItem, ReceiptState
# Pydantic models for structured receipt data (LLM output schemas)
class ReceiptLineItemPydantic(BaseModel):
"""Pydantic model for LLM structured extraction of line items."""
product_name: str = Field(description="Name of the product")
product_code: str = Field(description="Product code or SKU if available")
quantity: str = Field(description="Quantity purchased")
unit_of_measure: str = Field(description="Unit of measurement")
unit_price: str = Field(description="Price per unit")
total_price: str = Field(description="Total price for this line item")
raw_text_snippet: str = Field(description="Original OCR text for this item")
uncertain_fields: str = Field(description="Fields that were ambiguous or uncertain")
class ReceiptExtractionPydantic(BaseModel):
"""Pydantic model for complete structured receipt extraction."""
metadata: str = Field(description="Receipt metadata as string")
line_items: list[ReceiptLineItemPydantic] = Field(description="List of line items")
@standard_node(node_name="receipt_llm_extraction", metric_name="receipt_extraction")
async def receipt_llm_extraction_node(
state: ReceiptState, config: RunnableConfig
) -> dict[str, Any]:
"""Extract structured receipt data using LLM.
This node takes OCR'd receipt text and uses an LLM with structured output
to extract vendor information, line items, and transaction details.
Args:
state: Current workflow state with raw receipt text
config: Optional runtime configuration
Returns:
Updated state with structured receipt data
"""
info_highlight("Extracting structured data from receipt...", category="ReceiptLLMExtraction")
# Get raw receipt text from state
raw_text = state.get("document_content") or state.get("raw_receipt_text", "")
if not raw_text:
warning_highlight("No receipt text provided for extraction", category="ReceiptLLMExtraction")
return {
"receipt_extraction": {},
"processing_stage": "extraction_failed",
"validation_errors": ["No receipt text provided"]
}
try:
# Get LLM service from service factory if available
service_factory = getattr(config, 'service_factory', None) if config else None
if service_factory:
# Use service factory to get LLM service
try:
llm_service = await service_factory.get_service('llm')
llm = llm_service.get_model(model_name="gpt-4-mini") # Use efficient model for structured extraction
except Exception:
# Fallback if service factory not available
llm = None
else:
llm = None
if llm:
# Create structured extraction prompt based on the YAML workflow
system_prompt = """From the following OCR'd receipt text, extract these fields:
- Vendor, Merchant, or Business Name
- Vendor Address
- Transaction Date
- Receipt or Invoice Number
- Total, Subtotal, Tax, Payment Method, Card Info
- For each line item: product_name, product_code (if present), quantity, unit_of_measure, unit_price, total_price, and the snippet of raw OCR that led to these values.
If a value is ambiguous or uncertain, mark it accordingly.
Respond as JSON with fields "metadata" and "line_items" ([list of objects])."""
human_prompt = f"Text: {raw_text}"
# Create messages
messages = [
SystemMessage(content=system_prompt),
HumanMessage(content=human_prompt)
]
# Get structured output using Pydantic model
structured_llm = llm.with_structured_output(ReceiptExtractionPydantic)
response = await structured_llm.ainvoke(messages)
# Convert to dict for state management
extraction_result = {
"metadata": response.metadata,
"line_items": [item.model_dump() for item in response.line_items]
}
info_highlight(
f"LLM extraction complete: found {len(response.line_items)} line items",
category="ReceiptLLMExtraction"
)
else:
# Fallback to mock extraction for development/testing
warning_highlight("No LLM service available, using mock extraction", category="ReceiptLLMExtraction")
extraction_result = _mock_receipt_extraction(raw_text)
return {
"receipt_extraction": extraction_result,
"processing_stage": "extraction_complete",
"validation_errors": []
}
except Exception as e:
error_msg = f"Receipt LLM extraction failed: {str(e)}"
error_highlight(error_msg, category="ReceiptLLMExtraction")
return {
"receipt_extraction": {},
"processing_stage": "extraction_failed",
"validation_errors": [error_msg],
"errors": [
create_error_info(
message=error_msg,
node="receipt_llm_extraction",
severity="error",
category="extraction_error"
)
]
}
@standard_node(node_name="receipt_line_items_parser", metric_name="line_item_parsing")
async def receipt_line_items_parser_node(
state: ReceiptState, config: RunnableConfig
) -> dict[str, Any]:
"""Parse line items from structured receipt extraction.
This node extracts the line_items array from the LLM extraction result
and prepares them for validation processing.
Args:
state: Current workflow state with receipt extraction results
config: Optional runtime configuration
Returns:
Updated state with parsed line items
"""
debug_highlight("Parsing line items from receipt extraction...", category="LineItemParser")
extraction_result = state.get("receipt_extraction", {})
if not extraction_result:
warning_highlight("No extraction result available for parsing", category="LineItemParser")
return {
"line_items": [],
"receipt_metadata": {},
"processing_stage": "parsing_failed",
"validation_errors": ["No extraction result available"]
}
try:
# Extract line items
line_items = extraction_result.get("line_items", [])
metadata = extraction_result.get("metadata", "")
# Parse metadata if it's a string
if isinstance(metadata, str) and metadata.strip():
try:
# Try to parse as JSON if possible
parsed_metadata = json.loads(metadata)
except json.JSONDecodeError:
# Otherwise treat as plain text description
parsed_metadata = {"description": metadata}
else:
parsed_metadata = metadata if isinstance(metadata, dict) else {}
# Validate line items structure
valid_line_items: list[ReceiptLineItem] = []
for i, item in enumerate(line_items):
if item.get("product_name"):
valid_line_item: ReceiptLineItem = {
"index": i,
"product_name": item.get("product_name", ""),
"product_code": item.get("product_code", ""),
"quantity": item.get("quantity", ""),
"unit_of_measure": item.get("unit_of_measure", ""),
"unit_price": item.get("unit_price", ""),
"total_price": item.get("total_price", ""),
"raw_text_snippet": item.get("raw_text_snippet", ""),
"uncertain_fields": item.get("uncertain_fields", ""),
"validation_status": "pending"
}
valid_line_items.append(valid_line_item)
info_highlight(
f"Parsed {len(valid_line_items)} valid line items from extraction",
category="LineItemParser"
)
return {
"line_items": valid_line_items,
"receipt_metadata": parsed_metadata,
"processing_stage": "parsing_complete",
"validation_errors": []
}
except Exception as e:
error_msg = f"Line item parsing failed: {str(e)}"
error_highlight(error_msg, category="LineItemParser")
return {
"line_items": [],
"receipt_metadata": {},
"processing_stage": "parsing_failed",
"validation_errors": [error_msg],
"errors": [
create_error_info(
message=error_msg,
node="receipt_line_items_parser",
severity="error",
category="parsing_error"
)
]
}
@standard_node(node_name="receipt_item_validation", metric_name="item_validation")
async def receipt_item_validation_node(
state: ReceiptState, config: RunnableConfig
) -> dict[str, Any]:
"""Validate receipt line items against web catalogs.
This node implements the iterative validation logic from the YAML workflow,
using agent-based search to validate and canonicalize product information.
Args:
state: Current workflow state with parsed line items
config: Optional runtime configuration
Returns:
Updated state with validated line items
"""
info_highlight("Validating receipt line items against web catalogs...", category="ItemValidation")
line_items = state.get("line_items", [])
if not line_items:
warning_highlight("No line items available for validation", category="ItemValidation")
return {
"validated_items": [],
"canonical_products": [],
"processing_stage": "validation_failed",
"validation_errors": ["No line items to validate"]
}
try:
# Get services from factory if available
service_factory = getattr(config, 'service_factory', None) if config else None
web_search_service = None
llm_service = None
if service_factory:
try:
web_search_service = await service_factory.get_service('web_search')
llm_service = await service_factory.get_service('llm')
except Exception:
pass # Fallback to mock validation
validated_items = []
canonical_products = []
# Process items with concurrency limit to avoid overwhelming services
max_concurrent = 3 # Based on parallel_nums from YAML
semaphore = asyncio.Semaphore(max_concurrent)
async def validate_single_item(item: ReceiptLineItem) -> tuple[dict[str, Any], dict[str, Any]]:
"""Validate a single line item."""
async with semaphore:
return await _validate_product_item(
item, web_search_service, llm_service
)
# Create validation tasks
validation_tasks = [
validate_single_item(item) for item in line_items
]
# Execute validation with timeout
try:
results = await asyncio.wait_for(
asyncio.gather(*validation_tasks, return_exceptions=True),
timeout=120.0 # 2 minute timeout for all validations
)
# Process results
for i, result in enumerate(results):
if isinstance(result, Exception):
error_msg = f"Validation failed for item {i}: {str(result)}"
warning_highlight(error_msg, category="ItemValidation")
# Add failed item with error info
failed_item = dict(line_items[i])
failed_item.update({
"validation_status": "failed",
"validation_error": error_msg
})
validated_items.append(failed_item)
# Add empty canonical product
canonical_products.append({
"original_description": line_items[i].get("product_name", ""),
"verified_description": None,
"canonical_description": None,
"search_variations_used": [],
"successful_variation": None,
"notes": error_msg
})
else:
# Successful validation - result is a tuple
if isinstance(result, tuple) and len(result) == 2:
validated_item, canonical_product = result
validated_items.append(validated_item)
canonical_products.append(canonical_product)
else:
# Unexpected result format
warning_highlight(f"Unexpected result format for item {i}: {type(result)}", category="ItemValidation")
failed_item = dict(line_items[i])
failed_item.update({
"validation_status": "failed",
"validation_error": "Unexpected result format"
})
validated_items.append(failed_item)
canonical_products.append({
"original_description": line_items[i].get("product_name", ""),
"verified_description": None,
"canonical_description": None,
"search_variations_used": [],
"successful_variation": None,
"notes": "Unexpected result format"
})
except asyncio.TimeoutError:
error_msg = "Item validation timed out"
error_highlight(error_msg, category="ItemValidation")
return {
"validated_items": [],
"canonical_products": [],
"processing_stage": "validation_timeout",
"validation_errors": [error_msg]
}
success_count = sum(bool(item.get("validation_status") == "success")
for item in validated_items)
info_highlight(
f"Item validation complete: {success_count}/{len(validated_items)} items successfully validated",
category="ItemValidation"
)
return {
"validated_items": validated_items,
"canonical_products": canonical_products,
"processing_stage": "validation_complete",
"validation_errors": []
}
except Exception as e:
error_msg = f"Item validation failed: {str(e)}"
error_highlight(error_msg, category="ItemValidation")
return {
"validated_items": [],
"canonical_products": [],
"processing_stage": "validation_failed",
"validation_errors": [error_msg],
"errors": [
create_error_info(
message=error_msg,
node="receipt_item_validation",
severity="error",
category="validation_error"
)
]
}
# Helper functions
def _mock_receipt_extraction(raw_text: str) -> dict[str, Any]:
"""Mock receipt extraction for development/testing."""
# Look for common receipt patterns
lines = raw_text.split('\n')
line_items = [
{
"product_name": line.strip(),
"product_code": "",
"quantity": "1",
"unit_of_measure": "each",
"unit_price": "0.00",
"total_price": "0.00",
"raw_text_snippet": line.strip(),
"uncertain_fields": "all fields are mock data",
}
for line in lines[:5]
if any(char.isdigit() for char in line) and len(line.strip()) > 3
]
return {
"metadata": "Mock extraction - vendor and transaction details would be extracted here",
"line_items": line_items
}
async def _validate_product_item(
item: ReceiptLineItem,
web_search_service: Any,
llm_service: Any
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Validate a single product item using web search and LLM canonicalization."""
product_name = item.get("product_name", "")
if not web_search_service or not llm_service:
# Mock validation for development
canonical_product = {
"original_description": product_name,
"verified_description": f"Mock verified: {product_name}",
"canonical_description": f"Mock canonical: {product_name}",
"search_variations_used": [product_name],
"successful_variation": product_name,
"validation_sources": [],
"notes": "Mock validation - no web search service available",
"detailed_validation_notes": "Development mode: Web search and LLM services not available. Using mock canonicalization."
}
validated_item = dict(item) | {
"validation_status": "success",
"canonical_info": canonical_product,
}
return validated_item, canonical_product
try:
# Generate search variations (based on YAML agent prompt)
search_variations = _generate_search_variations(product_name)
# Perform web searches and track sources
search_results = []
validation_sources = []
successful_searches = []
# Wholesale distributor search contexts - target real suppliers
distributor_contexts = [
"Jetro Restaurant Depot wholesale",
"Costco business center food service",
"US Foods distributor catalog",
"Sysco food service products",
"food service distributor wholesale"
]
# Do up to 5 search iterations per item for comprehensive coverage
max_search_iterations = min(5, len(search_variations))
for i, variation in enumerate(search_variations[:max_search_iterations]):
current_distributor_context = distributor_contexts[i % len(distributor_contexts)]
try:
# Use different distributor context for each iteration
query = f"{variation} {current_distributor_context}"
results = await web_search_service.search(query=query, count=4) # More results per search
if results:
search_results.extend(results)
successful_searches.append(variation)
# Track sources from search results
for result in results:
source_info = {
"search_query": query,
"search_variation": variation,
"distributor_context": current_distributor_context,
"title": result.get("title", ""),
"url": result.get("url", ""),
"snippet": result.get("snippet", "")[:200], # Limit snippet length
"relevance_score": result.get("relevance_score", 0.0),
"source": result.get("source", "unknown")
}
validation_sources.append(source_info)
except Exception as e:
debug_highlight(f"Search failed for variation '{variation}' with {current_distributor_context}: {e}", category="ItemValidation")
continue # Skip failed searches
# Use LLM to canonicalize based on search results
canonical_product = await _canonicalize_product(
product_name, search_variations, search_results, llm_service, validation_sources
)
# Add comprehensive validation tracking
canonical_product.update({
"validation_sources": validation_sources,
"search_results_count": len(search_results),
"successful_search_variations": successful_searches,
"total_variations_tried": len(search_variations)
})
# Update validated item
validated_item = dict(item)
validated_item |= {
"validation_status": "success",
"canonical_info": canonical_product,
"search_results_count": len(search_results),
}
return validated_item, canonical_product
except Exception as e:
# Handle validation failure with detailed error tracking
# Initialize variables with defaults to avoid unbound errors
safe_search_variations = locals().get('search_variations', [])
safe_validation_sources = locals().get('validation_sources', [])
canonical_product = {
"original_description": product_name,
"verified_description": None,
"canonical_description": None,
"search_variations_used": safe_search_variations,
"successful_variation": None,
"validation_sources": safe_validation_sources,
"notes": f"Validation failed: {str(e)}",
"detailed_validation_notes": f"Error during validation process: {str(e)}. Search variations attempted: {safe_search_variations or 'none'}. Sources found: {len(safe_validation_sources)}"
}
validated_item = dict(item)
validated_item |= {
"validation_status": "failed",
"validation_error": str(e),
}
return validated_item, canonical_product
def _generate_search_variations(product_name: str) -> list[str]:
"""Generate intelligent search variations for restaurant/food service products.
Focuses on creating meaningful search terms that will find real products
in wholesale food distributor catalogs, not random word combinations.
"""
import re
product_lower = product_name.lower().strip()
# Skip if product name is too short or nonsensical
if len(product_lower) < 3:
return [product_name]
# 1. Original product name (cleaned)
original_clean = ' '.join(product_name.split())
variations = [original_clean]
# 2. Expand common restaurant/food service abbreviations first
abbreviation_map = {
'chkn': 'chicken', 'chk': 'chicken',
'brst': 'breast', 'bst': 'breast',
'frz': 'frozen', 'frzn': 'frozen',
'oz': 'ounce', 'lb': 'pound', 'lbs': 'pounds',
'gal': 'gallon', 'qt': 'quart', 'pt': 'pint',
'pkg': 'package', 'bx': 'box', 'cs': 'case',
'btl': 'bottle', 'can': 'can', 'jar': 'jar',
'reg': 'regular', 'org': 'organic', 'nat': 'natural',
'whl': 'whole', 'slcd': 'sliced', 'chpd': 'chopped',
'grnd': 'ground', 'dcd': 'diced', 'shrd': 'shredded',
'hvy': 'heavy', 'lt': 'light', 'md': 'medium', 'lg': 'large',
'bev': 'beverage', 'drk': 'drink',
'veg': 'vegetable', 'vegs': 'vegetables',
'ppr': 'paper', 'twr': 'towel', 'nap': 'napkin',
'utl': 'utensil', 'utn': 'utensil', 'fork': 'forks', 'spn': 'spoon',
'pls': 'plastic', 'styr': 'styrofoam', 'fom': 'foam',
'wht': 'white', 'blk': 'black', 'clr': 'clear'
}
# Create expanded version with abbreviations
expanded = product_lower
for abbr, full in abbreviation_map.items():
expanded = re.sub(r'\b' + abbr + r'\b', full, expanded)
if expanded != product_lower:
variations.append(expanded.title())
# 3. Extract meaningful product keywords (ignore meaningless words)
meaningless_words = {
'wrapped', 'heavy', 'light', 'medium', 'large', 'small', 'big', 'little',
'good', 'bad', 'best', 'nice', 'fine', 'great', 'super', 'ultra',
'special', 'premium', 'quality', 'grade', 'standard', 'regular',
'new', 'old', 'fresh', 'dry', 'wet', 'hot', 'cold', 'warm', 'cool'
}
# Extract meaningful words (longer than 3 chars, not meaningless)
words = re.findall(r'\b\w{3,}\b', expanded.lower())
if meaningful_words := [
w for w in words if w not in meaningless_words and len(w) > 3
]:
# Use the most meaningful words
if len(meaningful_words) >= 2:
# Combine top 2 meaningful words
variations.append(f"{meaningful_words[0]} {meaningful_words[1]}")
# Add single most meaningful word if it's likely a product
main_word = meaningful_words[0] if meaningful_words else None
if main_word and len(main_word) > 4:
variations.append(main_word)
# 5. Handle specific product categories intelligently
if any(word in product_lower for word in ['fork', 'spoon', 'knife', 'utensil']):
variations.extend(['disposable utensils', 'plastic forks', 'food service utensils'])
elif any(word in product_lower for word in ['cup', 'container', 'bowl', 'plate']):
variations.extend(['disposable containers', 'food service containers', 'takeout containers'])
elif any(word in product_lower for word in ['napkin', 'towel', 'tissue']):
variations.extend(['paper napkins', 'food service napkins', 'restaurant napkins'])
elif any(word in product_lower for word in ['chicken', 'beef', 'pork', 'meat']):
variations.extend(['frozen meat', 'food service meat', 'restaurant protein'])
elif any(word in product_lower for word in ['cheese', 'dairy', 'milk', 'butter']):
variations.extend(['dairy products', 'food service dairy', 'restaurant cheese'])
elif any(word in product_lower for word in ['bread', 'bun', 'roll', 'bakery']):
variations.extend(['bakery products', 'food service bread', 'restaurant buns'])
# 6. Remove duplicates and empty strings while preserving order
clean_variations = []
seen = set()
for var in variations:
var_clean = var.strip()
if var_clean and len(var_clean) > 2 and var_clean.lower() not in seen:
clean_variations.append(var_clean)
seen.add(var_clean.lower())
# 7. Limit to most relevant variations (prioritize quality over quantity)
return clean_variations[:8] # Reduced from 10 to focus on quality
async def _canonicalize_product(
original: str,
variations: list[str],
search_results: list[Any],
llm_service: Any,
validation_sources: list[dict[str, Any]]
) -> dict[str, Any]:
"""Use LLM to canonicalize product information based on web search results.
IMPORTANT: Only produces canonical names when we have genuine verification from
wholesale food distributors. Returns None for canonical_description if validation
is insufficient to avoid polluting database with fake data.
"""
try:
# Check if we have sufficient validation sources from food distributors
has_distributor_sources = any(
any(distributor in source.get('distributor_context', '').lower() for distributor in
['jetro', 'restaurant depot', 'costco', 'us foods', 'sysco'])
for source in validation_sources
)
# Require minimum validation threshold
sufficient_validation = (
len(validation_sources) >= 2 and # At least 2 sources
has_distributor_sources and # At least one from known distributors
len(search_results) >= 3 # At least 3 search results
)
# Create rich context for canonicalization with source information
source_summary = []
for source in validation_sources[:5]: # Use top 5 sources
distributor = source.get('distributor_context', 'unknown')
source_summary.append(f"- {source['title']} ({distributor}): {source['snippet']}")
search_context = f"""
Original product: {original}
Search variations tried: {', '.join(variations)}
Results found: {len(search_results)}
Distributor sources: {has_distributor_sources}
Sufficient validation: {sufficient_validation}
Top sources consulted:
{chr(10).join(source_summary) if source_summary else 'No sources found'}
"""
if sufficient_validation:
# Use LLM for detailed canonicalization with reasoning
llm = llm_service.get_model(model_name="gpt-4-mini")
prompt = f"""You are a restaurant inventory specialist. Canonicalize this product description ONLY if you can verify it from wholesale food distributor sources.
{search_context}
STRICT RULES:
- Only canonicalize if you find genuine matches in Jetro, Restaurant Depot, Costco Business, US Foods, or Sysco
- Do NOT create generic names like "General Groceries" or "Unknown Product"
- Do NOT guess or make up product names
- If uncertain, return the original description unchanged
Analyze the search results and provide:
1. A verified description (what you confirmed from distributor sources)
2. A canonical description (standardized format ONLY if genuinely verified)
3. Detailed reasoning explaining which distributor sources confirmed the product
Return JSON with: original_description, verified_description, canonical_description, notes, detailed_validation_notes
"""
# Get LLM response and parse (simplified for now - would use structured output in production)
_ = await llm.ainvoke(prompt) # Response not currently used, avoiding unused variable
# Extract key information from sources for reasoning
distributor_sources = [s for s in validation_sources if
any(dist in s.get('distributor_context', '').lower() for dist in
['jetro', 'restaurant depot', 'costco', 'us foods', 'sysco'])]
reasoning_notes = []
if distributor_sources:
reasoning_notes.append(f"Verified through {len(distributor_sources)} distributor sources:")
for i, source in enumerate(distributor_sources[:3], 1):
reasoning_notes.append(f"{i}. {source['distributor_context']}{source['title']} (score: {source['relevance_score']:.2f})")
else:
reasoning_notes.append("No distributor sources found - using fallback")
# Build canonical description ONLY if we have genuine verification
canonical_desc = None
verified_desc = original # Default to original
best_source = None # Initialize to avoid unbound error
if distributor_sources:
# Look for genuine product matches in distributor sources
best_source = max(distributor_sources, key=lambda s: s.get('relevance_score', 0))
if best_source.get('relevance_score', 0) > 0.7: # High confidence threshold
canonical_desc = _build_canonical_description(original, search_results, validation_sources)
verified_desc = canonical_desc
detailed_notes = f"""
Verification Process:
- Original: {original}
- Search variations: {', '.join(variations)}
- Total sources: {len(validation_sources)}
- Distributor sources: {len(distributor_sources)}
- Verification threshold met: {canonical_desc is not None}
- Best matching results: {', '.join([s['title'][:50] for s in distributor_sources[:2]])}
- Reasoning: {' '.join(reasoning_notes)}
"""
return {
"original_description": original,
"verified_description": verified_desc,
"canonical_description": canonical_desc, # None if not verified
"search_variations_used": variations,
"successful_variation": variations[0] if variations else None,
"notes": f"Validated against {len(distributor_sources)} distributor sources" if canonical_desc else "Insufficient verification - kept original",
"detailed_validation_notes": detailed_notes.strip(),
"confidence_score": best_source.get('relevance_score', 0.5) if best_source is not None else 0.5
}
else:
# Insufficient validation - return original unchanged
detailed_notes = f"""
Verification Process (Insufficient):
- Original: {original}
- Search variations attempted: {', '.join(variations)}
- Sources found: {len(validation_sources)}
- Distributor sources: {has_distributor_sources}
- Verification threshold: NOT MET
- Result: Keeping original description unchanged to avoid fake data
"""
return {
"original_description": original,
"verified_description": original,
"canonical_description": None, # No canonicalization
"search_variations_used": variations,
"successful_variation": None,
"notes": "Insufficient distributor verification - avoided fake canonicalization",
"detailed_validation_notes": detailed_notes.strip(),
"confidence_score": 0.3 # Low confidence
}
except Exception as e:
# Error handling - return original unchanged
detailed_notes = f"""
Verification Process (Error):
- Original: {original}
- Error: {str(e)}
- Search variations attempted: {', '.join(variations)}
- Sources found: {len(validation_sources)}
- Result: Error occurred, keeping original unchanged
"""
return {
"original_description": original,
"verified_description": original,
"canonical_description": None, # No canonicalization due to error
"search_variations_used": variations,
"successful_variation": None,
"notes": f"Error during validation: {str(e)} - kept original",
"detailed_validation_notes": detailed_notes.strip(),
"confidence_score": 0.1 # Very low confidence
}
def _build_canonical_description(
original: str,
search_results: list[Any],
validation_sources: list[dict[str, Any]]
) -> str:
"""Build canonical description using rule-based approach."""
# Start with original and apply transformations
canonical = original.strip().title()
# Apply common canonicalization rules
replacements = {
'Chkn': 'Chicken',
'Brst': 'Breast',
'Frz': 'Frozen',
'Oz': 'ounce',
'Lb': 'pound',
'Gal': 'gallon',
'Btl': 'Bottle',
'Can': 'Can',
'Pkg': 'Package',
'Bx': 'Box',
'Coca': 'Coca-Cola',
'Pepsi': 'Pepsi-Cola'
}
for abbr, full in replacements.items():
canonical = canonical.replace(abbr, full)
# If we have good source information, try to extract brand/product info
if validation_sources:
# Look for brand names in titles
for source in validation_sources:
title = source.get('title', '').lower()
if 'coca-cola' in title and 'coca' in original.lower():
canonical = canonical.replace('Coca', 'Coca-Cola')
elif 'pepsi' in title and 'pepsi' in original.lower():
canonical = canonical.replace('Pepsi', 'Pepsi-Cola')
return canonical
__all__ = [
"receipt_llm_extraction_node",
"receipt_line_items_parser_node",
"receipt_item_validation_node",
"ReceiptExtractionPydantic",
"ReceiptLineItemPydantic"
]

View File

@@ -17,7 +17,7 @@ from biz_bud.core.langgraph import standard_node
from biz_bud.logging import get_logger, info_highlight
if TYPE_CHECKING:
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.state import CompiledStateGraph as CompiledGraph
logger = get_logger(__name__)
@@ -39,7 +39,7 @@ class DocumentProcessingState(TypedDict, total=False):
@standard_node(node_name="analyze_document", metric_name="document_analysis")
async def analyze_document_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Analyze document to determine processing requirements."""
document_path = state.get("document_path", "")
@@ -59,7 +59,7 @@ async def analyze_document_node(
@standard_node(node_name="extract_text", metric_name="text_extraction")
async def extract_text_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Extract text from document."""
document_type = state.get("document_type", "unknown")
@@ -79,7 +79,7 @@ async def extract_text_node(
@standard_node(node_name="extract_metadata", metric_name="metadata_extraction")
async def extract_metadata_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Extract metadata from document."""
document_type = state.get("document_type", "unknown")
@@ -96,7 +96,7 @@ async def extract_metadata_node(
return {"extracted_metadata": metadata}
def create_document_processing_subgraph() -> CompiledGraph:
def create_document_processing_subgraph() -> CompiledGraph[Any]:
"""Create document processing subgraph.
This subgraph handles the specialized task of processing documents
@@ -134,7 +134,7 @@ class TagSuggestionState(TypedDict, total=False):
@standard_node(node_name="analyze_content_for_tags", metric_name="tag_analysis")
async def analyze_content_for_tags_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> Command[Literal["suggest_tags", "skip_suggestions"]]:
"""Analyze content to determine if tag suggestions are needed."""
content = state.get("document_content", "")
@@ -158,7 +158,7 @@ async def analyze_content_for_tags_node(
@standard_node(node_name="suggest_tags", metric_name="tag_suggestion")
async def suggest_tags_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Suggest tags based on document content."""
content = state.get("document_content", "").lower()
@@ -193,7 +193,7 @@ async def suggest_tags_node(
@standard_node(node_name="return_to_parent", metric_name="parent_return")
async def return_to_parent_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> Command[str]:
"""Return control to parent graph with results."""
info_highlight("Returning to parent graph with tag suggestions", category="TagSubgraph")
@@ -209,7 +209,7 @@ async def return_to_parent_node(
)
def create_tag_suggestion_subgraph() -> CompiledGraph:
def create_tag_suggestion_subgraph() -> CompiledGraph[Any]:
"""Create tag suggestion subgraph.
This subgraph demonstrates returning to parent graph using Command.PARENT.
@@ -246,7 +246,7 @@ class DocumentSearchState(TypedDict, total=False):
@standard_node(node_name="execute_search", metric_name="document_search")
async def execute_search_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Execute document search."""
query = state.get("search_query", "")
@@ -281,7 +281,7 @@ async def execute_search_node(
@standard_node(node_name="rank_results", metric_name="result_ranking")
async def rank_results_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Rank search results by relevance."""
results = state.get("search_results", [])
@@ -295,7 +295,7 @@ async def rank_results_node(
}
def create_document_search_subgraph() -> CompiledGraph:
def create_document_search_subgraph() -> CompiledGraph[Any]:
"""Create document search subgraph.
This subgraph handles specialized document search operations.

View File

@@ -109,7 +109,7 @@ CONSIDERATIONS: <any special notes>
@standard_node(node_name="input_processing", metric_name="planner_input_processing")
@ensure_immutable_node
async def input_processing_node(
state: PlannerState, config: RunnableConfig | None = None
state: PlannerState, config: RunnableConfig
) -> dict[str, Any]:
"""Process and validate input using existing input.py functions.
@@ -281,7 +281,7 @@ async def query_decomposition_node(state: PlannerState) -> dict[str, Any]:
@standard_node(node_name="agent_selection", metric_name="planner_agent_selection")
@ensure_immutable_node
async def agent_selection_node(
state: PlannerState, config: RunnableConfig | None = None
state: PlannerState, config: RunnableConfig
) -> dict[str, Any]:
"""Select appropriate graphs for each step using LLM reasoning.
@@ -598,7 +598,7 @@ async def router_node(
@standard_node(node_name="execute_graph", metric_name="planner_execute_graph")
@ensure_immutable_node
async def execute_graph_node(
state: PlannerState, config: RunnableConfig | None = None
state: PlannerState, config: RunnableConfig
) -> Command[Literal["router", "__end__"]]:
"""Execute the selected graph as a subgraph.
@@ -877,7 +877,7 @@ def compile_planner_graph():
return create_planner_graph(config)
def planner_graph_factory(config: RunnableConfig | None = None):
def planner_graph_factory(config: RunnableConfig):
"""Create planner graph for LangGraph API.
Args:

View File

@@ -151,7 +151,7 @@ _should_scrape_or_skip = create_list_length_router(
_should_process_next_url = create_bool_router("finalize", "check_duplicate", "batch_complete")
def create_url_to_r2r_graph(config: dict[str, Any] | None = None) -> CompiledStateGraph:
def create_url_to_r2r_graph(config: dict[str, Any] | None = None) -> CompiledStateGraph[URLToRAGState]:
"""Create the URL to R2R processing graph with iterative URL processing.
This graph processes URLs one at a time through the complete pipeline,
@@ -222,7 +222,7 @@ def create_url_to_r2r_graph(config: dict[str, Any] | None = None) -> CompiledSta
builder.add_node("scrape_url", batch_process_urls_node) # Process URL batch
# Repomix for git repos
builder.add_node("repomix_process", repomix_process_node)
builder.add_node("repomix_process", repomix_process_node) # type: ignore[arg-type]
# Analysis and upload
builder.add_node("analyze_content", analyze_content_for_rag_node)
@@ -536,7 +536,7 @@ stream_url_to_r2r = _stream_url_to_r2r
process_url_to_r2r_with_streaming = _process_url_to_r2r_with_streaming
def url_to_rag_graph_factory(config: RunnableConfig) -> "CompiledStateGraph":
def url_to_rag_graph_factory(config: RunnableConfig) -> "CompiledStateGraph[URLToRAGState]":
"""Create URL to RAG graph for graph-as-tool pattern.
This factory function follows the standard pattern for graphs

View File

@@ -31,7 +31,7 @@ firecrawl_batch_scrape_node = None
@standard_node(node_name="vector_store_upload", metric_name="vector_upload")
async def vector_store_upload_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Upload prepared content to vector store.
@@ -104,7 +104,7 @@ async def vector_store_upload_node(
@standard_node(node_name="git_repo_processor", metric_name="git_processing")
async def process_git_repository_node(
state: dict[str, Any], config: RunnableConfig | None = None
state: dict[str, Any], config: RunnableConfig
) -> dict[str, Any]:
"""Process Git repository for RAG ingestion.

View File

@@ -0,0 +1,184 @@
"""Integration nodes for the RAG workflow.
This module contains nodes that integrate with external services
specifically for RAG workflows, including Repomix for Git repositories
and various vector store integrations.
"""
from __future__ import annotations
from typing import Any
from langchain_core.runnables import RunnableConfig
from biz_bud.core.errors import create_error_info
from biz_bud.core.langgraph import standard_node
from biz_bud.logging import debug_highlight, error_highlight, info_highlight, warning_highlight
# Import from local nodes directory
try:
from .nodes.integrations import repomix_process_node
_legacy_imports_available = True
except ImportError:
_legacy_imports_available = False
repomix_process_node = None
# Firecrawl nodes don't currently exist but placeholders for future implementation
firecrawl_scrape_node = None
firecrawl_batch_scrape_node = None
@standard_node(node_name="vector_store_upload", metric_name="vector_upload")
async def vector_store_upload_node(
state: dict[str, Any], config: RunnableConfig | None
) -> dict[str, Any]:
"""Upload prepared content to vector store.
This node handles the upload of prepared content to various vector stores
(R2R, Pinecone, Qdrant, etc.) based on configuration.
Args:
state: Current workflow state with prepared content
config: Optional runtime configuration
Returns:
Updated state with upload results
"""
info_highlight("Uploading content to vector store...", category="VectorUpload")
prepared_content = state.get("rag_prepared_content", [])
if not prepared_content:
warning_highlight("No prepared content to upload", category="VectorUpload")
return {"upload_results": {"documents_uploaded": 0}}
collection_name = state.get("collection_name", "default")
try:
# Simulate upload (real implementation would use actual vector store service)
uploaded_count = 0
failed_uploads = []
for item in prepared_content:
if not item.get("ready_for_upload"):
continue
# In real implementation, this would upload to vector store
# For now, just track the upload
uploaded_count += 1
debug_highlight(
f"Uploaded document from {item.get('url', 'unknown')} to {collection_name}",
category="VectorUpload",
)
upload_results = {
"documents_uploaded": uploaded_count,
"failed_uploads": failed_uploads,
"collection_name": collection_name,
"vector_store": state.get("vector_store_type", "r2r"),
}
info_highlight(
f"Successfully uploaded {uploaded_count} documents to {collection_name}",
category="VectorUpload",
)
return {"upload_results": upload_results}
except Exception as e:
error_msg = f"Vector store upload failed: {str(e)}"
error_highlight(error_msg, category="VectorUpload")
return {
"upload_results": {"documents_uploaded": 0},
"errors": [
create_error_info(
message=error_msg,
node="vector_store_upload",
severity="error",
category="upload_error",
)
],
}
@standard_node(node_name="git_repo_processor", metric_name="git_processing")
async def process_git_repository_node(
state: dict[str, Any], config: RunnableConfig | None
) -> dict[str, Any]:
"""Process Git repository for RAG ingestion.
This node handles the special case of processing Git repositories,
including code analysis and documentation extraction.
Args:
state: Current workflow state
config: Optional runtime configuration
Returns:
Updated state with repository processing results
"""
info_highlight("Processing Git repository...", category="GitProcessor")
if not state.get("is_git_repo"):
return state
input_url = state.get("input_url", "")
try:
# Use repomix if available
if repomix_process_node and _legacy_imports_available:
return await repomix_process_node(state, config)
# Fallback implementation
warning_highlight(
"Repomix not available, using fallback Git processing",
category="GitProcessor",
)
# Extract repo info from URL
repo_info = {
"url": input_url,
"type": "git",
"files_processed": 0,
"documentation_found": False,
}
# In a real implementation, this would clone and analyze the repo
# For now, return placeholder results
return {
"repomix_output": {
"content": f"# Repository: {input_url}\n\nRepository processing not fully implemented.",
"metadata": repo_info,
},
"git_processing_complete": True,
}
except Exception as e:
error_msg = f"Git repository processing failed: {str(e)}"
error_highlight(error_msg, category="GitProcessor")
return {
"git_processing_complete": False,
"errors": [
create_error_info(
message=error_msg,
node="git_repo_processor",
severity="error",
category="git_processing_error",
)
],
}
# Export all integration nodes
__all__ = [
"vector_store_upload_node",
"process_git_repository_node",
]
# Re-export legacy nodes if available
if _legacy_imports_available:
if repomix_process_node:
__all__.append("repomix_process_node")
if firecrawl_scrape_node:
__all__.extend(["firecrawl_scrape_node", "firecrawl_batch_scrape_node"])

View File

@@ -20,7 +20,7 @@ logger = get_logger(__name__)
async def check_existing_content_node(
state: RAGAgentState, config: RunnableConfig | None = None
state: RAGAgentState, config: RunnableConfig
) -> dict[str, Any]:
"""Check if URL content already exists in knowledge stores.
@@ -109,7 +109,7 @@ async def check_existing_content_node(
async def decide_processing_node(
state: RAGAgentState, config: RunnableConfig | None = None
state: RAGAgentState, config: RunnableConfig
) -> dict[str, Any]:
"""Decide whether to process the URL based on existing content.
@@ -166,7 +166,7 @@ async def decide_processing_node(
async def determine_processing_params_node(
state: RAGAgentState, config: RunnableConfig | None = None
state: RAGAgentState, config: RunnableConfig
) -> dict[str, Any]:
"""Determine optimal parameters for URL processing using LLM analysis.
@@ -359,7 +359,7 @@ async def determine_processing_params_node(
async def invoke_url_to_rag_node(
state: RAGAgentState, config: RunnableConfig | None = None
state: RAGAgentState, config: RunnableConfig
) -> dict[str, Any]:
"""Invoke the url_to_rag graph with determined parameters.

View File

@@ -0,0 +1,507 @@
"""Node implementations for the RAG agent with content deduplication."""
from __future__ import annotations
import hashlib
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, cast
from urllib.parse import urlparse
from langchain_core.runnables import RunnableConfig
# Removed broken core import
from biz_bud.logging import get_logger, info_highlight
if TYPE_CHECKING:
from biz_bud.services.vector_store import VectorStore
from biz_bud.states.rag_agent import RAGAgentState
logger = get_logger(__name__)
async def check_existing_content_node(
state: RAGAgentState, config: RunnableConfig | None
) -> dict[str, Any]:
"""Check if URL content already exists in knowledge stores.
Query the VectorStore to find existing content for the given URL.
Calculate content age if found and return metadata for decision making.
Args:
state: Current agent state containing input_url and config.
Returns:
State updates with existing content information:
- url_hash: SHA256 hash of URL (16 chars)
- existing_content: Found content metadata or None
- content_age_days: Age in days or None
- rag_status: Updated to "checking"
"""
from biz_bud.core.config.schemas import AppConfig
from biz_bud.services.factory import ServiceFactory
url = state["input_url"]
state_config = state["config"]
# Generate URL hash for efficient lookup
url_hash = hashlib.sha256(url.encode()).hexdigest()[:16]
vector_store: VectorStore | None = None
try:
# Initialize services using dependency injection
# Convert dict config to AppConfig
if isinstance(state_config, AppConfig):
app_config = state_config
else:
app_config = AppConfig.model_validate(state_config)
factory = ServiceFactory(app_config)
vector_store = await factory.get_vector_store()
await vector_store.initialize()
await vector_store.initialize_collection()
# Search for existing content by URL with high similarity threshold
existing_results = await vector_store.semantic_search(
query=url, filters={"source_url": url}, top_k=1, score_threshold=0.9
)
if existing_results:
result = existing_results[0]
metadata = result["metadata"]
# Calculate content age for freshness checking
if "indexed_at" in metadata:
indexed_date = datetime.fromisoformat(metadata["indexed_at"])
age_days = (datetime.now(UTC) - indexed_date).days
else:
age_days = None
logger.info(f"Found existing content for {url}, age: {age_days} days")
return {
"url_hash": url_hash,
"existing_content": result,
"content_age_days": age_days,
"rag_status": "checking",
}
logger.info(f"No existing content found for {url}")
return {
"url_hash": url_hash,
"existing_content": None,
"content_age_days": None,
"rag_status": "checking",
}
except Exception as e:
logger.error(f"Error checking existing content: {e}")
result = {
"url_hash": url_hash,
"existing_content": None,
"content_age_days": None,
"rag_status": "checking",
"error": str(e),
}
if vector_store is not None:
await vector_store.cleanup()
return result
async def decide_processing_node(
state: RAGAgentState, config: RunnableConfig | None
) -> dict[str, Any]:
"""Decide whether to process the URL based on existing content.
Apply business logic to determine if content should be processed:
- New content is always processed
- Stale content (older than max_age) is reprocessed
- Force refresh overrides all checks
- Fresh content is skipped unless forced
Args:
state: Current agent state with content check results.
Returns:
State updates with processing decision:
- should_process: Boolean decision
- processing_reason: Human-readable explanation
- rag_status: Updated to "decided"
"""
force_refresh = state.get("force_refresh", False)
existing_content = state.get("existing_content")
content_age_days = state.get("content_age_days")
# Initialize decision variables
should_process = True
reason = "New content"
if existing_content and not force_refresh:
# Check content freshness against configured threshold
max_age_days = (
state["config"].get("rag_config", {}).get("max_content_age_days", 7)
)
if content_age_days is not None and content_age_days <= max_age_days:
should_process = False
reason = f"Content is fresh ({content_age_days} days old)"
elif content_age_days is not None:
reason = f"Content is stale ({content_age_days} days old)"
else:
reason = "Content exists but age unknown, reprocessing"
elif force_refresh:
reason = "Force refresh requested"
info_highlight(
f"Processing decision for {state['input_url']}: {should_process} - {reason}"
)
return {
"should_process": should_process,
"processing_reason": reason,
"rag_status": "decided",
}
async def determine_processing_params_node(
state: RAGAgentState, config: RunnableConfig | None
) -> dict[str, Any]:
"""Determine optimal parameters for URL processing using LLM analysis.
Uses an LLM to analyze:
- User input/query to understand intent
- URL structure and type
- Context from conversation history
Then intelligently sets parameters for both scraping and RAG processing.
Args:
state: Current agent state with URL and metadata.
Returns:
State updates with processing parameters:
- scrape_params: Firecrawl/scraping configuration
- r2r_params: R2R chunking configuration
"""
from biz_bud.core.utils.url_analyzer import analyze_url_type
from .scraping.url_analyzer import analyze_url_for_params_node
url = state["input_url"]
parsed = urlparse(url)
# Check if scrape_params were already provided in state (user override)
if state.get("scrape_params") and state["scrape_params"]:
logger.info("Using user-provided scrape_params from state")
return {
"scrape_params": state["scrape_params"],
"r2r_params": state.get(
"r2r_params",
{
"chunk_method": "naive",
"chunk_token_num": 512,
"layout_recognize": "DeepDOC",
},
),
}
# Check if it's a Git repository first using centralized analyzer
url_analysis = analyze_url_type(url)
is_git_repo = url_analysis.get("is_git_repo", False)
# If it's a git repository, skip LLM analysis and return minimal params
if is_git_repo:
logger.info(
f"Detected git repository URL: {url}, skipping scraping parameter analysis"
)
# Return minimal params since repomix will handle the repository
return {
"scrape_params": {
"max_depth": 0,
"max_pages": 1,
"include_subdomains": False,
"wait_for_selector": None,
"screenshot": False,
},
"r2r_params": {
"chunk_method": "markdown",
"chunk_token_num": 1024,
"layout_recognize": "DeepDOC",
},
}
# For non-git URLs, use LLM to analyze and get recommendations
# Cast to dict to satisfy type checker
llm_result = await analyze_url_for_params_node(dict(state))
url_params = llm_result.get("url_processing_params")
# Initialize scrape_params with proper type
scrape_params: dict[str, Any]
# Use LLM recommendations if available, otherwise use defaults
if url_params:
logger.info(f"Using LLM-recommended parameters: {url_params['rationale']}")
# Validate and sanitize LLM-provided parameters
def validate_int_param(
value: int | str | None, min_val: int, max_val: int, default: int
) -> int:
"""Validate integer parameter is within acceptable range."""
if value is None:
return default
try:
val = int(value)
if min_val <= val <= max_val:
return val
logger.warning(
f"Value {val} out of range [{min_val}, {max_val}], using default {default}"
)
return default
except (TypeError, ValueError):
logger.warning(
f"Invalid integer value {value}, using default {default}"
)
return default
def validate_bool_param(value: bool | str | int | None, default: bool) -> bool:
"""Validate boolean parameter."""
if value is None:
return default
if isinstance(value, bool):
return value
if isinstance(value, str):
return value.lower() in ("true", "yes", "1")
if type(value) is int:
return bool(value)
logger.warning(f"Invalid boolean value {value}, using default {default}")
return default
# Apply validation to each parameter
scrape_params = {
"max_depth": validate_int_param(url_params.get("max_depth", 2), 1, 10, 2),
"max_pages": validate_int_param(
url_params.get("max_pages", 50), 1, 1000, 50
),
"include_subdomains": validate_bool_param(
url_params.get("include_subdomains", False), False
),
"follow_external_links": validate_bool_param(
url_params.get("follow_external_links", False), False
),
"wait_for_selector": None,
"screenshot": False,
}
# Also use the priority paths if available
if url_params.get("priority_paths"):
scrape_params["priority_paths"] = url_params["priority_paths"]
else:
# Fallback to defaults if LLM analysis failed
logger.info("Using default parameters (LLM analysis unavailable)")
# Get values from config if available
state_config = state.get("config", {})
rag_config = state_config.get("rag_config", {})
scrape_params = {
"max_depth": rag_config.get("crawl_depth", 2),
"max_pages": rag_config.get("max_pages_to_crawl", 50),
"include_subdomains": False,
"wait_for_selector": None,
"screenshot": False,
}
# Default R2R parameters for general content
r2r_params = {
"chunk_method": "naive",
"chunk_token_num": 512,
"layout_recognize": "DeepDOC",
}
# Customize R2R params based on URL patterns and content type
# BUT don't override scrape_params if they came from user instructions via LLM
if "docs" in parsed.path or "documentation" in parsed.path:
# Documentation sites benefit from markdown chunking
r2r_params["chunk_method"] = "markdown"
# Only adjust scrape params if we're using defaults (no LLM recommendations)
if not url_params:
# Get values from config if available
state_config = state.get("config", {})
rag_config = state_config.get("rag_config", {})
# Use config values with fallbacks
scrape_params["max_depth"] = rag_config.get("crawl_depth", 3)
scrape_params["max_pages"] = rag_config.get("max_pages_to_crawl", 100)
elif any(ext in parsed.path for ext in [".pdf", ".docx", ".xlsx"]):
# Document URLs need specialized handling
r2r_params["layout_recognize"] = "DeepDOC"
scrape_params["max_pages"] = 1
# Adjust for large sites based on existing content
if state.get("existing_content"):
existing_content = state["existing_content"]
metadata = existing_content.get("metadata", {}) if existing_content else {}
if metadata.get("total_pages", 0) > 100:
# Large site - be more selective to avoid overload
scrape_params["max_depth"] = 1
scrape_params["max_pages"] = 25
logger.info(
f"Determined processing params for {url}: scrape={scrape_params}, r2r={r2r_params}"
)
return {"scrape_params": scrape_params, "r2r_params": r2r_params}
async def invoke_url_to_rag_node(
state: RAGAgentState, config: RunnableConfig | None
) -> dict[str, Any]:
"""Invoke the url_to_rag graph with determined parameters.
Execute the existing url_to_rag graph if processing is needed.
Store processing metadata for future deduplication.
Args:
state: Current agent state with processing decision and parameters.
Returns:
State updates with processing results:
- processing_result: Output from url_to_rag or skip reason
- rag_status: Updated to "completed" or "error"
- error: Set if processing fails
"""
if not state.get("should_process", True):
logger.info("Skipping processing as content is fresh")
return {
"processing_result": {
"skipped": True,
"reason": state.get("processing_reason", "Content already exists"),
},
"rag_status": "completed",
}
from langgraph.config import get_stream_writer
from biz_bud.graphs.rag.graph import process_url_to_r2r_with_streaming
try:
# Enhance config with optimized parameters
enhanced_config = {
**state["config"],
"scrape_params": state.get("scrape_params", {}),
}
# Add r2r_params to rag_config in the config
if "rag_config" not in enhanced_config:
enhanced_config["rag_config"] = {}
r2r_params = state.get("r2r_params", {})
# Get the rag_config dict to ensure it's not Any type
rag_config: dict[str, Any] = enhanced_config["rag_config"]
if rag_config and r2r_params:
rag_config |= r2r_params
# Ensure we don't create duplicate datasets
rag_config["reuse_existing_dataset"] = True
logger.info(f"Processing URL: {state['input_url']}")
# Get the stream writer for this node
try:
writer = get_stream_writer()
except (RuntimeError, TypeError) as e:
# Outside of a LangGraph runnable context or stream initialization error
logger.debug(f"Stream writer unavailable: {e}")
writer = None
# Use the streaming version and forward updates
result = await process_url_to_r2r_with_streaming(
state["input_url"],
enhanced_config,
on_update=lambda update: writer(update) if writer else None,
collection_name=state.get("collection_name"),
)
# Store metadata for future lookups
# Store metadata with the result dict
await _store_processing_metadata(state, cast("dict[str, Any]", result))
return {"processing_result": result, "rag_status": "completed"}
except Exception as e:
logger.error(f"Error processing URL: {e}")
return {"error": str(e), "rag_status": "error"}
async def _store_processing_metadata(
state: RAGAgentState, result: dict[str, Any]
) -> None:
"""Store processing metadata in vector store for deduplication.
Create a searchable record of processed content with metadata
for future deduplication and tracking.
Args:
state: Current agent state with processing parameters.
result: Processing result from url_to_rag graph.
"""
from biz_bud.core.config.schemas import AppConfig
from biz_bud.services.factory import ServiceFactory
vector_store: VectorStore | None = None
try:
# Convert dict config back to AppConfig if needed
state_config = state["config"]
# Convert dict config to AppConfig
if isinstance(state_config, AppConfig):
app_config = state_config
else:
app_config = AppConfig.model_validate(state_config)
factory = ServiceFactory(app_config)
vector_store = await factory.get_vector_store()
await vector_store.initialize()
# Prepare comprehensive metadata - flatten nested structures for VectorStore
metadata: dict[str, str | int | float | list[str]] = {
"source_url": state["input_url"],
"url_hash": str(state["url_hash"]) if state["url_hash"] else "",
"indexed_at": datetime.now(UTC).isoformat(),
"is_git_repo": str(result.get("is_git_repo", False)),
"r2r_dataset_id": str(result.get("r2r_document_id", "")),
"total_pages": len(result.get("scraped_content", [])),
# Flatten processing params
"scrape_depth": state.get("scrape_params", {}).get("depth", 1),
}
# Handle ragflow_chunk_size separately to avoid type issues
r2r_params = state.get("r2r_params", {})
# r2r_params is always a dict from state.get with default {}
metadata["ragflow_chunk_size"] = r2r_params.get("chunk_size", 500)
# Create searchable summary
summary = f"Knowledge base entry for {state['input_url']}. "
if result.get("is_git_repo"):
summary += "Git repository processed with Repomix. "
else:
summary += f"Website with {metadata['total_pages']} pages scraped. "
summary += f"Stored in R2R dataset: {metadata['r2r_dataset_id']}"
# Store in vector store with metadata namespace
await vector_store.upsert_with_metadata(
content=summary, metadata=metadata, namespace="rag_metadata"
)
info_highlight(f"Stored processing metadata for {state['input_url']}")
except Exception as e:
logger.error(f"Failed to store processing metadata: {e}")
finally:
if vector_store is not None:
await vector_store.cleanup()

View File

@@ -10,9 +10,7 @@ from langchain_core.runnables import RunnableConfig
# Removed broken core import
from biz_bud.logging import get_logger
from biz_bud.tools.capabilities.database.tool import r2r_rag_completion as r2r_rag
from biz_bud.tools.capabilities.database.tool import (
r2r_search_documents as r2r_deep_research, # Using search as fallback for deep research
)
from biz_bud.tools.capabilities.database.tool import r2r_search_documents as r2r_deep_research
from biz_bud.tools.capabilities.database.tool import r2r_search_documents as r2r_search
if TYPE_CHECKING:
@@ -22,7 +20,7 @@ logger = get_logger(__name__)
async def r2r_search_node(
state: RAGAgentState, config: RunnableConfig | None = None
state: RAGAgentState, config: RunnableConfig
) -> dict[str, Any]:
"""Perform search using R2R's hybrid search capabilities.
@@ -80,7 +78,7 @@ async def r2r_search_node(
async def r2r_rag_node(
state: RAGAgentState, config: RunnableConfig | None = None
state: RAGAgentState, config: RunnableConfig
) -> dict[str, Any]:
"""Perform RAG using R2R for intelligent responses.
@@ -138,7 +136,7 @@ async def r2r_rag_node(
async def r2r_deep_research_node(
state: RAGAgentState, config: RunnableConfig | None = None
state: RAGAgentState, config: RunnableConfig
) -> dict[str, Any]:
"""Perform deep research using R2R's agentic capabilities.

View File

@@ -0,0 +1,179 @@
"""RAG agent nodes using R2R for advanced retrieval."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
# Removed broken core import
from biz_bud.logging import get_logger
from biz_bud.tools.capabilities.database.tool import r2r_rag_completion as r2r_rag
from biz_bud.tools.capabilities.database.tool import r2r_search_documents as r2r_deep_research
from biz_bud.tools.capabilities.database.tool import r2r_search_documents as r2r_search
if TYPE_CHECKING:
from biz_bud.states.rag_agent import RAGAgentState
logger = get_logger(__name__)
async def r2r_search_node(
state: RAGAgentState, config: RunnableConfig | None
) -> dict[str, Any]:
"""Perform search using R2R's hybrid search capabilities.
Args:
state: Current agent state
Returns:
Updated state with search results
"""
query = state.get("query", "")
search_type = state.get("search_type", "hybrid")
logger.info(f"Performing R2R {search_type} search for: {query}")
try:
results = await r2r_search.ainvoke(
{
"query": query,
"search_type": search_type,
"limit": 10,
}
)
formatted_results = [
{
"content": result["content"],
"score": result["score"],
"metadata": result["metadata"],
"source": result["metadata"].get("source", "Unknown"),
}
for result in results
]
return {
"search_results": formatted_results,
"messages": [
AIMessage(
content=f"Found {len(results)} results using R2R {search_type} search"
)
],
}
except Exception as e:
logger.error(f"R2R search failed: {e}")
return {
"errors": state.get("errors", [])
+ [
{
"error_type": "R2R_SEARCH_ERROR",
"error_message": str(e),
"component": "r2r_search_node",
}
],
}
async def r2r_rag_node(
state: RAGAgentState, config: RunnableConfig | None
) -> dict[str, Any]:
"""Perform RAG using R2R for intelligent responses.
Args:
state: Current agent state
Returns:
Updated state with RAG response
"""
query = state.get("query", "")
logger.info(f"Performing R2R RAG for: {query}")
try:
response = await r2r_rag.ainvoke(
{
"query": query,
"use_citations": True,
"temperature": 0.7,
}
)
# Format response with citations
formatted_answer = response.get("answer", "No answer generated")
if citations := response.get("citations", []):
formatted_answer += "\n\nCitations:"
for i, citation in enumerate(citations, 1):
# Handle citation format defensively
if isinstance(citation, dict):
source = citation.get("source", "Unknown source")
elif isinstance(citation, str):
source = citation
else:
source = str(citation)
formatted_answer += f"\n[{i}] {source}"
return {
"rag_response": response,
"messages": [AIMessage(content=formatted_answer)],
}
except Exception as e:
logger.error(f"R2R RAG failed: {e}")
return {
"errors": state.get("errors", [])
+ [
{
"error_type": "R2R_RAG_ERROR",
"error_message": str(e),
"component": "r2r_rag_node",
}
],
}
async def r2r_deep_research_node(
state: RAGAgentState, config: RunnableConfig | None
) -> dict[str, Any]:
"""Perform deep research using R2R's agentic capabilities.
Args:
state: Current agent state
Returns:
Updated state with deep research results
"""
query = state.get("query", "")
logger.info(f"Performing R2R deep research for: {query}")
try:
response = await r2r_deep_research.ainvoke(
{
"query": query,
"thinking_budget": 4096,
"max_tokens": 16000,
}
)
return {
"deep_research_results": response,
"messages": [AIMessage(content=response.get("answer", "No results found"))],
}
except Exception as e:
logger.error(f"R2R deep research failed: {e}")
return {
"errors": state.get("errors", [])
+ [
{
"error_type": "R2R_DEEP_RESEARCH_ERROR",
"error_message": str(e),
"component": "r2r_deep_research_node",
}
],
}

View File

@@ -106,7 +106,7 @@ def _analyze_content_characteristics(
async def analyze_content_for_rag_node(
state: "URLToRAGState", config: RunnableConfig | None = None
state: "URLToRAGState", config: RunnableConfig
) -> dict[str, Any]:
"""Analyze scraped content and determine optimal RAGFlow configuration.

View File

@@ -0,0 +1,388 @@
"""Analyze scraped content to determine optimal R2R upload configuration."""
from typing import TYPE_CHECKING, Any, TypedDict
from langchain_core.runnables import RunnableConfig
from biz_bud.core import preserve_url_fields
from biz_bud.logging import get_logger
if TYPE_CHECKING:
from biz_bud.states.url_to_rag import URLToRAGState
logger = get_logger(__name__)
class _R2RConfig(TypedDict):
"""Recommended configuration for R2R document upload."""
chunk_size: int
extract_entities: bool
metadata: dict[str, Any]
rationale: str
_ANALYSIS_PROMPT = """Analyze the following scraped web content and determine the optimal R2R configuration.
Content Overview:
- URL: {url}
- Number of pages: {page_count}
- Total content length: {total_length} characters
- Has tables: {has_tables}
- Has code blocks: {has_code}
- Has Q&A format: {has_qa}
- Content types: {content_types}
Sample content from first page:
{sample_content}
Based on this analysis, recommend:
1. chunk_size: Optimal chunk size (500-2000 characters)
2. extract_entities: Whether to extract entities for knowledge graph (true/false)
3. metadata: Additional metadata tags (category, importance, content_type)
4. rationale: Brief explanation of your choices
Consider:
- Smaller chunks (500-800) for Q&A or reference content
- Larger chunks (1500-2000) for narrative or documentation
- Enable entity extraction for content with many people, places, products
- Add relevant metadata tags for better searchability
- Larger chunks (512-1024) for narrative content, smaller (128-256) for technical/reference
Respond ONLY with a JSON object in this exact format:
{{
"chunk_method": "string",
"parser_config": {{"chunk_token_num": number}},
"pagerank": number,
"rationale": "string"
}}"""
def _analyze_content_characteristics(
scraped_content: list[dict[str, Any]],
) -> dict[str, Any]:
"""Analyze characteristics of scraped content."""
characteristics: dict[str, Any] = {
"page_count": len(scraped_content),
"total_length": 0,
"has_tables": False,
"has_code": False,
"has_qa": False,
"content_types": set(),
}
for page in scraped_content:
content = page.get("markdown", "") or page.get("content", "")
characteristics["total_length"] += len(content)
# Check for tables
if "| " in content and " |" in content:
characteristics["has_tables"] = True
# Check for code blocks
if "```" in content or "<code>" in content:
characteristics["has_code"] = True
characteristics["content_types"].add("code")
# Check for Q&A patterns
if any(
pattern in content.lower()
for pattern in ["q:", "a:", "question:", "answer:", "faq"]
):
characteristics["has_qa"] = True
characteristics["content_types"].add("qa")
# Check content types from metadata
if isinstance(page.get("metadata"), dict):
if content_type := page["metadata"].get("contentType", ""):
characteristics["content_types"].add(content_type)
# Convert set to list for JSON serialization
characteristics["content_types"] = list(characteristics["content_types"])
return characteristics
async def analyze_content_for_rag_node(
state: "URLToRAGState", config: RunnableConfig | None
) -> dict[str, Any]:
"""Analyze scraped content and determine optimal RAGFlow configuration.
This node now analyzes each document individually for optimal configuration
while processing them concurrently for efficiency.
Args:
state: Current workflow state with scraped content
Returns:
Updated state with r2r_info field containing recommendations
"""
logger.info("Analyzing content for RAG processing")
scraped_content = state.get("scraped_content", [])
repomix_output = state.get("repomix_output")
# Track how many pages we've already processed
last_processed_count = state.get("last_processed_page_count", 0)
# Check if this is a repomix repository first (before checking scraped_content)
if repomix_output and not scraped_content:
logger.info("Processing repomix output (no scraped content check needed)")
# Continue to repomix handling below
elif last_processed_count < len(scraped_content):
# Only process NEW pages that haven't been analyzed yet
new_content = scraped_content[last_processed_count:]
logger.info(
f"Processing {len(new_content)} new pages (indices {last_processed_count} to {len(scraped_content) - 1})"
)
scraped_content = new_content
else:
logger.info("No new content to process")
no_content_result: dict[str, Any] = {
"last_processed_page_count": len(scraped_content)
}
# Preserve URL fields for collection naming
if state.get("url"):
no_content_result["url"] = state.get("url")
if state.get("input_url"):
no_content_result["input_url"] = state.get("input_url")
return no_content_result
logger.info(
f"Scraped content items: {len(scraped_content) if scraped_content else 0}"
)
logger.info(f"Has repomix output: {bool(repomix_output)}")
# Handle repomix output (Git repositories)
if repomix_output and not scraped_content:
# For repomix, create a single-page structure wrapped in pages array
processed_content = {
"pages": [
{
"content": repomix_output,
"markdown": repomix_output, # Repomix output is already markdown-like
"title": f"Repository: {state.get('input_url', 'Unknown')}",
"metadata": {"content_type": "repository", "source": "repomix"},
"url": state.get("input_url") or state.get("url", ""),
}
],
"metadata": {
"page_count": 1,
"total_length": len(repomix_output),
"content_types": ["repository"],
"source": "repomix",
},
}
# Return appropriate config for repository content
repo_config: _R2RConfig = {
"chunk_size": 2000, # Larger chunks for code
"extract_entities": True, # Extract entities like function/class names
"metadata": {"content_type": "repository"},
"rationale": "Repository content benefits from larger chunks and entity extraction",
}
repo_result: dict[str, Any] = {
"r2r_info": repo_config,
"processed_content": processed_content,
}
# Preserve URL fields for collection naming
repo_result = preserve_url_fields(repo_result, state)
return repo_result
if not scraped_content:
logger.warning("No scraped content to analyze, using default configuration")
empty_default_config: _R2RConfig = {
"chunk_size": 1000,
"extract_entities": False,
"metadata": {"content_type": "unknown"},
"rationale": "Default configuration used due to missing content",
}
empty_result: dict[str, Any] = {
"r2r_info": empty_default_config,
"processed_content": {
"pages": [
{
"content": "",
"title": "Empty Document",
"metadata": {"source_url": state.get("input_url", "")},
}
]
},
}
# Preserve URL fields for collection naming
empty_result = preserve_url_fields(empty_result, state)
return empty_result
# Analyze content characteristics (used for logging side effects)
try:
_analyze_content_characteristics(scraped_content)
except Exception as e:
logger.warning(f"Content characteristics analysis failed: {e}")
# Continue with processing even if characteristic analysis fails
# Use smaller model for analysis tasks
try:
# Skip LLM analysis for now and use simple rule-based analysis
logger.info(
f"Analyzing {len(scraped_content)} documents with rule-based approach"
)
analyzed_pages = []
for doc in scraped_content:
# Analyze content characteristics without LLM
content = doc.get("markdown", "") or doc.get("content", "")
content_length = len(content)
# Simple rule-based analysis
if content_length > 50000: # Very large content
chunk_size = 2000
extract_entities = True
content_type = "reference"
elif content_length > 10000: # Large content
chunk_size = 1500
extract_entities = True
content_type = "documentation"
elif (
"```" in content or "def " in content or "class " in content
): # Code content
chunk_size = 1000
extract_entities = False
content_type = "code"
elif "?" in content and len(content.split("?")) > 5: # Q&A content
chunk_size = 800
extract_entities = True
content_type = "qa"
else: # General content
chunk_size = 1000
extract_entities = False
content_type = "general"
# Create document with rule-based config
doc_with_config = {
**doc,
"r2r_config": {
"chunk_size": chunk_size,
"extract_entities": extract_entities,
"metadata": {"content_type": content_type},
"rationale": f"Rule-based analysis: {content_type} content ({content_length} chars)",
},
}
analyzed_pages.append(doc_with_config)
logger.info(
f"Analyzed document: {doc.get('title', 'Untitled')} -> {content_type} ({content_length} chars, chunk_size: {chunk_size})"
)
# Calculate overall statistics
total_length = sum(
len(page.get("markdown", "") or page.get("content", ""))
for page in analyzed_pages
)
content_types = {
page["r2r_config"]["metadata"].get("content_type", "general")
for page in analyzed_pages
if "r2r_config" in page and "metadata" in page["r2r_config"]
}
# Prepare processed content with individually analyzed pages
processed_content = {
"pages": analyzed_pages,
"metadata": {
"page_count": len(analyzed_pages),
"total_length": total_length,
"content_types": list(content_types),
"analysis_method": "per_document",
},
}
logger.info(
f"Analyzer completed: {len(analyzed_pages)} documents analyzed individually"
)
# Log summary of configurations
config_summary: dict[int, int] = {}
for page in analyzed_pages:
if "r2r_config" in page:
chunk_size = page["r2r_config"]["chunk_size"]
if isinstance(chunk_size, int):
config_summary[chunk_size] = config_summary.get(chunk_size, 0) + 1
logger.info(f"Chunk size distribution: {config_summary}")
# Return with a general r2r_info for compatibility
# Individual pages have their own configs
analysis_result: dict[str, Any] = {
"r2r_info": {
"chunk_size": 1000, # Default, overridden per document
"extract_entities": False, # Default, overridden per document
"metadata": {"analysis_method": "per_document"},
"rationale": "Each document analyzed individually for optimal configuration",
},
"processed_content": processed_content,
"last_processed_page_count": last_processed_count + len(scraped_content),
}
# IMPORTANT: Preserve URL fields for collection naming
# Check both url and input_url to ensure we don't lose the URL
if state.get("url"):
analysis_result["url"] = state.get("url")
if state.get("input_url"):
analysis_result["input_url"] = state.get("input_url")
return analysis_result
except Exception as e:
logger.error(f"Error analyzing content: {e}")
# Return default config on error
default_config: _R2RConfig = {
"chunk_size": 1000,
"extract_entities": False,
"metadata": {"content_type": "general"},
"rationale": "Using default configuration due to analysis error",
}
# Still prepare processed content even on error
# Pass pages individually for upload
processed_content = {
"pages": scraped_content,
"metadata": {
"page_count": len(scraped_content),
"total_length": sum(
len(page.get("markdown", "") or page.get("content", ""))
for page in scraped_content
),
},
}
result: dict[str, Any] = {
"r2r_info": default_config,
"processed_content": processed_content,
"last_processed_page_count": last_processed_count + len(scraped_content),
}
# IMPORTANT: Preserve URL fields for collection naming
# Check both url and input_url to ensure we don't lose the URL
if state.get("url"):
result["url"] = state.get("url")
if state.get("input_url"):
result["input_url"] = state.get("input_url")
logger.info(f"Analyzer (error case) returning with keys: {list(result.keys())}")
logger.info(
f"Processed content has {len(processed_content.get('pages', []))} pages"
)
return result

View File

@@ -132,7 +132,7 @@ async def _upload_single_page_to_r2r(
async def batch_check_duplicates_node(
state: URLToRAGState, config: RunnableConfig | None = None
state: URLToRAGState, config: RunnableConfig
) -> dict[str, Any]:
"""Check multiple URLs for duplicates in parallel.
@@ -263,7 +263,7 @@ async def batch_check_duplicates_node(
async def batch_scrape_and_upload_node(
state: URLToRAGState, config: RunnableConfig | None = None
state: URLToRAGState, config: RunnableConfig
) -> dict[str, Any]:
"""Scrape and upload multiple URLs concurrently."""
from firecrawl import AsyncFirecrawlApp

View File

@@ -0,0 +1,388 @@
"""Batch processing node for concurrent URL handling."""
from __future__ import annotations
import asyncio
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Protocol, cast
from langchain_core.runnables import RunnableConfig
from r2r import R2RClient
# Removed broken core import
from biz_bud.core.networking.async_utils import gather_with_concurrency
from biz_bud.logging import get_logger
if TYPE_CHECKING:
from biz_bud.states.url_to_rag import URLToRAGState
logger = get_logger(__name__)
class ScrapedDataProtocol(Protocol):
"""Protocol for scraped data objects with content and markdown."""
@property
def markdown(self) -> str | None:
"""Get markdown content."""
...
@property
def content(self) -> str | None:
"""Get raw content."""
...
class ScrapeResultProtocol(Protocol):
"""Protocol for scrape result objects."""
@property
def success(self) -> bool:
"""Whether the scrape was successful."""
...
@property
def data(self) -> ScrapedDataProtocol | None:
"""The scraped data if successful."""
...
async def _upload_single_page_to_r2r(
url: str,
scraped_data: ScrapedDataProtocol,
config: dict[str, Any],
metadata: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Upload a single page to R2R.
Args:
url: The URL of the page
scraped_data: The scraped data from Firecrawl
config: Application configuration
metadata: Optional metadata to include
Returns:
Dictionary with success status and optional error message
"""
from r2r import R2RClient
# Get R2R config
api_config = config.get("api_config", {})
r2r_base_url = api_config.get("r2r_base_url", "http://localhost:7272")
r2r_api_key = api_config.get(
"r2r_api_key"
) # Retrieve API key from config if available
try:
# Initialize R2R client
client = R2RClient(base_url=r2r_base_url)
# Login if API key provided
if r2r_api_key:
await asyncio.wait_for(
asyncio.to_thread(
lambda: client.users.login(
email="admin@example.com", password=r2r_api_key or ""
)
),
timeout=5.0,
)
# Extract content
content = ""
if hasattr(scraped_data, "markdown") and scraped_data.markdown:
content = scraped_data.markdown
elif hasattr(scraped_data, "content") and scraped_data.content:
content = scraped_data.content
if not content:
return {"success": False, "error": "No content to upload"}
# Create document metadata
doc_metadata = {
"source_url": url,
"scraped_at": datetime.now(UTC).isoformat(),
**(metadata or {}),
}
# Upload document
result = await asyncio.to_thread(
lambda: client.documents.create(
raw_text=content,
metadata=doc_metadata,
)
)
if (
result
and hasattr(result, "results")
and hasattr(result.results, "document_id")
):
return {
"success": True,
"document_id": result.results.document_id,
}
else:
return {"success": False, "error": "Upload returned no document ID"}
except Exception as e:
logger.error(f"Error uploading to R2R: {e}")
return {"success": False, "error": str(e)}
async def batch_check_duplicates_node(
state: URLToRAGState, config: RunnableConfig | None
) -> dict[str, Any]:
"""Check multiple URLs for duplicates in parallel.
This node processes URLs in batches to improve performance.
"""
urls_to_process = state.get("urls_to_process", [])
current_index = state.get("current_url_index", 0)
batch_size = 10 # Process 10 URLs at a time
# Get batch of URLs to check
end_index = min(current_index + batch_size, len(urls_to_process))
batch_urls = urls_to_process[current_index:end_index]
if not batch_urls:
return {"batch_complete": True}
logger.info(
f"Checking {len(batch_urls)} URLs for duplicates (batch starting at {current_index + 1}/{len(urls_to_process)})"
)
# Get R2R config
state_config = state.get("config", {})
api_config_raw = state_config.get("api_config", {})
# Cast to dict to allow flexible key access
api_config = dict(api_config_raw)
r2r_base_url = api_config.get("r2r_base_url", "http://localhost:7272")
r2r_api_key = api_config.get(
"r2r_api_key"
) # Retrieve API key from config if available
# Initialize R2R client
client = R2RClient(base_url=r2r_base_url)
# Login if API key provided
if r2r_api_key:
try:
await asyncio.wait_for(
asyncio.to_thread(
lambda: client.users.login(
email="admin@example.com", password=r2r_api_key or ""
)
),
timeout=5.0,
)
except TimeoutError:
logger.warning("R2R login timed out")
raise # Re-raise to fail fast on auth issues
except Exception as e:
logger.error(f"R2R login failed: {e}")
raise # Re-raise to fail fast on auth issues
# Check all URLs concurrently
async def check_single_url(url: str, index: int) -> tuple[int, bool, str | None]:
"""Check if a single URL exists in R2R."""
try:
def search_sync() -> object:
return client.retrieval.search(
query=f'source_url:"{url}"',
search_settings={
"filters": {"source_url": {"$eq": url}},
"limit": 1,
},
)
search_results = await asyncio.wait_for(
asyncio.to_thread(search_sync),
timeout=5.0, # 5 second timeout per search
)
if search_results and hasattr(search_results, "results"):
chunk_results = getattr(
getattr(search_results, "results"), "chunk_search_results", []
)
if chunk_results and len(chunk_results) > 0:
doc_id = getattr(chunk_results[0], "document_id", "unknown")
logger.info(f"URL already exists: {url} (doc_id: {doc_id})")
return index, True, doc_id
return index, False, None
except TimeoutError:
logger.warning(f"Timeout checking URL: {url}")
return index, False, None
except Exception as e:
logger.warning(f"Error checking URL {url}: {e}")
return index, False, None
# Create tasks for all URLs in batch
tasks = [
check_single_url(url, current_index + i) for i, url in enumerate(batch_urls)
]
# Run all checks concurrently with controlled concurrency
results = await gather_with_concurrency(5, *tasks)
# Process results
urls_to_skip = []
urls_to_scrape = []
skipped_count = state.get("skipped_urls_count", 0)
for idx, (abs_index, is_duplicate, doc_id) in enumerate(results):
url = batch_urls[idx]
if is_duplicate:
urls_to_skip.append(
{
"url": url,
"index": abs_index,
"doc_id": doc_id,
"reason": f"Already exists (ID: {doc_id})",
}
)
skipped_count += 1
else:
urls_to_scrape.append({"url": url, "index": abs_index})
logger.info(
f"Batch check complete: {len(urls_to_scrape)} to scrape, {len(urls_to_skip)} to skip"
)
return {
"batch_urls_to_scrape": urls_to_scrape,
"batch_urls_to_skip": urls_to_skip,
"current_url_index": end_index,
"skipped_urls_count": skipped_count,
"batch_complete": end_index >= len(urls_to_process),
}
async def batch_scrape_and_upload_node(
state: URLToRAGState, config: RunnableConfig | None
) -> dict[str, Any]:
"""Scrape and upload multiple URLs concurrently."""
from firecrawl import AsyncFirecrawlApp
from biz_bud.nodes.integrations.firecrawl.config import load_firecrawl_settings
batch_urls_to_scrape = state.get("batch_urls_to_scrape", [])
if not batch_urls_to_scrape:
return {"batch_scrape_complete": True}
logger.info(f"Batch scraping {len(batch_urls_to_scrape)} URLs")
# Get state config for upload function
state_config = state.get("config", {})
# Get Firecrawl config
settings = await load_firecrawl_settings(cast(dict[str, Any], state))
api_key, base_url = settings.api_key, settings.base_url
# Extract just the URLs
urls = [cast("dict[str, Any]", item)["url"] for item in batch_urls_to_scrape]
# Scrape all URLs concurrently
firecrawl = AsyncFirecrawlApp(api_key=api_key, api_url=base_url)
# Use batch scraping with the new SDK
scraped_results = []
# Process URLs in batches since the new SDK may not have batch_scrape
batch_size = 5
for i in range(0, len(urls), batch_size):
batch_urls = urls[i : i + batch_size]
batch_tasks = [
firecrawl.scrape_url(
url,
formats=["markdown", "content", "links"],
only_main_content=True,
timeout=40000,
)
for url in batch_urls
]
batch_results = await gather_with_concurrency(3, *batch_tasks, return_exceptions=True)
# Convert results to expected format
for result in batch_results:
if isinstance(result, Exception):
# Handle exception as failed scrape
# Create a simple mock result object instead of dynamic type
mock_result = object.__new__(object)
mock_result.success = False # type: ignore
mock_result.data = None # type: ignore
mock_result.error = str(result) # type: ignore
scraped_results.append(mock_result)
else:
scraped_results.append(result)
# Upload all successful results to R2R concurrently
successful_uploads = 0
failed_uploads = 0
async def upload_single_result(url: str, result: ScrapeResultProtocol) -> bool:
"""Upload a single scrape result to R2R."""
if result.success and result.data:
try:
# Use the existing upload function
upload_result = await _upload_single_page_to_r2r(
url=url,
scraped_data=result.data,
config=state_config,
metadata={"source": "batch_process", "batch_size": len(urls)},
)
if upload_result.get("success"):
return True
logger.error(f"Failed to upload {url}: {upload_result.get('error')}")
return False
except Exception as e:
logger.error(f"Exception uploading {url}: {e}")
return False
else:
logger.warning(f"Skipped failed scrape: {url}")
return False
# Upload all results concurrently
# scraped_results is always a list at this point
if scraped_results:
results_list = list(scraped_results)
else:
logger.error("Empty scraped_results returned from batch_scrape")
total_failed = cast("int", state.get("failed_uploads", 0)) + len(urls)
return {
"batch_scrape_complete": True,
"successful_uploads": state.get("successful_uploads", 0),
"failed_uploads": total_failed,
}
upload_tasks = [
upload_single_result(urls[i], cast("ScrapeResultProtocol", result))
for i, result in enumerate(results_list)
]
upload_results = await gather_with_concurrency(3, *upload_tasks)
successful_uploads = sum(bool(success) for success in upload_results)
failed_uploads = len(upload_results) - successful_uploads
logger.info(
f"Batch upload complete: {successful_uploads} succeeded, {failed_uploads} failed"
)
# Update total counts
total_succeeded = (
cast("int", state.get("successful_uploads", 0)) + successful_uploads
)
total_failed = cast("int", state.get("failed_uploads", 0)) + failed_uploads
return {
"successful_uploads": total_succeeded,
"failed_uploads": total_failed,
"batch_scrape_complete": True,
}

View File

@@ -124,7 +124,7 @@ def _validate_collection_name(name: str | None) -> str | None:
async def check_r2r_duplicate_node(
state: URLToRAGState, config: RunnableConfig | None = None
state: URLToRAGState, config: RunnableConfig
) -> dict[str, Any]:
"""Check multiple URLs for duplicates in R2R concurrently.

View File

@@ -0,0 +1,675 @@
"""Node for checking if a URL has already been processed in R2R."""
from __future__ import annotations
import re
import time
from typing import TYPE_CHECKING, Any, cast
from langchain_core.runnables import RunnableConfig
from biz_bud.core import URLNormalizer
from biz_bud.core.langgraph.state_immutability import StateUpdater
from biz_bud.core.networking.async_utils import gather_with_concurrency
from biz_bud.logging import get_logger
from biz_bud.tools.clients.r2r_utils import (
authenticate_r2r_client,
get_r2r_config,
r2r_direct_api_call,
)
from .utils import extract_collection_name
if TYPE_CHECKING:
from biz_bud.states.url_to_rag import URLToRAGState
logger = get_logger(__name__)
# Simple in-memory cache for duplicate check results (TTL: 5 minutes)
_duplicate_cache: dict[str, tuple[bool, float]] = {}
_CACHE_TTL = 300 # 5 minutes in seconds
def _get_cached_result(url: str, collection_id: str | None) -> bool | None:
"""Get cached duplicate check result if still valid."""
cache_key = f"{_normalize_url(url)}#{collection_id or 'global'}"
if cache_key in _duplicate_cache:
is_duplicate, timestamp = _duplicate_cache[cache_key]
if time.time() - timestamp < _CACHE_TTL:
logger.debug(f"Cache hit for {url}")
return is_duplicate
else:
# Expired entry
del _duplicate_cache[cache_key]
return None
def _cache_result(url: str, collection_id: str | None, is_duplicate: bool) -> None:
"""Cache duplicate check result."""
cache_key = f"{_normalize_url(url)}#{collection_id or 'global'}"
_duplicate_cache[cache_key] = (is_duplicate, time.time())
logger.debug(f"Cached result for {url}: {is_duplicate}")
def clear_duplicate_cache() -> None:
"""Clear the duplicate check cache. Useful for testing."""
global _duplicate_cache
_duplicate_cache.clear()
logger.debug("Duplicate cache cleared")
# Create a shared URLNormalizer instance for consistent normalization
_url_normalizer = URLNormalizer(
default_protocol="https",
normalize_protocol=True,
remove_fragments=True,
remove_www=True,
lowercase_domain=True,
sort_query_params=True,
remove_trailing_slash=True,
)
def _normalize_url(url: str) -> str:
"""Normalize URL for consistent comparison.
Args:
url: The URL to normalize
Returns:
Normalized URL
"""
return _url_normalizer.normalize(url)
def _get_url_variations(url: str) -> list[str]:
"""Get variations of a URL for flexible matching.
Args:
url: The URL to get variations for
Returns:
List of URL variations
"""
return _url_normalizer.get_variations(url)
def _validate_collection_name(name: str | None) -> str | None:
"""Validate and sanitize collection name for R2R compatibility.
Applies the same sanitization rules as extract_collection_name to ensure
collection name overrides follow R2R requirements.
Args:
name: Collection name to validate (can be None or empty)
Returns:
Sanitized collection name or None if invalid/empty
"""
if not name or not name.strip():
return None
# Apply same sanitization as extract_collection_name
sanitized = name.lower().strip()
# Only allow alphanumeric characters, hyphens, and underscores
sanitized = re.sub(r"[^a-z0-9\-_]", "_", sanitized)
# Reject if sanitized is empty
return sanitized or None
async def check_r2r_duplicate_node(
state: URLToRAGState, config: RunnableConfig | None
) -> dict[str, Any]:
"""Check multiple URLs for duplicates in R2R concurrently.
This node now processes URLs in batches for better performance and
determines the collection name for the entire batch.
Args:
state: Current workflow state
Returns:
State updates with batch duplicate check results and collection info
"""
import asyncio
from r2r import R2RClient
# Get URLs to process
urls_to_process = state.get("urls_to_process", [])
current_index = state.get("current_url_index", 0)
batch_size = 20 # Process 20 URLs at a time
# Get batch of URLs to check
end_index = min(current_index + batch_size, len(urls_to_process))
batch_urls = urls_to_process[current_index:end_index]
if not batch_urls:
batch_result: dict[str, Any] = {"batch_complete": True}
# Preserve URL fields for collection naming
if state.get("url"):
batch_result["url"] = state.get("url")
if state.get("input_url"):
batch_result["input_url"] = state.get("input_url")
return batch_result
logger.info(
f"Checking {len(batch_urls)} URLs for duplicates (batch {current_index + 1}-{end_index} of {len(urls_to_process)})"
)
# Check for override collection name first
override_collection_name = state.get("collection_name")
collection_name = None
if override_collection_name:
# Validate the override collection name
collection_name = _validate_collection_name(override_collection_name)
if collection_name:
logger.info(
f"Using override collection name: '{collection_name}' (original: '{override_collection_name}')"
)
else:
logger.warning(
f"Invalid override collection name '{override_collection_name}', falling back to URL-derived name"
)
if not collection_name:
# Extract collection name from the main URL (not batch URLs)
# Use input_url first, fall back to url if not available
main_url = state.get("input_url") or state.get("url", "")
if not main_url and batch_urls:
# If no main URL, use the first batch URL
main_url = batch_urls[0]
collection_name = extract_collection_name(main_url)
logger.info(
f"Derived collection name: '{collection_name}' from URL: {main_url}"
)
# Expose the final collection name in the state for transparency
updater = StateUpdater(state)
updater.set("final_collection_name", collection_name)
state = cast("URLToRAGState", updater.build())
state_config = state.get("config", {})
r2r_config = get_r2r_config(state_config)
logger.info(f"Using R2R base URL: {r2r_config['base_url']}")
try:
# Initialize R2R client
client = R2RClient(base_url=r2r_config["base_url"])
# Authenticate if API key is provided
await authenticate_r2r_client(
client,
r2r_config["api_key"],
r2r_config["email"]
)
# Get collection ID for the extracted collection name
collection_id = None
try:
# Try to list collections to find the ID
logger.info(f"Looking for collection '{collection_name}'...")
collections_list = await asyncio.wait_for(
asyncio.to_thread(lambda: client.collections.list(limit=100)),
timeout=10.0,
)
# Look for existing collection
if hasattr(collections_list, "results"):
for collection in collections_list.results:
if (
hasattr(collection, "name")
and collection.name == collection_name
):
collection_id = str(collection.id)
logger.info(
f"Found existing collection '{collection_name}' with ID: {collection_id}"
)
break
if not collection_id:
logger.info(
f"Collection '{collection_name}' not found, will be created during upload"
)
except Exception as e:
logger.warning(f"Could not retrieve collection ID: {e}")
# Try fallback API approach if SDK failed
try:
logger.info("Trying direct API fallback for collection lookup...")
collections_response = await r2r_direct_api_call(
client,
"GET",
"/v3/collections",
params={"limit": 100},
timeout=30.0,
)
# Look for existing collection
for collection in collections_response.get("results", []):
if collection.get("name") == collection_name:
collection_id = str(collection.get("id"))
logger.info(
f"Found existing collection '{collection_name}' with ID: {collection_id} (via API fallback)"
)
break
if not collection_id:
logger.info(
f"Collection '{collection_name}' not found via API fallback, will be created during upload"
)
except Exception as api_e:
logger.warning(f"API fallback also failed: {api_e}")
# Continue without collection ID - collection will be created during upload
# Check all URLs concurrently
async def check_single_url(url: str) -> tuple[str, bool, str | None]:
"""Check if a single URL exists in R2R with caching."""
# Check cache first
cached_result = _get_cached_result(url, collection_id)
if cached_result is not None:
return url, cached_result, None
try:
async def search_direct() -> dict[str, Any]:
"""Use optimized direct API call with hierarchical URL matching."""
# Start with canonical normalized URL for fast exact matching
canonical_url = _normalize_url(url)
logger.debug(f"Checking canonical URL: {canonical_url}")
# Build simple canonical filter first (most common case)
canonical_filters = {
"$or": [
{"source_url": {"$eq": canonical_url}},
{"parent_url": {"$eq": canonical_url}},
{"sourceURL": {"$eq": canonical_url}},
]
}
# Add collection filter if we have a collection ID
if collection_id:
canonical_filters = {
"$and": [
canonical_filters,
{"collection_id": {"$eq": collection_id}},
]
}
# Try canonical URL first (fast path)
try:
result = await r2r_direct_api_call(
client,
"POST",
"/v3/retrieval/search",
json_data={
"query": "*",
"search_settings": {
"filters": canonical_filters,
"limit": 1,
},
},
timeout=3.0, # Quick timeout for canonical check
)
# If we found results with canonical URL, return them
if (
result
and "results" in result
and result["results"]
and "chunk_search_results" in result["results"]
and result["results"]["chunk_search_results"]
):
logger.debug(f"Found canonical match for {url}")
return result
except Exception as e:
logger.debug(f"Canonical search failed for {url}: {e}")
# If canonical search didn't find anything, try variations (slower path)
logger.debug(f"Trying URL variations for {url}")
url_variations = _get_url_variations(url)
# Only use essential variations to keep query size reasonable
essential_variations = url_variations[
:3
] # Limit to 3 most important variations
logger.debug(f"Using essential variations: {essential_variations}")
variation_filters = []
for variation in essential_variations:
if variation != canonical_url: # Skip canonical (already tried)
variation_filters.extend(
[
{"source_url": {"$eq": variation}},
{"parent_url": {"$eq": variation}},
{"sourceURL": {"$eq": variation}},
]
)
try:
# Build variation filters with collection_id if available
variation_filters_final = {"$or": variation_filters}
if collection_id:
variation_filters_final = {
"$and": [
variation_filters_final,
{"collection_id": {"$eq": collection_id}},
]
}
# Search with variation filters
result = await r2r_direct_api_call(
client,
"POST",
"/v3/retrieval/search",
json_data={
"query": "*",
"search_settings": {
"filters": variation_filters_final,
"limit": 1,
},
},
timeout=3.0,
)
if (
result
and "results" in result
and result["results"]
and "chunk_search_results" in result["results"]
and result["results"]["chunk_search_results"]
):
logger.debug(f"Found variation match for {url}")
return result
else:
logger.debug(f"No results structure found for {url}")
except Exception as e:
# Log the actual error for debugging
error_msg = str(e)
logger.error(f"Search error for URL {url}: {e}")
if "400" in error_msg or "Query cannot be empty" in error_msg:
logger.debug(
f"R2R search returned 400 error for URL {url}: {e}"
)
# For 400 errors, assume URL doesn't exist rather than failing
raise
elif (
"'Response' object has no attribute 'model_dump_json'"
in error_msg
):
# Handle R2R client version compatibility issue
logger.debug(
f"R2R client compatibility issue for URL {url}: {e}"
)
# Try to handle the response directly
raise
else:
raise
if not variation_filters:
# No additional variations to try
return {}
filters = {"$or": variation_filters}
if collection_id:
filters = {
"$and": [
filters,
{"collection_id": {"$eq": collection_id}},
]
}
# Search with variations (with longer timeout)
return await r2r_direct_api_call(
client,
"POST",
"/v3/retrieval/search",
json_data={
"query": "*",
"search_settings": {
"filters": filters,
"limit": 1,
},
},
timeout=5.0, # Longer timeout for variation search
)
try:
search_results = await asyncio.wait_for(
search_direct(),
timeout=8.0, # Optimized: 3s canonical + 5s variations + buffer
)
logger.debug(
f"Search results for {url}: {type(search_results)} with {len(search_results.get('results', {}).get('chunk_search_results', [])) if search_results else 'no results'} results"
)
# Debug: log the search results structure
if search_results and "results" in search_results:
results = search_results["results"]
if "chunk_search_results" in results:
chunk_results = results["chunk_search_results"]
logger.debug(
f"Found {len(chunk_results)} chunk results for {url}"
)
for i, chunk in enumerate(
chunk_results[:1]
): # Log first chunk only
logger.debug(
f" Chunk {i + 1}: doc_id={chunk.get('document_id', 'N/A')}"
)
else:
logger.debug(
f"No chunk_search_results in results for {url}"
)
else:
logger.debug(f"No results structure found for {url}")
except Exception as e:
# Log the actual error for debugging
error_msg = str(e)
logger.error(f"Search error for URL {url}: {e}")
if "400" in error_msg or "Query cannot be empty" in error_msg:
logger.debug(
f"R2R search returned 400 error for URL {url}: {e}"
)
# For 400 errors, assume URL doesn't exist rather than failing
return url, False, None
elif (
"'Response' object has no attribute 'model_dump_json'"
in error_msg
):
# Handle R2R client version compatibility issue
logger.debug(
f"R2R client compatibility issue for URL {url}: {e}"
)
# Try to handle the response directly
return url, False, None
raise
# Handle the response - it's now a dictionary from the direct API call
if search_results:
# Try to access results directly first
try:
# For R2R v3, results are in a dict format
if "results" in search_results:
results = search_results["results"]
# Check if it's an AggregateSearchResult
if (
results
and isinstance(results, dict)
and "chunk_search_results" in results
):
chunk_results = results["chunk_search_results"]
elif isinstance(results, list):
chunk_results = results
else:
chunk_results = []
else:
chunk_results = []
if chunk_results and len(chunk_results) > 0:
first_result = chunk_results[0]
doc_id = first_result.get("document_id", "unknown")
# Get metadata to verify it's really this URL
metadata = first_result.get("metadata", {})
found_source_url = metadata.get("source_url", "")
found_parent_url = metadata.get("parent_url", "")
found_sourceURL = metadata.get("sourceURL", "")
logger.info(f"URL already exists: {url} (doc_id: {doc_id})")
logger.debug(f" Found source_url: {found_source_url}")
logger.debug(f" Found parent_url: {found_parent_url}")
logger.debug(f" Found sourceURL: {found_sourceURL}")
# Log which type of match we found
url_variations = _get_url_variations(url)
if found_parent_url in url_variations:
logger.info(
" -> Found as parent URL (site already scraped)"
)
elif found_source_url in url_variations:
logger.info(" -> Found as exact source URL match")
elif found_sourceURL in url_variations:
logger.info(" -> Found as sourceURL match")
else:
logger.warning(
f"WARNING: URL mismatch! Searched for '{url}' (variations: {url_variations}) but got source='{found_source_url}', parent='{found_parent_url}', sourceURL='{found_sourceURL}'"
)
# Cache positive result
_cache_result(url, collection_id, True)
return url, True, doc_id
except Exception as e:
logger.error(f"Error parsing search results for {url}: {e}")
# If we can't parse results, assume no duplicate
_cache_result(url, collection_id, False)
return url, False, None
# Cache negative result
_cache_result(url, collection_id, False)
return url, False, None
except TimeoutError:
logger.warning(f"Timeout checking URL: {url}")
# Don't cache timeout results as they may be transient
return url, False, None
except Exception as e:
# Handle specific error for Response objects lacking model_dump_json
error_msg = str(e)
if "'Response' object has no attribute 'model_dump_json'" in error_msg:
logger.debug(
f"R2R client compatibility issue for URL {url}: Response serialization error"
)
elif (
"400" in error_msg
or "Bad Request" in error_msg
or "Query cannot be empty" in error_msg
):
logger.debug(f"R2R search API returned 400 for URL {url}")
else:
logger.warning(f"Error checking URL {url}: {e}")
# For any error, assume URL doesn't exist to allow processing to continue
# Don't cache error results as they may be transient
return url, False, None
# Create tasks for all URLs in batch - add timeout protection
tasks = [check_single_url(url) for url in batch_urls]
# Run all checks concurrently with optimized timeout and controlled concurrency
try:
raw_results = await asyncio.wait_for(
gather_with_concurrency(4, *tasks, return_exceptions=True),
timeout=40.0, # Optimized: 8s per URL * 20 URLs / 4 (concurrent) = ~40s
)
# Filter out exceptions and convert to proper results
results: list[tuple[str, bool, str | None]] = []
for i, raw_result in enumerate(raw_results):
if isinstance(raw_result, Exception):
logger.warning(
f"URL check failed for {batch_urls[i]}: {raw_result}"
)
results.append((batch_urls[i], False, None))
else:
# raw_result is guaranteed to be tuple[str, bool, str | None] here
results.append(cast("tuple[str, bool, str | None]", raw_result))
except TimeoutError:
logger.error(
f"Batch duplicate check timed out after 40 seconds for {len(batch_urls)} URLs"
)
# Return all URLs as non-duplicates to allow processing to continue
results = [(url, False, None) for url in batch_urls]
# Process results
urls_to_scrape = []
urls_to_skip = []
skipped_count = state.get("skipped_urls_count", 0)
for url, is_duplicate, _ in results:
if is_duplicate:
urls_to_skip.append(url)
skipped_count += 1
logger.info(f"Skipping duplicate: {url}")
else:
urls_to_scrape.append(url)
logger.info(
f"Batch check complete: {len(urls_to_scrape)} to scrape, {len(urls_to_skip)} to skip"
)
success_result: dict[str, Any] = {
"batch_urls_to_scrape": urls_to_scrape,
"batch_urls_to_skip": urls_to_skip,
"current_url_index": end_index,
"skipped_urls_count": skipped_count,
"batch_complete": end_index >= len(urls_to_process),
# Add collection information to state
"collection_name": collection_name,
"collection_id": collection_id, # May be None if collection doesn't exist yet
}
# Preserve URL fields for collection naming
if state.get("url"):
success_result["url"] = state.get("url")
if state.get("input_url"):
success_result["input_url"] = state.get("input_url")
return success_result
except Exception as e:
logger.error(f"Error checking R2R for duplicates: {e}")
# Handle API errors gracefully by proceeding with all URLs
# This prevents blocking the workflow when R2R is temporarily unavailable
logger.warning(f"R2R duplicate check failed, proceeding with all URLs: {e}")
# Return all URLs for processing since we can't verify duplicates
error_result: dict[str, Any] = {
"batch_urls_to_scrape": batch_urls,
"batch_urls_to_skip": [],
"batch_complete": end_index >= len(urls_to_process),
"collection_name": collection_name,
"duplicate_check_status": "failed",
"duplicate_check_error": str(e),
}
# Preserve URL fields for collection naming
if state.get("url"):
error_result["url"] = state.get("url")
if state.get("input_url"):
error_result["input_url"] = state.get("input_url")
return error_result

Some files were not shown because too many files have changed in this diff Show More