diff --git a/lightrag/base.py b/lightrag/base.py index 4e32bf25..75059377 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -220,33 +220,32 @@ class BaseVectorStorage(StorageNameSpace, ABC): cosine_better_than_threshold: float = field(default=0.2) meta_fields: set[str] = field(default_factory=set) + def __post_init__(self): + """Validate required embedding_func for vector storage.""" + if self.embedding_func is None: + raise ValueError( + "embedding_func is required for vector storage. " + "Please provide a valid EmbeddingFunc instance." + ) + def _generate_collection_suffix(self) -> str | None: """Generates collection/table suffix from embedding_func. Return suffix if model_name exists in embedding_func, otherwise return None. + Note: embedding_func is guaranteed to exist (validated in __post_init__). Returns: str | None: Suffix string e.g. "text_embedding_3_large_3072d", or None if model_name not available """ import re - # Try to get embedding_func from self or global_config - embedding_func = self.embedding_func - if embedding_func is None and "embedding_func" in self.global_config: - embedding_func = self.global_config["embedding_func"] - - if embedding_func is None: - return None - - # Check if model_name exists - model_name = getattr(embedding_func, "model_name", None) + # Check if model_name exists (model_name is optional in EmbeddingFunc) + model_name = getattr(self.embedding_func, "model_name", None) if not model_name: return None - # Get embedding_dim - embedding_dim = getattr(embedding_func, "embedding_dim", None) - if embedding_dim is None: - return None + # embedding_dim is required in EmbeddingFunc + embedding_dim = self.embedding_func.embedding_dim # Generate suffix: clean model name and append dimension safe_model_name = re.sub(r"[^a-zA-Z0-9_]", "_", model_name.lower()) diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index 5c304d65..e299211b 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -28,6 +28,7 @@ class FaissVectorDBStorage(BaseVectorStorage): """ def __post_init__(self): + super().__post_init__() # Grab config values if available kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = kwargs.get("cosine_better_than_threshold") diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index d42c91a7..50c233a8 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -934,6 +934,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): raise def __post_init__(self): + super().__post_init__() # Check for MILVUS_WORKSPACE environment variable first (higher priority) # This allows administrators to force a specific workspace for all Milvus storage instances milvus_workspace = os.environ.get("MILVUS_WORKSPACE") diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index d390c37b..9b868c11 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -25,6 +25,7 @@ from .shared_storage import ( @dataclass class NanoVectorDBStorage(BaseVectorStorage): def __post_init__(self): + super().__post_init__() # Initialize basic attributes self._client = None self._storage_lock = None diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index c427d54c..42e3e6de 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -2298,6 +2298,7 @@ class PGVectorStorage(BaseVectorStorage): db: PostgreSQLDB | None = field(default=None) def __post_init__(self): + super().__post_init__() self._max_batch_size = self.global_config["embedding_batch_num"] config = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = config.get("cosine_better_than_threshold")