225 lines
8.2 KiB
Python
225 lines
8.2 KiB
Python
"""Application settings and configuration."""
|
|
|
|
from functools import lru_cache
|
|
from typing import Annotated, ClassVar, Literal
|
|
|
|
from prefect.variables import Variable
|
|
from pydantic import Field, HttpUrl, model_validator
|
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
|
|
|
|
class Settings(BaseSettings):
|
|
"""Application settings."""
|
|
|
|
model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
|
|
env_file=".env",
|
|
env_file_encoding="utf-8",
|
|
case_sensitive=False,
|
|
extra="ignore", # Ignore extra environment variables
|
|
)
|
|
|
|
# API Keys
|
|
firecrawl_api_key: str | None = None
|
|
openwebui_api_key: str | None = None
|
|
weaviate_api_key: str | None = None
|
|
r2r_api_key: str | None = None
|
|
|
|
# Endpoints
|
|
llm_endpoint: HttpUrl = HttpUrl("http://llm.lab")
|
|
weaviate_endpoint: HttpUrl = HttpUrl("http://weaviate.yo")
|
|
openwebui_endpoint: HttpUrl = HttpUrl("http://chat.lab") # This will be the API URL
|
|
firecrawl_endpoint: HttpUrl = HttpUrl("http://crawl.lab:30002")
|
|
r2r_endpoint: HttpUrl | None = Field(default=None, alias="r2r_api_url")
|
|
|
|
# Model Configuration
|
|
embedding_model: str = "ollama/bge-m3:latest"
|
|
embedding_dimension: int = 1024
|
|
|
|
# Ingestion Settings
|
|
default_batch_size: Annotated[int, Field(gt=0, le=500)] = 50
|
|
max_file_size: int = 1_000_000
|
|
max_crawl_depth: Annotated[int, Field(ge=1, le=20)] = 5
|
|
max_crawl_pages: Annotated[int, Field(ge=1, le=1000)] = 100
|
|
|
|
# Storage Settings
|
|
default_storage_backend: Literal["weaviate", "open_webui", "r2r"] = "weaviate"
|
|
default_collection_prefix: str = "docs"
|
|
|
|
# Prefect Settings
|
|
prefect_api_url: HttpUrl | None = None
|
|
prefect_api_key: str | None = None
|
|
prefect_work_pool: str = "default"
|
|
|
|
# Scheduling Defaults
|
|
default_schedule_interval: Annotated[int, Field(ge=1, le=10080)] = 60 # Max 1 week
|
|
|
|
# Performance Settings
|
|
max_concurrent_tasks: Annotated[int, Field(ge=1, le=20)] = 5
|
|
request_timeout: Annotated[int, Field(ge=10, le=300)] = 60
|
|
|
|
# Logging
|
|
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO"
|
|
|
|
def get_storage_endpoint(self, backend: str) -> HttpUrl:
|
|
"""
|
|
Get endpoint for storage backend.
|
|
|
|
Args:
|
|
backend: Storage backend name
|
|
|
|
Returns:
|
|
Endpoint URL
|
|
|
|
Raises:
|
|
ValueError: If backend is unknown or R2R endpoint not configured
|
|
"""
|
|
endpoints = {
|
|
"weaviate": self.weaviate_endpoint,
|
|
"open_webui": self.openwebui_endpoint,
|
|
}
|
|
|
|
if backend in endpoints:
|
|
return endpoints[backend]
|
|
elif backend == "r2r":
|
|
if not self.r2r_endpoint:
|
|
raise ValueError(
|
|
"R2R_API_URL must be set in environment variables. "
|
|
"This should have been caught during settings validation."
|
|
)
|
|
return self.r2r_endpoint
|
|
else:
|
|
raise ValueError(f"Unknown backend: {backend}. Supported: weaviate, open_webui, r2r")
|
|
|
|
def get_api_key(self, service: str) -> str | None:
|
|
"""
|
|
Get API key for service.
|
|
|
|
Args:
|
|
service: Service name
|
|
|
|
Returns:
|
|
API key or None
|
|
"""
|
|
service_map = {
|
|
"firecrawl": self.firecrawl_api_key,
|
|
"openwebui": self.openwebui_api_key,
|
|
"weaviate": self.weaviate_api_key,
|
|
"r2r": self.r2r_api_key,
|
|
}
|
|
return service_map.get(service)
|
|
|
|
@model_validator(mode="after")
|
|
def validate_backend_configuration(self) -> "Settings":
|
|
"""Validate that required configuration is present for the default backend."""
|
|
backend = self.default_storage_backend
|
|
|
|
# Validate R2R backend configuration
|
|
if backend == "r2r" and not self.r2r_endpoint:
|
|
raise ValueError(
|
|
"R2R_API_URL must be set in environment variables when using R2R as default backend"
|
|
)
|
|
|
|
# Validate API key requirements (optional warning for missing keys)
|
|
required_keys = {
|
|
"weaviate": ("WEAVIATE_API_KEY", self.weaviate_api_key),
|
|
"open_webui": ("OPENWEBUI_API_KEY", self.openwebui_api_key),
|
|
"r2r": ("R2R_API_KEY", self.r2r_api_key),
|
|
}
|
|
|
|
if backend in required_keys:
|
|
key_name, key_value = required_keys[backend]
|
|
if not key_value:
|
|
import warnings
|
|
warnings.warn(
|
|
f"{key_name} not set - authentication may fail for {backend} backend",
|
|
UserWarning,
|
|
stacklevel=2
|
|
)
|
|
|
|
return self
|
|
|
|
|
|
@lru_cache
|
|
def get_settings() -> Settings:
|
|
"""
|
|
Get cached settings instance.
|
|
|
|
Returns:
|
|
Settings instance
|
|
"""
|
|
return Settings()
|
|
|
|
|
|
class PrefectVariableConfig:
|
|
"""Helper class for managing Prefect variables with fallbacks to settings."""
|
|
|
|
def __init__(self) -> None:
|
|
self._settings: Settings = get_settings()
|
|
self._variable_names: list[str] = [
|
|
"default_batch_size", "max_file_size", "max_crawl_depth", "max_crawl_pages",
|
|
"default_storage_backend", "default_collection_prefix", "max_concurrent_tasks",
|
|
"request_timeout", "default_schedule_interval"
|
|
]
|
|
|
|
def _get_fallback_value(self, name: str, default_value: object = None) -> object:
|
|
"""Get fallback value from settings or default."""
|
|
return default_value or getattr(self._settings, name, default_value)
|
|
|
|
def get_with_fallback(self, name: str, default_value: str | int | float | None = None) -> str | int | float | None:
|
|
"""Get variable value with fallback synchronously."""
|
|
fallback = self._get_fallback_value(name, default_value)
|
|
# Ensure fallback is a type that Variable expects
|
|
variable_fallback = str(fallback) if fallback is not None else None
|
|
try:
|
|
result = Variable.get(name, default=variable_fallback)
|
|
# Variable can return various types, convert to our expected types
|
|
if isinstance(result, (str, int, float)):
|
|
return result
|
|
elif result is None:
|
|
return None
|
|
else:
|
|
# Convert other types to string
|
|
return str(result)
|
|
except Exception:
|
|
# Return fallback with proper type
|
|
if isinstance(fallback, (str, int, float)) or fallback is None:
|
|
return fallback
|
|
return str(fallback) if fallback is not None else None
|
|
|
|
async def get_with_fallback_async(self, name: str, default_value: str | int | float | None = None) -> str | int | float | None:
|
|
"""Get variable value with fallback asynchronously."""
|
|
fallback = self._get_fallback_value(name, default_value)
|
|
variable_fallback = str(fallback) if fallback is not None else None
|
|
try:
|
|
result = await Variable.aget(name, default=variable_fallback)
|
|
# Variable can return various types, convert to our expected types
|
|
if isinstance(result, (str, int, float)):
|
|
return result
|
|
elif result is None:
|
|
return None
|
|
else:
|
|
# Convert other types to string
|
|
return str(result)
|
|
except Exception:
|
|
# Return fallback with proper type
|
|
if isinstance(fallback, (str, int, float)) or fallback is None:
|
|
return fallback
|
|
return str(fallback) if fallback is not None else None
|
|
|
|
def get_ingestion_config(self) -> dict[str, str | int | float | None]:
|
|
"""Get all ingestion-related configuration variables synchronously."""
|
|
return {name: self.get_with_fallback(name) for name in self._variable_names}
|
|
|
|
async def get_ingestion_config_async(self) -> dict[str, str | int | float | None]:
|
|
"""Get all ingestion-related configuration variables asynchronously."""
|
|
result: dict[str, str | int | float | None] = {}
|
|
for name in self._variable_names:
|
|
result[name] = await self.get_with_fallback_async(name)
|
|
return result
|
|
|
|
|
|
@lru_cache
|
|
def get_prefect_config() -> PrefectVariableConfig:
|
|
"""Get cached Prefect variable configuration helper."""
|
|
return PrefectVariableConfig()
|