Merge branch 'main' of github.com:HKUDS/LightRAG into embedding-wrapping

This commit is contained in:
yangdx
2025-12-22 17:07:35 +08:00
11 changed files with 46 additions and 23 deletions

View File

@@ -220,8 +220,15 @@ class BaseVectorStorage(StorageNameSpace, ABC):
cosine_better_than_threshold: float = field(default=0.2)
meta_fields: set[str] = field(default_factory=set)
def __post_init__(self):
"""Validate required embedding_func for vector storage."""
def _validate_embedding_func(self):
"""Validate that embedding_func is provided.
This method should be called at the beginning of __post_init__
in all vector storage implementations.
Raises:
ValueError: If embedding_func is None
"""
if self.embedding_func is None:
raise ValueError(
"embedding_func is required for vector storage. "

View File

@@ -21,6 +21,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
"""ChromaDB vector storage implementation."""
def __post_init__(self):
self._validate_embedding_func()
try:
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold")

View File

@@ -28,7 +28,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
"""
def __post_init__(self):
super().__post_init__()
self._validate_embedding_func()
# Grab config values if available
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold")

View File

@@ -37,6 +37,9 @@ class MemgraphStorage(BaseGraphStorage):
# Priority: 1) MEMGRAPH_WORKSPACE env 2) user arg 3) default 'base'
memgraph_workspace = os.environ.get("MEMGRAPH_WORKSPACE")
if memgraph_workspace and memgraph_workspace.strip():
logger.info(
f"Using MEMGRAPH_WORKSPACE environment variable: '{memgraph_workspace}' (overriding '{self.workspace}/{self.namespace}')"
)
workspace = memgraph_workspace
if not workspace or not str(workspace).strip():

View File

@@ -934,7 +934,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
raise
def __post_init__(self):
super().__post_init__()
self._validate_embedding_func()
# Check for MILVUS_WORKSPACE environment variable first (higher priority)
# This allows administrators to force a specific workspace for all Milvus storage instances
milvus_workspace = os.environ.get("MILVUS_WORKSPACE")
@@ -942,7 +942,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
# Use environment variable value, overriding the passed workspace parameter
effective_workspace = milvus_workspace.strip()
logger.info(
f"Using MILVUS_WORKSPACE environment variable: '{effective_workspace}' (overriding passed workspace: '{self.workspace}')"
f"Using MILVUS_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')"
)
else:
# Use the workspace parameter passed during initialization

View File

@@ -89,7 +89,7 @@ class MongoKVStorage(BaseKVStorage):
global_config=global_config,
embedding_func=embedding_func,
)
# __post_init__() is automatically called by super().__init__()
self.__post_init__()
def __post_init__(self):
# Check for MONGODB_WORKSPACE environment variable first (higher priority)
@@ -99,7 +99,7 @@ class MongoKVStorage(BaseKVStorage):
# Use environment variable value, overriding the passed workspace parameter
effective_workspace = mongodb_workspace.strip()
logger.info(
f"Using MONGODB_WORKSPACE environment variable: '{effective_workspace}' (overriding passed workspace: '{self.workspace}')"
f"Using MONGODB_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')"
)
else:
# Use the workspace parameter passed during initialization
@@ -317,7 +317,7 @@ class MongoDocStatusStorage(DocStatusStorage):
global_config=global_config,
embedding_func=embedding_func,
)
# __post_init__() is automatically called by super().__init__()
self.__post_init__()
def __post_init__(self):
# Check for MONGODB_WORKSPACE environment variable first (higher priority)
@@ -327,7 +327,7 @@ class MongoDocStatusStorage(DocStatusStorage):
# Use environment variable value, overriding the passed workspace parameter
effective_workspace = mongodb_workspace.strip()
logger.info(
f"Using MONGODB_WORKSPACE environment variable: '{effective_workspace}' (overriding passed workspace: '{self.workspace}')"
f"Using MONGODB_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')"
)
else:
# Use the workspace parameter passed during initialization
@@ -750,7 +750,7 @@ class MongoGraphStorage(BaseGraphStorage):
# Use environment variable value, overriding the passed workspace parameter
effective_workspace = mongodb_workspace.strip()
logger.info(
f"Using MONGODB_WORKSPACE environment variable: '{effective_workspace}' (overriding passed workspace: '{self.workspace}')"
f"Using MONGODB_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')"
)
else:
# Use the workspace parameter passed during initialization
@@ -2052,11 +2052,10 @@ class MongoVectorDBStorage(BaseVectorStorage):
embedding_func=embedding_func,
meta_fields=meta_fields or set(),
)
# __post_init__() is automatically called by super().__init__()
self.__post_init__()
def __post_init__(self):
# Call parent class __post_init__ to validate embedding_func
super().__post_init__()
self._validate_embedding_func()
# Check for MONGODB_WORKSPACE environment variable first (higher priority)
# This allows administrators to force a specific workspace for all MongoDB storage instances
@@ -2065,7 +2064,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
# Use environment variable value, overriding the passed workspace parameter
effective_workspace = mongodb_workspace.strip()
logger.info(
f"Using MONGODB_WORKSPACE environment variable: '{effective_workspace}' (overriding passed workspace: '{self.workspace}')"
f"Using MONGODB_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')"
)
else:
# Use the workspace parameter passed during initialization

View File

@@ -25,7 +25,7 @@ from .shared_storage import (
@dataclass
class NanoVectorDBStorage(BaseVectorStorage):
def __post_init__(self):
super().__post_init__()
self._validate_embedding_func()
# Initialize basic attributes
self._client = None
self._storage_lock = None

View File

@@ -68,6 +68,9 @@ class Neo4JStorage(BaseGraphStorage):
# Read env and override the arg if present
neo4j_workspace = os.environ.get("NEO4J_WORKSPACE")
if neo4j_workspace and neo4j_workspace.strip():
logger.info(
f"Using NEO4J_WORKSPACE environment variable: '{neo4j_workspace}' (overriding '{self.workspace}/{self.namespace}')"
)
workspace = neo4j_workspace
# Default to 'base' when both arg and env are empty

View File

@@ -1852,6 +1852,9 @@ class PGKVStorage(BaseKVStorage):
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
if self.db.workspace:
# Use PostgreSQLDB's workspace (highest priority)
logger.info(
f"Using PG_WORKSPACE environment variable: '{self.db.workspace}' (overriding '{self.workspace}/{self.namespace}')"
)
self.workspace = self.db.workspace
elif hasattr(self, "workspace") and self.workspace:
# Use storage class's workspace (medium priority)
@@ -2328,7 +2331,7 @@ class PGVectorStorage(BaseVectorStorage):
db: PostgreSQLDB | None = field(default=None)
def __post_init__(self):
super().__post_init__()
self._validate_embedding_func()
self._max_batch_size = self.global_config["embedding_batch_num"]
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold")
@@ -2783,6 +2786,9 @@ class PGVectorStorage(BaseVectorStorage):
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
if self.db.workspace:
# Use PostgreSQLDB's workspace (highest priority)
logger.info(
f"Using PG_WORKSPACE environment variable: '{self.db.workspace}' (overriding '{self.workspace}/{self.namespace}')"
)
self.workspace = self.db.workspace
elif hasattr(self, "workspace") and self.workspace:
# Use storage class's workspace (medium priority)
@@ -3196,6 +3202,9 @@ class PGDocStatusStorage(DocStatusStorage):
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
if self.db.workspace:
# Use PostgreSQLDB's workspace (highest priority)
logger.info(
f"Using PG_WORKSPACE environment variable: '{self.db.workspace}' (overriding '{self.workspace}/{self.namespace}')"
)
self.workspace = self.db.workspace
elif hasattr(self, "workspace") and self.workspace:
# Use storage class's workspace (medium priority)
@@ -3877,6 +3886,9 @@ class PGGraphStorage(BaseGraphStorage):
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
if self.db.workspace:
# Use PostgreSQLDB's workspace (highest priority)
logger.info(
f"Using PG_WORKSPACE environment variable: '{self.db.workspace}' (overriding '{self.workspace}/{self.namespace}')"
)
self.workspace = self.db.workspace
elif hasattr(self, "workspace") and self.workspace:
# Use storage class's workspace (medium priority)

View File

@@ -124,7 +124,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
embedding_func=embedding_func,
meta_fields=meta_fields or set(),
)
# __post_init__() is automatically called by super().__init__()
self.__post_init__()
@staticmethod
def setup_collection(
@@ -413,9 +413,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
)
def __post_init__(self):
# Call parent class __post_init__ to validate embedding_func
super().__post_init__()
self._validate_embedding_func()
# Check for QDRANT_WORKSPACE environment variable first (higher priority)
# This allows administrators to force a specific workspace for all Qdrant storage instances
qdrant_workspace = os.environ.get("QDRANT_WORKSPACE")
@@ -423,7 +421,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
# Use environment variable value, overriding the passed workspace parameter
effective_workspace = qdrant_workspace.strip()
logger.info(
f"Using QDRANT_WORKSPACE environment variable: '{effective_workspace}' (overriding passed workspace: '{self.workspace}')"
f"Using QDRANT_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')"
)
else:
# Use the workspace parameter passed during initialization

View File

@@ -133,7 +133,7 @@ class RedisKVStorage(BaseKVStorage):
# Use environment variable value, overriding the passed workspace parameter
effective_workspace = redis_workspace.strip()
logger.info(
f"Using REDIS_WORKSPACE environment variable: '{effective_workspace}' (overriding passed workspace: '{self.workspace}')"
f"Using REDIS_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')"
)
else:
# Use the workspace parameter passed during initialization
@@ -526,7 +526,7 @@ class RedisDocStatusStorage(DocStatusStorage):
# Use environment variable value, overriding the passed workspace parameter
effective_workspace = redis_workspace.strip()
logger.info(
f"Using REDIS_WORKSPACE environment variable: '{effective_workspace}' (overriding passed workspace: '{self.workspace}')"
f"Using REDIS_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')"
)
else:
# Use the workspace parameter passed during initialization