522 lines
21 KiB
Python
522 lines
21 KiB
Python
"""Storage management utilities for TUI applications."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from collections.abc import AsyncGenerator, Coroutine, Sequence
|
|
from typing import TYPE_CHECKING, Protocol
|
|
|
|
from pydantic import SecretStr
|
|
|
|
from ....core.exceptions import StorageError
|
|
from ....core.models import Document, StorageBackend, StorageConfig
|
|
from ....storage.base import BaseStorage
|
|
from ....storage.openwebui import OpenWebUIStorage
|
|
from ....storage.r2r.storage import R2RStorage
|
|
from ....storage.weaviate import WeaviateStorage
|
|
from ..models import CollectionInfo, StorageCapabilities
|
|
|
|
if TYPE_CHECKING:
|
|
from ....config.settings import Settings
|
|
|
|
|
|
class StorageBackendProtocol(Protocol):
|
|
"""Protocol defining storage backend interface."""
|
|
|
|
async def initialize(self) -> None: ...
|
|
async def count(self, *, collection_name: str | None = None) -> int: ...
|
|
async def list_collections(self) -> list[str]: ...
|
|
async def search(
|
|
self,
|
|
query: str,
|
|
limit: int = 10,
|
|
threshold: float = 0.7,
|
|
*,
|
|
collection_name: str | None = None,
|
|
) -> AsyncGenerator[Document, None]: ...
|
|
async def close(self) -> None: ...
|
|
|
|
|
|
class MultiStorageAdapter(BaseStorage):
|
|
"""Mirror writes to multiple storage backends."""
|
|
|
|
def __init__(self, storages: Sequence[BaseStorage]) -> None:
|
|
if not storages:
|
|
raise ValueError("MultiStorageAdapter requires at least one storage backend")
|
|
|
|
unique: list[BaseStorage] = []
|
|
seen_ids: set[int] = set()
|
|
for storage in storages:
|
|
storage_id = id(storage)
|
|
if storage_id in seen_ids:
|
|
continue
|
|
seen_ids.add(storage_id)
|
|
unique.append(storage)
|
|
|
|
self._storages: list[BaseStorage] = unique
|
|
self._primary: BaseStorage = unique[0]
|
|
super().__init__(self._primary.config)
|
|
|
|
async def initialize(self) -> None:
|
|
for storage in self._storages:
|
|
await storage.initialize()
|
|
|
|
async def store(self, document: Document, *, collection_name: str | None = None) -> str:
|
|
# Store in primary backend first
|
|
primary_id: str = await self._primary.store(document, collection_name=collection_name)
|
|
|
|
# Replicate to secondary backends concurrently
|
|
if len(self._storages) > 1:
|
|
|
|
async def replicate_to_backend(
|
|
storage: BaseStorage,
|
|
) -> tuple[BaseStorage, bool, Exception | None]:
|
|
try:
|
|
await storage.store(document, collection_name=collection_name)
|
|
return storage, True, None
|
|
except Exception as exc:
|
|
return storage, False, exc
|
|
|
|
tasks = [replicate_to_backend(storage) for storage in self._storages[1:]]
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
failures: list[str] = []
|
|
errors: list[Exception] = []
|
|
|
|
for result in results:
|
|
if isinstance(result, tuple):
|
|
storage, success, error = result
|
|
if not success and error is not None:
|
|
failures.append(self._format_backend_label(storage))
|
|
errors.append(error)
|
|
elif isinstance(result, Exception):
|
|
failures.append("unknown")
|
|
errors.append(result)
|
|
|
|
if failures:
|
|
backends = ", ".join(failures)
|
|
primary_error = errors[0] if errors else Exception("Unknown replication error")
|
|
raise StorageError(
|
|
f"Document stored in primary backend but replication failed for: {backends}"
|
|
) from primary_error
|
|
|
|
return primary_id
|
|
|
|
async def store_batch(
|
|
self, documents: list[Document], *, collection_name: str | None = None
|
|
) -> list[str]:
|
|
# Store in primary backend first
|
|
primary_ids: list[str] = await self._primary.store_batch(
|
|
documents, collection_name=collection_name
|
|
)
|
|
|
|
# Replicate to secondary backends concurrently
|
|
if len(self._storages) > 1:
|
|
|
|
async def replicate_batch_to_backend(
|
|
storage: BaseStorage,
|
|
) -> tuple[BaseStorage, bool, Exception | None]:
|
|
try:
|
|
await storage.store_batch(documents, collection_name=collection_name)
|
|
return storage, True, None
|
|
except Exception as exc:
|
|
return storage, False, exc
|
|
|
|
tasks = [replicate_batch_to_backend(storage) for storage in self._storages[1:]]
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
failures: list[str] = []
|
|
errors: list[Exception] = []
|
|
|
|
for result in results:
|
|
if isinstance(result, tuple):
|
|
storage, success, error = result
|
|
if not success and error is not None:
|
|
failures.append(self._format_backend_label(storage))
|
|
errors.append(error)
|
|
elif isinstance(result, Exception):
|
|
failures.append("unknown")
|
|
errors.append(result)
|
|
|
|
if failures:
|
|
backends = ", ".join(failures)
|
|
primary_error = (
|
|
errors[0] if errors else Exception("Unknown batch replication error")
|
|
)
|
|
raise StorageError(
|
|
f"Batch stored in primary backend but replication failed for: {backends}"
|
|
) from primary_error
|
|
|
|
return primary_ids
|
|
|
|
async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool:
|
|
# Delete from primary backend first
|
|
primary_deleted: bool = await self._primary.delete(
|
|
document_id, collection_name=collection_name
|
|
)
|
|
|
|
# Delete from secondary backends concurrently
|
|
if len(self._storages) > 1:
|
|
|
|
async def delete_from_backend(
|
|
storage: BaseStorage,
|
|
) -> tuple[BaseStorage, bool, Exception | None]:
|
|
try:
|
|
await storage.delete(document_id, collection_name=collection_name)
|
|
return storage, True, None
|
|
except Exception as exc:
|
|
return storage, False, exc
|
|
|
|
tasks = [delete_from_backend(storage) for storage in self._storages[1:]]
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
failures: list[str] = []
|
|
errors: list[Exception] = []
|
|
|
|
for result in results:
|
|
if isinstance(result, tuple):
|
|
storage, success, error = result
|
|
if not success and error is not None:
|
|
failures.append(self._format_backend_label(storage))
|
|
errors.append(error)
|
|
elif isinstance(result, Exception):
|
|
failures.append("unknown")
|
|
errors.append(result)
|
|
|
|
if failures:
|
|
backends = ", ".join(failures)
|
|
primary_error = errors[0] if errors else Exception("Unknown deletion error")
|
|
raise StorageError(
|
|
f"Document deleted from primary backend but failed for: {backends}"
|
|
) from primary_error
|
|
|
|
return primary_deleted
|
|
|
|
async def count(self, *, collection_name: str | None = None) -> int:
|
|
count_result: int = await self._primary.count(collection_name=collection_name)
|
|
return count_result
|
|
|
|
async def list_collections(self) -> list[str]:
|
|
list_fn = getattr(self._primary, "list_collections", None)
|
|
if list_fn is None:
|
|
return []
|
|
collections_result: list[str] = await list_fn()
|
|
return collections_result
|
|
|
|
async def search(
|
|
self,
|
|
query: str,
|
|
limit: int = 10,
|
|
threshold: float = 0.7,
|
|
*,
|
|
collection_name: str | None = None,
|
|
) -> AsyncGenerator[Document, None]:
|
|
async for item in self._primary.search(
|
|
query,
|
|
limit=limit,
|
|
threshold=threshold,
|
|
collection_name=collection_name,
|
|
):
|
|
yield item
|
|
|
|
async def close(self) -> None:
|
|
for storage in self._storages:
|
|
close_fn = getattr(storage, "close", None)
|
|
if close_fn is not None:
|
|
await close_fn()
|
|
|
|
def _format_backend_label(self, storage: BaseStorage) -> str:
|
|
backend = getattr(storage.config, "backend", None)
|
|
if isinstance(backend, StorageBackend):
|
|
backend_value: str = backend.value
|
|
return backend_value
|
|
class_name: str = storage.__class__.__name__
|
|
return class_name
|
|
|
|
|
|
class StorageManager:
|
|
"""Centralized manager for all storage backend operations."""
|
|
|
|
def __init__(self, settings: Settings) -> None:
|
|
"""Initialize storage manager with application settings."""
|
|
self.settings: Settings = settings
|
|
self.backends: dict[StorageBackend, BaseStorage] = {}
|
|
self.capabilities: dict[StorageBackend, StorageCapabilities] = {}
|
|
self._initialized: bool = False
|
|
|
|
async def initialize_all_backends(self) -> dict[StorageBackend, bool]:
|
|
"""Initialize all available storage backends with timeout protection."""
|
|
results: dict[StorageBackend, bool] = {}
|
|
|
|
async def init_backend(
|
|
backend_type: StorageBackend, config: StorageConfig, storage_class: type[BaseStorage]
|
|
) -> bool:
|
|
"""Initialize a single backend with timeout."""
|
|
try:
|
|
storage = storage_class(config)
|
|
await asyncio.wait_for(storage.initialize(), timeout=30.0)
|
|
self.backends[backend_type] = storage
|
|
if backend_type == StorageBackend.WEAVIATE:
|
|
self.capabilities[backend_type] = StorageCapabilities.VECTOR_SEARCH
|
|
elif backend_type == StorageBackend.OPEN_WEBUI:
|
|
self.capabilities[backend_type] = StorageCapabilities.KNOWLEDGE_BASE
|
|
elif backend_type == StorageBackend.R2R:
|
|
self.capabilities[backend_type] = StorageCapabilities.FULL_FEATURED
|
|
return True
|
|
except (TimeoutError, Exception):
|
|
return False
|
|
|
|
# Initialize backends concurrently with timeout protection
|
|
tasks: list[tuple[StorageBackend, Coroutine[None, None, bool]]] = []
|
|
|
|
# Try Weaviate
|
|
if self.settings.weaviate_endpoint:
|
|
config = StorageConfig(
|
|
backend=StorageBackend.WEAVIATE,
|
|
endpoint=self.settings.weaviate_endpoint,
|
|
api_key=SecretStr(self.settings.weaviate_api_key)
|
|
if self.settings.weaviate_api_key
|
|
else None,
|
|
collection_name="default",
|
|
)
|
|
tasks.append(
|
|
(
|
|
StorageBackend.WEAVIATE,
|
|
init_backend(StorageBackend.WEAVIATE, config, WeaviateStorage),
|
|
)
|
|
)
|
|
else:
|
|
results[StorageBackend.WEAVIATE] = False
|
|
|
|
# Try OpenWebUI
|
|
if self.settings.openwebui_endpoint and self.settings.openwebui_api_key:
|
|
config = StorageConfig(
|
|
backend=StorageBackend.OPEN_WEBUI,
|
|
endpoint=self.settings.openwebui_endpoint,
|
|
api_key=SecretStr(self.settings.openwebui_api_key)
|
|
if self.settings.openwebui_api_key
|
|
else None,
|
|
collection_name="default",
|
|
)
|
|
tasks.append(
|
|
(
|
|
StorageBackend.OPEN_WEBUI,
|
|
init_backend(StorageBackend.OPEN_WEBUI, config, OpenWebUIStorage),
|
|
)
|
|
)
|
|
else:
|
|
results[StorageBackend.OPEN_WEBUI] = False
|
|
|
|
# Try R2R
|
|
if self.settings.r2r_endpoint:
|
|
config = StorageConfig(
|
|
backend=StorageBackend.R2R,
|
|
endpoint=self.settings.r2r_endpoint,
|
|
api_key=SecretStr(self.settings.r2r_api_key) if self.settings.r2r_api_key else None,
|
|
collection_name="default",
|
|
)
|
|
tasks.append((StorageBackend.R2R, init_backend(StorageBackend.R2R, config, R2RStorage)))
|
|
else:
|
|
results[StorageBackend.R2R] = False
|
|
|
|
# Execute initialization tasks concurrently
|
|
if tasks:
|
|
backend_types, task_coroutines = zip(*tasks, strict=False)
|
|
task_results: Sequence[bool | BaseException] = await asyncio.gather(
|
|
*task_coroutines, return_exceptions=True
|
|
)
|
|
|
|
for backend_type, task_result in zip(backend_types, task_results, strict=False):
|
|
results[backend_type] = task_result if isinstance(task_result, bool) else False
|
|
self._initialized = True
|
|
return results
|
|
|
|
def get_backend(self, backend_type: StorageBackend) -> BaseStorage | None:
|
|
"""Get storage backend by type."""
|
|
return self.backends.get(backend_type)
|
|
|
|
def build_multi_storage_adapter(
|
|
self, backends: Sequence[StorageBackend]
|
|
) -> MultiStorageAdapter:
|
|
storages: list[BaseStorage] = []
|
|
seen: set[StorageBackend] = set()
|
|
for backend in backends:
|
|
backend_enum = (
|
|
backend if isinstance(backend, StorageBackend) else StorageBackend(backend)
|
|
)
|
|
if backend_enum in seen:
|
|
continue
|
|
seen.add(backend_enum)
|
|
storage = self.backends.get(backend_enum)
|
|
if storage is None:
|
|
raise ValueError(f"Storage backend {backend_enum.value} is not initialized")
|
|
storages.append(storage)
|
|
return MultiStorageAdapter(storages)
|
|
|
|
def get_available_backends(self) -> list[StorageBackend]:
|
|
"""Get list of successfully initialized backends."""
|
|
return list(self.backends.keys())
|
|
|
|
def has_capability(self, backend: StorageBackend, capability: StorageCapabilities) -> bool:
|
|
"""Check if backend has specific capability."""
|
|
backend_caps = self.capabilities.get(backend, StorageCapabilities.BASIC)
|
|
return capability.value <= backend_caps.value
|
|
|
|
async def get_all_collections(self) -> list[CollectionInfo]:
|
|
"""Get collections from all available backends, merging collections with same name."""
|
|
collection_map: dict[str, CollectionInfo] = {}
|
|
|
|
for backend_type, storage in self.backends.items():
|
|
try:
|
|
backend_collections = await storage.list_collections()
|
|
for collection_name in backend_collections:
|
|
# Validate collection name
|
|
if not collection_name or not isinstance(collection_name, str):
|
|
continue
|
|
|
|
try:
|
|
count = await storage.count(collection_name=collection_name)
|
|
# Validate count is non-negative
|
|
count = max(count, 0)
|
|
except StorageError as e:
|
|
# Storage-specific errors - log and use 0 count
|
|
import logging
|
|
|
|
logging.warning(
|
|
f"Failed to get count for {collection_name} on {backend_type.value}: {e}"
|
|
)
|
|
count = 0
|
|
except Exception as e:
|
|
# Unexpected errors - log and skip this collection from this backend
|
|
import logging
|
|
|
|
logging.warning(
|
|
f"Unexpected error counting {collection_name} on {backend_type.value}: {e}"
|
|
)
|
|
continue
|
|
|
|
size_mb = count * 0.01 # Rough estimate: 10KB per document
|
|
|
|
# Create unique key combining collection name and backend to show separately
|
|
collection_key = f"{collection_name}#{backend_type.value}"
|
|
|
|
# Create new collection entry (no aggregation)
|
|
collection_info: CollectionInfo = {
|
|
"name": collection_name,
|
|
"type": self._get_collection_type(collection_name, backend_type),
|
|
"count": count,
|
|
"backend": backend_type.value,
|
|
"status": "active",
|
|
"last_updated": "2024-01-01T00:00:00Z",
|
|
"size_mb": size_mb,
|
|
}
|
|
collection_map[collection_key] = collection_info
|
|
except Exception:
|
|
continue
|
|
|
|
return list(collection_map.values())
|
|
|
|
def _get_collection_type(self, collection_name: str, backend: StorageBackend) -> str:
|
|
"""Determine collection type based on name and backend."""
|
|
# Prioritize definitive backend type first
|
|
if backend == StorageBackend.R2R:
|
|
return "r2r"
|
|
elif backend == StorageBackend.WEAVIATE:
|
|
return "weaviate"
|
|
elif backend == StorageBackend.OPEN_WEBUI:
|
|
return "openwebui"
|
|
|
|
# Fallback to name-based guessing if backend is not specific
|
|
name_lower = collection_name.lower()
|
|
if "web" in name_lower or "doc" in name_lower:
|
|
return "documentation"
|
|
elif "repo" in name_lower or "code" in name_lower:
|
|
return "repository"
|
|
else:
|
|
return "general"
|
|
|
|
async def search_across_backends(
|
|
self,
|
|
query: str,
|
|
limit: int = 10,
|
|
backends: list[StorageBackend] | None = None,
|
|
) -> dict[StorageBackend, list[Document]]:
|
|
"""Search across multiple backends and return grouped results."""
|
|
if backends is None:
|
|
backends = self.get_available_backends()
|
|
|
|
results: dict[StorageBackend, list[Document]] = {}
|
|
|
|
async def search_backend(backend_type: StorageBackend) -> None:
|
|
storage = self.backends.get(backend_type)
|
|
if storage:
|
|
try:
|
|
documents: list[Document] = []
|
|
async for doc in storage.search(query, limit=limit):
|
|
documents.append(doc)
|
|
results[backend_type] = documents
|
|
except Exception:
|
|
results[backend_type] = []
|
|
|
|
# Run searches in parallel
|
|
tasks = [search_backend(backend) for backend in backends]
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
return results
|
|
|
|
def get_r2r_storage(self) -> R2RStorage | None:
|
|
"""Get R2R storage instance if available."""
|
|
storage = self.backends.get(StorageBackend.R2R)
|
|
return storage if isinstance(storage, R2RStorage) else None
|
|
|
|
async def get_backend_status(
|
|
self,
|
|
) -> dict[StorageBackend, dict[str, str | int | bool | StorageCapabilities]]:
|
|
"""Get comprehensive status for all backends."""
|
|
status: dict[StorageBackend, dict[str, str | int | bool | StorageCapabilities]] = {}
|
|
|
|
for backend_type, storage in self.backends.items():
|
|
try:
|
|
collections = await storage.list_collections()
|
|
total_docs = 0
|
|
for collection in collections:
|
|
total_docs += await storage.count(collection_name=collection)
|
|
|
|
backend_status: dict[str, str | int | bool | StorageCapabilities] = {
|
|
"available": True,
|
|
"collections": len(collections),
|
|
"total_documents": total_docs,
|
|
"capabilities": self.capabilities.get(backend_type, StorageCapabilities.BASIC),
|
|
"endpoint": getattr(storage.config, "endpoint", "unknown"),
|
|
}
|
|
status[backend_type] = backend_status
|
|
except Exception as e:
|
|
status[backend_type] = {
|
|
"available": False,
|
|
"error": str(e),
|
|
"capabilities": StorageCapabilities.NONE,
|
|
}
|
|
|
|
return status
|
|
|
|
async def close_all(self) -> None:
|
|
"""Close all storage connections."""
|
|
for storage in self.backends.values():
|
|
try:
|
|
await storage.close()
|
|
except Exception:
|
|
pass
|
|
|
|
self.backends.clear()
|
|
self.capabilities.clear()
|
|
self._initialized = False
|
|
|
|
@property
|
|
def is_initialized(self) -> bool:
|
|
"""Check if storage manager is initialized."""
|
|
return self._initialized
|
|
|
|
def supports_advanced_features(self, backend: StorageBackend) -> bool:
|
|
"""Check if backend supports advanced features like chunks and entities."""
|
|
return self.has_capability(backend, StorageCapabilities.FULL_FEATURED)
|