refactor: update configuration handling and enhance tool execution results (#49)

* refactor: update configuration handling and enhance tool execution results

- Changed metadata and configurable attributes in ConfigurationProvider to use dictionary syntax for improved clarity.
- Added new capabilities for document management and retrieval in CapabilityInferenceEngine.
- Enhanced error handling and logging in buddy agent nodes to improve robustness during execution.
- Updated service initialization logic in ServiceFactory for faster startup by initializing critical services concurrently.
- Improved message handling in tool execution to support enhanced results with message merging.

These changes optimize configuration management, enhance tool execution, and improve overall application performance.

* refactor: centralize multimodal text extraction and improve error handling across agents

* refactor: standardize logging of unsupported content types with unified handler

* refactor: standardize logging by replacing logging.getLogger with bb_core.logging.get_logger

* refactor: improve error handling and type validation across service initialization and node registration

* feat: add configurable postgres pool settings and improve error handling across services

* refactor: replace Union types with | operator and improve type handling across packages

* refactor: simplify multimodal content handling and improve error logging
This commit is contained in:
2025-07-23 16:07:45 -04:00
committed by GitHub
parent 21aa778214
commit 5b1d90bd66
56 changed files with 1085 additions and 281 deletions

View File

@@ -4,16 +4,17 @@ This module provides factory functions for creating various types of
command-based routers that follow LangGraph best practices.
"""
import logging
from collections.abc import Callable
from typing import Any
from langgraph.types import Command
from bb_core.logging import get_logger
from .routing_rules import CommandRoutingRule
from ..validation.condition_security import ConditionValidator, ConditionSecurityError
logger = logging.getLogger(__name__)
logger = get_logger(__name__)
def create_command_router(
@@ -57,7 +58,7 @@ def create_command_router(
except Exception as e:
# Log evaluation error and continue to next rule
# This provides graceful degradation
logging.warning(f"Rule evaluation failed: {e}, continuing to next rule")
logger.warning(f"Rule evaluation failed: {e}, continuing to next rule")
continue
# No rules matched, use default target

View File

@@ -4,19 +4,20 @@ This module provides the base classes and protocols for building
command routing rules with security validation.
"""
import logging
import time
from dataclasses import dataclass, field
from typing import Any, Callable, Protocol, Union
from typing import Any, Callable, Protocol
from langgraph.types import Command
from bb_core.logging import get_logger
from ..validation.condition_security import (
ConditionSecurityError,
ConditionValidator,
validate_condition_for_security,
)
logger = logging.getLogger(__name__)
logger = get_logger(__name__)
class StateProtocol(Protocol):
@@ -44,7 +45,7 @@ class CommandRoutingRule:
description: Human-readable description of the rule
"""
condition: Union[str, Callable[[Any], bool]]
condition: str | Callable[[Any], bool]
target: str
state_updates: dict[str, Any] = field(default_factory=dict)
priority: int = 0

View File

@@ -1492,12 +1492,11 @@ def aggregate_errors(
error_type = error.__class__.__name__
category = error.category.value
severity = error.severity.value
elif isinstance(error, dict):
else:
# error is dict[str, Any] based on union type
error_type = error.get("type", "Unknown")
category = error.get("category", "unknown")
severity = error.get("severity", "error")
else:
continue
by_type[error_type] = by_type.get(error_type, 0) + 1
by_category[category] = by_category.get(category, 0) + 1
@@ -1540,10 +1539,9 @@ def should_halt_on_errors(
for error in errors:
if isinstance(error, BusinessBuddyError):
severity = error.severity.value
elif isinstance(error, dict):
severity = error.get("severity", "error")
else:
continue
# error is dict[str, Any] based on union type
severity = error.get("severity", "error")
if severity == "critical":
critical_count += 1

View File

@@ -10,7 +10,7 @@ This module provides a consistent error message formatting system that:
from __future__ import annotations
import re
from typing import Any, cast
from typing import Any
from bb_core.errors.base import (
ErrorCategory,
@@ -21,7 +21,28 @@ from bb_core.errors.base import (
class ErrorMessageFormatter:
"""Formatter for standardized error messages."""
"""Formatter for standardized error messages with security sanitization."""
# Sanitization rules: (pattern, replacement, flags, user_only)
SANITIZE_RULES = [
# Always apply these rules
(r'(api_key|token|password|secret|auth|credential)[:=]\s*[\'\"]*[^\s\'\"]+', r'\1=***', re.IGNORECASE, False),
(r'(https?://)[^/]*:[^@]*@', r'\1***:***@', re.IGNORECASE, False),
(r'(mongodb|redis|postgres)://[^/]*:[^@]*@', r'\1://***:***@', re.IGNORECASE, False),
(r'([a-zA-Z0-9._%+-]+)@([a-zA-Z0-9.-]+\.[a-zA-Z]{2,})', r'***@\2', 0, False),
(r'\b(\d{1,3}\.)(\d{1,3}\.)(\d{1,3}\.)(\d{1,3})\b', r'\1***.\3***', 0, False),
# User-only rules (more aggressive sanitization)
(r'/[a-zA-Z0-9/_\-\.]+/', '/***/', 0, True),
(r'[A-Z]:\\[a-zA-Z0-9\\/_\-\.]+', r'C:\***\\', 0, True),
(r'File "[^"]*", line \d+', 'File "***", line ***', 0, True),
# Logging-only rules (preserve some debugging info) - these apply when NOT for_user
]
# Additional rules that only apply for logging (when for_user=False)
LOGGING_ONLY_RULES = [
(r'/home/[^/]+/', '/home/***/', 0),
(r'/users/[^/]+/', '/users/***/', re.IGNORECASE),
]
# Standard error message templates by category
TEMPLATES = {
@@ -233,6 +254,43 @@ class ErrorMessageFormatter:
return details
@classmethod
def sanitize_error_message(cls, error: Any, for_user: bool = True) -> str:
"""Sanitize error messages to prevent sensitive information leakage.
Removes or masks common sensitive patterns like API keys, passwords,
file paths, and other potentially sensitive data from error messages.
Args:
error: The exception or error message to sanitize
for_user: If True, applies aggressive sanitization for user display.
If False, applies lighter sanitization for logging.
Returns:
Sanitized error message safe for the specified context
"""
message = str(error)
# Apply sanitization rules
for pattern, replacement, flags, user_only in cls.SANITIZE_RULES:
if not user_only or for_user:
message = re.sub(pattern, replacement, message, flags)
# Apply logging-only rules when not for user
if not for_user:
for pattern, replacement, flags in cls.LOGGING_ONLY_RULES:
message = re.sub(pattern, replacement, message, flags)
if for_user:
# Stack traces (keep only the error message part)
message = message.split('\n')[0] if '\n' in message else message
# Truncate very long messages
if len(message) > 200:
message = message[:200] + "..."
return message
def create_formatted_error(
message: str,
@@ -295,14 +353,16 @@ def create_formatted_error(
error_code = auto_code
# Get ErrorCategory enum
try:
if isinstance(category, ErrorCategory):
error_category: ErrorCategory = category
else:
category_str = category or "unknown"
error_category = cast(ErrorCategory, ErrorCategory(category_str))
except ValueError:
if isinstance(category, ErrorCategory):
error_category: ErrorCategory = category
else:
category_str = category or "unknown"
# Use explicit enum lookup to satisfy type checker
error_category = ErrorCategory.UNKNOWN
for enum_member in ErrorCategory.__members__.values():
if enum_member.value == category_str:
error_category = enum_member
break
# Format the message
formatted_message = ErrorMessageFormatter.format_error_message(
@@ -330,19 +390,22 @@ def create_formatted_error(
def format_error_for_user(error: ErrorInfo) -> str:
"""Format an error for user-friendly display.
"""Format an error for user-friendly display with sanitization.
Args:
error: ErrorInfo to format
Returns:
User-friendly error message
User-friendly error message with sensitive data sanitized
"""
message = error.get("message", "An error occurred")
details = error.get("details", {})
# Start with the main message
user_message = message
# Sanitize the message for user display
sanitized_message = ErrorMessageFormatter.sanitize_error_message(message, for_user=True)
# Start with the sanitized main message
user_message = sanitized_message
# Add suggestion if available
suggestion = details.get("context", {}).get("suggestion")
@@ -351,7 +414,7 @@ def format_error_for_user(error: ErrorInfo) -> str:
# Add node information if relevant
node = error.get("node")
if node and node not in message:
if node and node not in sanitized_message:
user_message += f"\n📍 Location: {node}"
return user_message

View File

@@ -6,7 +6,6 @@ deduplication, and rate limiting into the error handling workflow.
from __future__ import annotations
import logging
from typing import Any, cast
from bb_core.errors.aggregator import get_error_aggregator
@@ -20,8 +19,9 @@ from bb_core.errors.base import (
from bb_core.errors.formatter import categorize_error, create_formatted_error
from bb_core.errors.logger import get_error_logger
from bb_core.errors.router import RouteAction, get_error_router
from bb_core.logging import get_logger
logger = logging.getLogger(__name__)
logger = get_logger(__name__)
async def report_error(

View File

@@ -7,9 +7,8 @@ common patterns and custom handlers.
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import Any, cast
from typing import Any
from bb_core.errors.base import ErrorCategory, ErrorInfo, ErrorNamespace, ErrorSeverity
from bb_core.errors.router import (
@@ -20,8 +19,9 @@ from bb_core.errors.router import (
RouteCondition,
get_error_router,
)
from bb_core.logging import get_logger
logger = logging.getLogger(__name__)
logger = get_logger(__name__)
class RouterConfig:
@@ -163,7 +163,17 @@ class RouterConfig:
categories_list: list[ErrorCategory] = []
for cat in config["categories"]:
if isinstance(cat, str):
categories_list.append(cast(ErrorCategory, ErrorCategory(cat)))
# Use explicit enum lookup to satisfy type checker
category_enum: ErrorCategory | None = None
for enum_member in ErrorCategory.__members__.values():
if enum_member.value == cat:
category_enum = enum_member
break
if category_enum is None:
raise ValueError(f"Invalid ErrorCategory value: {cat!r}")
categories_list.append(category_enum)
elif isinstance(cat, ErrorCategory):
categories_list.append(cat)
else:
@@ -175,7 +185,17 @@ class RouterConfig:
severities_list: list[ErrorSeverity] = []
for sev in config["severities"]:
if isinstance(sev, str):
severities_list.append(cast(ErrorSeverity, ErrorSeverity(sev)))
# Use explicit enum lookup to satisfy type checker
severity_enum: ErrorSeverity | None = None
for enum_member in ErrorSeverity.__members__.values():
if enum_member.value == sev:
severity_enum = enum_member
break
if severity_enum is None:
raise ValueError(f"Invalid ErrorSeverity value: {sev!r}")
severities_list.append(severity_enum)
elif isinstance(sev, ErrorSeverity):
severities_list.append(sev)
else:

View File

@@ -10,6 +10,7 @@ This module provides telemetry hooks and monitoring integration for errors:
from __future__ import annotations
import logging
import time
from collections import deque
from collections.abc import Callable
@@ -19,6 +20,9 @@ from typing import Any, Protocol
from bb_core.errors.base import ErrorInfo
from bb_core.errors.logger import ErrorLogEntry, TelemetryHook
from bb_core.logging import get_logger
logger = get_logger(__name__)
class MetricsClient(Protocol):
@@ -374,9 +378,6 @@ class ErrorTelemetry:
self.alert_callback(severity, message, context)
else:
# Default: log the alert
import logging
logger = logging.getLogger(__name__)
level = (
logging.ERROR if severity in ["critical", "error"] else logging.WARNING
)

View File

@@ -166,12 +166,20 @@ class ConfigurationProvider:
# Copy metadata
self_metadata = getattr(self._config, "metadata", {})
other_metadata = getattr(other, "metadata", {})
merged.metadata = {**self_metadata, **other_metadata} # type: ignore[attr-defined]
if hasattr(merged, "__setitem__"):
merged["metadata"] = {**self_metadata, **other_metadata}
else:
# Fallback for non-dict-like objects
setattr(merged, "metadata", {**self_metadata, **other_metadata})
# Copy configurable
self_configurable = getattr(self._config, "configurable", {})
other_configurable = getattr(other, "configurable", {})
merged.configurable = {**self_configurable, **other_configurable} # type: ignore[attr-defined]
if hasattr(merged, "__setitem__"):
merged["configurable"] = {**self_configurable, **other_configurable}
else:
# Fallback for non-dict-like objects
setattr(merged, "configurable", {**self_configurable, **other_configurable})
# Copy other attributes
for attr in ["tags", "callbacks", "recursion_limit"]:
@@ -211,10 +219,10 @@ class ConfigurationProvider:
"app_config": app_config,
"service_factory": service_factory,
}
config.configurable = configurable_dict # type: ignore[attr-defined]
config["configurable"] = configurable_dict
# Set metadata
config.metadata = metadata # type: ignore[attr-defined]
config["metadata"] = metadata
return cls(config)
@@ -274,7 +282,7 @@ def create_runnable_config(
if max_tokens_override is not None:
configurable["max_tokens_override"] = max_tokens_override
config.configurable = configurable # type: ignore[attr-defined]
config["configurable"] = configurable
# Set metadata
metadata = {}
@@ -288,6 +296,6 @@ def create_runnable_config(
if session_id is not None:
metadata["session_id"] = session_id
config.metadata = metadata # type: ignore[attr-defined]
config["metadata"] = metadata
return config

View File

@@ -7,12 +7,11 @@ from collections import defaultdict, deque
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from datetime import datetime
from typing import TYPE_CHECKING, Any, TypeVar, cast, overload
from typing import TYPE_CHECKING, Any, cast
if TYPE_CHECKING:
pass
T = TypeVar("T")
# Global stats tracker
_stats_lock = asyncio.Lock()
@@ -156,7 +155,7 @@ async def retry_with_backoff(
config: RetryConfig,
*args,
**kwargs,
) -> T:
) -> Any:
"""Execute function with retry and exponential backoff.
Args:
@@ -177,13 +176,12 @@ async def retry_with_backoff(
try:
if asyncio.iscoroutinefunction(func):
# For async functions
async_func = cast(Callable[..., Awaitable[T]], func)
async_func = cast(Callable[..., Awaitable[Any]], func)
result = await async_func(*args, **kwargs)
return result
else:
# For sync functions
sync_func = cast(Callable[..., T], func)
result = sync_func(*args, **kwargs)
result = func(*args, **kwargs)
return result
except Exception as e:
# Check if exception is in the allowed list

View File

@@ -7,7 +7,6 @@ self-register with the appropriate registry when they are defined.
from __future__ import annotations
from collections.abc import Callable
from functools import wraps
from typing import Any, TypeVar
from bb_core.logging import get_logger
@@ -61,8 +60,26 @@ def register_component(
if "category" not in metadata_dict:
metadata_dict["category"] = "default"
# Ensure capabilities, tags, and dependencies are lists
if "capabilities" in metadata_dict and isinstance(metadata_dict["capabilities"], str):
metadata_dict["capabilities"] = [metadata_dict["capabilities"]]
if "tags" in metadata_dict and isinstance(metadata_dict["tags"], str):
metadata_dict["tags"] = [metadata_dict["tags"]]
if "dependencies" in metadata_dict and isinstance(metadata_dict["dependencies"], str):
metadata_dict["dependencies"] = [metadata_dict["dependencies"]]
# Ensure examples is a list if it exists
if "examples" in metadata_dict and not isinstance(metadata_dict["examples"], list):
metadata_dict["examples"] = []
# Ensure input_schema and output_schema are dicts or None
if "input_schema" in metadata_dict and isinstance(metadata_dict["input_schema"], str):
metadata_dict["input_schema"] = None
if "output_schema" in metadata_dict and isinstance(metadata_dict["output_schema"], str):
metadata_dict["output_schema"] = None
# Create metadata object
metadata = RegistryMetadata(**metadata_dict)
metadata = RegistryMetadata(**metadata_dict) # type: ignore[arg-type]
# Get registry manager and register
manager = get_registry_manager()
@@ -289,7 +306,6 @@ def auto_register_pending() -> None:
"""
import gc
import inspect
import sys
manager = get_registry_manager()
registered_count = 0

View File

@@ -5,7 +5,7 @@ to improve type safety and reduce reliance on Any types.
"""
from collections.abc import AsyncIterator, Sequence
from typing import Any, Protocol, TypeVar, Union
from typing import Any, Protocol, TypeVar
# Type variable for LLM models
ModelT = TypeVar('ModelT', bound='BaseLanguageModel')
@@ -68,7 +68,7 @@ class BaseLanguageModel(Protocol):
# Union type for models that may or may not have tools bound
ModelWithOptionalTools = Union[ToolBoundModel, BaseLanguageModel]
ModelWithOptionalTools = ToolBoundModel | BaseLanguageModel
# Type for tool factory results
ToolList = list[Any] # More specific than just Any

View File

@@ -9,7 +9,7 @@ from __future__ import annotations
import threading
import time
from collections import OrderedDict
from typing import Any, Generic, TypeVar
from typing import Generic, TypeVar
K = TypeVar('K')
V = TypeVar('V')
@@ -106,4 +106,4 @@ class LRUCacheWithExpiration(Generic[K, V]):
return len(expired_keys)
__all__ = ["LRUCacheWithExpiration"]
__all__ = ["LRUCacheWithExpiration"]

View File

@@ -8,9 +8,13 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Sequence
from bb_core.logging import get_logger
if TYPE_CHECKING:
from langchain_core.messages import BaseMessage
logger = get_logger(__name__)
class CapabilityInferenceEngine:
"""Centralized engine for inferring capabilities from queries and messages.
@@ -28,6 +32,11 @@ class CapabilityInferenceEngine:
"planning": ["plan", "strategy", "roadmap", "organize"],
"market_research": ["market", "competition", "industry", "trends", "business", "company"],
"web_tools": ["website", "url", "web page", "online"],
"paperless_ngx": ["ngx", "paperless", "paperless-ngx", "document management", "docs in ngx", "paperless documents"],
"document_management": ["documents", "paperless", "ngx", "file management", "document search", "docs", "files", "archives"],
"document_search": ["search documents", "find documents", "document lookup", "search docs", "find files"],
"document_retrieval": ["get document", "retrieve document", "pull document", "fetch document", "document details"],
"metadata_management": ["document metadata", "tags", "document types", "correspondents", "document info"],
}
# Default capabilities when none are detected
@@ -131,15 +140,11 @@ class CapabilityInferenceEngine:
except (AttributeError, TypeError, ValueError) as e:
# Log the error for debugging but continue processing
import logging
logger = logging.getLogger(__name__)
logger.debug(f"Error processing message {type(msg).__name__}: {e}")
continue
except Exception as e:
# Catch any unexpected errors and return defaults
import logging
logger = logging.getLogger(__name__)
logger.warning(f"Unexpected error in message processing: {e}")
return cls.DEFAULT_CAPABILITIES.copy()

View File

@@ -89,6 +89,7 @@ from .numeric import (
# Import from text submodule
from .text import (
MultimodalContentHandler,
clean_extracted_text,
clean_text,
count_tokens,
@@ -180,6 +181,7 @@ __all__ = [
"extract_sentences",
"extract_text_from_multimodal_content",
"merge_extraction_results",
"MultimodalContentHandler",
"normalize_whitespace",
"remove_html_tags",
"truncate_text",

View File

@@ -1,7 +1,9 @@
"""Base classes and interfaces for extraction."""
import logging
from abc import ABC, abstractmethod
from typing import Any, Union, cast
from collections import defaultdict
from typing import Any, Iterable, cast
from bb_core import get_logger
@@ -9,6 +11,16 @@ from bb_extraction.core.types import JsonDict, JsonValue
logger = get_logger(__name__)
# Mapping from content type to log level for media/unsupported content
MEDIA_LOG_LEVELS = {
"image": logging.INFO,
"audio": logging.INFO,
"video": logging.INFO,
"file": logging.INFO,
"document": logging.INFO,
}
SUPPORTED_TEXT = {"text", "markdown", "plain_text"}
class BaseExtractor(ABC):
"""Abstract base class for extractors."""
@@ -79,31 +91,43 @@ def merge_extraction_results(results: list[JsonDict]) -> JsonDict:
return merged
def extract_text_from_multimodal_content(content: Union[list[Any], Any]) -> str:
"""Extract text content from multimodal content blocks.
def extract_text_from_multimodal_content(
content: str | dict | Iterable[Any],
context: str = ""
) -> str:
"""Extract text from multimodal content with inline dispatch and rate-limiting."""
counts = defaultdict(int)
pieces: list[str] = []
Handles various content formats including lists of content blocks
and single content items. Logs warnings for unsupported content types.
for block in (content if isinstance(content, Iterable) and not isinstance(content, (str, dict))
else [content]):
if isinstance(block, str):
pieces.append(block)
continue
Args:
content: Content to extract text from. Can be a list of content blocks,
a single string, or other content types.
if isinstance(block, dict):
t = block.get("type", "unknown")
if t in SUPPORTED_TEXT:
pieces.append(block.get("text", ""))
continue
Returns:
Extracted text content as a string.
"""
if isinstance(content, list):
# Extract text content from list of content blocks
text_content = ""
for content_block in content:
if isinstance(content_block, str):
text_content += content_block + " "
elif isinstance(content_block, dict) and content_block.get("type") == "text":
text_content += content_block.get("text", "") + " "
else:
# Log unsupported content block types for debugging
logger.debug(
f"Unsupported content block type in multimodal content: {type(content_block).__name__}"
)
return text_content.strip()
return str(content) if content else ""
# fallback for both dicts (non-text) and any other type
key = block.get("type", type(block).__name__) if isinstance(block, dict) else type(block).__name__
counts[key] += 1
lvl = MEDIA_LOG_LEVELS.get(key, logging.WARNING)
msg = f"Skipped {key} content#{counts[key]}{' in '+context if context else ''}"
logging.getLogger(__name__).log(lvl, msg)
return " ".join(pieces).strip()
class MultimodalContentHandler:
"""Simplified backwards-compatible handler that wraps the new function."""
def extract_text(self, content: str | dict | Iterable[Any], context: str = "") -> str:
"""Extract text from multimodal content (backwards compatibility wrapper)."""
return extract_text_from_multimodal_content(content, context)
# Global instance for backwards compatibility
_multimodal_handler = MultimodalContentHandler()

View File

@@ -1,14 +1,15 @@
"""Extract statistics from text content."""
import logging
import re
from datetime import UTC, datetime
from re import Pattern
from bb_core.logging import get_logger
from .models import ExtractedStatistic, StatisticType
from .quality import assess_quality
logger = logging.getLogger(__name__)
logger = get_logger(__name__)
class StatisticsExtractor:

View File

@@ -1,6 +1,7 @@
"""Text processing and structured extraction utilities."""
from ..core.base import (
MultimodalContentHandler,
extract_text_from_multimodal_content,
merge_extraction_results,
)
@@ -43,6 +44,7 @@ __all__ = [
"extract_sentences",
"extract_text_from_multimodal_content",
"merge_extraction_results",
"MultimodalContentHandler",
"normalize_whitespace",
"remove_html_tags",
"truncate_text",

View File

@@ -25,9 +25,6 @@ logger = get_logger(__name__)
def extract_json_from_text(text: str) -> JsonDict | None:
"""Extract JSON object from text containing markdown code blocks or JSON strings."""
# Log for debugging
import logging
logger = logging.getLogger(__name__)
# Use warning level for now to ensure visibility
logger.warning(f"extract_json_from_text called with text length: {len(text)}")
logger.warning(f"Text starts with: {repr(text[:100])}")

View File

@@ -5,9 +5,9 @@ supporting integration with agent frameworks.
"""
import json
import logging
from typing import Annotated, Any, cast
from bb_core.logging import get_logger
from bb_core.validation import chunk_text, is_valid_url, merge_chunk_results
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import tool
@@ -25,7 +25,7 @@ from .numeric.quality import (
rate_statistic_quality,
)
logger = logging.getLogger(__name__)
logger = get_logger(__name__)
class StatisticsExtractionInput(BaseModel):

View File

@@ -0,0 +1,144 @@
"""Unit tests for multimodal content handling."""
import unittest
from unittest.mock import patch
from bb_extraction.core.base import (
MultimodalContentHandler,
extract_text_from_multimodal_content,
)
class TestExtractTextFromMultimodalContent(unittest.TestCase):
"""Test the extract_text_from_multimodal_content function."""
def test_extract_text_from_string(self):
"""Test extraction from simple string."""
content = "Hello world"
result = extract_text_from_multimodal_content(content)
self.assertEqual(result, "Hello world")
def test_extract_text_from_empty_string(self):
"""Test extraction from empty string."""
content = ""
result = extract_text_from_multimodal_content(content)
self.assertEqual(result, "")
def test_extract_text_from_none(self):
"""Test extraction from None."""
content = None
result = extract_text_from_multimodal_content(content)
self.assertEqual(result, "")
def test_extract_text_from_list_with_strings(self):
"""Test extraction from list of strings."""
content = ["Hello", "world", "test"]
result = extract_text_from_multimodal_content(content)
self.assertEqual(result, "Hello world test")
def test_extract_text_from_deeply_nested_multimodal_content(self):
"""Test extraction from deeply nested multimodal content."""
content = [
"Hello",
[
"nested",
[
{"type": "text", "text": "deep"},
[
"structure",
{"type": "text", "text": "test"}
]
]
],
{"type": "text", "text": "done"}
]
result = extract_text_from_multimodal_content(content)
# The simplified implementation doesn't handle nested lists - they are treated as unsupported content
self.assertEqual(result, "Hello done")
def test_extract_text_from_list_with_text_blocks(self):
"""Test extraction from list with text type blocks."""
content = [
{"type": "text", "text": "Hello"},
{"type": "text", "text": "world"},
]
result = extract_text_from_multimodal_content(content)
self.assertEqual(result, "Hello world")
def test_extract_text_from_mixed_content(self):
"""Test extraction from mixed content types."""
content = [
"Direct string",
{"type": "text", "text": "Text block"},
{"type": "text", "text": "Another block"},
]
result = extract_text_from_multimodal_content(content)
self.assertEqual(result, "Direct string Text block Another block")
@patch('logging.getLogger')
def test_extract_text_with_unsupported_types(self, mock_get_logger):
"""Test extraction logs unsupported content types."""
mock_logger = mock_get_logger.return_value
content = [
{"type": "text", "text": "Valid text"},
{"type": "image", "url": "image.jpg"},
{"type": "audio", "url": "audio.mp3"},
123, # Invalid type
]
result = extract_text_from_multimodal_content(content)
self.assertEqual(result, "Valid text")
# Assert that the logger was called for each unsupported type
self.assertEqual(mock_logger.log.call_count, 3)
def test_extract_text_from_empty_list(self):
"""Test extraction from empty list."""
content = []
result = extract_text_from_multimodal_content(content)
self.assertEqual(result, "")
def test_extract_text_from_whitespace_and_empty_blocks(self):
"""Test extraction from list with only whitespace or empty text blocks."""
content = [
{"type": "text", "text": ""},
{"type": "text", "text": " "},
{"type": "text", "text": "\n\t"},
"",
" ",
]
result = extract_text_from_multimodal_content(content)
self.assertEqual(result, "")
class TestMultimodalContentHandler(unittest.TestCase):
"""Test the MultimodalContentHandler class."""
def setUp(self):
"""Set up test fixtures."""
self.handler = MultimodalContentHandler()
def test_extract_text_basic(self):
"""Test basic text extraction."""
content = [
{"type": "text", "text": "Hello"},
{"type": "markdown", "text": "world"},
]
result = self.handler.extract_text(content)
self.assertEqual(result, "Hello world")
def test_markdown_support(self):
"""Test that markdown type is supported."""
content = [{"type": "markdown", "text": "# Header"}]
result = self.handler.extract_text(content)
self.assertEqual(result, "# Header")
def test_plain_text_support(self):
"""Test that plain_text type is supported."""
content = [{"type": "plain_text", "text": "Plain text"}]
result = self.handler.extract_text(content)
self.assertEqual(result, "Plain text")
if __name__ == "__main__":
unittest.main() # type: ignore[misc]

View File

@@ -392,6 +392,7 @@ def _get_client_from_config(config: RunnableConfig | None = None) -> PaperlessNG
cast("dict[str, object]", config.get("configurable", {})) if config else {}
)
# Get config values, but ensure we pass None when not available to allow env var fallback
base_url = cast(
"str | None", configurable.get("paperless_base_url") if configurable else None
)
@@ -399,6 +400,10 @@ def _get_client_from_config(config: RunnableConfig | None = None) -> PaperlessNG
"str | None", configurable.get("paperless_token") if configurable else None
)
# Convert empty strings to None to allow environment variable fallback
base_url = base_url if base_url else None
token = token if token else None
return PaperlessNGXClient(base_url=base_url, token=token)

View File

@@ -21,6 +21,8 @@ from bb_tools.constants import JINA_GROUNDING_ENDPOINT
# Import shared logic from jina_old (until refactored into a shared module)
from .utils import RetryConfig, get_jina_api_key, make_request_with_retry
from bb_core.errors.formatter import ErrorMessageFormatter
# Define InjectedState as a TypeVar for type checking
# This avoids import issues with langgraph.prebuilt
InjectedState = TypeVar("InjectedState")
@@ -151,7 +153,8 @@ async def _grounding[InjectedState](
except Exception as e:
error_highlight(f"Invalid grounding request: {str(e)}", "JinaTool")
# Always raise ToolException with message matching test regex
raise ToolException(f"Invalid grounding request: {str(e)}") from None
sanitized_error = ErrorMessageFormatter.sanitize_error_message(e, for_user=True)
raise ToolException(f"Invalid grounding request: {sanitized_error}") from None
if should_use_cache:
cached = await _get_cached_grounding(cache_key)

View File

@@ -26,6 +26,8 @@ from .utils import (
make_request_with_retry,
)
from bb_core.errors.formatter import ErrorMessageFormatter
class RerankerRequest(BaseModel):
"""Request model for Jina's Re-Ranker API."""
@@ -99,7 +101,8 @@ async def _rerank(
# state=cast("dict[str, Any] | None", state),
# )
# _ = await llm_cache.get(cache_key)
raise ToolException(f"Invalid rerank request: {str(e)}") from e
sanitized_error = ErrorMessageFormatter.sanitize_error_message(e, for_user=True)
raise ToolException(f"Invalid rerank request: {sanitized_error}") from e
# Caching disabled - LLMCache not available in bb_core
# cached_data = await llm_cache.get(cache_key)

View File

@@ -9,7 +9,7 @@ from __future__ import annotations
import asyncio
import uuid
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, Type
from typing import TYPE_CHECKING, Annotated, Any, Literal, Type
from langchain.tools import BaseTool
from langchain_core.callbacks import AsyncCallbackManagerForToolRun, CallbackManagerForToolRun
@@ -141,7 +141,7 @@ class ResearchGraphTool(BaseTool):
async def _arun(
self,
*args: Any,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
run_manager: AsyncCallbackManagerForToolRun | None = None,
**kwargs: Any,
) -> str:
"""Asynchronously run the research graph.
@@ -220,7 +220,7 @@ class ResearchGraphTool(BaseTool):
def _run(
self,
*args: Any,
run_manager: Optional[CallbackManagerForToolRun] = None,
run_manager: CallbackManagerForToolRun | None = None,
**kwargs: Any,
) -> str:
"""Synchronous wrapper for the research graph.

View File

@@ -46,33 +46,7 @@ try:
except ImportError:
_jina_available = False
# Set up logging
logger = logging.getLogger(__name__)
# Try to import internal utilities
try:
pass # get_logger not used
except ImportError:
# Fallback if internal utilities are not available
def get_logger(name: str) -> logging.Logger:
"""Get a logger instance.
Args:
name: Logger name
Returns:
Configured logger instance
"""
logger = logging.getLogger(name)
if not logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
return logger
# Already have logger from get_logger import above
# Type variables for generic types
@@ -142,7 +116,7 @@ class FirecrawlStrategy(ScraperStrategyBase):
api_key: Optional API key for Firecrawl service
"""
self.api_key: str | None = api_key
self.logger: logging.Logger = logging.getLogger(__name__)
self.logger: logging.Logger = get_logger(__name__)
async def can_handle(self, url: str, **kwargs: object) -> bool:
"""Check if Firecrawl can handle the URL.

View File

@@ -19,7 +19,6 @@ components while providing a flexible orchestration layer.
"""
import asyncio
import re
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING
@@ -44,6 +43,7 @@ from biz_bud.config.loader import load_config
from biz_bud.config.schemas import AppConfig
from biz_bud.services.factory import ServiceFactory
from biz_bud.states.buddy import BuddyState
from bb_core.errors.formatter import ErrorMessageFormatter
if TYPE_CHECKING:
from langgraph.graph.graph import CompiledGraph
@@ -51,35 +51,6 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
def _sanitize_error_message(error: Exception) -> str:
"""Sanitize error messages to prevent sensitive information leakage.
Removes or masks common sensitive patterns like API keys, passwords,
file paths, and other potentially sensitive data from error messages.
Args:
error: The exception to sanitize
Returns:
Sanitized error message safe for user display
"""
message = str(error)
# Remove or mask sensitive patterns
# API keys, tokens, passwords
message = re.sub(r'(api_key|token|password|secret|auth|credential)[:=]\s*[\'"]*[^\s\'"]+', r'\1=***', message, flags=re.IGNORECASE)
# File paths (remove system paths)
message = re.sub(r'/[a-zA-Z0-9/_\-\.]+/', '/***/', message)
# Stack traces (keep only the error message part)
message = message.split('\n')[0] if '\n' in message else message
# Truncate very long messages
if len(message) > 200:
message = message[:200] + "..."
return message
__all__ = [
@@ -321,9 +292,13 @@ async def run_buddy_agent(
.build()
)
# Run configuration
# Run configuration with app config and service factory for tool access
run_config = RunnableConfig(
configurable={"thread_id": initial_state["thread_id"]},
configurable={
"thread_id": initial_state["thread_id"],
"app_config": config,
# Service factory will be accessed via global factory pattern
},
recursion_limit=1000,
)
@@ -345,7 +320,7 @@ async def run_buddy_agent(
except Exception as e:
error_highlight(f"Buddy agent failed: {str(e)}")
return _sanitize_error_message(e)
return ErrorMessageFormatter.sanitize_error_message(e, for_user=True)
async def stream_buddy_agent(
@@ -376,9 +351,13 @@ async def stream_buddy_agent(
.build()
)
# Run configuration
# Run configuration with app config and service factory for tool access
run_config = RunnableConfig(
configurable={"thread_id": initial_state["thread_id"]},
configurable={
"thread_id": initial_state["thread_id"],
"app_config": config,
# Service factory will be accessed via global factory pattern
},
recursion_limit=1000,
)
@@ -419,8 +398,11 @@ async def stream_buddy_agent(
yield update["final_response"]
except Exception as e:
error_highlight(f"Buddy agent streaming failed: {str(e)}")
yield f"Error: {str(e)}"
import logging
logging.exception("Buddy agent streaming failed with an exception")
sanitized_error = ErrorMessageFormatter.sanitize_error_message(e, for_user=True)
error_highlight(f"Buddy agent streaming failed: {sanitized_error}")
yield f"Error: {sanitized_error}"
# Export for LangGraph API

View File

@@ -459,13 +459,24 @@ class ResponseFormatter:
total_executions = len(execution_history)
successful_executions = sum(
1 for record in execution_history
if record["status"] == "completed"
if record.get("status") == "completed"
)
failed_executions = sum(
1 for record in execution_history
if record["status"] == "failed"
if record.get("status") == "failed"
)
# Check for records missing status and handle them explicitly
records_without_status = [
record for record in execution_history
if "status" not in record
]
if records_without_status:
raise ValueError(
f"Found {len(records_without_status)} execution records missing 'status' key. "
f"All execution records must include a 'status' key. Offending records: {records_without_status}"
)
# Build the response
response_parts = [
"# Buddy Orchestration Complete",

View File

@@ -411,7 +411,22 @@ async def buddy_orchestrator_node(
query_complexity = state.get("query_complexity", "complex")
# For simple queries, bypass the planner and do direct web search
if query_complexity == "simple" and not StateHelper.has_execution_plan(state):
# BUT respect intelligent tool selection if specific capabilities were detected
# Centralized capability names for document-related features
from biz_bud.constants import DOCUMENT_CAPABILITIES
detected_capabilities = state.get("available_capabilities", [])
# Ensure detected_capabilities is iterable to prevent type errors with 'any'
if not isinstance(detected_capabilities, (list, tuple, set)):
logger.warning(
f"'available_capabilities' in state is not a list/tuple/set (got {type(detected_capabilities).__name__}); resetting to empty list. Upstream data issue likely."
)
detected_capabilities = []
has_specific_capabilities = any(cap in detected_capabilities for cap in DOCUMENT_CAPABILITIES)
if (query_complexity == "simple" and
not StateHelper.has_execution_plan(state) and
not has_specific_capabilities):
logger.info("Simple query detected - using lightweight web search approach")
try:
@@ -523,6 +538,8 @@ async def buddy_orchestrator_node(
except Exception as e:
logger.warning(f"Lightweight search failed: {e}, falling back to planner")
elif has_specific_capabilities:
logger.info(f"Specific capabilities detected {detected_capabilities} - using intelligent tool selection instead of lightweight search")
# Complex queries or fallback: use the original planner approach
if not StateHelper.has_execution_plan(state):
@@ -551,9 +568,11 @@ async def buddy_orchestrator_node(
if execution_plan:
# Get the first executable step to start execution
first_step = None
if execution_plan["steps"]:
first_step = execution_plan["steps"][0]
if "steps" not in execution_plan:
logger.error("Execution plan is missing 'steps' key. Full execution plan: %s", execution_plan)
raise ValueError(f"Execution plan is missing 'steps' key. Execution plan: {execution_plan}")
steps = execution_plan.get("steps", [])
first_step = steps[0] if steps else None
if first_step:
return (
@@ -674,12 +693,39 @@ async def buddy_executor_node(
executor = tool_factory.create_graph_tool(graph_name)
result = await executor._arun(query=step_query, context=context)
# Process enhanced tool result if it includes messages
tool_messages = []
processed_result = result
if isinstance(result, dict) and result.get("tool_execution_complete"):
# This is an enhanced result with messages to merge
tool_messages = result.get("messages", [])
# Extract the actual result content
if "synthesis" in result:
processed_result = result["synthesis"]
elif "final_result" in result:
processed_result = result["final_result"]
elif "response" in result:
processed_result = result["response"]
elif "final_response" in result:
processed_result = result["final_response"]
elif "status_message" in result:
processed_result = result["status_message"]
else:
processed_result = result
logger.info(f"Tool execution for {step_id} returned {len(tool_messages)} messages to merge")
else:
# Regular result, no special message handling needed
processed_result = result
# Create execution record using factory
execution_record = ExecutionRecordFactory.create_success_record(
step_id=step_id,
graph_name=graph_name,
start_time=start_time,
result=result,
result=processed_result,
)
# Update state
@@ -691,16 +737,25 @@ async def buddy_executor_node(
completed_steps.append(step_id)
intermediate_results = dict(state.get("intermediate_results", {}))
intermediate_results[step_id] = result
intermediate_results[step_id] = processed_result
return (
# Prepare state updates
state_updates = (
updater.set("execution_history", execution_history)
.set("completed_step_ids", completed_steps)
.set("intermediate_results", intermediate_results)
.set("last_execution_status", "success")
.build()
)
# Add tool messages if we have any (they will be merged via add_messages reducer)
# Note: set("messages", tool_messages) works correctly with add_messages reducer -
# it will append these messages to existing ones, not overwrite them
if tool_messages:
logger.info(f"Adding {len(tool_messages)} messages from tool execution to state")
state_updates = state_updates.set("messages", tool_messages)
return state_updates.build()
except Exception as e:
# Create failed execution record using factory
failed_execution_record = ExecutionRecordFactory.create_failure_record(
@@ -820,29 +875,61 @@ async def buddy_synthesizer_node(
tool_factory = get_tool_factory()
synthesizer = tool_factory.create_node_tool("synthesize_search_results")
synthesis = await synthesizer._arun(
synthesis_result = await synthesizer._arun(
query=user_query,
extracted_info=extracted_info,
sources=sources,
)
# Process enhanced tool result if it includes messages
synthesis_messages = []
synthesis = synthesis_result
if isinstance(synthesis_result, dict) and synthesis_result.get("tool_execution_complete"):
# This is an enhanced result with messages to merge
synthesis_messages = synthesis_result.get("messages", [])
# Extract the actual synthesis content
if "synthesis" in synthesis_result:
synthesis = synthesis_result["synthesis"]
elif "final_result" in synthesis_result:
synthesis = synthesis_result["final_result"]
elif "response" in synthesis_result:
synthesis = synthesis_result["response"]
else:
synthesis = synthesis_result
logger.info(f"Synthesis tool returned {len(synthesis_messages)} messages to merge")
else:
# Regular result, no special message handling needed
synthesis = synthesis_result
# Format final response using formatter
final_response = ResponseFormatter.format_final_response(
query=user_query,
synthesis=synthesis,
synthesis=str(synthesis),
execution_history=state.get("execution_history", []),
completed_steps=state.get("completed_step_ids", []),
adaptation_count=state.get("adaptation_count", 0),
)
# Prepare state updates
updater = StateUpdater(dict(state))
return (
state_updates = (
updater.set("final_response", final_response)
.set("orchestration_phase", "completed")
.set("status", "success")
.build()
)
# Add synthesis messages if we have any (they will be merged via add_messages reducer)
# Note: set("messages", synthesis_messages) works correctly with add_messages reducer -
# it will append these messages to existing ones, not overwrite them
if synthesis_messages:
logger.info(f"Adding {len(synthesis_messages)} messages from synthesis to state")
state_updates = state_updates.set("messages", synthesis_messages)
return state_updates.build()
except Exception as e:
error_msg = f"Failed to synthesize results: {str(e)}"
updater = StateUpdater(dict(state))

View File

@@ -7,12 +7,15 @@ reducing duplication and improving consistency across the Buddy agent.
import uuid
from typing import Any, Literal
from bb_core.logging import get_logger
from bb_extraction import extract_text_from_multimodal_content
from langchain_core.messages import HumanMessage
from biz_bud.config.schemas import AppConfig
from biz_bud.states.buddy import BuddyState
logger = get_logger(__name__)
class BuddyStateBuilder:
"""Builder for creating BuddyState instances with sensible defaults.
@@ -169,8 +172,29 @@ class StateHelper:
The extracted query string, or empty string if not found
"""
# First try the direct user_query field
if state.get("user_query"):
return state["user_query"]
user_query = state.get("user_query")
if user_query:
if isinstance(user_query, str):
return user_query
elif isinstance(user_query, dict):
logger.warning(
f"Expected 'user_query' to be str, got dict. Value: {user_query!r}. "
"Consider serializing this object before assignment."
)
# Optionally, you could return json.dumps(user_query) if that's appropriate for your use case.
return ""
elif isinstance(user_query, list):
logger.warning(
f"Expected 'user_query' to be str, got list. Value: {user_query!r}. "
"Consider joining or serializing this list before assignment."
)
return ""
else:
logger.warning(
f"Expected 'user_query' to be str, got {type(user_query).__name__}. "
f"Value: {user_query!r}. Converting to str as a last resort."
)
return str(user_query)
# Then try to find in messages
messages = state.get("messages", [])
@@ -187,8 +211,7 @@ class StateHelper:
if isinstance(query_value, str):
return query_value
else:
import logging
logging.warning(f"Expected 'query' in context to be str, got {type(query_value).__name__}. Converting to str.")
logger.warning(f"Expected 'query' in context to be str, got {type(query_value).__name__}. Converting to str.")
return str(query_value)
return ""

View File

@@ -21,6 +21,15 @@ from langchain_core.messages import HumanMessage
from pydantic import BaseModel, Field, create_model
from biz_bud.registries import get_graph_registry, get_node_registry, get_tool_registry
from bb_core.errors.formatter import ErrorMessageFormatter
# Import configuration modules at module level to avoid repeated imports
try:
from biz_bud.config.loader import load_config
from bb_core.langgraph.runnable_config import create_runnable_config
CONFIG_IMPORTS_AVAILABLE = True
except ImportError:
CONFIG_IMPORTS_AVAILABLE = False
logger = get_logger(__name__)
@@ -95,8 +104,9 @@ class ToolFactory:
return self._format_result(result)
except Exception as e:
error_msg = f"Failed to execute {node_name}: {str(e)}"
logger.error(error_msg)
sanitized_error = ErrorMessageFormatter.sanitize_error_message(e, for_user=True)
error_msg = f"Failed to execute {node_name}: {sanitized_error}"
logger.error(f"Failed to execute {node_name}: {str(e)}") # Log full error for debugging
return error_msg
def _run(self, **kwargs: Any) -> str:
@@ -230,7 +240,11 @@ class ToolFactory:
model_config = {"arbitrary_types_allowed": True}
async def _arun(self, **kwargs: Any) -> str:
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._last_execution_result: Any = None # Store full result for message merging
async def _arun(self, **kwargs: Any) -> Any:
"""Execute the graph."""
try:
# Create graph instance
@@ -238,18 +252,24 @@ class ToolFactory:
# Prepare initial state
query = kwargs.get("query", "")
context = kwargs.get("context", {})
state = {
"messages": [HumanMessage(content=query)],
"query": query,
"user_query": query,
"initial_input": kwargs,
"config": {},
"context": kwargs.get("context", {}),
"context": context,
"errors": [],
"status": "running",
"run_metadata": {},
"thread_id": f"{graph_name}-{uuid.uuid4().hex[:8]}",
"is_last_step": False,
# Preserve orchestrator's tool selection to prevent sub-graphs from overriding
"available_capabilities": context.get("available_capabilities", []),
"selected_tools": context.get("selected_tools", []),
"tool_selection_reasoning": context.get("tool_selection_reasoning", ""),
}
# Add any additional kwargs to state
@@ -257,30 +277,85 @@ class ToolFactory:
if key not in state:
state[key] = value
# Execute graph
result = await graph.ainvoke(state)
# Execute graph with proper configuration
# Load global configuration for service integrations
if CONFIG_IMPORTS_AVAILABLE:
try:
app_config = load_config()
logger.debug(f"Loaded app_config type: {type(app_config)}")
run_config = create_runnable_config(app_config=app_config)
logger.debug(f"Created run_config type: {type(run_config)}")
logger.debug(f"Run config has configurable: {hasattr(run_config, 'configurable')}")
result = await graph.ainvoke(state, config=run_config)
except Exception as config_error:
logger.warning(f"Failed to load config for graph execution: {config_error}")
import traceback
logger.debug(f"Config error traceback: {traceback.format_exc()}")
# Fallback to execution without config
result = await graph.ainvoke(state)
else:
logger.warning("Configuration modules not available, executing graph without config")
result = await graph.ainvoke(state)
# Extract result based on graph type
# Store full result in tool context for potential message merging
self._last_execution_result = result
# Extract result based on graph type, but preserve message handling
if graph_name == "planner":
# Planner graphs return complex structured data that downstream components
# expect to parse as structured dictionaries. PlanParser.parse_planner_result()
# handles dict input directly, avoiding JSON serialization/parsing issues.
return result
elif "synthesis" in result:
return result["synthesis"]
return self._merge_result_with_messages("synthesis", result)
elif "final_result" in result:
return result["final_result"]
return self._merge_result_with_messages("final_result", result)
elif "response" in result:
return result["response"]
return self._merge_result_with_messages("response", result)
elif "final_response" in result:
# Handle paperless-style results that have final_response (special case)
final_response = result["final_response"]
if "messages" in result and result["messages"]:
return {
"final_response": final_response,
"messages": result["messages"],
"paperless_results": result.get("paperless_results", []),
"tool_execution_complete": True
}
return final_response
else:
return f"Graph execution completed. Status: {result.get('status', 'unknown')}"
# For any other graph, if it has messages, preserve them
status_msg = f"Graph execution completed. Status: {result.get('status', 'unknown')}"
if "messages" in result and result["messages"]:
return {
"status_message": status_msg,
"messages": result["messages"],
"tool_execution_complete": True,
"full_result": result
}
return status_msg
except Exception as e:
error_msg = f"Failed to execute graph {graph_name}: {str(e)}"
logger.error(error_msg)
return error_msg
def _run(self, **kwargs: Any) -> str:
def _merge_result_with_messages(self, result_key: str, result: dict[str, Any]) -> Any:
"""Helper to merge the main result with messages if present.
Returns a dict with the result and messages if messages exist,
otherwise returns the result value directly.
"""
main_result = result[result_key]
if "messages" in result and result["messages"]:
return {
result_key: main_result,
"messages": result["messages"],
"tool_execution_complete": True
}
return main_result
def _run(self, **kwargs: Any) -> Any:
"""Execute the graph synchronously."""
return asyncio.run(self._arun(**kwargs))

View File

@@ -75,6 +75,15 @@ class DatabaseConfigModel(BaseModel):
postgres_port: int | None = Field(
None, description="Port number for PostgreSQL database server."
)
postgres_min_pool_size: int = Field(
2, description="Minimum size of PostgreSQL connection pool."
)
postgres_max_pool_size: int = Field(
15, description="Maximum size of PostgreSQL connection pool."
)
postgres_command_timeout: int = Field(
10, description="Command timeout in seconds for PostgreSQL operations."
)
@field_validator("postgres_port", mode="after")
@classmethod
@@ -84,6 +93,30 @@ class DatabaseConfigModel(BaseModel):
raise ValueError("postgres_port must be between 1 and 65535")
return value
@field_validator("postgres_min_pool_size", mode="after")
@classmethod
def validate_postgres_min_pool_size(cls, value: int) -> int:
"""Validate PostgreSQL minimum pool size."""
if value < 1:
raise ValueError("postgres_min_pool_size must be >= 1")
return value
@field_validator("postgres_max_pool_size", mode="after")
@classmethod
def validate_postgres_max_pool_size(cls, value: int) -> int:
"""Validate PostgreSQL maximum pool size."""
if value < 1:
raise ValueError("postgres_max_pool_size must be >= 1")
return value
@field_validator("postgres_command_timeout", mode="after")
@classmethod
def validate_postgres_command_timeout(cls, value: int) -> int:
"""Validate PostgreSQL command timeout."""
if value < 1:
raise ValueError("postgres_command_timeout must be >= 1")
return value
@field_validator("default_page_size", mode="after")
@classmethod
def validate_default_page_size(cls, value: int) -> int:

View File

@@ -218,6 +218,42 @@ class AgentType(str, Enum):
return self.value
class CapabilityNames(str, Enum):
"""Enumeration of supported capability names in the Business Buddy framework.
This enumeration defines the different capabilities that can be detected and
used for tool selection and routing decisions throughout the system.
"""
# Document-related capabilities
PAPERLESS_NGX = "paperless_ngx"
DOCUMENT_MANAGEMENT = "document_management"
DOCUMENT_SEARCH = "document_search"
DOCUMENT_RETRIEVAL = "document_retrieval"
# Search and analysis capabilities
WEB_SEARCH = "web_search"
RESEARCH = "research"
ANALYSIS = "analysis"
# Data processing capabilities
TEXT_PROCESSING = "text_processing"
IMAGE_ANALYSIS = "image_analysis"
def __str__(self) -> str:
"""Return string representation of the capability name."""
return self.value
# Commonly used capability groups
DOCUMENT_CAPABILITIES: Final[list[str]] = [
CapabilityNames.PAPERLESS_NGX,
CapabilityNames.DOCUMENT_MANAGEMENT,
CapabilityNames.DOCUMENT_SEARCH,
CapabilityNames.DOCUMENT_RETRIEVAL,
]
# Note: Web-related API endpoints have been moved to bb_tools.constants
# System prompt for LLMs
@@ -239,6 +275,8 @@ UNREACHABLE_ERROR: Final = "Unreachable code path - this should never be execute
__all__ = [
"AgentType",
"CapabilityNames",
"DOCUMENT_CAPABILITIES",
"UNREACHABLE_ERROR",
"SYSTEM_PROMPT",
"DEFAULT_LANGUAGE",

View File

@@ -389,12 +389,14 @@ async def execution_planning_node(state: PlannerState) -> dict[str, Any]:
updated_execution_plan = execution_plan.copy()
updated_execution_plan["execution_mode"] = execution_mode
if first_step:
updated_execution_plan["current_step_id"] = first_step["id"]
if "id" not in first_step:
logger.warning("First step is missing 'id' key. Defaulting current_step_id to 'unknown'.")
updated_execution_plan["current_step_id"] = first_step.get("id", "unknown")
first_step["status"] = "pending"
# Determine next routing decision
routing_decision = "route_to_agent" if first_step else "no_steps_available"
next_agent = first_step["agent_name"] if first_step else None
next_agent = first_step.get("agent_name") if first_step else None
updater = StateUpdater(dict(state))
return (updater
@@ -666,7 +668,14 @@ async def execute_graph_node(state: PlannerState, config: RunnableConfig | None
updated_execution_plan = execution_plan.copy()
updated_execution_plan["completed_steps"] = completed_steps
updated_execution_plan["current_step_id"] = next_step["id"] if next_step else None
if next_step:
if "id" in next_step:
updated_execution_plan["current_step_id"] = next_step["id"]
else:
logger.warning("Next step is present but missing 'id': %s", next_step)
updated_execution_plan["current_step_id"] = None
else:
updated_execution_plan["current_step_id"] = None
if next_step:
# Reset routing depth for successful step progression to prevent infinite recursion issues
@@ -675,14 +684,20 @@ async def execute_graph_node(state: PlannerState, config: RunnableConfig | None
# Check if this is the same step being retried vs new step progression
current_step_id = state.get("execution_plan", {}).get("current_step_id")
is_step_progression = current_step_id != next_step["id"]
next_step_id = next_step.get("id")
# Handle case where next_step_id is None to avoid logic errors
if next_step_id is None:
logger.warning("Cannot determine step progression - next_step missing 'id'")
is_step_progression = False
else:
is_step_progression = current_step_id != next_step_id
routing_depth = 0 if is_step_progression else current_routing_depth
logger.info(
f"Routing from execute_graph to router. "
f"Step progression: {is_step_progression}, "
f"Current step: {current_step_id}, Next step: {next_step['id']}, "
f"Current step: {current_step_id}, Next step: {next_step_id}, "
f"Routing depth: {routing_depth}"
)

View File

@@ -374,8 +374,12 @@ def get_research_graph(
config_dict = config.model_dump()
# Use query from config if not provided
if not query and "inputs" in config_dict and "query" in config_dict["inputs"]:
query = config_dict["inputs"]["query"]
if not query:
inputs = config_dict.get("inputs", {})
if isinstance(inputs, dict):
query = inputs.get("query")
else:
query = None
# Create default initial state
default_state: ResearchState = {

View File

@@ -36,10 +36,11 @@ async def error_analyzer_node(state: ErrorHandlingState, config: RunnableConfig
Dictionary with error analysis results
"""
error = state["current_error"]
context = state["error_context"]
error = state.get("current_error")
context = state.get("error_context", {})
logger.info(f"Analyzing error from node: {context['node_name']}")
node_name = context.get('node_name', 'unknown') if context else 'unknown'
logger.info(f"Analyzing error from node: {node_name}")
# First, apply rule-based classification
initial_analysis = _rule_based_analysis(error, context)

View File

@@ -97,18 +97,37 @@ async def _get_paperless_llm(config: RunnableConfig | None = None) -> BaseChatMo
def _validate_paperless_config(config: RunnableConfig | None) -> dict[str, str]:
"""Validate and extract Paperless NGX configuration."""
if not config or "configurable" not in config:
raise ValueError("Paperless NGX configuration is missing")
import os
configurable = config.get("configurable", {})
# Try to get from config first
base_url = None
token = None
base_url = configurable.get("paperless_base_url")
if config and "configurable" in config:
configurable = config.get("configurable", {})
base_url = configurable.get("paperless_base_url")
token = configurable.get("paperless_token")
# Fallback to environment variables if not in config
if not base_url:
raise ValueError("Paperless NGX base URL is required in configuration")
token = configurable.get("paperless_token")
logger.warning(
"Paperless NGX base URL not found in config; falling back to PAPERLESS_BASE_URL environment variable. "
"Checked 'configurable.paperless_base_url' and PAPERLESS_BASE_URL env var."
)
base_url = os.getenv("PAPERLESS_BASE_URL")
if not token:
raise ValueError("Paperless NGX API token is required in configuration")
logger.warning(
"Paperless NGX token not found in config; falling back to PAPERLESS_TOKEN environment variable. "
"Checked 'configurable.paperless_token' and PAPERLESS_TOKEN env var."
)
token = os.getenv("PAPERLESS_TOKEN")
# Validate that we have the required configuration
if not base_url:
raise ValueError("Paperless NGX base URL is required but was not found in configuration.")
if not token:
raise ValueError("Paperless NGX API token is required but was not found in configuration.")
return {
"paperless_base_url": base_url,

View File

@@ -419,8 +419,9 @@ async def invoke_url_to_rag_node(state: RAGAgentState, config: RunnableConfig |
# Get the stream writer for this node
try:
writer = get_stream_writer()
except RuntimeError:
# Outside of a LangGraph runnable context (e.g., in tests)
except (RuntimeError, TypeError) as e:
# Outside of a LangGraph runnable context or stream initialization error
logger.debug(f"Stream writer unavailable: {e}")
writer = None
# Use the streaming version and forward updates

View File

@@ -16,6 +16,7 @@ from langchain_core.messages import SystemMessage
from langchain_core.runnables import RunnableConfig
from biz_bud.nodes.rag.utils import extract_collection_name
from bb_core.errors.formatter import ErrorMessageFormatter
try:
from langgraph.config import get_stream_writer
@@ -215,7 +216,8 @@ async def _r2r_direct_api_call(
) from e
except httpx.ConnectError as e:
logger.error(f"Connection error to R2R server at {url}")
raise Exception(f"Cannot connect to R2R server at {base_url}. Error: {str(e)}") from e
sanitized_error = ErrorMessageFormatter.sanitize_error_message(e, for_user=True)
raise Exception(f"Cannot connect to R2R server at {base_url}. Error: {sanitized_error}") from e
except Exception as e:
logger.error(f"Unexpected error calling R2R API: {e}")
raise
@@ -620,8 +622,9 @@ async def upload_to_r2r_node(state: URLToRAGState, config: RunnableConfig | None
"""
try:
writer = get_stream_writer() if get_stream_writer else None
except RuntimeError:
# Handle case where we're not in a runnable context (e.g., during tests)
except (RuntimeError, TypeError) as e:
# Handle case where we're not in a runnable context or stream has initialization errors
logger.debug(f"Stream writer unavailable: {e}")
writer = None
result: dict[str, Any]

View File

@@ -5,16 +5,16 @@ using LangGraph's interrupt functionality with improved type safety
and state management.
"""
import logging
from typing import Any, TypedDict, cast
from langchain_core.runnables import RunnableConfig
from langgraph.errors import NodeInterrupt
from bb_core.langgraph import standard_node
from bb_core.logging import get_logger
from biz_bud.states.unified import BusinessBuddyState
logger = logging.getLogger(__name__)
logger = get_logger(__name__)
def _is_error_info(obj: object) -> bool:

View File

@@ -151,9 +151,11 @@ class GraphRegistry(BaseRegistry[GraphFactoryProtocol]):
if factory_func:
# Create RegistryMetadata from GRAPH_METADATA
# Add input requirements as dependencies
dependencies = []
if "input_requirements" in metadata_dict:
dependencies = metadata_dict["input_requirements"]
input_requirements = metadata_dict.get("input_requirements", [])
if isinstance(input_requirements, list):
dependencies = input_requirements
else:
dependencies = []
metadata = RegistryMetadata.model_validate({
"name": metadata_dict.get("name", module_name),

View File

@@ -242,18 +242,31 @@ class NodeRegistry(BaseRegistry[NodeProtocol]):
reg_info = getattr(obj, "_registry_metadata", None)
if reg_info and reg_info.get("registry") == "nodes":
# Check if already registered before attempting registration
if reg_info["metadata"].name not in self.list_all():
self.register(
reg_info["metadata"].name,
obj,
reg_info["metadata"],
)
delattr(obj, "_registry_metadata")
discovered += 1
logger.debug(f"Registered decorated node: {name}")
metadata = reg_info.get("metadata")
if metadata and hasattr(metadata, 'name'):
# Additional validation for metadata structure
if not isinstance(metadata.name, str) or not metadata.name.strip():
logger.warning(f"Node {name} has invalid metadata.name (not a non-empty string): {metadata.name}")
continue
if metadata.name not in self.list_all():
try:
self.register(
metadata.name,
obj,
metadata,
)
delattr(obj, "_registry_metadata")
discovered += 1
logger.debug(f"Registered decorated node: {name}")
except Exception as e:
logger.error(f"Failed to register node {name} with metadata {metadata}: {e}")
else:
logger.debug(f"Skipping already registered decorated node: {name}")
delattr(obj, "_registry_metadata")
elif metadata:
logger.warning(f"Node {name} has metadata but missing 'name' attribute: {type(metadata).__name__}")
else:
logger.debug(f"Skipping already registered decorated node: {name}")
delattr(obj, "_registry_metadata")
logger.warning(f"Node {name} has empty metadata in registry info")
continue
# Check if it looks like a node function

View File

@@ -193,13 +193,20 @@ class PostgresStore(BaseService[PostgresStoreConfig]):
"database": config.postgres_db,
"user": config.postgres_user,
"password": config.postgres_password,
"min_size": 5,
"max_size": 20,
"min_size": getattr(config, "postgres_min_pool_size", 2), # Configurable, default 2
"max_size": getattr(config, "postgres_max_pool_size", 15), # Configurable, default 15
"command_timeout": getattr(config, "postgres_command_timeout", 10), # Configurable, default 10s
"server_settings": {
"application_name": "business_buddy",
"tcp_keepalives_idle": "600",
"tcp_keepalives_interval": "30",
"tcp_keepalives_count": "3",
},
}
# Add schema search_path if specified
if config.postgres_schema and config.postgres_schema != "public":
conn_params["server_settings"] = {"search_path": f"{config.postgres_schema},public"}
conn_params["server_settings"]["search_path"] = f"{config.postgres_schema},public"
self.pool = await asyncpg.create_pool(**conn_params)
info_success("Database connection pool initialized")

View File

@@ -111,6 +111,7 @@ from typing import TYPE_CHECKING, Any, TypeVar, cast
from bb_core import get_logger
from biz_bud.config.schemas import AppConfig
from biz_bud.services.base import BaseService
from biz_bud.types.llm import (
LLMCallKwargsTypedDict,
LLMErrorResponseTypedDict,
@@ -120,7 +121,6 @@ from biz_bud.types.llm import (
if TYPE_CHECKING:
from langchain_core.messages import AIMessage, BaseMessage
from biz_bud.services.base import BaseService
from biz_bud.services.db import PostgresStore
from biz_bud.services.llm import LangchainLLMClient
from biz_bud.services.redis_backend import RedisCacheBackend
@@ -282,8 +282,8 @@ class ServiceFactory:
# Create service instance
service = service_class(self._config)
# Initialize with timeout (default 60 seconds, configurable)
timeout = getattr(self._config, "service_init_timeout", 60.0)
# Initialize with timeout (reduced from 60 to 30 seconds for faster startup)
timeout = getattr(self._config, "service_init_timeout", 30.0)
await asyncio.wait_for(service.initialize(), timeout=timeout)
logger.info(f"Successfully initialized service: {service_class.__name__}")
@@ -432,6 +432,76 @@ class ServiceFactory:
# but added to satisfy static analysis tools
return cast("T", None)
def _partition_results(
self, results: dict[type[BaseService[Any]], object]
) -> tuple[dict[type[BaseService[Any]], BaseService[Any]], list[str]]:
"""Partition initialization results into succeeded and failed services."""
succeeded = {
k: cast("BaseService[Any]", v)
for k, v in results.items()
if not isinstance(v, Exception)
}
failed = [
k.__name__
for k, v in results.items()
if isinstance(v, Exception)
]
return succeeded, failed
async def initialize_services(self, service_classes: list[type[BaseService[Any]]]) -> dict[type[BaseService[Any]], BaseService[Any]]:
"""Initialize multiple services concurrently for faster startup.
This method optimizes startup time by initializing multiple services
in parallel rather than sequentially. Services that don't depend on
each other can be initialized concurrently.
Args:
service_classes: List of service classes to initialize
Returns:
Dictionary mapping service class to initialized service instance
Raises:
Exception: If any critical service fails to initialize
"""
logger.info(f"Initializing {len(service_classes)} services concurrently")
existing = {cls: self._services[cls] for cls in service_classes if cls in self._services}
pending = [cls for cls in service_classes if cls not in existing]
raw_results = await asyncio.gather(
*(self.get_service(cls) for cls in pending),
return_exceptions=True
)
results: dict[type[BaseService[Any]], object] = dict(zip(pending, raw_results))
succeeded, failed = self._partition_results(results)
for cls in succeeded:
logger.info(f"Successfully initialized {cls.__name__}")
if failed:
logger.error(f"Failed services: {', '.join(failed)}")
return {**existing, **succeeded}
async def initialize_critical_services(self) -> None:
"""Initialize only critical services required for basic application functionality.
Critical services are those needed for core operations:
- Database (PostgresStore) for data persistence
- Cache (RedisCacheBackend) for performance
Non-critical services like vector store, extraction services are initialized lazily.
This optimizes startup time by deferring heavy initialization until needed.
"""
from biz_bud.services.db import PostgresStore
from biz_bud.services.redis_backend import RedisCacheBackend
critical_services: list[type[BaseService[Any]]] = [PostgresStore, RedisCacheBackend]
logger.info("Initializing critical services for faster startup")
await self.initialize_services(critical_services)
logger.info("Critical services initialized successfully")
async def cleanup(self) -> None:
"""Cleanup all services.

View File

@@ -65,12 +65,13 @@ Error Handling:
"""
import asyncio
import logging
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Protocol, cast
logger = logging.getLogger(__name__)
from bb_core.logging import get_logger
logger = get_logger(__name__)
class SingletonState(Enum):

View File

@@ -188,6 +188,7 @@ from __future__ import annotations
import asyncio
import functools
import os
from datetime import datetime
from typing import (
TYPE_CHECKING,
@@ -362,11 +363,42 @@ class VectorStore(BaseService[VectorStoreConfig]):
self.client: QdrantClient | None = None
self._app_config = app_config
def _warn_if_low_timeout(self, timeout: int) -> None:
"""Warn if timeout is set to a low value."""
if timeout < 10:
logger.warning(
"Vector store operation timeout is set to a low value (%d seconds). "
"This may cause failures for large or slow queries. Consider increasing it if you encounter timeouts.",
timeout,
)
def _get_operation_timeout(self) -> int:
"""Get the operation timeout from config or use default."""
"""Get the operation timeout from environment, config, or use default."""
env_timeout = os.getenv("VECTOR_STORE_OPERATION_TIMEOUT")
if env_timeout is not None:
try:
timeout = int(env_timeout)
except ValueError:
logger.warning("Invalid VECTOR_STORE_OPERATION_TIMEOUT value: %s. Falling back to config/default.", env_timeout)
timeout = None
else:
if timeout <= 0:
logger.warning(
"VECTOR_STORE_OPERATION_TIMEOUT must be a positive integer. Got %d. Falling back to config/default.",
timeout,
)
timeout = None
else:
self._warn_if_low_timeout(timeout)
return timeout
if hasattr(self._app_config, "vector_store_enhanced"):
return getattr(self._app_config.vector_store_enhanced, "operation_timeout", 10)
return 10
timeout = getattr(self._app_config.vector_store_enhanced, "operation_timeout", 5)
else:
timeout = 5 # Reduced from 10 to 5 seconds for faster startup
self._warn_if_low_timeout(timeout)
return timeout
@classmethod
def _validate_config(cls, app_config: AppConfig) -> VectorStoreConfig:
@@ -758,13 +790,17 @@ class VectorStore(BaseService[VectorStoreConfig]):
# Filter by score threshold and format results
formatted_results = []
for result in results:
if result["score"] >= score_threshold:
score = result.get("score", 0.0)
if score == 0.0 and "score" not in result:
logger.warning("Vector search result missing score - defaulting to 0.0: %s", result.get("id", "unknown"))
if score >= score_threshold:
metadata = result.get("metadata", {})
formatted_results.append(
{
"content": result["metadata"].get("content", ""),
"score": result["score"],
"metadata": result["metadata"],
"vector_id": result["id"],
"content": metadata.get("content", ""),
"score": score,
"metadata": metadata,
"vector_id": result.get("id", ""),
}
)

View File

@@ -9,6 +9,7 @@ from __future__ import annotations
import argparse
import asyncio
import json
import logging
import sys
from pathlib import Path
from typing import Any
@@ -392,7 +393,6 @@ async def main() -> int:
# Configure logging
if args.verbose:
import logging
logging.getLogger("bb_core").setLevel(logging.DEBUG)
logging.getLogger("biz_bud").setLevel(logging.DEBUG)
@@ -413,4 +413,4 @@ async def main() -> int:
if __name__ == "__main__":
exit_code = asyncio.run(main())
sys.exit(exit_code)
sys.exit(exit_code)

View File

@@ -8,7 +8,6 @@ deployment.
import os
import sys
import logging
from contextlib import asynccontextmanager
from typing import cast
@@ -17,12 +16,11 @@ from starlette.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from bb_core.logging import get_logger
from biz_bud.config.loader import load_config
from biz_bud.services.factory import get_global_factory
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger = get_logger(__name__)
class HealthResponse(BaseModel):
@@ -62,24 +60,14 @@ async def lifespan(app: FastAPI):
setattr(app.state, 'config', config)
setattr(app.state, 'service_factory', service_factory)
# Verify critical services are available
logger.info("Verifying service connectivity...")
# Test database connection if configured
if hasattr(config, 'database_config') and config.database_config:
try:
await service_factory.get_db_service()
logger.info("Database service initialized successfully")
except Exception as e:
logger.warning(f"Database service initialization failed: {e}")
# Test Redis connection if configured
if hasattr(config, 'redis_config') and config.redis_config:
try:
await service_factory.get_redis_cache()
logger.info("Redis service initialized successfully")
except Exception as e:
logger.warning(f"Redis service initialization failed: {e}")
# Initialize critical services in parallel for faster startup
logger.info("Initializing critical services...")
try:
await service_factory.initialize_critical_services()
logger.info("All critical services initialized successfully")
except Exception as e:
logger.warning(f"Some critical services failed to initialize: {e}")
# Continue startup even if some services fail - they will be lazily initialized when needed
logger.info("Business Buddy application started successfully")

View File

@@ -137,6 +137,9 @@ class TestNetworkFailures:
postgres_db="test_db",
postgres_user="test_user",
postgres_password="test_pass",
postgres_min_pool_size=2,
postgres_max_pool_size=15,
postgres_command_timeout=10,
qdrant_host="localhost",
qdrant_port=6333,
qdrant_api_key="test-key",

View File

@@ -149,6 +149,9 @@ def database_config() -> DatabaseConfigModel:
postgres_db="test_db",
postgres_user="test_user",
postgres_password="test_pass",
postgres_min_pool_size=2,
postgres_max_pool_size=15,
postgres_command_timeout=10,
qdrant_host="localhost",
qdrant_port=6333,
qdrant_api_key=None,

View File

@@ -219,6 +219,9 @@ def create_database_config(
postgres_db=postgres_db,
postgres_user=postgres_user,
postgres_password=postgres_password,
postgres_min_pool_size=2,
postgres_max_pool_size=15,
postgres_command_timeout=10,
qdrant_host=qdrant_host,
qdrant_port=qdrant_port,
qdrant_api_key=qdrant_api_key,

View File

@@ -331,7 +331,7 @@ class TestBuddyPerformanceHandoffs:
import os
process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss
initial_memory = process.memory_info().rss # type: ignore[misc]
with patch("biz_bud.agents.tool_factory.get_tool_factory") as mock_factory:
mock_planner = AsyncMock()
@@ -368,7 +368,7 @@ class TestBuddyPerformanceHandoffs:
result = await buddy_graph.ainvoke(performance_state)
final_memory = process.memory_info().rss
final_memory = process.memory_info().rss # type: ignore[misc]
memory_increase = final_memory - initial_memory
# Memory increase should be reasonable (< 50MB)

View File

@@ -292,17 +292,33 @@ class TestIntrospectionDetection:
{"type": "text", "text": " capabilities"}
]
# Extract text content
query_text = ""
for content_block in multimodal_query:
if isinstance(content_block, dict) and content_block.get("type") == "text":
query_text += content_block.get("text", "") + " "
query_text = query_text.strip()
# Extract text content using centralized utility
from bb_extraction import extract_text_from_multimodal_content
query_text = extract_text_from_multimodal_content(multimodal_query)
introspection_keywords = ["capabilities"]
is_introspection = any(keyword in query_text.lower() for keyword in introspection_keywords)
assert is_introspection is True
def test_multimodal_content_with_only_unsupported_types(self):
"""Test multimodal content extraction with only unsupported types returns empty string."""
from bb_extraction import extract_text_from_multimodal_content
# Content with only unsupported types
multimodal_query = [
{"type": "image", "url": "image.jpg"},
{"type": "audio", "url": "audio.mp3"},
{"type": "video", "url": "video.mp4"}
]
query_text = extract_text_from_multimodal_content(multimodal_query)
assert query_text == ""
# Should not trigger introspection since no text content
introspection_keywords = ["capabilities", "tools", "functions"]
is_introspection = any(keyword in query_text.lower() for keyword in introspection_keywords)
assert is_introspection is False
class TestCapabilityDiscoveryIntegration:
"""Test integration with capability discovery for introspection."""
@@ -359,4 +375,4 @@ class TestCapabilityDiscoveryIntegration:
assert result["orchestration_phase"] == "synthesizing"
assert result["is_capability_introspection"] is True
assert "extracted_info" in result
assert "sources" in result
assert "sources" in result

View File

@@ -272,6 +272,9 @@ class TestDatabaseConfig:
postgres_password=None,
postgres_db=None,
postgres_host=None,
postgres_min_pool_size=2,
postgres_max_pool_size=15,
postgres_command_timeout=10,
postgres_port=0, # Invalid port
)
@@ -287,9 +290,92 @@ class TestDatabaseConfig:
postgres_password=None,
postgres_db=None,
postgres_host=None,
postgres_min_pool_size=2,
postgres_max_pool_size=15,
postgres_command_timeout=10,
postgres_port=70000, # Port too high
)
def test_postgres_pool_size_validation(self) -> None:
"""Test PostgreSQL pool size validation."""
# Test negative min pool size
with pytest.raises(ValidationError):
DatabaseConfigModel(
qdrant_host=None,
qdrant_port=None,
qdrant_api_key=None,
default_page_size=100,
max_page_size=1000,
qdrant_collection_name=None,
postgres_user=None,
postgres_password=None,
postgres_db=None,
postgres_host=None,
postgres_min_pool_size=-1, # Invalid negative pool size
postgres_max_pool_size=15,
postgres_command_timeout=10,
postgres_port=5432,
)
# Test zero max pool size
with pytest.raises(ValidationError):
DatabaseConfigModel(
qdrant_host=None,
qdrant_port=None,
qdrant_api_key=None,
default_page_size=100,
max_page_size=1000,
qdrant_collection_name=None,
postgres_user=None,
postgres_password=None,
postgres_db=None,
postgres_host=None,
postgres_min_pool_size=2,
postgres_max_pool_size=0, # Invalid zero pool size
postgres_command_timeout=10,
postgres_port=5432,
)
def test_postgres_timeout_validation(self) -> None:
"""Test PostgreSQL timeout validation."""
# Test negative timeout
with pytest.raises(ValidationError):
DatabaseConfigModel(
qdrant_host=None,
qdrant_port=None,
qdrant_api_key=None,
default_page_size=100,
max_page_size=1000,
qdrant_collection_name=None,
postgres_user=None,
postgres_password=None,
postgres_db=None,
postgres_host=None,
postgres_min_pool_size=2,
postgres_max_pool_size=15,
postgres_command_timeout=-5, # Invalid negative timeout
postgres_port=5432,
)
# Test zero timeout
with pytest.raises(ValidationError):
DatabaseConfigModel(
qdrant_host=None,
qdrant_port=None,
qdrant_api_key=None,
default_page_size=100,
max_page_size=1000,
qdrant_collection_name=None,
postgres_user=None,
postgres_password=None,
postgres_db=None,
postgres_host=None,
postgres_min_pool_size=2,
postgres_max_pool_size=15,
postgres_command_timeout=0, # Invalid zero timeout
postgres_port=5432,
)
class TestRedisConfig:
"""Test Redis configuration validation."""

View File

@@ -112,6 +112,9 @@ async def test_find_affected_catalog_items_node_success(mock_catalog_items):
postgres_db="test_db",
postgres_user="test_user",
postgres_password="test_pass",
postgres_min_pool_size=2,
postgres_max_pool_size=15,
postgres_command_timeout=10,
qdrant_host="localhost",
qdrant_port=6333,
qdrant_api_key=None,
@@ -194,6 +197,9 @@ async def test_find_affected_catalog_items_node_error():
postgres_db="test_db",
postgres_user="test_user",
postgres_password="test_pass",
postgres_min_pool_size=2,
postgres_max_pool_size=15,
postgres_command_timeout=10,
qdrant_host="localhost",
qdrant_port=6333,
qdrant_api_key=None,
@@ -257,6 +263,9 @@ async def test_batch_analyze_components_node_success(batch_component_reports, ma
postgres_db="test_db",
postgres_user="test_user",
postgres_password="test_pass",
postgres_min_pool_size=2,
postgres_max_pool_size=15,
postgres_command_timeout=10,
qdrant_host="localhost",
qdrant_port=6333,
qdrant_api_key=None,
@@ -350,6 +359,9 @@ async def test_batch_analyze_components_node_error():
postgres_db="test_db",
postgres_user="test_user",
postgres_password="test_pass",
postgres_min_pool_size=2,
postgres_max_pool_size=15,
postgres_command_timeout=10,
qdrant_host="localhost",
qdrant_port=6333,
qdrant_api_key=None,

View File

@@ -30,6 +30,9 @@ def create_test_config() -> AppConfig:
postgres_db="testdb",
postgres_user="user",
postgres_password="pass",
postgres_min_pool_size=2,
postgres_max_pool_size=15,
postgres_command_timeout=10,
qdrant_host="localhost",
qdrant_port=6333,
qdrant_api_key=None,

View File

@@ -71,6 +71,9 @@ class TestServiceFactory:
postgres_db="test",
postgres_user="user",
postgres_password="pass",
postgres_min_pool_size=2,
postgres_max_pool_size=15,
postgres_command_timeout=10,
qdrant_host="localhost",
qdrant_port=6333,
qdrant_api_key=None,