1017 lines
36 KiB
Python
1017 lines
36 KiB
Python
"""Weaviate storage adapter."""
|
|
|
|
from collections.abc import AsyncGenerator, Mapping, Sequence
|
|
from datetime import UTC, datetime
|
|
from typing import Literal, Self, TypeAlias, cast, overload
|
|
from uuid import UUID
|
|
|
|
import weaviate
|
|
from typing_extensions import override
|
|
from weaviate.classes.config import Configure, DataType, Property
|
|
from weaviate.classes.data import DataObject
|
|
from weaviate.classes.query import Filter
|
|
from weaviate.collections import Collection
|
|
from weaviate.exceptions import (
|
|
WeaviateBatchError,
|
|
WeaviateConnectionError,
|
|
WeaviateQueryError,
|
|
)
|
|
|
|
from ..core.exceptions import StorageError
|
|
from ..core.models import Document, DocumentMetadata, IngestionSource, StorageConfig
|
|
from ..utils.vectorizer import Vectorizer
|
|
from .base import BaseStorage
|
|
from .types import CollectionSummary, DocumentInfo
|
|
|
|
VectorContainer: TypeAlias = Mapping[str, object] | Sequence[object] | None
|
|
|
|
|
|
class WeaviateStorage(BaseStorage):
|
|
"""Storage adapter for Weaviate."""
|
|
|
|
client: weaviate.WeaviateClient | None
|
|
vectorizer: Vectorizer
|
|
_default_collection: str
|
|
|
|
def __init__(self, config: StorageConfig):
|
|
"""
|
|
Initialize Weaviate storage.
|
|
|
|
Args:
|
|
config: Storage configuration
|
|
"""
|
|
super().__init__(config)
|
|
self.client = None
|
|
self.vectorizer = Vectorizer(config)
|
|
self._default_collection = self._normalize_collection_name(config.collection_name)
|
|
|
|
@override
|
|
async def initialize(self) -> None:
|
|
"""Initialize Weaviate client and create collection if needed."""
|
|
try:
|
|
# Let Weaviate client handle URL parsing
|
|
self.client = weaviate.WeaviateClient(
|
|
connection_params=weaviate.connect.ConnectionParams.from_url(
|
|
url=str(self.config.endpoint),
|
|
grpc_port=50051, # Default gRPC port
|
|
),
|
|
additional_config=weaviate.classes.init.AdditionalConfig(
|
|
timeout=weaviate.classes.init.Timeout(init=30, query=60, insert=120),
|
|
),
|
|
)
|
|
|
|
# Connect to the client
|
|
self.client.connect()
|
|
|
|
# Ensure the default collection exists
|
|
await self._ensure_collection(self._default_collection)
|
|
|
|
except WeaviateConnectionError as e:
|
|
raise StorageError(f"Failed to connect to Weaviate: {e}") from e
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to initialize Weaviate: {e}") from e
|
|
|
|
async def _create_collection(self, collection_name: str) -> None:
|
|
"""Create Weaviate collection with schema."""
|
|
if not self.client:
|
|
raise StorageError("Weaviate client not initialized")
|
|
try:
|
|
client = cast(weaviate.WeaviateClient, self.client)
|
|
client.collections.create(
|
|
name=collection_name,
|
|
properties=[
|
|
Property(
|
|
name="content", data_type=DataType.TEXT, description="Document content"
|
|
),
|
|
Property(name="source_url", data_type=DataType.TEXT, description="Source URL"),
|
|
Property(name="title", data_type=DataType.TEXT, description="Document title"),
|
|
Property(
|
|
name="description",
|
|
data_type=DataType.TEXT,
|
|
description="Document description",
|
|
),
|
|
Property(
|
|
name="timestamp", data_type=DataType.DATE, description="Ingestion timestamp"
|
|
),
|
|
Property(
|
|
name="content_type", data_type=DataType.TEXT, description="Content type"
|
|
),
|
|
Property(name="word_count", data_type=DataType.INT, description="Word count"),
|
|
Property(
|
|
name="char_count", data_type=DataType.INT, description="Character count"
|
|
),
|
|
Property(
|
|
name="source", data_type=DataType.TEXT, description="Ingestion source"
|
|
),
|
|
],
|
|
vectorizer_config=Configure.Vectorizer.none(),
|
|
)
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to create collection: {e}") from e
|
|
|
|
@staticmethod
|
|
def _extract_vector(vector_raw: VectorContainer) -> list[float] | None:
|
|
"""Normalize vector payloads returned by Weaviate into a float list."""
|
|
if isinstance(vector_raw, Mapping):
|
|
default_vector = vector_raw.get("default")
|
|
return WeaviateStorage._extract_vector(
|
|
cast(VectorContainer, default_vector)
|
|
)
|
|
|
|
if not isinstance(vector_raw, Sequence) or isinstance(
|
|
vector_raw, (str, bytes, bytearray)
|
|
):
|
|
return None
|
|
|
|
items = list(vector_raw)
|
|
if not items:
|
|
return None
|
|
|
|
first_item = items[0]
|
|
if isinstance(first_item, (int, float)):
|
|
numeric_items = cast(list[int | float], items)
|
|
try:
|
|
return [float(value) for value in numeric_items]
|
|
except (TypeError, ValueError):
|
|
return None
|
|
|
|
if isinstance(first_item, Sequence) and not isinstance(
|
|
first_item, (str, bytes, bytearray)
|
|
):
|
|
inner_items = list(first_item)
|
|
if all(isinstance(item, (int, float)) for item in inner_items):
|
|
try:
|
|
numeric_inner = cast(list[int | float], inner_items)
|
|
return [float(item) for item in numeric_inner]
|
|
except (TypeError, ValueError):
|
|
return None
|
|
|
|
return None
|
|
|
|
@staticmethod
|
|
def _parse_source(source_raw: object) -> IngestionSource:
|
|
"""Safely normalize persistence source values into enum instances."""
|
|
if isinstance(source_raw, IngestionSource):
|
|
return source_raw
|
|
|
|
if isinstance(source_raw, str):
|
|
try:
|
|
return IngestionSource(source_raw)
|
|
except ValueError:
|
|
return IngestionSource.WEB
|
|
|
|
return IngestionSource.WEB
|
|
|
|
@staticmethod
|
|
@overload
|
|
def _coerce_properties(
|
|
properties: object,
|
|
*,
|
|
context: str,
|
|
) -> Mapping[str, object]:
|
|
...
|
|
|
|
@staticmethod
|
|
@overload
|
|
def _coerce_properties(
|
|
properties: object,
|
|
*,
|
|
context: str,
|
|
allow_missing: Literal[False],
|
|
) -> Mapping[str, object]:
|
|
...
|
|
|
|
@staticmethod
|
|
@overload
|
|
def _coerce_properties(
|
|
properties: object,
|
|
*,
|
|
context: str,
|
|
allow_missing: Literal[True],
|
|
) -> Mapping[str, object] | None:
|
|
...
|
|
|
|
@staticmethod
|
|
def _coerce_properties(
|
|
properties: object,
|
|
*,
|
|
context: str,
|
|
allow_missing: bool = False,
|
|
) -> Mapping[str, object] | None:
|
|
"""Ensure Weaviate properties payloads are mappings."""
|
|
if properties is None:
|
|
if allow_missing:
|
|
return None
|
|
raise StorageError(f"{context} returned object without properties")
|
|
|
|
if not isinstance(properties, Mapping):
|
|
raise StorageError(
|
|
f"{context} returned invalid properties payload of type {type(properties)!r}"
|
|
)
|
|
|
|
return cast(Mapping[str, object], properties)
|
|
|
|
def _normalize_collection_name(self, collection_name: str | None) -> str:
|
|
"""Return a canonicalized collection name, defaulting to configured value."""
|
|
candidate = collection_name or self.config.collection_name
|
|
if not candidate:
|
|
raise StorageError("Collection name is required")
|
|
|
|
if normalized := candidate.strip():
|
|
return normalized[0].upper() + normalized[1:]
|
|
else:
|
|
raise StorageError("Collection name cannot be empty")
|
|
|
|
async def _ensure_collection(self, collection_name: str) -> None:
|
|
"""Create the collection if missing."""
|
|
if not self.client:
|
|
raise StorageError("Weaviate client not initialized")
|
|
|
|
client = cast(weaviate.WeaviateClient, self.client)
|
|
existing = client.collections.list_all()
|
|
if collection_name not in existing:
|
|
await self._create_collection(collection_name)
|
|
|
|
async def _prepare_collection(
|
|
self,
|
|
collection_name: str | None,
|
|
*,
|
|
ensure_exists: bool,
|
|
) -> tuple[Collection, str]:
|
|
"""Return a ready collection handle and normalized name."""
|
|
normalized = self._normalize_collection_name(collection_name)
|
|
|
|
if not self.client:
|
|
raise StorageError("Weaviate client not initialized")
|
|
|
|
if ensure_exists:
|
|
await self._ensure_collection(normalized)
|
|
|
|
client = cast(weaviate.WeaviateClient, self.client)
|
|
return client.collections.get(normalized), normalized
|
|
|
|
@override
|
|
async def store(self, document: Document, *, collection_name: str | None = None) -> str:
|
|
"""
|
|
Store a document in Weaviate.
|
|
|
|
Args:
|
|
document: Document to store
|
|
|
|
Returns:
|
|
Document ID
|
|
"""
|
|
try:
|
|
# Vectorize content if no vector provided
|
|
if document.vector is None:
|
|
document.vector = await self.vectorizer.vectorize(document.content)
|
|
|
|
collection, resolved_name = await self._prepare_collection(
|
|
collection_name, ensure_exists=True
|
|
)
|
|
|
|
# Prepare properties
|
|
properties = {
|
|
"content": document.content,
|
|
"source_url": document.metadata["source_url"],
|
|
"title": document.metadata.get("title", ""),
|
|
"description": document.metadata.get("description", ""),
|
|
"timestamp": document.metadata["timestamp"].isoformat(),
|
|
"content_type": document.metadata["content_type"],
|
|
"word_count": document.metadata["word_count"],
|
|
"char_count": document.metadata["char_count"],
|
|
"source": document.source.value,
|
|
}
|
|
|
|
# Insert with vector
|
|
result = collection.data.insert(
|
|
properties=properties, vector=document.vector, uuid=str(document.id)
|
|
)
|
|
|
|
return str(result)
|
|
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to store document: {e}") from e
|
|
|
|
@override
|
|
async def store_batch(
|
|
self, documents: list[Document], *, collection_name: str | None = None
|
|
) -> list[str]:
|
|
"""
|
|
Store multiple documents using proper batch operations.
|
|
|
|
Args:
|
|
documents: List of documents
|
|
|
|
Returns:
|
|
List of successfully stored document IDs
|
|
"""
|
|
try:
|
|
collection, resolved_name = await self._prepare_collection(
|
|
collection_name, ensure_exists=True
|
|
)
|
|
|
|
# Vectorize documents without vectors
|
|
for doc in documents:
|
|
if doc.vector is None:
|
|
doc.vector = await self.vectorizer.vectorize(doc.content)
|
|
|
|
# Prepare batch data for insert_many
|
|
batch_objects = []
|
|
for doc in documents:
|
|
properties = {
|
|
"content": doc.content,
|
|
"source_url": doc.metadata["source_url"],
|
|
"title": doc.metadata.get("title", ""),
|
|
"description": doc.metadata.get("description", ""),
|
|
"timestamp": doc.metadata["timestamp"].isoformat(),
|
|
"content_type": doc.metadata["content_type"],
|
|
"word_count": doc.metadata["word_count"],
|
|
"char_count": doc.metadata["char_count"],
|
|
"source": doc.source.value,
|
|
}
|
|
|
|
batch_objects.append(
|
|
DataObject(properties=properties, vector=doc.vector, uuid=str(doc.id))
|
|
)
|
|
|
|
# Insert batch using insert_many
|
|
response = collection.data.insert_many(batch_objects)
|
|
|
|
successful_ids: list[str] = []
|
|
error_indices = set(response.errors.keys()) if response else set()
|
|
|
|
for index, doc in enumerate(documents):
|
|
if index in error_indices:
|
|
continue
|
|
|
|
uuid_value = response.uuids.get(index) if response else None
|
|
successful_ids.append(str(uuid_value) if uuid_value is not None else str(doc.id))
|
|
|
|
if error_indices:
|
|
error_messages = ", ".join(
|
|
f"{documents[i].id}: {response.errors[i].message}"
|
|
for i in error_indices
|
|
if hasattr(response.errors[i], "message")
|
|
)
|
|
print(
|
|
"Weaviate partial batch failure for collection "
|
|
f"{resolved_name}: {error_messages}"
|
|
)
|
|
|
|
return successful_ids
|
|
|
|
except WeaviateBatchError as e:
|
|
raise StorageError(f"Batch operation failed: {e}") from e
|
|
except WeaviateConnectionError as e:
|
|
raise StorageError(f"Connection to Weaviate failed: {e}") from e
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to store batch: {e}") from e
|
|
|
|
@override
|
|
async def retrieve(
|
|
self, document_id: str, *, collection_name: str | None = None
|
|
) -> Document | None:
|
|
"""
|
|
Retrieve a document from Weaviate.
|
|
|
|
Args:
|
|
document_id: Document ID
|
|
|
|
Returns:
|
|
Document or None
|
|
"""
|
|
try:
|
|
collection, resolved_name = await self._prepare_collection(
|
|
collection_name, ensure_exists=False
|
|
)
|
|
result = collection.query.fetch_object_by_id(document_id)
|
|
|
|
if not result:
|
|
return None
|
|
|
|
# Reconstruct document
|
|
props = self._coerce_properties(
|
|
result.properties,
|
|
context="fetch_object_by_id",
|
|
)
|
|
metadata_dict = {
|
|
"source_url": str(props["source_url"]),
|
|
"title": str(props.get("title")) if props.get("title") else None,
|
|
"description": str(props.get("description"))
|
|
if props.get("description")
|
|
else None,
|
|
"timestamp": str(props["timestamp"]),
|
|
"content_type": str(props["content_type"]),
|
|
"word_count": int(str(props["word_count"])),
|
|
"char_count": int(str(props["char_count"])),
|
|
}
|
|
metadata = cast(DocumentMetadata, cast(object, metadata_dict))
|
|
|
|
vector = self._extract_vector(cast(VectorContainer, result.vector))
|
|
|
|
return Document(
|
|
id=UUID(document_id),
|
|
content=str(props["content"]),
|
|
metadata=metadata,
|
|
vector=vector,
|
|
source=self._parse_source(props.get("source")),
|
|
collection=resolved_name,
|
|
)
|
|
|
|
except WeaviateQueryError as e:
|
|
raise StorageError(f"Query failed: {e}") from e
|
|
except WeaviateConnectionError as e:
|
|
# Connection issues should be logged and return None
|
|
import logging
|
|
logging.warning(f"Weaviate connection error retrieving document {document_id}: {e}")
|
|
return None
|
|
except Exception as e:
|
|
# Log unexpected errors for debugging
|
|
import logging
|
|
logging.warning(f"Unexpected error retrieving document {document_id}: {e}")
|
|
return None
|
|
|
|
def _build_search_metadata(self, props: Mapping[str, object]) -> DocumentMetadata:
|
|
"""Build metadata dictionary from Weaviate properties."""
|
|
metadata_dict = {
|
|
"source_url": str(props["source_url"]),
|
|
"title": str(props.get("title")) if props.get("title") else None,
|
|
"description": str(props.get("description"))
|
|
if props.get("description")
|
|
else None,
|
|
"timestamp": str(props["timestamp"]),
|
|
"content_type": str(props["content_type"]),
|
|
"word_count": int(str(props["word_count"])),
|
|
"char_count": int(str(props["char_count"])),
|
|
}
|
|
return cast(DocumentMetadata, cast(object, metadata_dict))
|
|
|
|
def _extract_search_score(self, result: object) -> float | None:
|
|
"""Extract and convert search score from result metadata."""
|
|
metadata_obj = getattr(result, "metadata", None)
|
|
if metadata_obj is None:
|
|
return None
|
|
|
|
raw_distance = getattr(metadata_obj, "distance", None)
|
|
if raw_distance is None:
|
|
return None
|
|
|
|
try:
|
|
distance_value = float(raw_distance)
|
|
return max(0.0, 1.0 - distance_value)
|
|
except (TypeError, ValueError) as e:
|
|
import logging
|
|
logging.debug(f"Invalid distance value {raw_distance}: {e}")
|
|
return None
|
|
|
|
def _build_search_document(
|
|
self,
|
|
result: object,
|
|
resolved_name: str,
|
|
) -> Document:
|
|
"""Build Document from Weaviate search result."""
|
|
props = self._coerce_properties(
|
|
getattr(result, "properties", None),
|
|
context="search result",
|
|
)
|
|
metadata = self._build_search_metadata(props)
|
|
|
|
vector_attr = getattr(result, "vector", None)
|
|
vector = self._extract_vector(cast(VectorContainer, vector_attr))
|
|
score_value = self._extract_search_score(result)
|
|
|
|
uuid_raw = getattr(result, "uuid", None)
|
|
if uuid_raw is None:
|
|
raise StorageError("Weaviate search result missing uuid")
|
|
uuid_value = uuid_raw if isinstance(uuid_raw, UUID) else UUID(str(uuid_raw))
|
|
|
|
return Document(
|
|
id=uuid_value,
|
|
content=str(props["content"]),
|
|
metadata=metadata,
|
|
vector=vector,
|
|
source=self._parse_source(props.get("source")),
|
|
collection=resolved_name,
|
|
score=score_value,
|
|
)
|
|
|
|
@override
|
|
async def search(
|
|
self,
|
|
query: str,
|
|
limit: int = 10,
|
|
threshold: float = 0.7,
|
|
*,
|
|
collection_name: str | None = None,
|
|
) -> AsyncGenerator[Document, None]:
|
|
"""
|
|
Search for documents in Weaviate.
|
|
|
|
Args:
|
|
query: Search query
|
|
limit: Maximum results
|
|
threshold: Similarity threshold
|
|
|
|
Yields:
|
|
Matching documents
|
|
"""
|
|
try:
|
|
query_vector = await self.vectorizer.vectorize(query)
|
|
collection, resolved_name = await self._prepare_collection(
|
|
collection_name, ensure_exists=False
|
|
)
|
|
|
|
results = collection.query.near_vector(
|
|
near_vector=query_vector,
|
|
limit=limit,
|
|
distance=1 - threshold,
|
|
return_metadata=["distance"],
|
|
)
|
|
|
|
for result in results.objects:
|
|
yield self._build_search_document(result, resolved_name)
|
|
|
|
except WeaviateQueryError as e:
|
|
raise StorageError(f"Search query failed: {e}") from e
|
|
except WeaviateConnectionError as e:
|
|
raise StorageError(f"Connection to Weaviate failed during search: {e}") from e
|
|
except Exception as e:
|
|
raise StorageError(f"Search failed: {e}") from e
|
|
|
|
@override
|
|
async def delete(self, document_id: str, *, collection_name: str | None = None) -> bool:
|
|
"""
|
|
Delete a document from Weaviate.
|
|
|
|
Args:
|
|
document_id: Document ID
|
|
|
|
Returns:
|
|
True if deleted
|
|
"""
|
|
try:
|
|
collection, _ = await self._prepare_collection(collection_name, ensure_exists=False)
|
|
collection.data.delete_by_id(document_id)
|
|
return True
|
|
except WeaviateQueryError as e:
|
|
raise StorageError(f"Delete operation failed: {e}") from e
|
|
except Exception:
|
|
return False
|
|
|
|
@override
|
|
async def count(self, *, collection_name: str | None = None) -> int:
|
|
"""
|
|
Get document count in collection.
|
|
|
|
Returns:
|
|
Number of documents
|
|
"""
|
|
try:
|
|
if not self.client:
|
|
return 0
|
|
collection, _ = await self._prepare_collection(collection_name, ensure_exists=False)
|
|
result = collection.aggregate.over_all(total_count=True)
|
|
return result.total_count or 0
|
|
except WeaviateQueryError as e:
|
|
raise StorageError(f"Count query failed: {e}") from e
|
|
except Exception:
|
|
return 0
|
|
|
|
async def list_collections(self) -> list[str]:
|
|
"""
|
|
List all available collections.
|
|
|
|
Returns:
|
|
List of collection names
|
|
"""
|
|
try:
|
|
if not self.client:
|
|
raise StorageError("Weaviate client not initialized")
|
|
|
|
client = cast(weaviate.WeaviateClient, self.client)
|
|
return list(client.collections.list_all())
|
|
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to list collections: {e}") from e
|
|
|
|
async def describe_collections(self) -> list[CollectionSummary]:
|
|
"""Return metadata for each Weaviate collection."""
|
|
if not self.client:
|
|
raise StorageError("Weaviate client not initialized")
|
|
|
|
try:
|
|
client = cast(weaviate.WeaviateClient, self.client)
|
|
collections: list[CollectionSummary] = []
|
|
for name in client.collections.list_all():
|
|
collection_obj = client.collections.get(name)
|
|
if not collection_obj:
|
|
continue
|
|
|
|
count = collection_obj.aggregate.over_all(total_count=True).total_count or 0
|
|
size_mb = count * 0.01
|
|
collection_summary: CollectionSummary = {
|
|
"name": name,
|
|
"count": count,
|
|
"size_mb": size_mb,
|
|
}
|
|
collections.append(collection_summary)
|
|
|
|
return collections
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to describe collections: {e}") from e
|
|
|
|
async def sample_documents(
|
|
self, limit: int = 5, *, collection_name: str | None = None
|
|
) -> list[Document]:
|
|
"""
|
|
Get sample documents from the collection.
|
|
|
|
Args:
|
|
limit: Maximum number of documents to return
|
|
|
|
Returns:
|
|
List of sample documents
|
|
"""
|
|
try:
|
|
collection, resolved_name = await self._prepare_collection(
|
|
collection_name, ensure_exists=False
|
|
)
|
|
|
|
# Query for sample documents
|
|
response = collection.query.fetch_objects(limit=limit)
|
|
|
|
documents = []
|
|
for obj in response.objects:
|
|
# Convert back to Document format
|
|
props = self._coerce_properties(
|
|
getattr(obj, "properties", None),
|
|
context="sample_documents",
|
|
allow_missing=True,
|
|
)
|
|
if props is None:
|
|
continue
|
|
uuid_raw = getattr(obj, "uuid", None)
|
|
if uuid_raw is None:
|
|
continue
|
|
document_id = uuid_raw if isinstance(uuid_raw, UUID) else UUID(str(uuid_raw))
|
|
# Safely convert WeaviateField values
|
|
word_count_val = props.get("word_count")
|
|
if isinstance(word_count_val, (int, float)):
|
|
word_count = int(word_count_val)
|
|
elif word_count_val:
|
|
word_count = int(str(word_count_val))
|
|
else:
|
|
word_count = 0
|
|
|
|
char_count_val = props.get("char_count")
|
|
if isinstance(char_count_val, (int, float)):
|
|
char_count = int(char_count_val)
|
|
elif char_count_val:
|
|
char_count = int(str(char_count_val))
|
|
else:
|
|
char_count = 0
|
|
|
|
doc = Document(
|
|
id=document_id,
|
|
content=str(props.get("content", "")),
|
|
source=self._parse_source(props.get("source")),
|
|
metadata={
|
|
"source_url": str(props.get("source_url", "")),
|
|
"title": str(props.get("title", "")) if props.get("title") else None,
|
|
"description": str(props.get("description", ""))
|
|
if props.get("description")
|
|
else None,
|
|
"timestamp": datetime.fromisoformat(
|
|
str(props.get("timestamp", datetime.now(UTC).isoformat()))
|
|
),
|
|
"content_type": str(props.get("content_type", "text/plain")),
|
|
"word_count": word_count,
|
|
"char_count": char_count,
|
|
},
|
|
collection=resolved_name,
|
|
)
|
|
documents.append(doc)
|
|
|
|
return documents
|
|
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to sample documents: {e}") from e
|
|
|
|
def _safe_convert_count(self, value: object) -> int:
|
|
"""Safely convert a value to integer count."""
|
|
if isinstance(value, (int, float)):
|
|
return int(value)
|
|
elif value:
|
|
return int(str(value))
|
|
else:
|
|
return 0
|
|
|
|
def _build_document_metadata(self, props: Mapping[str, object]) -> DocumentMetadata:
|
|
"""Build metadata from search document properties."""
|
|
return {
|
|
"source_url": str(props.get("source_url", "")),
|
|
"title": str(props.get("title", "")) if props.get("title") else None,
|
|
"description": str(props.get("description", ""))
|
|
if props.get("description")
|
|
else None,
|
|
"timestamp": datetime.fromisoformat(
|
|
str(props.get("timestamp", datetime.now(UTC).isoformat()))
|
|
),
|
|
"content_type": str(props.get("content_type", "text/plain")),
|
|
"word_count": self._safe_convert_count(props.get("word_count")),
|
|
"char_count": self._safe_convert_count(props.get("char_count")),
|
|
}
|
|
|
|
def _extract_document_score(self, obj: object) -> float | None:
|
|
"""Extract score from document search result."""
|
|
metadata_obj = getattr(obj, "metadata", None)
|
|
if metadata_obj is None:
|
|
return None
|
|
|
|
raw_score = getattr(metadata_obj, "score", None)
|
|
if raw_score is None:
|
|
return None
|
|
|
|
try:
|
|
return float(raw_score)
|
|
except (TypeError, ValueError) as e:
|
|
import logging
|
|
logging.debug(f"Invalid score value {raw_score}: {e}")
|
|
return None
|
|
|
|
def _build_document_from_search(
|
|
self,
|
|
obj: object,
|
|
resolved_name: str,
|
|
) -> Document:
|
|
"""Build Document from search document result."""
|
|
props = self._coerce_properties(
|
|
getattr(obj, "properties", None),
|
|
context="document search result",
|
|
)
|
|
metadata = self._build_document_metadata(props)
|
|
score_value = self._extract_document_score(obj)
|
|
|
|
uuid_raw = getattr(obj, "uuid", None)
|
|
if uuid_raw is None:
|
|
raise StorageError("Weaviate search document result missing uuid")
|
|
uuid_value = uuid_raw if isinstance(uuid_raw, UUID) else UUID(str(uuid_raw))
|
|
|
|
return Document(
|
|
id=uuid_value,
|
|
content=str(props.get("content", "")),
|
|
source=self._parse_source(props.get("source")),
|
|
metadata=metadata,
|
|
collection=resolved_name,
|
|
score=score_value,
|
|
)
|
|
|
|
async def search_documents(
|
|
self, query: str, limit: int = 10, *, collection_name: str | None = None
|
|
) -> list[Document]:
|
|
"""
|
|
Search documents in the collection.
|
|
|
|
Args:
|
|
query: Search query
|
|
limit: Maximum number of results
|
|
|
|
Returns:
|
|
List of matching documents
|
|
"""
|
|
try:
|
|
if not self.client:
|
|
raise StorageError("Weaviate client not initialized")
|
|
|
|
collection, resolved_name = await self._prepare_collection(
|
|
collection_name, ensure_exists=False
|
|
)
|
|
|
|
# Try hybrid search first, fall back to BM25 keyword search
|
|
try:
|
|
response = collection.query.hybrid(
|
|
query=query, limit=limit, return_metadata=["score"]
|
|
)
|
|
except Exception:
|
|
response = collection.query.bm25(
|
|
query=query, limit=limit, return_metadata=["score"]
|
|
)
|
|
|
|
return [
|
|
self._build_document_from_search(obj, resolved_name)
|
|
for obj in response.objects
|
|
]
|
|
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to search documents: {e}") from e
|
|
|
|
async def list_documents(
|
|
self,
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
*,
|
|
collection_name: str | None = None,
|
|
) -> list[DocumentInfo]:
|
|
"""
|
|
List documents in the collection with pagination.
|
|
|
|
Args:
|
|
limit: Maximum number of documents to return
|
|
offset: Number of documents to skip
|
|
|
|
Returns:
|
|
List of document dictionaries with id, title, source_url, and content preview
|
|
"""
|
|
try:
|
|
if not self.client:
|
|
raise StorageError("Weaviate client not initialized")
|
|
|
|
collection, _ = await self._prepare_collection(collection_name, ensure_exists=False)
|
|
|
|
# Query documents with pagination
|
|
response = collection.query.fetch_objects(
|
|
limit=limit, offset=offset, return_metadata=["creation_time"]
|
|
)
|
|
|
|
documents: list[DocumentInfo] = []
|
|
for obj in response.objects:
|
|
props = self._coerce_properties(
|
|
obj.properties,
|
|
context="list_documents",
|
|
allow_missing=True,
|
|
)
|
|
if props is None:
|
|
continue
|
|
content = str(props.get("content", ""))
|
|
word_count_value = props.get("word_count", 0)
|
|
# Convert WeaviateField to int
|
|
if isinstance(word_count_value, (int, float)):
|
|
word_count = int(word_count_value)
|
|
elif word_count_value:
|
|
word_count = int(str(word_count_value))
|
|
else:
|
|
word_count = 0
|
|
|
|
doc_info: DocumentInfo = {
|
|
"id": str(obj.uuid),
|
|
"title": str(props.get("title", "Untitled")),
|
|
"source_url": str(props.get("source_url", "")),
|
|
"description": str(props.get("description", "")),
|
|
"content_type": str(props.get("content_type", "text/plain")),
|
|
"content_preview": (f"{content[:200]}..." if len(content) > 200 else content),
|
|
"word_count": word_count,
|
|
"timestamp": str(props.get("timestamp", "")),
|
|
}
|
|
documents.append(doc_info)
|
|
|
|
return documents
|
|
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to list documents: {e}") from e
|
|
|
|
async def delete_documents(
|
|
self, document_ids: list[str], *, collection_name: str | None = None
|
|
) -> dict[str, bool]:
|
|
"""
|
|
Delete multiple documents from Weaviate.
|
|
|
|
Args:
|
|
document_ids: List of document IDs to delete
|
|
|
|
Returns:
|
|
Dictionary mapping document IDs to deletion success status
|
|
"""
|
|
results: dict[str, bool] = {}
|
|
|
|
try:
|
|
if not self.client:
|
|
raise StorageError("Weaviate client not initialized")
|
|
|
|
if not document_ids:
|
|
return results
|
|
|
|
collection, resolved_name = await self._prepare_collection(
|
|
collection_name, ensure_exists=False
|
|
)
|
|
|
|
delete_filter = Filter.by_id().contains_any(document_ids)
|
|
response = collection.data.delete_many(where=delete_filter, verbose=True)
|
|
|
|
if objects := getattr(response, "objects", None):
|
|
for result_obj in objects:
|
|
if doc_uuid := str(getattr(result_obj, "uuid", "")):
|
|
results[doc_uuid] = bool(getattr(result_obj, "successful", False))
|
|
|
|
if len(results) < len(document_ids):
|
|
default_success = getattr(response, "failed", 0) == 0
|
|
for doc_id in document_ids:
|
|
_ = results.setdefault(doc_id, default_success)
|
|
|
|
return results
|
|
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to delete documents: {e}") from e
|
|
|
|
async def delete_by_filter(
|
|
self, filter_dict: dict[str, str], *, collection_name: str | None = None
|
|
) -> int:
|
|
"""
|
|
Delete documents matching a filter.
|
|
|
|
Args:
|
|
filter_dict: Filter criteria (e.g., {"source_url": "example.com"})
|
|
|
|
Returns:
|
|
Number of documents deleted
|
|
"""
|
|
try:
|
|
if not self.client:
|
|
raise StorageError("Weaviate client not initialized")
|
|
|
|
collection, _ = await self._prepare_collection(collection_name, ensure_exists=False)
|
|
|
|
# Build where filter
|
|
where_filter = None
|
|
if "source_url" in filter_dict:
|
|
where_filter = Filter.by_property("source_url").equal(filter_dict["source_url"])
|
|
|
|
# Get documents matching filter
|
|
if where_filter:
|
|
response = collection.query.fetch_objects(
|
|
filters=where_filter,
|
|
limit=1000, # Max batch size
|
|
)
|
|
else:
|
|
response = collection.query.fetch_objects(
|
|
limit=1000 # Max batch size
|
|
)
|
|
|
|
# Delete matching documents
|
|
deleted_count = 0
|
|
for obj in response.objects:
|
|
try:
|
|
collection.data.delete_by_id(obj.uuid)
|
|
deleted_count += 1
|
|
except Exception:
|
|
continue
|
|
|
|
return deleted_count
|
|
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to delete by filter: {e}") from e
|
|
|
|
async def delete_collection(self, collection_name: str | None = None) -> bool:
|
|
"""
|
|
Delete the entire collection.
|
|
|
|
Returns:
|
|
True if successful
|
|
"""
|
|
try:
|
|
if not self.client:
|
|
raise StorageError("Weaviate client not initialized")
|
|
|
|
target = self._normalize_collection_name(collection_name)
|
|
|
|
# Delete the collection using the client's collections API
|
|
client = cast(weaviate.WeaviateClient, self.client)
|
|
client.collections.delete(target)
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
raise StorageError(f"Failed to delete collection: {e}") from e
|
|
|
|
async def __aenter__(self) -> Self:
|
|
"""Async context manager entry."""
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: object | None,
|
|
) -> None:
|
|
"""Async context manager exit with proper cleanup."""
|
|
await self.close()
|
|
|
|
async def close(self) -> None:
|
|
"""Close client connection."""
|
|
if self.client:
|
|
try:
|
|
client = cast(weaviate.WeaviateClient, self.client)
|
|
client.close()
|
|
except Exception as e:
|
|
import logging
|
|
logging.warning(f"Error closing Weaviate client: {e}")
|
|
|
|
def __del__(self) -> None:
|
|
"""Clean up client connection as fallback."""
|
|
if self.client:
|
|
try:
|
|
client = cast(weaviate.WeaviateClient, self.client)
|
|
client.close()
|
|
except Exception:
|
|
pass # Ignore errors in destructor
|