Merge pull request #2513 from HKUDS/feature/vectordb-model-isolation

feat: Implement Vector Database Model Isolation and Auto-Migration
This commit is contained in:
Daniel.y
2025-12-21 18:58:34 +08:00
committed by GitHub
36 changed files with 4746 additions and 485 deletions

View File

@@ -13,7 +13,7 @@ jobs:
strategy:
matrix:
python-version: ['3.12', '3.13', '3.14']
python-version: ['3.12', '3.14']
steps:
- uses: actions/checkout@v6

View File

@@ -286,7 +286,7 @@ if __name__ == "__main__":
<summary> 参数 </summary>
| **参数** | **类型** | **说明** | **默认值** |
|--------------|----------|-----------------|-------------|
| -------------- | ---------- | ----------------- | ------------- |
| **working_dir** | `str` | 存储缓存的目录 | `lightrag_cache+timestamp` |
| **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` |
| **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` |
@@ -425,7 +425,7 @@ async def llm_model_func(
**kwargs
)
@wrap_embedding_func_with_attrs(embedding_dim=4096, max_token_size=8192)
@wrap_embedding_func_with_attrs(embedding_dim=4096, max_token_size=8192, model_name="solar-embedding-1-large-query")
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embed.func(
texts,
@@ -490,7 +490,7 @@ import numpy as np
from lightrag.utils import wrap_embedding_func_with_attrs
from lightrag.llm.ollama import ollama_model_complete, ollama_embed
@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192)
@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192, model_name="nomic-embed-text")
async def embedding_func(texts: list[str]) -> np.ndarray:
return await ollama_embed.func(texts, embed_model="nomic-embed-text")
@@ -542,7 +542,7 @@ import numpy as np
from lightrag.utils import wrap_embedding_func_with_attrs
from lightrag.llm.ollama import ollama_model_complete, ollama_embed
@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192)
@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192, model_name="nomic-embed-text")
async def embedding_func(texts: list[str]) -> np.ndarray:
return await ollama_embed.func(texts, embed_model="nomic-embed-text")
@@ -1633,24 +1633,24 @@ LightRAG使用以下提示生成高级查询相应的代码在`example/genera
### 总体性能表
| |**农业**| |**计算机科学**| |**法律**| |**混合**| |
||**农业**||**计算机科学**||**法律**||**混合**||
|----------------------|---------------|------------|------|------------|---------|------------|-------|------------|
| |NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|
||NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|
|**全面性**|32.4%|**67.6%**|38.4%|**61.6%**|16.4%|**83.6%**|38.8%|**61.2%**|
|**多样性**|23.6%|**76.4%**|38.0%|**62.0%**|13.6%|**86.4%**|32.4%|**67.6%**|
|**赋能性**|32.4%|**67.6%**|38.8%|**61.2%**|16.4%|**83.6%**|42.8%|**57.2%**|
|**总体**|32.4%|**67.6%**|38.8%|**61.2%**|15.2%|**84.8%**|40.0%|**60.0%**|
| |RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|
||RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|
|**全面性**|31.6%|**68.4%**|38.8%|**61.2%**|15.2%|**84.8%**|39.2%|**60.8%**|
|**多样性**|29.2%|**70.8%**|39.2%|**60.8%**|11.6%|**88.4%**|30.8%|**69.2%**|
|**赋能性**|31.6%|**68.4%**|36.4%|**63.6%**|15.2%|**84.8%**|42.4%|**57.6%**|
|**总体**|32.4%|**67.6%**|38.0%|**62.0%**|14.4%|**85.6%**|40.0%|**60.0%**|
| |HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**|
||HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**|
|**全面性**|26.0%|**74.0%**|41.6%|**58.4%**|26.8%|**73.2%**|40.4%|**59.6%**|
|**多样性**|24.0%|**76.0%**|38.8%|**61.2%**|20.0%|**80.0%**|32.4%|**67.6%**|
|**赋能性**|25.2%|**74.8%**|40.8%|**59.2%**|26.0%|**74.0%**|46.0%|**54.0%**|
|**总体**|24.8%|**75.2%**|41.6%|**58.4%**|26.4%|**73.6%**|42.4%|**57.6%**|
| |GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|
||GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|
|**全面性**|45.6%|**54.4%**|48.4%|**51.6%**|48.4%|**51.6%**|**50.4%**|49.6%|
|**多样性**|22.8%|**77.2%**|40.8%|**59.2%**|26.4%|**73.6%**|36.0%|**64.0%**|
|**赋能性**|41.2%|**58.8%**|45.2%|**54.8%**|43.6%|**56.4%**|**50.8%**|49.2%|

View File

@@ -287,9 +287,9 @@ A full list of LightRAG init parameters:
<summary> Parameters </summary>
| **Parameter** | **Type** | **Explanation** | **Default** |
|--------------|----------|-----------------|-------------|
| -------------- | ---------- | ----------------- | ------------- |
| **working_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` |
| **workspace** | str | Workspace name for data isolation between different LightRAG Instances | |
| **workspace** | str | Workspace name for data isolation between different LightRAG Instances | |
| **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` |
| **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` |
| **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`,`Neo4JStorage`,`PGGraphStorage`,`AGEStorage` | `NetworkXStorage` |
@@ -421,7 +421,7 @@ async def llm_model_func(
**kwargs
)
@wrap_embedding_func_with_attrs(embedding_dim=4096, max_token_size=8192)
@wrap_embedding_func_with_attrs(embedding_dim=4096, max_token_size=8192, model_name="solar-embedding-1-large-query")
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embed.func(
texts,
@@ -488,7 +488,7 @@ import numpy as np
from lightrag.utils import wrap_embedding_func_with_attrs
from lightrag.llm.ollama import ollama_model_complete, ollama_embed
@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192)
@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192, model_name="nomic-embed-text")
async def embedding_func(texts: list[str]) -> np.ndarray:
return await ollama_embed.func(texts, embed_model="nomic-embed-text")
@@ -540,7 +540,7 @@ import numpy as np
from lightrag.utils import wrap_embedding_func_with_attrs
from lightrag.llm.ollama import ollama_model_complete, ollama_embed
@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192)
@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192, model_name="nomic-embed-text")
async def embedding_func(texts: list[str]) -> np.ndarray:
return await ollama_embed.func(texts, embed_model="nomic-embed-text")
@@ -1701,24 +1701,24 @@ Output your evaluation in the following JSON format:
### Overall Performance Table
| |**Agriculture**| |**CS**| |**Legal**| |**Mix**| |
||**Agriculture**||**CS**||**Legal**||**Mix**||
|----------------------|---------------|------------|------|------------|---------|------------|-------|------------|
| |NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|
||NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|
|**Comprehensiveness**|32.4%|**67.6%**|38.4%|**61.6%**|16.4%|**83.6%**|38.8%|**61.2%**|
|**Diversity**|23.6%|**76.4%**|38.0%|**62.0%**|13.6%|**86.4%**|32.4%|**67.6%**|
|**Empowerment**|32.4%|**67.6%**|38.8%|**61.2%**|16.4%|**83.6%**|42.8%|**57.2%**|
|**Overall**|32.4%|**67.6%**|38.8%|**61.2%**|15.2%|**84.8%**|40.0%|**60.0%**|
| |RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|
||RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|
|**Comprehensiveness**|31.6%|**68.4%**|38.8%|**61.2%**|15.2%|**84.8%**|39.2%|**60.8%**|
|**Diversity**|29.2%|**70.8%**|39.2%|**60.8%**|11.6%|**88.4%**|30.8%|**69.2%**|
|**Empowerment**|31.6%|**68.4%**|36.4%|**63.6%**|15.2%|**84.8%**|42.4%|**57.6%**|
|**Overall**|32.4%|**67.6%**|38.0%|**62.0%**|14.4%|**85.6%**|40.0%|**60.0%**|
| |HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**|
||HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**|
|**Comprehensiveness**|26.0%|**74.0%**|41.6%|**58.4%**|26.8%|**73.2%**|40.4%|**59.6%**|
|**Diversity**|24.0%|**76.0%**|38.8%|**61.2%**|20.0%|**80.0%**|32.4%|**67.6%**|
|**Empowerment**|25.2%|**74.8%**|40.8%|**59.2%**|26.0%|**74.0%**|46.0%|**54.0%**|
|**Overall**|24.8%|**75.2%**|41.6%|**58.4%**|26.4%|**73.6%**|42.4%|**57.6%**|
| |GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|
||GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|
|**Comprehensiveness**|45.6%|**54.4%**|48.4%|**51.6%**|48.4%|**51.6%**|**50.4%**|49.6%|
|**Diversity**|22.8%|**77.2%**|40.8%|**59.2%**|26.4%|**73.6%**|36.0%|**64.0%**|
|**Empowerment**|41.2%|**58.8%**|45.2%|**54.8%**|43.6%|**56.4%**|**50.8%**|49.2%|

View File

@@ -868,6 +868,7 @@ def create_app(args):
func=optimized_embedding_function,
max_token_size=final_max_token_size,
send_dimensions=False, # Will be set later based on binding requirements
model_name=model,
)
# Log final embedding configuration

View File

@@ -220,6 +220,37 @@ class BaseVectorStorage(StorageNameSpace, ABC):
cosine_better_than_threshold: float = field(default=0.2)
meta_fields: set[str] = field(default_factory=set)
def __post_init__(self):
"""Validate required embedding_func for vector storage."""
if self.embedding_func is None:
raise ValueError(
"embedding_func is required for vector storage. "
"Please provide a valid EmbeddingFunc instance."
)
def _generate_collection_suffix(self) -> str | None:
"""Generates collection/table suffix from embedding_func.
Return suffix if model_name exists in embedding_func, otherwise return None.
Note: embedding_func is guaranteed to exist (validated in __post_init__).
Returns:
str | None: Suffix string e.g. "text_embedding_3_large_3072d", or None if model_name not available
"""
import re
# Check if model_name exists (model_name is optional in EmbeddingFunc)
model_name = getattr(self.embedding_func, "model_name", None)
if not model_name:
return None
# embedding_dim is required in EmbeddingFunc
embedding_dim = self.embedding_func.embedding_dim
# Generate suffix: clean model name and append dimension
safe_model_name = re.sub(r"[^a-zA-Z0-9_]", "_", model_name.lower())
return f"{safe_model_name}_{embedding_dim}d"
@abstractmethod
async def query(
self, query: str, top_k: int, query_embedding: list[float] = None

View File

@@ -128,8 +128,8 @@ class ChunkTokenLimitExceededError(ValueError):
self.chunk_preview = truncated_preview
class QdrantMigrationError(Exception):
"""Raised when Qdrant data migration from legacy collections fails."""
class DataMigrationError(Exception):
"""Raised when data migration from legacy collection/table fails."""
def __init__(self, message: str):
super().__init__(message)

View File

@@ -28,6 +28,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
"""
def __post_init__(self):
super().__post_init__()
# Grab config values if available
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold")
@@ -358,9 +359,22 @@ class FaissVectorDBStorage(BaseVectorStorage):
)
return
dim_mismatch = False
try:
# Load the Faiss index
self._index = faiss.read_index(self._faiss_index_file)
# Verify dimension consistency between loaded index and embedding function
if self._index.d != self._dim:
error_msg = (
f"Dimension mismatch: loaded Faiss index has dimension {self._index.d}, "
f"but embedding function expects dimension {self._dim}. "
f"Please ensure the embedding model matches the stored index or rebuild the index."
)
logger.error(error_msg)
dim_mismatch = True
raise ValueError(error_msg)
# Load metadata
with open(self._meta_file, "r", encoding="utf-8") as f:
stored_dict = json.load(f)
@@ -375,6 +389,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
f"[{self.workspace}] Faiss index loaded with {self._index.ntotal} vectors from {self._faiss_index_file}"
)
except Exception as e:
if dim_mismatch:
raise
logger.error(
f"[{self.workspace}] Failed to load Faiss index or metadata: {e}"
)

View File

@@ -934,6 +934,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
raise
def __post_init__(self):
super().__post_init__()
# Check for MILVUS_WORKSPACE environment variable first (higher priority)
# This allows administrators to force a specific workspace for all Milvus storage instances
milvus_workspace = os.environ.get("MILVUS_WORKSPACE")

View File

@@ -89,7 +89,7 @@ class MongoKVStorage(BaseKVStorage):
global_config=global_config,
embedding_func=embedding_func,
)
self.__post_init__()
# __post_init__() is automatically called by super().__init__()
def __post_init__(self):
# Check for MONGODB_WORKSPACE environment variable first (higher priority)
@@ -317,7 +317,7 @@ class MongoDocStatusStorage(DocStatusStorage):
global_config=global_config,
embedding_func=embedding_func,
)
self.__post_init__()
# __post_init__() is automatically called by super().__init__()
def __post_init__(self):
# Check for MONGODB_WORKSPACE environment variable first (higher priority)
@@ -2052,9 +2052,12 @@ class MongoVectorDBStorage(BaseVectorStorage):
embedding_func=embedding_func,
meta_fields=meta_fields or set(),
)
self.__post_init__()
# __post_init__() is automatically called by super().__init__()
def __post_init__(self):
# Call parent class __post_init__ to validate embedding_func
super().__post_init__()
# Check for MONGODB_WORKSPACE environment variable first (higher priority)
# This allows administrators to force a specific workspace for all MongoDB storage instances
mongodb_workspace = os.environ.get("MONGODB_WORKSPACE")
@@ -2131,8 +2134,32 @@ class MongoVectorDBStorage(BaseVectorStorage):
indexes = await indexes_cursor.to_list(length=None)
for index in indexes:
if index["name"] == self._index_name:
# Check if the existing index has matching vector dimensions
existing_dim = None
definition = index.get("latestDefinition", {})
fields = definition.get("fields", [])
for field in fields:
if (
field.get("type") == "vector"
and field.get("path") == "vector"
):
existing_dim = field.get("numDimensions")
break
expected_dim = self.embedding_func.embedding_dim
if existing_dim is not None and existing_dim != expected_dim:
error_msg = (
f"Vector dimension mismatch! Index '{self._index_name}' has "
f"dimension {existing_dim}, but current embedding model expects "
f"dimension {expected_dim}. Please drop the existing index or "
f"use an embedding model with matching dimensions."
)
logger.error(f"[{self.workspace}] {error_msg}")
raise ValueError(error_msg)
logger.info(
f"[{self.workspace}] vector index {self._index_name} already exist"
f"[{self.workspace}] vector index {self._index_name} already exists with matching dimensions ({expected_dim})"
)
return

View File

@@ -25,6 +25,7 @@ from .shared_storage import (
@dataclass
class NanoVectorDBStorage(BaseVectorStorage):
def __post_init__(self):
super().__post_init__()
# Initialize basic attributes
self._client = None
self._storage_lock = None

File diff suppressed because it is too large Load Diff

View File

@@ -10,7 +10,7 @@ import numpy as np
import pipmaster as pm
from ..base import BaseVectorStorage
from ..exceptions import QdrantMigrationError
from ..exceptions import DataMigrationError
from ..kg.shared_storage import get_data_init_lock
from ..utils import compute_mdhash_id, logger
@@ -66,6 +66,51 @@ def workspace_filter_condition(workspace: str) -> models.FieldCondition:
)
def _find_legacy_collection(
client: QdrantClient,
namespace: str,
workspace: str = None,
model_suffix: str = None,
) -> str | None:
"""
Find legacy collection with backward compatibility support.
This function tries multiple naming patterns to locate legacy collections
created by older versions of LightRAG:
1. lightrag_vdb_{namespace} - if model_suffix is provided (HIGHEST PRIORITY)
2. {workspace}_{namespace} or {namespace} - no matter if model_suffix is provided or not
3. lightrag_vdb_{namespace} - fall back value no matter if model_suffix is provided or not (LOWEST PRIORITY)
Args:
client: QdrantClient instance
namespace: Base namespace (e.g., "chunks", "entities")
workspace: Optional workspace identifier
model_suffix: Optional model suffix for new collection
Returns:
Collection name if found, None otherwise
"""
# Try multiple naming patterns for backward compatibility
# More specific names (with workspace) have higher priority
candidates = [
f"lightrag_vdb_{namespace}" if model_suffix else None,
f"{workspace}_{namespace}" if workspace else None,
f"lightrag_vdb_{namespace}",
namespace,
]
for candidate in candidates:
if candidate and client.collection_exists(candidate):
logger.info(
f"Qdrant: Found legacy collection '{candidate}' "
f"(namespace={namespace}, workspace={workspace or 'none'})"
)
return candidate
return None
@final
@dataclass
class QdrantVectorDBStorage(BaseVectorStorage):
@@ -79,63 +124,52 @@ class QdrantVectorDBStorage(BaseVectorStorage):
embedding_func=embedding_func,
meta_fields=meta_fields or set(),
)
self.__post_init__()
# __post_init__() is automatically called by super().__init__()
@staticmethod
def setup_collection(
client: QdrantClient,
collection_name: str,
legacy_namespace: str = None,
workspace: str = None,
**kwargs,
namespace: str,
workspace: str,
vectors_config: models.VectorParams,
hnsw_config: models.HnswConfigDiff,
model_suffix: str,
):
"""
Setup Qdrant collection with migration support from legacy collections.
Ensure final collection has workspace isolation index.
Check vector dimension compatibility before new collection creation.
Drop legacy collection if it exists and is empty.
Only migrate data from legacy collection to new collection when new collection first created and legacy collection is not empty.
Args:
client: QdrantClient instance
collection_name: Name of the new collection
legacy_namespace: Name of the legacy collection (if exists)
collection_name: Name of the final collection
namespace: Base namespace (e.g., "chunks", "entities")
workspace: Workspace identifier for data isolation
**kwargs: Additional arguments for collection creation (vectors_config, hnsw_config, etc.)
vectors_config: Vector configuration parameters for the collection
hnsw_config: HNSW index configuration diff for the collection
"""
if not namespace or not workspace:
raise ValueError("namespace and workspace must be provided")
workspace_count_filter = models.Filter(
must=[workspace_filter_condition(workspace)]
)
new_collection_exists = client.collection_exists(collection_name)
legacy_exists = legacy_namespace and client.collection_exists(legacy_namespace)
legacy_collection = _find_legacy_collection(
client, namespace, workspace, model_suffix
)
# Case 1: Both new and legacy collections exist - Warning only (no migration)
if new_collection_exists and legacy_exists:
logger.warning(
f"Qdrant: Legacy collection '{legacy_namespace}' still exist. Remove it if migration is complete."
)
return
# Case 2: Only new collection exists - Ensure index exists
if new_collection_exists:
# Check if workspace index exists, create if missing
try:
collection_info = client.get_collection(collection_name)
if WORKSPACE_ID_FIELD not in collection_info.payload_schema:
logger.info(
f"Qdrant: Creating missing workspace index for '{collection_name}'"
)
client.create_payload_index(
collection_name=collection_name,
field_name=WORKSPACE_ID_FIELD,
field_schema=models.KeywordIndexParams(
type=models.KeywordIndexType.KEYWORD,
is_tenant=True,
),
)
except Exception as e:
logger.warning(
f"Qdrant: Could not verify/create workspace index for '{collection_name}': {e}"
)
return
# Case 3: Neither exists - Create new collection
if not legacy_exists:
logger.info(f"Qdrant: Creating new collection '{collection_name}'")
client.create_collection(collection_name, **kwargs)
# Case 1: Only new collection exists or new collection is the same as legacy collection
# No data migration needed, and ensuring index is created then return
if (new_collection_exists and not legacy_collection) or (
collection_name == legacy_collection
):
# create_payload_index return without error if index already exists
client.create_payload_index(
collection_name=collection_name,
field_name=WORKSPACE_ID_FIELD,
@@ -144,132 +178,244 @@ class QdrantVectorDBStorage(BaseVectorStorage):
is_tenant=True,
),
)
logger.info(f"Qdrant: Collection '{collection_name}' created successfully")
new_workspace_count = client.count(
collection_name=collection_name,
count_filter=workspace_count_filter,
exact=True,
).count
# Skip data migration if new collection already has workspace data
if new_workspace_count == 0 and not (collection_name == legacy_collection):
logger.warning(
f"Qdrant: workspace data in collection '{collection_name}' is empty. "
f"Ensure it is caused by new workspace setup and not an unexpected embedding model change."
)
return
# Case 4: Only legacy exists - Migrate data
logger.info(
f"Qdrant: Migrating data from legacy collection '{legacy_namespace}'"
legacy_count = None
if not new_collection_exists:
# Check vector dimension compatibility before creating new collection
if legacy_collection:
legacy_count = client.count(
collection_name=legacy_collection, exact=True
).count
if legacy_count > 0:
legacy_info = client.get_collection(legacy_collection)
legacy_dim = legacy_info.config.params.vectors.size
if vectors_config.size and legacy_dim != vectors_config.size:
logger.error(
f"Qdrant: Dimension mismatch detected! "
f"Legacy collection '{legacy_collection}' has {legacy_dim}d vectors, "
f"but new embedding model expects {vectors_config.size}d."
)
raise DataMigrationError(
f"Dimension mismatch between legacy collection '{legacy_collection}' "
f"and new collection. Expected {vectors_config.size}d but got {legacy_dim}d."
)
client.create_collection(
collection_name, vectors_config=vectors_config, hnsw_config=hnsw_config
)
logger.info(f"Qdrant: Collection '{collection_name}' created successfully")
if not legacy_collection:
logger.warning(
"Qdrant: Ensure this new collection creation is caused by new workspace setup and not an unexpected embedding model change."
)
# create_payload_index return without error if index already exists
client.create_payload_index(
collection_name=collection_name,
field_name=WORKSPACE_ID_FIELD,
field_schema=models.KeywordIndexParams(
type=models.KeywordIndexType.KEYWORD,
is_tenant=True,
),
)
try:
# Get legacy collection count
legacy_count = client.count(
collection_name=legacy_namespace, exact=True
).count
logger.info(f"Qdrant: Found {legacy_count} records in legacy collection")
# Case 2: Legacy collection exist
if legacy_collection:
# Only drop legacy collection if it's empty
if legacy_count is None:
legacy_count = client.count(
collection_name=legacy_collection, exact=True
).count
if legacy_count == 0:
logger.info("Qdrant: Legacy collection is empty, skipping migration")
# Create new empty collection
client.create_collection(collection_name, **kwargs)
client.create_payload_index(
collection_name=collection_name,
field_name=WORKSPACE_ID_FIELD,
field_schema=models.KeywordIndexParams(
type=models.KeywordIndexType.KEYWORD,
is_tenant=True,
),
client.delete_collection(collection_name=legacy_collection)
logger.info(
f"Qdrant: Empty legacy collection '{legacy_collection}' deleted successfully"
)
return
# Create new collection first
logger.info(f"Qdrant: Creating new collection '{collection_name}'")
client.create_collection(collection_name, **kwargs)
new_workspace_count = client.count(
collection_name=collection_name,
count_filter=workspace_count_filter,
exact=True,
).count
# Batch migration (500 records per batch)
migrated_count = 0
offset = None
batch_size = 500
while True:
# Scroll through legacy data
result = client.scroll(
collection_name=legacy_namespace,
limit=batch_size,
offset=offset,
with_vectors=True,
with_payload=True,
# Skip data migration if new collection already has workspace data
if new_workspace_count > 0:
logger.warning(
f"Qdrant: Both new and legacy collection have data. "
f"{legacy_count} records in {legacy_collection} require manual deletion after migration verification."
)
points, next_offset = result
return
if not points:
break
# Case 3: Only legacy exists - migrate data from legacy collection to new collection
# Check if legacy collection has workspace_id to determine migration strategy
# Note: payload_schema only reflects INDEXED fields, so we also sample
# actual payloads to detect unindexed workspace_id fields
legacy_info = client.get_collection(legacy_collection)
has_workspace_index = WORKSPACE_ID_FIELD in (
legacy_info.payload_schema or {}
)
# Transform points for new collection
new_points = []
for point in points:
# Add workspace_id to payload
new_payload = dict(point.payload or {})
new_payload[WORKSPACE_ID_FIELD] = workspace or DEFAULT_WORKSPACE
# Create new point with workspace-prefixed ID
original_id = new_payload.get(ID_FIELD)
if original_id:
new_point_id = compute_mdhash_id_for_qdrant(
original_id, prefix=workspace or DEFAULT_WORKSPACE
# Detect workspace_id field presence by sampling payloads if not indexed
# This prevents cross-workspace data leakage when workspace_id exists but isn't indexed
has_workspace_field = has_workspace_index
if not has_workspace_index:
# Sample a small batch of points to check for workspace_id in payloads
# All points must have workspace_id if any point has it
sample_result = client.scroll(
collection_name=legacy_collection,
limit=10, # Small sample is sufficient for detection
with_payload=True,
with_vectors=False,
)
sample_points, _ = sample_result
for point in sample_points:
if point.payload and WORKSPACE_ID_FIELD in point.payload:
has_workspace_field = True
logger.info(
f"Qdrant: Detected unindexed {WORKSPACE_ID_FIELD} field "
f"in legacy collection '{legacy_collection}' via payload sampling"
)
else:
# Fallback: use original point ID
new_point_id = str(point.id)
break
new_points.append(
models.PointStruct(
id=new_point_id,
vector=point.vector,
payload=new_payload,
# Build workspace filter if legacy collection has workspace support
# This prevents cross-workspace data leakage during migration
legacy_scroll_filter = None
if has_workspace_field:
legacy_scroll_filter = models.Filter(
must=[workspace_filter_condition(workspace)]
)
# Recount with workspace filter for accurate migration tracking
legacy_count = client.count(
collection_name=legacy_collection,
count_filter=legacy_scroll_filter,
exact=True,
).count
logger.info(
f"Qdrant: Legacy collection has workspace support, "
f"filtering to {legacy_count} records for workspace '{workspace}'"
)
logger.info(
f"Qdrant: Found legacy collection '{legacy_collection}' with {legacy_count} records to migrate."
)
logger.info(
f"Qdrant: Migrating data from legacy collection '{legacy_collection}' to new collection '{collection_name}'"
)
try:
# Batch migration (500 records per batch)
migrated_count = 0
offset = None
batch_size = 500
while True:
# Scroll through legacy data with optional workspace filter
result = client.scroll(
collection_name=legacy_collection,
scroll_filter=legacy_scroll_filter,
limit=batch_size,
offset=offset,
with_vectors=True,
with_payload=True,
)
points, next_offset = result
if not points:
break
# Transform points for new collection
new_points = []
for point in points:
# Set workspace_id in payload
new_payload = dict(point.payload or {})
new_payload[WORKSPACE_ID_FIELD] = workspace
# Create new point with workspace-prefixed ID
original_id = new_payload.get(ID_FIELD)
if original_id:
new_point_id = compute_mdhash_id_for_qdrant(
original_id, prefix=workspace
)
else:
# Fallback: use original point ID
new_point_id = str(point.id)
new_points.append(
models.PointStruct(
id=new_point_id,
vector=point.vector,
payload=new_payload,
)
)
# Upsert to new collection
client.upsert(
collection_name=collection_name, points=new_points, wait=True
)
# Upsert to new collection
client.upsert(
collection_name=collection_name, points=new_points, wait=True
migrated_count += len(points)
logger.info(
f"Qdrant: {migrated_count}/{legacy_count} records migrated"
)
# Check if we've reached the end
if next_offset is None:
break
offset = next_offset
new_count_after = client.count(
collection_name=collection_name,
count_filter=workspace_count_filter,
exact=True,
).count
inserted_count = new_count_after - new_workspace_count
if inserted_count != legacy_count:
error_msg = (
"Qdrant: Migration verification failed, expected "
f"{legacy_count} inserted records, got {inserted_count}."
)
logger.error(error_msg)
raise DataMigrationError(error_msg)
except DataMigrationError:
# Re-raise DataMigrationError as-is to preserve specific error messages
raise
except Exception as e:
logger.error(
f"Qdrant: Failed to migrate data from legacy collection '{legacy_collection}' to new collection '{collection_name}': {e}"
)
migrated_count += len(points)
logger.info(f"Qdrant: {migrated_count}/{legacy_count} records migrated")
# Check if we've reached the end
if next_offset is None:
break
offset = next_offset
# Verify migration by comparing counts
logger.info("Verifying migration...")
new_count = client.count(collection_name=collection_name, exact=True).count
if new_count != legacy_count:
error_msg = f"Qdrant: Migration verification failed, expected {legacy_count} records, got {new_count} in new collection"
logger.error(error_msg)
raise QdrantMigrationError(error_msg)
raise DataMigrationError(
f"Failed to migrate data from legacy collection '{legacy_collection}' to new collection '{collection_name}'"
) from e
logger.info(
f"Qdrant: Migration completed successfully: {migrated_count} records migrated"
f"Qdrant: Migration from '{legacy_collection}' to '{collection_name}' completed successfully"
)
# Create payload index after successful migration
logger.info("Qdrant: Creating workspace payload index...")
client.create_payload_index(
collection_name=collection_name,
field_name=WORKSPACE_ID_FIELD,
field_schema=models.KeywordIndexParams(
type=models.KeywordIndexType.KEYWORD,
is_tenant=True,
),
logger.warning(
"Qdrant: Manual deletion is required after data migration verification."
)
logger.info(
f"Qdrant: Migration from '{legacy_namespace}' to '{collection_name}' completed successfully"
)
except QdrantMigrationError:
# Re-raise migration errors without wrapping
raise
except Exception as e:
error_msg = f"Qdrant: Migration failed with error: {e}"
logger.error(error_msg)
raise QdrantMigrationError(error_msg) from e
def __post_init__(self):
# Call parent class __post_init__ to validate embedding_func
super().__post_init__()
# Check for QDRANT_WORKSPACE environment variable first (higher priority)
# This allows administrators to force a specific workspace for all Qdrant storage instances
qdrant_workspace = os.environ.get("QDRANT_WORKSPACE")
@@ -287,20 +433,23 @@ class QdrantVectorDBStorage(BaseVectorStorage):
f"Using passed workspace parameter: '{effective_workspace}'"
)
# Get legacy namespace for data migration from old version
if effective_workspace:
self.legacy_namespace = f"{effective_workspace}_{self.namespace}"
else:
self.legacy_namespace = self.namespace
self.effective_workspace = effective_workspace or DEFAULT_WORKSPACE
# Use a shared collection with payload-based partitioning (Qdrant's recommended approach)
# Ref: https://qdrant.tech/documentation/guides/multiple-partitions/
self.final_namespace = f"lightrag_vdb_{self.namespace}"
logger.debug(
f"Using shared collection '{self.final_namespace}' with workspace '{self.effective_workspace}' for payload-based partitioning"
)
# Generate model suffix
self.model_suffix = self._generate_collection_suffix()
# New naming scheme with model isolation
# Example: "lightrag_vdb_chunks_text_embedding_ada_002_1536d"
# Ensure model_suffix is not empty before appending
if self.model_suffix:
self.final_namespace = f"lightrag_vdb_{self.namespace}_{self.model_suffix}"
logger.info(f"Qdrant collection: {self.final_namespace}")
else:
# Fallback: use legacy namespace if model_suffix is unavailable
self.final_namespace = f"lightrag_vdb_{self.namespace}"
logger.warning(
f"Qdrant collection: {self.final_namespace} missing suffix. Pls add model_name to embedding_func for proper workspace data isolation."
)
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold")
@@ -338,11 +487,11 @@ class QdrantVectorDBStorage(BaseVectorStorage):
)
# Setup collection (create if not exists and configure indexes)
# Pass legacy_namespace and workspace for migration support
# Pass namespace and workspace for backward-compatible migration support
QdrantVectorDBStorage.setup_collection(
self._client,
self.final_namespace,
legacy_namespace=self.legacy_namespace,
namespace=self.namespace,
workspace=self.effective_workspace,
vectors_config=models.VectorParams(
size=self.embedding_func.embedding_dim,
@@ -352,8 +501,11 @@ class QdrantVectorDBStorage(BaseVectorStorage):
payload_m=16,
m=0,
),
model_suffix=self.model_suffix,
)
# Removed duplicate max batch size initialization
self._initialized = True
logger.info(
f"[{self.workspace}] Qdrant collection '{self.namespace}' initialized successfully"
@@ -481,21 +633,44 @@ class QdrantVectorDBStorage(BaseVectorStorage):
entity_name: Name of the entity to delete
"""
try:
# Generate the entity ID using the same function as used for storage
# Compute entity ID from name (same as Milvus)
entity_id = compute_mdhash_id(entity_name, prefix=ENTITY_PREFIX)
qdrant_entity_id = compute_mdhash_id_for_qdrant(
entity_id, prefix=self.effective_workspace
logger.debug(
f"[{self.workspace}] Attempting to delete entity {entity_name} with ID {entity_id}"
)
# Delete the entity point by its Qdrant ID directly
self._client.delete(
# Scroll to find the entity by its ID field in payload with workspace filtering
# This is safer than reconstructing the Qdrant point ID
results = self._client.scroll(
collection_name=self.final_namespace,
points_selector=models.PointIdsList(points=[qdrant_entity_id]),
wait=True,
)
logger.debug(
f"[{self.workspace}] Successfully deleted entity {entity_name}"
scroll_filter=models.Filter(
must=[
workspace_filter_condition(self.effective_workspace),
models.FieldCondition(
key=ID_FIELD, match=models.MatchValue(value=entity_id)
),
]
),
with_payload=False,
limit=1,
)
# Extract point IDs to delete
points = results[0]
if points:
ids_to_delete = [point.id for point in points]
self._client.delete(
collection_name=self.final_namespace,
points_selector=models.PointIdsList(points=ids_to_delete),
wait=True,
)
logger.debug(
f"[{self.workspace}] Successfully deleted entity {entity_name}"
)
else:
logger.debug(
f"[{self.workspace}] Entity {entity_name} not found in storage"
)
except Exception as e:
logger.error(f"[{self.workspace}] Error deleting entity {entity_name}: {e}")
@@ -506,38 +681,60 @@ class QdrantVectorDBStorage(BaseVectorStorage):
entity_name: Name of the entity whose relations should be deleted
"""
try:
# Find relations where the entity is either source or target, with workspace filtering
results = self._client.scroll(
collection_name=self.final_namespace,
scroll_filter=models.Filter(
must=[workspace_filter_condition(self.effective_workspace)],
should=[
models.FieldCondition(
key="src_id", match=models.MatchValue(value=entity_name)
),
models.FieldCondition(
key="tgt_id", match=models.MatchValue(value=entity_name)
),
],
),
with_payload=True,
limit=1000, # Adjust as needed for your use case
# Build the filter to find relations where entity is either source or target
# must + should = workspace_id matches AND (src_id matches OR tgt_id matches)
relation_filter = models.Filter(
must=[workspace_filter_condition(self.effective_workspace)],
should=[
models.FieldCondition(
key="src_id", match=models.MatchValue(value=entity_name)
),
models.FieldCondition(
key="tgt_id", match=models.MatchValue(value=entity_name)
),
],
)
# Extract points that need to be deleted
relation_points = results[0]
ids_to_delete = [point.id for point in relation_points]
# Paginate through all matching relations to handle large datasets
total_deleted = 0
offset = None
batch_size = 1000
if ids_to_delete:
# Delete the relations with workspace filtering
assert isinstance(self._client, QdrantClient)
while True:
# Scroll to find relations, using with_payload=False for efficiency
# since we only need point IDs for deletion
results = self._client.scroll(
collection_name=self.final_namespace,
scroll_filter=relation_filter,
with_payload=False,
with_vectors=False,
limit=batch_size,
offset=offset,
)
points, next_offset = results
if not points:
break
# Extract point IDs to delete
ids_to_delete = [point.id for point in points]
# Delete the batch of relations
self._client.delete(
collection_name=self.final_namespace,
points_selector=models.PointIdsList(points=ids_to_delete),
wait=True,
)
total_deleted += len(ids_to_delete)
# Check if we've reached the end
if next_offset is None:
break
offset = next_offset
if total_deleted > 0:
logger.debug(
f"[{self.workspace}] Deleted {len(ids_to_delete)} relations for {entity_name}"
f"[{self.workspace}] Deleted {total_deleted} relations for {entity_name}"
)
else:
logger.debug(

View File

@@ -163,7 +163,9 @@ class UnifiedLock(Generic[T]):
enable_output=self._enable_logging,
)
# Then acquire the main lock
# Acquire the main lock
# Note: self._lock should never be None here as the check has been moved
# to get_internal_lock() and get_data_init_lock() functions
if self._is_async:
await self._lock.acquire()
else:
@@ -193,19 +195,21 @@ class UnifiedLock(Generic[T]):
async def __aexit__(self, exc_type, exc_val, exc_tb):
main_lock_released = False
async_lock_released = False
try:
# Release main lock first
if self._is_async:
self._lock.release()
else:
self._lock.release()
main_lock_released = True
if self._lock is not None:
if self._is_async:
self._lock.release()
else:
self._lock.release()
direct_log(
f"== Lock == Process {self._pid}: Released lock {self._name} (async={self._is_async})",
level="INFO",
enable_output=self._enable_logging,
)
direct_log(
f"== Lock == Process {self._pid}: Released lock {self._name} (async={self._is_async})",
level="INFO",
enable_output=self._enable_logging,
)
main_lock_released = True
# Then release async lock if in multiprocess mode
if not self._is_async and self._async_lock is not None:
@@ -215,6 +219,7 @@ class UnifiedLock(Generic[T]):
level="DEBUG",
enable_output=self._enable_logging,
)
async_lock_released = True
except Exception as e:
direct_log(
@@ -223,9 +228,10 @@ class UnifiedLock(Generic[T]):
enable_output=True,
)
# If main lock release failed but async lock hasn't been released, try to release it
# If main lock release failed but async lock hasn't been attempted yet, try to release it
if (
not main_lock_released
and not async_lock_released
and not self._is_async
and self._async_lock is not None
):
@@ -255,6 +261,10 @@ class UnifiedLock(Generic[T]):
try:
if self._is_async:
raise RuntimeError("Use 'async with' for shared_storage lock")
# Acquire the main lock
# Note: self._lock should never be None here as the check has been moved
# to get_internal_lock() and get_data_init_lock() functions
direct_log(
f"== Lock == Process {self._pid}: Acquiring lock {self._name} (sync)",
level="DEBUG",
@@ -1060,6 +1070,10 @@ class _KeyedLockContext:
def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified storage lock for data consistency"""
if _internal_lock is None:
raise RuntimeError(
"Shared data not initialized. Call initialize_share_data() before using locks!"
)
async_lock = _async_locks.get("internal_lock") if _is_multiprocess else None
return UnifiedLock(
lock=_internal_lock,
@@ -1090,6 +1104,10 @@ def get_storage_keyed_lock(
def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified data initialization lock for ensuring atomic data initialization"""
if _data_init_lock is None:
raise RuntimeError(
"Shared data not initialized. Call initialize_share_data() before using locks!"
)
async_lock = _async_locks.get("data_init_lock") if _is_multiprocess else None
return UnifiedLock(
lock=_data_init_lock,

View File

@@ -7,7 +7,7 @@ import inspect
import os
import time
import warnings
from dataclasses import asdict, dataclass, field
from dataclasses import asdict, dataclass, field, replace
from datetime import datetime, timezone
from functools import partial
from typing import (
@@ -518,14 +518,9 @@ class LightRAG:
f"max_total_tokens({self.summary_max_tokens}) should greater than summary_length_recommended({self.summary_length_recommended})"
)
# Fix global_config now
global_config = asdict(self)
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# Init Embedding
# Step 1: Capture max_token_size before applying decorator (decorator strips dataclass attributes)
# Step 1: Capture embedding_func and max_token_size before applying decorator
original_embedding_func = self.embedding_func
embedding_max_token_size = None
if self.embedding_func and hasattr(self.embedding_func, "max_token_size"):
embedding_max_token_size = self.embedding_func.max_token_size
@@ -534,12 +529,26 @@ class LightRAG:
)
self.embedding_token_limit = embedding_max_token_size
# Step 2: Apply priority wrapper decorator
self.embedding_func = priority_limit_async_func_call(
self.embedding_func_max_async,
llm_timeout=self.default_embedding_timeout,
queue_name="Embedding func",
)(self.embedding_func)
# Fix global_config now
global_config = asdict(self)
# Restore original EmbeddingFunc object (asdict converts it to dict)
global_config["embedding_func"] = original_embedding_func
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# Step 2: Apply priority wrapper decorator to EmbeddingFunc's inner func
# Create a NEW EmbeddingFunc instance with the wrapped func to avoid mutating the caller's object
# This ensures _generate_collection_suffix can still access attributes (model_name, embedding_dim)
# while preventing side effects when the same EmbeddingFunc is reused across multiple LightRAG instances
if self.embedding_func is not None:
wrapped_func = priority_limit_async_func_call(
self.embedding_func_max_async,
llm_timeout=self.default_embedding_timeout,
queue_name="Embedding func",
)(self.embedding_func.func)
# Use dataclasses.replace() to create a new instance, leaving the original unchanged
self.embedding_func = replace(self.embedding_func, func=wrapped_func)
# Initialize all storages
self.key_string_value_json_storage_cls: type[BaseKVStorage] = (

View File

@@ -351,7 +351,9 @@ async def bedrock_complete(
return result
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
@wrap_embedding_func_with_attrs(
embedding_dim=1024, max_token_size=8192, model_name="amazon.titan-embed-text-v2:0"
)
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=60),

View File

@@ -453,7 +453,9 @@ async def gemini_model_complete(
)
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=2048)
@wrap_embedding_func_with_attrs(
embedding_dim=1536, max_token_size=2048, model_name="gemini-embedding-001"
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),

View File

@@ -142,7 +142,9 @@ async def hf_model_complete(
return result
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
@wrap_embedding_func_with_attrs(
embedding_dim=1024, max_token_size=8192, model_name="hf_embedding_model"
)
async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray:
# Detect the appropriate device
if torch.cuda.is_available():

View File

@@ -58,7 +58,9 @@ async def fetch_data(url, headers, data):
return data_list
@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192)
@wrap_embedding_func_with_attrs(
embedding_dim=2048, max_token_size=8192, model_name="jina-embeddings-v4"
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),

View File

@@ -138,7 +138,9 @@ async def lollms_model_complete(
)
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
@wrap_embedding_func_with_attrs(
embedding_dim=1024, max_token_size=8192, model_name="lollms_embedding_model"
)
async def lollms_embed(
texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs
) -> np.ndarray:

View File

@@ -33,7 +33,9 @@ from lightrag.utils import (
import numpy as np
@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192)
@wrap_embedding_func_with_attrs(
embedding_dim=2048, max_token_size=8192, model_name="nvidia_embedding_model"
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),

View File

@@ -172,7 +172,9 @@ async def ollama_model_complete(
)
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
@wrap_embedding_func_with_attrs(
embedding_dim=1024, max_token_size=8192, model_name="bge-m3:latest"
)
async def ollama_embed(
texts: list[str], embed_model: str = "bge-m3:latest", **kwargs
) -> np.ndarray:

View File

@@ -677,7 +677,9 @@ async def nvidia_openai_complete(
return result
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@wrap_embedding_func_with_attrs(
embedding_dim=1536, max_token_size=8192, model_name="text-embedding-3-small"
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
@@ -867,7 +869,11 @@ async def azure_openai_complete(
return result
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@wrap_embedding_func_with_attrs(
embedding_dim=1536,
max_token_size=8192,
model_name="my-text-embedding-3-large-deployment",
)
async def azure_openai_embed(
texts: list[str],
model: str | None = None,

View File

@@ -179,7 +179,9 @@ async def zhipu_complete(
)
@wrap_embedding_func_with_attrs(embedding_dim=1024)
@wrap_embedding_func_with_attrs(
embedding_dim=1024, max_token_size=8192, model_name="embedding-3"
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),

View File

@@ -0,0 +1,720 @@
#!/usr/bin/env python3
"""
Qdrant Legacy Data Preparation Tool for LightRAG
This tool copies data from new collections to legacy collections for testing
the data migration logic in setup_collection function.
New Collections (with workspace_id):
- lightrag_vdb_chunks
- lightrag_vdb_entities
- lightrag_vdb_relationships
Legacy Collections (without workspace_id, dynamically named as {workspace}_{suffix}):
- {workspace}_chunks (e.g., space1_chunks)
- {workspace}_entities (e.g., space1_entities)
- {workspace}_relationships (e.g., space1_relationships)
The tool:
1. Filters source data by workspace_id
2. Verifies workspace data exists before creating legacy collections
3. Removes workspace_id field to simulate legacy data format
4. Copies only the specified workspace's data to legacy collections
Usage:
python -m lightrag.tools.prepare_qdrant_legacy_data
# or
python lightrag/tools/prepare_qdrant_legacy_data.py
# Specify custom workspace
python -m lightrag.tools.prepare_qdrant_legacy_data --workspace space1
# Process specific collection types only
python -m lightrag.tools.prepare_qdrant_legacy_data --types chunks,entities
# Dry run (preview only, no actual changes)
python -m lightrag.tools.prepare_qdrant_legacy_data --dry-run
"""
import argparse
import asyncio
import configparser
import os
import sys
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
import pipmaster as pm
from dotenv import load_dotenv
from qdrant_client import QdrantClient, models # type: ignore
# Add project root to path for imports
sys.path.insert(
0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
# Load environment variables
load_dotenv(dotenv_path=".env", override=False)
# Ensure qdrant-client is installed
if not pm.is_installed("qdrant-client"):
pm.install("qdrant-client")
# Collection namespace mapping: new collection pattern -> legacy suffix
# Legacy collection will be named as: {workspace}_{suffix}
COLLECTION_NAMESPACES = {
"chunks": {
"new": "lightrag_vdb_chunks",
"suffix": "chunks",
},
"entities": {
"new": "lightrag_vdb_entities",
"suffix": "entities",
},
"relationships": {
"new": "lightrag_vdb_relationships",
"suffix": "relationships",
},
}
# Default batch size for copy operations
DEFAULT_BATCH_SIZE = 500
# Field to remove from legacy data
WORKSPACE_ID_FIELD = "workspace_id"
# ANSI color codes for terminal output
BOLD_CYAN = "\033[1;36m"
BOLD_GREEN = "\033[1;32m"
BOLD_YELLOW = "\033[1;33m"
BOLD_RED = "\033[1;31m"
RESET = "\033[0m"
@dataclass
class CopyStats:
"""Copy operation statistics"""
collection_type: str
source_collection: str
target_collection: str
total_records: int = 0
copied_records: int = 0
failed_records: int = 0
errors: List[Dict[str, Any]] = field(default_factory=list)
elapsed_time: float = 0.0
def add_error(self, batch_idx: int, error: Exception, batch_size: int):
"""Record batch error"""
self.errors.append(
{
"batch": batch_idx,
"error_type": type(error).__name__,
"error_msg": str(error),
"records_lost": batch_size,
"timestamp": time.time(),
}
)
self.failed_records += batch_size
class QdrantLegacyDataPreparationTool:
"""Tool for preparing legacy data in Qdrant for migration testing"""
def __init__(
self,
workspace: str = "space1",
batch_size: int = DEFAULT_BATCH_SIZE,
dry_run: bool = False,
clear_target: bool = False,
):
"""
Initialize the tool.
Args:
workspace: Workspace to use for filtering new collection data
batch_size: Number of records to process per batch
dry_run: If True, only preview operations without making changes
clear_target: If True, delete target collection before copying data
"""
self.workspace = workspace
self.batch_size = batch_size
self.dry_run = dry_run
self.clear_target = clear_target
self._client: Optional[QdrantClient] = None
def _get_client(self) -> QdrantClient:
"""Get or create QdrantClient instance"""
if self._client is None:
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
self._client = QdrantClient(
url=os.environ.get(
"QDRANT_URL", config.get("qdrant", "uri", fallback=None)
),
api_key=os.environ.get(
"QDRANT_API_KEY",
config.get("qdrant", "apikey", fallback=None),
),
)
return self._client
def print_header(self):
"""Print tool header"""
print("\n" + "=" * 60)
print("Qdrant Legacy Data Preparation Tool - LightRAG")
print("=" * 60)
if self.dry_run:
print(f"{BOLD_YELLOW}⚠️ DRY RUN MODE - No changes will be made{RESET}")
if self.clear_target:
print(
f"{BOLD_RED}⚠️ CLEAR TARGET MODE - Target collections will be deleted first{RESET}"
)
print(f"Workspace: {BOLD_CYAN}{self.workspace}{RESET}")
print(f"Batch Size: {self.batch_size}")
print("=" * 60)
def check_connection(self) -> bool:
"""Check Qdrant connection"""
try:
client = self._get_client()
# Try to list collections to verify connection
client.get_collections()
print(f"{BOLD_GREEN}{RESET} Qdrant connection successful")
return True
except Exception as e:
print(f"{BOLD_RED}{RESET} Qdrant connection failed: {e}")
return False
def get_collection_info(self, collection_name: str) -> Optional[Dict[str, Any]]:
"""
Get collection information.
Args:
collection_name: Name of the collection
Returns:
Dictionary with collection info (vector_size, count) or None if not exists
"""
client = self._get_client()
if not client.collection_exists(collection_name):
return None
info = client.get_collection(collection_name)
count = client.count(collection_name=collection_name, exact=True).count
# Handle both object and dict formats for vectors config
vectors_config = info.config.params.vectors
if isinstance(vectors_config, dict):
# Named vectors format or dict format
if vectors_config:
first_key = next(iter(vectors_config.keys()), None)
if first_key and hasattr(vectors_config[first_key], "size"):
vector_size = vectors_config[first_key].size
distance = vectors_config[first_key].distance
else:
# Try to get from dict values
first_val = next(iter(vectors_config.values()), {})
vector_size = (
first_val.get("size")
if isinstance(first_val, dict)
else getattr(first_val, "size", None)
)
distance = (
first_val.get("distance")
if isinstance(first_val, dict)
else getattr(first_val, "distance", None)
)
else:
vector_size = None
distance = None
else:
# Standard single vector format
vector_size = vectors_config.size
distance = vectors_config.distance
return {
"name": collection_name,
"vector_size": vector_size,
"count": count,
"distance": distance,
}
def delete_collection(self, collection_name: str) -> bool:
"""
Delete a collection if it exists.
Args:
collection_name: Name of the collection to delete
Returns:
True if deleted or doesn't exist
"""
client = self._get_client()
if not client.collection_exists(collection_name):
return True
if self.dry_run:
target_info = self.get_collection_info(collection_name)
count = target_info["count"] if target_info else 0
print(
f" {BOLD_YELLOW}[DRY RUN]{RESET} Would delete collection '{collection_name}' ({count:,} records)"
)
return True
try:
target_info = self.get_collection_info(collection_name)
count = target_info["count"] if target_info else 0
client.delete_collection(collection_name=collection_name)
print(
f" {BOLD_RED}{RESET} Deleted collection '{collection_name}' ({count:,} records)"
)
return True
except Exception as e:
print(f" {BOLD_RED}{RESET} Failed to delete collection: {e}")
return False
def create_legacy_collection(
self, collection_name: str, vector_size: int, distance: models.Distance
) -> bool:
"""
Create legacy collection if it doesn't exist.
Args:
collection_name: Name of the collection to create
vector_size: Dimension of vectors
distance: Distance metric
Returns:
True if created or already exists
"""
client = self._get_client()
if client.collection_exists(collection_name):
print(f" Collection '{collection_name}' already exists")
return True
if self.dry_run:
print(
f" {BOLD_YELLOW}[DRY RUN]{RESET} Would create collection '{collection_name}' with {vector_size}d vectors"
)
return True
try:
client.create_collection(
collection_name=collection_name,
vectors_config=models.VectorParams(
size=vector_size,
distance=distance,
),
hnsw_config=models.HnswConfigDiff(
payload_m=16,
m=0,
),
)
print(
f" {BOLD_GREEN}{RESET} Created collection '{collection_name}' with {vector_size}d vectors"
)
return True
except Exception as e:
print(f" {BOLD_RED}{RESET} Failed to create collection: {e}")
return False
def _get_workspace_filter(self) -> models.Filter:
"""Create workspace filter for Qdrant queries"""
return models.Filter(
must=[
models.FieldCondition(
key=WORKSPACE_ID_FIELD,
match=models.MatchValue(value=self.workspace),
)
]
)
def get_workspace_count(self, collection_name: str) -> int:
"""
Get count of records for the current workspace in a collection.
Args:
collection_name: Name of the collection
Returns:
Count of records for the workspace
"""
client = self._get_client()
return client.count(
collection_name=collection_name,
count_filter=self._get_workspace_filter(),
exact=True,
).count
def copy_collection_data(
self,
source_collection: str,
target_collection: str,
collection_type: str,
workspace_count: int,
) -> CopyStats:
"""
Copy data from source to target collection.
This filters by workspace_id and removes it from payload to simulate legacy data format.
Args:
source_collection: Source collection name
target_collection: Target collection name
collection_type: Type of collection (chunks, entities, relationships)
workspace_count: Pre-computed count of workspace records
Returns:
CopyStats with operation results
"""
client = self._get_client()
stats = CopyStats(
collection_type=collection_type,
source_collection=source_collection,
target_collection=target_collection,
)
start_time = time.time()
stats.total_records = workspace_count
if workspace_count == 0:
print(f" No records for workspace '{self.workspace}', skipping")
stats.elapsed_time = time.time() - start_time
return stats
print(f" Workspace records: {workspace_count:,}")
if self.dry_run:
print(
f" {BOLD_YELLOW}[DRY RUN]{RESET} Would copy {workspace_count:,} records to '{target_collection}'"
)
stats.copied_records = workspace_count
stats.elapsed_time = time.time() - start_time
return stats
# Batch copy using scroll with workspace filter
workspace_filter = self._get_workspace_filter()
offset = None
batch_idx = 0
while True:
# Scroll source collection with workspace filter
result = client.scroll(
collection_name=source_collection,
scroll_filter=workspace_filter,
limit=self.batch_size,
offset=offset,
with_vectors=True,
with_payload=True,
)
points, next_offset = result
if not points:
break
batch_idx += 1
# Transform points: remove workspace_id from payload
new_points = []
for point in points:
new_payload = dict(point.payload or {})
# Remove workspace_id to simulate legacy format
new_payload.pop(WORKSPACE_ID_FIELD, None)
# Use original id from payload if available, otherwise use point.id
original_id = new_payload.get("id")
if original_id:
# Generate a simple deterministic id for legacy format
# Use original id directly (legacy format didn't have workspace prefix)
import hashlib
import uuid
hashed = hashlib.sha256(original_id.encode("utf-8")).digest()
point_id = uuid.UUID(bytes=hashed[:16], version=4).hex
else:
point_id = str(point.id)
new_points.append(
models.PointStruct(
id=point_id,
vector=point.vector,
payload=new_payload,
)
)
try:
# Upsert to target collection
client.upsert(
collection_name=target_collection, points=new_points, wait=True
)
stats.copied_records += len(new_points)
# Progress bar
progress = (stats.copied_records / workspace_count) * 100
bar_length = 30
filled = int(bar_length * stats.copied_records // workspace_count)
bar = "" * filled + "" * (bar_length - filled)
print(
f"\r Copying: {bar} {stats.copied_records:,}/{workspace_count:,} ({progress:.1f}%) ",
end="",
flush=True,
)
except Exception as e:
stats.add_error(batch_idx, e, len(new_points))
print(
f"\n {BOLD_RED}{RESET} Batch {batch_idx} failed: {type(e).__name__}: {e}"
)
if next_offset is None:
break
offset = next_offset
print() # New line after progress bar
stats.elapsed_time = time.time() - start_time
return stats
def process_collection_type(self, collection_type: str) -> Optional[CopyStats]:
"""
Process a single collection type.
Args:
collection_type: Type of collection (chunks, entities, relationships)
Returns:
CopyStats or None if error
"""
namespace_config = COLLECTION_NAMESPACES.get(collection_type)
if not namespace_config:
print(f"{BOLD_RED}{RESET} Unknown collection type: {collection_type}")
return None
source = namespace_config["new"]
# Generate legacy collection name dynamically: {workspace}_{suffix}
target = f"{self.workspace}_{namespace_config['suffix']}"
print(f"\n{'=' * 50}")
print(f"Processing: {BOLD_CYAN}{collection_type}{RESET}")
print(f"{'=' * 50}")
print(f" Source: {source}")
print(f" Target: {target}")
# Check source collection
source_info = self.get_collection_info(source)
if source_info is None:
print(
f" {BOLD_YELLOW}{RESET} Source collection '{source}' does not exist, skipping"
)
return None
print(f" Source vector dimension: {source_info['vector_size']}d")
print(f" Source distance metric: {source_info['distance']}")
print(f" Source total records: {source_info['count']:,}")
# Check workspace data exists BEFORE creating legacy collection
workspace_count = self.get_workspace_count(source)
print(f" Workspace '{self.workspace}' records: {workspace_count:,}")
if workspace_count == 0:
print(
f" {BOLD_YELLOW}{RESET} No data found for workspace '{self.workspace}' in '{source}', skipping"
)
return None
# Clear target collection if requested
if self.clear_target:
if not self.delete_collection(target):
return None
# Create target collection only after confirming workspace data exists
if not self.create_legacy_collection(
target, source_info["vector_size"], source_info["distance"]
):
return None
# Copy data with workspace filter
stats = self.copy_collection_data(
source, target, collection_type, workspace_count
)
# Print result
if stats.failed_records == 0:
print(
f" {BOLD_GREEN}{RESET} Copied {stats.copied_records:,} records in {stats.elapsed_time:.2f}s"
)
else:
print(
f" {BOLD_YELLOW}{RESET} Copied {stats.copied_records:,} records, "
f"{BOLD_RED}{stats.failed_records:,} failed{RESET} in {stats.elapsed_time:.2f}s"
)
return stats
def print_summary(self, all_stats: List[CopyStats]):
"""Print summary of all operations"""
print("\n" + "=" * 60)
print("Summary")
print("=" * 60)
total_copied = sum(s.copied_records for s in all_stats)
total_failed = sum(s.failed_records for s in all_stats)
total_time = sum(s.elapsed_time for s in all_stats)
for stats in all_stats:
status = (
f"{BOLD_GREEN}{RESET}"
if stats.failed_records == 0
else f"{BOLD_YELLOW}{RESET}"
)
print(
f" {status} {stats.collection_type}: {stats.copied_records:,}/{stats.total_records:,} "
f"({stats.source_collection}{stats.target_collection})"
)
print("-" * 60)
print(f" Total records copied: {BOLD_CYAN}{total_copied:,}{RESET}")
if total_failed > 0:
print(f" Total records failed: {BOLD_RED}{total_failed:,}{RESET}")
print(f" Total time: {total_time:.2f}s")
if self.dry_run:
print(f"\n{BOLD_YELLOW}⚠️ DRY RUN - No actual changes were made{RESET}")
# Print error details if any
all_errors = []
for stats in all_stats:
all_errors.extend(stats.errors)
if all_errors:
print(f"\n{BOLD_RED}Errors ({len(all_errors)}){RESET}")
for i, error in enumerate(all_errors[:5], 1):
print(
f" {i}. Batch {error['batch']}: {error['error_type']}: {error['error_msg']}"
)
if len(all_errors) > 5:
print(f" ... and {len(all_errors) - 5} more errors")
print("=" * 60)
async def run(self, collection_types: Optional[List[str]] = None):
"""
Run the data preparation tool.
Args:
collection_types: List of collection types to process (default: all)
"""
self.print_header()
# Check connection
if not self.check_connection():
return
# Determine which collection types to process
if collection_types:
types_to_process = [t.strip() for t in collection_types]
invalid_types = [
t for t in types_to_process if t not in COLLECTION_NAMESPACES
]
if invalid_types:
print(
f"{BOLD_RED}{RESET} Invalid collection types: {', '.join(invalid_types)}"
)
print(f" Valid types: {', '.join(COLLECTION_NAMESPACES.keys())}")
return
else:
types_to_process = list(COLLECTION_NAMESPACES.keys())
print(f"\nCollection types to process: {', '.join(types_to_process)}")
# Process each collection type
all_stats = []
for ctype in types_to_process:
stats = self.process_collection_type(ctype)
if stats:
all_stats.append(stats)
# Print summary
if all_stats:
self.print_summary(all_stats)
else:
print(f"\n{BOLD_YELLOW}{RESET} No collections were processed")
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description="Prepare legacy data in Qdrant for migration testing",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python -m lightrag.tools.prepare_qdrant_legacy_data
python -m lightrag.tools.prepare_qdrant_legacy_data --workspace space1
python -m lightrag.tools.prepare_qdrant_legacy_data --types chunks,entities
python -m lightrag.tools.prepare_qdrant_legacy_data --dry-run
""",
)
parser.add_argument(
"--workspace",
type=str,
default="space1",
help="Workspace name (default: space1)",
)
parser.add_argument(
"--types",
type=str,
default=None,
help="Comma-separated list of collection types (chunks, entities, relationships)",
)
parser.add_argument(
"--batch-size",
type=int,
default=DEFAULT_BATCH_SIZE,
help=f"Batch size for copy operations (default: {DEFAULT_BATCH_SIZE})",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Preview operations without making changes",
)
parser.add_argument(
"--clear-target",
action="store_true",
help="Delete target collections before copying (for clean test environment)",
)
return parser.parse_args()
async def main():
"""Main entry point"""
args = parse_args()
collection_types = None
if args.types:
collection_types = [t.strip() for t in args.types.split(",")]
tool = QdrantLegacyDataPreparationTool(
workspace=args.workspace,
batch_size=args.batch_size,
dry_run=args.dry_run,
clear_target=args.clear_target,
)
await tool.run(collection_types=collection_types)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -425,6 +425,9 @@ class EmbeddingFunc:
send_dimensions: bool = (
False # Control whether to send embedding_dim to the function
)
model_name: str | None = (
None # Model name for implementating workspace data isolation in vector DB
)
async def __call__(self, *args, **kwargs) -> np.ndarray:
# Only inject embedding_dim when send_dimensions is True
@@ -1016,42 +1019,36 @@ def wrap_embedding_func_with_attrs(**kwargs):
Correct usage patterns:
1. Direct implementation (decorated):
1. Direct decoration:
```python
@wrap_embedding_func_with_attrs(embedding_dim=1536)
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192, model_name="my_embedding_model")
async def my_embed(texts, embedding_dim=None):
# Direct implementation
return embeddings
```
2. Wrapper calling decorated function (DO NOT decorate wrapper):
2. Double decoration:
```python
# my_embed is already decorated above
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192, model_name="my_embedding_model")
@retry(...)
async def openai_embed(texts, ...):
# Base implementation
pass
async def my_wrapper(texts, **kwargs): # ❌ DO NOT decorate this!
# Must call .func to access unwrapped implementation
return await my_embed.func(texts, **kwargs)
```
3. Wrapper calling decorated function (properly decorated):
```python
@wrap_embedding_func_with_attrs(embedding_dim=1536)
async def my_wrapper(texts, **kwargs): # ✅ Can decorate if calling .func
# Calling .func avoids double decoration
return await my_embed.func(texts, **kwargs)
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=4096, model_name="another_embedding_model")
# Note: No @retry here!
async def new_openai_embed(texts, ...):
# CRITICAL: Call .func to access unwrapped function
return await openai_embed.func(texts, ...) # ✅ Correct
# return await openai_embed(texts, ...) # ❌ Wrong - double decoration!
```
The decorated function becomes an EmbeddingFunc instance with:
- embedding_dim: The embedding dimension
- max_token_size: Maximum token limit (optional)
- model_name: Model name (optional)
- func: The original unwrapped function (access via .func)
- __call__: Wrapper that injects embedding_dim parameter
Double decoration causes:
- Double injection of embedding_dim parameter
- Incorrect parameter passing to the underlying implementation
- Runtime errors due to parameter conflicts
Args:
embedding_dim: The dimension of embedding vectors
max_token_size: Maximum number of tokens (optional)
@@ -1059,21 +1056,6 @@ def wrap_embedding_func_with_attrs(**kwargs):
Returns:
A decorator that wraps the function as an EmbeddingFunc instance
Example of correct wrapper implementation:
```python
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(...)
async def openai_embed(texts, ...):
# Base implementation
pass
@wrap_embedding_func_with_attrs(embedding_dim=1536) # Note: No @retry here!
async def azure_openai_embed(texts, ...):
# CRITICAL: Call .func to access unwrapped function
return await openai_embed.func(texts, ...) # ✅ Correct
# return await openai_embed(texts, ...) # ❌ Wrong - double decoration!
```
"""
def final_decro(func) -> EmbeddingFunc:

View File

@@ -72,7 +72,6 @@ api = [
# API-specific dependencies
"aiofiles",
"ascii_colors",
"asyncpg",
"distro",
"fastapi",
"httpcore",
@@ -108,7 +107,8 @@ offline-storage = [
"neo4j>=5.0.0,<7.0.0",
"pymilvus>=2.6.2,<3.0.0",
"pymongo>=4.0.0,<5.0.0",
"asyncpg>=0.29.0,<1.0.0",
"asyncpg>=0.31.0,<1.0.0",
"pgvector>=0.4.2,<1.0.0",
"qdrant-client>=1.11.0,<2.0.0",
]

View File

@@ -8,8 +8,9 @@
# Or use constraints: pip install --constraint constraints-offline.txt -r requirements-offline-storage.txt
# Storage backend dependencies (with version constraints matching pyproject.toml)
asyncpg>=0.29.0,<1.0.0
asyncpg>=0.31.0,<1.0.0
neo4j>=5.0.0,<7.0.0
pgvector>=0.4.2,<1.0.0
pymilvus>=2.6.2,<3.0.0
pymongo>=4.0.0,<5.0.0
qdrant-client>=1.11.0,<2.0.0

View File

@@ -7,20 +7,17 @@
# Recommended: Use pip install lightrag-hku[offline] for the same effect
# Or use constraints: pip install --constraint constraints-offline.txt -r requirements-offline.txt
# LLM provider dependencies (with version constraints matching pyproject.toml)
aioboto3>=12.0.0,<16.0.0
anthropic>=0.18.0,<1.0.0
# Storage backend dependencies
asyncpg>=0.29.0,<1.0.0
asyncpg>=0.31.0,<1.0.0
google-api-core>=2.0.0,<3.0.0
google-genai>=1.0.0,<2.0.0
# Document processing dependencies
llama-index>=0.9.0,<1.0.0
neo4j>=5.0.0,<7.0.0
ollama>=0.1.0,<1.0.0
openai>=2.0.0,<3.0.0
openpyxl>=3.0.0,<4.0.0
pgvector>=0.4.2,<1.0.0
pycryptodome>=3.0.0,<4.0.0
pymilvus>=2.6.2,<3.0.0
pymongo>=4.0.0,<5.0.0

View File

@@ -0,0 +1,377 @@
"""
Tests for dimension mismatch handling during migration.
This test module verifies that both PostgreSQL and Qdrant storage backends
properly detect and handle vector dimension mismatches when migrating from
legacy collections/tables to new ones with different embedding models.
"""
import json
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from lightrag.kg.qdrant_impl import QdrantVectorDBStorage
from lightrag.kg.postgres_impl import PGVectorStorage
from lightrag.exceptions import DataMigrationError
# Note: Tests should use proper table names that have DDL templates
# Valid base tables: LIGHTRAG_VDB_CHUNKS, LIGHTRAG_VDB_ENTITIES, LIGHTRAG_VDB_RELATIONSHIPS,
# LIGHTRAG_DOC_CHUNKS, LIGHTRAG_DOC_FULL_DOCS, LIGHTRAG_DOC_TEXT_CHUNKS
class TestQdrantDimensionMismatch:
"""Test suite for Qdrant dimension mismatch handling."""
def test_qdrant_dimension_mismatch_raises_error(self):
"""
Test that Qdrant raises DataMigrationError when dimensions don't match.
Scenario: Legacy collection has 1536d vectors, new model expects 3072d.
Expected: DataMigrationError is raised to prevent data corruption.
"""
from qdrant_client import models
# Setup mock client
client = MagicMock()
# Mock legacy collection with 1536d vectors
legacy_collection_info = MagicMock()
legacy_collection_info.config.params.vectors.size = 1536
# Setup collection existence checks
def collection_exists_side_effect(name):
if (
name == "lightrag_vdb_chunks"
): # legacy (matches _find_legacy_collection pattern)
return True
elif name == "lightrag_chunks_model_3072d": # new
return False
return False
client.collection_exists.side_effect = collection_exists_side_effect
client.get_collection.return_value = legacy_collection_info
client.count.return_value.count = 100 # Legacy has data
# Patch _find_legacy_collection to return the legacy collection name
with patch(
"lightrag.kg.qdrant_impl._find_legacy_collection",
return_value="lightrag_vdb_chunks",
):
# Call setup_collection with 3072d (different from legacy 1536d)
# Should raise DataMigrationError due to dimension mismatch
with pytest.raises(DataMigrationError) as exc_info:
QdrantVectorDBStorage.setup_collection(
client,
"lightrag_chunks_model_3072d",
namespace="chunks",
workspace="test",
vectors_config=models.VectorParams(
size=3072, distance=models.Distance.COSINE
),
hnsw_config=models.HnswConfigDiff(
payload_m=16,
m=0,
),
model_suffix="model_3072d",
)
# Verify error message contains dimension information
assert "3072" in str(exc_info.value) or "1536" in str(exc_info.value)
# Verify new collection was NOT created (error raised before creation)
client.create_collection.assert_not_called()
# Verify migration was NOT attempted
client.scroll.assert_not_called()
client.upsert.assert_not_called()
def test_qdrant_dimension_match_proceed_migration(self):
"""
Test that Qdrant proceeds with migration when dimensions match.
Scenario: Legacy collection has 1536d vectors, new model also expects 1536d.
Expected: Migration proceeds normally.
"""
from qdrant_client import models
client = MagicMock()
# Mock legacy collection with 1536d vectors (matching new)
legacy_collection_info = MagicMock()
legacy_collection_info.config.params.vectors.size = 1536
def collection_exists_side_effect(name):
if name == "lightrag_chunks": # legacy
return True
elif name == "lightrag_chunks_model_1536d": # new
return False
return False
client.collection_exists.side_effect = collection_exists_side_effect
client.get_collection.return_value = legacy_collection_info
# Track whether upsert has been called (migration occurred)
migration_done = {"value": False}
def upsert_side_effect(*args, **kwargs):
migration_done["value"] = True
return MagicMock()
client.upsert.side_effect = upsert_side_effect
# Mock count to return different values based on collection name and migration state
# Before migration: new collection has 0 records
# After migration: new collection has 1 record (matching migrated data)
def count_side_effect(collection_name, **kwargs):
result = MagicMock()
if collection_name == "lightrag_chunks": # legacy
result.count = 1 # Legacy has 1 record
elif collection_name == "lightrag_chunks_model_1536d": # new
# Return 0 before migration, 1 after migration
result.count = 1 if migration_done["value"] else 0
else:
result.count = 0
return result
client.count.side_effect = count_side_effect
# Mock scroll to return sample data (1 record for easier verification)
sample_point = MagicMock()
sample_point.id = "test_id"
sample_point.vector = [0.1] * 1536
sample_point.payload = {"id": "test"}
client.scroll.return_value = ([sample_point], None)
# Mock _find_legacy_collection to return the legacy collection name
with patch(
"lightrag.kg.qdrant_impl._find_legacy_collection",
return_value="lightrag_chunks",
):
# Call setup_collection with matching 1536d
QdrantVectorDBStorage.setup_collection(
client,
"lightrag_chunks_model_1536d",
namespace="chunks",
workspace="test",
vectors_config=models.VectorParams(
size=1536, distance=models.Distance.COSINE
),
hnsw_config=models.HnswConfigDiff(
payload_m=16,
m=0,
),
model_suffix="model_1536d",
)
# Verify migration WAS attempted
client.create_collection.assert_called_once()
client.scroll.assert_called()
client.upsert.assert_called()
class TestPostgresDimensionMismatch:
"""Test suite for PostgreSQL dimension mismatch handling."""
async def test_postgres_dimension_mismatch_raises_error_metadata(self):
"""
Test that PostgreSQL raises DataMigrationError when dimensions don't match.
Scenario: Legacy table has 1536d vectors, new model expects 3072d.
Expected: DataMigrationError is raised to prevent data corruption.
"""
# Setup mock database
db = AsyncMock()
# Mock check_table_exists
async def mock_check_table_exists(table_name):
if table_name == "LIGHTRAG_DOC_CHUNKS": # legacy
return True
elif table_name == "LIGHTRAG_DOC_CHUNKS_model_3072d": # new
return False
return False
db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
# Mock table existence and dimension checks
async def query_side_effect(query, params, **kwargs):
if "COUNT(*)" in query:
return {"count": 100} # Legacy has data
elif "SELECT content_vector FROM" in query:
# Return sample vector with 1536 dimensions
return {"content_vector": [0.1] * 1536}
return {}
db.query.side_effect = query_side_effect
db.execute = AsyncMock()
db._create_vector_index = AsyncMock()
# Call setup_table with 3072d (different from legacy 1536d)
# Should raise DataMigrationError due to dimension mismatch
with pytest.raises(DataMigrationError) as exc_info:
await PGVectorStorage.setup_table(
db,
"LIGHTRAG_DOC_CHUNKS_model_3072d",
legacy_table_name="LIGHTRAG_DOC_CHUNKS",
base_table="LIGHTRAG_DOC_CHUNKS",
embedding_dim=3072,
workspace="test",
)
# Verify error message contains dimension information
assert "3072" in str(exc_info.value) or "1536" in str(exc_info.value)
async def test_postgres_dimension_mismatch_raises_error_sampling(self):
"""
Test that PostgreSQL raises error when dimensions don't match (via sampling).
Scenario: Legacy table vector sampling detects 1536d vs expected 3072d.
Expected: DataMigrationError is raised to prevent data corruption.
"""
db = AsyncMock()
# Mock check_table_exists
async def mock_check_table_exists(table_name):
if table_name == "LIGHTRAG_DOC_CHUNKS": # legacy
return True
elif table_name == "LIGHTRAG_DOC_CHUNKS_model_3072d": # new
return False
return False
db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
# Mock table existence and dimension checks
async def query_side_effect(query, params, **kwargs):
if "information_schema.tables" in query:
if params[0] == "LIGHTRAG_DOC_CHUNKS": # legacy
return {"exists": True}
elif params[0] == "LIGHTRAG_DOC_CHUNKS_model_3072d": # new
return {"exists": False}
elif "COUNT(*)" in query:
return {"count": 100} # Legacy has data
elif "SELECT content_vector FROM" in query:
# Return sample vector with 1536 dimensions as a JSON string
return {"content_vector": json.dumps([0.1] * 1536)}
return {}
db.query.side_effect = query_side_effect
db.execute = AsyncMock()
db._create_vector_index = AsyncMock()
# Call setup_table with 3072d (different from legacy 1536d)
# Should raise DataMigrationError due to dimension mismatch
with pytest.raises(DataMigrationError) as exc_info:
await PGVectorStorage.setup_table(
db,
"LIGHTRAG_DOC_CHUNKS_model_3072d",
legacy_table_name="LIGHTRAG_DOC_CHUNKS",
base_table="LIGHTRAG_DOC_CHUNKS",
embedding_dim=3072,
workspace="test",
)
# Verify error message contains dimension information
assert "3072" in str(exc_info.value) or "1536" in str(exc_info.value)
async def test_postgres_dimension_match_proceed_migration(self):
"""
Test that PostgreSQL proceeds with migration when dimensions match.
Scenario: Legacy table has 1536d vectors, new model also expects 1536d.
Expected: Migration proceeds normally.
"""
db = AsyncMock()
# Track migration state
migration_done = {"value": False}
# Define exactly 2 records for consistency
mock_records = [
{
"id": "test1",
"content_vector": [0.1] * 1536,
"workspace": "test",
},
{
"id": "test2",
"content_vector": [0.2] * 1536,
"workspace": "test",
},
]
# Mock check_table_exists
async def mock_check_table_exists(table_name):
if table_name == "LIGHTRAG_DOC_CHUNKS": # legacy exists
return True
elif table_name == "LIGHTRAG_DOC_CHUNKS_model_1536d": # new doesn't exist
return False
return False
db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
async def query_side_effect(query, params, **kwargs):
multirows = kwargs.get("multirows", False)
query_upper = query.upper()
if "information_schema.tables" in query:
if params[0] == "LIGHTRAG_DOC_CHUNKS": # legacy
return {"exists": True}
elif params[0] == "LIGHTRAG_DOC_CHUNKS_model_1536d": # new
return {"exists": False}
elif "COUNT(*)" in query_upper:
# Return different counts based on table name in query and migration state
if "LIGHTRAG_DOC_CHUNKS_MODEL_1536D" in query_upper:
# After migration: return migrated count, before: return 0
return {
"count": len(mock_records) if migration_done["value"] else 0
}
# Legacy table always has 2 records (matching mock_records)
return {"count": len(mock_records)}
elif "PG_ATTRIBUTE" in query_upper:
return {"vector_dim": 1536} # Legacy has matching 1536d
elif "SELECT" in query_upper and "FROM" in query_upper and multirows:
# Return sample data for migration using keyset pagination
# Handle keyset pagination: params = [workspace, limit] or [workspace, last_id, limit]
if "id >" in query.lower():
# Keyset pagination: params = [workspace, last_id, limit]
last_id = params[1] if len(params) > 1 else None
# Find records after last_id
found_idx = -1
for i, rec in enumerate(mock_records):
if rec["id"] == last_id:
found_idx = i
break
if found_idx >= 0:
return mock_records[found_idx + 1 :]
return []
else:
# First batch: params = [workspace, limit]
return mock_records
return {}
db.query.side_effect = query_side_effect
# Mock _run_with_retry to track when migration happens
migration_executed = []
async def mock_run_with_retry(operation, *args, **kwargs):
migration_executed.append(True)
migration_done["value"] = True
return None
db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry)
db.execute = AsyncMock()
db._create_vector_index = AsyncMock()
# Call setup_table with matching 1536d
await PGVectorStorage.setup_table(
db,
"LIGHTRAG_DOC_CHUNKS_model_1536d",
legacy_table_name="LIGHTRAG_DOC_CHUNKS",
base_table="LIGHTRAG_DOC_CHUNKS",
embedding_dim=1536,
workspace="test",
)
# Verify migration WAS called (via _run_with_retry for batch operations)
assert len(migration_executed) > 0, "Migration should have been executed"

View File

@@ -0,0 +1,220 @@
"""
Tests for safety when model suffix is absent (no model_name provided).
This test module verifies that the system correctly handles the case when
no model_name is provided, preventing accidental deletion of the only table/collection
on restart.
Critical Bug: When model_suffix is empty, table_name == legacy_table_name.
On second startup, Case 1 logic would delete the only table/collection thinking
it's "legacy", causing all subsequent operations to fail.
"""
from unittest.mock import MagicMock, AsyncMock, patch
from lightrag.kg.qdrant_impl import QdrantVectorDBStorage
from lightrag.kg.postgres_impl import PGVectorStorage
class TestNoModelSuffixSafety:
"""Test suite for preventing data loss when model_suffix is absent."""
def test_qdrant_no_suffix_second_startup(self):
"""
Test Qdrant doesn't delete collection on second startup when no model_name.
Scenario:
1. First startup: Creates collection without suffix
2. Collection is empty
3. Second startup: Should NOT delete the collection
Bug: Without fix, Case 1 would delete the only collection.
"""
from qdrant_client import models
client = MagicMock()
# Simulate second startup: collection already exists and is empty
# IMPORTANT: Without suffix, collection_name == legacy collection name
collection_name = "lightrag_vdb_chunks" # No suffix, same as legacy
# Both exist (they're the same collection)
client.collection_exists.return_value = True
# Collection is empty
client.count.return_value.count = 0
# Patch _find_legacy_collection to return the SAME collection name
# This simulates the scenario where new collection == legacy collection
with patch(
"lightrag.kg.qdrant_impl._find_legacy_collection",
return_value="lightrag_vdb_chunks", # Same as collection_name
):
# Call setup_collection
# This should detect that new == legacy and skip deletion
QdrantVectorDBStorage.setup_collection(
client,
collection_name,
namespace="chunks",
workspace="_",
vectors_config=models.VectorParams(
size=1536, distance=models.Distance.COSINE
),
hnsw_config=models.HnswConfigDiff(
payload_m=16,
m=0,
),
model_suffix="", # Empty suffix to simulate no model_name provided
)
# CRITICAL: Collection should NOT be deleted
client.delete_collection.assert_not_called()
# Verify we returned early (skipped Case 1 cleanup)
# The collection_exists was checked, but we didn't proceed to count
# because we detected same name
assert client.collection_exists.call_count >= 1
async def test_postgres_no_suffix_second_startup(self):
"""
Test PostgreSQL doesn't delete table on second startup when no model_name.
Scenario:
1. First startup: Creates table without suffix
2. Table is empty
3. Second startup: Should NOT delete the table
Bug: Without fix, Case 1 would delete the only table.
"""
db = AsyncMock()
# Configure mock return values to avoid unawaited coroutine warnings
db.query.return_value = {"count": 0}
db._create_vector_index.return_value = None
# Simulate second startup: table already exists and is empty
# IMPORTANT: table_name and legacy_table_name are THE SAME
table_name = "LIGHTRAG_VDB_CHUNKS" # No suffix
legacy_table_name = "LIGHTRAG_VDB_CHUNKS" # Same as new
# Setup mock responses using check_table_exists on db
async def check_table_exists_side_effect(name):
# Both tables exist (they're the same)
return True
db.check_table_exists = AsyncMock(side_effect=check_table_exists_side_effect)
# Call setup_table
# This should detect that new == legacy and skip deletion
await PGVectorStorage.setup_table(
db,
table_name,
workspace="test_workspace",
embedding_dim=1536,
legacy_table_name=legacy_table_name,
base_table="LIGHTRAG_VDB_CHUNKS",
)
# CRITICAL: Table should NOT be deleted (no DROP TABLE)
drop_calls = [
call
for call in db.execute.call_args_list
if call[0][0] and "DROP TABLE" in call[0][0]
]
assert (
len(drop_calls) == 0
), "Should not drop table when new and legacy are the same"
# Note: COUNT queries for workspace data are expected behavior in Case 1
# (for logging/warning purposes when workspace data is empty).
# The critical safety check is that DROP TABLE is not called.
def test_qdrant_with_suffix_case1_still_works(self):
"""
Test that Case 1 cleanup still works when there IS a suffix.
This ensures our fix doesn't break the normal Case 1 scenario.
"""
from qdrant_client import models
client = MagicMock()
# Different names (normal case)
collection_name = "lightrag_vdb_chunks_ada_002_1536d" # With suffix
legacy_collection = "lightrag_vdb_chunks" # Without suffix
# Setup: both exist
def collection_exists_side_effect(name):
return name in [collection_name, legacy_collection]
client.collection_exists.side_effect = collection_exists_side_effect
# Legacy is empty
client.count.return_value.count = 0
# Call setup_collection
QdrantVectorDBStorage.setup_collection(
client,
collection_name,
namespace="chunks",
workspace="_",
vectors_config=models.VectorParams(
size=1536, distance=models.Distance.COSINE
),
hnsw_config=models.HnswConfigDiff(
payload_m=16,
m=0,
),
model_suffix="ada_002_1536d",
)
# SHOULD delete legacy (normal Case 1 behavior)
client.delete_collection.assert_called_once_with(
collection_name=legacy_collection
)
async def test_postgres_with_suffix_case1_still_works(self):
"""
Test that Case 1 cleanup still works when there IS a suffix.
This ensures our fix doesn't break the normal Case 1 scenario.
"""
db = AsyncMock()
# Different names (normal case)
table_name = "LIGHTRAG_VDB_CHUNKS_ADA_002_1536D" # With suffix
legacy_table_name = "LIGHTRAG_VDB_CHUNKS" # Without suffix
# Setup mock responses using check_table_exists on db
async def check_table_exists_side_effect(name):
# Both tables exist
return True
db.check_table_exists = AsyncMock(side_effect=check_table_exists_side_effect)
# Mock empty table
async def query_side_effect(sql, params, **kwargs):
if "COUNT(*)" in sql:
return {"count": 0}
return {}
db.query.side_effect = query_side_effect
# Call setup_table
await PGVectorStorage.setup_table(
db,
table_name,
workspace="test_workspace",
embedding_dim=1536,
legacy_table_name=legacy_table_name,
base_table="LIGHTRAG_VDB_CHUNKS",
)
# SHOULD delete legacy (normal Case 1 behavior)
drop_calls = [
call
for call in db.execute.call_args_list
if call[0][0] and "DROP TABLE" in call[0][0]
]
assert len(drop_calls) == 1, "Should drop legacy table in normal Case 1"
assert legacy_table_name in drop_calls[0][0][0]

View File

@@ -0,0 +1,210 @@
"""
Unit tests for PostgreSQL safe index name generation.
This module tests the _safe_index_name helper function which prevents
PostgreSQL's silent 63-byte identifier truncation from causing index
lookup failures.
"""
import pytest
# Mark all tests as offline (no external dependencies)
pytestmark = pytest.mark.offline
class TestSafeIndexName:
"""Test suite for _safe_index_name function."""
def test_short_name_unchanged(self):
"""Short index names should remain unchanged."""
from lightrag.kg.postgres_impl import _safe_index_name
# Short table name - should return unchanged
result = _safe_index_name("lightrag_vdb_entity", "hnsw_cosine")
assert result == "idx_lightrag_vdb_entity_hnsw_cosine"
assert len(result.encode("utf-8")) <= 63
def test_long_name_gets_hashed(self):
"""Long table names exceeding 63 bytes should get hashed."""
from lightrag.kg.postgres_impl import _safe_index_name
# Long table name that would exceed 63 bytes
long_table_name = "LIGHTRAG_VDB_ENTITY_text_embedding_3_large_3072d"
result = _safe_index_name(long_table_name, "hnsw_cosine")
# Should be within 63 bytes
assert len(result.encode("utf-8")) <= 63
# Should start with idx_ prefix
assert result.startswith("idx_")
# Should contain the suffix
assert result.endswith("_hnsw_cosine")
# Should NOT be the naive concatenation (which would be truncated)
naive_name = f"idx_{long_table_name.lower()}_hnsw_cosine"
assert result != naive_name
def test_deterministic_output(self):
"""Same input should always produce same output (deterministic)."""
from lightrag.kg.postgres_impl import _safe_index_name
table_name = "LIGHTRAG_VDB_CHUNKS_text_embedding_3_large_3072d"
suffix = "hnsw_cosine"
result1 = _safe_index_name(table_name, suffix)
result2 = _safe_index_name(table_name, suffix)
assert result1 == result2
def test_different_suffixes_different_results(self):
"""Different suffixes should produce different index names."""
from lightrag.kg.postgres_impl import _safe_index_name
table_name = "LIGHTRAG_VDB_ENTITY_text_embedding_3_large_3072d"
result1 = _safe_index_name(table_name, "hnsw_cosine")
result2 = _safe_index_name(table_name, "ivfflat_cosine")
assert result1 != result2
def test_case_insensitive(self):
"""Table names should be normalized to lowercase."""
from lightrag.kg.postgres_impl import _safe_index_name
result_upper = _safe_index_name("LIGHTRAG_VDB_ENTITY", "hnsw_cosine")
result_lower = _safe_index_name("lightrag_vdb_entity", "hnsw_cosine")
assert result_upper == result_lower
def test_boundary_case_exactly_63_bytes(self):
"""Test boundary case where name is exactly at 63-byte limit."""
from lightrag.kg.postgres_impl import _safe_index_name
# Create a table name that results in exactly 63 bytes
# idx_ (4) + table_name + _ (1) + suffix = 63
# So table_name + suffix = 58
# Test a name that's just under the limit (should remain unchanged)
short_suffix = "id"
# idx_ (4) + 56 chars + _ (1) + id (2) = 63
table_56 = "a" * 56
result = _safe_index_name(table_56, short_suffix)
expected = f"idx_{table_56}_{short_suffix}"
assert result == expected
assert len(result.encode("utf-8")) == 63
def test_unicode_handling(self):
"""Unicode characters should be properly handled (bytes, not chars)."""
from lightrag.kg.postgres_impl import _safe_index_name
# Unicode characters can take more bytes than visible chars
# Chinese characters are 3 bytes each in UTF-8
table_name = "lightrag_测试_table" # Contains Chinese chars
result = _safe_index_name(table_name, "hnsw_cosine")
# Should always be within 63 bytes
assert len(result.encode("utf-8")) <= 63
def test_real_world_model_names(self):
"""Test with real-world embedding model names that cause issues."""
from lightrag.kg.postgres_impl import _safe_index_name
# These are actual model names that have caused issues
test_cases = [
("LIGHTRAG_VDB_CHUNKS_text_embedding_3_large_3072d", "hnsw_cosine"),
("LIGHTRAG_VDB_ENTITY_text_embedding_3_large_3072d", "hnsw_cosine"),
("LIGHTRAG_VDB_RELATION_text_embedding_3_large_3072d", "hnsw_cosine"),
(
"LIGHTRAG_VDB_ENTITY_bge_m3_1024d",
"hnsw_cosine",
), # Shorter model name
(
"LIGHTRAG_VDB_CHUNKS_nomic_embed_text_v1_768d",
"ivfflat_cosine",
), # Different index type
]
for table_name, suffix in test_cases:
result = _safe_index_name(table_name, suffix)
# Critical: must be within PostgreSQL's 63-byte limit
assert (
len(result.encode("utf-8")) <= 63
), f"Index name too long: {result} for table {table_name}"
# Must have consistent format
assert result.startswith("idx_"), f"Missing idx_ prefix: {result}"
assert result.endswith(f"_{suffix}"), f"Missing suffix {suffix}: {result}"
def test_hash_uniqueness_for_similar_tables(self):
"""Similar but different table names should produce different hashes."""
from lightrag.kg.postgres_impl import _safe_index_name
# These tables have similar names but should have different hashes
tables = [
"LIGHTRAG_VDB_CHUNKS_model_a_1024d",
"LIGHTRAG_VDB_CHUNKS_model_b_1024d",
"LIGHTRAG_VDB_ENTITY_model_a_1024d",
]
results = [_safe_index_name(t, "hnsw_cosine") for t in tables]
# All results should be unique
assert len(set(results)) == len(results), "Hash collision detected!"
class TestIndexNameIntegration:
"""Integration-style tests for index name usage patterns."""
def test_pg_indexes_lookup_compatibility(self):
"""
Test that the generated index name will work with pg_indexes lookup.
This is the core problem: PostgreSQL stores the truncated name,
but we were looking up the untruncated name. Our fix ensures we
always use a name that fits within 63 bytes.
"""
from lightrag.kg.postgres_impl import _safe_index_name
table_name = "LIGHTRAG_VDB_CHUNKS_text_embedding_3_large_3072d"
suffix = "hnsw_cosine"
# Generate the index name
index_name = _safe_index_name(table_name, suffix)
# Simulate what PostgreSQL would store (truncate at 63 bytes)
stored_name = index_name.encode("utf-8")[:63].decode("utf-8", errors="ignore")
# The key fix: our generated name should equal the stored name
# because it's already within the 63-byte limit
assert (
index_name == stored_name
), "Index name would be truncated by PostgreSQL, causing lookup failures!"
def test_backward_compatibility_short_names(self):
"""
Ensure backward compatibility with existing short index names.
For tables that have existing indexes with short names (pre-model-suffix era),
the function should not change their names.
"""
from lightrag.kg.postgres_impl import _safe_index_name
# Legacy table names without model suffix
legacy_tables = [
"LIGHTRAG_VDB_ENTITY",
"LIGHTRAG_VDB_RELATION",
"LIGHTRAG_VDB_CHUNKS",
]
for table in legacy_tables:
for suffix in ["hnsw_cosine", "ivfflat_cosine", "id"]:
result = _safe_index_name(table, suffix)
expected = f"idx_{table.lower()}_{suffix}"
# Short names should remain unchanged for backward compatibility
if len(expected.encode("utf-8")) <= 63:
assert (
result == expected
), f"Short name changed unexpectedly: {result} != {expected}"

View File

@@ -0,0 +1,856 @@
import pytest
from unittest.mock import patch, AsyncMock
import numpy as np
from lightrag.utils import EmbeddingFunc
from lightrag.kg.postgres_impl import (
PGVectorStorage,
)
from lightrag.namespace import NameSpace
# Mock PostgreSQLDB
@pytest.fixture
def mock_pg_db():
"""Mock PostgreSQL database connection"""
db = AsyncMock()
db.workspace = "test_workspace"
# Mock query responses with multirows support
async def mock_query(sql, params=None, multirows=False, **kwargs):
# Default return value
if multirows:
return [] # Return empty list for multirows
return {"exists": False, "count": 0}
# Mock for execute that mimics PostgreSQLDB.execute() behavior
async def mock_execute(sql, data=None, **kwargs):
"""
Mock that mimics PostgreSQLDB.execute() behavior:
- Accepts data as dict[str, Any] | None (second parameter)
- Internally converts dict.values() to tuple for AsyncPG
"""
# Mimic real execute() which accepts dict and converts to tuple
if data is not None and not isinstance(data, dict):
raise TypeError(
f"PostgreSQLDB.execute() expects data as dict, got {type(data).__name__}"
)
return None
db.query = AsyncMock(side_effect=mock_query)
db.execute = AsyncMock(side_effect=mock_execute)
return db
# Mock get_data_init_lock to avoid async lock issues in tests
@pytest.fixture(autouse=True)
def mock_data_init_lock():
with patch("lightrag.kg.postgres_impl.get_data_init_lock") as mock_lock:
mock_lock_ctx = AsyncMock()
mock_lock.return_value = mock_lock_ctx
yield mock_lock
# Mock ClientManager
@pytest.fixture
def mock_client_manager(mock_pg_db):
with patch("lightrag.kg.postgres_impl.ClientManager") as mock_manager:
mock_manager.get_client = AsyncMock(return_value=mock_pg_db)
mock_manager.release_client = AsyncMock()
yield mock_manager
# Mock Embedding function
@pytest.fixture
def mock_embedding_func():
async def embed_func(texts, **kwargs):
return np.array([[0.1] * 768 for _ in texts])
func = EmbeddingFunc(embedding_dim=768, func=embed_func, model_name="test_model")
return func
async def test_postgres_table_naming(
mock_client_manager, mock_pg_db, mock_embedding_func
):
"""Test if table name is correctly generated with model suffix"""
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
storage = PGVectorStorage(
namespace=NameSpace.VECTOR_STORE_CHUNKS,
global_config=config,
embedding_func=mock_embedding_func,
workspace="test_ws",
)
# Verify table name contains model suffix
expected_suffix = "test_model_768d"
assert expected_suffix in storage.table_name
assert storage.table_name == f"LIGHTRAG_VDB_CHUNKS_{expected_suffix}"
# Verify legacy table name
assert storage.legacy_table_name == "LIGHTRAG_VDB_CHUNKS"
async def test_postgres_migration_trigger(
mock_client_manager, mock_pg_db, mock_embedding_func
):
"""Test if migration logic is triggered correctly"""
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
storage = PGVectorStorage(
namespace=NameSpace.VECTOR_STORE_CHUNKS,
global_config=config,
embedding_func=mock_embedding_func,
workspace="test_ws",
)
# Setup mocks for migration scenario
# 1. New table does not exist, legacy table exists
async def mock_check_table_exists(table_name):
return table_name == storage.legacy_table_name
mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
# 2. Legacy table has 100 records
mock_rows = [
{"id": f"test_id_{i}", "content": f"content_{i}", "workspace": "test_ws"}
for i in range(100)
]
migration_state = {"new_table_count": 0}
async def mock_query(sql, params=None, multirows=False, **kwargs):
if "COUNT(*)" in sql:
sql_upper = sql.upper()
legacy_table = storage.legacy_table_name.upper()
new_table = storage.table_name.upper()
is_new_table = new_table in sql_upper
is_legacy_table = legacy_table in sql_upper and not is_new_table
if is_new_table:
return {"count": migration_state["new_table_count"]}
if is_legacy_table:
return {"count": 100}
return {"count": 0}
elif multirows and "SELECT *" in sql:
# Mock batch fetch for migration using keyset pagination
# New pattern: WHERE workspace = $1 AND id > $2 ORDER BY id LIMIT $3
# or first batch: WHERE workspace = $1 ORDER BY id LIMIT $2
if "WHERE workspace" in sql:
if "id >" in sql:
# Keyset pagination: params = [workspace, last_id, limit]
last_id = params[1] if len(params) > 1 else None
# Find rows after last_id
start_idx = 0
for i, row in enumerate(mock_rows):
if row["id"] == last_id:
start_idx = i + 1
break
limit = params[2] if len(params) > 2 else 500
else:
# First batch (no last_id): params = [workspace, limit]
start_idx = 0
limit = params[1] if len(params) > 1 else 500
else:
# No workspace filter with keyset
if "id >" in sql:
last_id = params[0] if params else None
start_idx = 0
for i, row in enumerate(mock_rows):
if row["id"] == last_id:
start_idx = i + 1
break
limit = params[1] if len(params) > 1 else 500
else:
start_idx = 0
limit = params[0] if params else 500
end = min(start_idx + limit, len(mock_rows))
return mock_rows[start_idx:end]
return {}
mock_pg_db.query = AsyncMock(side_effect=mock_query)
# Track migration through _run_with_retry calls
migration_executed = []
async def mock_run_with_retry(operation, **kwargs):
# Track that migration batch operation was called
migration_executed.append(True)
migration_state["new_table_count"] = 100
return None
mock_pg_db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry)
with patch(
"lightrag.kg.postgres_impl.PGVectorStorage._pg_create_table", AsyncMock()
):
# Initialize storage (should trigger migration)
await storage.initialize()
# Verify migration was executed by checking _run_with_retry was called
# (batch migration uses _run_with_retry with executemany)
assert len(migration_executed) > 0, "Migration should have been executed"
async def test_postgres_no_migration_needed(
mock_client_manager, mock_pg_db, mock_embedding_func
):
"""Test scenario where new table already exists (no migration needed)"""
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
storage = PGVectorStorage(
namespace=NameSpace.VECTOR_STORE_CHUNKS,
global_config=config,
embedding_func=mock_embedding_func,
workspace="test_ws",
)
# Mock: new table already exists
async def mock_check_table_exists(table_name):
return table_name == storage.table_name
mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
with patch(
"lightrag.kg.postgres_impl.PGVectorStorage._pg_create_table", AsyncMock()
) as mock_create:
await storage.initialize()
# Verify no table creation was attempted
mock_create.assert_not_called()
async def test_scenario_1_new_workspace_creation(
mock_client_manager, mock_pg_db, mock_embedding_func
):
"""
Scenario 1: New workspace creation
Expected behavior:
- No legacy table exists
- Directly create new table with model suffix
- No migration needed
"""
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
embedding_func = EmbeddingFunc(
embedding_dim=3072,
func=mock_embedding_func.func,
model_name="text-embedding-3-large",
)
storage = PGVectorStorage(
namespace=NameSpace.VECTOR_STORE_CHUNKS,
global_config=config,
embedding_func=embedding_func,
workspace="new_workspace",
)
# Mock: neither table exists
async def mock_check_table_exists(table_name):
return False
mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
with patch(
"lightrag.kg.postgres_impl.PGVectorStorage._pg_create_table", AsyncMock()
) as mock_create:
await storage.initialize()
# Verify table name format
assert "text_embedding_3_large_3072d" in storage.table_name
# Verify new table creation was called
mock_create.assert_called_once()
call_args = mock_create.call_args
assert (
call_args[0][1] == storage.table_name
) # table_name is second positional arg
async def test_scenario_2_legacy_upgrade_migration(
mock_client_manager, mock_pg_db, mock_embedding_func
):
"""
Scenario 2: Upgrade from legacy version
Expected behavior:
- Legacy table exists (without model suffix)
- New table doesn't exist
- Automatically migrate data to new table with suffix
"""
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
embedding_func = EmbeddingFunc(
embedding_dim=1536,
func=mock_embedding_func.func,
model_name="text-embedding-ada-002",
)
storage = PGVectorStorage(
namespace=NameSpace.VECTOR_STORE_CHUNKS,
global_config=config,
embedding_func=embedding_func,
workspace="legacy_workspace",
)
# Mock: only legacy table exists
async def mock_check_table_exists(table_name):
return table_name == storage.legacy_table_name
mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
# Mock: legacy table has 50 records
mock_rows = [
{
"id": f"legacy_id_{i}",
"content": f"legacy_content_{i}",
"workspace": "legacy_workspace",
}
for i in range(50)
]
# Track which queries have been made for proper response
query_history = []
migration_state = {"new_table_count": 0}
async def mock_query(sql, params=None, multirows=False, **kwargs):
query_history.append(sql)
if "COUNT(*)" in sql:
# Determine table type:
# - Legacy: contains base name but NOT model suffix
# - New: contains model suffix (e.g., text_embedding_ada_002_1536d)
sql_upper = sql.upper()
base_name = storage.legacy_table_name.upper()
# Check if this is querying the new table (has model suffix)
has_model_suffix = storage.table_name.upper() in sql_upper
is_legacy_table = base_name in sql_upper and not has_model_suffix
has_workspace_filter = "WHERE workspace" in sql
if is_legacy_table and has_workspace_filter:
# Count for legacy table with workspace filter (before migration)
return {"count": 50}
elif is_legacy_table and not has_workspace_filter:
# Total count for legacy table
return {"count": 50}
else:
# New table count (before/after migration)
return {"count": migration_state["new_table_count"]}
elif multirows and "SELECT *" in sql:
# Mock batch fetch for migration using keyset pagination
# New pattern: WHERE workspace = $1 AND id > $2 ORDER BY id LIMIT $3
# or first batch: WHERE workspace = $1 ORDER BY id LIMIT $2
if "WHERE workspace" in sql:
if "id >" in sql:
# Keyset pagination: params = [workspace, last_id, limit]
last_id = params[1] if len(params) > 1 else None
# Find rows after last_id
start_idx = 0
for i, row in enumerate(mock_rows):
if row["id"] == last_id:
start_idx = i + 1
break
limit = params[2] if len(params) > 2 else 500
else:
# First batch (no last_id): params = [workspace, limit]
start_idx = 0
limit = params[1] if len(params) > 1 else 500
else:
# No workspace filter with keyset
if "id >" in sql:
last_id = params[0] if params else None
start_idx = 0
for i, row in enumerate(mock_rows):
if row["id"] == last_id:
start_idx = i + 1
break
limit = params[1] if len(params) > 1 else 500
else:
start_idx = 0
limit = params[0] if params else 500
end = min(start_idx + limit, len(mock_rows))
return mock_rows[start_idx:end]
return {}
mock_pg_db.query = AsyncMock(side_effect=mock_query)
# Track migration through _run_with_retry calls
migration_executed = []
async def mock_run_with_retry(operation, **kwargs):
# Track that migration batch operation was called
migration_executed.append(True)
migration_state["new_table_count"] = 50
return None
mock_pg_db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry)
with patch(
"lightrag.kg.postgres_impl.PGVectorStorage._pg_create_table", AsyncMock()
) as mock_create:
await storage.initialize()
# Verify table name contains ada-002
assert "text_embedding_ada_002_1536d" in storage.table_name
# Verify migration was executed (batch migration uses _run_with_retry)
assert len(migration_executed) > 0, "Migration should have been executed"
mock_create.assert_called_once()
async def test_scenario_3_multi_model_coexistence(
mock_client_manager, mock_pg_db, mock_embedding_func
):
"""
Scenario 3: Multiple embedding models coexist
Expected behavior:
- Different embedding models create separate tables
- Tables are isolated by model suffix
- No interference between different models
"""
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
# Workspace A: uses bge-small (768d)
embedding_func_a = EmbeddingFunc(
embedding_dim=768, func=mock_embedding_func.func, model_name="bge-small"
)
storage_a = PGVectorStorage(
namespace=NameSpace.VECTOR_STORE_CHUNKS,
global_config=config,
embedding_func=embedding_func_a,
workspace="workspace_a",
)
# Workspace B: uses bge-large (1024d)
async def embed_func_b(texts, **kwargs):
return np.array([[0.1] * 1024 for _ in texts])
embedding_func_b = EmbeddingFunc(
embedding_dim=1024, func=embed_func_b, model_name="bge-large"
)
storage_b = PGVectorStorage(
namespace=NameSpace.VECTOR_STORE_CHUNKS,
global_config=config,
embedding_func=embedding_func_b,
workspace="workspace_b",
)
# Verify different table names
assert storage_a.table_name != storage_b.table_name
assert "bge_small_768d" in storage_a.table_name
assert "bge_large_1024d" in storage_b.table_name
# Mock: both tables don't exist yet
async def mock_check_table_exists(table_name):
return False
mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
with patch(
"lightrag.kg.postgres_impl.PGVectorStorage._pg_create_table", AsyncMock()
) as mock_create:
# Initialize both storages
await storage_a.initialize()
await storage_b.initialize()
# Verify two separate tables were created
assert mock_create.call_count == 2
# Verify table names are different
call_args_list = mock_create.call_args_list
table_names = [call[0][1] for call in call_args_list] # Second positional arg
assert len(set(table_names)) == 2 # Two unique table names
assert storage_a.table_name in table_names
assert storage_b.table_name in table_names
async def test_case1_empty_legacy_auto_cleanup(
mock_client_manager, mock_pg_db, mock_embedding_func
):
"""
Case 1a: Both new and legacy tables exist, but legacy is EMPTY
Expected: Automatically delete empty legacy table (safe cleanup)
"""
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
embedding_func = EmbeddingFunc(
embedding_dim=1536,
func=mock_embedding_func.func,
model_name="test-model",
)
storage = PGVectorStorage(
namespace=NameSpace.VECTOR_STORE_CHUNKS,
global_config=config,
embedding_func=embedding_func,
workspace="test_ws",
)
# Mock: Both tables exist
async def mock_check_table_exists(table_name):
return True # Both new and legacy exist
mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
# Mock: Legacy table is empty (0 records)
async def mock_query(sql, params=None, multirows=False, **kwargs):
if "COUNT(*)" in sql:
if storage.legacy_table_name in sql:
return {"count": 0} # Empty legacy table
else:
return {"count": 100} # New table has data
return {}
mock_pg_db.query = AsyncMock(side_effect=mock_query)
with patch("lightrag.kg.postgres_impl.logger"):
await storage.initialize()
# Verify: Empty legacy table should be automatically cleaned up
# Empty tables are safe to delete without data loss risk
delete_calls = [
call
for call in mock_pg_db.execute.call_args_list
if call[0][0] and "DROP TABLE" in call[0][0]
]
assert len(delete_calls) >= 1, "Empty legacy table should be auto-deleted"
# Check if legacy table was dropped
dropped_table = storage.legacy_table_name
assert any(
dropped_table in str(call) for call in delete_calls
), f"Expected to drop empty legacy table '{dropped_table}'"
print(
f"✅ Case 1a: Empty legacy table '{dropped_table}' auto-deleted successfully"
)
async def test_case1_nonempty_legacy_warning(
mock_client_manager, mock_pg_db, mock_embedding_func
):
"""
Case 1b: Both new and legacy tables exist, and legacy HAS DATA
Expected: Log warning, do not delete legacy (preserve data)
"""
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
embedding_func = EmbeddingFunc(
embedding_dim=1536,
func=mock_embedding_func.func,
model_name="test-model",
)
storage = PGVectorStorage(
namespace=NameSpace.VECTOR_STORE_CHUNKS,
global_config=config,
embedding_func=embedding_func,
workspace="test_ws",
)
# Mock: Both tables exist
async def mock_check_table_exists(table_name):
return True # Both new and legacy exist
mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
# Mock: Legacy table has data (50 records)
async def mock_query(sql, params=None, multirows=False, **kwargs):
if "COUNT(*)" in sql:
if storage.legacy_table_name in sql:
return {"count": 50} # Legacy has data
else:
return {"count": 100} # New table has data
return {}
mock_pg_db.query = AsyncMock(side_effect=mock_query)
with patch("lightrag.kg.postgres_impl.logger"):
await storage.initialize()
# Verify: Legacy table with data should be preserved
# We never auto-delete tables that contain data to prevent accidental data loss
delete_calls = [
call
for call in mock_pg_db.execute.call_args_list
if call[0][0] and "DROP TABLE" in call[0][0]
]
# Check if legacy table was deleted (it should not be)
dropped_table = storage.legacy_table_name
legacy_deleted = any(dropped_table in str(call) for call in delete_calls)
assert not legacy_deleted, "Legacy table with data should NOT be auto-deleted"
print(
f"✅ Case 1b: Legacy table '{dropped_table}' with data preserved (warning only)"
)
async def test_case1_sequential_workspace_migration(
mock_client_manager, mock_pg_db, mock_embedding_func
):
"""
Case 1c: Sequential workspace migration (Multi-tenant scenario)
Critical bug fix verification:
Timeline:
1. Legacy table has workspace_a (3 records) + workspace_b (3 records)
2. Workspace A initializes first → Case 3 (only legacy exists) → migrates A's data
3. Workspace B initializes later → Case 3 (both tables exist, legacy has B's data) → should migrate B's data
4. Verify workspace B's data is correctly migrated to new table
This test verifies the migration logic correctly handles multi-tenant scenarios
where different workspaces migrate sequentially.
"""
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
embedding_func = EmbeddingFunc(
embedding_dim=1536,
func=mock_embedding_func.func,
model_name="test-model",
)
# Mock data: Legacy table has 6 records total (3 from workspace_a, 3 from workspace_b)
mock_rows_a = [
{"id": f"a_{i}", "content": f"A content {i}", "workspace": "workspace_a"}
for i in range(3)
]
mock_rows_b = [
{"id": f"b_{i}", "content": f"B content {i}", "workspace": "workspace_b"}
for i in range(3)
]
# Track migration state
migration_state = {
"new_table_exists": False,
"workspace_a_migrated": False,
"workspace_a_migration_count": 0,
"workspace_b_migration_count": 0,
}
# Step 1: Simulate workspace_a initialization (Case 3 - only legacy exists)
# CRITICAL: Set db.workspace to workspace_a
mock_pg_db.workspace = "workspace_a"
storage_a = PGVectorStorage(
namespace=NameSpace.VECTOR_STORE_CHUNKS,
global_config=config,
embedding_func=embedding_func,
workspace="workspace_a",
)
# Mock table_exists for workspace_a
async def mock_check_table_exists_a(table_name):
if table_name == storage_a.legacy_table_name:
return True
if table_name == storage_a.table_name:
return migration_state["new_table_exists"]
return False
mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists_a)
# Mock query for workspace_a (Case 3)
async def mock_query_a(sql, params=None, multirows=False, **kwargs):
sql_upper = sql.upper()
base_name = storage_a.legacy_table_name.upper()
if "COUNT(*)" in sql:
has_model_suffix = "TEST_MODEL_1536D" in sql_upper
is_legacy = base_name in sql_upper and not has_model_suffix
has_workspace_filter = "WHERE workspace" in sql
if is_legacy and has_workspace_filter:
workspace = params[0] if params and len(params) > 0 else None
if workspace == "workspace_a":
return {"count": 3}
elif workspace == "workspace_b":
return {"count": 3}
elif is_legacy and not has_workspace_filter:
# Global count in legacy table
return {"count": 6}
elif has_model_suffix:
if has_workspace_filter:
workspace = params[0] if params and len(params) > 0 else None
if workspace == "workspace_a":
return {"count": migration_state["workspace_a_migration_count"]}
if workspace == "workspace_b":
return {"count": migration_state["workspace_b_migration_count"]}
return {
"count": migration_state["workspace_a_migration_count"]
+ migration_state["workspace_b_migration_count"]
}
elif multirows and "SELECT *" in sql:
if "WHERE workspace" in sql:
workspace = params[0] if params and len(params) > 0 else None
if workspace == "workspace_a":
# Handle keyset pagination
if "id >" in sql:
# params = [workspace, last_id, limit]
last_id = params[1] if len(params) > 1 else None
start_idx = 0
for i, row in enumerate(mock_rows_a):
if row["id"] == last_id:
start_idx = i + 1
break
limit = params[2] if len(params) > 2 else 500
else:
# First batch: params = [workspace, limit]
start_idx = 0
limit = params[1] if len(params) > 1 else 500
end = min(start_idx + limit, len(mock_rows_a))
return mock_rows_a[start_idx:end]
return {}
mock_pg_db.query = AsyncMock(side_effect=mock_query_a)
# Track migration via _run_with_retry (batch migration uses this)
migration_a_executed = []
async def mock_run_with_retry_a(operation, **kwargs):
migration_a_executed.append(True)
migration_state["workspace_a_migration_count"] = len(mock_rows_a)
return None
mock_pg_db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry_a)
# Initialize workspace_a (Case 3)
with patch("lightrag.kg.postgres_impl.logger"):
await storage_a.initialize()
migration_state["new_table_exists"] = True
migration_state["workspace_a_migrated"] = True
print("✅ Step 1: Workspace A initialized")
# Verify migration was executed via _run_with_retry (batch migration uses executemany)
assert (
len(migration_a_executed) > 0
), "Migration should have been executed for workspace_a"
print(f"✅ Step 1: Migration executed {len(migration_a_executed)} batch(es)")
# Step 2: Simulate workspace_b initialization (Case 3 - both exist, but legacy has B's data)
# CRITICAL: Set db.workspace to workspace_b
mock_pg_db.workspace = "workspace_b"
storage_b = PGVectorStorage(
namespace=NameSpace.VECTOR_STORE_CHUNKS,
global_config=config,
embedding_func=embedding_func,
workspace="workspace_b",
)
mock_pg_db.reset_mock()
# Mock table_exists for workspace_b (both exist)
async def mock_check_table_exists_b(table_name):
return True # Both tables exist
mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists_b)
# Mock query for workspace_b (Case 3)
async def mock_query_b(sql, params=None, multirows=False, **kwargs):
sql_upper = sql.upper()
base_name = storage_b.legacy_table_name.upper()
if "COUNT(*)" in sql:
has_model_suffix = "TEST_MODEL_1536D" in sql_upper
is_legacy = base_name in sql_upper and not has_model_suffix
has_workspace_filter = "WHERE workspace" in sql
if is_legacy and has_workspace_filter:
workspace = params[0] if params and len(params) > 0 else None
if workspace == "workspace_b":
return {"count": 3} # workspace_b still has data in legacy
elif workspace == "workspace_a":
return {"count": 0} # workspace_a already migrated
elif is_legacy and not has_workspace_filter:
# Global count: only workspace_b data remains
return {"count": 3}
elif has_model_suffix:
if has_workspace_filter:
workspace = params[0] if params and len(params) > 0 else None
if workspace == "workspace_b":
return {"count": migration_state["workspace_b_migration_count"]}
elif workspace == "workspace_a":
return {"count": 3}
else:
return {"count": 3 + migration_state["workspace_b_migration_count"]}
elif multirows and "SELECT *" in sql:
if "WHERE workspace" in sql:
workspace = params[0] if params and len(params) > 0 else None
if workspace == "workspace_b":
# Handle keyset pagination
if "id >" in sql:
# params = [workspace, last_id, limit]
last_id = params[1] if len(params) > 1 else None
start_idx = 0
for i, row in enumerate(mock_rows_b):
if row["id"] == last_id:
start_idx = i + 1
break
limit = params[2] if len(params) > 2 else 500
else:
# First batch: params = [workspace, limit]
start_idx = 0
limit = params[1] if len(params) > 1 else 500
end = min(start_idx + limit, len(mock_rows_b))
return mock_rows_b[start_idx:end]
return {}
mock_pg_db.query = AsyncMock(side_effect=mock_query_b)
# Track migration via _run_with_retry for workspace_b
migration_b_executed = []
async def mock_run_with_retry_b(operation, **kwargs):
migration_b_executed.append(True)
migration_state["workspace_b_migration_count"] = len(mock_rows_b)
return None
mock_pg_db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry_b)
# Initialize workspace_b (Case 3 - both tables exist)
with patch("lightrag.kg.postgres_impl.logger"):
await storage_b.initialize()
print("✅ Step 2: Workspace B initialized")
# Verify workspace_b migration happens when new table has no workspace_b data
# but legacy table still has workspace_b data.
assert (
len(migration_b_executed) > 0
), "Migration should have been executed for workspace_b"
print("✅ Step 2: Migration executed for workspace_b")
print("\n🎉 Case 1c: Sequential workspace migration verification complete!")
print(" - Workspace A: Migrated successfully (only legacy existed)")
print(" - Workspace B: Migrated successfully (new table empty for workspace_b)")

View File

@@ -0,0 +1,562 @@
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
import numpy as np
from qdrant_client import models
from lightrag.utils import EmbeddingFunc
from lightrag.kg.qdrant_impl import QdrantVectorDBStorage
# Mock QdrantClient
@pytest.fixture
def mock_qdrant_client():
with patch("lightrag.kg.qdrant_impl.QdrantClient") as mock_client_cls:
client = mock_client_cls.return_value
client.collection_exists.return_value = False
client.count.return_value.count = 0
# Mock payload schema and vector config for get_collection
collection_info = MagicMock()
collection_info.payload_schema = {}
# Mock vector dimension to match mock_embedding_func (768d)
collection_info.config.params.vectors.size = 768
client.get_collection.return_value = collection_info
yield client
# Mock get_data_init_lock to avoid async lock issues in tests
@pytest.fixture(autouse=True)
def mock_data_init_lock():
with patch("lightrag.kg.qdrant_impl.get_data_init_lock") as mock_lock:
mock_lock_ctx = AsyncMock()
mock_lock.return_value = mock_lock_ctx
yield mock_lock
# Mock Embedding function
@pytest.fixture
def mock_embedding_func():
async def embed_func(texts, **kwargs):
return np.array([[0.1] * 768 for _ in texts])
func = EmbeddingFunc(embedding_dim=768, func=embed_func, model_name="test-model")
return func
async def test_qdrant_collection_naming(mock_qdrant_client, mock_embedding_func):
"""Test if collection name is correctly generated with model suffix"""
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
storage = QdrantVectorDBStorage(
namespace="chunks",
global_config=config,
embedding_func=mock_embedding_func,
workspace="test_ws",
)
# Verify collection name contains model suffix
expected_suffix = "test_model_768d"
assert expected_suffix in storage.final_namespace
assert storage.final_namespace == f"lightrag_vdb_chunks_{expected_suffix}"
async def test_qdrant_migration_trigger(mock_qdrant_client, mock_embedding_func):
"""Test if migration logic is triggered correctly"""
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
storage = QdrantVectorDBStorage(
namespace="chunks",
global_config=config,
embedding_func=mock_embedding_func,
workspace="test_ws",
)
# Legacy collection name (without model suffix)
legacy_collection = "lightrag_vdb_chunks"
# Setup mocks for migration scenario
# 1. New collection does not exist, only legacy exists
mock_qdrant_client.collection_exists.side_effect = (
lambda name: name == legacy_collection
)
# 2. Legacy collection exists and has data
migration_state = {"new_workspace_count": 0}
def count_mock(collection_name, exact=True, count_filter=None):
mock_result = MagicMock()
if collection_name == legacy_collection:
mock_result.count = 100
elif collection_name == storage.final_namespace:
mock_result.count = migration_state["new_workspace_count"]
else:
mock_result.count = 0
return mock_result
mock_qdrant_client.count.side_effect = count_mock
# 3. Mock scroll for data migration
mock_point = MagicMock()
mock_point.id = "old_id"
mock_point.vector = [0.1] * 768
mock_point.payload = {"content": "test"} # No workspace_id in payload
# When payload_schema is empty, the code first samples payloads to detect workspace_id
# Then proceeds with migration batches
# Scroll calls: 1) Sampling (limit=10), 2) Migration batch, 3) End of migration
mock_qdrant_client.scroll.side_effect = [
([mock_point], "_"), # Sampling scroll - no workspace_id found
([mock_point], "next_offset"), # Migration batch
([], None), # End of migration
]
def upsert_mock(*args, **kwargs):
migration_state["new_workspace_count"] = 100
return None
mock_qdrant_client.upsert.side_effect = upsert_mock
# Initialize storage (triggers migration)
await storage.initialize()
# Verify migration steps
# 1. Legacy count checked
mock_qdrant_client.count.assert_any_call(
collection_name=legacy_collection, exact=True
)
# 2. New collection created
mock_qdrant_client.create_collection.assert_called()
# 3. Data scrolled from legacy
# First call (index 0) is sampling scroll with limit=10
# Second call (index 1) is migration batch with limit=500
assert mock_qdrant_client.scroll.call_count >= 2
# Check sampling scroll
sampling_call = mock_qdrant_client.scroll.call_args_list[0]
assert sampling_call.kwargs["collection_name"] == legacy_collection
assert sampling_call.kwargs["limit"] == 10
# Check migration batch scroll
migration_call = mock_qdrant_client.scroll.call_args_list[1]
assert migration_call.kwargs["collection_name"] == legacy_collection
assert migration_call.kwargs["limit"] == 500
# 4. Data upserted to new
mock_qdrant_client.upsert.assert_called()
# 5. Payload index created
mock_qdrant_client.create_payload_index.assert_called()
async def test_qdrant_no_migration_needed(mock_qdrant_client, mock_embedding_func):
"""Test scenario where new collection already exists (Case 1 in setup_collection)
When only the new collection exists and no legacy collection is found,
the implementation should:
1. Create payload index on the new collection (ensure index exists)
2. NOT attempt any data migration (no scroll calls)
"""
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
storage = QdrantVectorDBStorage(
namespace="chunks",
global_config=config,
embedding_func=mock_embedding_func,
workspace="test_ws",
)
# Only new collection exists (no legacy collection found)
mock_qdrant_client.collection_exists.side_effect = (
lambda name: name == storage.final_namespace
)
# Initialize
await storage.initialize()
# Should create payload index on the new collection (ensure index)
mock_qdrant_client.create_payload_index.assert_called_with(
collection_name=storage.final_namespace,
field_name="workspace_id",
field_schema=models.KeywordIndexParams(
type=models.KeywordIndexType.KEYWORD,
is_tenant=True,
),
)
# Should NOT migrate (no scroll calls since no legacy collection exists)
mock_qdrant_client.scroll.assert_not_called()
# ============================================================================
# Tests for scenarios described in design document (Lines 606-649)
# ============================================================================
async def test_scenario_1_new_workspace_creation(
mock_qdrant_client, mock_embedding_func
):
"""
场景1新建workspace
预期直接创建lightrag_vdb_chunks_text_embedding_3_large_3072d
"""
# Use a large embedding model
large_model_func = EmbeddingFunc(
embedding_dim=3072,
func=mock_embedding_func.func,
model_name="text-embedding-3-large",
)
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
storage = QdrantVectorDBStorage(
namespace="chunks",
global_config=config,
embedding_func=large_model_func,
workspace="test_new",
)
# Case 3: Neither legacy nor new collection exists
mock_qdrant_client.collection_exists.return_value = False
# Initialize storage
await storage.initialize()
# Verify: Should create new collection with model suffix
expected_collection = "lightrag_vdb_chunks_text_embedding_3_large_3072d"
assert storage.final_namespace == expected_collection
# Verify create_collection was called with correct name
create_calls = [
call for call in mock_qdrant_client.create_collection.call_args_list
]
assert len(create_calls) > 0
assert (
create_calls[0][0][0] == expected_collection
or create_calls[0].kwargs.get("collection_name") == expected_collection
)
# Verify no migration was attempted
mock_qdrant_client.scroll.assert_not_called()
print(
f"✅ Scenario 1: New workspace created with collection '{expected_collection}'"
)
async def test_scenario_2_legacy_upgrade_migration(
mock_qdrant_client, mock_embedding_func
):
"""
场景2从旧版本升级
已存在lightrag_vdb_chunks无后缀
预期自动迁移数据到lightrag_vdb_chunks_text_embedding_ada_002_1536d
注意:迁移后不再自动删除遗留集合,需要手动删除
"""
# Use ada-002 model
ada_func = EmbeddingFunc(
embedding_dim=1536,
func=mock_embedding_func.func,
model_name="text-embedding-ada-002",
)
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
storage = QdrantVectorDBStorage(
namespace="chunks",
global_config=config,
embedding_func=ada_func,
workspace="test_legacy",
)
# Legacy collection name (without model suffix)
legacy_collection = "lightrag_vdb_chunks"
new_collection = storage.final_namespace
# Case 4: Only legacy collection exists
mock_qdrant_client.collection_exists.side_effect = (
lambda name: name == legacy_collection
)
# Mock legacy collection info with 1536d vectors
legacy_collection_info = MagicMock()
legacy_collection_info.payload_schema = {}
legacy_collection_info.config.params.vectors.size = 1536
mock_qdrant_client.get_collection.return_value = legacy_collection_info
migration_state = {"new_workspace_count": 0}
def count_mock(collection_name, exact=True, count_filter=None):
mock_result = MagicMock()
if collection_name == legacy_collection:
mock_result.count = 150
elif collection_name == new_collection:
mock_result.count = migration_state["new_workspace_count"]
else:
mock_result.count = 0
return mock_result
mock_qdrant_client.count.side_effect = count_mock
# Mock scroll results (simulate migration in batches)
mock_points = []
for i in range(10):
point = MagicMock()
point.id = f"legacy-{i}"
point.vector = [0.1] * 1536
# No workspace_id in payload - simulates legacy data
point.payload = {"content": f"Legacy document {i}", "id": f"doc-{i}"}
mock_points.append(point)
# When payload_schema is empty, the code first samples payloads to detect workspace_id
# Then proceeds with migration batches
# Scroll calls: 1) Sampling (limit=10), 2) Migration batch, 3) End of migration
mock_qdrant_client.scroll.side_effect = [
(mock_points, "_"), # Sampling scroll - no workspace_id found in payloads
(mock_points, "offset1"), # Migration batch
([], None), # End of migration
]
def upsert_mock(*args, **kwargs):
migration_state["new_workspace_count"] = 150
return None
mock_qdrant_client.upsert.side_effect = upsert_mock
# Initialize (triggers migration)
await storage.initialize()
# Verify: New collection should be created
expected_new_collection = "lightrag_vdb_chunks_text_embedding_ada_002_1536d"
assert storage.final_namespace == expected_new_collection
# Verify migration steps
# 1. Check legacy count
mock_qdrant_client.count.assert_any_call(
collection_name=legacy_collection, exact=True
)
# 2. Create new collection
mock_qdrant_client.create_collection.assert_called()
# 3. Scroll legacy data
scroll_calls = [call for call in mock_qdrant_client.scroll.call_args_list]
assert len(scroll_calls) >= 1
assert scroll_calls[0].kwargs["collection_name"] == legacy_collection
# 4. Upsert to new collection
upsert_calls = [call for call in mock_qdrant_client.upsert.call_args_list]
assert len(upsert_calls) >= 1
assert upsert_calls[0].kwargs["collection_name"] == new_collection
# Note: Legacy collection is NOT automatically deleted after migration
# Manual deletion is required after data migration verification
print(
f"✅ Scenario 2: Legacy data migrated from '{legacy_collection}' to '{expected_new_collection}'"
)
async def test_scenario_3_multi_model_coexistence(mock_qdrant_client):
"""
场景3多模型并存
预期两个独立的collection互不干扰
"""
# Model A: bge-small with 768d
async def embed_func_a(texts, **kwargs):
return np.array([[0.1] * 768 for _ in texts])
model_a_func = EmbeddingFunc(
embedding_dim=768, func=embed_func_a, model_name="bge-small"
)
# Model B: bge-large with 1024d
async def embed_func_b(texts, **kwargs):
return np.array([[0.2] * 1024 for _ in texts])
model_b_func = EmbeddingFunc(
embedding_dim=1024, func=embed_func_b, model_name="bge-large"
)
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
# Create storage for workspace A with model A
storage_a = QdrantVectorDBStorage(
namespace="chunks",
global_config=config,
embedding_func=model_a_func,
workspace="workspace_a",
)
# Create storage for workspace B with model B
storage_b = QdrantVectorDBStorage(
namespace="chunks",
global_config=config,
embedding_func=model_b_func,
workspace="workspace_b",
)
# Verify: Collection names are different
assert storage_a.final_namespace != storage_b.final_namespace
# Verify: Model A collection
expected_collection_a = "lightrag_vdb_chunks_bge_small_768d"
assert storage_a.final_namespace == expected_collection_a
# Verify: Model B collection
expected_collection_b = "lightrag_vdb_chunks_bge_large_1024d"
assert storage_b.final_namespace == expected_collection_b
# Verify: Different embedding dimensions are preserved
assert storage_a.embedding_func.embedding_dim == 768
assert storage_b.embedding_func.embedding_dim == 1024
print("✅ Scenario 3: Multi-model coexistence verified")
print(f" - Workspace A: {expected_collection_a} (768d)")
print(f" - Workspace B: {expected_collection_b} (1024d)")
print(" - Collections are independent")
async def test_case1_empty_legacy_auto_cleanup(mock_qdrant_client, mock_embedding_func):
"""
Case 1a: 新旧collection都存在且旧库为空
预期:自动删除旧库
"""
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
storage = QdrantVectorDBStorage(
namespace="chunks",
global_config=config,
embedding_func=mock_embedding_func,
workspace="test_ws",
)
# Legacy collection name (without model suffix)
legacy_collection = "lightrag_vdb_chunks"
new_collection = storage.final_namespace
# Mock: Both collections exist
mock_qdrant_client.collection_exists.side_effect = lambda name: name in [
legacy_collection,
new_collection,
]
# Mock: Legacy collection is empty (0 records)
def count_mock(collection_name, exact=True, count_filter=None):
mock_result = MagicMock()
if collection_name == legacy_collection:
mock_result.count = 0 # Empty legacy collection
else:
mock_result.count = 100 # New collection has data
return mock_result
mock_qdrant_client.count.side_effect = count_mock
# Mock get_collection for Case 2 check
collection_info = MagicMock()
collection_info.payload_schema = {"workspace_id": True}
mock_qdrant_client.get_collection.return_value = collection_info
# Initialize storage
await storage.initialize()
# Verify: Empty legacy collection should be automatically cleaned up
# Empty collections are safe to delete without data loss risk
delete_calls = [
call for call in mock_qdrant_client.delete_collection.call_args_list
]
assert len(delete_calls) >= 1, "Empty legacy collection should be auto-deleted"
deleted_collection = (
delete_calls[0][0][0]
if delete_calls[0][0]
else delete_calls[0].kwargs.get("collection_name")
)
assert (
deleted_collection == legacy_collection
), f"Expected to delete '{legacy_collection}', but deleted '{deleted_collection}'"
print(
f"✅ Case 1a: Empty legacy collection '{legacy_collection}' auto-deleted successfully"
)
async def test_case1_nonempty_legacy_warning(mock_qdrant_client, mock_embedding_func):
"""
Case 1b: 新旧collection都存在且旧库有数据
预期:警告但不删除
"""
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
storage = QdrantVectorDBStorage(
namespace="chunks",
global_config=config,
embedding_func=mock_embedding_func,
workspace="test_ws",
)
# Legacy collection name (without model suffix)
legacy_collection = "lightrag_vdb_chunks"
new_collection = storage.final_namespace
# Mock: Both collections exist
mock_qdrant_client.collection_exists.side_effect = lambda name: name in [
legacy_collection,
new_collection,
]
# Mock: Legacy collection has data (50 records)
def count_mock(collection_name, exact=True, count_filter=None):
mock_result = MagicMock()
if collection_name == legacy_collection:
mock_result.count = 50 # Legacy has data
else:
mock_result.count = 100 # New collection has data
return mock_result
mock_qdrant_client.count.side_effect = count_mock
# Mock get_collection for Case 2 check
collection_info = MagicMock()
collection_info.payload_schema = {"workspace_id": True}
mock_qdrant_client.get_collection.return_value = collection_info
# Initialize storage
await storage.initialize()
# Verify: Legacy collection with data should be preserved
# We never auto-delete collections that contain data to prevent accidental data loss
delete_calls = [
call for call in mock_qdrant_client.delete_collection.call_args_list
]
# Check if legacy collection was deleted (it should not be)
legacy_deleted = any(
(call[0][0] if call[0] else call.kwargs.get("collection_name"))
== legacy_collection
for call in delete_calls
)
assert not legacy_deleted, "Legacy collection with data should NOT be auto-deleted"
print(
f"✅ Case 1b: Legacy collection '{legacy_collection}' with data preserved (warning only)"
)

View File

@@ -0,0 +1,122 @@
"""
Tests for UnifiedLock safety when lock is None.
This test module verifies that get_internal_lock() and get_data_init_lock()
raise RuntimeError when shared data is not initialized, preventing false
security and potential race conditions.
Design: The None check has been moved from UnifiedLock.__aenter__/__enter__
to the lock factory functions (get_internal_lock, get_data_init_lock) for
early failure detection.
Critical Bug 1 (Fixed): When self._lock is None, the code would fail with
AttributeError. Now the check is in factory functions for clearer errors.
Critical Bug 2: In __aexit__, when async_lock.release() fails, the error
recovery logic would attempt to release it again, causing double-release issues.
"""
from unittest.mock import MagicMock, AsyncMock
import pytest
from lightrag.kg.shared_storage import (
UnifiedLock,
get_internal_lock,
get_data_init_lock,
finalize_share_data,
)
class TestUnifiedLockSafety:
"""Test suite for UnifiedLock None safety checks."""
def setup_method(self):
"""Ensure shared data is finalized before each test."""
finalize_share_data()
def teardown_method(self):
"""Clean up after each test."""
finalize_share_data()
def test_get_internal_lock_raises_when_not_initialized(self):
"""
Test that get_internal_lock() raises RuntimeError when shared data is not initialized.
Scenario: Call get_internal_lock() before initialize_share_data() is called.
Expected: RuntimeError raised with clear error message.
This test verifies the None check has been moved to the factory function.
"""
with pytest.raises(
RuntimeError, match="Shared data not initialized.*initialize_share_data"
):
get_internal_lock()
def test_get_data_init_lock_raises_when_not_initialized(self):
"""
Test that get_data_init_lock() raises RuntimeError when shared data is not initialized.
Scenario: Call get_data_init_lock() before initialize_share_data() is called.
Expected: RuntimeError raised with clear error message.
This test verifies the None check has been moved to the factory function.
"""
with pytest.raises(
RuntimeError, match="Shared data not initialized.*initialize_share_data"
):
get_data_init_lock()
@pytest.mark.offline
async def test_aexit_no_double_release_on_async_lock_failure(self):
"""
Test that __aexit__ doesn't attempt to release async_lock twice when it fails.
Scenario: async_lock.release() fails during normal release.
Expected: Recovery logic should NOT attempt to release async_lock again,
preventing double-release issues.
This tests Bug 2 fix: async_lock_released tracking prevents double release.
"""
# Create mock locks
main_lock = MagicMock()
main_lock.acquire = MagicMock()
main_lock.release = MagicMock()
async_lock = AsyncMock()
async_lock.acquire = AsyncMock()
# Make async_lock.release() fail
release_call_count = 0
def mock_release_fail():
nonlocal release_call_count
release_call_count += 1
raise RuntimeError("Async lock release failed")
async_lock.release = MagicMock(side_effect=mock_release_fail)
# Create UnifiedLock with both locks (sync mode with async_lock)
lock = UnifiedLock(
lock=main_lock,
is_async=False,
name="test_double_release",
enable_logging=False,
)
lock._async_lock = async_lock
# Try to use the lock - should fail during __aexit__
try:
async with lock:
pass
except RuntimeError as e:
# Should get the async lock release error
assert "Async lock release failed" in str(e)
# Verify async_lock.release() was called only ONCE, not twice
assert (
release_call_count == 1
), f"async_lock.release() should be called only once, but was called {release_call_count} times"
# Main lock should have been released successfully
main_lock.release.assert_called_once()

View File

@@ -149,7 +149,6 @@ def _assert_no_timeline_overlap(timeline: List[Tuple[str, str]]) -> None:
@pytest.mark.offline
@pytest.mark.asyncio
async def test_pipeline_status_isolation():
"""
Test that pipeline status is isolated between different workspaces.
@@ -204,7 +203,6 @@ async def test_pipeline_status_isolation():
@pytest.mark.offline
@pytest.mark.asyncio
async def test_lock_mechanism(stress_test_mode, parallel_workers):
"""
Test that the new keyed lock mechanism works correctly without deadlocks.
@@ -274,7 +272,6 @@ async def test_lock_mechanism(stress_test_mode, parallel_workers):
@pytest.mark.offline
@pytest.mark.asyncio
async def test_backward_compatibility():
"""
Test that legacy code without workspace parameter still works correctly.
@@ -348,7 +345,6 @@ async def test_backward_compatibility():
@pytest.mark.offline
@pytest.mark.asyncio
async def test_multi_workspace_concurrency():
"""
Test that multiple workspaces can operate concurrently without interference.
@@ -432,7 +428,6 @@ async def test_multi_workspace_concurrency():
@pytest.mark.offline
@pytest.mark.asyncio
async def test_namespace_lock_reentrance():
"""
Test that NamespaceLock prevents re-entrance in the same coroutine
@@ -506,7 +501,6 @@ async def test_namespace_lock_reentrance():
@pytest.mark.offline
@pytest.mark.asyncio
async def test_different_namespace_lock_isolation():
"""
Test that locks for different namespaces (same workspace) are independent.
@@ -546,7 +540,6 @@ async def test_different_namespace_lock_isolation():
@pytest.mark.offline
@pytest.mark.asyncio
async def test_error_handling():
"""
Test error handling for invalid workspace configurations.
@@ -597,7 +590,6 @@ async def test_error_handling():
@pytest.mark.offline
@pytest.mark.asyncio
async def test_update_flags_workspace_isolation():
"""
Test that update flags are properly isolated between workspaces.
@@ -727,7 +719,6 @@ async def test_update_flags_workspace_isolation():
@pytest.mark.offline
@pytest.mark.asyncio
async def test_empty_workspace_standardization():
"""
Test that empty workspace is properly standardized to "" instead of "_".
@@ -781,7 +772,6 @@ async def test_empty_workspace_standardization():
@pytest.mark.offline
@pytest.mark.asyncio
async def test_json_kv_storage_workspace_isolation(keep_test_artifacts):
"""
Integration test: Verify JsonKVStorage properly isolates data between workspaces.
@@ -961,7 +951,6 @@ async def test_json_kv_storage_workspace_isolation(keep_test_artifacts):
@pytest.mark.offline
@pytest.mark.asyncio
async def test_lightrag_end_to_end_workspace_isolation(keep_test_artifacts):
"""
End-to-end test: Create two LightRAG instances with different workspaces,

View File

@@ -0,0 +1,288 @@
"""
Tests for workspace isolation during PostgreSQL migration.
This test module verifies that setup_table() properly filters migration data
by workspace, preventing cross-workspace data leakage during legacy table migration.
Critical Bug: Migration copied ALL records from legacy table regardless of workspace,
causing workspace A to receive workspace B's data, violating multi-tenant isolation.
"""
import pytest
from unittest.mock import AsyncMock
from lightrag.kg.postgres_impl import PGVectorStorage
class TestWorkspaceMigrationIsolation:
"""Test suite for workspace-scoped migration in PostgreSQL."""
async def test_migration_filters_by_workspace(self):
"""
Test that migration only copies data from the specified workspace.
Scenario: Legacy table contains data from multiple workspaces.
Migrate only workspace_a's data to new table.
Expected: New table contains only workspace_a data, workspace_b data excluded.
"""
db = AsyncMock()
# Configure mock return values to avoid unawaited coroutine warnings
db._create_vector_index.return_value = None
# Track state for new table count (starts at 0, increases after migration)
new_table_record_count = {"count": 0}
# Mock table existence checks
async def table_exists_side_effect(db_instance, name):
if name.lower() == "lightrag_doc_chunks": # legacy
return True
elif name.lower() == "lightrag_doc_chunks_model_1536d": # new
return False # New table doesn't exist initially
return False
# Mock data for workspace_a
mock_records_a = [
{
"id": "a1",
"workspace": "workspace_a",
"content": "content_a1",
"content_vector": [0.1] * 1536,
},
{
"id": "a2",
"workspace": "workspace_a",
"content": "content_a2",
"content_vector": [0.2] * 1536,
},
]
# Mock query responses
async def query_side_effect(sql, params, **kwargs):
multirows = kwargs.get("multirows", False)
sql_upper = sql.upper()
# Count query for new table workspace data (verification before migration)
if (
"COUNT(*)" in sql_upper
and "MODEL_1536D" in sql_upper
and "WHERE WORKSPACE" in sql_upper
):
return new_table_record_count # Initially 0
# Count query with workspace filter (legacy table) - for workspace count
elif "COUNT(*)" in sql_upper and "WHERE WORKSPACE" in sql_upper:
if params and params[0] == "workspace_a":
return {"count": 2} # workspace_a has 2 records
elif params and params[0] == "workspace_b":
return {"count": 3} # workspace_b has 3 records
return {"count": 0}
# Count query for legacy table (total, no workspace filter)
elif (
"COUNT(*)" in sql_upper
and "LIGHTRAG" in sql_upper
and "WHERE WORKSPACE" not in sql_upper
):
return {"count": 5} # Total records in legacy
# SELECT with workspace filter for migration (multirows)
elif "SELECT" in sql_upper and "FROM" in sql_upper and multirows:
workspace = params[0] if params else None
if workspace == "workspace_a":
# Handle keyset pagination: check for "id >" pattern
if "id >" in sql.lower():
# Keyset pagination: params = [workspace, last_id, limit]
last_id = params[1] if len(params) > 1 else None
# Find records after last_id
found_idx = -1
for i, rec in enumerate(mock_records_a):
if rec["id"] == last_id:
found_idx = i
break
if found_idx >= 0:
return mock_records_a[found_idx + 1 :]
return []
else:
# First batch: params = [workspace, limit]
return mock_records_a
return [] # No data for other workspaces
return {}
db.query.side_effect = query_side_effect
db.execute = AsyncMock()
# Mock check_table_exists on db
async def check_table_exists_side_effect(name):
if name.lower() == "lightrag_doc_chunks": # legacy
return True
elif name.lower() == "lightrag_doc_chunks_model_1536d": # new
return False # New table doesn't exist initially
return False
db.check_table_exists = AsyncMock(side_effect=check_table_exists_side_effect)
# Track migration through _run_with_retry calls
migration_executed = []
async def mock_run_with_retry(operation, *args, **kwargs):
migration_executed.append(True)
new_table_record_count["count"] = 2 # Simulate 2 records migrated
return None
db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry)
# Migrate for workspace_a only - correct parameter order
await PGVectorStorage.setup_table(
db,
"LIGHTRAG_DOC_CHUNKS_model_1536d",
workspace="workspace_a", # CRITICAL: Only migrate workspace_a
embedding_dim=1536,
legacy_table_name="LIGHTRAG_DOC_CHUNKS",
base_table="LIGHTRAG_DOC_CHUNKS",
)
# Verify the migration was triggered
assert (
len(migration_executed) > 0
), "Migration should have been executed for workspace_a"
async def test_migration_without_workspace_raises_error(self):
"""
Test that migration without workspace parameter raises ValueError.
Scenario: setup_table called without workspace parameter.
Expected: ValueError is raised because workspace is required.
"""
db = AsyncMock()
# workspace is now a required parameter - calling with None should raise ValueError
with pytest.raises(ValueError, match="workspace must be provided"):
await PGVectorStorage.setup_table(
db,
"lightrag_doc_chunks_model_1536d",
workspace=None, # No workspace - should raise ValueError
embedding_dim=1536,
legacy_table_name="lightrag_doc_chunks",
base_table="lightrag_doc_chunks",
)
async def test_no_cross_workspace_contamination(self):
"""
Test that workspace B's migration doesn't include workspace A's data.
Scenario: Migration for workspace_b only.
Expected: Only workspace_b data is queried, workspace_a data excluded.
"""
db = AsyncMock()
# Configure mock return values to avoid unawaited coroutine warnings
db._create_vector_index.return_value = None
# Track which workspace is being queried
queried_workspace = None
new_table_count = {"count": 0}
# Mock data for workspace_b
mock_records_b = [
{
"id": "b1",
"workspace": "workspace_b",
"content": "content_b1",
"content_vector": [0.3] * 1536,
},
]
async def table_exists_side_effect(db_instance, name):
if name.lower() == "lightrag_doc_chunks": # legacy
return True
elif name.lower() == "lightrag_doc_chunks_model_1536d": # new
return False
return False
async def query_side_effect(sql, params, **kwargs):
nonlocal queried_workspace
multirows = kwargs.get("multirows", False)
sql_upper = sql.upper()
# Count query for new table workspace data (should be 0 initially)
if (
"COUNT(*)" in sql_upper
and "MODEL_1536D" in sql_upper
and "WHERE WORKSPACE" in sql_upper
):
return new_table_count
# Count query with workspace filter (legacy table)
elif "COUNT(*)" in sql_upper and "WHERE WORKSPACE" in sql_upper:
queried_workspace = params[0] if params else None
return {"count": 1} # 1 record for the queried workspace
# Count query for legacy table total (no workspace filter)
elif (
"COUNT(*)" in sql_upper
and "LIGHTRAG" in sql_upper
and "WHERE WORKSPACE" not in sql_upper
):
return {"count": 3} # 3 total records in legacy
# SELECT with workspace filter for migration (multirows)
elif "SELECT" in sql_upper and "FROM" in sql_upper and multirows:
workspace = params[0] if params else None
if workspace == "workspace_b":
# Handle keyset pagination: check for "id >" pattern
if "id >" in sql.lower():
# Keyset pagination: params = [workspace, last_id, limit]
last_id = params[1] if len(params) > 1 else None
# Find records after last_id
found_idx = -1
for i, rec in enumerate(mock_records_b):
if rec["id"] == last_id:
found_idx = i
break
if found_idx >= 0:
return mock_records_b[found_idx + 1 :]
return []
else:
# First batch: params = [workspace, limit]
return mock_records_b
return [] # No data for other workspaces
return {}
db.query.side_effect = query_side_effect
db.execute = AsyncMock()
# Mock check_table_exists on db
async def check_table_exists_side_effect(name):
if name.lower() == "lightrag_doc_chunks": # legacy
return True
elif name.lower() == "lightrag_doc_chunks_model_1536d": # new
return False
return False
db.check_table_exists = AsyncMock(side_effect=check_table_exists_side_effect)
# Track migration through _run_with_retry calls
migration_executed = []
async def mock_run_with_retry(operation, *args, **kwargs):
migration_executed.append(True)
new_table_count["count"] = 1 # Simulate migration
return None
db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry)
# Migrate workspace_b - correct parameter order
await PGVectorStorage.setup_table(
db,
"LIGHTRAG_DOC_CHUNKS_model_1536d",
workspace="workspace_b", # Only migrate workspace_b
embedding_dim=1536,
legacy_table_name="LIGHTRAG_DOC_CHUNKS",
base_table="LIGHTRAG_DOC_CHUNKS",
)
# Verify only workspace_b was queried
assert queried_workspace == "workspace_b", "Should only query workspace_b"