Files
rag-manager/ingest_pipeline/cli/tui/utils/storage_manager.py
2025-09-21 03:00:57 +00:00

522 lines
21 KiB
Python

"""Storage management utilities for TUI applications."""
from __future__ import annotations
import asyncio
from collections.abc import AsyncGenerator, Coroutine, Sequence
from typing import TYPE_CHECKING, Protocol
from pydantic import SecretStr
from ....core.exceptions import StorageError
from ....core.models import Document, StorageBackend, StorageConfig
from ....storage.base import BaseStorage
from ....storage.openwebui import OpenWebUIStorage
from ....storage.r2r.storage import R2RStorage
from ....storage.weaviate import WeaviateStorage
from ..models import CollectionInfo, StorageCapabilities
if TYPE_CHECKING:
from ....config.settings import Settings
class StorageBackendProtocol(Protocol):
"""Protocol defining storage backend interface."""
async def initialize(self) -> None: ...
async def count(self, *, collection_name: str | None = None) -> int: ...
async def list_collections(self) -> list[str]: ...
async def search(
self,
query: str,
limit: int = 10,
threshold: float = 0.7,
*,
collection_name: str | None = None,
) -> AsyncGenerator[Document, None]: ...
async def close(self) -> None: ...
class MultiStorageAdapter(BaseStorage):
"""Mirror writes to multiple storage backends."""
def __init__(self, storages: Sequence[BaseStorage]) -> None:
if not storages:
raise ValueError("MultiStorageAdapter requires at least one storage backend")
unique: list[BaseStorage] = []
seen_ids: set[int] = set()
for storage in storages:
storage_id = id(storage)
if storage_id in seen_ids:
continue
seen_ids.add(storage_id)
unique.append(storage)
self._storages: list[BaseStorage] = unique
self._primary: BaseStorage = unique[0]
super().__init__(self._primary.config)
async def initialize(self) -> None:
for storage in self._storages:
await storage.initialize()
async def store(self, document: Document, *, collection_name: str | None = None) -> str:
# Store in primary backend first
primary_id: str = await self._primary.store(document, collection_name=collection_name)
# Replicate to secondary backends concurrently
if len(self._storages) > 1:
async def replicate_to_backend(
storage: BaseStorage,
) -> tuple[BaseStorage, bool, Exception | None]:
try:
await storage.store(document, collection_name=collection_name)
return storage, True, None
except Exception as exc:
return storage, False, exc
tasks = [replicate_to_backend(storage) for storage in self._storages[1:]]
results = await asyncio.gather(*tasks, return_exceptions=True)
failures: list[str] = []
errors: list[Exception] = []
for result in results:
if isinstance(result, tuple):
storage, success, error = result
if not success and error is not None:
failures.append(self._format_backend_label(storage))
errors.append(error)
elif isinstance(result, Exception):
failures.append("unknown")
errors.append(result)
if failures:
backends = ", ".join(failures)
primary_error = errors[0] if errors else Exception("Unknown replication error")
raise StorageError(
f"Document stored in primary backend but replication failed for: {backends}"
) from primary_error
return primary_id
async def store_batch(
self, documents: list[Document], *, collection_name: str | None = None
) -> list[str]:
# Store in primary backend first
primary_ids: list[str] = await self._primary.store_batch(
documents, collection_name=collection_name
)
# Replicate to secondary backends concurrently
if len(self._storages) > 1:
async def replicate_batch_to_backend(
storage: BaseStorage,
) -> tuple[BaseStorage, bool, Exception | None]:
try:
await storage.store_batch(documents, collection_name=collection_name)
return storage, True, None
except Exception as exc:
return storage, False, exc
tasks = [replicate_batch_to_backend(storage) for storage in self._storages[1:]]
results = await asyncio.gather(*tasks, return_exceptions=True)
failures: list[str] = []
errors: list[Exception] = []
for result in results:
if isinstance(result, tuple):
storage, success, error = result
if not success and error is not None:
failures.append(self._format_backend_label(storage))
errors.append(error)
elif isinstance(result, Exception):
failures.append("unknown")
errors.append(result)
if failures:
backends = ", ".join(failures)
primary_error = (
errors[0] if errors else Exception("Unknown batch replication error")
)
raise StorageError(
f"Batch stored in primary backend but replication failed for: {backends}"
) from primary_error
return primary_ids
async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool:
# Delete from primary backend first
primary_deleted: bool = await self._primary.delete(
document_id, collection_name=collection_name
)
# Delete from secondary backends concurrently
if len(self._storages) > 1:
async def delete_from_backend(
storage: BaseStorage,
) -> tuple[BaseStorage, bool, Exception | None]:
try:
await storage.delete(document_id, collection_name=collection_name)
return storage, True, None
except Exception as exc:
return storage, False, exc
tasks = [delete_from_backend(storage) for storage in self._storages[1:]]
results = await asyncio.gather(*tasks, return_exceptions=True)
failures: list[str] = []
errors: list[Exception] = []
for result in results:
if isinstance(result, tuple):
storage, success, error = result
if not success and error is not None:
failures.append(self._format_backend_label(storage))
errors.append(error)
elif isinstance(result, Exception):
failures.append("unknown")
errors.append(result)
if failures:
backends = ", ".join(failures)
primary_error = errors[0] if errors else Exception("Unknown deletion error")
raise StorageError(
f"Document deleted from primary backend but failed for: {backends}"
) from primary_error
return primary_deleted
async def count(self, *, collection_name: str | None = None) -> int:
count_result: int = await self._primary.count(collection_name=collection_name)
return count_result
async def list_collections(self) -> list[str]:
list_fn = getattr(self._primary, "list_collections", None)
if list_fn is None:
return []
collections_result: list[str] = await list_fn()
return collections_result
async def search(
self,
query: str,
limit: int = 10,
threshold: float = 0.7,
*,
collection_name: str | None = None,
) -> AsyncGenerator[Document, None]:
async for item in self._primary.search(
query,
limit=limit,
threshold=threshold,
collection_name=collection_name,
):
yield item
async def close(self) -> None:
for storage in self._storages:
close_fn = getattr(storage, "close", None)
if close_fn is not None:
await close_fn()
def _format_backend_label(self, storage: BaseStorage) -> str:
backend = getattr(storage.config, "backend", None)
if isinstance(backend, StorageBackend):
backend_value: str = backend.value
return backend_value
class_name: str = storage.__class__.__name__
return class_name
class StorageManager:
"""Centralized manager for all storage backend operations."""
def __init__(self, settings: Settings) -> None:
"""Initialize storage manager with application settings."""
self.settings: Settings = settings
self.backends: dict[StorageBackend, BaseStorage] = {}
self.capabilities: dict[StorageBackend, StorageCapabilities] = {}
self._initialized: bool = False
async def initialize_all_backends(self) -> dict[StorageBackend, bool]:
"""Initialize all available storage backends with timeout protection."""
results: dict[StorageBackend, bool] = {}
async def init_backend(
backend_type: StorageBackend, config: StorageConfig, storage_class: type[BaseStorage]
) -> bool:
"""Initialize a single backend with timeout."""
try:
storage = storage_class(config)
await asyncio.wait_for(storage.initialize(), timeout=30.0)
self.backends[backend_type] = storage
if backend_type == StorageBackend.WEAVIATE:
self.capabilities[backend_type] = StorageCapabilities.VECTOR_SEARCH
elif backend_type == StorageBackend.OPEN_WEBUI:
self.capabilities[backend_type] = StorageCapabilities.KNOWLEDGE_BASE
elif backend_type == StorageBackend.R2R:
self.capabilities[backend_type] = StorageCapabilities.FULL_FEATURED
return True
except (TimeoutError, Exception):
return False
# Initialize backends concurrently with timeout protection
tasks: list[tuple[StorageBackend, Coroutine[None, None, bool]]] = []
# Try Weaviate
if self.settings.weaviate_endpoint:
config = StorageConfig(
backend=StorageBackend.WEAVIATE,
endpoint=self.settings.weaviate_endpoint,
api_key=SecretStr(self.settings.weaviate_api_key)
if self.settings.weaviate_api_key
else None,
collection_name="default",
)
tasks.append(
(
StorageBackend.WEAVIATE,
init_backend(StorageBackend.WEAVIATE, config, WeaviateStorage),
)
)
else:
results[StorageBackend.WEAVIATE] = False
# Try OpenWebUI
if self.settings.openwebui_endpoint and self.settings.openwebui_api_key:
config = StorageConfig(
backend=StorageBackend.OPEN_WEBUI,
endpoint=self.settings.openwebui_endpoint,
api_key=SecretStr(self.settings.openwebui_api_key)
if self.settings.openwebui_api_key
else None,
collection_name="default",
)
tasks.append(
(
StorageBackend.OPEN_WEBUI,
init_backend(StorageBackend.OPEN_WEBUI, config, OpenWebUIStorage),
)
)
else:
results[StorageBackend.OPEN_WEBUI] = False
# Try R2R
if self.settings.r2r_endpoint:
config = StorageConfig(
backend=StorageBackend.R2R,
endpoint=self.settings.r2r_endpoint,
api_key=SecretStr(self.settings.r2r_api_key) if self.settings.r2r_api_key else None,
collection_name="default",
)
tasks.append((StorageBackend.R2R, init_backend(StorageBackend.R2R, config, R2RStorage)))
else:
results[StorageBackend.R2R] = False
# Execute initialization tasks concurrently
if tasks:
backend_types, task_coroutines = zip(*tasks, strict=False)
task_results: Sequence[bool | BaseException] = await asyncio.gather(
*task_coroutines, return_exceptions=True
)
for backend_type, task_result in zip(backend_types, task_results, strict=False):
results[backend_type] = task_result if isinstance(task_result, bool) else False
self._initialized = True
return results
def get_backend(self, backend_type: StorageBackend) -> BaseStorage | None:
"""Get storage backend by type."""
return self.backends.get(backend_type)
def build_multi_storage_adapter(
self, backends: Sequence[StorageBackend]
) -> MultiStorageAdapter:
storages: list[BaseStorage] = []
seen: set[StorageBackend] = set()
for backend in backends:
backend_enum = (
backend if isinstance(backend, StorageBackend) else StorageBackend(backend)
)
if backend_enum in seen:
continue
seen.add(backend_enum)
storage = self.backends.get(backend_enum)
if storage is None:
raise ValueError(f"Storage backend {backend_enum.value} is not initialized")
storages.append(storage)
return MultiStorageAdapter(storages)
def get_available_backends(self) -> list[StorageBackend]:
"""Get list of successfully initialized backends."""
return list(self.backends.keys())
def has_capability(self, backend: StorageBackend, capability: StorageCapabilities) -> bool:
"""Check if backend has specific capability."""
backend_caps = self.capabilities.get(backend, StorageCapabilities.BASIC)
return capability.value <= backend_caps.value
async def get_all_collections(self) -> list[CollectionInfo]:
"""Get collections from all available backends, merging collections with same name."""
collection_map: dict[str, CollectionInfo] = {}
for backend_type, storage in self.backends.items():
try:
backend_collections = await storage.list_collections()
for collection_name in backend_collections:
# Validate collection name
if not collection_name or not isinstance(collection_name, str):
continue
try:
count = await storage.count(collection_name=collection_name)
# Validate count is non-negative
count = max(count, 0)
except StorageError as e:
# Storage-specific errors - log and use 0 count
import logging
logging.warning(
f"Failed to get count for {collection_name} on {backend_type.value}: {e}"
)
count = 0
except Exception as e:
# Unexpected errors - log and skip this collection from this backend
import logging
logging.warning(
f"Unexpected error counting {collection_name} on {backend_type.value}: {e}"
)
continue
size_mb = count * 0.01 # Rough estimate: 10KB per document
# Create unique key combining collection name and backend to show separately
collection_key = f"{collection_name}#{backend_type.value}"
# Create new collection entry (no aggregation)
collection_info: CollectionInfo = {
"name": collection_name,
"type": self._get_collection_type(collection_name, backend_type),
"count": count,
"backend": backend_type.value,
"status": "active",
"last_updated": "2024-01-01T00:00:00Z",
"size_mb": size_mb,
}
collection_map[collection_key] = collection_info
except Exception:
continue
return list(collection_map.values())
def _get_collection_type(self, collection_name: str, backend: StorageBackend) -> str:
"""Determine collection type based on name and backend."""
# Prioritize definitive backend type first
if backend == StorageBackend.R2R:
return "r2r"
elif backend == StorageBackend.WEAVIATE:
return "weaviate"
elif backend == StorageBackend.OPEN_WEBUI:
return "openwebui"
# Fallback to name-based guessing if backend is not specific
name_lower = collection_name.lower()
if "web" in name_lower or "doc" in name_lower:
return "documentation"
elif "repo" in name_lower or "code" in name_lower:
return "repository"
else:
return "general"
async def search_across_backends(
self,
query: str,
limit: int = 10,
backends: list[StorageBackend] | None = None,
) -> dict[StorageBackend, list[Document]]:
"""Search across multiple backends and return grouped results."""
if backends is None:
backends = self.get_available_backends()
results: dict[StorageBackend, list[Document]] = {}
async def search_backend(backend_type: StorageBackend) -> None:
storage = self.backends.get(backend_type)
if storage:
try:
documents: list[Document] = []
async for doc in storage.search(query, limit=limit):
documents.append(doc)
results[backend_type] = documents
except Exception:
results[backend_type] = []
# Run searches in parallel
tasks = [search_backend(backend) for backend in backends]
await asyncio.gather(*tasks, return_exceptions=True)
return results
def get_r2r_storage(self) -> R2RStorage | None:
"""Get R2R storage instance if available."""
storage = self.backends.get(StorageBackend.R2R)
return storage if isinstance(storage, R2RStorage) else None
async def get_backend_status(
self,
) -> dict[StorageBackend, dict[str, str | int | bool | StorageCapabilities]]:
"""Get comprehensive status for all backends."""
status: dict[StorageBackend, dict[str, str | int | bool | StorageCapabilities]] = {}
for backend_type, storage in self.backends.items():
try:
collections = await storage.list_collections()
total_docs = 0
for collection in collections:
total_docs += await storage.count(collection_name=collection)
backend_status: dict[str, str | int | bool | StorageCapabilities] = {
"available": True,
"collections": len(collections),
"total_documents": total_docs,
"capabilities": self.capabilities.get(backend_type, StorageCapabilities.BASIC),
"endpoint": getattr(storage.config, "endpoint", "unknown"),
}
status[backend_type] = backend_status
except Exception as e:
status[backend_type] = {
"available": False,
"error": str(e),
"capabilities": StorageCapabilities.NONE,
}
return status
async def close_all(self) -> None:
"""Close all storage connections."""
for storage in self.backends.values():
try:
await storage.close()
except Exception:
pass
self.backends.clear()
self.capabilities.clear()
self._initialized = False
@property
def is_initialized(self) -> bool:
"""Check if storage manager is initialized."""
return self._initialized
def supports_advanced_features(self, backend: StorageBackend) -> bool:
"""Check if backend supports advanced features like chunks and entities."""
return self.has_capability(backend, StorageCapabilities.FULL_FEATURED)