Enforce embedding_func validation in BaseVectorStorage
- Add validation in BaseVectorStorage - Call super().__post_init__ in subclasses - Simplify collection suffix logic
This commit is contained in:
@@ -220,33 +220,32 @@ class BaseVectorStorage(StorageNameSpace, ABC):
|
||||
cosine_better_than_threshold: float = field(default=0.2)
|
||||
meta_fields: set[str] = field(default_factory=set)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate required embedding_func for vector storage."""
|
||||
if self.embedding_func is None:
|
||||
raise ValueError(
|
||||
"embedding_func is required for vector storage. "
|
||||
"Please provide a valid EmbeddingFunc instance."
|
||||
)
|
||||
|
||||
def _generate_collection_suffix(self) -> str | None:
|
||||
"""Generates collection/table suffix from embedding_func.
|
||||
|
||||
Return suffix if model_name exists in embedding_func, otherwise return None.
|
||||
Note: embedding_func is guaranteed to exist (validated in __post_init__).
|
||||
|
||||
Returns:
|
||||
str | None: Suffix string e.g. "text_embedding_3_large_3072d", or None if model_name not available
|
||||
"""
|
||||
import re
|
||||
|
||||
# Try to get embedding_func from self or global_config
|
||||
embedding_func = self.embedding_func
|
||||
if embedding_func is None and "embedding_func" in self.global_config:
|
||||
embedding_func = self.global_config["embedding_func"]
|
||||
|
||||
if embedding_func is None:
|
||||
return None
|
||||
|
||||
# Check if model_name exists
|
||||
model_name = getattr(embedding_func, "model_name", None)
|
||||
# Check if model_name exists (model_name is optional in EmbeddingFunc)
|
||||
model_name = getattr(self.embedding_func, "model_name", None)
|
||||
if not model_name:
|
||||
return None
|
||||
|
||||
# Get embedding_dim
|
||||
embedding_dim = getattr(embedding_func, "embedding_dim", None)
|
||||
if embedding_dim is None:
|
||||
return None
|
||||
# embedding_dim is required in EmbeddingFunc
|
||||
embedding_dim = self.embedding_func.embedding_dim
|
||||
|
||||
# Generate suffix: clean model name and append dimension
|
||||
safe_model_name = re.sub(r"[^a-zA-Z0-9_]", "_", model_name.lower())
|
||||
|
||||
@@ -28,6 +28,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Grab config values if available
|
||||
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
||||
|
||||
@@ -934,6 +934,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
raise
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Check for MILVUS_WORKSPACE environment variable first (higher priority)
|
||||
# This allows administrators to force a specific workspace for all Milvus storage instances
|
||||
milvus_workspace = os.environ.get("MILVUS_WORKSPACE")
|
||||
|
||||
@@ -25,6 +25,7 @@ from .shared_storage import (
|
||||
@dataclass
|
||||
class NanoVectorDBStorage(BaseVectorStorage):
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Initialize basic attributes
|
||||
self._client = None
|
||||
self._storage_lock = None
|
||||
|
||||
@@ -2298,6 +2298,7 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
db: PostgreSQLDB | None = field(default=None)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||
|
||||
Reference in New Issue
Block a user