Merge pull request #2513 from HKUDS/feature/vectordb-model-isolation
feat: Implement Vector Database Model Isolation and Auto-Migration
This commit is contained in:
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -13,7 +13,7 @@ jobs:
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.12', '3.13', '3.14']
|
||||
python-version: ['3.12', '3.14']
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
18
README-zh.md
18
README-zh.md
@@ -286,7 +286,7 @@ if __name__ == "__main__":
|
||||
<summary> 参数 </summary>
|
||||
|
||||
| **参数** | **类型** | **说明** | **默认值** |
|
||||
|--------------|----------|-----------------|-------------|
|
||||
| -------------- | ---------- | ----------------- | ------------- |
|
||||
| **working_dir** | `str` | 存储缓存的目录 | `lightrag_cache+timestamp` |
|
||||
| **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` |
|
||||
| **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` |
|
||||
@@ -425,7 +425,7 @@ async def llm_model_func(
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=4096, max_token_size=8192)
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=4096, max_token_size=8192, model_name="solar-embedding-1-large-query")
|
||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||
return await openai_embed.func(
|
||||
texts,
|
||||
@@ -490,7 +490,7 @@ import numpy as np
|
||||
from lightrag.utils import wrap_embedding_func_with_attrs
|
||||
from lightrag.llm.ollama import ollama_model_complete, ollama_embed
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192)
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192, model_name="nomic-embed-text")
|
||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||
return await ollama_embed.func(texts, embed_model="nomic-embed-text")
|
||||
|
||||
@@ -542,7 +542,7 @@ import numpy as np
|
||||
from lightrag.utils import wrap_embedding_func_with_attrs
|
||||
from lightrag.llm.ollama import ollama_model_complete, ollama_embed
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192)
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192, model_name="nomic-embed-text")
|
||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||
return await ollama_embed.func(texts, embed_model="nomic-embed-text")
|
||||
|
||||
@@ -1633,24 +1633,24 @@ LightRAG使用以下提示生成高级查询,相应的代码在`example/genera
|
||||
|
||||
### 总体性能表
|
||||
|
||||
| |**农业**| |**计算机科学**| |**法律**| |**混合**| |
|
||||
||**农业**||**计算机科学**||**法律**||**混合**||
|
||||
|----------------------|---------------|------------|------|------------|---------|------------|-------|------------|
|
||||
| |NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|
|
||||
||NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|
|
||||
|**全面性**|32.4%|**67.6%**|38.4%|**61.6%**|16.4%|**83.6%**|38.8%|**61.2%**|
|
||||
|**多样性**|23.6%|**76.4%**|38.0%|**62.0%**|13.6%|**86.4%**|32.4%|**67.6%**|
|
||||
|**赋能性**|32.4%|**67.6%**|38.8%|**61.2%**|16.4%|**83.6%**|42.8%|**57.2%**|
|
||||
|**总体**|32.4%|**67.6%**|38.8%|**61.2%**|15.2%|**84.8%**|40.0%|**60.0%**|
|
||||
| |RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|
|
||||
||RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|
|
||||
|**全面性**|31.6%|**68.4%**|38.8%|**61.2%**|15.2%|**84.8%**|39.2%|**60.8%**|
|
||||
|**多样性**|29.2%|**70.8%**|39.2%|**60.8%**|11.6%|**88.4%**|30.8%|**69.2%**|
|
||||
|**赋能性**|31.6%|**68.4%**|36.4%|**63.6%**|15.2%|**84.8%**|42.4%|**57.6%**|
|
||||
|**总体**|32.4%|**67.6%**|38.0%|**62.0%**|14.4%|**85.6%**|40.0%|**60.0%**|
|
||||
| |HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**|
|
||||
||HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**|
|
||||
|**全面性**|26.0%|**74.0%**|41.6%|**58.4%**|26.8%|**73.2%**|40.4%|**59.6%**|
|
||||
|**多样性**|24.0%|**76.0%**|38.8%|**61.2%**|20.0%|**80.0%**|32.4%|**67.6%**|
|
||||
|**赋能性**|25.2%|**74.8%**|40.8%|**59.2%**|26.0%|**74.0%**|46.0%|**54.0%**|
|
||||
|**总体**|24.8%|**75.2%**|41.6%|**58.4%**|26.4%|**73.6%**|42.4%|**57.6%**|
|
||||
| |GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|
|
||||
||GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|
|
||||
|**全面性**|45.6%|**54.4%**|48.4%|**51.6%**|48.4%|**51.6%**|**50.4%**|49.6%|
|
||||
|**多样性**|22.8%|**77.2%**|40.8%|**59.2%**|26.4%|**73.6%**|36.0%|**64.0%**|
|
||||
|**赋能性**|41.2%|**58.8%**|45.2%|**54.8%**|43.6%|**56.4%**|**50.8%**|49.2%|
|
||||
|
||||
20
README.md
20
README.md
@@ -287,9 +287,9 @@ A full list of LightRAG init parameters:
|
||||
<summary> Parameters </summary>
|
||||
|
||||
| **Parameter** | **Type** | **Explanation** | **Default** |
|
||||
|--------------|----------|-----------------|-------------|
|
||||
| -------------- | ---------- | ----------------- | ------------- |
|
||||
| **working_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` |
|
||||
| **workspace** | str | Workspace name for data isolation between different LightRAG Instances | |
|
||||
| **workspace** | str | Workspace name for data isolation between different LightRAG Instances | |
|
||||
| **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` |
|
||||
| **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` |
|
||||
| **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`,`Neo4JStorage`,`PGGraphStorage`,`AGEStorage` | `NetworkXStorage` |
|
||||
@@ -421,7 +421,7 @@ async def llm_model_func(
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=4096, max_token_size=8192)
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=4096, max_token_size=8192, model_name="solar-embedding-1-large-query")
|
||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||
return await openai_embed.func(
|
||||
texts,
|
||||
@@ -488,7 +488,7 @@ import numpy as np
|
||||
from lightrag.utils import wrap_embedding_func_with_attrs
|
||||
from lightrag.llm.ollama import ollama_model_complete, ollama_embed
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192)
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192, model_name="nomic-embed-text")
|
||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||
return await ollama_embed.func(texts, embed_model="nomic-embed-text")
|
||||
|
||||
@@ -540,7 +540,7 @@ import numpy as np
|
||||
from lightrag.utils import wrap_embedding_func_with_attrs
|
||||
from lightrag.llm.ollama import ollama_model_complete, ollama_embed
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192)
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192, model_name="nomic-embed-text")
|
||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||
return await ollama_embed.func(texts, embed_model="nomic-embed-text")
|
||||
|
||||
@@ -1701,24 +1701,24 @@ Output your evaluation in the following JSON format:
|
||||
|
||||
### Overall Performance Table
|
||||
|
||||
| |**Agriculture**| |**CS**| |**Legal**| |**Mix**| |
|
||||
||**Agriculture**||**CS**||**Legal**||**Mix**||
|
||||
|----------------------|---------------|------------|------|------------|---------|------------|-------|------------|
|
||||
| |NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|
|
||||
||NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|
|
||||
|**Comprehensiveness**|32.4%|**67.6%**|38.4%|**61.6%**|16.4%|**83.6%**|38.8%|**61.2%**|
|
||||
|**Diversity**|23.6%|**76.4%**|38.0%|**62.0%**|13.6%|**86.4%**|32.4%|**67.6%**|
|
||||
|**Empowerment**|32.4%|**67.6%**|38.8%|**61.2%**|16.4%|**83.6%**|42.8%|**57.2%**|
|
||||
|**Overall**|32.4%|**67.6%**|38.8%|**61.2%**|15.2%|**84.8%**|40.0%|**60.0%**|
|
||||
| |RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|
|
||||
||RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|
|
||||
|**Comprehensiveness**|31.6%|**68.4%**|38.8%|**61.2%**|15.2%|**84.8%**|39.2%|**60.8%**|
|
||||
|**Diversity**|29.2%|**70.8%**|39.2%|**60.8%**|11.6%|**88.4%**|30.8%|**69.2%**|
|
||||
|**Empowerment**|31.6%|**68.4%**|36.4%|**63.6%**|15.2%|**84.8%**|42.4%|**57.6%**|
|
||||
|**Overall**|32.4%|**67.6%**|38.0%|**62.0%**|14.4%|**85.6%**|40.0%|**60.0%**|
|
||||
| |HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**|
|
||||
||HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**|
|
||||
|**Comprehensiveness**|26.0%|**74.0%**|41.6%|**58.4%**|26.8%|**73.2%**|40.4%|**59.6%**|
|
||||
|**Diversity**|24.0%|**76.0%**|38.8%|**61.2%**|20.0%|**80.0%**|32.4%|**67.6%**|
|
||||
|**Empowerment**|25.2%|**74.8%**|40.8%|**59.2%**|26.0%|**74.0%**|46.0%|**54.0%**|
|
||||
|**Overall**|24.8%|**75.2%**|41.6%|**58.4%**|26.4%|**73.6%**|42.4%|**57.6%**|
|
||||
| |GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|
|
||||
||GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|
|
||||
|**Comprehensiveness**|45.6%|**54.4%**|48.4%|**51.6%**|48.4%|**51.6%**|**50.4%**|49.6%|
|
||||
|**Diversity**|22.8%|**77.2%**|40.8%|**59.2%**|26.4%|**73.6%**|36.0%|**64.0%**|
|
||||
|**Empowerment**|41.2%|**58.8%**|45.2%|**54.8%**|43.6%|**56.4%**|**50.8%**|49.2%|
|
||||
|
||||
@@ -868,6 +868,7 @@ def create_app(args):
|
||||
func=optimized_embedding_function,
|
||||
max_token_size=final_max_token_size,
|
||||
send_dimensions=False, # Will be set later based on binding requirements
|
||||
model_name=model,
|
||||
)
|
||||
|
||||
# Log final embedding configuration
|
||||
|
||||
@@ -220,6 +220,37 @@ class BaseVectorStorage(StorageNameSpace, ABC):
|
||||
cosine_better_than_threshold: float = field(default=0.2)
|
||||
meta_fields: set[str] = field(default_factory=set)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate required embedding_func for vector storage."""
|
||||
if self.embedding_func is None:
|
||||
raise ValueError(
|
||||
"embedding_func is required for vector storage. "
|
||||
"Please provide a valid EmbeddingFunc instance."
|
||||
)
|
||||
|
||||
def _generate_collection_suffix(self) -> str | None:
|
||||
"""Generates collection/table suffix from embedding_func.
|
||||
|
||||
Return suffix if model_name exists in embedding_func, otherwise return None.
|
||||
Note: embedding_func is guaranteed to exist (validated in __post_init__).
|
||||
|
||||
Returns:
|
||||
str | None: Suffix string e.g. "text_embedding_3_large_3072d", or None if model_name not available
|
||||
"""
|
||||
import re
|
||||
|
||||
# Check if model_name exists (model_name is optional in EmbeddingFunc)
|
||||
model_name = getattr(self.embedding_func, "model_name", None)
|
||||
if not model_name:
|
||||
return None
|
||||
|
||||
# embedding_dim is required in EmbeddingFunc
|
||||
embedding_dim = self.embedding_func.embedding_dim
|
||||
|
||||
# Generate suffix: clean model name and append dimension
|
||||
safe_model_name = re.sub(r"[^a-zA-Z0-9_]", "_", model_name.lower())
|
||||
return f"{safe_model_name}_{embedding_dim}d"
|
||||
|
||||
@abstractmethod
|
||||
async def query(
|
||||
self, query: str, top_k: int, query_embedding: list[float] = None
|
||||
|
||||
@@ -128,8 +128,8 @@ class ChunkTokenLimitExceededError(ValueError):
|
||||
self.chunk_preview = truncated_preview
|
||||
|
||||
|
||||
class QdrantMigrationError(Exception):
|
||||
"""Raised when Qdrant data migration from legacy collections fails."""
|
||||
class DataMigrationError(Exception):
|
||||
"""Raised when data migration from legacy collection/table fails."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
|
||||
@@ -28,6 +28,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Grab config values if available
|
||||
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
||||
@@ -358,9 +359,22 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
return
|
||||
|
||||
dim_mismatch = False
|
||||
try:
|
||||
# Load the Faiss index
|
||||
self._index = faiss.read_index(self._faiss_index_file)
|
||||
|
||||
# Verify dimension consistency between loaded index and embedding function
|
||||
if self._index.d != self._dim:
|
||||
error_msg = (
|
||||
f"Dimension mismatch: loaded Faiss index has dimension {self._index.d}, "
|
||||
f"but embedding function expects dimension {self._dim}. "
|
||||
f"Please ensure the embedding model matches the stored index or rebuild the index."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
dim_mismatch = True
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Load metadata
|
||||
with open(self._meta_file, "r", encoding="utf-8") as f:
|
||||
stored_dict = json.load(f)
|
||||
@@ -375,6 +389,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
f"[{self.workspace}] Faiss index loaded with {self._index.ntotal} vectors from {self._faiss_index_file}"
|
||||
)
|
||||
except Exception as e:
|
||||
if dim_mismatch:
|
||||
raise
|
||||
logger.error(
|
||||
f"[{self.workspace}] Failed to load Faiss index or metadata: {e}"
|
||||
)
|
||||
|
||||
@@ -934,6 +934,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
raise
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Check for MILVUS_WORKSPACE environment variable first (higher priority)
|
||||
# This allows administrators to force a specific workspace for all Milvus storage instances
|
||||
milvus_workspace = os.environ.get("MILVUS_WORKSPACE")
|
||||
|
||||
@@ -89,7 +89,7 @@ class MongoKVStorage(BaseKVStorage):
|
||||
global_config=global_config,
|
||||
embedding_func=embedding_func,
|
||||
)
|
||||
self.__post_init__()
|
||||
# __post_init__() is automatically called by super().__init__()
|
||||
|
||||
def __post_init__(self):
|
||||
# Check for MONGODB_WORKSPACE environment variable first (higher priority)
|
||||
@@ -317,7 +317,7 @@ class MongoDocStatusStorage(DocStatusStorage):
|
||||
global_config=global_config,
|
||||
embedding_func=embedding_func,
|
||||
)
|
||||
self.__post_init__()
|
||||
# __post_init__() is automatically called by super().__init__()
|
||||
|
||||
def __post_init__(self):
|
||||
# Check for MONGODB_WORKSPACE environment variable first (higher priority)
|
||||
@@ -2052,9 +2052,12 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||
embedding_func=embedding_func,
|
||||
meta_fields=meta_fields or set(),
|
||||
)
|
||||
self.__post_init__()
|
||||
# __post_init__() is automatically called by super().__init__()
|
||||
|
||||
def __post_init__(self):
|
||||
# Call parent class __post_init__ to validate embedding_func
|
||||
super().__post_init__()
|
||||
|
||||
# Check for MONGODB_WORKSPACE environment variable first (higher priority)
|
||||
# This allows administrators to force a specific workspace for all MongoDB storage instances
|
||||
mongodb_workspace = os.environ.get("MONGODB_WORKSPACE")
|
||||
@@ -2131,8 +2134,32 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||
indexes = await indexes_cursor.to_list(length=None)
|
||||
for index in indexes:
|
||||
if index["name"] == self._index_name:
|
||||
# Check if the existing index has matching vector dimensions
|
||||
existing_dim = None
|
||||
definition = index.get("latestDefinition", {})
|
||||
fields = definition.get("fields", [])
|
||||
for field in fields:
|
||||
if (
|
||||
field.get("type") == "vector"
|
||||
and field.get("path") == "vector"
|
||||
):
|
||||
existing_dim = field.get("numDimensions")
|
||||
break
|
||||
|
||||
expected_dim = self.embedding_func.embedding_dim
|
||||
|
||||
if existing_dim is not None and existing_dim != expected_dim:
|
||||
error_msg = (
|
||||
f"Vector dimension mismatch! Index '{self._index_name}' has "
|
||||
f"dimension {existing_dim}, but current embedding model expects "
|
||||
f"dimension {expected_dim}. Please drop the existing index or "
|
||||
f"use an embedding model with matching dimensions."
|
||||
)
|
||||
logger.error(f"[{self.workspace}] {error_msg}")
|
||||
raise ValueError(error_msg)
|
||||
|
||||
logger.info(
|
||||
f"[{self.workspace}] vector index {self._index_name} already exist"
|
||||
f"[{self.workspace}] vector index {self._index_name} already exists with matching dimensions ({expected_dim})"
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ from .shared_storage import (
|
||||
@dataclass
|
||||
class NanoVectorDBStorage(BaseVectorStorage):
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Initialize basic attributes
|
||||
self._client = None
|
||||
self._storage_lock = None
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -10,7 +10,7 @@ import numpy as np
|
||||
import pipmaster as pm
|
||||
|
||||
from ..base import BaseVectorStorage
|
||||
from ..exceptions import QdrantMigrationError
|
||||
from ..exceptions import DataMigrationError
|
||||
from ..kg.shared_storage import get_data_init_lock
|
||||
from ..utils import compute_mdhash_id, logger
|
||||
|
||||
@@ -66,6 +66,51 @@ def workspace_filter_condition(workspace: str) -> models.FieldCondition:
|
||||
)
|
||||
|
||||
|
||||
def _find_legacy_collection(
|
||||
client: QdrantClient,
|
||||
namespace: str,
|
||||
workspace: str = None,
|
||||
model_suffix: str = None,
|
||||
) -> str | None:
|
||||
"""
|
||||
Find legacy collection with backward compatibility support.
|
||||
|
||||
This function tries multiple naming patterns to locate legacy collections
|
||||
created by older versions of LightRAG:
|
||||
|
||||
1. lightrag_vdb_{namespace} - if model_suffix is provided (HIGHEST PRIORITY)
|
||||
2. {workspace}_{namespace} or {namespace} - no matter if model_suffix is provided or not
|
||||
3. lightrag_vdb_{namespace} - fall back value no matter if model_suffix is provided or not (LOWEST PRIORITY)
|
||||
|
||||
Args:
|
||||
client: QdrantClient instance
|
||||
namespace: Base namespace (e.g., "chunks", "entities")
|
||||
workspace: Optional workspace identifier
|
||||
model_suffix: Optional model suffix for new collection
|
||||
|
||||
Returns:
|
||||
Collection name if found, None otherwise
|
||||
"""
|
||||
# Try multiple naming patterns for backward compatibility
|
||||
# More specific names (with workspace) have higher priority
|
||||
candidates = [
|
||||
f"lightrag_vdb_{namespace}" if model_suffix else None,
|
||||
f"{workspace}_{namespace}" if workspace else None,
|
||||
f"lightrag_vdb_{namespace}",
|
||||
namespace,
|
||||
]
|
||||
|
||||
for candidate in candidates:
|
||||
if candidate and client.collection_exists(candidate):
|
||||
logger.info(
|
||||
f"Qdrant: Found legacy collection '{candidate}' "
|
||||
f"(namespace={namespace}, workspace={workspace or 'none'})"
|
||||
)
|
||||
return candidate
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
@@ -79,63 +124,52 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
embedding_func=embedding_func,
|
||||
meta_fields=meta_fields or set(),
|
||||
)
|
||||
self.__post_init__()
|
||||
# __post_init__() is automatically called by super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def setup_collection(
|
||||
client: QdrantClient,
|
||||
collection_name: str,
|
||||
legacy_namespace: str = None,
|
||||
workspace: str = None,
|
||||
**kwargs,
|
||||
namespace: str,
|
||||
workspace: str,
|
||||
vectors_config: models.VectorParams,
|
||||
hnsw_config: models.HnswConfigDiff,
|
||||
model_suffix: str,
|
||||
):
|
||||
"""
|
||||
Setup Qdrant collection with migration support from legacy collections.
|
||||
|
||||
Ensure final collection has workspace isolation index.
|
||||
Check vector dimension compatibility before new collection creation.
|
||||
Drop legacy collection if it exists and is empty.
|
||||
Only migrate data from legacy collection to new collection when new collection first created and legacy collection is not empty.
|
||||
|
||||
Args:
|
||||
client: QdrantClient instance
|
||||
collection_name: Name of the new collection
|
||||
legacy_namespace: Name of the legacy collection (if exists)
|
||||
collection_name: Name of the final collection
|
||||
namespace: Base namespace (e.g., "chunks", "entities")
|
||||
workspace: Workspace identifier for data isolation
|
||||
**kwargs: Additional arguments for collection creation (vectors_config, hnsw_config, etc.)
|
||||
vectors_config: Vector configuration parameters for the collection
|
||||
hnsw_config: HNSW index configuration diff for the collection
|
||||
"""
|
||||
if not namespace or not workspace:
|
||||
raise ValueError("namespace and workspace must be provided")
|
||||
|
||||
workspace_count_filter = models.Filter(
|
||||
must=[workspace_filter_condition(workspace)]
|
||||
)
|
||||
|
||||
new_collection_exists = client.collection_exists(collection_name)
|
||||
legacy_exists = legacy_namespace and client.collection_exists(legacy_namespace)
|
||||
legacy_collection = _find_legacy_collection(
|
||||
client, namespace, workspace, model_suffix
|
||||
)
|
||||
|
||||
# Case 1: Both new and legacy collections exist - Warning only (no migration)
|
||||
if new_collection_exists and legacy_exists:
|
||||
logger.warning(
|
||||
f"Qdrant: Legacy collection '{legacy_namespace}' still exist. Remove it if migration is complete."
|
||||
)
|
||||
return
|
||||
|
||||
# Case 2: Only new collection exists - Ensure index exists
|
||||
if new_collection_exists:
|
||||
# Check if workspace index exists, create if missing
|
||||
try:
|
||||
collection_info = client.get_collection(collection_name)
|
||||
if WORKSPACE_ID_FIELD not in collection_info.payload_schema:
|
||||
logger.info(
|
||||
f"Qdrant: Creating missing workspace index for '{collection_name}'"
|
||||
)
|
||||
client.create_payload_index(
|
||||
collection_name=collection_name,
|
||||
field_name=WORKSPACE_ID_FIELD,
|
||||
field_schema=models.KeywordIndexParams(
|
||||
type=models.KeywordIndexType.KEYWORD,
|
||||
is_tenant=True,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Qdrant: Could not verify/create workspace index for '{collection_name}': {e}"
|
||||
)
|
||||
return
|
||||
|
||||
# Case 3: Neither exists - Create new collection
|
||||
if not legacy_exists:
|
||||
logger.info(f"Qdrant: Creating new collection '{collection_name}'")
|
||||
client.create_collection(collection_name, **kwargs)
|
||||
# Case 1: Only new collection exists or new collection is the same as legacy collection
|
||||
# No data migration needed, and ensuring index is created then return
|
||||
if (new_collection_exists and not legacy_collection) or (
|
||||
collection_name == legacy_collection
|
||||
):
|
||||
# create_payload_index return without error if index already exists
|
||||
client.create_payload_index(
|
||||
collection_name=collection_name,
|
||||
field_name=WORKSPACE_ID_FIELD,
|
||||
@@ -144,132 +178,244 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
is_tenant=True,
|
||||
),
|
||||
)
|
||||
logger.info(f"Qdrant: Collection '{collection_name}' created successfully")
|
||||
new_workspace_count = client.count(
|
||||
collection_name=collection_name,
|
||||
count_filter=workspace_count_filter,
|
||||
exact=True,
|
||||
).count
|
||||
|
||||
# Skip data migration if new collection already has workspace data
|
||||
if new_workspace_count == 0 and not (collection_name == legacy_collection):
|
||||
logger.warning(
|
||||
f"Qdrant: workspace data in collection '{collection_name}' is empty. "
|
||||
f"Ensure it is caused by new workspace setup and not an unexpected embedding model change."
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
# Case 4: Only legacy exists - Migrate data
|
||||
logger.info(
|
||||
f"Qdrant: Migrating data from legacy collection '{legacy_namespace}'"
|
||||
legacy_count = None
|
||||
if not new_collection_exists:
|
||||
# Check vector dimension compatibility before creating new collection
|
||||
if legacy_collection:
|
||||
legacy_count = client.count(
|
||||
collection_name=legacy_collection, exact=True
|
||||
).count
|
||||
if legacy_count > 0:
|
||||
legacy_info = client.get_collection(legacy_collection)
|
||||
legacy_dim = legacy_info.config.params.vectors.size
|
||||
|
||||
if vectors_config.size and legacy_dim != vectors_config.size:
|
||||
logger.error(
|
||||
f"Qdrant: Dimension mismatch detected! "
|
||||
f"Legacy collection '{legacy_collection}' has {legacy_dim}d vectors, "
|
||||
f"but new embedding model expects {vectors_config.size}d."
|
||||
)
|
||||
|
||||
raise DataMigrationError(
|
||||
f"Dimension mismatch between legacy collection '{legacy_collection}' "
|
||||
f"and new collection. Expected {vectors_config.size}d but got {legacy_dim}d."
|
||||
)
|
||||
|
||||
client.create_collection(
|
||||
collection_name, vectors_config=vectors_config, hnsw_config=hnsw_config
|
||||
)
|
||||
logger.info(f"Qdrant: Collection '{collection_name}' created successfully")
|
||||
if not legacy_collection:
|
||||
logger.warning(
|
||||
"Qdrant: Ensure this new collection creation is caused by new workspace setup and not an unexpected embedding model change."
|
||||
)
|
||||
|
||||
# create_payload_index return without error if index already exists
|
||||
client.create_payload_index(
|
||||
collection_name=collection_name,
|
||||
field_name=WORKSPACE_ID_FIELD,
|
||||
field_schema=models.KeywordIndexParams(
|
||||
type=models.KeywordIndexType.KEYWORD,
|
||||
is_tenant=True,
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
# Get legacy collection count
|
||||
legacy_count = client.count(
|
||||
collection_name=legacy_namespace, exact=True
|
||||
).count
|
||||
logger.info(f"Qdrant: Found {legacy_count} records in legacy collection")
|
||||
|
||||
# Case 2: Legacy collection exist
|
||||
if legacy_collection:
|
||||
# Only drop legacy collection if it's empty
|
||||
if legacy_count is None:
|
||||
legacy_count = client.count(
|
||||
collection_name=legacy_collection, exact=True
|
||||
).count
|
||||
if legacy_count == 0:
|
||||
logger.info("Qdrant: Legacy collection is empty, skipping migration")
|
||||
# Create new empty collection
|
||||
client.create_collection(collection_name, **kwargs)
|
||||
client.create_payload_index(
|
||||
collection_name=collection_name,
|
||||
field_name=WORKSPACE_ID_FIELD,
|
||||
field_schema=models.KeywordIndexParams(
|
||||
type=models.KeywordIndexType.KEYWORD,
|
||||
is_tenant=True,
|
||||
),
|
||||
client.delete_collection(collection_name=legacy_collection)
|
||||
logger.info(
|
||||
f"Qdrant: Empty legacy collection '{legacy_collection}' deleted successfully"
|
||||
)
|
||||
return
|
||||
|
||||
# Create new collection first
|
||||
logger.info(f"Qdrant: Creating new collection '{collection_name}'")
|
||||
client.create_collection(collection_name, **kwargs)
|
||||
new_workspace_count = client.count(
|
||||
collection_name=collection_name,
|
||||
count_filter=workspace_count_filter,
|
||||
exact=True,
|
||||
).count
|
||||
|
||||
# Batch migration (500 records per batch)
|
||||
migrated_count = 0
|
||||
offset = None
|
||||
batch_size = 500
|
||||
|
||||
while True:
|
||||
# Scroll through legacy data
|
||||
result = client.scroll(
|
||||
collection_name=legacy_namespace,
|
||||
limit=batch_size,
|
||||
offset=offset,
|
||||
with_vectors=True,
|
||||
with_payload=True,
|
||||
# Skip data migration if new collection already has workspace data
|
||||
if new_workspace_count > 0:
|
||||
logger.warning(
|
||||
f"Qdrant: Both new and legacy collection have data. "
|
||||
f"{legacy_count} records in {legacy_collection} require manual deletion after migration verification."
|
||||
)
|
||||
points, next_offset = result
|
||||
return
|
||||
|
||||
if not points:
|
||||
break
|
||||
# Case 3: Only legacy exists - migrate data from legacy collection to new collection
|
||||
# Check if legacy collection has workspace_id to determine migration strategy
|
||||
# Note: payload_schema only reflects INDEXED fields, so we also sample
|
||||
# actual payloads to detect unindexed workspace_id fields
|
||||
legacy_info = client.get_collection(legacy_collection)
|
||||
has_workspace_index = WORKSPACE_ID_FIELD in (
|
||||
legacy_info.payload_schema or {}
|
||||
)
|
||||
|
||||
# Transform points for new collection
|
||||
new_points = []
|
||||
for point in points:
|
||||
# Add workspace_id to payload
|
||||
new_payload = dict(point.payload or {})
|
||||
new_payload[WORKSPACE_ID_FIELD] = workspace or DEFAULT_WORKSPACE
|
||||
|
||||
# Create new point with workspace-prefixed ID
|
||||
original_id = new_payload.get(ID_FIELD)
|
||||
if original_id:
|
||||
new_point_id = compute_mdhash_id_for_qdrant(
|
||||
original_id, prefix=workspace or DEFAULT_WORKSPACE
|
||||
# Detect workspace_id field presence by sampling payloads if not indexed
|
||||
# This prevents cross-workspace data leakage when workspace_id exists but isn't indexed
|
||||
has_workspace_field = has_workspace_index
|
||||
if not has_workspace_index:
|
||||
# Sample a small batch of points to check for workspace_id in payloads
|
||||
# All points must have workspace_id if any point has it
|
||||
sample_result = client.scroll(
|
||||
collection_name=legacy_collection,
|
||||
limit=10, # Small sample is sufficient for detection
|
||||
with_payload=True,
|
||||
with_vectors=False,
|
||||
)
|
||||
sample_points, _ = sample_result
|
||||
for point in sample_points:
|
||||
if point.payload and WORKSPACE_ID_FIELD in point.payload:
|
||||
has_workspace_field = True
|
||||
logger.info(
|
||||
f"Qdrant: Detected unindexed {WORKSPACE_ID_FIELD} field "
|
||||
f"in legacy collection '{legacy_collection}' via payload sampling"
|
||||
)
|
||||
else:
|
||||
# Fallback: use original point ID
|
||||
new_point_id = str(point.id)
|
||||
break
|
||||
|
||||
new_points.append(
|
||||
models.PointStruct(
|
||||
id=new_point_id,
|
||||
vector=point.vector,
|
||||
payload=new_payload,
|
||||
# Build workspace filter if legacy collection has workspace support
|
||||
# This prevents cross-workspace data leakage during migration
|
||||
legacy_scroll_filter = None
|
||||
if has_workspace_field:
|
||||
legacy_scroll_filter = models.Filter(
|
||||
must=[workspace_filter_condition(workspace)]
|
||||
)
|
||||
# Recount with workspace filter for accurate migration tracking
|
||||
legacy_count = client.count(
|
||||
collection_name=legacy_collection,
|
||||
count_filter=legacy_scroll_filter,
|
||||
exact=True,
|
||||
).count
|
||||
logger.info(
|
||||
f"Qdrant: Legacy collection has workspace support, "
|
||||
f"filtering to {legacy_count} records for workspace '{workspace}'"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Qdrant: Found legacy collection '{legacy_collection}' with {legacy_count} records to migrate."
|
||||
)
|
||||
logger.info(
|
||||
f"Qdrant: Migrating data from legacy collection '{legacy_collection}' to new collection '{collection_name}'"
|
||||
)
|
||||
|
||||
try:
|
||||
# Batch migration (500 records per batch)
|
||||
migrated_count = 0
|
||||
offset = None
|
||||
batch_size = 500
|
||||
|
||||
while True:
|
||||
# Scroll through legacy data with optional workspace filter
|
||||
result = client.scroll(
|
||||
collection_name=legacy_collection,
|
||||
scroll_filter=legacy_scroll_filter,
|
||||
limit=batch_size,
|
||||
offset=offset,
|
||||
with_vectors=True,
|
||||
with_payload=True,
|
||||
)
|
||||
points, next_offset = result
|
||||
|
||||
if not points:
|
||||
break
|
||||
|
||||
# Transform points for new collection
|
||||
new_points = []
|
||||
for point in points:
|
||||
# Set workspace_id in payload
|
||||
new_payload = dict(point.payload or {})
|
||||
new_payload[WORKSPACE_ID_FIELD] = workspace
|
||||
|
||||
# Create new point with workspace-prefixed ID
|
||||
original_id = new_payload.get(ID_FIELD)
|
||||
if original_id:
|
||||
new_point_id = compute_mdhash_id_for_qdrant(
|
||||
original_id, prefix=workspace
|
||||
)
|
||||
else:
|
||||
# Fallback: use original point ID
|
||||
new_point_id = str(point.id)
|
||||
|
||||
new_points.append(
|
||||
models.PointStruct(
|
||||
id=new_point_id,
|
||||
vector=point.vector,
|
||||
payload=new_payload,
|
||||
)
|
||||
)
|
||||
|
||||
# Upsert to new collection
|
||||
client.upsert(
|
||||
collection_name=collection_name, points=new_points, wait=True
|
||||
)
|
||||
|
||||
# Upsert to new collection
|
||||
client.upsert(
|
||||
collection_name=collection_name, points=new_points, wait=True
|
||||
migrated_count += len(points)
|
||||
logger.info(
|
||||
f"Qdrant: {migrated_count}/{legacy_count} records migrated"
|
||||
)
|
||||
|
||||
# Check if we've reached the end
|
||||
if next_offset is None:
|
||||
break
|
||||
offset = next_offset
|
||||
|
||||
new_count_after = client.count(
|
||||
collection_name=collection_name,
|
||||
count_filter=workspace_count_filter,
|
||||
exact=True,
|
||||
).count
|
||||
inserted_count = new_count_after - new_workspace_count
|
||||
if inserted_count != legacy_count:
|
||||
error_msg = (
|
||||
"Qdrant: Migration verification failed, expected "
|
||||
f"{legacy_count} inserted records, got {inserted_count}."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise DataMigrationError(error_msg)
|
||||
|
||||
except DataMigrationError:
|
||||
# Re-raise DataMigrationError as-is to preserve specific error messages
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Qdrant: Failed to migrate data from legacy collection '{legacy_collection}' to new collection '{collection_name}': {e}"
|
||||
)
|
||||
|
||||
migrated_count += len(points)
|
||||
logger.info(f"Qdrant: {migrated_count}/{legacy_count} records migrated")
|
||||
|
||||
# Check if we've reached the end
|
||||
if next_offset is None:
|
||||
break
|
||||
offset = next_offset
|
||||
|
||||
# Verify migration by comparing counts
|
||||
logger.info("Verifying migration...")
|
||||
new_count = client.count(collection_name=collection_name, exact=True).count
|
||||
|
||||
if new_count != legacy_count:
|
||||
error_msg = f"Qdrant: Migration verification failed, expected {legacy_count} records, got {new_count} in new collection"
|
||||
logger.error(error_msg)
|
||||
raise QdrantMigrationError(error_msg)
|
||||
raise DataMigrationError(
|
||||
f"Failed to migrate data from legacy collection '{legacy_collection}' to new collection '{collection_name}'"
|
||||
) from e
|
||||
|
||||
logger.info(
|
||||
f"Qdrant: Migration completed successfully: {migrated_count} records migrated"
|
||||
f"Qdrant: Migration from '{legacy_collection}' to '{collection_name}' completed successfully"
|
||||
)
|
||||
|
||||
# Create payload index after successful migration
|
||||
logger.info("Qdrant: Creating workspace payload index...")
|
||||
client.create_payload_index(
|
||||
collection_name=collection_name,
|
||||
field_name=WORKSPACE_ID_FIELD,
|
||||
field_schema=models.KeywordIndexParams(
|
||||
type=models.KeywordIndexType.KEYWORD,
|
||||
is_tenant=True,
|
||||
),
|
||||
logger.warning(
|
||||
"Qdrant: Manual deletion is required after data migration verification."
|
||||
)
|
||||
logger.info(
|
||||
f"Qdrant: Migration from '{legacy_namespace}' to '{collection_name}' completed successfully"
|
||||
)
|
||||
|
||||
except QdrantMigrationError:
|
||||
# Re-raise migration errors without wrapping
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Qdrant: Migration failed with error: {e}"
|
||||
logger.error(error_msg)
|
||||
raise QdrantMigrationError(error_msg) from e
|
||||
|
||||
def __post_init__(self):
|
||||
# Call parent class __post_init__ to validate embedding_func
|
||||
super().__post_init__()
|
||||
|
||||
# Check for QDRANT_WORKSPACE environment variable first (higher priority)
|
||||
# This allows administrators to force a specific workspace for all Qdrant storage instances
|
||||
qdrant_workspace = os.environ.get("QDRANT_WORKSPACE")
|
||||
@@ -287,20 +433,23 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
f"Using passed workspace parameter: '{effective_workspace}'"
|
||||
)
|
||||
|
||||
# Get legacy namespace for data migration from old version
|
||||
if effective_workspace:
|
||||
self.legacy_namespace = f"{effective_workspace}_{self.namespace}"
|
||||
else:
|
||||
self.legacy_namespace = self.namespace
|
||||
|
||||
self.effective_workspace = effective_workspace or DEFAULT_WORKSPACE
|
||||
|
||||
# Use a shared collection with payload-based partitioning (Qdrant's recommended approach)
|
||||
# Ref: https://qdrant.tech/documentation/guides/multiple-partitions/
|
||||
self.final_namespace = f"lightrag_vdb_{self.namespace}"
|
||||
logger.debug(
|
||||
f"Using shared collection '{self.final_namespace}' with workspace '{self.effective_workspace}' for payload-based partitioning"
|
||||
)
|
||||
# Generate model suffix
|
||||
self.model_suffix = self._generate_collection_suffix()
|
||||
|
||||
# New naming scheme with model isolation
|
||||
# Example: "lightrag_vdb_chunks_text_embedding_ada_002_1536d"
|
||||
# Ensure model_suffix is not empty before appending
|
||||
if self.model_suffix:
|
||||
self.final_namespace = f"lightrag_vdb_{self.namespace}_{self.model_suffix}"
|
||||
logger.info(f"Qdrant collection: {self.final_namespace}")
|
||||
else:
|
||||
# Fallback: use legacy namespace if model_suffix is unavailable
|
||||
self.final_namespace = f"lightrag_vdb_{self.namespace}"
|
||||
logger.warning(
|
||||
f"Qdrant collection: {self.final_namespace} missing suffix. Pls add model_name to embedding_func for proper workspace data isolation."
|
||||
)
|
||||
|
||||
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
||||
@@ -338,11 +487,11 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
|
||||
# Setup collection (create if not exists and configure indexes)
|
||||
# Pass legacy_namespace and workspace for migration support
|
||||
# Pass namespace and workspace for backward-compatible migration support
|
||||
QdrantVectorDBStorage.setup_collection(
|
||||
self._client,
|
||||
self.final_namespace,
|
||||
legacy_namespace=self.legacy_namespace,
|
||||
namespace=self.namespace,
|
||||
workspace=self.effective_workspace,
|
||||
vectors_config=models.VectorParams(
|
||||
size=self.embedding_func.embedding_dim,
|
||||
@@ -352,8 +501,11 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
payload_m=16,
|
||||
m=0,
|
||||
),
|
||||
model_suffix=self.model_suffix,
|
||||
)
|
||||
|
||||
# Removed duplicate max batch size initialization
|
||||
|
||||
self._initialized = True
|
||||
logger.info(
|
||||
f"[{self.workspace}] Qdrant collection '{self.namespace}' initialized successfully"
|
||||
@@ -481,21 +633,44 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
entity_name: Name of the entity to delete
|
||||
"""
|
||||
try:
|
||||
# Generate the entity ID using the same function as used for storage
|
||||
# Compute entity ID from name (same as Milvus)
|
||||
entity_id = compute_mdhash_id(entity_name, prefix=ENTITY_PREFIX)
|
||||
qdrant_entity_id = compute_mdhash_id_for_qdrant(
|
||||
entity_id, prefix=self.effective_workspace
|
||||
logger.debug(
|
||||
f"[{self.workspace}] Attempting to delete entity {entity_name} with ID {entity_id}"
|
||||
)
|
||||
|
||||
# Delete the entity point by its Qdrant ID directly
|
||||
self._client.delete(
|
||||
# Scroll to find the entity by its ID field in payload with workspace filtering
|
||||
# This is safer than reconstructing the Qdrant point ID
|
||||
results = self._client.scroll(
|
||||
collection_name=self.final_namespace,
|
||||
points_selector=models.PointIdsList(points=[qdrant_entity_id]),
|
||||
wait=True,
|
||||
)
|
||||
logger.debug(
|
||||
f"[{self.workspace}] Successfully deleted entity {entity_name}"
|
||||
scroll_filter=models.Filter(
|
||||
must=[
|
||||
workspace_filter_condition(self.effective_workspace),
|
||||
models.FieldCondition(
|
||||
key=ID_FIELD, match=models.MatchValue(value=entity_id)
|
||||
),
|
||||
]
|
||||
),
|
||||
with_payload=False,
|
||||
limit=1,
|
||||
)
|
||||
|
||||
# Extract point IDs to delete
|
||||
points = results[0]
|
||||
if points:
|
||||
ids_to_delete = [point.id for point in points]
|
||||
self._client.delete(
|
||||
collection_name=self.final_namespace,
|
||||
points_selector=models.PointIdsList(points=ids_to_delete),
|
||||
wait=True,
|
||||
)
|
||||
logger.debug(
|
||||
f"[{self.workspace}] Successfully deleted entity {entity_name}"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"[{self.workspace}] Entity {entity_name} not found in storage"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.workspace}] Error deleting entity {entity_name}: {e}")
|
||||
|
||||
@@ -506,38 +681,60 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
entity_name: Name of the entity whose relations should be deleted
|
||||
"""
|
||||
try:
|
||||
# Find relations where the entity is either source or target, with workspace filtering
|
||||
results = self._client.scroll(
|
||||
collection_name=self.final_namespace,
|
||||
scroll_filter=models.Filter(
|
||||
must=[workspace_filter_condition(self.effective_workspace)],
|
||||
should=[
|
||||
models.FieldCondition(
|
||||
key="src_id", match=models.MatchValue(value=entity_name)
|
||||
),
|
||||
models.FieldCondition(
|
||||
key="tgt_id", match=models.MatchValue(value=entity_name)
|
||||
),
|
||||
],
|
||||
),
|
||||
with_payload=True,
|
||||
limit=1000, # Adjust as needed for your use case
|
||||
# Build the filter to find relations where entity is either source or target
|
||||
# must + should = workspace_id matches AND (src_id matches OR tgt_id matches)
|
||||
relation_filter = models.Filter(
|
||||
must=[workspace_filter_condition(self.effective_workspace)],
|
||||
should=[
|
||||
models.FieldCondition(
|
||||
key="src_id", match=models.MatchValue(value=entity_name)
|
||||
),
|
||||
models.FieldCondition(
|
||||
key="tgt_id", match=models.MatchValue(value=entity_name)
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Extract points that need to be deleted
|
||||
relation_points = results[0]
|
||||
ids_to_delete = [point.id for point in relation_points]
|
||||
# Paginate through all matching relations to handle large datasets
|
||||
total_deleted = 0
|
||||
offset = None
|
||||
batch_size = 1000
|
||||
|
||||
if ids_to_delete:
|
||||
# Delete the relations with workspace filtering
|
||||
assert isinstance(self._client, QdrantClient)
|
||||
while True:
|
||||
# Scroll to find relations, using with_payload=False for efficiency
|
||||
# since we only need point IDs for deletion
|
||||
results = self._client.scroll(
|
||||
collection_name=self.final_namespace,
|
||||
scroll_filter=relation_filter,
|
||||
with_payload=False,
|
||||
with_vectors=False,
|
||||
limit=batch_size,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
points, next_offset = results
|
||||
if not points:
|
||||
break
|
||||
|
||||
# Extract point IDs to delete
|
||||
ids_to_delete = [point.id for point in points]
|
||||
|
||||
# Delete the batch of relations
|
||||
self._client.delete(
|
||||
collection_name=self.final_namespace,
|
||||
points_selector=models.PointIdsList(points=ids_to_delete),
|
||||
wait=True,
|
||||
)
|
||||
total_deleted += len(ids_to_delete)
|
||||
|
||||
# Check if we've reached the end
|
||||
if next_offset is None:
|
||||
break
|
||||
offset = next_offset
|
||||
|
||||
if total_deleted > 0:
|
||||
logger.debug(
|
||||
f"[{self.workspace}] Deleted {len(ids_to_delete)} relations for {entity_name}"
|
||||
f"[{self.workspace}] Deleted {total_deleted} relations for {entity_name}"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
|
||||
@@ -163,7 +163,9 @@ class UnifiedLock(Generic[T]):
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
|
||||
# Then acquire the main lock
|
||||
# Acquire the main lock
|
||||
# Note: self._lock should never be None here as the check has been moved
|
||||
# to get_internal_lock() and get_data_init_lock() functions
|
||||
if self._is_async:
|
||||
await self._lock.acquire()
|
||||
else:
|
||||
@@ -193,19 +195,21 @@ class UnifiedLock(Generic[T]):
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
main_lock_released = False
|
||||
async_lock_released = False
|
||||
try:
|
||||
# Release main lock first
|
||||
if self._is_async:
|
||||
self._lock.release()
|
||||
else:
|
||||
self._lock.release()
|
||||
main_lock_released = True
|
||||
if self._lock is not None:
|
||||
if self._is_async:
|
||||
self._lock.release()
|
||||
else:
|
||||
self._lock.release()
|
||||
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Released lock {self._name} (async={self._is_async})",
|
||||
level="INFO",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Released lock {self._name} (async={self._is_async})",
|
||||
level="INFO",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
main_lock_released = True
|
||||
|
||||
# Then release async lock if in multiprocess mode
|
||||
if not self._is_async and self._async_lock is not None:
|
||||
@@ -215,6 +219,7 @@ class UnifiedLock(Generic[T]):
|
||||
level="DEBUG",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
async_lock_released = True
|
||||
|
||||
except Exception as e:
|
||||
direct_log(
|
||||
@@ -223,9 +228,10 @@ class UnifiedLock(Generic[T]):
|
||||
enable_output=True,
|
||||
)
|
||||
|
||||
# If main lock release failed but async lock hasn't been released, try to release it
|
||||
# If main lock release failed but async lock hasn't been attempted yet, try to release it
|
||||
if (
|
||||
not main_lock_released
|
||||
and not async_lock_released
|
||||
and not self._is_async
|
||||
and self._async_lock is not None
|
||||
):
|
||||
@@ -255,6 +261,10 @@ class UnifiedLock(Generic[T]):
|
||||
try:
|
||||
if self._is_async:
|
||||
raise RuntimeError("Use 'async with' for shared_storage lock")
|
||||
|
||||
# Acquire the main lock
|
||||
# Note: self._lock should never be None here as the check has been moved
|
||||
# to get_internal_lock() and get_data_init_lock() functions
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Acquiring lock {self._name} (sync)",
|
||||
level="DEBUG",
|
||||
@@ -1060,6 +1070,10 @@ class _KeyedLockContext:
|
||||
|
||||
def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
"""return unified storage lock for data consistency"""
|
||||
if _internal_lock is None:
|
||||
raise RuntimeError(
|
||||
"Shared data not initialized. Call initialize_share_data() before using locks!"
|
||||
)
|
||||
async_lock = _async_locks.get("internal_lock") if _is_multiprocess else None
|
||||
return UnifiedLock(
|
||||
lock=_internal_lock,
|
||||
@@ -1090,6 +1104,10 @@ def get_storage_keyed_lock(
|
||||
|
||||
def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
"""return unified data initialization lock for ensuring atomic data initialization"""
|
||||
if _data_init_lock is None:
|
||||
raise RuntimeError(
|
||||
"Shared data not initialized. Call initialize_share_data() before using locks!"
|
||||
)
|
||||
async_lock = _async_locks.get("data_init_lock") if _is_multiprocess else None
|
||||
return UnifiedLock(
|
||||
lock=_data_init_lock,
|
||||
|
||||
@@ -7,7 +7,7 @@ import inspect
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from dataclasses import asdict, dataclass, field, replace
|
||||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from typing import (
|
||||
@@ -518,14 +518,9 @@ class LightRAG:
|
||||
f"max_total_tokens({self.summary_max_tokens}) should greater than summary_length_recommended({self.summary_length_recommended})"
|
||||
)
|
||||
|
||||
# Fix global_config now
|
||||
global_config = asdict(self)
|
||||
|
||||
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
|
||||
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
||||
|
||||
# Init Embedding
|
||||
# Step 1: Capture max_token_size before applying decorator (decorator strips dataclass attributes)
|
||||
# Step 1: Capture embedding_func and max_token_size before applying decorator
|
||||
original_embedding_func = self.embedding_func
|
||||
embedding_max_token_size = None
|
||||
if self.embedding_func and hasattr(self.embedding_func, "max_token_size"):
|
||||
embedding_max_token_size = self.embedding_func.max_token_size
|
||||
@@ -534,12 +529,26 @@ class LightRAG:
|
||||
)
|
||||
self.embedding_token_limit = embedding_max_token_size
|
||||
|
||||
# Step 2: Apply priority wrapper decorator
|
||||
self.embedding_func = priority_limit_async_func_call(
|
||||
self.embedding_func_max_async,
|
||||
llm_timeout=self.default_embedding_timeout,
|
||||
queue_name="Embedding func",
|
||||
)(self.embedding_func)
|
||||
# Fix global_config now
|
||||
global_config = asdict(self)
|
||||
# Restore original EmbeddingFunc object (asdict converts it to dict)
|
||||
global_config["embedding_func"] = original_embedding_func
|
||||
|
||||
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
|
||||
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
||||
|
||||
# Step 2: Apply priority wrapper decorator to EmbeddingFunc's inner func
|
||||
# Create a NEW EmbeddingFunc instance with the wrapped func to avoid mutating the caller's object
|
||||
# This ensures _generate_collection_suffix can still access attributes (model_name, embedding_dim)
|
||||
# while preventing side effects when the same EmbeddingFunc is reused across multiple LightRAG instances
|
||||
if self.embedding_func is not None:
|
||||
wrapped_func = priority_limit_async_func_call(
|
||||
self.embedding_func_max_async,
|
||||
llm_timeout=self.default_embedding_timeout,
|
||||
queue_name="Embedding func",
|
||||
)(self.embedding_func.func)
|
||||
# Use dataclasses.replace() to create a new instance, leaving the original unchanged
|
||||
self.embedding_func = replace(self.embedding_func, func=wrapped_func)
|
||||
|
||||
# Initialize all storages
|
||||
self.key_string_value_json_storage_cls: type[BaseKVStorage] = (
|
||||
|
||||
@@ -351,7 +351,9 @@ async def bedrock_complete(
|
||||
return result
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||
@wrap_embedding_func_with_attrs(
|
||||
embedding_dim=1024, max_token_size=8192, model_name="amazon.titan-embed-text-v2:0"
|
||||
)
|
||||
@retry(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
|
||||
@@ -453,7 +453,9 @@ async def gemini_model_complete(
|
||||
)
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=2048)
|
||||
@wrap_embedding_func_with_attrs(
|
||||
embedding_dim=1536, max_token_size=2048, model_name="gemini-embedding-001"
|
||||
)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
|
||||
@@ -142,7 +142,9 @@ async def hf_model_complete(
|
||||
return result
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||
@wrap_embedding_func_with_attrs(
|
||||
embedding_dim=1024, max_token_size=8192, model_name="hf_embedding_model"
|
||||
)
|
||||
async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
||||
# Detect the appropriate device
|
||||
if torch.cuda.is_available():
|
||||
|
||||
@@ -58,7 +58,9 @@ async def fetch_data(url, headers, data):
|
||||
return data_list
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192)
|
||||
@wrap_embedding_func_with_attrs(
|
||||
embedding_dim=2048, max_token_size=8192, model_name="jina-embeddings-v4"
|
||||
)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
|
||||
@@ -138,7 +138,9 @@ async def lollms_model_complete(
|
||||
)
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||
@wrap_embedding_func_with_attrs(
|
||||
embedding_dim=1024, max_token_size=8192, model_name="lollms_embedding_model"
|
||||
)
|
||||
async def lollms_embed(
|
||||
texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs
|
||||
) -> np.ndarray:
|
||||
|
||||
@@ -33,7 +33,9 @@ from lightrag.utils import (
|
||||
import numpy as np
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192)
|
||||
@wrap_embedding_func_with_attrs(
|
||||
embedding_dim=2048, max_token_size=8192, model_name="nvidia_embedding_model"
|
||||
)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
|
||||
@@ -172,7 +172,9 @@ async def ollama_model_complete(
|
||||
)
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||
@wrap_embedding_func_with_attrs(
|
||||
embedding_dim=1024, max_token_size=8192, model_name="bge-m3:latest"
|
||||
)
|
||||
async def ollama_embed(
|
||||
texts: list[str], embed_model: str = "bge-m3:latest", **kwargs
|
||||
) -> np.ndarray:
|
||||
|
||||
@@ -677,7 +677,9 @@ async def nvidia_openai_complete(
|
||||
return result
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||
@wrap_embedding_func_with_attrs(
|
||||
embedding_dim=1536, max_token_size=8192, model_name="text-embedding-3-small"
|
||||
)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
@@ -867,7 +869,11 @@ async def azure_openai_complete(
|
||||
return result
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||
@wrap_embedding_func_with_attrs(
|
||||
embedding_dim=1536,
|
||||
max_token_size=8192,
|
||||
model_name="my-text-embedding-3-large-deployment",
|
||||
)
|
||||
async def azure_openai_embed(
|
||||
texts: list[str],
|
||||
model: str | None = None,
|
||||
|
||||
@@ -179,7 +179,9 @@ async def zhipu_complete(
|
||||
)
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1024)
|
||||
@wrap_embedding_func_with_attrs(
|
||||
embedding_dim=1024, max_token_size=8192, model_name="embedding-3"
|
||||
)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
|
||||
720
lightrag/tools/prepare_qdrant_legacy_data.py
Normal file
720
lightrag/tools/prepare_qdrant_legacy_data.py
Normal file
@@ -0,0 +1,720 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Qdrant Legacy Data Preparation Tool for LightRAG
|
||||
|
||||
This tool copies data from new collections to legacy collections for testing
|
||||
the data migration logic in setup_collection function.
|
||||
|
||||
New Collections (with workspace_id):
|
||||
- lightrag_vdb_chunks
|
||||
- lightrag_vdb_entities
|
||||
- lightrag_vdb_relationships
|
||||
|
||||
Legacy Collections (without workspace_id, dynamically named as {workspace}_{suffix}):
|
||||
- {workspace}_chunks (e.g., space1_chunks)
|
||||
- {workspace}_entities (e.g., space1_entities)
|
||||
- {workspace}_relationships (e.g., space1_relationships)
|
||||
|
||||
The tool:
|
||||
1. Filters source data by workspace_id
|
||||
2. Verifies workspace data exists before creating legacy collections
|
||||
3. Removes workspace_id field to simulate legacy data format
|
||||
4. Copies only the specified workspace's data to legacy collections
|
||||
|
||||
Usage:
|
||||
python -m lightrag.tools.prepare_qdrant_legacy_data
|
||||
# or
|
||||
python lightrag/tools/prepare_qdrant_legacy_data.py
|
||||
|
||||
# Specify custom workspace
|
||||
python -m lightrag.tools.prepare_qdrant_legacy_data --workspace space1
|
||||
|
||||
# Process specific collection types only
|
||||
python -m lightrag.tools.prepare_qdrant_legacy_data --types chunks,entities
|
||||
|
||||
# Dry run (preview only, no actual changes)
|
||||
python -m lightrag.tools.prepare_qdrant_legacy_data --dry-run
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import configparser
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pipmaster as pm
|
||||
from dotenv import load_dotenv
|
||||
from qdrant_client import QdrantClient, models # type: ignore
|
||||
|
||||
# Add project root to path for imports
|
||||
sys.path.insert(
|
||||
0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
)
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv(dotenv_path=".env", override=False)
|
||||
|
||||
# Ensure qdrant-client is installed
|
||||
if not pm.is_installed("qdrant-client"):
|
||||
pm.install("qdrant-client")
|
||||
|
||||
# Collection namespace mapping: new collection pattern -> legacy suffix
|
||||
# Legacy collection will be named as: {workspace}_{suffix}
|
||||
COLLECTION_NAMESPACES = {
|
||||
"chunks": {
|
||||
"new": "lightrag_vdb_chunks",
|
||||
"suffix": "chunks",
|
||||
},
|
||||
"entities": {
|
||||
"new": "lightrag_vdb_entities",
|
||||
"suffix": "entities",
|
||||
},
|
||||
"relationships": {
|
||||
"new": "lightrag_vdb_relationships",
|
||||
"suffix": "relationships",
|
||||
},
|
||||
}
|
||||
|
||||
# Default batch size for copy operations
|
||||
DEFAULT_BATCH_SIZE = 500
|
||||
|
||||
# Field to remove from legacy data
|
||||
WORKSPACE_ID_FIELD = "workspace_id"
|
||||
|
||||
# ANSI color codes for terminal output
|
||||
BOLD_CYAN = "\033[1;36m"
|
||||
BOLD_GREEN = "\033[1;32m"
|
||||
BOLD_YELLOW = "\033[1;33m"
|
||||
BOLD_RED = "\033[1;31m"
|
||||
RESET = "\033[0m"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CopyStats:
|
||||
"""Copy operation statistics"""
|
||||
|
||||
collection_type: str
|
||||
source_collection: str
|
||||
target_collection: str
|
||||
total_records: int = 0
|
||||
copied_records: int = 0
|
||||
failed_records: int = 0
|
||||
errors: List[Dict[str, Any]] = field(default_factory=list)
|
||||
elapsed_time: float = 0.0
|
||||
|
||||
def add_error(self, batch_idx: int, error: Exception, batch_size: int):
|
||||
"""Record batch error"""
|
||||
self.errors.append(
|
||||
{
|
||||
"batch": batch_idx,
|
||||
"error_type": type(error).__name__,
|
||||
"error_msg": str(error),
|
||||
"records_lost": batch_size,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
)
|
||||
self.failed_records += batch_size
|
||||
|
||||
|
||||
class QdrantLegacyDataPreparationTool:
|
||||
"""Tool for preparing legacy data in Qdrant for migration testing"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace: str = "space1",
|
||||
batch_size: int = DEFAULT_BATCH_SIZE,
|
||||
dry_run: bool = False,
|
||||
clear_target: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the tool.
|
||||
|
||||
Args:
|
||||
workspace: Workspace to use for filtering new collection data
|
||||
batch_size: Number of records to process per batch
|
||||
dry_run: If True, only preview operations without making changes
|
||||
clear_target: If True, delete target collection before copying data
|
||||
"""
|
||||
self.workspace = workspace
|
||||
self.batch_size = batch_size
|
||||
self.dry_run = dry_run
|
||||
self.clear_target = clear_target
|
||||
self._client: Optional[QdrantClient] = None
|
||||
|
||||
def _get_client(self) -> QdrantClient:
|
||||
"""Get or create QdrantClient instance"""
|
||||
if self._client is None:
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
||||
self._client = QdrantClient(
|
||||
url=os.environ.get(
|
||||
"QDRANT_URL", config.get("qdrant", "uri", fallback=None)
|
||||
),
|
||||
api_key=os.environ.get(
|
||||
"QDRANT_API_KEY",
|
||||
config.get("qdrant", "apikey", fallback=None),
|
||||
),
|
||||
)
|
||||
return self._client
|
||||
|
||||
def print_header(self):
|
||||
"""Print tool header"""
|
||||
print("\n" + "=" * 60)
|
||||
print("Qdrant Legacy Data Preparation Tool - LightRAG")
|
||||
print("=" * 60)
|
||||
if self.dry_run:
|
||||
print(f"{BOLD_YELLOW}⚠️ DRY RUN MODE - No changes will be made{RESET}")
|
||||
if self.clear_target:
|
||||
print(
|
||||
f"{BOLD_RED}⚠️ CLEAR TARGET MODE - Target collections will be deleted first{RESET}"
|
||||
)
|
||||
print(f"Workspace: {BOLD_CYAN}{self.workspace}{RESET}")
|
||||
print(f"Batch Size: {self.batch_size}")
|
||||
print("=" * 60)
|
||||
|
||||
def check_connection(self) -> bool:
|
||||
"""Check Qdrant connection"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
# Try to list collections to verify connection
|
||||
client.get_collections()
|
||||
print(f"{BOLD_GREEN}✓{RESET} Qdrant connection successful")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"{BOLD_RED}✗{RESET} Qdrant connection failed: {e}")
|
||||
return False
|
||||
|
||||
def get_collection_info(self, collection_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get collection information.
|
||||
|
||||
Args:
|
||||
collection_name: Name of the collection
|
||||
|
||||
Returns:
|
||||
Dictionary with collection info (vector_size, count) or None if not exists
|
||||
"""
|
||||
client = self._get_client()
|
||||
|
||||
if not client.collection_exists(collection_name):
|
||||
return None
|
||||
|
||||
info = client.get_collection(collection_name)
|
||||
count = client.count(collection_name=collection_name, exact=True).count
|
||||
|
||||
# Handle both object and dict formats for vectors config
|
||||
vectors_config = info.config.params.vectors
|
||||
if isinstance(vectors_config, dict):
|
||||
# Named vectors format or dict format
|
||||
if vectors_config:
|
||||
first_key = next(iter(vectors_config.keys()), None)
|
||||
if first_key and hasattr(vectors_config[first_key], "size"):
|
||||
vector_size = vectors_config[first_key].size
|
||||
distance = vectors_config[first_key].distance
|
||||
else:
|
||||
# Try to get from dict values
|
||||
first_val = next(iter(vectors_config.values()), {})
|
||||
vector_size = (
|
||||
first_val.get("size")
|
||||
if isinstance(first_val, dict)
|
||||
else getattr(first_val, "size", None)
|
||||
)
|
||||
distance = (
|
||||
first_val.get("distance")
|
||||
if isinstance(first_val, dict)
|
||||
else getattr(first_val, "distance", None)
|
||||
)
|
||||
else:
|
||||
vector_size = None
|
||||
distance = None
|
||||
else:
|
||||
# Standard single vector format
|
||||
vector_size = vectors_config.size
|
||||
distance = vectors_config.distance
|
||||
|
||||
return {
|
||||
"name": collection_name,
|
||||
"vector_size": vector_size,
|
||||
"count": count,
|
||||
"distance": distance,
|
||||
}
|
||||
|
||||
def delete_collection(self, collection_name: str) -> bool:
|
||||
"""
|
||||
Delete a collection if it exists.
|
||||
|
||||
Args:
|
||||
collection_name: Name of the collection to delete
|
||||
|
||||
Returns:
|
||||
True if deleted or doesn't exist
|
||||
"""
|
||||
client = self._get_client()
|
||||
|
||||
if not client.collection_exists(collection_name):
|
||||
return True
|
||||
|
||||
if self.dry_run:
|
||||
target_info = self.get_collection_info(collection_name)
|
||||
count = target_info["count"] if target_info else 0
|
||||
print(
|
||||
f" {BOLD_YELLOW}[DRY RUN]{RESET} Would delete collection '{collection_name}' ({count:,} records)"
|
||||
)
|
||||
return True
|
||||
|
||||
try:
|
||||
target_info = self.get_collection_info(collection_name)
|
||||
count = target_info["count"] if target_info else 0
|
||||
client.delete_collection(collection_name=collection_name)
|
||||
print(
|
||||
f" {BOLD_RED}✗{RESET} Deleted collection '{collection_name}' ({count:,} records)"
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f" {BOLD_RED}✗{RESET} Failed to delete collection: {e}")
|
||||
return False
|
||||
|
||||
def create_legacy_collection(
|
||||
self, collection_name: str, vector_size: int, distance: models.Distance
|
||||
) -> bool:
|
||||
"""
|
||||
Create legacy collection if it doesn't exist.
|
||||
|
||||
Args:
|
||||
collection_name: Name of the collection to create
|
||||
vector_size: Dimension of vectors
|
||||
distance: Distance metric
|
||||
|
||||
Returns:
|
||||
True if created or already exists
|
||||
"""
|
||||
client = self._get_client()
|
||||
|
||||
if client.collection_exists(collection_name):
|
||||
print(f" Collection '{collection_name}' already exists")
|
||||
return True
|
||||
|
||||
if self.dry_run:
|
||||
print(
|
||||
f" {BOLD_YELLOW}[DRY RUN]{RESET} Would create collection '{collection_name}' with {vector_size}d vectors"
|
||||
)
|
||||
return True
|
||||
|
||||
try:
|
||||
client.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=models.VectorParams(
|
||||
size=vector_size,
|
||||
distance=distance,
|
||||
),
|
||||
hnsw_config=models.HnswConfigDiff(
|
||||
payload_m=16,
|
||||
m=0,
|
||||
),
|
||||
)
|
||||
print(
|
||||
f" {BOLD_GREEN}✓{RESET} Created collection '{collection_name}' with {vector_size}d vectors"
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f" {BOLD_RED}✗{RESET} Failed to create collection: {e}")
|
||||
return False
|
||||
|
||||
def _get_workspace_filter(self) -> models.Filter:
|
||||
"""Create workspace filter for Qdrant queries"""
|
||||
return models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key=WORKSPACE_ID_FIELD,
|
||||
match=models.MatchValue(value=self.workspace),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def get_workspace_count(self, collection_name: str) -> int:
|
||||
"""
|
||||
Get count of records for the current workspace in a collection.
|
||||
|
||||
Args:
|
||||
collection_name: Name of the collection
|
||||
|
||||
Returns:
|
||||
Count of records for the workspace
|
||||
"""
|
||||
client = self._get_client()
|
||||
return client.count(
|
||||
collection_name=collection_name,
|
||||
count_filter=self._get_workspace_filter(),
|
||||
exact=True,
|
||||
).count
|
||||
|
||||
def copy_collection_data(
|
||||
self,
|
||||
source_collection: str,
|
||||
target_collection: str,
|
||||
collection_type: str,
|
||||
workspace_count: int,
|
||||
) -> CopyStats:
|
||||
"""
|
||||
Copy data from source to target collection.
|
||||
|
||||
This filters by workspace_id and removes it from payload to simulate legacy data format.
|
||||
|
||||
Args:
|
||||
source_collection: Source collection name
|
||||
target_collection: Target collection name
|
||||
collection_type: Type of collection (chunks, entities, relationships)
|
||||
workspace_count: Pre-computed count of workspace records
|
||||
|
||||
Returns:
|
||||
CopyStats with operation results
|
||||
"""
|
||||
client = self._get_client()
|
||||
stats = CopyStats(
|
||||
collection_type=collection_type,
|
||||
source_collection=source_collection,
|
||||
target_collection=target_collection,
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
stats.total_records = workspace_count
|
||||
|
||||
if workspace_count == 0:
|
||||
print(f" No records for workspace '{self.workspace}', skipping")
|
||||
stats.elapsed_time = time.time() - start_time
|
||||
return stats
|
||||
|
||||
print(f" Workspace records: {workspace_count:,}")
|
||||
|
||||
if self.dry_run:
|
||||
print(
|
||||
f" {BOLD_YELLOW}[DRY RUN]{RESET} Would copy {workspace_count:,} records to '{target_collection}'"
|
||||
)
|
||||
stats.copied_records = workspace_count
|
||||
stats.elapsed_time = time.time() - start_time
|
||||
return stats
|
||||
|
||||
# Batch copy using scroll with workspace filter
|
||||
workspace_filter = self._get_workspace_filter()
|
||||
offset = None
|
||||
batch_idx = 0
|
||||
|
||||
while True:
|
||||
# Scroll source collection with workspace filter
|
||||
result = client.scroll(
|
||||
collection_name=source_collection,
|
||||
scroll_filter=workspace_filter,
|
||||
limit=self.batch_size,
|
||||
offset=offset,
|
||||
with_vectors=True,
|
||||
with_payload=True,
|
||||
)
|
||||
points, next_offset = result
|
||||
|
||||
if not points:
|
||||
break
|
||||
|
||||
batch_idx += 1
|
||||
|
||||
# Transform points: remove workspace_id from payload
|
||||
new_points = []
|
||||
for point in points:
|
||||
new_payload = dict(point.payload or {})
|
||||
# Remove workspace_id to simulate legacy format
|
||||
new_payload.pop(WORKSPACE_ID_FIELD, None)
|
||||
|
||||
# Use original id from payload if available, otherwise use point.id
|
||||
original_id = new_payload.get("id")
|
||||
if original_id:
|
||||
# Generate a simple deterministic id for legacy format
|
||||
# Use original id directly (legacy format didn't have workspace prefix)
|
||||
import hashlib
|
||||
import uuid
|
||||
|
||||
hashed = hashlib.sha256(original_id.encode("utf-8")).digest()
|
||||
point_id = uuid.UUID(bytes=hashed[:16], version=4).hex
|
||||
else:
|
||||
point_id = str(point.id)
|
||||
|
||||
new_points.append(
|
||||
models.PointStruct(
|
||||
id=point_id,
|
||||
vector=point.vector,
|
||||
payload=new_payload,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# Upsert to target collection
|
||||
client.upsert(
|
||||
collection_name=target_collection, points=new_points, wait=True
|
||||
)
|
||||
stats.copied_records += len(new_points)
|
||||
|
||||
# Progress bar
|
||||
progress = (stats.copied_records / workspace_count) * 100
|
||||
bar_length = 30
|
||||
filled = int(bar_length * stats.copied_records // workspace_count)
|
||||
bar = "█" * filled + "░" * (bar_length - filled)
|
||||
|
||||
print(
|
||||
f"\r Copying: {bar} {stats.copied_records:,}/{workspace_count:,} ({progress:.1f}%) ",
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
stats.add_error(batch_idx, e, len(new_points))
|
||||
print(
|
||||
f"\n {BOLD_RED}✗{RESET} Batch {batch_idx} failed: {type(e).__name__}: {e}"
|
||||
)
|
||||
|
||||
if next_offset is None:
|
||||
break
|
||||
offset = next_offset
|
||||
|
||||
print() # New line after progress bar
|
||||
stats.elapsed_time = time.time() - start_time
|
||||
|
||||
return stats
|
||||
|
||||
def process_collection_type(self, collection_type: str) -> Optional[CopyStats]:
|
||||
"""
|
||||
Process a single collection type.
|
||||
|
||||
Args:
|
||||
collection_type: Type of collection (chunks, entities, relationships)
|
||||
|
||||
Returns:
|
||||
CopyStats or None if error
|
||||
"""
|
||||
namespace_config = COLLECTION_NAMESPACES.get(collection_type)
|
||||
if not namespace_config:
|
||||
print(f"{BOLD_RED}✗{RESET} Unknown collection type: {collection_type}")
|
||||
return None
|
||||
|
||||
source = namespace_config["new"]
|
||||
# Generate legacy collection name dynamically: {workspace}_{suffix}
|
||||
target = f"{self.workspace}_{namespace_config['suffix']}"
|
||||
|
||||
print(f"\n{'=' * 50}")
|
||||
print(f"Processing: {BOLD_CYAN}{collection_type}{RESET}")
|
||||
print(f"{'=' * 50}")
|
||||
print(f" Source: {source}")
|
||||
print(f" Target: {target}")
|
||||
|
||||
# Check source collection
|
||||
source_info = self.get_collection_info(source)
|
||||
if source_info is None:
|
||||
print(
|
||||
f" {BOLD_YELLOW}⚠{RESET} Source collection '{source}' does not exist, skipping"
|
||||
)
|
||||
return None
|
||||
|
||||
print(f" Source vector dimension: {source_info['vector_size']}d")
|
||||
print(f" Source distance metric: {source_info['distance']}")
|
||||
print(f" Source total records: {source_info['count']:,}")
|
||||
|
||||
# Check workspace data exists BEFORE creating legacy collection
|
||||
workspace_count = self.get_workspace_count(source)
|
||||
print(f" Workspace '{self.workspace}' records: {workspace_count:,}")
|
||||
|
||||
if workspace_count == 0:
|
||||
print(
|
||||
f" {BOLD_YELLOW}⚠{RESET} No data found for workspace '{self.workspace}' in '{source}', skipping"
|
||||
)
|
||||
return None
|
||||
|
||||
# Clear target collection if requested
|
||||
if self.clear_target:
|
||||
if not self.delete_collection(target):
|
||||
return None
|
||||
|
||||
# Create target collection only after confirming workspace data exists
|
||||
if not self.create_legacy_collection(
|
||||
target, source_info["vector_size"], source_info["distance"]
|
||||
):
|
||||
return None
|
||||
|
||||
# Copy data with workspace filter
|
||||
stats = self.copy_collection_data(
|
||||
source, target, collection_type, workspace_count
|
||||
)
|
||||
|
||||
# Print result
|
||||
if stats.failed_records == 0:
|
||||
print(
|
||||
f" {BOLD_GREEN}✓{RESET} Copied {stats.copied_records:,} records in {stats.elapsed_time:.2f}s"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f" {BOLD_YELLOW}⚠{RESET} Copied {stats.copied_records:,} records, "
|
||||
f"{BOLD_RED}{stats.failed_records:,} failed{RESET} in {stats.elapsed_time:.2f}s"
|
||||
)
|
||||
|
||||
return stats
|
||||
|
||||
def print_summary(self, all_stats: List[CopyStats]):
|
||||
"""Print summary of all operations"""
|
||||
print("\n" + "=" * 60)
|
||||
print("Summary")
|
||||
print("=" * 60)
|
||||
|
||||
total_copied = sum(s.copied_records for s in all_stats)
|
||||
total_failed = sum(s.failed_records for s in all_stats)
|
||||
total_time = sum(s.elapsed_time for s in all_stats)
|
||||
|
||||
for stats in all_stats:
|
||||
status = (
|
||||
f"{BOLD_GREEN}✓{RESET}"
|
||||
if stats.failed_records == 0
|
||||
else f"{BOLD_YELLOW}⚠{RESET}"
|
||||
)
|
||||
print(
|
||||
f" {status} {stats.collection_type}: {stats.copied_records:,}/{stats.total_records:,} "
|
||||
f"({stats.source_collection} → {stats.target_collection})"
|
||||
)
|
||||
|
||||
print("-" * 60)
|
||||
print(f" Total records copied: {BOLD_CYAN}{total_copied:,}{RESET}")
|
||||
if total_failed > 0:
|
||||
print(f" Total records failed: {BOLD_RED}{total_failed:,}{RESET}")
|
||||
print(f" Total time: {total_time:.2f}s")
|
||||
|
||||
if self.dry_run:
|
||||
print(f"\n{BOLD_YELLOW}⚠️ DRY RUN - No actual changes were made{RESET}")
|
||||
|
||||
# Print error details if any
|
||||
all_errors = []
|
||||
for stats in all_stats:
|
||||
all_errors.extend(stats.errors)
|
||||
|
||||
if all_errors:
|
||||
print(f"\n{BOLD_RED}Errors ({len(all_errors)}){RESET}")
|
||||
for i, error in enumerate(all_errors[:5], 1):
|
||||
print(
|
||||
f" {i}. Batch {error['batch']}: {error['error_type']}: {error['error_msg']}"
|
||||
)
|
||||
if len(all_errors) > 5:
|
||||
print(f" ... and {len(all_errors) - 5} more errors")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
async def run(self, collection_types: Optional[List[str]] = None):
|
||||
"""
|
||||
Run the data preparation tool.
|
||||
|
||||
Args:
|
||||
collection_types: List of collection types to process (default: all)
|
||||
"""
|
||||
self.print_header()
|
||||
|
||||
# Check connection
|
||||
if not self.check_connection():
|
||||
return
|
||||
|
||||
# Determine which collection types to process
|
||||
if collection_types:
|
||||
types_to_process = [t.strip() for t in collection_types]
|
||||
invalid_types = [
|
||||
t for t in types_to_process if t not in COLLECTION_NAMESPACES
|
||||
]
|
||||
if invalid_types:
|
||||
print(
|
||||
f"{BOLD_RED}✗{RESET} Invalid collection types: {', '.join(invalid_types)}"
|
||||
)
|
||||
print(f" Valid types: {', '.join(COLLECTION_NAMESPACES.keys())}")
|
||||
return
|
||||
else:
|
||||
types_to_process = list(COLLECTION_NAMESPACES.keys())
|
||||
|
||||
print(f"\nCollection types to process: {', '.join(types_to_process)}")
|
||||
|
||||
# Process each collection type
|
||||
all_stats = []
|
||||
for ctype in types_to_process:
|
||||
stats = self.process_collection_type(ctype)
|
||||
if stats:
|
||||
all_stats.append(stats)
|
||||
|
||||
# Print summary
|
||||
if all_stats:
|
||||
self.print_summary(all_stats)
|
||||
else:
|
||||
print(f"\n{BOLD_YELLOW}⚠{RESET} No collections were processed")
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Prepare legacy data in Qdrant for migration testing",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python -m lightrag.tools.prepare_qdrant_legacy_data
|
||||
python -m lightrag.tools.prepare_qdrant_legacy_data --workspace space1
|
||||
python -m lightrag.tools.prepare_qdrant_legacy_data --types chunks,entities
|
||||
python -m lightrag.tools.prepare_qdrant_legacy_data --dry-run
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--workspace",
|
||||
type=str,
|
||||
default="space1",
|
||||
help="Workspace name (default: space1)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--types",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Comma-separated list of collection types (chunks, entities, relationships)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=DEFAULT_BATCH_SIZE,
|
||||
help=f"Batch size for copy operations (default: {DEFAULT_BATCH_SIZE})",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Preview operations without making changes",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--clear-target",
|
||||
action="store_true",
|
||||
help="Delete target collections before copying (for clean test environment)",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main entry point"""
|
||||
args = parse_args()
|
||||
|
||||
collection_types = None
|
||||
if args.types:
|
||||
collection_types = [t.strip() for t in args.types.split(",")]
|
||||
|
||||
tool = QdrantLegacyDataPreparationTool(
|
||||
workspace=args.workspace,
|
||||
batch_size=args.batch_size,
|
||||
dry_run=args.dry_run,
|
||||
clear_target=args.clear_target,
|
||||
)
|
||||
|
||||
await tool.run(collection_types=collection_types)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -425,6 +425,9 @@ class EmbeddingFunc:
|
||||
send_dimensions: bool = (
|
||||
False # Control whether to send embedding_dim to the function
|
||||
)
|
||||
model_name: str | None = (
|
||||
None # Model name for implementating workspace data isolation in vector DB
|
||||
)
|
||||
|
||||
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
||||
# Only inject embedding_dim when send_dimensions is True
|
||||
@@ -1016,42 +1019,36 @@ def wrap_embedding_func_with_attrs(**kwargs):
|
||||
|
||||
Correct usage patterns:
|
||||
|
||||
1. Direct implementation (decorated):
|
||||
1. Direct decoration:
|
||||
```python
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536)
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192, model_name="my_embedding_model")
|
||||
async def my_embed(texts, embedding_dim=None):
|
||||
# Direct implementation
|
||||
return embeddings
|
||||
```
|
||||
|
||||
2. Wrapper calling decorated function (DO NOT decorate wrapper):
|
||||
2. Double decoration:
|
||||
```python
|
||||
# my_embed is already decorated above
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192, model_name="my_embedding_model")
|
||||
@retry(...)
|
||||
async def openai_embed(texts, ...):
|
||||
# Base implementation
|
||||
pass
|
||||
|
||||
async def my_wrapper(texts, **kwargs): # ❌ DO NOT decorate this!
|
||||
# Must call .func to access unwrapped implementation
|
||||
return await my_embed.func(texts, **kwargs)
|
||||
```
|
||||
|
||||
3. Wrapper calling decorated function (properly decorated):
|
||||
```python
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536)
|
||||
async def my_wrapper(texts, **kwargs): # ✅ Can decorate if calling .func
|
||||
# Calling .func avoids double decoration
|
||||
return await my_embed.func(texts, **kwargs)
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=4096, model_name="another_embedding_model")
|
||||
# Note: No @retry here!
|
||||
async def new_openai_embed(texts, ...):
|
||||
# CRITICAL: Call .func to access unwrapped function
|
||||
return await openai_embed.func(texts, ...) # ✅ Correct
|
||||
# return await openai_embed(texts, ...) # ❌ Wrong - double decoration!
|
||||
```
|
||||
|
||||
The decorated function becomes an EmbeddingFunc instance with:
|
||||
- embedding_dim: The embedding dimension
|
||||
- max_token_size: Maximum token limit (optional)
|
||||
- model_name: Model name (optional)
|
||||
- func: The original unwrapped function (access via .func)
|
||||
- __call__: Wrapper that injects embedding_dim parameter
|
||||
|
||||
Double decoration causes:
|
||||
- Double injection of embedding_dim parameter
|
||||
- Incorrect parameter passing to the underlying implementation
|
||||
- Runtime errors due to parameter conflicts
|
||||
|
||||
Args:
|
||||
embedding_dim: The dimension of embedding vectors
|
||||
max_token_size: Maximum number of tokens (optional)
|
||||
@@ -1059,21 +1056,6 @@ def wrap_embedding_func_with_attrs(**kwargs):
|
||||
|
||||
Returns:
|
||||
A decorator that wraps the function as an EmbeddingFunc instance
|
||||
|
||||
Example of correct wrapper implementation:
|
||||
```python
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||
@retry(...)
|
||||
async def openai_embed(texts, ...):
|
||||
# Base implementation
|
||||
pass
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536) # Note: No @retry here!
|
||||
async def azure_openai_embed(texts, ...):
|
||||
# CRITICAL: Call .func to access unwrapped function
|
||||
return await openai_embed.func(texts, ...) # ✅ Correct
|
||||
# return await openai_embed(texts, ...) # ❌ Wrong - double decoration!
|
||||
```
|
||||
"""
|
||||
|
||||
def final_decro(func) -> EmbeddingFunc:
|
||||
|
||||
@@ -72,7 +72,6 @@ api = [
|
||||
# API-specific dependencies
|
||||
"aiofiles",
|
||||
"ascii_colors",
|
||||
"asyncpg",
|
||||
"distro",
|
||||
"fastapi",
|
||||
"httpcore",
|
||||
@@ -108,7 +107,8 @@ offline-storage = [
|
||||
"neo4j>=5.0.0,<7.0.0",
|
||||
"pymilvus>=2.6.2,<3.0.0",
|
||||
"pymongo>=4.0.0,<5.0.0",
|
||||
"asyncpg>=0.29.0,<1.0.0",
|
||||
"asyncpg>=0.31.0,<1.0.0",
|
||||
"pgvector>=0.4.2,<1.0.0",
|
||||
"qdrant-client>=1.11.0,<2.0.0",
|
||||
]
|
||||
|
||||
|
||||
@@ -8,8 +8,9 @@
|
||||
# Or use constraints: pip install --constraint constraints-offline.txt -r requirements-offline-storage.txt
|
||||
|
||||
# Storage backend dependencies (with version constraints matching pyproject.toml)
|
||||
asyncpg>=0.29.0,<1.0.0
|
||||
asyncpg>=0.31.0,<1.0.0
|
||||
neo4j>=5.0.0,<7.0.0
|
||||
pgvector>=0.4.2,<1.0.0
|
||||
pymilvus>=2.6.2,<3.0.0
|
||||
pymongo>=4.0.0,<5.0.0
|
||||
qdrant-client>=1.11.0,<2.0.0
|
||||
|
||||
@@ -7,20 +7,17 @@
|
||||
# Recommended: Use pip install lightrag-hku[offline] for the same effect
|
||||
# Or use constraints: pip install --constraint constraints-offline.txt -r requirements-offline.txt
|
||||
|
||||
# LLM provider dependencies (with version constraints matching pyproject.toml)
|
||||
aioboto3>=12.0.0,<16.0.0
|
||||
anthropic>=0.18.0,<1.0.0
|
||||
|
||||
# Storage backend dependencies
|
||||
asyncpg>=0.29.0,<1.0.0
|
||||
asyncpg>=0.31.0,<1.0.0
|
||||
google-api-core>=2.0.0,<3.0.0
|
||||
google-genai>=1.0.0,<2.0.0
|
||||
|
||||
# Document processing dependencies
|
||||
llama-index>=0.9.0,<1.0.0
|
||||
neo4j>=5.0.0,<7.0.0
|
||||
ollama>=0.1.0,<1.0.0
|
||||
openai>=2.0.0,<3.0.0
|
||||
openpyxl>=3.0.0,<4.0.0
|
||||
pgvector>=0.4.2,<1.0.0
|
||||
pycryptodome>=3.0.0,<4.0.0
|
||||
pymilvus>=2.6.2,<3.0.0
|
||||
pymongo>=4.0.0,<5.0.0
|
||||
|
||||
377
tests/test_dimension_mismatch.py
Normal file
377
tests/test_dimension_mismatch.py
Normal file
@@ -0,0 +1,377 @@
|
||||
"""
|
||||
Tests for dimension mismatch handling during migration.
|
||||
|
||||
This test module verifies that both PostgreSQL and Qdrant storage backends
|
||||
properly detect and handle vector dimension mismatches when migrating from
|
||||
legacy collections/tables to new ones with different embedding models.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
from lightrag.kg.qdrant_impl import QdrantVectorDBStorage
|
||||
from lightrag.kg.postgres_impl import PGVectorStorage
|
||||
from lightrag.exceptions import DataMigrationError
|
||||
|
||||
|
||||
# Note: Tests should use proper table names that have DDL templates
|
||||
# Valid base tables: LIGHTRAG_VDB_CHUNKS, LIGHTRAG_VDB_ENTITIES, LIGHTRAG_VDB_RELATIONSHIPS,
|
||||
# LIGHTRAG_DOC_CHUNKS, LIGHTRAG_DOC_FULL_DOCS, LIGHTRAG_DOC_TEXT_CHUNKS
|
||||
|
||||
|
||||
class TestQdrantDimensionMismatch:
|
||||
"""Test suite for Qdrant dimension mismatch handling."""
|
||||
|
||||
def test_qdrant_dimension_mismatch_raises_error(self):
|
||||
"""
|
||||
Test that Qdrant raises DataMigrationError when dimensions don't match.
|
||||
|
||||
Scenario: Legacy collection has 1536d vectors, new model expects 3072d.
|
||||
Expected: DataMigrationError is raised to prevent data corruption.
|
||||
"""
|
||||
from qdrant_client import models
|
||||
|
||||
# Setup mock client
|
||||
client = MagicMock()
|
||||
|
||||
# Mock legacy collection with 1536d vectors
|
||||
legacy_collection_info = MagicMock()
|
||||
legacy_collection_info.config.params.vectors.size = 1536
|
||||
|
||||
# Setup collection existence checks
|
||||
def collection_exists_side_effect(name):
|
||||
if (
|
||||
name == "lightrag_vdb_chunks"
|
||||
): # legacy (matches _find_legacy_collection pattern)
|
||||
return True
|
||||
elif name == "lightrag_chunks_model_3072d": # new
|
||||
return False
|
||||
return False
|
||||
|
||||
client.collection_exists.side_effect = collection_exists_side_effect
|
||||
client.get_collection.return_value = legacy_collection_info
|
||||
client.count.return_value.count = 100 # Legacy has data
|
||||
|
||||
# Patch _find_legacy_collection to return the legacy collection name
|
||||
with patch(
|
||||
"lightrag.kg.qdrant_impl._find_legacy_collection",
|
||||
return_value="lightrag_vdb_chunks",
|
||||
):
|
||||
# Call setup_collection with 3072d (different from legacy 1536d)
|
||||
# Should raise DataMigrationError due to dimension mismatch
|
||||
with pytest.raises(DataMigrationError) as exc_info:
|
||||
QdrantVectorDBStorage.setup_collection(
|
||||
client,
|
||||
"lightrag_chunks_model_3072d",
|
||||
namespace="chunks",
|
||||
workspace="test",
|
||||
vectors_config=models.VectorParams(
|
||||
size=3072, distance=models.Distance.COSINE
|
||||
),
|
||||
hnsw_config=models.HnswConfigDiff(
|
||||
payload_m=16,
|
||||
m=0,
|
||||
),
|
||||
model_suffix="model_3072d",
|
||||
)
|
||||
|
||||
# Verify error message contains dimension information
|
||||
assert "3072" in str(exc_info.value) or "1536" in str(exc_info.value)
|
||||
|
||||
# Verify new collection was NOT created (error raised before creation)
|
||||
client.create_collection.assert_not_called()
|
||||
|
||||
# Verify migration was NOT attempted
|
||||
client.scroll.assert_not_called()
|
||||
client.upsert.assert_not_called()
|
||||
|
||||
def test_qdrant_dimension_match_proceed_migration(self):
|
||||
"""
|
||||
Test that Qdrant proceeds with migration when dimensions match.
|
||||
|
||||
Scenario: Legacy collection has 1536d vectors, new model also expects 1536d.
|
||||
Expected: Migration proceeds normally.
|
||||
"""
|
||||
from qdrant_client import models
|
||||
|
||||
client = MagicMock()
|
||||
|
||||
# Mock legacy collection with 1536d vectors (matching new)
|
||||
legacy_collection_info = MagicMock()
|
||||
legacy_collection_info.config.params.vectors.size = 1536
|
||||
|
||||
def collection_exists_side_effect(name):
|
||||
if name == "lightrag_chunks": # legacy
|
||||
return True
|
||||
elif name == "lightrag_chunks_model_1536d": # new
|
||||
return False
|
||||
return False
|
||||
|
||||
client.collection_exists.side_effect = collection_exists_side_effect
|
||||
client.get_collection.return_value = legacy_collection_info
|
||||
|
||||
# Track whether upsert has been called (migration occurred)
|
||||
migration_done = {"value": False}
|
||||
|
||||
def upsert_side_effect(*args, **kwargs):
|
||||
migration_done["value"] = True
|
||||
return MagicMock()
|
||||
|
||||
client.upsert.side_effect = upsert_side_effect
|
||||
|
||||
# Mock count to return different values based on collection name and migration state
|
||||
# Before migration: new collection has 0 records
|
||||
# After migration: new collection has 1 record (matching migrated data)
|
||||
def count_side_effect(collection_name, **kwargs):
|
||||
result = MagicMock()
|
||||
if collection_name == "lightrag_chunks": # legacy
|
||||
result.count = 1 # Legacy has 1 record
|
||||
elif collection_name == "lightrag_chunks_model_1536d": # new
|
||||
# Return 0 before migration, 1 after migration
|
||||
result.count = 1 if migration_done["value"] else 0
|
||||
else:
|
||||
result.count = 0
|
||||
return result
|
||||
|
||||
client.count.side_effect = count_side_effect
|
||||
|
||||
# Mock scroll to return sample data (1 record for easier verification)
|
||||
sample_point = MagicMock()
|
||||
sample_point.id = "test_id"
|
||||
sample_point.vector = [0.1] * 1536
|
||||
sample_point.payload = {"id": "test"}
|
||||
client.scroll.return_value = ([sample_point], None)
|
||||
|
||||
# Mock _find_legacy_collection to return the legacy collection name
|
||||
with patch(
|
||||
"lightrag.kg.qdrant_impl._find_legacy_collection",
|
||||
return_value="lightrag_chunks",
|
||||
):
|
||||
# Call setup_collection with matching 1536d
|
||||
QdrantVectorDBStorage.setup_collection(
|
||||
client,
|
||||
"lightrag_chunks_model_1536d",
|
||||
namespace="chunks",
|
||||
workspace="test",
|
||||
vectors_config=models.VectorParams(
|
||||
size=1536, distance=models.Distance.COSINE
|
||||
),
|
||||
hnsw_config=models.HnswConfigDiff(
|
||||
payload_m=16,
|
||||
m=0,
|
||||
),
|
||||
model_suffix="model_1536d",
|
||||
)
|
||||
|
||||
# Verify migration WAS attempted
|
||||
client.create_collection.assert_called_once()
|
||||
client.scroll.assert_called()
|
||||
client.upsert.assert_called()
|
||||
|
||||
|
||||
class TestPostgresDimensionMismatch:
|
||||
"""Test suite for PostgreSQL dimension mismatch handling."""
|
||||
|
||||
async def test_postgres_dimension_mismatch_raises_error_metadata(self):
|
||||
"""
|
||||
Test that PostgreSQL raises DataMigrationError when dimensions don't match.
|
||||
|
||||
Scenario: Legacy table has 1536d vectors, new model expects 3072d.
|
||||
Expected: DataMigrationError is raised to prevent data corruption.
|
||||
"""
|
||||
# Setup mock database
|
||||
db = AsyncMock()
|
||||
|
||||
# Mock check_table_exists
|
||||
async def mock_check_table_exists(table_name):
|
||||
if table_name == "LIGHTRAG_DOC_CHUNKS": # legacy
|
||||
return True
|
||||
elif table_name == "LIGHTRAG_DOC_CHUNKS_model_3072d": # new
|
||||
return False
|
||||
return False
|
||||
|
||||
db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
|
||||
|
||||
# Mock table existence and dimension checks
|
||||
async def query_side_effect(query, params, **kwargs):
|
||||
if "COUNT(*)" in query:
|
||||
return {"count": 100} # Legacy has data
|
||||
elif "SELECT content_vector FROM" in query:
|
||||
# Return sample vector with 1536 dimensions
|
||||
return {"content_vector": [0.1] * 1536}
|
||||
return {}
|
||||
|
||||
db.query.side_effect = query_side_effect
|
||||
db.execute = AsyncMock()
|
||||
db._create_vector_index = AsyncMock()
|
||||
|
||||
# Call setup_table with 3072d (different from legacy 1536d)
|
||||
# Should raise DataMigrationError due to dimension mismatch
|
||||
with pytest.raises(DataMigrationError) as exc_info:
|
||||
await PGVectorStorage.setup_table(
|
||||
db,
|
||||
"LIGHTRAG_DOC_CHUNKS_model_3072d",
|
||||
legacy_table_name="LIGHTRAG_DOC_CHUNKS",
|
||||
base_table="LIGHTRAG_DOC_CHUNKS",
|
||||
embedding_dim=3072,
|
||||
workspace="test",
|
||||
)
|
||||
|
||||
# Verify error message contains dimension information
|
||||
assert "3072" in str(exc_info.value) or "1536" in str(exc_info.value)
|
||||
|
||||
async def test_postgres_dimension_mismatch_raises_error_sampling(self):
|
||||
"""
|
||||
Test that PostgreSQL raises error when dimensions don't match (via sampling).
|
||||
|
||||
Scenario: Legacy table vector sampling detects 1536d vs expected 3072d.
|
||||
Expected: DataMigrationError is raised to prevent data corruption.
|
||||
"""
|
||||
db = AsyncMock()
|
||||
|
||||
# Mock check_table_exists
|
||||
async def mock_check_table_exists(table_name):
|
||||
if table_name == "LIGHTRAG_DOC_CHUNKS": # legacy
|
||||
return True
|
||||
elif table_name == "LIGHTRAG_DOC_CHUNKS_model_3072d": # new
|
||||
return False
|
||||
return False
|
||||
|
||||
db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
|
||||
|
||||
# Mock table existence and dimension checks
|
||||
async def query_side_effect(query, params, **kwargs):
|
||||
if "information_schema.tables" in query:
|
||||
if params[0] == "LIGHTRAG_DOC_CHUNKS": # legacy
|
||||
return {"exists": True}
|
||||
elif params[0] == "LIGHTRAG_DOC_CHUNKS_model_3072d": # new
|
||||
return {"exists": False}
|
||||
elif "COUNT(*)" in query:
|
||||
return {"count": 100} # Legacy has data
|
||||
elif "SELECT content_vector FROM" in query:
|
||||
# Return sample vector with 1536 dimensions as a JSON string
|
||||
return {"content_vector": json.dumps([0.1] * 1536)}
|
||||
return {}
|
||||
|
||||
db.query.side_effect = query_side_effect
|
||||
db.execute = AsyncMock()
|
||||
db._create_vector_index = AsyncMock()
|
||||
|
||||
# Call setup_table with 3072d (different from legacy 1536d)
|
||||
# Should raise DataMigrationError due to dimension mismatch
|
||||
with pytest.raises(DataMigrationError) as exc_info:
|
||||
await PGVectorStorage.setup_table(
|
||||
db,
|
||||
"LIGHTRAG_DOC_CHUNKS_model_3072d",
|
||||
legacy_table_name="LIGHTRAG_DOC_CHUNKS",
|
||||
base_table="LIGHTRAG_DOC_CHUNKS",
|
||||
embedding_dim=3072,
|
||||
workspace="test",
|
||||
)
|
||||
|
||||
# Verify error message contains dimension information
|
||||
assert "3072" in str(exc_info.value) or "1536" in str(exc_info.value)
|
||||
|
||||
async def test_postgres_dimension_match_proceed_migration(self):
|
||||
"""
|
||||
Test that PostgreSQL proceeds with migration when dimensions match.
|
||||
|
||||
Scenario: Legacy table has 1536d vectors, new model also expects 1536d.
|
||||
Expected: Migration proceeds normally.
|
||||
"""
|
||||
db = AsyncMock()
|
||||
|
||||
# Track migration state
|
||||
migration_done = {"value": False}
|
||||
|
||||
# Define exactly 2 records for consistency
|
||||
mock_records = [
|
||||
{
|
||||
"id": "test1",
|
||||
"content_vector": [0.1] * 1536,
|
||||
"workspace": "test",
|
||||
},
|
||||
{
|
||||
"id": "test2",
|
||||
"content_vector": [0.2] * 1536,
|
||||
"workspace": "test",
|
||||
},
|
||||
]
|
||||
|
||||
# Mock check_table_exists
|
||||
async def mock_check_table_exists(table_name):
|
||||
if table_name == "LIGHTRAG_DOC_CHUNKS": # legacy exists
|
||||
return True
|
||||
elif table_name == "LIGHTRAG_DOC_CHUNKS_model_1536d": # new doesn't exist
|
||||
return False
|
||||
return False
|
||||
|
||||
db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
|
||||
|
||||
async def query_side_effect(query, params, **kwargs):
|
||||
multirows = kwargs.get("multirows", False)
|
||||
query_upper = query.upper()
|
||||
|
||||
if "information_schema.tables" in query:
|
||||
if params[0] == "LIGHTRAG_DOC_CHUNKS": # legacy
|
||||
return {"exists": True}
|
||||
elif params[0] == "LIGHTRAG_DOC_CHUNKS_model_1536d": # new
|
||||
return {"exists": False}
|
||||
elif "COUNT(*)" in query_upper:
|
||||
# Return different counts based on table name in query and migration state
|
||||
if "LIGHTRAG_DOC_CHUNKS_MODEL_1536D" in query_upper:
|
||||
# After migration: return migrated count, before: return 0
|
||||
return {
|
||||
"count": len(mock_records) if migration_done["value"] else 0
|
||||
}
|
||||
# Legacy table always has 2 records (matching mock_records)
|
||||
return {"count": len(mock_records)}
|
||||
elif "PG_ATTRIBUTE" in query_upper:
|
||||
return {"vector_dim": 1536} # Legacy has matching 1536d
|
||||
elif "SELECT" in query_upper and "FROM" in query_upper and multirows:
|
||||
# Return sample data for migration using keyset pagination
|
||||
# Handle keyset pagination: params = [workspace, limit] or [workspace, last_id, limit]
|
||||
if "id >" in query.lower():
|
||||
# Keyset pagination: params = [workspace, last_id, limit]
|
||||
last_id = params[1] if len(params) > 1 else None
|
||||
# Find records after last_id
|
||||
found_idx = -1
|
||||
for i, rec in enumerate(mock_records):
|
||||
if rec["id"] == last_id:
|
||||
found_idx = i
|
||||
break
|
||||
if found_idx >= 0:
|
||||
return mock_records[found_idx + 1 :]
|
||||
return []
|
||||
else:
|
||||
# First batch: params = [workspace, limit]
|
||||
return mock_records
|
||||
return {}
|
||||
|
||||
db.query.side_effect = query_side_effect
|
||||
|
||||
# Mock _run_with_retry to track when migration happens
|
||||
migration_executed = []
|
||||
|
||||
async def mock_run_with_retry(operation, *args, **kwargs):
|
||||
migration_executed.append(True)
|
||||
migration_done["value"] = True
|
||||
return None
|
||||
|
||||
db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry)
|
||||
db.execute = AsyncMock()
|
||||
db._create_vector_index = AsyncMock()
|
||||
|
||||
# Call setup_table with matching 1536d
|
||||
await PGVectorStorage.setup_table(
|
||||
db,
|
||||
"LIGHTRAG_DOC_CHUNKS_model_1536d",
|
||||
legacy_table_name="LIGHTRAG_DOC_CHUNKS",
|
||||
base_table="LIGHTRAG_DOC_CHUNKS",
|
||||
embedding_dim=1536,
|
||||
workspace="test",
|
||||
)
|
||||
|
||||
# Verify migration WAS called (via _run_with_retry for batch operations)
|
||||
assert len(migration_executed) > 0, "Migration should have been executed"
|
||||
220
tests/test_no_model_suffix_safety.py
Normal file
220
tests/test_no_model_suffix_safety.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""
|
||||
Tests for safety when model suffix is absent (no model_name provided).
|
||||
|
||||
This test module verifies that the system correctly handles the case when
|
||||
no model_name is provided, preventing accidental deletion of the only table/collection
|
||||
on restart.
|
||||
|
||||
Critical Bug: When model_suffix is empty, table_name == legacy_table_name.
|
||||
On second startup, Case 1 logic would delete the only table/collection thinking
|
||||
it's "legacy", causing all subsequent operations to fail.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
from lightrag.kg.qdrant_impl import QdrantVectorDBStorage
|
||||
from lightrag.kg.postgres_impl import PGVectorStorage
|
||||
|
||||
|
||||
class TestNoModelSuffixSafety:
|
||||
"""Test suite for preventing data loss when model_suffix is absent."""
|
||||
|
||||
def test_qdrant_no_suffix_second_startup(self):
|
||||
"""
|
||||
Test Qdrant doesn't delete collection on second startup when no model_name.
|
||||
|
||||
Scenario:
|
||||
1. First startup: Creates collection without suffix
|
||||
2. Collection is empty
|
||||
3. Second startup: Should NOT delete the collection
|
||||
|
||||
Bug: Without fix, Case 1 would delete the only collection.
|
||||
"""
|
||||
from qdrant_client import models
|
||||
|
||||
client = MagicMock()
|
||||
|
||||
# Simulate second startup: collection already exists and is empty
|
||||
# IMPORTANT: Without suffix, collection_name == legacy collection name
|
||||
collection_name = "lightrag_vdb_chunks" # No suffix, same as legacy
|
||||
|
||||
# Both exist (they're the same collection)
|
||||
client.collection_exists.return_value = True
|
||||
|
||||
# Collection is empty
|
||||
client.count.return_value.count = 0
|
||||
|
||||
# Patch _find_legacy_collection to return the SAME collection name
|
||||
# This simulates the scenario where new collection == legacy collection
|
||||
with patch(
|
||||
"lightrag.kg.qdrant_impl._find_legacy_collection",
|
||||
return_value="lightrag_vdb_chunks", # Same as collection_name
|
||||
):
|
||||
# Call setup_collection
|
||||
# This should detect that new == legacy and skip deletion
|
||||
QdrantVectorDBStorage.setup_collection(
|
||||
client,
|
||||
collection_name,
|
||||
namespace="chunks",
|
||||
workspace="_",
|
||||
vectors_config=models.VectorParams(
|
||||
size=1536, distance=models.Distance.COSINE
|
||||
),
|
||||
hnsw_config=models.HnswConfigDiff(
|
||||
payload_m=16,
|
||||
m=0,
|
||||
),
|
||||
model_suffix="", # Empty suffix to simulate no model_name provided
|
||||
)
|
||||
|
||||
# CRITICAL: Collection should NOT be deleted
|
||||
client.delete_collection.assert_not_called()
|
||||
|
||||
# Verify we returned early (skipped Case 1 cleanup)
|
||||
# The collection_exists was checked, but we didn't proceed to count
|
||||
# because we detected same name
|
||||
assert client.collection_exists.call_count >= 1
|
||||
|
||||
async def test_postgres_no_suffix_second_startup(self):
|
||||
"""
|
||||
Test PostgreSQL doesn't delete table on second startup when no model_name.
|
||||
|
||||
Scenario:
|
||||
1. First startup: Creates table without suffix
|
||||
2. Table is empty
|
||||
3. Second startup: Should NOT delete the table
|
||||
|
||||
Bug: Without fix, Case 1 would delete the only table.
|
||||
"""
|
||||
db = AsyncMock()
|
||||
|
||||
# Configure mock return values to avoid unawaited coroutine warnings
|
||||
db.query.return_value = {"count": 0}
|
||||
db._create_vector_index.return_value = None
|
||||
|
||||
# Simulate second startup: table already exists and is empty
|
||||
# IMPORTANT: table_name and legacy_table_name are THE SAME
|
||||
table_name = "LIGHTRAG_VDB_CHUNKS" # No suffix
|
||||
legacy_table_name = "LIGHTRAG_VDB_CHUNKS" # Same as new
|
||||
|
||||
# Setup mock responses using check_table_exists on db
|
||||
async def check_table_exists_side_effect(name):
|
||||
# Both tables exist (they're the same)
|
||||
return True
|
||||
|
||||
db.check_table_exists = AsyncMock(side_effect=check_table_exists_side_effect)
|
||||
|
||||
# Call setup_table
|
||||
# This should detect that new == legacy and skip deletion
|
||||
await PGVectorStorage.setup_table(
|
||||
db,
|
||||
table_name,
|
||||
workspace="test_workspace",
|
||||
embedding_dim=1536,
|
||||
legacy_table_name=legacy_table_name,
|
||||
base_table="LIGHTRAG_VDB_CHUNKS",
|
||||
)
|
||||
|
||||
# CRITICAL: Table should NOT be deleted (no DROP TABLE)
|
||||
drop_calls = [
|
||||
call
|
||||
for call in db.execute.call_args_list
|
||||
if call[0][0] and "DROP TABLE" in call[0][0]
|
||||
]
|
||||
assert (
|
||||
len(drop_calls) == 0
|
||||
), "Should not drop table when new and legacy are the same"
|
||||
|
||||
# Note: COUNT queries for workspace data are expected behavior in Case 1
|
||||
# (for logging/warning purposes when workspace data is empty).
|
||||
# The critical safety check is that DROP TABLE is not called.
|
||||
|
||||
def test_qdrant_with_suffix_case1_still_works(self):
|
||||
"""
|
||||
Test that Case 1 cleanup still works when there IS a suffix.
|
||||
|
||||
This ensures our fix doesn't break the normal Case 1 scenario.
|
||||
"""
|
||||
from qdrant_client import models
|
||||
|
||||
client = MagicMock()
|
||||
|
||||
# Different names (normal case)
|
||||
collection_name = "lightrag_vdb_chunks_ada_002_1536d" # With suffix
|
||||
legacy_collection = "lightrag_vdb_chunks" # Without suffix
|
||||
|
||||
# Setup: both exist
|
||||
def collection_exists_side_effect(name):
|
||||
return name in [collection_name, legacy_collection]
|
||||
|
||||
client.collection_exists.side_effect = collection_exists_side_effect
|
||||
|
||||
# Legacy is empty
|
||||
client.count.return_value.count = 0
|
||||
|
||||
# Call setup_collection
|
||||
QdrantVectorDBStorage.setup_collection(
|
||||
client,
|
||||
collection_name,
|
||||
namespace="chunks",
|
||||
workspace="_",
|
||||
vectors_config=models.VectorParams(
|
||||
size=1536, distance=models.Distance.COSINE
|
||||
),
|
||||
hnsw_config=models.HnswConfigDiff(
|
||||
payload_m=16,
|
||||
m=0,
|
||||
),
|
||||
model_suffix="ada_002_1536d",
|
||||
)
|
||||
|
||||
# SHOULD delete legacy (normal Case 1 behavior)
|
||||
client.delete_collection.assert_called_once_with(
|
||||
collection_name=legacy_collection
|
||||
)
|
||||
|
||||
async def test_postgres_with_suffix_case1_still_works(self):
|
||||
"""
|
||||
Test that Case 1 cleanup still works when there IS a suffix.
|
||||
|
||||
This ensures our fix doesn't break the normal Case 1 scenario.
|
||||
"""
|
||||
db = AsyncMock()
|
||||
|
||||
# Different names (normal case)
|
||||
table_name = "LIGHTRAG_VDB_CHUNKS_ADA_002_1536D" # With suffix
|
||||
legacy_table_name = "LIGHTRAG_VDB_CHUNKS" # Without suffix
|
||||
|
||||
# Setup mock responses using check_table_exists on db
|
||||
async def check_table_exists_side_effect(name):
|
||||
# Both tables exist
|
||||
return True
|
||||
|
||||
db.check_table_exists = AsyncMock(side_effect=check_table_exists_side_effect)
|
||||
|
||||
# Mock empty table
|
||||
async def query_side_effect(sql, params, **kwargs):
|
||||
if "COUNT(*)" in sql:
|
||||
return {"count": 0}
|
||||
return {}
|
||||
|
||||
db.query.side_effect = query_side_effect
|
||||
|
||||
# Call setup_table
|
||||
await PGVectorStorage.setup_table(
|
||||
db,
|
||||
table_name,
|
||||
workspace="test_workspace",
|
||||
embedding_dim=1536,
|
||||
legacy_table_name=legacy_table_name,
|
||||
base_table="LIGHTRAG_VDB_CHUNKS",
|
||||
)
|
||||
|
||||
# SHOULD delete legacy (normal Case 1 behavior)
|
||||
drop_calls = [
|
||||
call
|
||||
for call in db.execute.call_args_list
|
||||
if call[0][0] and "DROP TABLE" in call[0][0]
|
||||
]
|
||||
assert len(drop_calls) == 1, "Should drop legacy table in normal Case 1"
|
||||
assert legacy_table_name in drop_calls[0][0][0]
|
||||
210
tests/test_postgres_index_name.py
Normal file
210
tests/test_postgres_index_name.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""
|
||||
Unit tests for PostgreSQL safe index name generation.
|
||||
|
||||
This module tests the _safe_index_name helper function which prevents
|
||||
PostgreSQL's silent 63-byte identifier truncation from causing index
|
||||
lookup failures.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
# Mark all tests as offline (no external dependencies)
|
||||
pytestmark = pytest.mark.offline
|
||||
|
||||
|
||||
class TestSafeIndexName:
|
||||
"""Test suite for _safe_index_name function."""
|
||||
|
||||
def test_short_name_unchanged(self):
|
||||
"""Short index names should remain unchanged."""
|
||||
from lightrag.kg.postgres_impl import _safe_index_name
|
||||
|
||||
# Short table name - should return unchanged
|
||||
result = _safe_index_name("lightrag_vdb_entity", "hnsw_cosine")
|
||||
assert result == "idx_lightrag_vdb_entity_hnsw_cosine"
|
||||
assert len(result.encode("utf-8")) <= 63
|
||||
|
||||
def test_long_name_gets_hashed(self):
|
||||
"""Long table names exceeding 63 bytes should get hashed."""
|
||||
from lightrag.kg.postgres_impl import _safe_index_name
|
||||
|
||||
# Long table name that would exceed 63 bytes
|
||||
long_table_name = "LIGHTRAG_VDB_ENTITY_text_embedding_3_large_3072d"
|
||||
result = _safe_index_name(long_table_name, "hnsw_cosine")
|
||||
|
||||
# Should be within 63 bytes
|
||||
assert len(result.encode("utf-8")) <= 63
|
||||
|
||||
# Should start with idx_ prefix
|
||||
assert result.startswith("idx_")
|
||||
|
||||
# Should contain the suffix
|
||||
assert result.endswith("_hnsw_cosine")
|
||||
|
||||
# Should NOT be the naive concatenation (which would be truncated)
|
||||
naive_name = f"idx_{long_table_name.lower()}_hnsw_cosine"
|
||||
assert result != naive_name
|
||||
|
||||
def test_deterministic_output(self):
|
||||
"""Same input should always produce same output (deterministic)."""
|
||||
from lightrag.kg.postgres_impl import _safe_index_name
|
||||
|
||||
table_name = "LIGHTRAG_VDB_CHUNKS_text_embedding_3_large_3072d"
|
||||
suffix = "hnsw_cosine"
|
||||
|
||||
result1 = _safe_index_name(table_name, suffix)
|
||||
result2 = _safe_index_name(table_name, suffix)
|
||||
|
||||
assert result1 == result2
|
||||
|
||||
def test_different_suffixes_different_results(self):
|
||||
"""Different suffixes should produce different index names."""
|
||||
from lightrag.kg.postgres_impl import _safe_index_name
|
||||
|
||||
table_name = "LIGHTRAG_VDB_ENTITY_text_embedding_3_large_3072d"
|
||||
|
||||
result1 = _safe_index_name(table_name, "hnsw_cosine")
|
||||
result2 = _safe_index_name(table_name, "ivfflat_cosine")
|
||||
|
||||
assert result1 != result2
|
||||
|
||||
def test_case_insensitive(self):
|
||||
"""Table names should be normalized to lowercase."""
|
||||
from lightrag.kg.postgres_impl import _safe_index_name
|
||||
|
||||
result_upper = _safe_index_name("LIGHTRAG_VDB_ENTITY", "hnsw_cosine")
|
||||
result_lower = _safe_index_name("lightrag_vdb_entity", "hnsw_cosine")
|
||||
|
||||
assert result_upper == result_lower
|
||||
|
||||
def test_boundary_case_exactly_63_bytes(self):
|
||||
"""Test boundary case where name is exactly at 63-byte limit."""
|
||||
from lightrag.kg.postgres_impl import _safe_index_name
|
||||
|
||||
# Create a table name that results in exactly 63 bytes
|
||||
# idx_ (4) + table_name + _ (1) + suffix = 63
|
||||
# So table_name + suffix = 58
|
||||
|
||||
# Test a name that's just under the limit (should remain unchanged)
|
||||
short_suffix = "id"
|
||||
# idx_ (4) + 56 chars + _ (1) + id (2) = 63
|
||||
table_56 = "a" * 56
|
||||
result = _safe_index_name(table_56, short_suffix)
|
||||
expected = f"idx_{table_56}_{short_suffix}"
|
||||
assert result == expected
|
||||
assert len(result.encode("utf-8")) == 63
|
||||
|
||||
def test_unicode_handling(self):
|
||||
"""Unicode characters should be properly handled (bytes, not chars)."""
|
||||
from lightrag.kg.postgres_impl import _safe_index_name
|
||||
|
||||
# Unicode characters can take more bytes than visible chars
|
||||
# Chinese characters are 3 bytes each in UTF-8
|
||||
table_name = "lightrag_测试_table" # Contains Chinese chars
|
||||
result = _safe_index_name(table_name, "hnsw_cosine")
|
||||
|
||||
# Should always be within 63 bytes
|
||||
assert len(result.encode("utf-8")) <= 63
|
||||
|
||||
def test_real_world_model_names(self):
|
||||
"""Test with real-world embedding model names that cause issues."""
|
||||
from lightrag.kg.postgres_impl import _safe_index_name
|
||||
|
||||
# These are actual model names that have caused issues
|
||||
test_cases = [
|
||||
("LIGHTRAG_VDB_CHUNKS_text_embedding_3_large_3072d", "hnsw_cosine"),
|
||||
("LIGHTRAG_VDB_ENTITY_text_embedding_3_large_3072d", "hnsw_cosine"),
|
||||
("LIGHTRAG_VDB_RELATION_text_embedding_3_large_3072d", "hnsw_cosine"),
|
||||
(
|
||||
"LIGHTRAG_VDB_ENTITY_bge_m3_1024d",
|
||||
"hnsw_cosine",
|
||||
), # Shorter model name
|
||||
(
|
||||
"LIGHTRAG_VDB_CHUNKS_nomic_embed_text_v1_768d",
|
||||
"ivfflat_cosine",
|
||||
), # Different index type
|
||||
]
|
||||
|
||||
for table_name, suffix in test_cases:
|
||||
result = _safe_index_name(table_name, suffix)
|
||||
|
||||
# Critical: must be within PostgreSQL's 63-byte limit
|
||||
assert (
|
||||
len(result.encode("utf-8")) <= 63
|
||||
), f"Index name too long: {result} for table {table_name}"
|
||||
|
||||
# Must have consistent format
|
||||
assert result.startswith("idx_"), f"Missing idx_ prefix: {result}"
|
||||
assert result.endswith(f"_{suffix}"), f"Missing suffix {suffix}: {result}"
|
||||
|
||||
def test_hash_uniqueness_for_similar_tables(self):
|
||||
"""Similar but different table names should produce different hashes."""
|
||||
from lightrag.kg.postgres_impl import _safe_index_name
|
||||
|
||||
# These tables have similar names but should have different hashes
|
||||
tables = [
|
||||
"LIGHTRAG_VDB_CHUNKS_model_a_1024d",
|
||||
"LIGHTRAG_VDB_CHUNKS_model_b_1024d",
|
||||
"LIGHTRAG_VDB_ENTITY_model_a_1024d",
|
||||
]
|
||||
|
||||
results = [_safe_index_name(t, "hnsw_cosine") for t in tables]
|
||||
|
||||
# All results should be unique
|
||||
assert len(set(results)) == len(results), "Hash collision detected!"
|
||||
|
||||
|
||||
class TestIndexNameIntegration:
|
||||
"""Integration-style tests for index name usage patterns."""
|
||||
|
||||
def test_pg_indexes_lookup_compatibility(self):
|
||||
"""
|
||||
Test that the generated index name will work with pg_indexes lookup.
|
||||
|
||||
This is the core problem: PostgreSQL stores the truncated name,
|
||||
but we were looking up the untruncated name. Our fix ensures we
|
||||
always use a name that fits within 63 bytes.
|
||||
"""
|
||||
from lightrag.kg.postgres_impl import _safe_index_name
|
||||
|
||||
table_name = "LIGHTRAG_VDB_CHUNKS_text_embedding_3_large_3072d"
|
||||
suffix = "hnsw_cosine"
|
||||
|
||||
# Generate the index name
|
||||
index_name = _safe_index_name(table_name, suffix)
|
||||
|
||||
# Simulate what PostgreSQL would store (truncate at 63 bytes)
|
||||
stored_name = index_name.encode("utf-8")[:63].decode("utf-8", errors="ignore")
|
||||
|
||||
# The key fix: our generated name should equal the stored name
|
||||
# because it's already within the 63-byte limit
|
||||
assert (
|
||||
index_name == stored_name
|
||||
), "Index name would be truncated by PostgreSQL, causing lookup failures!"
|
||||
|
||||
def test_backward_compatibility_short_names(self):
|
||||
"""
|
||||
Ensure backward compatibility with existing short index names.
|
||||
|
||||
For tables that have existing indexes with short names (pre-model-suffix era),
|
||||
the function should not change their names.
|
||||
"""
|
||||
from lightrag.kg.postgres_impl import _safe_index_name
|
||||
|
||||
# Legacy table names without model suffix
|
||||
legacy_tables = [
|
||||
"LIGHTRAG_VDB_ENTITY",
|
||||
"LIGHTRAG_VDB_RELATION",
|
||||
"LIGHTRAG_VDB_CHUNKS",
|
||||
]
|
||||
|
||||
for table in legacy_tables:
|
||||
for suffix in ["hnsw_cosine", "ivfflat_cosine", "id"]:
|
||||
result = _safe_index_name(table, suffix)
|
||||
expected = f"idx_{table.lower()}_{suffix}"
|
||||
|
||||
# Short names should remain unchanged for backward compatibility
|
||||
if len(expected.encode("utf-8")) <= 63:
|
||||
assert (
|
||||
result == expected
|
||||
), f"Short name changed unexpectedly: {result} != {expected}"
|
||||
856
tests/test_postgres_migration.py
Normal file
856
tests/test_postgres_migration.py
Normal file
@@ -0,0 +1,856 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, AsyncMock
|
||||
import numpy as np
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from lightrag.kg.postgres_impl import (
|
||||
PGVectorStorage,
|
||||
)
|
||||
from lightrag.namespace import NameSpace
|
||||
|
||||
|
||||
# Mock PostgreSQLDB
|
||||
@pytest.fixture
|
||||
def mock_pg_db():
|
||||
"""Mock PostgreSQL database connection"""
|
||||
db = AsyncMock()
|
||||
db.workspace = "test_workspace"
|
||||
|
||||
# Mock query responses with multirows support
|
||||
async def mock_query(sql, params=None, multirows=False, **kwargs):
|
||||
# Default return value
|
||||
if multirows:
|
||||
return [] # Return empty list for multirows
|
||||
return {"exists": False, "count": 0}
|
||||
|
||||
# Mock for execute that mimics PostgreSQLDB.execute() behavior
|
||||
async def mock_execute(sql, data=None, **kwargs):
|
||||
"""
|
||||
Mock that mimics PostgreSQLDB.execute() behavior:
|
||||
- Accepts data as dict[str, Any] | None (second parameter)
|
||||
- Internally converts dict.values() to tuple for AsyncPG
|
||||
"""
|
||||
# Mimic real execute() which accepts dict and converts to tuple
|
||||
if data is not None and not isinstance(data, dict):
|
||||
raise TypeError(
|
||||
f"PostgreSQLDB.execute() expects data as dict, got {type(data).__name__}"
|
||||
)
|
||||
return None
|
||||
|
||||
db.query = AsyncMock(side_effect=mock_query)
|
||||
db.execute = AsyncMock(side_effect=mock_execute)
|
||||
|
||||
return db
|
||||
|
||||
|
||||
# Mock get_data_init_lock to avoid async lock issues in tests
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_data_init_lock():
|
||||
with patch("lightrag.kg.postgres_impl.get_data_init_lock") as mock_lock:
|
||||
mock_lock_ctx = AsyncMock()
|
||||
mock_lock.return_value = mock_lock_ctx
|
||||
yield mock_lock
|
||||
|
||||
|
||||
# Mock ClientManager
|
||||
@pytest.fixture
|
||||
def mock_client_manager(mock_pg_db):
|
||||
with patch("lightrag.kg.postgres_impl.ClientManager") as mock_manager:
|
||||
mock_manager.get_client = AsyncMock(return_value=mock_pg_db)
|
||||
mock_manager.release_client = AsyncMock()
|
||||
yield mock_manager
|
||||
|
||||
|
||||
# Mock Embedding function
|
||||
@pytest.fixture
|
||||
def mock_embedding_func():
|
||||
async def embed_func(texts, **kwargs):
|
||||
return np.array([[0.1] * 768 for _ in texts])
|
||||
|
||||
func = EmbeddingFunc(embedding_dim=768, func=embed_func, model_name="test_model")
|
||||
return func
|
||||
|
||||
|
||||
async def test_postgres_table_naming(
|
||||
mock_client_manager, mock_pg_db, mock_embedding_func
|
||||
):
|
||||
"""Test if table name is correctly generated with model suffix"""
|
||||
config = {
|
||||
"embedding_batch_num": 10,
|
||||
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||
}
|
||||
|
||||
storage = PGVectorStorage(
|
||||
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||
global_config=config,
|
||||
embedding_func=mock_embedding_func,
|
||||
workspace="test_ws",
|
||||
)
|
||||
|
||||
# Verify table name contains model suffix
|
||||
expected_suffix = "test_model_768d"
|
||||
assert expected_suffix in storage.table_name
|
||||
assert storage.table_name == f"LIGHTRAG_VDB_CHUNKS_{expected_suffix}"
|
||||
|
||||
# Verify legacy table name
|
||||
assert storage.legacy_table_name == "LIGHTRAG_VDB_CHUNKS"
|
||||
|
||||
|
||||
async def test_postgres_migration_trigger(
|
||||
mock_client_manager, mock_pg_db, mock_embedding_func
|
||||
):
|
||||
"""Test if migration logic is triggered correctly"""
|
||||
config = {
|
||||
"embedding_batch_num": 10,
|
||||
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||
}
|
||||
|
||||
storage = PGVectorStorage(
|
||||
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||
global_config=config,
|
||||
embedding_func=mock_embedding_func,
|
||||
workspace="test_ws",
|
||||
)
|
||||
|
||||
# Setup mocks for migration scenario
|
||||
# 1. New table does not exist, legacy table exists
|
||||
async def mock_check_table_exists(table_name):
|
||||
return table_name == storage.legacy_table_name
|
||||
|
||||
mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
|
||||
|
||||
# 2. Legacy table has 100 records
|
||||
mock_rows = [
|
||||
{"id": f"test_id_{i}", "content": f"content_{i}", "workspace": "test_ws"}
|
||||
for i in range(100)
|
||||
]
|
||||
migration_state = {"new_table_count": 0}
|
||||
|
||||
async def mock_query(sql, params=None, multirows=False, **kwargs):
|
||||
if "COUNT(*)" in sql:
|
||||
sql_upper = sql.upper()
|
||||
legacy_table = storage.legacy_table_name.upper()
|
||||
new_table = storage.table_name.upper()
|
||||
is_new_table = new_table in sql_upper
|
||||
is_legacy_table = legacy_table in sql_upper and not is_new_table
|
||||
|
||||
if is_new_table:
|
||||
return {"count": migration_state["new_table_count"]}
|
||||
if is_legacy_table:
|
||||
return {"count": 100}
|
||||
return {"count": 0}
|
||||
elif multirows and "SELECT *" in sql:
|
||||
# Mock batch fetch for migration using keyset pagination
|
||||
# New pattern: WHERE workspace = $1 AND id > $2 ORDER BY id LIMIT $3
|
||||
# or first batch: WHERE workspace = $1 ORDER BY id LIMIT $2
|
||||
if "WHERE workspace" in sql:
|
||||
if "id >" in sql:
|
||||
# Keyset pagination: params = [workspace, last_id, limit]
|
||||
last_id = params[1] if len(params) > 1 else None
|
||||
# Find rows after last_id
|
||||
start_idx = 0
|
||||
for i, row in enumerate(mock_rows):
|
||||
if row["id"] == last_id:
|
||||
start_idx = i + 1
|
||||
break
|
||||
limit = params[2] if len(params) > 2 else 500
|
||||
else:
|
||||
# First batch (no last_id): params = [workspace, limit]
|
||||
start_idx = 0
|
||||
limit = params[1] if len(params) > 1 else 500
|
||||
else:
|
||||
# No workspace filter with keyset
|
||||
if "id >" in sql:
|
||||
last_id = params[0] if params else None
|
||||
start_idx = 0
|
||||
for i, row in enumerate(mock_rows):
|
||||
if row["id"] == last_id:
|
||||
start_idx = i + 1
|
||||
break
|
||||
limit = params[1] if len(params) > 1 else 500
|
||||
else:
|
||||
start_idx = 0
|
||||
limit = params[0] if params else 500
|
||||
end = min(start_idx + limit, len(mock_rows))
|
||||
return mock_rows[start_idx:end]
|
||||
return {}
|
||||
|
||||
mock_pg_db.query = AsyncMock(side_effect=mock_query)
|
||||
|
||||
# Track migration through _run_with_retry calls
|
||||
migration_executed = []
|
||||
|
||||
async def mock_run_with_retry(operation, **kwargs):
|
||||
# Track that migration batch operation was called
|
||||
migration_executed.append(True)
|
||||
migration_state["new_table_count"] = 100
|
||||
return None
|
||||
|
||||
mock_pg_db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry)
|
||||
|
||||
with patch(
|
||||
"lightrag.kg.postgres_impl.PGVectorStorage._pg_create_table", AsyncMock()
|
||||
):
|
||||
# Initialize storage (should trigger migration)
|
||||
await storage.initialize()
|
||||
|
||||
# Verify migration was executed by checking _run_with_retry was called
|
||||
# (batch migration uses _run_with_retry with executemany)
|
||||
assert len(migration_executed) > 0, "Migration should have been executed"
|
||||
|
||||
|
||||
async def test_postgres_no_migration_needed(
|
||||
mock_client_manager, mock_pg_db, mock_embedding_func
|
||||
):
|
||||
"""Test scenario where new table already exists (no migration needed)"""
|
||||
config = {
|
||||
"embedding_batch_num": 10,
|
||||
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||
}
|
||||
|
||||
storage = PGVectorStorage(
|
||||
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||
global_config=config,
|
||||
embedding_func=mock_embedding_func,
|
||||
workspace="test_ws",
|
||||
)
|
||||
|
||||
# Mock: new table already exists
|
||||
async def mock_check_table_exists(table_name):
|
||||
return table_name == storage.table_name
|
||||
|
||||
mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
|
||||
|
||||
with patch(
|
||||
"lightrag.kg.postgres_impl.PGVectorStorage._pg_create_table", AsyncMock()
|
||||
) as mock_create:
|
||||
await storage.initialize()
|
||||
|
||||
# Verify no table creation was attempted
|
||||
mock_create.assert_not_called()
|
||||
|
||||
|
||||
async def test_scenario_1_new_workspace_creation(
|
||||
mock_client_manager, mock_pg_db, mock_embedding_func
|
||||
):
|
||||
"""
|
||||
Scenario 1: New workspace creation
|
||||
|
||||
Expected behavior:
|
||||
- No legacy table exists
|
||||
- Directly create new table with model suffix
|
||||
- No migration needed
|
||||
"""
|
||||
config = {
|
||||
"embedding_batch_num": 10,
|
||||
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||
}
|
||||
|
||||
embedding_func = EmbeddingFunc(
|
||||
embedding_dim=3072,
|
||||
func=mock_embedding_func.func,
|
||||
model_name="text-embedding-3-large",
|
||||
)
|
||||
|
||||
storage = PGVectorStorage(
|
||||
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||
global_config=config,
|
||||
embedding_func=embedding_func,
|
||||
workspace="new_workspace",
|
||||
)
|
||||
|
||||
# Mock: neither table exists
|
||||
async def mock_check_table_exists(table_name):
|
||||
return False
|
||||
|
||||
mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
|
||||
|
||||
with patch(
|
||||
"lightrag.kg.postgres_impl.PGVectorStorage._pg_create_table", AsyncMock()
|
||||
) as mock_create:
|
||||
await storage.initialize()
|
||||
|
||||
# Verify table name format
|
||||
assert "text_embedding_3_large_3072d" in storage.table_name
|
||||
|
||||
# Verify new table creation was called
|
||||
mock_create.assert_called_once()
|
||||
call_args = mock_create.call_args
|
||||
assert (
|
||||
call_args[0][1] == storage.table_name
|
||||
) # table_name is second positional arg
|
||||
|
||||
|
||||
async def test_scenario_2_legacy_upgrade_migration(
|
||||
mock_client_manager, mock_pg_db, mock_embedding_func
|
||||
):
|
||||
"""
|
||||
Scenario 2: Upgrade from legacy version
|
||||
|
||||
Expected behavior:
|
||||
- Legacy table exists (without model suffix)
|
||||
- New table doesn't exist
|
||||
- Automatically migrate data to new table with suffix
|
||||
"""
|
||||
config = {
|
||||
"embedding_batch_num": 10,
|
||||
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||
}
|
||||
|
||||
embedding_func = EmbeddingFunc(
|
||||
embedding_dim=1536,
|
||||
func=mock_embedding_func.func,
|
||||
model_name="text-embedding-ada-002",
|
||||
)
|
||||
|
||||
storage = PGVectorStorage(
|
||||
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||
global_config=config,
|
||||
embedding_func=embedding_func,
|
||||
workspace="legacy_workspace",
|
||||
)
|
||||
|
||||
# Mock: only legacy table exists
|
||||
async def mock_check_table_exists(table_name):
|
||||
return table_name == storage.legacy_table_name
|
||||
|
||||
mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
|
||||
|
||||
# Mock: legacy table has 50 records
|
||||
mock_rows = [
|
||||
{
|
||||
"id": f"legacy_id_{i}",
|
||||
"content": f"legacy_content_{i}",
|
||||
"workspace": "legacy_workspace",
|
||||
}
|
||||
for i in range(50)
|
||||
]
|
||||
|
||||
# Track which queries have been made for proper response
|
||||
query_history = []
|
||||
migration_state = {"new_table_count": 0}
|
||||
|
||||
async def mock_query(sql, params=None, multirows=False, **kwargs):
|
||||
query_history.append(sql)
|
||||
|
||||
if "COUNT(*)" in sql:
|
||||
# Determine table type:
|
||||
# - Legacy: contains base name but NOT model suffix
|
||||
# - New: contains model suffix (e.g., text_embedding_ada_002_1536d)
|
||||
sql_upper = sql.upper()
|
||||
base_name = storage.legacy_table_name.upper()
|
||||
|
||||
# Check if this is querying the new table (has model suffix)
|
||||
has_model_suffix = storage.table_name.upper() in sql_upper
|
||||
|
||||
is_legacy_table = base_name in sql_upper and not has_model_suffix
|
||||
has_workspace_filter = "WHERE workspace" in sql
|
||||
|
||||
if is_legacy_table and has_workspace_filter:
|
||||
# Count for legacy table with workspace filter (before migration)
|
||||
return {"count": 50}
|
||||
elif is_legacy_table and not has_workspace_filter:
|
||||
# Total count for legacy table
|
||||
return {"count": 50}
|
||||
else:
|
||||
# New table count (before/after migration)
|
||||
return {"count": migration_state["new_table_count"]}
|
||||
elif multirows and "SELECT *" in sql:
|
||||
# Mock batch fetch for migration using keyset pagination
|
||||
# New pattern: WHERE workspace = $1 AND id > $2 ORDER BY id LIMIT $3
|
||||
# or first batch: WHERE workspace = $1 ORDER BY id LIMIT $2
|
||||
if "WHERE workspace" in sql:
|
||||
if "id >" in sql:
|
||||
# Keyset pagination: params = [workspace, last_id, limit]
|
||||
last_id = params[1] if len(params) > 1 else None
|
||||
# Find rows after last_id
|
||||
start_idx = 0
|
||||
for i, row in enumerate(mock_rows):
|
||||
if row["id"] == last_id:
|
||||
start_idx = i + 1
|
||||
break
|
||||
limit = params[2] if len(params) > 2 else 500
|
||||
else:
|
||||
# First batch (no last_id): params = [workspace, limit]
|
||||
start_idx = 0
|
||||
limit = params[1] if len(params) > 1 else 500
|
||||
else:
|
||||
# No workspace filter with keyset
|
||||
if "id >" in sql:
|
||||
last_id = params[0] if params else None
|
||||
start_idx = 0
|
||||
for i, row in enumerate(mock_rows):
|
||||
if row["id"] == last_id:
|
||||
start_idx = i + 1
|
||||
break
|
||||
limit = params[1] if len(params) > 1 else 500
|
||||
else:
|
||||
start_idx = 0
|
||||
limit = params[0] if params else 500
|
||||
end = min(start_idx + limit, len(mock_rows))
|
||||
return mock_rows[start_idx:end]
|
||||
return {}
|
||||
|
||||
mock_pg_db.query = AsyncMock(side_effect=mock_query)
|
||||
|
||||
# Track migration through _run_with_retry calls
|
||||
migration_executed = []
|
||||
|
||||
async def mock_run_with_retry(operation, **kwargs):
|
||||
# Track that migration batch operation was called
|
||||
migration_executed.append(True)
|
||||
migration_state["new_table_count"] = 50
|
||||
return None
|
||||
|
||||
mock_pg_db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry)
|
||||
|
||||
with patch(
|
||||
"lightrag.kg.postgres_impl.PGVectorStorage._pg_create_table", AsyncMock()
|
||||
) as mock_create:
|
||||
await storage.initialize()
|
||||
|
||||
# Verify table name contains ada-002
|
||||
assert "text_embedding_ada_002_1536d" in storage.table_name
|
||||
|
||||
# Verify migration was executed (batch migration uses _run_with_retry)
|
||||
assert len(migration_executed) > 0, "Migration should have been executed"
|
||||
mock_create.assert_called_once()
|
||||
|
||||
|
||||
async def test_scenario_3_multi_model_coexistence(
|
||||
mock_client_manager, mock_pg_db, mock_embedding_func
|
||||
):
|
||||
"""
|
||||
Scenario 3: Multiple embedding models coexist
|
||||
|
||||
Expected behavior:
|
||||
- Different embedding models create separate tables
|
||||
- Tables are isolated by model suffix
|
||||
- No interference between different models
|
||||
"""
|
||||
config = {
|
||||
"embedding_batch_num": 10,
|
||||
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||
}
|
||||
|
||||
# Workspace A: uses bge-small (768d)
|
||||
embedding_func_a = EmbeddingFunc(
|
||||
embedding_dim=768, func=mock_embedding_func.func, model_name="bge-small"
|
||||
)
|
||||
|
||||
storage_a = PGVectorStorage(
|
||||
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||
global_config=config,
|
||||
embedding_func=embedding_func_a,
|
||||
workspace="workspace_a",
|
||||
)
|
||||
|
||||
# Workspace B: uses bge-large (1024d)
|
||||
async def embed_func_b(texts, **kwargs):
|
||||
return np.array([[0.1] * 1024 for _ in texts])
|
||||
|
||||
embedding_func_b = EmbeddingFunc(
|
||||
embedding_dim=1024, func=embed_func_b, model_name="bge-large"
|
||||
)
|
||||
|
||||
storage_b = PGVectorStorage(
|
||||
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||
global_config=config,
|
||||
embedding_func=embedding_func_b,
|
||||
workspace="workspace_b",
|
||||
)
|
||||
|
||||
# Verify different table names
|
||||
assert storage_a.table_name != storage_b.table_name
|
||||
assert "bge_small_768d" in storage_a.table_name
|
||||
assert "bge_large_1024d" in storage_b.table_name
|
||||
|
||||
# Mock: both tables don't exist yet
|
||||
async def mock_check_table_exists(table_name):
|
||||
return False
|
||||
|
||||
mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
|
||||
|
||||
with patch(
|
||||
"lightrag.kg.postgres_impl.PGVectorStorage._pg_create_table", AsyncMock()
|
||||
) as mock_create:
|
||||
# Initialize both storages
|
||||
await storage_a.initialize()
|
||||
await storage_b.initialize()
|
||||
|
||||
# Verify two separate tables were created
|
||||
assert mock_create.call_count == 2
|
||||
|
||||
# Verify table names are different
|
||||
call_args_list = mock_create.call_args_list
|
||||
table_names = [call[0][1] for call in call_args_list] # Second positional arg
|
||||
assert len(set(table_names)) == 2 # Two unique table names
|
||||
assert storage_a.table_name in table_names
|
||||
assert storage_b.table_name in table_names
|
||||
|
||||
|
||||
async def test_case1_empty_legacy_auto_cleanup(
|
||||
mock_client_manager, mock_pg_db, mock_embedding_func
|
||||
):
|
||||
"""
|
||||
Case 1a: Both new and legacy tables exist, but legacy is EMPTY
|
||||
Expected: Automatically delete empty legacy table (safe cleanup)
|
||||
"""
|
||||
config = {
|
||||
"embedding_batch_num": 10,
|
||||
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||
}
|
||||
|
||||
embedding_func = EmbeddingFunc(
|
||||
embedding_dim=1536,
|
||||
func=mock_embedding_func.func,
|
||||
model_name="test-model",
|
||||
)
|
||||
|
||||
storage = PGVectorStorage(
|
||||
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||
global_config=config,
|
||||
embedding_func=embedding_func,
|
||||
workspace="test_ws",
|
||||
)
|
||||
|
||||
# Mock: Both tables exist
|
||||
async def mock_check_table_exists(table_name):
|
||||
return True # Both new and legacy exist
|
||||
|
||||
mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
|
||||
|
||||
# Mock: Legacy table is empty (0 records)
|
||||
async def mock_query(sql, params=None, multirows=False, **kwargs):
|
||||
if "COUNT(*)" in sql:
|
||||
if storage.legacy_table_name in sql:
|
||||
return {"count": 0} # Empty legacy table
|
||||
else:
|
||||
return {"count": 100} # New table has data
|
||||
return {}
|
||||
|
||||
mock_pg_db.query = AsyncMock(side_effect=mock_query)
|
||||
|
||||
with patch("lightrag.kg.postgres_impl.logger"):
|
||||
await storage.initialize()
|
||||
|
||||
# Verify: Empty legacy table should be automatically cleaned up
|
||||
# Empty tables are safe to delete without data loss risk
|
||||
delete_calls = [
|
||||
call
|
||||
for call in mock_pg_db.execute.call_args_list
|
||||
if call[0][0] and "DROP TABLE" in call[0][0]
|
||||
]
|
||||
assert len(delete_calls) >= 1, "Empty legacy table should be auto-deleted"
|
||||
# Check if legacy table was dropped
|
||||
dropped_table = storage.legacy_table_name
|
||||
assert any(
|
||||
dropped_table in str(call) for call in delete_calls
|
||||
), f"Expected to drop empty legacy table '{dropped_table}'"
|
||||
|
||||
print(
|
||||
f"✅ Case 1a: Empty legacy table '{dropped_table}' auto-deleted successfully"
|
||||
)
|
||||
|
||||
|
||||
async def test_case1_nonempty_legacy_warning(
|
||||
mock_client_manager, mock_pg_db, mock_embedding_func
|
||||
):
|
||||
"""
|
||||
Case 1b: Both new and legacy tables exist, and legacy HAS DATA
|
||||
Expected: Log warning, do not delete legacy (preserve data)
|
||||
"""
|
||||
config = {
|
||||
"embedding_batch_num": 10,
|
||||
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||
}
|
||||
|
||||
embedding_func = EmbeddingFunc(
|
||||
embedding_dim=1536,
|
||||
func=mock_embedding_func.func,
|
||||
model_name="test-model",
|
||||
)
|
||||
|
||||
storage = PGVectorStorage(
|
||||
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||
global_config=config,
|
||||
embedding_func=embedding_func,
|
||||
workspace="test_ws",
|
||||
)
|
||||
|
||||
# Mock: Both tables exist
|
||||
async def mock_check_table_exists(table_name):
|
||||
return True # Both new and legacy exist
|
||||
|
||||
mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
|
||||
|
||||
# Mock: Legacy table has data (50 records)
|
||||
async def mock_query(sql, params=None, multirows=False, **kwargs):
|
||||
if "COUNT(*)" in sql:
|
||||
if storage.legacy_table_name in sql:
|
||||
return {"count": 50} # Legacy has data
|
||||
else:
|
||||
return {"count": 100} # New table has data
|
||||
return {}
|
||||
|
||||
mock_pg_db.query = AsyncMock(side_effect=mock_query)
|
||||
|
||||
with patch("lightrag.kg.postgres_impl.logger"):
|
||||
await storage.initialize()
|
||||
|
||||
# Verify: Legacy table with data should be preserved
|
||||
# We never auto-delete tables that contain data to prevent accidental data loss
|
||||
delete_calls = [
|
||||
call
|
||||
for call in mock_pg_db.execute.call_args_list
|
||||
if call[0][0] and "DROP TABLE" in call[0][0]
|
||||
]
|
||||
# Check if legacy table was deleted (it should not be)
|
||||
dropped_table = storage.legacy_table_name
|
||||
legacy_deleted = any(dropped_table in str(call) for call in delete_calls)
|
||||
assert not legacy_deleted, "Legacy table with data should NOT be auto-deleted"
|
||||
|
||||
print(
|
||||
f"✅ Case 1b: Legacy table '{dropped_table}' with data preserved (warning only)"
|
||||
)
|
||||
|
||||
|
||||
async def test_case1_sequential_workspace_migration(
|
||||
mock_client_manager, mock_pg_db, mock_embedding_func
|
||||
):
|
||||
"""
|
||||
Case 1c: Sequential workspace migration (Multi-tenant scenario)
|
||||
|
||||
Critical bug fix verification:
|
||||
Timeline:
|
||||
1. Legacy table has workspace_a (3 records) + workspace_b (3 records)
|
||||
2. Workspace A initializes first → Case 3 (only legacy exists) → migrates A's data
|
||||
3. Workspace B initializes later → Case 3 (both tables exist, legacy has B's data) → should migrate B's data
|
||||
4. Verify workspace B's data is correctly migrated to new table
|
||||
|
||||
This test verifies the migration logic correctly handles multi-tenant scenarios
|
||||
where different workspaces migrate sequentially.
|
||||
"""
|
||||
config = {
|
||||
"embedding_batch_num": 10,
|
||||
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||
}
|
||||
|
||||
embedding_func = EmbeddingFunc(
|
||||
embedding_dim=1536,
|
||||
func=mock_embedding_func.func,
|
||||
model_name="test-model",
|
||||
)
|
||||
|
||||
# Mock data: Legacy table has 6 records total (3 from workspace_a, 3 from workspace_b)
|
||||
mock_rows_a = [
|
||||
{"id": f"a_{i}", "content": f"A content {i}", "workspace": "workspace_a"}
|
||||
for i in range(3)
|
||||
]
|
||||
mock_rows_b = [
|
||||
{"id": f"b_{i}", "content": f"B content {i}", "workspace": "workspace_b"}
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
# Track migration state
|
||||
migration_state = {
|
||||
"new_table_exists": False,
|
||||
"workspace_a_migrated": False,
|
||||
"workspace_a_migration_count": 0,
|
||||
"workspace_b_migration_count": 0,
|
||||
}
|
||||
|
||||
# Step 1: Simulate workspace_a initialization (Case 3 - only legacy exists)
|
||||
# CRITICAL: Set db.workspace to workspace_a
|
||||
mock_pg_db.workspace = "workspace_a"
|
||||
|
||||
storage_a = PGVectorStorage(
|
||||
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||
global_config=config,
|
||||
embedding_func=embedding_func,
|
||||
workspace="workspace_a",
|
||||
)
|
||||
|
||||
# Mock table_exists for workspace_a
|
||||
async def mock_check_table_exists_a(table_name):
|
||||
if table_name == storage_a.legacy_table_name:
|
||||
return True
|
||||
if table_name == storage_a.table_name:
|
||||
return migration_state["new_table_exists"]
|
||||
return False
|
||||
|
||||
mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists_a)
|
||||
|
||||
# Mock query for workspace_a (Case 3)
|
||||
async def mock_query_a(sql, params=None, multirows=False, **kwargs):
|
||||
sql_upper = sql.upper()
|
||||
base_name = storage_a.legacy_table_name.upper()
|
||||
|
||||
if "COUNT(*)" in sql:
|
||||
has_model_suffix = "TEST_MODEL_1536D" in sql_upper
|
||||
is_legacy = base_name in sql_upper and not has_model_suffix
|
||||
has_workspace_filter = "WHERE workspace" in sql
|
||||
|
||||
if is_legacy and has_workspace_filter:
|
||||
workspace = params[0] if params and len(params) > 0 else None
|
||||
if workspace == "workspace_a":
|
||||
return {"count": 3}
|
||||
elif workspace == "workspace_b":
|
||||
return {"count": 3}
|
||||
elif is_legacy and not has_workspace_filter:
|
||||
# Global count in legacy table
|
||||
return {"count": 6}
|
||||
elif has_model_suffix:
|
||||
if has_workspace_filter:
|
||||
workspace = params[0] if params and len(params) > 0 else None
|
||||
if workspace == "workspace_a":
|
||||
return {"count": migration_state["workspace_a_migration_count"]}
|
||||
if workspace == "workspace_b":
|
||||
return {"count": migration_state["workspace_b_migration_count"]}
|
||||
return {
|
||||
"count": migration_state["workspace_a_migration_count"]
|
||||
+ migration_state["workspace_b_migration_count"]
|
||||
}
|
||||
elif multirows and "SELECT *" in sql:
|
||||
if "WHERE workspace" in sql:
|
||||
workspace = params[0] if params and len(params) > 0 else None
|
||||
if workspace == "workspace_a":
|
||||
# Handle keyset pagination
|
||||
if "id >" in sql:
|
||||
# params = [workspace, last_id, limit]
|
||||
last_id = params[1] if len(params) > 1 else None
|
||||
start_idx = 0
|
||||
for i, row in enumerate(mock_rows_a):
|
||||
if row["id"] == last_id:
|
||||
start_idx = i + 1
|
||||
break
|
||||
limit = params[2] if len(params) > 2 else 500
|
||||
else:
|
||||
# First batch: params = [workspace, limit]
|
||||
start_idx = 0
|
||||
limit = params[1] if len(params) > 1 else 500
|
||||
end = min(start_idx + limit, len(mock_rows_a))
|
||||
return mock_rows_a[start_idx:end]
|
||||
return {}
|
||||
|
||||
mock_pg_db.query = AsyncMock(side_effect=mock_query_a)
|
||||
|
||||
# Track migration via _run_with_retry (batch migration uses this)
|
||||
migration_a_executed = []
|
||||
|
||||
async def mock_run_with_retry_a(operation, **kwargs):
|
||||
migration_a_executed.append(True)
|
||||
migration_state["workspace_a_migration_count"] = len(mock_rows_a)
|
||||
return None
|
||||
|
||||
mock_pg_db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry_a)
|
||||
|
||||
# Initialize workspace_a (Case 3)
|
||||
with patch("lightrag.kg.postgres_impl.logger"):
|
||||
await storage_a.initialize()
|
||||
migration_state["new_table_exists"] = True
|
||||
migration_state["workspace_a_migrated"] = True
|
||||
|
||||
print("✅ Step 1: Workspace A initialized")
|
||||
# Verify migration was executed via _run_with_retry (batch migration uses executemany)
|
||||
assert (
|
||||
len(migration_a_executed) > 0
|
||||
), "Migration should have been executed for workspace_a"
|
||||
print(f"✅ Step 1: Migration executed {len(migration_a_executed)} batch(es)")
|
||||
|
||||
# Step 2: Simulate workspace_b initialization (Case 3 - both exist, but legacy has B's data)
|
||||
# CRITICAL: Set db.workspace to workspace_b
|
||||
mock_pg_db.workspace = "workspace_b"
|
||||
|
||||
storage_b = PGVectorStorage(
|
||||
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||
global_config=config,
|
||||
embedding_func=embedding_func,
|
||||
workspace="workspace_b",
|
||||
)
|
||||
|
||||
mock_pg_db.reset_mock()
|
||||
|
||||
# Mock table_exists for workspace_b (both exist)
|
||||
async def mock_check_table_exists_b(table_name):
|
||||
return True # Both tables exist
|
||||
|
||||
mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists_b)
|
||||
|
||||
# Mock query for workspace_b (Case 3)
|
||||
async def mock_query_b(sql, params=None, multirows=False, **kwargs):
|
||||
sql_upper = sql.upper()
|
||||
base_name = storage_b.legacy_table_name.upper()
|
||||
|
||||
if "COUNT(*)" in sql:
|
||||
has_model_suffix = "TEST_MODEL_1536D" in sql_upper
|
||||
is_legacy = base_name in sql_upper and not has_model_suffix
|
||||
has_workspace_filter = "WHERE workspace" in sql
|
||||
|
||||
if is_legacy and has_workspace_filter:
|
||||
workspace = params[0] if params and len(params) > 0 else None
|
||||
if workspace == "workspace_b":
|
||||
return {"count": 3} # workspace_b still has data in legacy
|
||||
elif workspace == "workspace_a":
|
||||
return {"count": 0} # workspace_a already migrated
|
||||
elif is_legacy and not has_workspace_filter:
|
||||
# Global count: only workspace_b data remains
|
||||
return {"count": 3}
|
||||
elif has_model_suffix:
|
||||
if has_workspace_filter:
|
||||
workspace = params[0] if params and len(params) > 0 else None
|
||||
if workspace == "workspace_b":
|
||||
return {"count": migration_state["workspace_b_migration_count"]}
|
||||
elif workspace == "workspace_a":
|
||||
return {"count": 3}
|
||||
else:
|
||||
return {"count": 3 + migration_state["workspace_b_migration_count"]}
|
||||
elif multirows and "SELECT *" in sql:
|
||||
if "WHERE workspace" in sql:
|
||||
workspace = params[0] if params and len(params) > 0 else None
|
||||
if workspace == "workspace_b":
|
||||
# Handle keyset pagination
|
||||
if "id >" in sql:
|
||||
# params = [workspace, last_id, limit]
|
||||
last_id = params[1] if len(params) > 1 else None
|
||||
start_idx = 0
|
||||
for i, row in enumerate(mock_rows_b):
|
||||
if row["id"] == last_id:
|
||||
start_idx = i + 1
|
||||
break
|
||||
limit = params[2] if len(params) > 2 else 500
|
||||
else:
|
||||
# First batch: params = [workspace, limit]
|
||||
start_idx = 0
|
||||
limit = params[1] if len(params) > 1 else 500
|
||||
end = min(start_idx + limit, len(mock_rows_b))
|
||||
return mock_rows_b[start_idx:end]
|
||||
return {}
|
||||
|
||||
mock_pg_db.query = AsyncMock(side_effect=mock_query_b)
|
||||
|
||||
# Track migration via _run_with_retry for workspace_b
|
||||
migration_b_executed = []
|
||||
|
||||
async def mock_run_with_retry_b(operation, **kwargs):
|
||||
migration_b_executed.append(True)
|
||||
migration_state["workspace_b_migration_count"] = len(mock_rows_b)
|
||||
return None
|
||||
|
||||
mock_pg_db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry_b)
|
||||
|
||||
# Initialize workspace_b (Case 3 - both tables exist)
|
||||
with patch("lightrag.kg.postgres_impl.logger"):
|
||||
await storage_b.initialize()
|
||||
|
||||
print("✅ Step 2: Workspace B initialized")
|
||||
|
||||
# Verify workspace_b migration happens when new table has no workspace_b data
|
||||
# but legacy table still has workspace_b data.
|
||||
assert (
|
||||
len(migration_b_executed) > 0
|
||||
), "Migration should have been executed for workspace_b"
|
||||
print("✅ Step 2: Migration executed for workspace_b")
|
||||
|
||||
print("\n🎉 Case 1c: Sequential workspace migration verification complete!")
|
||||
print(" - Workspace A: Migrated successfully (only legacy existed)")
|
||||
print(" - Workspace B: Migrated successfully (new table empty for workspace_b)")
|
||||
562
tests/test_qdrant_migration.py
Normal file
562
tests/test_qdrant_migration.py
Normal file
@@ -0,0 +1,562 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
import numpy as np
|
||||
from qdrant_client import models
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from lightrag.kg.qdrant_impl import QdrantVectorDBStorage
|
||||
|
||||
|
||||
# Mock QdrantClient
|
||||
@pytest.fixture
|
||||
def mock_qdrant_client():
|
||||
with patch("lightrag.kg.qdrant_impl.QdrantClient") as mock_client_cls:
|
||||
client = mock_client_cls.return_value
|
||||
client.collection_exists.return_value = False
|
||||
client.count.return_value.count = 0
|
||||
# Mock payload schema and vector config for get_collection
|
||||
collection_info = MagicMock()
|
||||
collection_info.payload_schema = {}
|
||||
# Mock vector dimension to match mock_embedding_func (768d)
|
||||
collection_info.config.params.vectors.size = 768
|
||||
client.get_collection.return_value = collection_info
|
||||
yield client
|
||||
|
||||
|
||||
# Mock get_data_init_lock to avoid async lock issues in tests
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_data_init_lock():
|
||||
with patch("lightrag.kg.qdrant_impl.get_data_init_lock") as mock_lock:
|
||||
mock_lock_ctx = AsyncMock()
|
||||
mock_lock.return_value = mock_lock_ctx
|
||||
yield mock_lock
|
||||
|
||||
|
||||
# Mock Embedding function
|
||||
@pytest.fixture
|
||||
def mock_embedding_func():
|
||||
async def embed_func(texts, **kwargs):
|
||||
return np.array([[0.1] * 768 for _ in texts])
|
||||
|
||||
func = EmbeddingFunc(embedding_dim=768, func=embed_func, model_name="test-model")
|
||||
return func
|
||||
|
||||
|
||||
async def test_qdrant_collection_naming(mock_qdrant_client, mock_embedding_func):
|
||||
"""Test if collection name is correctly generated with model suffix"""
|
||||
config = {
|
||||
"embedding_batch_num": 10,
|
||||
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||
}
|
||||
|
||||
storage = QdrantVectorDBStorage(
|
||||
namespace="chunks",
|
||||
global_config=config,
|
||||
embedding_func=mock_embedding_func,
|
||||
workspace="test_ws",
|
||||
)
|
||||
|
||||
# Verify collection name contains model suffix
|
||||
expected_suffix = "test_model_768d"
|
||||
assert expected_suffix in storage.final_namespace
|
||||
assert storage.final_namespace == f"lightrag_vdb_chunks_{expected_suffix}"
|
||||
|
||||
|
||||
async def test_qdrant_migration_trigger(mock_qdrant_client, mock_embedding_func):
|
||||
"""Test if migration logic is triggered correctly"""
|
||||
config = {
|
||||
"embedding_batch_num": 10,
|
||||
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||
}
|
||||
|
||||
storage = QdrantVectorDBStorage(
|
||||
namespace="chunks",
|
||||
global_config=config,
|
||||
embedding_func=mock_embedding_func,
|
||||
workspace="test_ws",
|
||||
)
|
||||
|
||||
# Legacy collection name (without model suffix)
|
||||
legacy_collection = "lightrag_vdb_chunks"
|
||||
|
||||
# Setup mocks for migration scenario
|
||||
# 1. New collection does not exist, only legacy exists
|
||||
mock_qdrant_client.collection_exists.side_effect = (
|
||||
lambda name: name == legacy_collection
|
||||
)
|
||||
|
||||
# 2. Legacy collection exists and has data
|
||||
migration_state = {"new_workspace_count": 0}
|
||||
|
||||
def count_mock(collection_name, exact=True, count_filter=None):
|
||||
mock_result = MagicMock()
|
||||
if collection_name == legacy_collection:
|
||||
mock_result.count = 100
|
||||
elif collection_name == storage.final_namespace:
|
||||
mock_result.count = migration_state["new_workspace_count"]
|
||||
else:
|
||||
mock_result.count = 0
|
||||
return mock_result
|
||||
|
||||
mock_qdrant_client.count.side_effect = count_mock
|
||||
|
||||
# 3. Mock scroll for data migration
|
||||
mock_point = MagicMock()
|
||||
mock_point.id = "old_id"
|
||||
mock_point.vector = [0.1] * 768
|
||||
mock_point.payload = {"content": "test"} # No workspace_id in payload
|
||||
|
||||
# When payload_schema is empty, the code first samples payloads to detect workspace_id
|
||||
# Then proceeds with migration batches
|
||||
# Scroll calls: 1) Sampling (limit=10), 2) Migration batch, 3) End of migration
|
||||
mock_qdrant_client.scroll.side_effect = [
|
||||
([mock_point], "_"), # Sampling scroll - no workspace_id found
|
||||
([mock_point], "next_offset"), # Migration batch
|
||||
([], None), # End of migration
|
||||
]
|
||||
|
||||
def upsert_mock(*args, **kwargs):
|
||||
migration_state["new_workspace_count"] = 100
|
||||
return None
|
||||
|
||||
mock_qdrant_client.upsert.side_effect = upsert_mock
|
||||
|
||||
# Initialize storage (triggers migration)
|
||||
await storage.initialize()
|
||||
|
||||
# Verify migration steps
|
||||
# 1. Legacy count checked
|
||||
mock_qdrant_client.count.assert_any_call(
|
||||
collection_name=legacy_collection, exact=True
|
||||
)
|
||||
|
||||
# 2. New collection created
|
||||
mock_qdrant_client.create_collection.assert_called()
|
||||
|
||||
# 3. Data scrolled from legacy
|
||||
# First call (index 0) is sampling scroll with limit=10
|
||||
# Second call (index 1) is migration batch with limit=500
|
||||
assert mock_qdrant_client.scroll.call_count >= 2
|
||||
# Check sampling scroll
|
||||
sampling_call = mock_qdrant_client.scroll.call_args_list[0]
|
||||
assert sampling_call.kwargs["collection_name"] == legacy_collection
|
||||
assert sampling_call.kwargs["limit"] == 10
|
||||
# Check migration batch scroll
|
||||
migration_call = mock_qdrant_client.scroll.call_args_list[1]
|
||||
assert migration_call.kwargs["collection_name"] == legacy_collection
|
||||
assert migration_call.kwargs["limit"] == 500
|
||||
|
||||
# 4. Data upserted to new
|
||||
mock_qdrant_client.upsert.assert_called()
|
||||
|
||||
# 5. Payload index created
|
||||
mock_qdrant_client.create_payload_index.assert_called()
|
||||
|
||||
|
||||
async def test_qdrant_no_migration_needed(mock_qdrant_client, mock_embedding_func):
|
||||
"""Test scenario where new collection already exists (Case 1 in setup_collection)
|
||||
|
||||
When only the new collection exists and no legacy collection is found,
|
||||
the implementation should:
|
||||
1. Create payload index on the new collection (ensure index exists)
|
||||
2. NOT attempt any data migration (no scroll calls)
|
||||
"""
|
||||
config = {
|
||||
"embedding_batch_num": 10,
|
||||
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||
}
|
||||
|
||||
storage = QdrantVectorDBStorage(
|
||||
namespace="chunks",
|
||||
global_config=config,
|
||||
embedding_func=mock_embedding_func,
|
||||
workspace="test_ws",
|
||||
)
|
||||
|
||||
# Only new collection exists (no legacy collection found)
|
||||
mock_qdrant_client.collection_exists.side_effect = (
|
||||
lambda name: name == storage.final_namespace
|
||||
)
|
||||
|
||||
# Initialize
|
||||
await storage.initialize()
|
||||
|
||||
# Should create payload index on the new collection (ensure index)
|
||||
mock_qdrant_client.create_payload_index.assert_called_with(
|
||||
collection_name=storage.final_namespace,
|
||||
field_name="workspace_id",
|
||||
field_schema=models.KeywordIndexParams(
|
||||
type=models.KeywordIndexType.KEYWORD,
|
||||
is_tenant=True,
|
||||
),
|
||||
)
|
||||
# Should NOT migrate (no scroll calls since no legacy collection exists)
|
||||
mock_qdrant_client.scroll.assert_not_called()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for scenarios described in design document (Lines 606-649)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def test_scenario_1_new_workspace_creation(
|
||||
mock_qdrant_client, mock_embedding_func
|
||||
):
|
||||
"""
|
||||
场景1:新建workspace
|
||||
预期:直接创建lightrag_vdb_chunks_text_embedding_3_large_3072d
|
||||
"""
|
||||
# Use a large embedding model
|
||||
large_model_func = EmbeddingFunc(
|
||||
embedding_dim=3072,
|
||||
func=mock_embedding_func.func,
|
||||
model_name="text-embedding-3-large",
|
||||
)
|
||||
|
||||
config = {
|
||||
"embedding_batch_num": 10,
|
||||
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||
}
|
||||
|
||||
storage = QdrantVectorDBStorage(
|
||||
namespace="chunks",
|
||||
global_config=config,
|
||||
embedding_func=large_model_func,
|
||||
workspace="test_new",
|
||||
)
|
||||
|
||||
# Case 3: Neither legacy nor new collection exists
|
||||
mock_qdrant_client.collection_exists.return_value = False
|
||||
|
||||
# Initialize storage
|
||||
await storage.initialize()
|
||||
|
||||
# Verify: Should create new collection with model suffix
|
||||
expected_collection = "lightrag_vdb_chunks_text_embedding_3_large_3072d"
|
||||
assert storage.final_namespace == expected_collection
|
||||
|
||||
# Verify create_collection was called with correct name
|
||||
create_calls = [
|
||||
call for call in mock_qdrant_client.create_collection.call_args_list
|
||||
]
|
||||
assert len(create_calls) > 0
|
||||
assert (
|
||||
create_calls[0][0][0] == expected_collection
|
||||
or create_calls[0].kwargs.get("collection_name") == expected_collection
|
||||
)
|
||||
|
||||
# Verify no migration was attempted
|
||||
mock_qdrant_client.scroll.assert_not_called()
|
||||
|
||||
print(
|
||||
f"✅ Scenario 1: New workspace created with collection '{expected_collection}'"
|
||||
)
|
||||
|
||||
|
||||
async def test_scenario_2_legacy_upgrade_migration(
|
||||
mock_qdrant_client, mock_embedding_func
|
||||
):
|
||||
"""
|
||||
场景2:从旧版本升级
|
||||
已存在lightrag_vdb_chunks(无后缀)
|
||||
预期:自动迁移数据到lightrag_vdb_chunks_text_embedding_ada_002_1536d
|
||||
注意:迁移后不再自动删除遗留集合,需要手动删除
|
||||
"""
|
||||
# Use ada-002 model
|
||||
ada_func = EmbeddingFunc(
|
||||
embedding_dim=1536,
|
||||
func=mock_embedding_func.func,
|
||||
model_name="text-embedding-ada-002",
|
||||
)
|
||||
|
||||
config = {
|
||||
"embedding_batch_num": 10,
|
||||
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||
}
|
||||
|
||||
storage = QdrantVectorDBStorage(
|
||||
namespace="chunks",
|
||||
global_config=config,
|
||||
embedding_func=ada_func,
|
||||
workspace="test_legacy",
|
||||
)
|
||||
|
||||
# Legacy collection name (without model suffix)
|
||||
legacy_collection = "lightrag_vdb_chunks"
|
||||
new_collection = storage.final_namespace
|
||||
|
||||
# Case 4: Only legacy collection exists
|
||||
mock_qdrant_client.collection_exists.side_effect = (
|
||||
lambda name: name == legacy_collection
|
||||
)
|
||||
|
||||
# Mock legacy collection info with 1536d vectors
|
||||
legacy_collection_info = MagicMock()
|
||||
legacy_collection_info.payload_schema = {}
|
||||
legacy_collection_info.config.params.vectors.size = 1536
|
||||
mock_qdrant_client.get_collection.return_value = legacy_collection_info
|
||||
|
||||
migration_state = {"new_workspace_count": 0}
|
||||
|
||||
def count_mock(collection_name, exact=True, count_filter=None):
|
||||
mock_result = MagicMock()
|
||||
if collection_name == legacy_collection:
|
||||
mock_result.count = 150
|
||||
elif collection_name == new_collection:
|
||||
mock_result.count = migration_state["new_workspace_count"]
|
||||
else:
|
||||
mock_result.count = 0
|
||||
return mock_result
|
||||
|
||||
mock_qdrant_client.count.side_effect = count_mock
|
||||
|
||||
# Mock scroll results (simulate migration in batches)
|
||||
mock_points = []
|
||||
for i in range(10):
|
||||
point = MagicMock()
|
||||
point.id = f"legacy-{i}"
|
||||
point.vector = [0.1] * 1536
|
||||
# No workspace_id in payload - simulates legacy data
|
||||
point.payload = {"content": f"Legacy document {i}", "id": f"doc-{i}"}
|
||||
mock_points.append(point)
|
||||
|
||||
# When payload_schema is empty, the code first samples payloads to detect workspace_id
|
||||
# Then proceeds with migration batches
|
||||
# Scroll calls: 1) Sampling (limit=10), 2) Migration batch, 3) End of migration
|
||||
mock_qdrant_client.scroll.side_effect = [
|
||||
(mock_points, "_"), # Sampling scroll - no workspace_id found in payloads
|
||||
(mock_points, "offset1"), # Migration batch
|
||||
([], None), # End of migration
|
||||
]
|
||||
|
||||
def upsert_mock(*args, **kwargs):
|
||||
migration_state["new_workspace_count"] = 150
|
||||
return None
|
||||
|
||||
mock_qdrant_client.upsert.side_effect = upsert_mock
|
||||
|
||||
# Initialize (triggers migration)
|
||||
await storage.initialize()
|
||||
|
||||
# Verify: New collection should be created
|
||||
expected_new_collection = "lightrag_vdb_chunks_text_embedding_ada_002_1536d"
|
||||
assert storage.final_namespace == expected_new_collection
|
||||
|
||||
# Verify migration steps
|
||||
# 1. Check legacy count
|
||||
mock_qdrant_client.count.assert_any_call(
|
||||
collection_name=legacy_collection, exact=True
|
||||
)
|
||||
|
||||
# 2. Create new collection
|
||||
mock_qdrant_client.create_collection.assert_called()
|
||||
|
||||
# 3. Scroll legacy data
|
||||
scroll_calls = [call for call in mock_qdrant_client.scroll.call_args_list]
|
||||
assert len(scroll_calls) >= 1
|
||||
assert scroll_calls[0].kwargs["collection_name"] == legacy_collection
|
||||
|
||||
# 4. Upsert to new collection
|
||||
upsert_calls = [call for call in mock_qdrant_client.upsert.call_args_list]
|
||||
assert len(upsert_calls) >= 1
|
||||
assert upsert_calls[0].kwargs["collection_name"] == new_collection
|
||||
|
||||
# Note: Legacy collection is NOT automatically deleted after migration
|
||||
# Manual deletion is required after data migration verification
|
||||
|
||||
print(
|
||||
f"✅ Scenario 2: Legacy data migrated from '{legacy_collection}' to '{expected_new_collection}'"
|
||||
)
|
||||
|
||||
|
||||
async def test_scenario_3_multi_model_coexistence(mock_qdrant_client):
|
||||
"""
|
||||
场景3:多模型并存
|
||||
预期:两个独立的collection,互不干扰
|
||||
"""
|
||||
|
||||
# Model A: bge-small with 768d
|
||||
async def embed_func_a(texts, **kwargs):
|
||||
return np.array([[0.1] * 768 for _ in texts])
|
||||
|
||||
model_a_func = EmbeddingFunc(
|
||||
embedding_dim=768, func=embed_func_a, model_name="bge-small"
|
||||
)
|
||||
|
||||
# Model B: bge-large with 1024d
|
||||
async def embed_func_b(texts, **kwargs):
|
||||
return np.array([[0.2] * 1024 for _ in texts])
|
||||
|
||||
model_b_func = EmbeddingFunc(
|
||||
embedding_dim=1024, func=embed_func_b, model_name="bge-large"
|
||||
)
|
||||
|
||||
config = {
|
||||
"embedding_batch_num": 10,
|
||||
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||
}
|
||||
|
||||
# Create storage for workspace A with model A
|
||||
storage_a = QdrantVectorDBStorage(
|
||||
namespace="chunks",
|
||||
global_config=config,
|
||||
embedding_func=model_a_func,
|
||||
workspace="workspace_a",
|
||||
)
|
||||
|
||||
# Create storage for workspace B with model B
|
||||
storage_b = QdrantVectorDBStorage(
|
||||
namespace="chunks",
|
||||
global_config=config,
|
||||
embedding_func=model_b_func,
|
||||
workspace="workspace_b",
|
||||
)
|
||||
|
||||
# Verify: Collection names are different
|
||||
assert storage_a.final_namespace != storage_b.final_namespace
|
||||
|
||||
# Verify: Model A collection
|
||||
expected_collection_a = "lightrag_vdb_chunks_bge_small_768d"
|
||||
assert storage_a.final_namespace == expected_collection_a
|
||||
|
||||
# Verify: Model B collection
|
||||
expected_collection_b = "lightrag_vdb_chunks_bge_large_1024d"
|
||||
assert storage_b.final_namespace == expected_collection_b
|
||||
|
||||
# Verify: Different embedding dimensions are preserved
|
||||
assert storage_a.embedding_func.embedding_dim == 768
|
||||
assert storage_b.embedding_func.embedding_dim == 1024
|
||||
|
||||
print("✅ Scenario 3: Multi-model coexistence verified")
|
||||
print(f" - Workspace A: {expected_collection_a} (768d)")
|
||||
print(f" - Workspace B: {expected_collection_b} (1024d)")
|
||||
print(" - Collections are independent")
|
||||
|
||||
|
||||
async def test_case1_empty_legacy_auto_cleanup(mock_qdrant_client, mock_embedding_func):
|
||||
"""
|
||||
Case 1a: 新旧collection都存在,且旧库为空
|
||||
预期:自动删除旧库
|
||||
"""
|
||||
config = {
|
||||
"embedding_batch_num": 10,
|
||||
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||
}
|
||||
|
||||
storage = QdrantVectorDBStorage(
|
||||
namespace="chunks",
|
||||
global_config=config,
|
||||
embedding_func=mock_embedding_func,
|
||||
workspace="test_ws",
|
||||
)
|
||||
|
||||
# Legacy collection name (without model suffix)
|
||||
legacy_collection = "lightrag_vdb_chunks"
|
||||
new_collection = storage.final_namespace
|
||||
|
||||
# Mock: Both collections exist
|
||||
mock_qdrant_client.collection_exists.side_effect = lambda name: name in [
|
||||
legacy_collection,
|
||||
new_collection,
|
||||
]
|
||||
|
||||
# Mock: Legacy collection is empty (0 records)
|
||||
def count_mock(collection_name, exact=True, count_filter=None):
|
||||
mock_result = MagicMock()
|
||||
if collection_name == legacy_collection:
|
||||
mock_result.count = 0 # Empty legacy collection
|
||||
else:
|
||||
mock_result.count = 100 # New collection has data
|
||||
return mock_result
|
||||
|
||||
mock_qdrant_client.count.side_effect = count_mock
|
||||
|
||||
# Mock get_collection for Case 2 check
|
||||
collection_info = MagicMock()
|
||||
collection_info.payload_schema = {"workspace_id": True}
|
||||
mock_qdrant_client.get_collection.return_value = collection_info
|
||||
|
||||
# Initialize storage
|
||||
await storage.initialize()
|
||||
|
||||
# Verify: Empty legacy collection should be automatically cleaned up
|
||||
# Empty collections are safe to delete without data loss risk
|
||||
delete_calls = [
|
||||
call for call in mock_qdrant_client.delete_collection.call_args_list
|
||||
]
|
||||
assert len(delete_calls) >= 1, "Empty legacy collection should be auto-deleted"
|
||||
deleted_collection = (
|
||||
delete_calls[0][0][0]
|
||||
if delete_calls[0][0]
|
||||
else delete_calls[0].kwargs.get("collection_name")
|
||||
)
|
||||
assert (
|
||||
deleted_collection == legacy_collection
|
||||
), f"Expected to delete '{legacy_collection}', but deleted '{deleted_collection}'"
|
||||
|
||||
print(
|
||||
f"✅ Case 1a: Empty legacy collection '{legacy_collection}' auto-deleted successfully"
|
||||
)
|
||||
|
||||
|
||||
async def test_case1_nonempty_legacy_warning(mock_qdrant_client, mock_embedding_func):
|
||||
"""
|
||||
Case 1b: 新旧collection都存在,且旧库有数据
|
||||
预期:警告但不删除
|
||||
"""
|
||||
config = {
|
||||
"embedding_batch_num": 10,
|
||||
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||
}
|
||||
|
||||
storage = QdrantVectorDBStorage(
|
||||
namespace="chunks",
|
||||
global_config=config,
|
||||
embedding_func=mock_embedding_func,
|
||||
workspace="test_ws",
|
||||
)
|
||||
|
||||
# Legacy collection name (without model suffix)
|
||||
legacy_collection = "lightrag_vdb_chunks"
|
||||
new_collection = storage.final_namespace
|
||||
|
||||
# Mock: Both collections exist
|
||||
mock_qdrant_client.collection_exists.side_effect = lambda name: name in [
|
||||
legacy_collection,
|
||||
new_collection,
|
||||
]
|
||||
|
||||
# Mock: Legacy collection has data (50 records)
|
||||
def count_mock(collection_name, exact=True, count_filter=None):
|
||||
mock_result = MagicMock()
|
||||
if collection_name == legacy_collection:
|
||||
mock_result.count = 50 # Legacy has data
|
||||
else:
|
||||
mock_result.count = 100 # New collection has data
|
||||
return mock_result
|
||||
|
||||
mock_qdrant_client.count.side_effect = count_mock
|
||||
|
||||
# Mock get_collection for Case 2 check
|
||||
collection_info = MagicMock()
|
||||
collection_info.payload_schema = {"workspace_id": True}
|
||||
mock_qdrant_client.get_collection.return_value = collection_info
|
||||
|
||||
# Initialize storage
|
||||
await storage.initialize()
|
||||
|
||||
# Verify: Legacy collection with data should be preserved
|
||||
# We never auto-delete collections that contain data to prevent accidental data loss
|
||||
delete_calls = [
|
||||
call for call in mock_qdrant_client.delete_collection.call_args_list
|
||||
]
|
||||
# Check if legacy collection was deleted (it should not be)
|
||||
legacy_deleted = any(
|
||||
(call[0][0] if call[0] else call.kwargs.get("collection_name"))
|
||||
== legacy_collection
|
||||
for call in delete_calls
|
||||
)
|
||||
assert not legacy_deleted, "Legacy collection with data should NOT be auto-deleted"
|
||||
|
||||
print(
|
||||
f"✅ Case 1b: Legacy collection '{legacy_collection}' with data preserved (warning only)"
|
||||
)
|
||||
122
tests/test_unified_lock_safety.py
Normal file
122
tests/test_unified_lock_safety.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
Tests for UnifiedLock safety when lock is None.
|
||||
|
||||
This test module verifies that get_internal_lock() and get_data_init_lock()
|
||||
raise RuntimeError when shared data is not initialized, preventing false
|
||||
security and potential race conditions.
|
||||
|
||||
Design: The None check has been moved from UnifiedLock.__aenter__/__enter__
|
||||
to the lock factory functions (get_internal_lock, get_data_init_lock) for
|
||||
early failure detection.
|
||||
|
||||
Critical Bug 1 (Fixed): When self._lock is None, the code would fail with
|
||||
AttributeError. Now the check is in factory functions for clearer errors.
|
||||
|
||||
Critical Bug 2: In __aexit__, when async_lock.release() fails, the error
|
||||
recovery logic would attempt to release it again, causing double-release issues.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from lightrag.kg.shared_storage import (
|
||||
UnifiedLock,
|
||||
get_internal_lock,
|
||||
get_data_init_lock,
|
||||
finalize_share_data,
|
||||
)
|
||||
|
||||
|
||||
class TestUnifiedLockSafety:
|
||||
"""Test suite for UnifiedLock None safety checks."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Ensure shared data is finalized before each test."""
|
||||
finalize_share_data()
|
||||
|
||||
def teardown_method(self):
|
||||
"""Clean up after each test."""
|
||||
finalize_share_data()
|
||||
|
||||
def test_get_internal_lock_raises_when_not_initialized(self):
|
||||
"""
|
||||
Test that get_internal_lock() raises RuntimeError when shared data is not initialized.
|
||||
|
||||
Scenario: Call get_internal_lock() before initialize_share_data() is called.
|
||||
Expected: RuntimeError raised with clear error message.
|
||||
|
||||
This test verifies the None check has been moved to the factory function.
|
||||
"""
|
||||
with pytest.raises(
|
||||
RuntimeError, match="Shared data not initialized.*initialize_share_data"
|
||||
):
|
||||
get_internal_lock()
|
||||
|
||||
def test_get_data_init_lock_raises_when_not_initialized(self):
|
||||
"""
|
||||
Test that get_data_init_lock() raises RuntimeError when shared data is not initialized.
|
||||
|
||||
Scenario: Call get_data_init_lock() before initialize_share_data() is called.
|
||||
Expected: RuntimeError raised with clear error message.
|
||||
|
||||
This test verifies the None check has been moved to the factory function.
|
||||
"""
|
||||
with pytest.raises(
|
||||
RuntimeError, match="Shared data not initialized.*initialize_share_data"
|
||||
):
|
||||
get_data_init_lock()
|
||||
|
||||
@pytest.mark.offline
|
||||
async def test_aexit_no_double_release_on_async_lock_failure(self):
|
||||
"""
|
||||
Test that __aexit__ doesn't attempt to release async_lock twice when it fails.
|
||||
|
||||
Scenario: async_lock.release() fails during normal release.
|
||||
Expected: Recovery logic should NOT attempt to release async_lock again,
|
||||
preventing double-release issues.
|
||||
|
||||
This tests Bug 2 fix: async_lock_released tracking prevents double release.
|
||||
"""
|
||||
# Create mock locks
|
||||
main_lock = MagicMock()
|
||||
main_lock.acquire = MagicMock()
|
||||
main_lock.release = MagicMock()
|
||||
|
||||
async_lock = AsyncMock()
|
||||
async_lock.acquire = AsyncMock()
|
||||
|
||||
# Make async_lock.release() fail
|
||||
release_call_count = 0
|
||||
|
||||
def mock_release_fail():
|
||||
nonlocal release_call_count
|
||||
release_call_count += 1
|
||||
raise RuntimeError("Async lock release failed")
|
||||
|
||||
async_lock.release = MagicMock(side_effect=mock_release_fail)
|
||||
|
||||
# Create UnifiedLock with both locks (sync mode with async_lock)
|
||||
lock = UnifiedLock(
|
||||
lock=main_lock,
|
||||
is_async=False,
|
||||
name="test_double_release",
|
||||
enable_logging=False,
|
||||
)
|
||||
lock._async_lock = async_lock
|
||||
|
||||
# Try to use the lock - should fail during __aexit__
|
||||
try:
|
||||
async with lock:
|
||||
pass
|
||||
except RuntimeError as e:
|
||||
# Should get the async lock release error
|
||||
assert "Async lock release failed" in str(e)
|
||||
|
||||
# Verify async_lock.release() was called only ONCE, not twice
|
||||
assert (
|
||||
release_call_count == 1
|
||||
), f"async_lock.release() should be called only once, but was called {release_call_count} times"
|
||||
|
||||
# Main lock should have been released successfully
|
||||
main_lock.release.assert_called_once()
|
||||
@@ -149,7 +149,6 @@ def _assert_no_timeline_overlap(timeline: List[Tuple[str, str]]) -> None:
|
||||
|
||||
|
||||
@pytest.mark.offline
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_status_isolation():
|
||||
"""
|
||||
Test that pipeline status is isolated between different workspaces.
|
||||
@@ -204,7 +203,6 @@ async def test_pipeline_status_isolation():
|
||||
|
||||
|
||||
@pytest.mark.offline
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_mechanism(stress_test_mode, parallel_workers):
|
||||
"""
|
||||
Test that the new keyed lock mechanism works correctly without deadlocks.
|
||||
@@ -274,7 +272,6 @@ async def test_lock_mechanism(stress_test_mode, parallel_workers):
|
||||
|
||||
|
||||
@pytest.mark.offline
|
||||
@pytest.mark.asyncio
|
||||
async def test_backward_compatibility():
|
||||
"""
|
||||
Test that legacy code without workspace parameter still works correctly.
|
||||
@@ -348,7 +345,6 @@ async def test_backward_compatibility():
|
||||
|
||||
|
||||
@pytest.mark.offline
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_workspace_concurrency():
|
||||
"""
|
||||
Test that multiple workspaces can operate concurrently without interference.
|
||||
@@ -432,7 +428,6 @@ async def test_multi_workspace_concurrency():
|
||||
|
||||
|
||||
@pytest.mark.offline
|
||||
@pytest.mark.asyncio
|
||||
async def test_namespace_lock_reentrance():
|
||||
"""
|
||||
Test that NamespaceLock prevents re-entrance in the same coroutine
|
||||
@@ -506,7 +501,6 @@ async def test_namespace_lock_reentrance():
|
||||
|
||||
|
||||
@pytest.mark.offline
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_namespace_lock_isolation():
|
||||
"""
|
||||
Test that locks for different namespaces (same workspace) are independent.
|
||||
@@ -546,7 +540,6 @@ async def test_different_namespace_lock_isolation():
|
||||
|
||||
|
||||
@pytest.mark.offline
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling():
|
||||
"""
|
||||
Test error handling for invalid workspace configurations.
|
||||
@@ -597,7 +590,6 @@ async def test_error_handling():
|
||||
|
||||
|
||||
@pytest.mark.offline
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_flags_workspace_isolation():
|
||||
"""
|
||||
Test that update flags are properly isolated between workspaces.
|
||||
@@ -727,7 +719,6 @@ async def test_update_flags_workspace_isolation():
|
||||
|
||||
|
||||
@pytest.mark.offline
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_workspace_standardization():
|
||||
"""
|
||||
Test that empty workspace is properly standardized to "" instead of "_".
|
||||
@@ -781,7 +772,6 @@ async def test_empty_workspace_standardization():
|
||||
|
||||
|
||||
@pytest.mark.offline
|
||||
@pytest.mark.asyncio
|
||||
async def test_json_kv_storage_workspace_isolation(keep_test_artifacts):
|
||||
"""
|
||||
Integration test: Verify JsonKVStorage properly isolates data between workspaces.
|
||||
@@ -961,7 +951,6 @@ async def test_json_kv_storage_workspace_isolation(keep_test_artifacts):
|
||||
|
||||
|
||||
@pytest.mark.offline
|
||||
@pytest.mark.asyncio
|
||||
async def test_lightrag_end_to_end_workspace_isolation(keep_test_artifacts):
|
||||
"""
|
||||
End-to-end test: Create two LightRAG instances with different workspaces,
|
||||
|
||||
288
tests/test_workspace_migration_isolation.py
Normal file
288
tests/test_workspace_migration_isolation.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
Tests for workspace isolation during PostgreSQL migration.
|
||||
|
||||
This test module verifies that setup_table() properly filters migration data
|
||||
by workspace, preventing cross-workspace data leakage during legacy table migration.
|
||||
|
||||
Critical Bug: Migration copied ALL records from legacy table regardless of workspace,
|
||||
causing workspace A to receive workspace B's data, violating multi-tenant isolation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from lightrag.kg.postgres_impl import PGVectorStorage
|
||||
|
||||
|
||||
class TestWorkspaceMigrationIsolation:
|
||||
"""Test suite for workspace-scoped migration in PostgreSQL."""
|
||||
|
||||
async def test_migration_filters_by_workspace(self):
|
||||
"""
|
||||
Test that migration only copies data from the specified workspace.
|
||||
|
||||
Scenario: Legacy table contains data from multiple workspaces.
|
||||
Migrate only workspace_a's data to new table.
|
||||
Expected: New table contains only workspace_a data, workspace_b data excluded.
|
||||
"""
|
||||
db = AsyncMock()
|
||||
|
||||
# Configure mock return values to avoid unawaited coroutine warnings
|
||||
db._create_vector_index.return_value = None
|
||||
|
||||
# Track state for new table count (starts at 0, increases after migration)
|
||||
new_table_record_count = {"count": 0}
|
||||
|
||||
# Mock table existence checks
|
||||
async def table_exists_side_effect(db_instance, name):
|
||||
if name.lower() == "lightrag_doc_chunks": # legacy
|
||||
return True
|
||||
elif name.lower() == "lightrag_doc_chunks_model_1536d": # new
|
||||
return False # New table doesn't exist initially
|
||||
return False
|
||||
|
||||
# Mock data for workspace_a
|
||||
mock_records_a = [
|
||||
{
|
||||
"id": "a1",
|
||||
"workspace": "workspace_a",
|
||||
"content": "content_a1",
|
||||
"content_vector": [0.1] * 1536,
|
||||
},
|
||||
{
|
||||
"id": "a2",
|
||||
"workspace": "workspace_a",
|
||||
"content": "content_a2",
|
||||
"content_vector": [0.2] * 1536,
|
||||
},
|
||||
]
|
||||
|
||||
# Mock query responses
|
||||
async def query_side_effect(sql, params, **kwargs):
|
||||
multirows = kwargs.get("multirows", False)
|
||||
sql_upper = sql.upper()
|
||||
|
||||
# Count query for new table workspace data (verification before migration)
|
||||
if (
|
||||
"COUNT(*)" in sql_upper
|
||||
and "MODEL_1536D" in sql_upper
|
||||
and "WHERE WORKSPACE" in sql_upper
|
||||
):
|
||||
return new_table_record_count # Initially 0
|
||||
|
||||
# Count query with workspace filter (legacy table) - for workspace count
|
||||
elif "COUNT(*)" in sql_upper and "WHERE WORKSPACE" in sql_upper:
|
||||
if params and params[0] == "workspace_a":
|
||||
return {"count": 2} # workspace_a has 2 records
|
||||
elif params and params[0] == "workspace_b":
|
||||
return {"count": 3} # workspace_b has 3 records
|
||||
return {"count": 0}
|
||||
|
||||
# Count query for legacy table (total, no workspace filter)
|
||||
elif (
|
||||
"COUNT(*)" in sql_upper
|
||||
and "LIGHTRAG" in sql_upper
|
||||
and "WHERE WORKSPACE" not in sql_upper
|
||||
):
|
||||
return {"count": 5} # Total records in legacy
|
||||
|
||||
# SELECT with workspace filter for migration (multirows)
|
||||
elif "SELECT" in sql_upper and "FROM" in sql_upper and multirows:
|
||||
workspace = params[0] if params else None
|
||||
if workspace == "workspace_a":
|
||||
# Handle keyset pagination: check for "id >" pattern
|
||||
if "id >" in sql.lower():
|
||||
# Keyset pagination: params = [workspace, last_id, limit]
|
||||
last_id = params[1] if len(params) > 1 else None
|
||||
# Find records after last_id
|
||||
found_idx = -1
|
||||
for i, rec in enumerate(mock_records_a):
|
||||
if rec["id"] == last_id:
|
||||
found_idx = i
|
||||
break
|
||||
if found_idx >= 0:
|
||||
return mock_records_a[found_idx + 1 :]
|
||||
return []
|
||||
else:
|
||||
# First batch: params = [workspace, limit]
|
||||
return mock_records_a
|
||||
return [] # No data for other workspaces
|
||||
|
||||
return {}
|
||||
|
||||
db.query.side_effect = query_side_effect
|
||||
db.execute = AsyncMock()
|
||||
|
||||
# Mock check_table_exists on db
|
||||
async def check_table_exists_side_effect(name):
|
||||
if name.lower() == "lightrag_doc_chunks": # legacy
|
||||
return True
|
||||
elif name.lower() == "lightrag_doc_chunks_model_1536d": # new
|
||||
return False # New table doesn't exist initially
|
||||
return False
|
||||
|
||||
db.check_table_exists = AsyncMock(side_effect=check_table_exists_side_effect)
|
||||
|
||||
# Track migration through _run_with_retry calls
|
||||
migration_executed = []
|
||||
|
||||
async def mock_run_with_retry(operation, *args, **kwargs):
|
||||
migration_executed.append(True)
|
||||
new_table_record_count["count"] = 2 # Simulate 2 records migrated
|
||||
return None
|
||||
|
||||
db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry)
|
||||
|
||||
# Migrate for workspace_a only - correct parameter order
|
||||
await PGVectorStorage.setup_table(
|
||||
db,
|
||||
"LIGHTRAG_DOC_CHUNKS_model_1536d",
|
||||
workspace="workspace_a", # CRITICAL: Only migrate workspace_a
|
||||
embedding_dim=1536,
|
||||
legacy_table_name="LIGHTRAG_DOC_CHUNKS",
|
||||
base_table="LIGHTRAG_DOC_CHUNKS",
|
||||
)
|
||||
|
||||
# Verify the migration was triggered
|
||||
assert (
|
||||
len(migration_executed) > 0
|
||||
), "Migration should have been executed for workspace_a"
|
||||
|
||||
async def test_migration_without_workspace_raises_error(self):
|
||||
"""
|
||||
Test that migration without workspace parameter raises ValueError.
|
||||
|
||||
Scenario: setup_table called without workspace parameter.
|
||||
Expected: ValueError is raised because workspace is required.
|
||||
"""
|
||||
db = AsyncMock()
|
||||
|
||||
# workspace is now a required parameter - calling with None should raise ValueError
|
||||
with pytest.raises(ValueError, match="workspace must be provided"):
|
||||
await PGVectorStorage.setup_table(
|
||||
db,
|
||||
"lightrag_doc_chunks_model_1536d",
|
||||
workspace=None, # No workspace - should raise ValueError
|
||||
embedding_dim=1536,
|
||||
legacy_table_name="lightrag_doc_chunks",
|
||||
base_table="lightrag_doc_chunks",
|
||||
)
|
||||
|
||||
async def test_no_cross_workspace_contamination(self):
|
||||
"""
|
||||
Test that workspace B's migration doesn't include workspace A's data.
|
||||
|
||||
Scenario: Migration for workspace_b only.
|
||||
Expected: Only workspace_b data is queried, workspace_a data excluded.
|
||||
"""
|
||||
db = AsyncMock()
|
||||
|
||||
# Configure mock return values to avoid unawaited coroutine warnings
|
||||
db._create_vector_index.return_value = None
|
||||
|
||||
# Track which workspace is being queried
|
||||
queried_workspace = None
|
||||
new_table_count = {"count": 0}
|
||||
|
||||
# Mock data for workspace_b
|
||||
mock_records_b = [
|
||||
{
|
||||
"id": "b1",
|
||||
"workspace": "workspace_b",
|
||||
"content": "content_b1",
|
||||
"content_vector": [0.3] * 1536,
|
||||
},
|
||||
]
|
||||
|
||||
async def table_exists_side_effect(db_instance, name):
|
||||
if name.lower() == "lightrag_doc_chunks": # legacy
|
||||
return True
|
||||
elif name.lower() == "lightrag_doc_chunks_model_1536d": # new
|
||||
return False
|
||||
return False
|
||||
|
||||
async def query_side_effect(sql, params, **kwargs):
|
||||
nonlocal queried_workspace
|
||||
multirows = kwargs.get("multirows", False)
|
||||
sql_upper = sql.upper()
|
||||
|
||||
# Count query for new table workspace data (should be 0 initially)
|
||||
if (
|
||||
"COUNT(*)" in sql_upper
|
||||
and "MODEL_1536D" in sql_upper
|
||||
and "WHERE WORKSPACE" in sql_upper
|
||||
):
|
||||
return new_table_count
|
||||
|
||||
# Count query with workspace filter (legacy table)
|
||||
elif "COUNT(*)" in sql_upper and "WHERE WORKSPACE" in sql_upper:
|
||||
queried_workspace = params[0] if params else None
|
||||
return {"count": 1} # 1 record for the queried workspace
|
||||
|
||||
# Count query for legacy table total (no workspace filter)
|
||||
elif (
|
||||
"COUNT(*)" in sql_upper
|
||||
and "LIGHTRAG" in sql_upper
|
||||
and "WHERE WORKSPACE" not in sql_upper
|
||||
):
|
||||
return {"count": 3} # 3 total records in legacy
|
||||
|
||||
# SELECT with workspace filter for migration (multirows)
|
||||
elif "SELECT" in sql_upper and "FROM" in sql_upper and multirows:
|
||||
workspace = params[0] if params else None
|
||||
if workspace == "workspace_b":
|
||||
# Handle keyset pagination: check for "id >" pattern
|
||||
if "id >" in sql.lower():
|
||||
# Keyset pagination: params = [workspace, last_id, limit]
|
||||
last_id = params[1] if len(params) > 1 else None
|
||||
# Find records after last_id
|
||||
found_idx = -1
|
||||
for i, rec in enumerate(mock_records_b):
|
||||
if rec["id"] == last_id:
|
||||
found_idx = i
|
||||
break
|
||||
if found_idx >= 0:
|
||||
return mock_records_b[found_idx + 1 :]
|
||||
return []
|
||||
else:
|
||||
# First batch: params = [workspace, limit]
|
||||
return mock_records_b
|
||||
return [] # No data for other workspaces
|
||||
|
||||
return {}
|
||||
|
||||
db.query.side_effect = query_side_effect
|
||||
db.execute = AsyncMock()
|
||||
|
||||
# Mock check_table_exists on db
|
||||
async def check_table_exists_side_effect(name):
|
||||
if name.lower() == "lightrag_doc_chunks": # legacy
|
||||
return True
|
||||
elif name.lower() == "lightrag_doc_chunks_model_1536d": # new
|
||||
return False
|
||||
return False
|
||||
|
||||
db.check_table_exists = AsyncMock(side_effect=check_table_exists_side_effect)
|
||||
|
||||
# Track migration through _run_with_retry calls
|
||||
migration_executed = []
|
||||
|
||||
async def mock_run_with_retry(operation, *args, **kwargs):
|
||||
migration_executed.append(True)
|
||||
new_table_count["count"] = 1 # Simulate migration
|
||||
return None
|
||||
|
||||
db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry)
|
||||
|
||||
# Migrate workspace_b - correct parameter order
|
||||
await PGVectorStorage.setup_table(
|
||||
db,
|
||||
"LIGHTRAG_DOC_CHUNKS_model_1536d",
|
||||
workspace="workspace_b", # Only migrate workspace_b
|
||||
embedding_dim=1536,
|
||||
legacy_table_name="LIGHTRAG_DOC_CHUNKS",
|
||||
base_table="LIGHTRAG_DOC_CHUNKS",
|
||||
)
|
||||
|
||||
# Verify only workspace_b was queried
|
||||
assert queried_workspace == "workspace_b", "Should only query workspace_b"
|
||||
Reference in New Issue
Block a user