Enforce embedding_func validation in BaseVectorStorage

- Add validation in BaseVectorStorage
- Call super().__post_init__ in subclasses
- Simplify collection suffix logic
This commit is contained in:
yangdx
2025-12-20 03:49:31 +08:00
parent 81a0d632ca
commit 864131a622
5 changed files with 17 additions and 14 deletions

View File

@@ -220,33 +220,32 @@ class BaseVectorStorage(StorageNameSpace, ABC):
cosine_better_than_threshold: float = field(default=0.2)
meta_fields: set[str] = field(default_factory=set)
def __post_init__(self):
"""Validate required embedding_func for vector storage."""
if self.embedding_func is None:
raise ValueError(
"embedding_func is required for vector storage. "
"Please provide a valid EmbeddingFunc instance."
)
def _generate_collection_suffix(self) -> str | None:
"""Generates collection/table suffix from embedding_func.
Return suffix if model_name exists in embedding_func, otherwise return None.
Note: embedding_func is guaranteed to exist (validated in __post_init__).
Returns:
str | None: Suffix string e.g. "text_embedding_3_large_3072d", or None if model_name not available
"""
import re
# Try to get embedding_func from self or global_config
embedding_func = self.embedding_func
if embedding_func is None and "embedding_func" in self.global_config:
embedding_func = self.global_config["embedding_func"]
if embedding_func is None:
return None
# Check if model_name exists
model_name = getattr(embedding_func, "model_name", None)
# Check if model_name exists (model_name is optional in EmbeddingFunc)
model_name = getattr(self.embedding_func, "model_name", None)
if not model_name:
return None
# Get embedding_dim
embedding_dim = getattr(embedding_func, "embedding_dim", None)
if embedding_dim is None:
return None
# embedding_dim is required in EmbeddingFunc
embedding_dim = self.embedding_func.embedding_dim
# Generate suffix: clean model name and append dimension
safe_model_name = re.sub(r"[^a-zA-Z0-9_]", "_", model_name.lower())

View File

@@ -28,6 +28,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
"""
def __post_init__(self):
super().__post_init__()
# Grab config values if available
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold")

View File

@@ -934,6 +934,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
raise
def __post_init__(self):
super().__post_init__()
# Check for MILVUS_WORKSPACE environment variable first (higher priority)
# This allows administrators to force a specific workspace for all Milvus storage instances
milvus_workspace = os.environ.get("MILVUS_WORKSPACE")

View File

@@ -25,6 +25,7 @@ from .shared_storage import (
@dataclass
class NanoVectorDBStorage(BaseVectorStorage):
def __post_init__(self):
super().__post_init__()
# Initialize basic attributes
self._client = None
self._storage_lock = None

View File

@@ -2298,6 +2298,7 @@ class PGVectorStorage(BaseVectorStorage):
db: PostgreSQLDB | None = field(default=None)
def __post_init__(self):
super().__post_init__()
self._max_batch_size = self.global_config["embedding_batch_num"]
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold")