diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index f54a7ee3..c9fb6c0b 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -13,7 +13,7 @@ jobs:
strategy:
matrix:
- python-version: ['3.12', '3.13', '3.14']
+ python-version: ['3.12', '3.14']
steps:
- uses: actions/checkout@v6
diff --git a/README-zh.md b/README-zh.md
index 5a331b39..f72f4e01 100644
--- a/README-zh.md
+++ b/README-zh.md
@@ -286,7 +286,7 @@ if __name__ == "__main__":
参数
| **参数** | **类型** | **说明** | **默认值** |
-|--------------|----------|-----------------|-------------|
+| -------------- | ---------- | ----------------- | ------------- |
| **working_dir** | `str` | 存储缓存的目录 | `lightrag_cache+timestamp` |
| **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` |
| **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` |
@@ -425,7 +425,7 @@ async def llm_model_func(
**kwargs
)
-@wrap_embedding_func_with_attrs(embedding_dim=4096, max_token_size=8192)
+@wrap_embedding_func_with_attrs(embedding_dim=4096, max_token_size=8192, model_name="solar-embedding-1-large-query")
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embed.func(
texts,
@@ -490,7 +490,7 @@ import numpy as np
from lightrag.utils import wrap_embedding_func_with_attrs
from lightrag.llm.ollama import ollama_model_complete, ollama_embed
-@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192)
+@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192, model_name="nomic-embed-text")
async def embedding_func(texts: list[str]) -> np.ndarray:
return await ollama_embed.func(texts, embed_model="nomic-embed-text")
@@ -542,7 +542,7 @@ import numpy as np
from lightrag.utils import wrap_embedding_func_with_attrs
from lightrag.llm.ollama import ollama_model_complete, ollama_embed
-@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192)
+@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192, model_name="nomic-embed-text")
async def embedding_func(texts: list[str]) -> np.ndarray:
return await ollama_embed.func(texts, embed_model="nomic-embed-text")
@@ -1633,24 +1633,24 @@ LightRAG使用以下提示生成高级查询,相应的代码在`example/genera
### 总体性能表
-| |**农业**| |**计算机科学**| |**法律**| |**混合**| |
+||**农业**||**计算机科学**||**法律**||**混合**||
|----------------------|---------------|------------|------|------------|---------|------------|-------|------------|
-| |NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|
+||NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|
|**全面性**|32.4%|**67.6%**|38.4%|**61.6%**|16.4%|**83.6%**|38.8%|**61.2%**|
|**多样性**|23.6%|**76.4%**|38.0%|**62.0%**|13.6%|**86.4%**|32.4%|**67.6%**|
|**赋能性**|32.4%|**67.6%**|38.8%|**61.2%**|16.4%|**83.6%**|42.8%|**57.2%**|
|**总体**|32.4%|**67.6%**|38.8%|**61.2%**|15.2%|**84.8%**|40.0%|**60.0%**|
-| |RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|
+||RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|
|**全面性**|31.6%|**68.4%**|38.8%|**61.2%**|15.2%|**84.8%**|39.2%|**60.8%**|
|**多样性**|29.2%|**70.8%**|39.2%|**60.8%**|11.6%|**88.4%**|30.8%|**69.2%**|
|**赋能性**|31.6%|**68.4%**|36.4%|**63.6%**|15.2%|**84.8%**|42.4%|**57.6%**|
|**总体**|32.4%|**67.6%**|38.0%|**62.0%**|14.4%|**85.6%**|40.0%|**60.0%**|
-| |HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**|
+||HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**|
|**全面性**|26.0%|**74.0%**|41.6%|**58.4%**|26.8%|**73.2%**|40.4%|**59.6%**|
|**多样性**|24.0%|**76.0%**|38.8%|**61.2%**|20.0%|**80.0%**|32.4%|**67.6%**|
|**赋能性**|25.2%|**74.8%**|40.8%|**59.2%**|26.0%|**74.0%**|46.0%|**54.0%**|
|**总体**|24.8%|**75.2%**|41.6%|**58.4%**|26.4%|**73.6%**|42.4%|**57.6%**|
-| |GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|
+||GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|
|**全面性**|45.6%|**54.4%**|48.4%|**51.6%**|48.4%|**51.6%**|**50.4%**|49.6%|
|**多样性**|22.8%|**77.2%**|40.8%|**59.2%**|26.4%|**73.6%**|36.0%|**64.0%**|
|**赋能性**|41.2%|**58.8%**|45.2%|**54.8%**|43.6%|**56.4%**|**50.8%**|49.2%|
diff --git a/README.md b/README.md
index b157c350..d7a4f563 100644
--- a/README.md
+++ b/README.md
@@ -287,9 +287,9 @@ A full list of LightRAG init parameters:
Parameters
| **Parameter** | **Type** | **Explanation** | **Default** |
-|--------------|----------|-----------------|-------------|
+| -------------- | ---------- | ----------------- | ------------- |
| **working_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` |
-| **workspace** | str | Workspace name for data isolation between different LightRAG Instances | |
+| **workspace** | str | Workspace name for data isolation between different LightRAG Instances | |
| **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` |
| **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` |
| **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`,`Neo4JStorage`,`PGGraphStorage`,`AGEStorage` | `NetworkXStorage` |
@@ -421,7 +421,7 @@ async def llm_model_func(
**kwargs
)
-@wrap_embedding_func_with_attrs(embedding_dim=4096, max_token_size=8192)
+@wrap_embedding_func_with_attrs(embedding_dim=4096, max_token_size=8192, model_name="solar-embedding-1-large-query")
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embed.func(
texts,
@@ -488,7 +488,7 @@ import numpy as np
from lightrag.utils import wrap_embedding_func_with_attrs
from lightrag.llm.ollama import ollama_model_complete, ollama_embed
-@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192)
+@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192, model_name="nomic-embed-text")
async def embedding_func(texts: list[str]) -> np.ndarray:
return await ollama_embed.func(texts, embed_model="nomic-embed-text")
@@ -540,7 +540,7 @@ import numpy as np
from lightrag.utils import wrap_embedding_func_with_attrs
from lightrag.llm.ollama import ollama_model_complete, ollama_embed
-@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192)
+@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192, model_name="nomic-embed-text")
async def embedding_func(texts: list[str]) -> np.ndarray:
return await ollama_embed.func(texts, embed_model="nomic-embed-text")
@@ -1701,24 +1701,24 @@ Output your evaluation in the following JSON format:
### Overall Performance Table
-| |**Agriculture**| |**CS**| |**Legal**| |**Mix**| |
+||**Agriculture**||**CS**||**Legal**||**Mix**||
|----------------------|---------------|------------|------|------------|---------|------------|-------|------------|
-| |NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|
+||NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|
|**Comprehensiveness**|32.4%|**67.6%**|38.4%|**61.6%**|16.4%|**83.6%**|38.8%|**61.2%**|
|**Diversity**|23.6%|**76.4%**|38.0%|**62.0%**|13.6%|**86.4%**|32.4%|**67.6%**|
|**Empowerment**|32.4%|**67.6%**|38.8%|**61.2%**|16.4%|**83.6%**|42.8%|**57.2%**|
|**Overall**|32.4%|**67.6%**|38.8%|**61.2%**|15.2%|**84.8%**|40.0%|**60.0%**|
-| |RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|
+||RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|
|**Comprehensiveness**|31.6%|**68.4%**|38.8%|**61.2%**|15.2%|**84.8%**|39.2%|**60.8%**|
|**Diversity**|29.2%|**70.8%**|39.2%|**60.8%**|11.6%|**88.4%**|30.8%|**69.2%**|
|**Empowerment**|31.6%|**68.4%**|36.4%|**63.6%**|15.2%|**84.8%**|42.4%|**57.6%**|
|**Overall**|32.4%|**67.6%**|38.0%|**62.0%**|14.4%|**85.6%**|40.0%|**60.0%**|
-| |HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**|
+||HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**|
|**Comprehensiveness**|26.0%|**74.0%**|41.6%|**58.4%**|26.8%|**73.2%**|40.4%|**59.6%**|
|**Diversity**|24.0%|**76.0%**|38.8%|**61.2%**|20.0%|**80.0%**|32.4%|**67.6%**|
|**Empowerment**|25.2%|**74.8%**|40.8%|**59.2%**|26.0%|**74.0%**|46.0%|**54.0%**|
|**Overall**|24.8%|**75.2%**|41.6%|**58.4%**|26.4%|**73.6%**|42.4%|**57.6%**|
-| |GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|
+||GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|
|**Comprehensiveness**|45.6%|**54.4%**|48.4%|**51.6%**|48.4%|**51.6%**|**50.4%**|49.6%|
|**Diversity**|22.8%|**77.2%**|40.8%|**59.2%**|26.4%|**73.6%**|36.0%|**64.0%**|
|**Empowerment**|41.2%|**58.8%**|45.2%|**54.8%**|43.6%|**56.4%**|**50.8%**|49.2%|
diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py
index 692be453..9151f02e 100644
--- a/lightrag/api/lightrag_server.py
+++ b/lightrag/api/lightrag_server.py
@@ -868,6 +868,7 @@ def create_app(args):
func=optimized_embedding_function,
max_token_size=final_max_token_size,
send_dimensions=False, # Will be set later based on binding requirements
+ model_name=model,
)
# Log final embedding configuration
diff --git a/lightrag/base.py b/lightrag/base.py
index bae0728b..75059377 100644
--- a/lightrag/base.py
+++ b/lightrag/base.py
@@ -220,6 +220,37 @@ class BaseVectorStorage(StorageNameSpace, ABC):
cosine_better_than_threshold: float = field(default=0.2)
meta_fields: set[str] = field(default_factory=set)
+ def __post_init__(self):
+ """Validate required embedding_func for vector storage."""
+ if self.embedding_func is None:
+ raise ValueError(
+ "embedding_func is required for vector storage. "
+ "Please provide a valid EmbeddingFunc instance."
+ )
+
+ def _generate_collection_suffix(self) -> str | None:
+ """Generates collection/table suffix from embedding_func.
+
+ Return suffix if model_name exists in embedding_func, otherwise return None.
+ Note: embedding_func is guaranteed to exist (validated in __post_init__).
+
+ Returns:
+ str | None: Suffix string e.g. "text_embedding_3_large_3072d", or None if model_name not available
+ """
+ import re
+
+ # Check if model_name exists (model_name is optional in EmbeddingFunc)
+ model_name = getattr(self.embedding_func, "model_name", None)
+ if not model_name:
+ return None
+
+ # embedding_dim is required in EmbeddingFunc
+ embedding_dim = self.embedding_func.embedding_dim
+
+ # Generate suffix: clean model name and append dimension
+ safe_model_name = re.sub(r"[^a-zA-Z0-9_]", "_", model_name.lower())
+ return f"{safe_model_name}_{embedding_dim}d"
+
@abstractmethod
async def query(
self, query: str, top_k: int, query_embedding: list[float] = None
diff --git a/lightrag/exceptions.py b/lightrag/exceptions.py
index 709f294d..7c9accef 100644
--- a/lightrag/exceptions.py
+++ b/lightrag/exceptions.py
@@ -128,8 +128,8 @@ class ChunkTokenLimitExceededError(ValueError):
self.chunk_preview = truncated_preview
-class QdrantMigrationError(Exception):
- """Raised when Qdrant data migration from legacy collections fails."""
+class DataMigrationError(Exception):
+ """Raised when data migration from legacy collection/table fails."""
def __init__(self, message: str):
super().__init__(message)
diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py
index adb0058b..e299211b 100644
--- a/lightrag/kg/faiss_impl.py
+++ b/lightrag/kg/faiss_impl.py
@@ -28,6 +28,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
"""
def __post_init__(self):
+ super().__post_init__()
# Grab config values if available
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold")
@@ -358,9 +359,22 @@ class FaissVectorDBStorage(BaseVectorStorage):
)
return
+ dim_mismatch = False
try:
# Load the Faiss index
self._index = faiss.read_index(self._faiss_index_file)
+
+ # Verify dimension consistency between loaded index and embedding function
+ if self._index.d != self._dim:
+ error_msg = (
+ f"Dimension mismatch: loaded Faiss index has dimension {self._index.d}, "
+ f"but embedding function expects dimension {self._dim}. "
+ f"Please ensure the embedding model matches the stored index or rebuild the index."
+ )
+ logger.error(error_msg)
+ dim_mismatch = True
+ raise ValueError(error_msg)
+
# Load metadata
with open(self._meta_file, "r", encoding="utf-8") as f:
stored_dict = json.load(f)
@@ -375,6 +389,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
f"[{self.workspace}] Faiss index loaded with {self._index.ntotal} vectors from {self._faiss_index_file}"
)
except Exception as e:
+ if dim_mismatch:
+ raise
logger.error(
f"[{self.workspace}] Failed to load Faiss index or metadata: {e}"
)
diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py
index d42c91a7..50c233a8 100644
--- a/lightrag/kg/milvus_impl.py
+++ b/lightrag/kg/milvus_impl.py
@@ -934,6 +934,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
raise
def __post_init__(self):
+ super().__post_init__()
# Check for MILVUS_WORKSPACE environment variable first (higher priority)
# This allows administrators to force a specific workspace for all Milvus storage instances
milvus_workspace = os.environ.get("MILVUS_WORKSPACE")
diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py
index e11e6411..351e0039 100644
--- a/lightrag/kg/mongo_impl.py
+++ b/lightrag/kg/mongo_impl.py
@@ -89,7 +89,7 @@ class MongoKVStorage(BaseKVStorage):
global_config=global_config,
embedding_func=embedding_func,
)
- self.__post_init__()
+ # __post_init__() is automatically called by super().__init__()
def __post_init__(self):
# Check for MONGODB_WORKSPACE environment variable first (higher priority)
@@ -317,7 +317,7 @@ class MongoDocStatusStorage(DocStatusStorage):
global_config=global_config,
embedding_func=embedding_func,
)
- self.__post_init__()
+ # __post_init__() is automatically called by super().__init__()
def __post_init__(self):
# Check for MONGODB_WORKSPACE environment variable first (higher priority)
@@ -2052,9 +2052,12 @@ class MongoVectorDBStorage(BaseVectorStorage):
embedding_func=embedding_func,
meta_fields=meta_fields or set(),
)
- self.__post_init__()
+ # __post_init__() is automatically called by super().__init__()
def __post_init__(self):
+ # Call parent class __post_init__ to validate embedding_func
+ super().__post_init__()
+
# Check for MONGODB_WORKSPACE environment variable first (higher priority)
# This allows administrators to force a specific workspace for all MongoDB storage instances
mongodb_workspace = os.environ.get("MONGODB_WORKSPACE")
@@ -2131,8 +2134,32 @@ class MongoVectorDBStorage(BaseVectorStorage):
indexes = await indexes_cursor.to_list(length=None)
for index in indexes:
if index["name"] == self._index_name:
+ # Check if the existing index has matching vector dimensions
+ existing_dim = None
+ definition = index.get("latestDefinition", {})
+ fields = definition.get("fields", [])
+ for field in fields:
+ if (
+ field.get("type") == "vector"
+ and field.get("path") == "vector"
+ ):
+ existing_dim = field.get("numDimensions")
+ break
+
+ expected_dim = self.embedding_func.embedding_dim
+
+ if existing_dim is not None and existing_dim != expected_dim:
+ error_msg = (
+ f"Vector dimension mismatch! Index '{self._index_name}' has "
+ f"dimension {existing_dim}, but current embedding model expects "
+ f"dimension {expected_dim}. Please drop the existing index or "
+ f"use an embedding model with matching dimensions."
+ )
+ logger.error(f"[{self.workspace}] {error_msg}")
+ raise ValueError(error_msg)
+
logger.info(
- f"[{self.workspace}] vector index {self._index_name} already exist"
+ f"[{self.workspace}] vector index {self._index_name} already exists with matching dimensions ({expected_dim})"
)
return
diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py
index d390c37b..9b868c11 100644
--- a/lightrag/kg/nano_vector_db_impl.py
+++ b/lightrag/kg/nano_vector_db_impl.py
@@ -25,6 +25,7 @@ from .shared_storage import (
@dataclass
class NanoVectorDBStorage(BaseVectorStorage):
def __post_init__(self):
+ super().__post_init__()
# Initialize basic attributes
self._client = None
self._storage_lock = None
diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py
index 49069ce3..6b70b505 100644
--- a/lightrag/kg/postgres_impl.py
+++ b/lightrag/kg/postgres_impl.py
@@ -1,4 +1,5 @@
import asyncio
+import hashlib
import json
import os
import re
@@ -31,6 +32,7 @@ from ..base import (
DocStatus,
DocStatusStorage,
)
+from ..exceptions import DataMigrationError
from ..namespace import NameSpace, is_namespace
from ..utils import logger
from ..kg.shared_storage import get_data_init_lock
@@ -39,9 +41,12 @@ import pipmaster as pm
if not pm.is_installed("asyncpg"):
pm.install("asyncpg")
+if not pm.is_installed("pgvector"):
+ pm.install("pgvector")
import asyncpg # type: ignore
from asyncpg import Pool # type: ignore
+from pgvector.asyncpg import register_vector # type: ignore
from dotenv import load_dotenv
@@ -52,6 +57,42 @@ load_dotenv(dotenv_path=".env", override=False)
T = TypeVar("T")
+# PostgreSQL identifier length limit (in bytes)
+PG_MAX_IDENTIFIER_LENGTH = 63
+
+
+def _safe_index_name(table_name: str, index_suffix: str) -> str:
+ """
+ Generate a PostgreSQL-safe index name that won't be truncated.
+
+ PostgreSQL silently truncates identifiers to 63 bytes. This function
+ ensures index names stay within that limit by hashing long table names.
+
+ Args:
+ table_name: The table name (may be long with model suffix)
+ index_suffix: The index type suffix (e.g., 'hnsw_cosine', 'id', 'workspace_id')
+
+ Returns:
+ A deterministic index name that fits within 63 bytes
+ """
+ # Construct the full index name
+ full_name = f"idx_{table_name.lower()}_{index_suffix}"
+
+ # If it fits within the limit, use it as-is
+ if len(full_name.encode("utf-8")) <= PG_MAX_IDENTIFIER_LENGTH:
+ return full_name
+
+ # Otherwise, hash the table name to create a shorter unique identifier
+ # Keep 'idx_' prefix and suffix readable, hash the middle
+ hash_input = table_name.lower().encode("utf-8")
+ table_hash = hashlib.md5(hash_input).hexdigest()[:12] # 12 hex chars
+
+ # Format: idx_{hash}_{suffix} - guaranteed to fit
+ # Maximum: idx_ (4) + hash (12) + _ (1) + suffix (variable) = 17 + suffix
+ shortened_name = f"idx_{table_hash}_{index_suffix}"
+
+ return shortened_name
+
class PostgreSQLDB:
def __init__(self, config: dict[str, Any], **kwargs: Any):
@@ -252,14 +293,41 @@ class PostgreSQLDB:
else wait_fixed(0)
)
+ async def _init_connection(connection: asyncpg.Connection) -> None:
+ """Initialize each connection with pgvector codec.
+
+ This callback is invoked by asyncpg for every new connection in the pool.
+ Registering the vector codec here ensures ALL connections can properly
+ encode/decode vector columns, eliminating non-deterministic behavior
+ where some connections have the codec and others don't.
+ """
+ await register_vector(connection)
+
async def _create_pool_once() -> None:
- pool = await asyncpg.create_pool(**connection_params) # type: ignore
+ # STEP 1: Bootstrap - ensure vector extension exists BEFORE pool creation.
+ # On a fresh database, register_vector() in _init_connection will fail
+ # if the vector extension doesn't exist yet, because the 'vector' type
+ # won't be found in pg_catalog. We must create the extension first
+ # using a standalone bootstrap connection.
+ bootstrap_conn = await asyncpg.connect(
+ user=self.user,
+ password=self.password,
+ database=self.database,
+ host=self.host,
+ port=self.port,
+ ssl=connection_params.get("ssl"),
+ )
try:
- async with pool.acquire() as connection:
- await self.configure_vector_extension(connection)
- except Exception:
- await pool.close()
- raise
+ await self.configure_vector_extension(bootstrap_conn)
+ finally:
+ await bootstrap_conn.close()
+
+ # STEP 2: Now safe to create pool with register_vector callback.
+ # The vector extension is guaranteed to exist at this point.
+ pool = await asyncpg.create_pool(
+ **connection_params,
+ init=_init_connection, # Register pgvector codec on every connection
+ ) # type: ignore
self.pool = pool
try:
@@ -564,6 +632,24 @@ class PostgreSQLDB:
}
try:
+ # Filter out tables that don't exist (e.g., legacy vector tables may not exist)
+ existing_tables = {}
+ for table_name, columns in tables_to_migrate.items():
+ if await self.check_table_exists(table_name):
+ existing_tables[table_name] = columns
+ else:
+ logger.debug(
+ f"Table {table_name} does not exist, skipping timestamp migration"
+ )
+
+ # Skip if no tables to migrate
+ if not existing_tables:
+ logger.debug("No tables found for timestamp migration")
+ return
+
+ # Use filtered tables for migration
+ tables_to_migrate = existing_tables
+
# Optimization: Batch check all columns in one query instead of 8 separate queries
table_names_lower = [t.lower() for t in tables_to_migrate.keys()]
all_column_names = list(
@@ -640,6 +726,22 @@ class PostgreSQLDB:
"""
try:
+ # 0. Check if both tables exist before proceeding
+ vdb_chunks_exists = await self.check_table_exists("LIGHTRAG_VDB_CHUNKS")
+ doc_chunks_exists = await self.check_table_exists("LIGHTRAG_DOC_CHUNKS")
+
+ if not vdb_chunks_exists:
+ logger.debug(
+ "Skipping migration: LIGHTRAG_VDB_CHUNKS table does not exist"
+ )
+ return
+
+ if not doc_chunks_exists:
+ logger.debug(
+ "Skipping migration: LIGHTRAG_DOC_CHUNKS table does not exist"
+ )
+ return
+
# 1. Check if the new table LIGHTRAG_VDB_CHUNKS is empty
vdb_chunks_count_sql = "SELECT COUNT(1) as count FROM LIGHTRAG_VDB_CHUNKS"
vdb_chunks_count_result = await self.query(vdb_chunks_count_sql)
@@ -1008,6 +1110,24 @@ class PostgreSQLDB:
]
try:
+ # Filter out tables that don't exist (e.g., legacy vector tables may not exist)
+ existing_migrations = []
+ for migration in field_migrations:
+ if await self.check_table_exists(migration["table"]):
+ existing_migrations.append(migration)
+ else:
+ logger.debug(
+ f"Table {migration['table']} does not exist, skipping field length migration for {migration['column']}"
+ )
+
+ # Skip if no migrations to process
+ if not existing_migrations:
+ logger.debug("No tables found for field length migration")
+ return
+
+ # Use filtered migrations for processing
+ field_migrations = existing_migrations
+
# Optimization: Batch check all columns in one query instead of 5 separate queries
unique_tables = list(set(m["table"].lower() for m in field_migrations))
unique_columns = list(set(m["column"] for m in field_migrations))
@@ -1092,8 +1212,20 @@ class PostgreSQLDB:
logger.error(f"Failed to batch check field lengths: {e}")
async def check_tables(self):
- # First create all tables
+ # Vector tables that should be skipped - they are created by PGVectorStorage.setup_table()
+ # with proper embedding model and dimension suffix for data isolation
+ vector_tables_to_skip = {
+ "LIGHTRAG_VDB_CHUNKS",
+ "LIGHTRAG_VDB_ENTITY",
+ "LIGHTRAG_VDB_RELATION",
+ }
+
+ # First create all tables (except vector tables)
for k, v in TABLES.items():
+ # Skip vector tables - they are created by PGVectorStorage.setup_table()
+ if k in vector_tables_to_skip:
+ continue
+
try:
await self.query(f"SELECT 1 FROM {k} LIMIT 1")
except Exception:
@@ -1111,7 +1243,8 @@ class PostgreSQLDB:
# Batch check all indexes at once (optimization: single query instead of N queries)
try:
- table_names = list(TABLES.keys())
+ # Exclude vector tables from index creation since they are created by PGVectorStorage.setup_table()
+ table_names = [k for k in TABLES.keys() if k not in vector_tables_to_skip]
table_names_lower = [t.lower() for t in table_names]
# Get all existing indexes for our tables in one query
@@ -1163,23 +1296,9 @@ class PostgreSQLDB:
except Exception as e:
logger.error(f"PostgreSQL, Failed to batch check/create indexes: {e}")
- # Create vector indexs
- if self.vector_index_type:
- logger.info(
- f"PostgreSQL, Create vector indexs, type: {self.vector_index_type}"
- )
- try:
- if self.vector_index_type in ["HNSW", "IVFFLAT", "VCHORDRQ"]:
- await self._create_vector_indexes()
- else:
- logger.warning(
- "Doesn't support this vector index type: {self.vector_index_type}. "
- "Supported types: HNSW, IVFFLAT, VCHORDRQ"
- )
- except Exception as e:
- logger.error(
- f"PostgreSQL, Failed to create vector index, type: {self.vector_index_type}, Got: {e}"
- )
+ # NOTE: Vector index creation moved to PGVectorStorage.setup_table()
+ # Each vector storage instance creates its own index with correct embedding_dim
+
# After all tables are created, attempt to migrate timestamp fields
try:
await self._migrate_timestamp_columns()
@@ -1381,64 +1500,74 @@ class PostgreSQLDB:
except Exception as e:
logger.warning(f"Failed to create index {index['name']}: {e}")
- async def _create_vector_indexes(self):
- vdb_tables = [
- "LIGHTRAG_VDB_CHUNKS",
- "LIGHTRAG_VDB_ENTITY",
- "LIGHTRAG_VDB_RELATION",
- ]
+ async def _create_vector_index(self, table_name: str, embedding_dim: int):
+ """
+ Create vector index for a specific table.
+
+ Args:
+ table_name: Name of the table to create index on
+ embedding_dim: Embedding dimension for the vector column
+ """
+ if not self.vector_index_type:
+ return
create_sql = {
"HNSW": f"""
CREATE INDEX {{vector_index_name}}
- ON {{k}} USING hnsw (content_vector vector_cosine_ops)
+ ON {{table_name}} USING hnsw (content_vector vector_cosine_ops)
WITH (m = {self.hnsw_m}, ef_construction = {self.hnsw_ef})
""",
"IVFFLAT": f"""
CREATE INDEX {{vector_index_name}}
- ON {{k}} USING ivfflat (content_vector vector_cosine_ops)
+ ON {{table_name}} USING ivfflat (content_vector vector_cosine_ops)
WITH (lists = {self.ivfflat_lists})
""",
"VCHORDRQ": f"""
CREATE INDEX {{vector_index_name}}
- ON {{k}} USING vchordrq (content_vector vector_cosine_ops)
- {f'WITH (options = $${self.vchordrq_build_options}$$)' if self.vchordrq_build_options else ''}
+ ON {{table_name}} USING vchordrq (content_vector vector_cosine_ops)
+ {f"WITH (options = $${self.vchordrq_build_options}$$)" if self.vchordrq_build_options else ""}
""",
}
- embedding_dim = int(os.environ.get("EMBEDDING_DIM", 1024))
- for k in vdb_tables:
- vector_index_name = (
- f"idx_{k.lower()}_{self.vector_index_type.lower()}_cosine"
+ if self.vector_index_type not in create_sql:
+ logger.warning(
+ f"Unsupported vector index type: {self.vector_index_type}. "
+ "Supported types: HNSW, IVFFLAT, VCHORDRQ"
)
- check_vector_index_sql = f"""
- SELECT 1 FROM pg_indexes
- WHERE indexname = '{vector_index_name}' AND tablename = '{k.lower()}'
- """
- try:
- vector_index_exists = await self.query(check_vector_index_sql)
- if not vector_index_exists:
- # Only set vector dimension when index doesn't exist
- alter_sql = f"ALTER TABLE {k} ALTER COLUMN content_vector TYPE VECTOR({embedding_dim})"
- await self.execute(alter_sql)
- logger.debug(f"Ensured vector dimension for {k}")
- logger.info(
- f"Creating {self.vector_index_type} index {vector_index_name} on table {k}"
+ return
+
+ k = table_name
+ # Use _safe_index_name to avoid PostgreSQL's 63-byte identifier truncation
+ index_suffix = f"{self.vector_index_type.lower()}_cosine"
+ vector_index_name = _safe_index_name(k, index_suffix)
+ check_vector_index_sql = f"""
+ SELECT 1 FROM pg_indexes
+ WHERE indexname = '{vector_index_name}' AND tablename = '{k.lower()}'
+ """
+ try:
+ vector_index_exists = await self.query(check_vector_index_sql)
+ if not vector_index_exists:
+ # Only set vector dimension when index doesn't exist
+ alter_sql = f"ALTER TABLE {k} ALTER COLUMN content_vector TYPE VECTOR({embedding_dim})"
+ await self.execute(alter_sql)
+ logger.debug(f"Ensured vector dimension for {k}")
+ logger.info(
+ f"Creating {self.vector_index_type} index {vector_index_name} on table {k}"
+ )
+ await self.execute(
+ create_sql[self.vector_index_type].format(
+ vector_index_name=vector_index_name, table_name=k
)
- await self.execute(
- create_sql[self.vector_index_type].format(
- vector_index_name=vector_index_name, k=k
- )
- )
- logger.info(
- f"Successfully created vector index {vector_index_name} on table {k}"
- )
- else:
- logger.info(
- f"{self.vector_index_type} vector index {vector_index_name} already exists on table {k}"
- )
- except Exception as e:
- logger.error(f"Failed to create vector index on table {k}, Got: {e}")
+ )
+ logger.info(
+ f"Successfully created vector index {vector_index_name} on table {k}"
+ )
+ else:
+ logger.info(
+ f"{self.vector_index_type} vector index {vector_index_name} already exists on table {k}"
+ )
+ except Exception as e:
+ logger.error(f"Failed to create vector index on table {k}, Got: {e}")
async def query(
self,
@@ -1474,6 +1603,24 @@ class PostgreSQLDB:
logger.error(f"PostgreSQL database, error:{e}")
raise
+ async def check_table_exists(self, table_name: str) -> bool:
+ """Check if a table exists in PostgreSQL database
+
+ Args:
+ table_name: Name of the table to check
+
+ Returns:
+ bool: True if table exists, False otherwise
+ """
+ query = """
+ SELECT EXISTS (
+ SELECT FROM information_schema.tables
+ WHERE table_name = $1
+ )
+ """
+ result = await self.query(query, [table_name.lower()])
+ return result.get("exists", False) if result else False
+
async def execute(
self,
sql: str,
@@ -2181,6 +2328,7 @@ class PGVectorStorage(BaseVectorStorage):
db: PostgreSQLDB | None = field(default=None)
def __post_init__(self):
+ super().__post_init__()
self._max_batch_size = self.global_config["embedding_batch_num"]
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold")
@@ -2190,6 +2338,443 @@ class PGVectorStorage(BaseVectorStorage):
)
self.cosine_better_than_threshold = cosine_threshold
+ # Generate model suffix for table isolation
+ self.model_suffix = self._generate_collection_suffix()
+
+ # Get base table name
+ base_table = namespace_to_table_name(self.namespace)
+ if not base_table:
+ raise ValueError(f"Unknown namespace: {self.namespace}")
+
+ # New table name (with suffix)
+ # Ensure model_suffix is not empty before appending
+ if self.model_suffix:
+ self.table_name = f"{base_table}_{self.model_suffix}"
+ logger.info(f"PostgreSQL table: {self.table_name}")
+ else:
+ # Fallback: use base table name if model_suffix is unavailable
+ self.table_name = base_table
+ logger.warning(
+ f"PostgreSQL table: {self.table_name} missing suffix. Pls add model_name to embedding_func for proper workspace data isolation."
+ )
+
+ # Legacy table name (without suffix, for migration)
+ self.legacy_table_name = base_table
+
+ # Validate table name length (PostgreSQL identifier limit is 63 characters)
+ if len(self.table_name) > PG_MAX_IDENTIFIER_LENGTH:
+ raise ValueError(
+ f"PostgreSQL table name exceeds {PG_MAX_IDENTIFIER_LENGTH} character limit: '{self.table_name}' "
+ f"(length: {len(self.table_name)}). "
+ f"Consider using a shorter embedding model name or workspace name."
+ )
+
+ @staticmethod
+ async def _pg_create_table(
+ db: PostgreSQLDB, table_name: str, base_table: str, embedding_dim: int
+ ) -> None:
+ """Create a new vector table by replacing the table name in DDL template,
+ and create indexes on id and (workspace, id) columns.
+
+ Args:
+ db: PostgreSQLDB instance
+ table_name: Name of the new table to create
+ base_table: Base table name for DDL template lookup
+ embedding_dim: Embedding dimension for vector column
+ """
+ if base_table not in TABLES:
+ raise ValueError(f"No DDL template found for table: {base_table}")
+
+ ddl_template = TABLES[base_table]["ddl"]
+
+ # Replace embedding dimension placeholder if exists
+ ddl = ddl_template.replace("VECTOR(dimension)", f"VECTOR({embedding_dim})")
+
+ # Replace table name
+ ddl = ddl.replace(base_table, table_name)
+
+ await db.execute(ddl)
+
+ # Create indexes similar to check_tables() but with safe index names
+ # Create index for id column
+ id_index_name = _safe_index_name(table_name, "id")
+ try:
+ create_id_index_sql = f"CREATE INDEX {id_index_name} ON {table_name}(id)"
+ logger.info(
+ f"PostgreSQL, Creating index {id_index_name} on table {table_name}"
+ )
+ await db.execute(create_id_index_sql)
+ except Exception as e:
+ logger.error(
+ f"PostgreSQL, Failed to create index {id_index_name}, Got: {e}"
+ )
+
+ # Create composite index for (workspace, id)
+ workspace_id_index_name = _safe_index_name(table_name, "workspace_id")
+ try:
+ create_composite_index_sql = (
+ f"CREATE INDEX {workspace_id_index_name} ON {table_name}(workspace, id)"
+ )
+ logger.info(
+ f"PostgreSQL, Creating composite index {workspace_id_index_name} on table {table_name}"
+ )
+ await db.execute(create_composite_index_sql)
+ except Exception as e:
+ logger.error(
+ f"PostgreSQL, Failed to create composite index {workspace_id_index_name}, Got: {e}"
+ )
+
+ @staticmethod
+ async def _pg_migrate_workspace_data(
+ db: PostgreSQLDB,
+ legacy_table_name: str,
+ new_table_name: str,
+ workspace: str,
+ expected_count: int,
+ embedding_dim: int,
+ ) -> int:
+ """Migrate workspace data from legacy table to new table using batch insert.
+
+ This function uses asyncpg's executemany for efficient batch insertion,
+ reducing database round-trips from N to 1 per batch.
+
+ Uses keyset pagination (cursor-based) with ORDER BY id for stable ordering.
+ This ensures every legacy row is migrated exactly once, avoiding the
+ non-deterministic row ordering issues with OFFSET/LIMIT without ORDER BY.
+
+ Args:
+ db: PostgreSQLDB instance
+ legacy_table_name: Name of the legacy table to migrate from
+ new_table_name: Name of the new table to migrate to
+ workspace: Workspace to filter records for migration
+ expected_count: Expected number of records to migrate
+ embedding_dim: Embedding dimension for vector column
+
+ Returns:
+ Number of records migrated
+ """
+ migrated_count = 0
+ last_id: str | None = None
+ batch_size = 500
+
+ while True:
+ # Use keyset pagination with ORDER BY id for deterministic ordering
+ # This avoids OFFSET/LIMIT without ORDER BY which can skip or duplicate rows
+ if workspace:
+ if last_id is not None:
+ select_query = f"SELECT * FROM {legacy_table_name} WHERE workspace = $1 AND id > $2 ORDER BY id LIMIT $3"
+ rows = await db.query(
+ select_query, [workspace, last_id, batch_size], multirows=True
+ )
+ else:
+ select_query = f"SELECT * FROM {legacy_table_name} WHERE workspace = $1 ORDER BY id LIMIT $2"
+ rows = await db.query(
+ select_query, [workspace, batch_size], multirows=True
+ )
+ else:
+ if last_id is not None:
+ select_query = f"SELECT * FROM {legacy_table_name} WHERE id > $1 ORDER BY id LIMIT $2"
+ rows = await db.query(
+ select_query, [last_id, batch_size], multirows=True
+ )
+ else:
+ select_query = (
+ f"SELECT * FROM {legacy_table_name} ORDER BY id LIMIT $1"
+ )
+ rows = await db.query(select_query, [batch_size], multirows=True)
+
+ if not rows:
+ break
+
+ # Track the last ID for keyset pagination cursor
+ last_id = rows[-1]["id"]
+
+ # Batch insert optimization: use executemany instead of individual inserts
+ # Get column names from the first row
+ first_row = dict(rows[0])
+ columns = list(first_row.keys())
+ columns_str = ", ".join(columns)
+ placeholders = ", ".join([f"${i + 1}" for i in range(len(columns))])
+
+ insert_query = f"""
+ INSERT INTO {new_table_name} ({columns_str})
+ VALUES ({placeholders})
+ ON CONFLICT (workspace, id) DO NOTHING
+ """
+
+ # Prepare batch data: convert rows to list of tuples
+ batch_values = []
+ for row in rows:
+ row_dict = dict(row)
+
+ # FIX: Parse vector strings from connections without register_vector codec.
+ # When pgvector codec is not registered on the read connection, vector
+ # columns are returned as text strings like "[0.1,0.2,...]" instead of
+ # lists/arrays. We need to convert these to numpy arrays before passing
+ # to executemany, which uses a connection WITH register_vector codec
+ # that expects list/tuple/ndarray types.
+ if "content_vector" in row_dict:
+ vec = row_dict["content_vector"]
+ if isinstance(vec, str):
+ # pgvector text format: "[0.1,0.2,0.3,...]"
+ vec = vec.strip("[]")
+ if vec:
+ row_dict["content_vector"] = np.array(
+ [float(x) for x in vec.split(",")], dtype=np.float32
+ )
+ else:
+ row_dict["content_vector"] = None
+
+ # Extract values in column order to match placeholders
+ values_tuple = tuple(row_dict[col] for col in columns)
+ batch_values.append(values_tuple)
+
+ # Use executemany for batch execution - significantly reduces DB round-trips
+ # Note: register_vector is already called on pool init, no need to call it again
+ async def _batch_insert(connection: asyncpg.Connection) -> None:
+ await connection.executemany(insert_query, batch_values)
+
+ await db._run_with_retry(_batch_insert)
+
+ migrated_count += len(rows)
+ workspace_info = f" for workspace '{workspace}'" if workspace else ""
+ logger.info(
+ f"PostgreSQL: {migrated_count}/{expected_count} records migrated{workspace_info}"
+ )
+
+ return migrated_count
+
+ @staticmethod
+ async def setup_table(
+ db: PostgreSQLDB,
+ table_name: str,
+ workspace: str,
+ embedding_dim: int,
+ legacy_table_name: str,
+ base_table: str,
+ ):
+ """
+ Setup PostgreSQL table with migration support from legacy tables.
+
+ Ensure final table has workspace isolation index.
+ Check vector dimension compatibility before new table creation.
+ Drop legacy table if it exists and is empty.
+ Only migrate data from legacy table to new table when new table first created and legacy table is not empty.
+ This function must be call ClientManager.get_client() to legacy table is migrated to latest schema.
+
+ Args:
+ db: PostgreSQLDB instance
+ table_name: Name of the new table
+ workspace: Workspace to filter records for migration
+ legacy_table_name: Name of the legacy table to check for migration
+ base_table: Base table name for DDL template lookup
+ embedding_dim: Embedding dimension for vector column
+ """
+ if not workspace:
+ raise ValueError("workspace must be provided")
+
+ new_table_exists = await db.check_table_exists(table_name)
+ legacy_exists = legacy_table_name and await db.check_table_exists(
+ legacy_table_name
+ )
+
+ # Case 1: Only new table exists or new table is the same as legacy table
+ # No data migration needed, ensuring index is created then return
+ if (new_table_exists and not legacy_exists) or (
+ new_table_exists and (table_name.lower() == legacy_table_name.lower())
+ ):
+ await db._create_vector_index(table_name, embedding_dim)
+
+ workspace_count_query = (
+ f"SELECT COUNT(*) as count FROM {table_name} WHERE workspace = $1"
+ )
+ workspace_count_result = await db.query(workspace_count_query, [workspace])
+ workspace_count = (
+ workspace_count_result.get("count", 0) if workspace_count_result else 0
+ )
+ if workspace_count == 0 and not (
+ table_name.lower() == legacy_table_name.lower()
+ ):
+ logger.warning(
+ f"PostgreSQL: workspace data in table '{table_name}' is empty. "
+ f"Ensure it is caused by new workspace setup and not an unexpected embedding model change."
+ )
+
+ return
+
+ legacy_count = None
+ if not new_table_exists:
+ # Check vector dimension compatibility before creating new table
+ if legacy_exists:
+ count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name} WHERE workspace = $1"
+ count_result = await db.query(count_query, [workspace])
+ legacy_count = count_result.get("count", 0) if count_result else 0
+
+ if legacy_count > 0:
+ legacy_dim = None
+ try:
+ sample_query = f"SELECT content_vector FROM {legacy_table_name} WHERE workspace = $1 LIMIT 1"
+ sample_result = await db.query(sample_query, [workspace])
+ # Fix: Use 'is not None' instead of truthiness check to avoid
+ # NumPy array boolean ambiguity error
+ if (
+ sample_result
+ and sample_result.get("content_vector") is not None
+ ):
+ vector_data = sample_result["content_vector"]
+ # pgvector returns list directly, but may also return NumPy arrays
+ # when register_vector codec is active on the connection
+ if isinstance(vector_data, (list, tuple)):
+ legacy_dim = len(vector_data)
+ elif hasattr(vector_data, "__len__") and not isinstance(
+ vector_data, str
+ ):
+ # Handle NumPy arrays and other array-like objects
+ legacy_dim = len(vector_data)
+ elif isinstance(vector_data, str):
+ import json
+
+ vector_list = json.loads(vector_data)
+ legacy_dim = len(vector_list)
+
+ if legacy_dim and legacy_dim != embedding_dim:
+ logger.error(
+ f"PostgreSQL: Dimension mismatch detected! "
+ f"Legacy table '{legacy_table_name}' has {legacy_dim}d vectors, "
+ f"but new embedding model expects {embedding_dim}d."
+ )
+ raise DataMigrationError(
+ f"Dimension mismatch between legacy table '{legacy_table_name}' "
+ f"and new embedding model. Expected {embedding_dim}d but got {legacy_dim}d."
+ )
+
+ except DataMigrationError:
+ # Re-raise DataMigrationError as-is to preserve specific error messages
+ raise
+ except Exception as e:
+ raise DataMigrationError(
+ f"Could not verify legacy table vector dimension: {e}. "
+ f"Proceeding with caution..."
+ )
+
+ await PGVectorStorage._pg_create_table(
+ db, table_name, base_table, embedding_dim
+ )
+ logger.info(f"PostgreSQL: New table '{table_name}' created successfully")
+
+ if not legacy_exists:
+ await db._create_vector_index(table_name, embedding_dim)
+ logger.info(
+ "Ensure this new table creation is caused by new workspace setup and not an unexpected embedding model change."
+ )
+ return
+
+ # Ensure vector index is created
+ await db._create_vector_index(table_name, embedding_dim)
+
+ # Case 2: Legacy table exist
+ if legacy_exists:
+ workspace_info = f" for workspace '{workspace}'"
+
+ # Only drop legacy table if entire table is empty
+ total_count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name}"
+ total_count_result = await db.query(total_count_query, [])
+ total_count = (
+ total_count_result.get("count", 0) if total_count_result else 0
+ )
+ if total_count == 0:
+ logger.info(
+ f"PostgreSQL: Empty legacy table '{legacy_table_name}' deleted successfully"
+ )
+ drop_query = f"DROP TABLE {legacy_table_name}"
+ await db.execute(drop_query, None)
+ return
+
+ # No data migration needed if legacy workspace is empty
+ if legacy_count is None:
+ count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name} WHERE workspace = $1"
+ count_result = await db.query(count_query, [workspace])
+ legacy_count = count_result.get("count", 0) if count_result else 0
+
+ if legacy_count == 0:
+ logger.info(
+ f"PostgreSQL: No records{workspace_info} found in legacy table. "
+ f"No data migration needed."
+ )
+ return
+
+ new_count_query = (
+ f"SELECT COUNT(*) as count FROM {table_name} WHERE workspace = $1"
+ )
+ new_count_result = await db.query(new_count_query, [workspace])
+ new_table_workspace_count = (
+ new_count_result.get("count", 0) if new_count_result else 0
+ )
+
+ if new_table_workspace_count > 0:
+ logger.warning(
+ f"PostgreSQL: Both new and legacy collection have data. "
+ f"{legacy_count} records in {legacy_table_name} require manual deletion after migration verification."
+ )
+ return
+
+ # Case 3: Legacy has workspace data and new table is empty for workspace
+ logger.info(
+ f"PostgreSQL: Found legacy table '{legacy_table_name}' with {legacy_count} records{workspace_info}."
+ )
+ logger.info(
+ f"PostgreSQL: Migrating data from legacy table '{legacy_table_name}' to new table '{table_name}'"
+ )
+
+ try:
+ migrated_count = await PGVectorStorage._pg_migrate_workspace_data(
+ db,
+ legacy_table_name,
+ table_name,
+ workspace,
+ legacy_count,
+ embedding_dim,
+ )
+ if migrated_count != legacy_count:
+ logger.warning(
+ "PostgreSQL: Read %s legacy records%s during migration, expected %s.",
+ migrated_count,
+ workspace_info,
+ legacy_count,
+ )
+
+ new_count_result = await db.query(new_count_query, [workspace])
+ new_table_count_after = (
+ new_count_result.get("count", 0) if new_count_result else 0
+ )
+ inserted_count = new_table_count_after - new_table_workspace_count
+
+ if inserted_count != legacy_count:
+ error_msg = (
+ "PostgreSQL: Migration verification failed, "
+ f"expected {legacy_count} inserted records, got {inserted_count}."
+ )
+ logger.error(error_msg)
+ raise DataMigrationError(error_msg)
+
+ except DataMigrationError:
+ # Re-raise DataMigrationError as-is to preserve specific error messages
+ raise
+ except Exception as e:
+ logger.error(
+ f"PostgreSQL: Failed to migrate data from legacy table '{legacy_table_name}' to new table '{table_name}': {e}"
+ )
+ raise DataMigrationError(
+ f"Failed to migrate data from legacy table '{legacy_table_name}' to new table '{table_name}'"
+ ) from e
+
+ logger.info(
+ f"PostgreSQL: Migration from '{legacy_table_name}' to '{table_name}' completed successfully"
+ )
+ logger.warning(
+ "PostgreSQL: Manual deletion is required after data migration verification."
+ )
+
async def initialize(self):
async with get_data_init_lock():
if self.db is None:
@@ -2206,6 +2791,16 @@ class PGVectorStorage(BaseVectorStorage):
# Use "default" for compatibility (lowest priority)
self.workspace = "default"
+ # Setup table (create if not exists and handle migration)
+ await PGVectorStorage.setup_table(
+ self.db,
+ self.table_name,
+ self.workspace, # CRITICAL: Filter migration by workspace
+ embedding_dim=self.embedding_func.embedding_dim,
+ legacy_table_name=self.legacy_table_name,
+ base_table=self.legacy_table_name, # base_table for DDL template lookup
+ )
+
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
@@ -2213,75 +2808,97 @@ class PGVectorStorage(BaseVectorStorage):
def _upsert_chunks(
self, item: dict[str, Any], current_time: datetime.datetime
- ) -> tuple[str, dict[str, Any]]:
+ ) -> tuple[str, tuple[Any, ...]]:
+ """Prepare upsert data for chunks.
+
+ Returns:
+ Tuple of (SQL template, values tuple for executemany)
+ """
try:
- upsert_sql = SQL_TEMPLATES["upsert_chunk"]
- data: dict[str, Any] = {
- "workspace": self.workspace,
- "id": item["__id__"],
- "tokens": item["tokens"],
- "chunk_order_index": item["chunk_order_index"],
- "full_doc_id": item["full_doc_id"],
- "content": item["content"],
- "content_vector": json.dumps(item["__vector__"].tolist()),
- "file_path": item["file_path"],
- "create_time": current_time,
- "update_time": current_time,
- }
+ upsert_sql = SQL_TEMPLATES["upsert_chunk"].format(
+ table_name=self.table_name
+ )
+ # Return tuple in the exact order of SQL parameters ($1, $2, ...)
+ values: tuple[Any, ...] = (
+ self.workspace, # $1
+ item["__id__"], # $2
+ item["tokens"], # $3
+ item["chunk_order_index"], # $4
+ item["full_doc_id"], # $5
+ item["content"], # $6
+ item["__vector__"], # $7 - numpy array, handled by pgvector codec
+ item["file_path"], # $8
+ current_time, # $9
+ current_time, # $10
+ )
except Exception as e:
logger.error(
- f"[{self.workspace}] Error to prepare upsert,\nsql: {e}\nitem: {item}"
+ f"[{self.workspace}] Error to prepare upsert,\nerror: {e}\nitem: {item}"
)
raise
- return upsert_sql, data
+ return upsert_sql, values
def _upsert_entities(
self, item: dict[str, Any], current_time: datetime.datetime
- ) -> tuple[str, dict[str, Any]]:
- upsert_sql = SQL_TEMPLATES["upsert_entity"]
+ ) -> tuple[str, tuple[Any, ...]]:
+ """Prepare upsert data for entities.
+
+ Returns:
+ Tuple of (SQL template, values tuple for executemany)
+ """
+ upsert_sql = SQL_TEMPLATES["upsert_entity"].format(table_name=self.table_name)
source_id = item["source_id"]
if isinstance(source_id, str) and "" in source_id:
chunk_ids = source_id.split("")
else:
chunk_ids = [source_id]
- data: dict[str, Any] = {
- "workspace": self.workspace,
- "id": item["__id__"],
- "entity_name": item["entity_name"],
- "content": item["content"],
- "content_vector": json.dumps(item["__vector__"].tolist()),
- "chunk_ids": chunk_ids,
- "file_path": item.get("file_path", None),
- "create_time": current_time,
- "update_time": current_time,
- }
- return upsert_sql, data
+ # Return tuple in the exact order of SQL parameters ($1, $2, ...)
+ values: tuple[Any, ...] = (
+ self.workspace, # $1
+ item["__id__"], # $2
+ item["entity_name"], # $3
+ item["content"], # $4
+ item["__vector__"], # $5 - numpy array, handled by pgvector codec
+ chunk_ids, # $6
+ item.get("file_path", None), # $7
+ current_time, # $8
+ current_time, # $9
+ )
+ return upsert_sql, values
def _upsert_relationships(
self, item: dict[str, Any], current_time: datetime.datetime
- ) -> tuple[str, dict[str, Any]]:
- upsert_sql = SQL_TEMPLATES["upsert_relationship"]
+ ) -> tuple[str, tuple[Any, ...]]:
+ """Prepare upsert data for relationships.
+
+ Returns:
+ Tuple of (SQL template, values tuple for executemany)
+ """
+ upsert_sql = SQL_TEMPLATES["upsert_relationship"].format(
+ table_name=self.table_name
+ )
source_id = item["source_id"]
if isinstance(source_id, str) and "" in source_id:
chunk_ids = source_id.split("")
else:
chunk_ids = [source_id]
- data: dict[str, Any] = {
- "workspace": self.workspace,
- "id": item["__id__"],
- "source_id": item["src_id"],
- "target_id": item["tgt_id"],
- "content": item["content"],
- "content_vector": json.dumps(item["__vector__"].tolist()),
- "chunk_ids": chunk_ids,
- "file_path": item.get("file_path", None),
- "create_time": current_time,
- "update_time": current_time,
- }
- return upsert_sql, data
+ # Return tuple in the exact order of SQL parameters ($1, $2, ...)
+ values: tuple[Any, ...] = (
+ self.workspace, # $1
+ item["__id__"], # $2
+ item["src_id"], # $3
+ item["tgt_id"], # $4
+ item["content"], # $5
+ item["__vector__"], # $6 - numpy array, handled by pgvector codec
+ chunk_ids, # $7
+ item.get("file_path", None), # $8
+ current_time, # $9
+ current_time, # $10
+ )
+ return upsert_sql, values
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
@@ -2309,17 +2926,34 @@ class PGVectorStorage(BaseVectorStorage):
embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
+
+ # Prepare batch values for executemany
+ batch_values: list[tuple[Any, ...]] = []
+ upsert_sql = None
+
for item in list_data:
if is_namespace(self.namespace, NameSpace.VECTOR_STORE_CHUNKS):
- upsert_sql, data = self._upsert_chunks(item, current_time)
+ upsert_sql, values = self._upsert_chunks(item, current_time)
elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_ENTITIES):
- upsert_sql, data = self._upsert_entities(item, current_time)
+ upsert_sql, values = self._upsert_entities(item, current_time)
elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_RELATIONSHIPS):
- upsert_sql, data = self._upsert_relationships(item, current_time)
+ upsert_sql, values = self._upsert_relationships(item, current_time)
else:
raise ValueError(f"{self.namespace} is not supported")
- await self.db.execute(upsert_sql, data)
+ batch_values.append(values)
+
+ # Use executemany for batch execution - significantly reduces DB round-trips
+ # Note: register_vector is already called on pool init, no need to call it again
+ if batch_values and upsert_sql:
+
+ async def _batch_upsert(connection: asyncpg.Connection) -> None:
+ await connection.executemany(upsert_sql, batch_values)
+
+ await self.db._run_with_retry(_batch_upsert)
+ logger.debug(
+ f"[{self.workspace}] Batch upserted {len(batch_values)} records to {self.namespace}"
+ )
#################### query method ###############
async def query(
@@ -2335,7 +2969,9 @@ class PGVectorStorage(BaseVectorStorage):
embedding_string = ",".join(map(str, embedding))
- sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string)
+ sql = SQL_TEMPLATES[self.namespace].format(
+ embedding_string=embedding_string, table_name=self.table_name
+ )
params = {
"workspace": self.workspace,
"closer_than_threshold": 1 - self.cosine_better_than_threshold,
@@ -2357,14 +2993,9 @@ class PGVectorStorage(BaseVectorStorage):
if not ids:
return
- table_name = namespace_to_table_name(self.namespace)
- if not table_name:
- logger.error(
- f"[{self.workspace}] Unknown namespace for vector deletion: {self.namespace}"
- )
- return
-
- delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
+ delete_sql = (
+ f"DELETE FROM {self.table_name} WHERE workspace=$1 AND id = ANY($2)"
+ )
try:
await self.db.execute(delete_sql, {"workspace": self.workspace, "ids": ids})
@@ -2383,8 +3014,8 @@ class PGVectorStorage(BaseVectorStorage):
entity_name: The name of the entity to delete
"""
try:
- # Construct SQL to delete the entity
- delete_sql = """DELETE FROM LIGHTRAG_VDB_ENTITY
+ # Construct SQL to delete the entity using dynamic table name
+ delete_sql = f"""DELETE FROM {self.table_name}
WHERE workspace=$1 AND entity_name=$2"""
await self.db.execute(
@@ -2404,7 +3035,7 @@ class PGVectorStorage(BaseVectorStorage):
"""
try:
# Delete relations where the entity is either the source or target
- delete_sql = """DELETE FROM LIGHTRAG_VDB_RELATION
+ delete_sql = f"""DELETE FROM {self.table_name}
WHERE workspace=$1 AND (source_id=$2 OR target_id=$2)"""
await self.db.execute(
@@ -2427,14 +3058,7 @@ class PGVectorStorage(BaseVectorStorage):
Returns:
The vector data if found, or None if not found
"""
- table_name = namespace_to_table_name(self.namespace)
- if not table_name:
- logger.error(
- f"[{self.workspace}] Unknown namespace for ID lookup: {self.namespace}"
- )
- return None
-
- query = f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM {table_name} WHERE workspace=$1 AND id=$2"
+ query = f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM {self.table_name} WHERE workspace=$1 AND id=$2"
params = {"workspace": self.workspace, "id": id}
try:
@@ -2460,15 +3084,8 @@ class PGVectorStorage(BaseVectorStorage):
if not ids:
return []
- table_name = namespace_to_table_name(self.namespace)
- if not table_name:
- logger.error(
- f"[{self.workspace}] Unknown namespace for IDs lookup: {self.namespace}"
- )
- return []
-
ids_str = ",".join([f"'{id}'" for id in ids])
- query = f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM {table_name} WHERE workspace=$1 AND id IN ({ids_str})"
+ query = f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM {self.table_name} WHERE workspace=$1 AND id IN ({ids_str})"
params = {"workspace": self.workspace}
try:
@@ -2509,15 +3126,8 @@ class PGVectorStorage(BaseVectorStorage):
if not ids:
return {}
- table_name = namespace_to_table_name(self.namespace)
- if not table_name:
- logger.error(
- f"[{self.workspace}] Unknown namespace for vector lookup: {self.namespace}"
- )
- return {}
-
ids_str = ",".join([f"'{id}'" for id in ids])
- query = f"SELECT id, content_vector FROM {table_name} WHERE workspace=$1 AND id IN ({ids_str})"
+ query = f"SELECT id, content_vector FROM {self.table_name} WHERE workspace=$1 AND id IN ({ids_str})"
params = {"workspace": self.workspace}
try:
@@ -2527,10 +3137,18 @@ class PGVectorStorage(BaseVectorStorage):
for result in results:
if result and "content_vector" in result and "id" in result:
try:
- # Parse JSON string to get vector as list of floats
- vector_data = json.loads(result["content_vector"])
- if isinstance(vector_data, list):
- vectors_dict[result["id"]] = vector_data
+ vector_data = result["content_vector"]
+ # Handle both pgvector-registered connections (returns list/tuple)
+ # and non-registered connections (returns JSON string)
+ if isinstance(vector_data, (list, tuple)):
+ vectors_dict[result["id"]] = list(vector_data)
+ elif isinstance(vector_data, str):
+ parsed = json.loads(vector_data)
+ if isinstance(parsed, list):
+ vectors_dict[result["id"]] = parsed
+ # Handle numpy arrays from pgvector
+ elif hasattr(vector_data, "tolist"):
+ vectors_dict[result["id"]] = vector_data.tolist()
except (json.JSONDecodeError, TypeError) as e:
logger.warning(
f"[{self.workspace}] Failed to parse vector data for ID {result['id']}: {e}"
@@ -2546,15 +3164,8 @@ class PGVectorStorage(BaseVectorStorage):
async def drop(self) -> dict[str, str]:
"""Drop the storage"""
try:
- table_name = namespace_to_table_name(self.namespace)
- if not table_name:
- return {
- "status": "error",
- "message": f"Unknown namespace: {self.namespace}",
- }
-
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
- table_name=table_name
+ table_name=self.table_name
)
await self.db.execute(drop_sql, {"workspace": self.workspace})
return {"status": "success", "message": "data dropped"}
@@ -2593,6 +3204,9 @@ class PGDocStatusStorage(DocStatusStorage):
# Use "default" for compatibility (lowest priority)
self.workspace = "default"
+ # NOTE: Table creation is handled by PostgreSQLDB.initdb() during initialization
+ # No need to create table here as it's already created in the TABLES dict
+
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
@@ -4787,14 +5401,14 @@ TABLES = {
)"""
},
"LIGHTRAG_VDB_CHUNKS": {
- "ddl": f"""CREATE TABLE LIGHTRAG_VDB_CHUNKS (
+ "ddl": """CREATE TABLE LIGHTRAG_VDB_CHUNKS (
id VARCHAR(255),
workspace VARCHAR(255),
full_doc_id VARCHAR(256),
chunk_order_index INTEGER,
tokens INTEGER,
content TEXT,
- content_vector VECTOR({os.environ.get("EMBEDDING_DIM", 1024)}),
+ content_vector VECTOR(dimension),
file_path TEXT NULL,
create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
@@ -4802,12 +5416,12 @@ TABLES = {
)"""
},
"LIGHTRAG_VDB_ENTITY": {
- "ddl": f"""CREATE TABLE LIGHTRAG_VDB_ENTITY (
+ "ddl": """CREATE TABLE LIGHTRAG_VDB_ENTITY (
id VARCHAR(255),
workspace VARCHAR(255),
entity_name VARCHAR(512),
content TEXT,
- content_vector VECTOR({os.environ.get("EMBEDDING_DIM", 1024)}),
+ content_vector VECTOR(dimension),
create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
chunk_ids VARCHAR(255)[] NULL,
@@ -4816,13 +5430,13 @@ TABLES = {
)"""
},
"LIGHTRAG_VDB_RELATION": {
- "ddl": f"""CREATE TABLE LIGHTRAG_VDB_RELATION (
+ "ddl": """CREATE TABLE LIGHTRAG_VDB_RELATION (
id VARCHAR(255),
workspace VARCHAR(255),
source_id VARCHAR(512),
target_id VARCHAR(512),
content TEXT,
- content_vector VECTOR({os.environ.get("EMBEDDING_DIM", 1024)}),
+ content_vector VECTOR(dimension),
create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
chunk_ids VARCHAR(255)[] NULL,
@@ -5047,7 +5661,7 @@ SQL_TEMPLATES = {
update_time = EXCLUDED.update_time
""",
# SQL for VectorStorage
- "upsert_chunk": """INSERT INTO LIGHTRAG_VDB_CHUNKS (workspace, id, tokens,
+ "upsert_chunk": """INSERT INTO {table_name} (workspace, id, tokens,
chunk_order_index, full_doc_id, content, content_vector, file_path,
create_time, update_time)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
@@ -5060,7 +5674,7 @@ SQL_TEMPLATES = {
file_path=EXCLUDED.file_path,
update_time = EXCLUDED.update_time
""",
- "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
+ "upsert_entity": """INSERT INTO {table_name} (workspace, id, entity_name, content,
content_vector, chunk_ids, file_path, create_time, update_time)
VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7, $8, $9)
ON CONFLICT (workspace,id) DO UPDATE
@@ -5071,7 +5685,7 @@ SQL_TEMPLATES = {
file_path=EXCLUDED.file_path,
update_time=EXCLUDED.update_time
""",
- "upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id,
+ "upsert_relationship": """INSERT INTO {table_name} (workspace, id, source_id,
target_id, content, content_vector, chunk_ids, file_path, create_time, update_time)
VALUES ($1, $2, $3, $4, $5, $6, $7::varchar[], $8, $9, $10)
ON CONFLICT (workspace,id) DO UPDATE
@@ -5087,7 +5701,7 @@ SQL_TEMPLATES = {
SELECT r.source_id AS src_id,
r.target_id AS tgt_id,
EXTRACT(EPOCH FROM r.create_time)::BIGINT AS created_at
- FROM LIGHTRAG_VDB_RELATION r
+ FROM {table_name} r
WHERE r.workspace = $1
AND r.content_vector <=> '[{embedding_string}]'::vector < $2
ORDER BY r.content_vector <=> '[{embedding_string}]'::vector
@@ -5096,7 +5710,7 @@ SQL_TEMPLATES = {
"entities": """
SELECT e.entity_name,
EXTRACT(EPOCH FROM e.create_time)::BIGINT AS created_at
- FROM LIGHTRAG_VDB_ENTITY e
+ FROM {table_name} e
WHERE e.workspace = $1
AND e.content_vector <=> '[{embedding_string}]'::vector < $2
ORDER BY e.content_vector <=> '[{embedding_string}]'::vector
@@ -5107,7 +5721,7 @@ SQL_TEMPLATES = {
c.content,
c.file_path,
EXTRACT(EPOCH FROM c.create_time)::BIGINT AS created_at
- FROM LIGHTRAG_VDB_CHUNKS c
+ FROM {table_name} c
WHERE c.workspace = $1
AND c.content_vector <=> '[{embedding_string}]'::vector < $2
ORDER BY c.content_vector <=> '[{embedding_string}]'::vector
diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py
index 75de2613..81f0f4e4 100644
--- a/lightrag/kg/qdrant_impl.py
+++ b/lightrag/kg/qdrant_impl.py
@@ -10,7 +10,7 @@ import numpy as np
import pipmaster as pm
from ..base import BaseVectorStorage
-from ..exceptions import QdrantMigrationError
+from ..exceptions import DataMigrationError
from ..kg.shared_storage import get_data_init_lock
from ..utils import compute_mdhash_id, logger
@@ -66,6 +66,51 @@ def workspace_filter_condition(workspace: str) -> models.FieldCondition:
)
+def _find_legacy_collection(
+ client: QdrantClient,
+ namespace: str,
+ workspace: str = None,
+ model_suffix: str = None,
+) -> str | None:
+ """
+ Find legacy collection with backward compatibility support.
+
+ This function tries multiple naming patterns to locate legacy collections
+ created by older versions of LightRAG:
+
+ 1. lightrag_vdb_{namespace} - if model_suffix is provided (HIGHEST PRIORITY)
+ 2. {workspace}_{namespace} or {namespace} - no matter if model_suffix is provided or not
+ 3. lightrag_vdb_{namespace} - fall back value no matter if model_suffix is provided or not (LOWEST PRIORITY)
+
+ Args:
+ client: QdrantClient instance
+ namespace: Base namespace (e.g., "chunks", "entities")
+ workspace: Optional workspace identifier
+ model_suffix: Optional model suffix for new collection
+
+ Returns:
+ Collection name if found, None otherwise
+ """
+ # Try multiple naming patterns for backward compatibility
+ # More specific names (with workspace) have higher priority
+ candidates = [
+ f"lightrag_vdb_{namespace}" if model_suffix else None,
+ f"{workspace}_{namespace}" if workspace else None,
+ f"lightrag_vdb_{namespace}",
+ namespace,
+ ]
+
+ for candidate in candidates:
+ if candidate and client.collection_exists(candidate):
+ logger.info(
+ f"Qdrant: Found legacy collection '{candidate}' "
+ f"(namespace={namespace}, workspace={workspace or 'none'})"
+ )
+ return candidate
+
+ return None
+
+
@final
@dataclass
class QdrantVectorDBStorage(BaseVectorStorage):
@@ -79,63 +124,52 @@ class QdrantVectorDBStorage(BaseVectorStorage):
embedding_func=embedding_func,
meta_fields=meta_fields or set(),
)
- self.__post_init__()
+ # __post_init__() is automatically called by super().__init__()
@staticmethod
def setup_collection(
client: QdrantClient,
collection_name: str,
- legacy_namespace: str = None,
- workspace: str = None,
- **kwargs,
+ namespace: str,
+ workspace: str,
+ vectors_config: models.VectorParams,
+ hnsw_config: models.HnswConfigDiff,
+ model_suffix: str,
):
"""
Setup Qdrant collection with migration support from legacy collections.
+ Ensure final collection has workspace isolation index.
+ Check vector dimension compatibility before new collection creation.
+ Drop legacy collection if it exists and is empty.
+ Only migrate data from legacy collection to new collection when new collection first created and legacy collection is not empty.
+
Args:
client: QdrantClient instance
- collection_name: Name of the new collection
- legacy_namespace: Name of the legacy collection (if exists)
+ collection_name: Name of the final collection
+ namespace: Base namespace (e.g., "chunks", "entities")
workspace: Workspace identifier for data isolation
- **kwargs: Additional arguments for collection creation (vectors_config, hnsw_config, etc.)
+ vectors_config: Vector configuration parameters for the collection
+ hnsw_config: HNSW index configuration diff for the collection
"""
+ if not namespace or not workspace:
+ raise ValueError("namespace and workspace must be provided")
+
+ workspace_count_filter = models.Filter(
+ must=[workspace_filter_condition(workspace)]
+ )
+
new_collection_exists = client.collection_exists(collection_name)
- legacy_exists = legacy_namespace and client.collection_exists(legacy_namespace)
+ legacy_collection = _find_legacy_collection(
+ client, namespace, workspace, model_suffix
+ )
- # Case 1: Both new and legacy collections exist - Warning only (no migration)
- if new_collection_exists and legacy_exists:
- logger.warning(
- f"Qdrant: Legacy collection '{legacy_namespace}' still exist. Remove it if migration is complete."
- )
- return
-
- # Case 2: Only new collection exists - Ensure index exists
- if new_collection_exists:
- # Check if workspace index exists, create if missing
- try:
- collection_info = client.get_collection(collection_name)
- if WORKSPACE_ID_FIELD not in collection_info.payload_schema:
- logger.info(
- f"Qdrant: Creating missing workspace index for '{collection_name}'"
- )
- client.create_payload_index(
- collection_name=collection_name,
- field_name=WORKSPACE_ID_FIELD,
- field_schema=models.KeywordIndexParams(
- type=models.KeywordIndexType.KEYWORD,
- is_tenant=True,
- ),
- )
- except Exception as e:
- logger.warning(
- f"Qdrant: Could not verify/create workspace index for '{collection_name}': {e}"
- )
- return
-
- # Case 3: Neither exists - Create new collection
- if not legacy_exists:
- logger.info(f"Qdrant: Creating new collection '{collection_name}'")
- client.create_collection(collection_name, **kwargs)
+ # Case 1: Only new collection exists or new collection is the same as legacy collection
+ # No data migration needed, and ensuring index is created then return
+ if (new_collection_exists and not legacy_collection) or (
+ collection_name == legacy_collection
+ ):
+ # create_payload_index return without error if index already exists
client.create_payload_index(
collection_name=collection_name,
field_name=WORKSPACE_ID_FIELD,
@@ -144,132 +178,244 @@ class QdrantVectorDBStorage(BaseVectorStorage):
is_tenant=True,
),
)
- logger.info(f"Qdrant: Collection '{collection_name}' created successfully")
+ new_workspace_count = client.count(
+ collection_name=collection_name,
+ count_filter=workspace_count_filter,
+ exact=True,
+ ).count
+
+ # Skip data migration if new collection already has workspace data
+ if new_workspace_count == 0 and not (collection_name == legacy_collection):
+ logger.warning(
+ f"Qdrant: workspace data in collection '{collection_name}' is empty. "
+ f"Ensure it is caused by new workspace setup and not an unexpected embedding model change."
+ )
+
return
- # Case 4: Only legacy exists - Migrate data
- logger.info(
- f"Qdrant: Migrating data from legacy collection '{legacy_namespace}'"
+ legacy_count = None
+ if not new_collection_exists:
+ # Check vector dimension compatibility before creating new collection
+ if legacy_collection:
+ legacy_count = client.count(
+ collection_name=legacy_collection, exact=True
+ ).count
+ if legacy_count > 0:
+ legacy_info = client.get_collection(legacy_collection)
+ legacy_dim = legacy_info.config.params.vectors.size
+
+ if vectors_config.size and legacy_dim != vectors_config.size:
+ logger.error(
+ f"Qdrant: Dimension mismatch detected! "
+ f"Legacy collection '{legacy_collection}' has {legacy_dim}d vectors, "
+ f"but new embedding model expects {vectors_config.size}d."
+ )
+
+ raise DataMigrationError(
+ f"Dimension mismatch between legacy collection '{legacy_collection}' "
+ f"and new collection. Expected {vectors_config.size}d but got {legacy_dim}d."
+ )
+
+ client.create_collection(
+ collection_name, vectors_config=vectors_config, hnsw_config=hnsw_config
+ )
+ logger.info(f"Qdrant: Collection '{collection_name}' created successfully")
+ if not legacy_collection:
+ logger.warning(
+ "Qdrant: Ensure this new collection creation is caused by new workspace setup and not an unexpected embedding model change."
+ )
+
+ # create_payload_index return without error if index already exists
+ client.create_payload_index(
+ collection_name=collection_name,
+ field_name=WORKSPACE_ID_FIELD,
+ field_schema=models.KeywordIndexParams(
+ type=models.KeywordIndexType.KEYWORD,
+ is_tenant=True,
+ ),
)
- try:
- # Get legacy collection count
- legacy_count = client.count(
- collection_name=legacy_namespace, exact=True
- ).count
- logger.info(f"Qdrant: Found {legacy_count} records in legacy collection")
-
+ # Case 2: Legacy collection exist
+ if legacy_collection:
+ # Only drop legacy collection if it's empty
+ if legacy_count is None:
+ legacy_count = client.count(
+ collection_name=legacy_collection, exact=True
+ ).count
if legacy_count == 0:
- logger.info("Qdrant: Legacy collection is empty, skipping migration")
- # Create new empty collection
- client.create_collection(collection_name, **kwargs)
- client.create_payload_index(
- collection_name=collection_name,
- field_name=WORKSPACE_ID_FIELD,
- field_schema=models.KeywordIndexParams(
- type=models.KeywordIndexType.KEYWORD,
- is_tenant=True,
- ),
+ client.delete_collection(collection_name=legacy_collection)
+ logger.info(
+ f"Qdrant: Empty legacy collection '{legacy_collection}' deleted successfully"
)
return
- # Create new collection first
- logger.info(f"Qdrant: Creating new collection '{collection_name}'")
- client.create_collection(collection_name, **kwargs)
+ new_workspace_count = client.count(
+ collection_name=collection_name,
+ count_filter=workspace_count_filter,
+ exact=True,
+ ).count
- # Batch migration (500 records per batch)
- migrated_count = 0
- offset = None
- batch_size = 500
-
- while True:
- # Scroll through legacy data
- result = client.scroll(
- collection_name=legacy_namespace,
- limit=batch_size,
- offset=offset,
- with_vectors=True,
- with_payload=True,
+ # Skip data migration if new collection already has workspace data
+ if new_workspace_count > 0:
+ logger.warning(
+ f"Qdrant: Both new and legacy collection have data. "
+ f"{legacy_count} records in {legacy_collection} require manual deletion after migration verification."
)
- points, next_offset = result
+ return
- if not points:
- break
+ # Case 3: Only legacy exists - migrate data from legacy collection to new collection
+ # Check if legacy collection has workspace_id to determine migration strategy
+ # Note: payload_schema only reflects INDEXED fields, so we also sample
+ # actual payloads to detect unindexed workspace_id fields
+ legacy_info = client.get_collection(legacy_collection)
+ has_workspace_index = WORKSPACE_ID_FIELD in (
+ legacy_info.payload_schema or {}
+ )
- # Transform points for new collection
- new_points = []
- for point in points:
- # Add workspace_id to payload
- new_payload = dict(point.payload or {})
- new_payload[WORKSPACE_ID_FIELD] = workspace or DEFAULT_WORKSPACE
-
- # Create new point with workspace-prefixed ID
- original_id = new_payload.get(ID_FIELD)
- if original_id:
- new_point_id = compute_mdhash_id_for_qdrant(
- original_id, prefix=workspace or DEFAULT_WORKSPACE
+ # Detect workspace_id field presence by sampling payloads if not indexed
+ # This prevents cross-workspace data leakage when workspace_id exists but isn't indexed
+ has_workspace_field = has_workspace_index
+ if not has_workspace_index:
+ # Sample a small batch of points to check for workspace_id in payloads
+ # All points must have workspace_id if any point has it
+ sample_result = client.scroll(
+ collection_name=legacy_collection,
+ limit=10, # Small sample is sufficient for detection
+ with_payload=True,
+ with_vectors=False,
+ )
+ sample_points, _ = sample_result
+ for point in sample_points:
+ if point.payload and WORKSPACE_ID_FIELD in point.payload:
+ has_workspace_field = True
+ logger.info(
+ f"Qdrant: Detected unindexed {WORKSPACE_ID_FIELD} field "
+ f"in legacy collection '{legacy_collection}' via payload sampling"
)
- else:
- # Fallback: use original point ID
- new_point_id = str(point.id)
+ break
- new_points.append(
- models.PointStruct(
- id=new_point_id,
- vector=point.vector,
- payload=new_payload,
+ # Build workspace filter if legacy collection has workspace support
+ # This prevents cross-workspace data leakage during migration
+ legacy_scroll_filter = None
+ if has_workspace_field:
+ legacy_scroll_filter = models.Filter(
+ must=[workspace_filter_condition(workspace)]
+ )
+ # Recount with workspace filter for accurate migration tracking
+ legacy_count = client.count(
+ collection_name=legacy_collection,
+ count_filter=legacy_scroll_filter,
+ exact=True,
+ ).count
+ logger.info(
+ f"Qdrant: Legacy collection has workspace support, "
+ f"filtering to {legacy_count} records for workspace '{workspace}'"
+ )
+
+ logger.info(
+ f"Qdrant: Found legacy collection '{legacy_collection}' with {legacy_count} records to migrate."
+ )
+ logger.info(
+ f"Qdrant: Migrating data from legacy collection '{legacy_collection}' to new collection '{collection_name}'"
+ )
+
+ try:
+ # Batch migration (500 records per batch)
+ migrated_count = 0
+ offset = None
+ batch_size = 500
+
+ while True:
+ # Scroll through legacy data with optional workspace filter
+ result = client.scroll(
+ collection_name=legacy_collection,
+ scroll_filter=legacy_scroll_filter,
+ limit=batch_size,
+ offset=offset,
+ with_vectors=True,
+ with_payload=True,
+ )
+ points, next_offset = result
+
+ if not points:
+ break
+
+ # Transform points for new collection
+ new_points = []
+ for point in points:
+ # Set workspace_id in payload
+ new_payload = dict(point.payload or {})
+ new_payload[WORKSPACE_ID_FIELD] = workspace
+
+ # Create new point with workspace-prefixed ID
+ original_id = new_payload.get(ID_FIELD)
+ if original_id:
+ new_point_id = compute_mdhash_id_for_qdrant(
+ original_id, prefix=workspace
+ )
+ else:
+ # Fallback: use original point ID
+ new_point_id = str(point.id)
+
+ new_points.append(
+ models.PointStruct(
+ id=new_point_id,
+ vector=point.vector,
+ payload=new_payload,
+ )
)
+
+ # Upsert to new collection
+ client.upsert(
+ collection_name=collection_name, points=new_points, wait=True
)
- # Upsert to new collection
- client.upsert(
- collection_name=collection_name, points=new_points, wait=True
+ migrated_count += len(points)
+ logger.info(
+ f"Qdrant: {migrated_count}/{legacy_count} records migrated"
+ )
+
+ # Check if we've reached the end
+ if next_offset is None:
+ break
+ offset = next_offset
+
+ new_count_after = client.count(
+ collection_name=collection_name,
+ count_filter=workspace_count_filter,
+ exact=True,
+ ).count
+ inserted_count = new_count_after - new_workspace_count
+ if inserted_count != legacy_count:
+ error_msg = (
+ "Qdrant: Migration verification failed, expected "
+ f"{legacy_count} inserted records, got {inserted_count}."
+ )
+ logger.error(error_msg)
+ raise DataMigrationError(error_msg)
+
+ except DataMigrationError:
+ # Re-raise DataMigrationError as-is to preserve specific error messages
+ raise
+ except Exception as e:
+ logger.error(
+ f"Qdrant: Failed to migrate data from legacy collection '{legacy_collection}' to new collection '{collection_name}': {e}"
)
-
- migrated_count += len(points)
- logger.info(f"Qdrant: {migrated_count}/{legacy_count} records migrated")
-
- # Check if we've reached the end
- if next_offset is None:
- break
- offset = next_offset
-
- # Verify migration by comparing counts
- logger.info("Verifying migration...")
- new_count = client.count(collection_name=collection_name, exact=True).count
-
- if new_count != legacy_count:
- error_msg = f"Qdrant: Migration verification failed, expected {legacy_count} records, got {new_count} in new collection"
- logger.error(error_msg)
- raise QdrantMigrationError(error_msg)
+ raise DataMigrationError(
+ f"Failed to migrate data from legacy collection '{legacy_collection}' to new collection '{collection_name}'"
+ ) from e
logger.info(
- f"Qdrant: Migration completed successfully: {migrated_count} records migrated"
+ f"Qdrant: Migration from '{legacy_collection}' to '{collection_name}' completed successfully"
)
-
- # Create payload index after successful migration
- logger.info("Qdrant: Creating workspace payload index...")
- client.create_payload_index(
- collection_name=collection_name,
- field_name=WORKSPACE_ID_FIELD,
- field_schema=models.KeywordIndexParams(
- type=models.KeywordIndexType.KEYWORD,
- is_tenant=True,
- ),
+ logger.warning(
+ "Qdrant: Manual deletion is required after data migration verification."
)
- logger.info(
- f"Qdrant: Migration from '{legacy_namespace}' to '{collection_name}' completed successfully"
- )
-
- except QdrantMigrationError:
- # Re-raise migration errors without wrapping
- raise
- except Exception as e:
- error_msg = f"Qdrant: Migration failed with error: {e}"
- logger.error(error_msg)
- raise QdrantMigrationError(error_msg) from e
def __post_init__(self):
+ # Call parent class __post_init__ to validate embedding_func
+ super().__post_init__()
+
# Check for QDRANT_WORKSPACE environment variable first (higher priority)
# This allows administrators to force a specific workspace for all Qdrant storage instances
qdrant_workspace = os.environ.get("QDRANT_WORKSPACE")
@@ -287,20 +433,23 @@ class QdrantVectorDBStorage(BaseVectorStorage):
f"Using passed workspace parameter: '{effective_workspace}'"
)
- # Get legacy namespace for data migration from old version
- if effective_workspace:
- self.legacy_namespace = f"{effective_workspace}_{self.namespace}"
- else:
- self.legacy_namespace = self.namespace
-
self.effective_workspace = effective_workspace or DEFAULT_WORKSPACE
- # Use a shared collection with payload-based partitioning (Qdrant's recommended approach)
- # Ref: https://qdrant.tech/documentation/guides/multiple-partitions/
- self.final_namespace = f"lightrag_vdb_{self.namespace}"
- logger.debug(
- f"Using shared collection '{self.final_namespace}' with workspace '{self.effective_workspace}' for payload-based partitioning"
- )
+ # Generate model suffix
+ self.model_suffix = self._generate_collection_suffix()
+
+ # New naming scheme with model isolation
+ # Example: "lightrag_vdb_chunks_text_embedding_ada_002_1536d"
+ # Ensure model_suffix is not empty before appending
+ if self.model_suffix:
+ self.final_namespace = f"lightrag_vdb_{self.namespace}_{self.model_suffix}"
+ logger.info(f"Qdrant collection: {self.final_namespace}")
+ else:
+ # Fallback: use legacy namespace if model_suffix is unavailable
+ self.final_namespace = f"lightrag_vdb_{self.namespace}"
+ logger.warning(
+ f"Qdrant collection: {self.final_namespace} missing suffix. Pls add model_name to embedding_func for proper workspace data isolation."
+ )
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold")
@@ -338,11 +487,11 @@ class QdrantVectorDBStorage(BaseVectorStorage):
)
# Setup collection (create if not exists and configure indexes)
- # Pass legacy_namespace and workspace for migration support
+ # Pass namespace and workspace for backward-compatible migration support
QdrantVectorDBStorage.setup_collection(
self._client,
self.final_namespace,
- legacy_namespace=self.legacy_namespace,
+ namespace=self.namespace,
workspace=self.effective_workspace,
vectors_config=models.VectorParams(
size=self.embedding_func.embedding_dim,
@@ -352,8 +501,11 @@ class QdrantVectorDBStorage(BaseVectorStorage):
payload_m=16,
m=0,
),
+ model_suffix=self.model_suffix,
)
+ # Removed duplicate max batch size initialization
+
self._initialized = True
logger.info(
f"[{self.workspace}] Qdrant collection '{self.namespace}' initialized successfully"
@@ -481,21 +633,44 @@ class QdrantVectorDBStorage(BaseVectorStorage):
entity_name: Name of the entity to delete
"""
try:
- # Generate the entity ID using the same function as used for storage
+ # Compute entity ID from name (same as Milvus)
entity_id = compute_mdhash_id(entity_name, prefix=ENTITY_PREFIX)
- qdrant_entity_id = compute_mdhash_id_for_qdrant(
- entity_id, prefix=self.effective_workspace
+ logger.debug(
+ f"[{self.workspace}] Attempting to delete entity {entity_name} with ID {entity_id}"
)
- # Delete the entity point by its Qdrant ID directly
- self._client.delete(
+ # Scroll to find the entity by its ID field in payload with workspace filtering
+ # This is safer than reconstructing the Qdrant point ID
+ results = self._client.scroll(
collection_name=self.final_namespace,
- points_selector=models.PointIdsList(points=[qdrant_entity_id]),
- wait=True,
- )
- logger.debug(
- f"[{self.workspace}] Successfully deleted entity {entity_name}"
+ scroll_filter=models.Filter(
+ must=[
+ workspace_filter_condition(self.effective_workspace),
+ models.FieldCondition(
+ key=ID_FIELD, match=models.MatchValue(value=entity_id)
+ ),
+ ]
+ ),
+ with_payload=False,
+ limit=1,
)
+
+ # Extract point IDs to delete
+ points = results[0]
+ if points:
+ ids_to_delete = [point.id for point in points]
+ self._client.delete(
+ collection_name=self.final_namespace,
+ points_selector=models.PointIdsList(points=ids_to_delete),
+ wait=True,
+ )
+ logger.debug(
+ f"[{self.workspace}] Successfully deleted entity {entity_name}"
+ )
+ else:
+ logger.debug(
+ f"[{self.workspace}] Entity {entity_name} not found in storage"
+ )
except Exception as e:
logger.error(f"[{self.workspace}] Error deleting entity {entity_name}: {e}")
@@ -506,38 +681,60 @@ class QdrantVectorDBStorage(BaseVectorStorage):
entity_name: Name of the entity whose relations should be deleted
"""
try:
- # Find relations where the entity is either source or target, with workspace filtering
- results = self._client.scroll(
- collection_name=self.final_namespace,
- scroll_filter=models.Filter(
- must=[workspace_filter_condition(self.effective_workspace)],
- should=[
- models.FieldCondition(
- key="src_id", match=models.MatchValue(value=entity_name)
- ),
- models.FieldCondition(
- key="tgt_id", match=models.MatchValue(value=entity_name)
- ),
- ],
- ),
- with_payload=True,
- limit=1000, # Adjust as needed for your use case
+ # Build the filter to find relations where entity is either source or target
+ # must + should = workspace_id matches AND (src_id matches OR tgt_id matches)
+ relation_filter = models.Filter(
+ must=[workspace_filter_condition(self.effective_workspace)],
+ should=[
+ models.FieldCondition(
+ key="src_id", match=models.MatchValue(value=entity_name)
+ ),
+ models.FieldCondition(
+ key="tgt_id", match=models.MatchValue(value=entity_name)
+ ),
+ ],
)
- # Extract points that need to be deleted
- relation_points = results[0]
- ids_to_delete = [point.id for point in relation_points]
+ # Paginate through all matching relations to handle large datasets
+ total_deleted = 0
+ offset = None
+ batch_size = 1000
- if ids_to_delete:
- # Delete the relations with workspace filtering
- assert isinstance(self._client, QdrantClient)
+ while True:
+ # Scroll to find relations, using with_payload=False for efficiency
+ # since we only need point IDs for deletion
+ results = self._client.scroll(
+ collection_name=self.final_namespace,
+ scroll_filter=relation_filter,
+ with_payload=False,
+ with_vectors=False,
+ limit=batch_size,
+ offset=offset,
+ )
+
+ points, next_offset = results
+ if not points:
+ break
+
+ # Extract point IDs to delete
+ ids_to_delete = [point.id for point in points]
+
+ # Delete the batch of relations
self._client.delete(
collection_name=self.final_namespace,
points_selector=models.PointIdsList(points=ids_to_delete),
wait=True,
)
+ total_deleted += len(ids_to_delete)
+
+ # Check if we've reached the end
+ if next_offset is None:
+ break
+ offset = next_offset
+
+ if total_deleted > 0:
logger.debug(
- f"[{self.workspace}] Deleted {len(ids_to_delete)} relations for {entity_name}"
+ f"[{self.workspace}] Deleted {total_deleted} relations for {entity_name}"
)
else:
logger.debug(
diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py
index ef0f61e2..6da56308 100644
--- a/lightrag/kg/shared_storage.py
+++ b/lightrag/kg/shared_storage.py
@@ -163,7 +163,9 @@ class UnifiedLock(Generic[T]):
enable_output=self._enable_logging,
)
- # Then acquire the main lock
+ # Acquire the main lock
+ # Note: self._lock should never be None here as the check has been moved
+ # to get_internal_lock() and get_data_init_lock() functions
if self._is_async:
await self._lock.acquire()
else:
@@ -193,19 +195,21 @@ class UnifiedLock(Generic[T]):
async def __aexit__(self, exc_type, exc_val, exc_tb):
main_lock_released = False
+ async_lock_released = False
try:
# Release main lock first
- if self._is_async:
- self._lock.release()
- else:
- self._lock.release()
- main_lock_released = True
+ if self._lock is not None:
+ if self._is_async:
+ self._lock.release()
+ else:
+ self._lock.release()
- direct_log(
- f"== Lock == Process {self._pid}: Released lock {self._name} (async={self._is_async})",
- level="INFO",
- enable_output=self._enable_logging,
- )
+ direct_log(
+ f"== Lock == Process {self._pid}: Released lock {self._name} (async={self._is_async})",
+ level="INFO",
+ enable_output=self._enable_logging,
+ )
+ main_lock_released = True
# Then release async lock if in multiprocess mode
if not self._is_async and self._async_lock is not None:
@@ -215,6 +219,7 @@ class UnifiedLock(Generic[T]):
level="DEBUG",
enable_output=self._enable_logging,
)
+ async_lock_released = True
except Exception as e:
direct_log(
@@ -223,9 +228,10 @@ class UnifiedLock(Generic[T]):
enable_output=True,
)
- # If main lock release failed but async lock hasn't been released, try to release it
+ # If main lock release failed but async lock hasn't been attempted yet, try to release it
if (
not main_lock_released
+ and not async_lock_released
and not self._is_async
and self._async_lock is not None
):
@@ -255,6 +261,10 @@ class UnifiedLock(Generic[T]):
try:
if self._is_async:
raise RuntimeError("Use 'async with' for shared_storage lock")
+
+ # Acquire the main lock
+ # Note: self._lock should never be None here as the check has been moved
+ # to get_internal_lock() and get_data_init_lock() functions
direct_log(
f"== Lock == Process {self._pid}: Acquiring lock {self._name} (sync)",
level="DEBUG",
@@ -1060,6 +1070,10 @@ class _KeyedLockContext:
def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified storage lock for data consistency"""
+ if _internal_lock is None:
+ raise RuntimeError(
+ "Shared data not initialized. Call initialize_share_data() before using locks!"
+ )
async_lock = _async_locks.get("internal_lock") if _is_multiprocess else None
return UnifiedLock(
lock=_internal_lock,
@@ -1090,6 +1104,10 @@ def get_storage_keyed_lock(
def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified data initialization lock for ensuring atomic data initialization"""
+ if _data_init_lock is None:
+ raise RuntimeError(
+ "Shared data not initialized. Call initialize_share_data() before using locks!"
+ )
async_lock = _async_locks.get("data_init_lock") if _is_multiprocess else None
return UnifiedLock(
lock=_data_init_lock,
diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py
index 8a638759..d1e2bac3 100644
--- a/lightrag/lightrag.py
+++ b/lightrag/lightrag.py
@@ -7,7 +7,7 @@ import inspect
import os
import time
import warnings
-from dataclasses import asdict, dataclass, field
+from dataclasses import asdict, dataclass, field, replace
from datetime import datetime, timezone
from functools import partial
from typing import (
@@ -518,14 +518,9 @@ class LightRAG:
f"max_total_tokens({self.summary_max_tokens}) should greater than summary_length_recommended({self.summary_length_recommended})"
)
- # Fix global_config now
- global_config = asdict(self)
-
- _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
- logger.debug(f"LightRAG init with param:\n {_print_config}\n")
-
# Init Embedding
- # Step 1: Capture max_token_size before applying decorator (decorator strips dataclass attributes)
+ # Step 1: Capture embedding_func and max_token_size before applying decorator
+ original_embedding_func = self.embedding_func
embedding_max_token_size = None
if self.embedding_func and hasattr(self.embedding_func, "max_token_size"):
embedding_max_token_size = self.embedding_func.max_token_size
@@ -534,12 +529,26 @@ class LightRAG:
)
self.embedding_token_limit = embedding_max_token_size
- # Step 2: Apply priority wrapper decorator
- self.embedding_func = priority_limit_async_func_call(
- self.embedding_func_max_async,
- llm_timeout=self.default_embedding_timeout,
- queue_name="Embedding func",
- )(self.embedding_func)
+ # Fix global_config now
+ global_config = asdict(self)
+ # Restore original EmbeddingFunc object (asdict converts it to dict)
+ global_config["embedding_func"] = original_embedding_func
+
+ _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
+ logger.debug(f"LightRAG init with param:\n {_print_config}\n")
+
+ # Step 2: Apply priority wrapper decorator to EmbeddingFunc's inner func
+ # Create a NEW EmbeddingFunc instance with the wrapped func to avoid mutating the caller's object
+ # This ensures _generate_collection_suffix can still access attributes (model_name, embedding_dim)
+ # while preventing side effects when the same EmbeddingFunc is reused across multiple LightRAG instances
+ if self.embedding_func is not None:
+ wrapped_func = priority_limit_async_func_call(
+ self.embedding_func_max_async,
+ llm_timeout=self.default_embedding_timeout,
+ queue_name="Embedding func",
+ )(self.embedding_func.func)
+ # Use dataclasses.replace() to create a new instance, leaving the original unchanged
+ self.embedding_func = replace(self.embedding_func, func=wrapped_func)
# Initialize all storages
self.key_string_value_json_storage_cls: type[BaseKVStorage] = (
diff --git a/lightrag/llm/bedrock.py b/lightrag/llm/bedrock.py
index f6871422..e651e3c8 100644
--- a/lightrag/llm/bedrock.py
+++ b/lightrag/llm/bedrock.py
@@ -351,7 +351,9 @@ async def bedrock_complete(
return result
-@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
+@wrap_embedding_func_with_attrs(
+ embedding_dim=1024, max_token_size=8192, model_name="amazon.titan-embed-text-v2:0"
+)
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=60),
diff --git a/lightrag/llm/gemini.py b/lightrag/llm/gemini.py
index 37ce7206..5e438ceb 100644
--- a/lightrag/llm/gemini.py
+++ b/lightrag/llm/gemini.py
@@ -453,7 +453,9 @@ async def gemini_model_complete(
)
-@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=2048)
+@wrap_embedding_func_with_attrs(
+ embedding_dim=1536, max_token_size=2048, model_name="gemini-embedding-001"
+)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
diff --git a/lightrag/llm/hf.py b/lightrag/llm/hf.py
index 447f95c3..eff89650 100644
--- a/lightrag/llm/hf.py
+++ b/lightrag/llm/hf.py
@@ -142,7 +142,9 @@ async def hf_model_complete(
return result
-@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
+@wrap_embedding_func_with_attrs(
+ embedding_dim=1024, max_token_size=8192, model_name="hf_embedding_model"
+)
async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray:
# Detect the appropriate device
if torch.cuda.is_available():
diff --git a/lightrag/llm/jina.py b/lightrag/llm/jina.py
index 41251f4a..5c380854 100644
--- a/lightrag/llm/jina.py
+++ b/lightrag/llm/jina.py
@@ -58,7 +58,9 @@ async def fetch_data(url, headers, data):
return data_list
-@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192)
+@wrap_embedding_func_with_attrs(
+ embedding_dim=2048, max_token_size=8192, model_name="jina-embeddings-v4"
+)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
diff --git a/lightrag/llm/lollms.py b/lightrag/llm/lollms.py
index 2f2a1dbf..3eaef1af 100644
--- a/lightrag/llm/lollms.py
+++ b/lightrag/llm/lollms.py
@@ -138,7 +138,9 @@ async def lollms_model_complete(
)
-@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
+@wrap_embedding_func_with_attrs(
+ embedding_dim=1024, max_token_size=8192, model_name="lollms_embedding_model"
+)
async def lollms_embed(
texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs
) -> np.ndarray:
diff --git a/lightrag/llm/nvidia_openai.py b/lightrag/llm/nvidia_openai.py
index 1ebaf3a6..9025ec13 100644
--- a/lightrag/llm/nvidia_openai.py
+++ b/lightrag/llm/nvidia_openai.py
@@ -33,7 +33,9 @@ from lightrag.utils import (
import numpy as np
-@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192)
+@wrap_embedding_func_with_attrs(
+ embedding_dim=2048, max_token_size=8192, model_name="nvidia_embedding_model"
+)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
diff --git a/lightrag/llm/ollama.py b/lightrag/llm/ollama.py
index cd633e80..62269296 100644
--- a/lightrag/llm/ollama.py
+++ b/lightrag/llm/ollama.py
@@ -172,7 +172,9 @@ async def ollama_model_complete(
)
-@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
+@wrap_embedding_func_with_attrs(
+ embedding_dim=1024, max_token_size=8192, model_name="bge-m3:latest"
+)
async def ollama_embed(
texts: list[str], embed_model: str = "bge-m3:latest", **kwargs
) -> np.ndarray:
diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py
index 9c3d0261..b49cac71 100644
--- a/lightrag/llm/openai.py
+++ b/lightrag/llm/openai.py
@@ -677,7 +677,9 @@ async def nvidia_openai_complete(
return result
-@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
+@wrap_embedding_func_with_attrs(
+ embedding_dim=1536, max_token_size=8192, model_name="text-embedding-3-small"
+)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
@@ -867,7 +869,11 @@ async def azure_openai_complete(
return result
-@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
+@wrap_embedding_func_with_attrs(
+ embedding_dim=1536,
+ max_token_size=8192,
+ model_name="my-text-embedding-3-large-deployment",
+)
async def azure_openai_embed(
texts: list[str],
model: str | None = None,
diff --git a/lightrag/llm/zhipu.py b/lightrag/llm/zhipu.py
index d90f3cc1..5caa82bf 100644
--- a/lightrag/llm/zhipu.py
+++ b/lightrag/llm/zhipu.py
@@ -179,7 +179,9 @@ async def zhipu_complete(
)
-@wrap_embedding_func_with_attrs(embedding_dim=1024)
+@wrap_embedding_func_with_attrs(
+ embedding_dim=1024, max_token_size=8192, model_name="embedding-3"
+)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
diff --git a/lightrag/tools/prepare_qdrant_legacy_data.py b/lightrag/tools/prepare_qdrant_legacy_data.py
new file mode 100644
index 00000000..2ac90196
--- /dev/null
+++ b/lightrag/tools/prepare_qdrant_legacy_data.py
@@ -0,0 +1,720 @@
+#!/usr/bin/env python3
+"""
+Qdrant Legacy Data Preparation Tool for LightRAG
+
+This tool copies data from new collections to legacy collections for testing
+the data migration logic in setup_collection function.
+
+New Collections (with workspace_id):
+ - lightrag_vdb_chunks
+ - lightrag_vdb_entities
+ - lightrag_vdb_relationships
+
+Legacy Collections (without workspace_id, dynamically named as {workspace}_{suffix}):
+ - {workspace}_chunks (e.g., space1_chunks)
+ - {workspace}_entities (e.g., space1_entities)
+ - {workspace}_relationships (e.g., space1_relationships)
+
+The tool:
+ 1. Filters source data by workspace_id
+ 2. Verifies workspace data exists before creating legacy collections
+ 3. Removes workspace_id field to simulate legacy data format
+ 4. Copies only the specified workspace's data to legacy collections
+
+Usage:
+ python -m lightrag.tools.prepare_qdrant_legacy_data
+ # or
+ python lightrag/tools/prepare_qdrant_legacy_data.py
+
+ # Specify custom workspace
+ python -m lightrag.tools.prepare_qdrant_legacy_data --workspace space1
+
+ # Process specific collection types only
+ python -m lightrag.tools.prepare_qdrant_legacy_data --types chunks,entities
+
+ # Dry run (preview only, no actual changes)
+ python -m lightrag.tools.prepare_qdrant_legacy_data --dry-run
+"""
+
+import argparse
+import asyncio
+import configparser
+import os
+import sys
+import time
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional
+
+import pipmaster as pm
+from dotenv import load_dotenv
+from qdrant_client import QdrantClient, models # type: ignore
+
+# Add project root to path for imports
+sys.path.insert(
+ 0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+)
+
+# Load environment variables
+load_dotenv(dotenv_path=".env", override=False)
+
+# Ensure qdrant-client is installed
+if not pm.is_installed("qdrant-client"):
+ pm.install("qdrant-client")
+
+# Collection namespace mapping: new collection pattern -> legacy suffix
+# Legacy collection will be named as: {workspace}_{suffix}
+COLLECTION_NAMESPACES = {
+ "chunks": {
+ "new": "lightrag_vdb_chunks",
+ "suffix": "chunks",
+ },
+ "entities": {
+ "new": "lightrag_vdb_entities",
+ "suffix": "entities",
+ },
+ "relationships": {
+ "new": "lightrag_vdb_relationships",
+ "suffix": "relationships",
+ },
+}
+
+# Default batch size for copy operations
+DEFAULT_BATCH_SIZE = 500
+
+# Field to remove from legacy data
+WORKSPACE_ID_FIELD = "workspace_id"
+
+# ANSI color codes for terminal output
+BOLD_CYAN = "\033[1;36m"
+BOLD_GREEN = "\033[1;32m"
+BOLD_YELLOW = "\033[1;33m"
+BOLD_RED = "\033[1;31m"
+RESET = "\033[0m"
+
+
+@dataclass
+class CopyStats:
+ """Copy operation statistics"""
+
+ collection_type: str
+ source_collection: str
+ target_collection: str
+ total_records: int = 0
+ copied_records: int = 0
+ failed_records: int = 0
+ errors: List[Dict[str, Any]] = field(default_factory=list)
+ elapsed_time: float = 0.0
+
+ def add_error(self, batch_idx: int, error: Exception, batch_size: int):
+ """Record batch error"""
+ self.errors.append(
+ {
+ "batch": batch_idx,
+ "error_type": type(error).__name__,
+ "error_msg": str(error),
+ "records_lost": batch_size,
+ "timestamp": time.time(),
+ }
+ )
+ self.failed_records += batch_size
+
+
+class QdrantLegacyDataPreparationTool:
+ """Tool for preparing legacy data in Qdrant for migration testing"""
+
+ def __init__(
+ self,
+ workspace: str = "space1",
+ batch_size: int = DEFAULT_BATCH_SIZE,
+ dry_run: bool = False,
+ clear_target: bool = False,
+ ):
+ """
+ Initialize the tool.
+
+ Args:
+ workspace: Workspace to use for filtering new collection data
+ batch_size: Number of records to process per batch
+ dry_run: If True, only preview operations without making changes
+ clear_target: If True, delete target collection before copying data
+ """
+ self.workspace = workspace
+ self.batch_size = batch_size
+ self.dry_run = dry_run
+ self.clear_target = clear_target
+ self._client: Optional[QdrantClient] = None
+
+ def _get_client(self) -> QdrantClient:
+ """Get or create QdrantClient instance"""
+ if self._client is None:
+ config = configparser.ConfigParser()
+ config.read("config.ini", "utf-8")
+
+ self._client = QdrantClient(
+ url=os.environ.get(
+ "QDRANT_URL", config.get("qdrant", "uri", fallback=None)
+ ),
+ api_key=os.environ.get(
+ "QDRANT_API_KEY",
+ config.get("qdrant", "apikey", fallback=None),
+ ),
+ )
+ return self._client
+
+ def print_header(self):
+ """Print tool header"""
+ print("\n" + "=" * 60)
+ print("Qdrant Legacy Data Preparation Tool - LightRAG")
+ print("=" * 60)
+ if self.dry_run:
+ print(f"{BOLD_YELLOW}⚠️ DRY RUN MODE - No changes will be made{RESET}")
+ if self.clear_target:
+ print(
+ f"{BOLD_RED}⚠️ CLEAR TARGET MODE - Target collections will be deleted first{RESET}"
+ )
+ print(f"Workspace: {BOLD_CYAN}{self.workspace}{RESET}")
+ print(f"Batch Size: {self.batch_size}")
+ print("=" * 60)
+
+ def check_connection(self) -> bool:
+ """Check Qdrant connection"""
+ try:
+ client = self._get_client()
+ # Try to list collections to verify connection
+ client.get_collections()
+ print(f"{BOLD_GREEN}✓{RESET} Qdrant connection successful")
+ return True
+ except Exception as e:
+ print(f"{BOLD_RED}✗{RESET} Qdrant connection failed: {e}")
+ return False
+
+ def get_collection_info(self, collection_name: str) -> Optional[Dict[str, Any]]:
+ """
+ Get collection information.
+
+ Args:
+ collection_name: Name of the collection
+
+ Returns:
+ Dictionary with collection info (vector_size, count) or None if not exists
+ """
+ client = self._get_client()
+
+ if not client.collection_exists(collection_name):
+ return None
+
+ info = client.get_collection(collection_name)
+ count = client.count(collection_name=collection_name, exact=True).count
+
+ # Handle both object and dict formats for vectors config
+ vectors_config = info.config.params.vectors
+ if isinstance(vectors_config, dict):
+ # Named vectors format or dict format
+ if vectors_config:
+ first_key = next(iter(vectors_config.keys()), None)
+ if first_key and hasattr(vectors_config[first_key], "size"):
+ vector_size = vectors_config[first_key].size
+ distance = vectors_config[first_key].distance
+ else:
+ # Try to get from dict values
+ first_val = next(iter(vectors_config.values()), {})
+ vector_size = (
+ first_val.get("size")
+ if isinstance(first_val, dict)
+ else getattr(first_val, "size", None)
+ )
+ distance = (
+ first_val.get("distance")
+ if isinstance(first_val, dict)
+ else getattr(first_val, "distance", None)
+ )
+ else:
+ vector_size = None
+ distance = None
+ else:
+ # Standard single vector format
+ vector_size = vectors_config.size
+ distance = vectors_config.distance
+
+ return {
+ "name": collection_name,
+ "vector_size": vector_size,
+ "count": count,
+ "distance": distance,
+ }
+
+ def delete_collection(self, collection_name: str) -> bool:
+ """
+ Delete a collection if it exists.
+
+ Args:
+ collection_name: Name of the collection to delete
+
+ Returns:
+ True if deleted or doesn't exist
+ """
+ client = self._get_client()
+
+ if not client.collection_exists(collection_name):
+ return True
+
+ if self.dry_run:
+ target_info = self.get_collection_info(collection_name)
+ count = target_info["count"] if target_info else 0
+ print(
+ f" {BOLD_YELLOW}[DRY RUN]{RESET} Would delete collection '{collection_name}' ({count:,} records)"
+ )
+ return True
+
+ try:
+ target_info = self.get_collection_info(collection_name)
+ count = target_info["count"] if target_info else 0
+ client.delete_collection(collection_name=collection_name)
+ print(
+ f" {BOLD_RED}✗{RESET} Deleted collection '{collection_name}' ({count:,} records)"
+ )
+ return True
+ except Exception as e:
+ print(f" {BOLD_RED}✗{RESET} Failed to delete collection: {e}")
+ return False
+
+ def create_legacy_collection(
+ self, collection_name: str, vector_size: int, distance: models.Distance
+ ) -> bool:
+ """
+ Create legacy collection if it doesn't exist.
+
+ Args:
+ collection_name: Name of the collection to create
+ vector_size: Dimension of vectors
+ distance: Distance metric
+
+ Returns:
+ True if created or already exists
+ """
+ client = self._get_client()
+
+ if client.collection_exists(collection_name):
+ print(f" Collection '{collection_name}' already exists")
+ return True
+
+ if self.dry_run:
+ print(
+ f" {BOLD_YELLOW}[DRY RUN]{RESET} Would create collection '{collection_name}' with {vector_size}d vectors"
+ )
+ return True
+
+ try:
+ client.create_collection(
+ collection_name=collection_name,
+ vectors_config=models.VectorParams(
+ size=vector_size,
+ distance=distance,
+ ),
+ hnsw_config=models.HnswConfigDiff(
+ payload_m=16,
+ m=0,
+ ),
+ )
+ print(
+ f" {BOLD_GREEN}✓{RESET} Created collection '{collection_name}' with {vector_size}d vectors"
+ )
+ return True
+ except Exception as e:
+ print(f" {BOLD_RED}✗{RESET} Failed to create collection: {e}")
+ return False
+
+ def _get_workspace_filter(self) -> models.Filter:
+ """Create workspace filter for Qdrant queries"""
+ return models.Filter(
+ must=[
+ models.FieldCondition(
+ key=WORKSPACE_ID_FIELD,
+ match=models.MatchValue(value=self.workspace),
+ )
+ ]
+ )
+
+ def get_workspace_count(self, collection_name: str) -> int:
+ """
+ Get count of records for the current workspace in a collection.
+
+ Args:
+ collection_name: Name of the collection
+
+ Returns:
+ Count of records for the workspace
+ """
+ client = self._get_client()
+ return client.count(
+ collection_name=collection_name,
+ count_filter=self._get_workspace_filter(),
+ exact=True,
+ ).count
+
+ def copy_collection_data(
+ self,
+ source_collection: str,
+ target_collection: str,
+ collection_type: str,
+ workspace_count: int,
+ ) -> CopyStats:
+ """
+ Copy data from source to target collection.
+
+ This filters by workspace_id and removes it from payload to simulate legacy data format.
+
+ Args:
+ source_collection: Source collection name
+ target_collection: Target collection name
+ collection_type: Type of collection (chunks, entities, relationships)
+ workspace_count: Pre-computed count of workspace records
+
+ Returns:
+ CopyStats with operation results
+ """
+ client = self._get_client()
+ stats = CopyStats(
+ collection_type=collection_type,
+ source_collection=source_collection,
+ target_collection=target_collection,
+ )
+
+ start_time = time.time()
+ stats.total_records = workspace_count
+
+ if workspace_count == 0:
+ print(f" No records for workspace '{self.workspace}', skipping")
+ stats.elapsed_time = time.time() - start_time
+ return stats
+
+ print(f" Workspace records: {workspace_count:,}")
+
+ if self.dry_run:
+ print(
+ f" {BOLD_YELLOW}[DRY RUN]{RESET} Would copy {workspace_count:,} records to '{target_collection}'"
+ )
+ stats.copied_records = workspace_count
+ stats.elapsed_time = time.time() - start_time
+ return stats
+
+ # Batch copy using scroll with workspace filter
+ workspace_filter = self._get_workspace_filter()
+ offset = None
+ batch_idx = 0
+
+ while True:
+ # Scroll source collection with workspace filter
+ result = client.scroll(
+ collection_name=source_collection,
+ scroll_filter=workspace_filter,
+ limit=self.batch_size,
+ offset=offset,
+ with_vectors=True,
+ with_payload=True,
+ )
+ points, next_offset = result
+
+ if not points:
+ break
+
+ batch_idx += 1
+
+ # Transform points: remove workspace_id from payload
+ new_points = []
+ for point in points:
+ new_payload = dict(point.payload or {})
+ # Remove workspace_id to simulate legacy format
+ new_payload.pop(WORKSPACE_ID_FIELD, None)
+
+ # Use original id from payload if available, otherwise use point.id
+ original_id = new_payload.get("id")
+ if original_id:
+ # Generate a simple deterministic id for legacy format
+ # Use original id directly (legacy format didn't have workspace prefix)
+ import hashlib
+ import uuid
+
+ hashed = hashlib.sha256(original_id.encode("utf-8")).digest()
+ point_id = uuid.UUID(bytes=hashed[:16], version=4).hex
+ else:
+ point_id = str(point.id)
+
+ new_points.append(
+ models.PointStruct(
+ id=point_id,
+ vector=point.vector,
+ payload=new_payload,
+ )
+ )
+
+ try:
+ # Upsert to target collection
+ client.upsert(
+ collection_name=target_collection, points=new_points, wait=True
+ )
+ stats.copied_records += len(new_points)
+
+ # Progress bar
+ progress = (stats.copied_records / workspace_count) * 100
+ bar_length = 30
+ filled = int(bar_length * stats.copied_records // workspace_count)
+ bar = "█" * filled + "░" * (bar_length - filled)
+
+ print(
+ f"\r Copying: {bar} {stats.copied_records:,}/{workspace_count:,} ({progress:.1f}%) ",
+ end="",
+ flush=True,
+ )
+
+ except Exception as e:
+ stats.add_error(batch_idx, e, len(new_points))
+ print(
+ f"\n {BOLD_RED}✗{RESET} Batch {batch_idx} failed: {type(e).__name__}: {e}"
+ )
+
+ if next_offset is None:
+ break
+ offset = next_offset
+
+ print() # New line after progress bar
+ stats.elapsed_time = time.time() - start_time
+
+ return stats
+
+ def process_collection_type(self, collection_type: str) -> Optional[CopyStats]:
+ """
+ Process a single collection type.
+
+ Args:
+ collection_type: Type of collection (chunks, entities, relationships)
+
+ Returns:
+ CopyStats or None if error
+ """
+ namespace_config = COLLECTION_NAMESPACES.get(collection_type)
+ if not namespace_config:
+ print(f"{BOLD_RED}✗{RESET} Unknown collection type: {collection_type}")
+ return None
+
+ source = namespace_config["new"]
+ # Generate legacy collection name dynamically: {workspace}_{suffix}
+ target = f"{self.workspace}_{namespace_config['suffix']}"
+
+ print(f"\n{'=' * 50}")
+ print(f"Processing: {BOLD_CYAN}{collection_type}{RESET}")
+ print(f"{'=' * 50}")
+ print(f" Source: {source}")
+ print(f" Target: {target}")
+
+ # Check source collection
+ source_info = self.get_collection_info(source)
+ if source_info is None:
+ print(
+ f" {BOLD_YELLOW}⚠{RESET} Source collection '{source}' does not exist, skipping"
+ )
+ return None
+
+ print(f" Source vector dimension: {source_info['vector_size']}d")
+ print(f" Source distance metric: {source_info['distance']}")
+ print(f" Source total records: {source_info['count']:,}")
+
+ # Check workspace data exists BEFORE creating legacy collection
+ workspace_count = self.get_workspace_count(source)
+ print(f" Workspace '{self.workspace}' records: {workspace_count:,}")
+
+ if workspace_count == 0:
+ print(
+ f" {BOLD_YELLOW}⚠{RESET} No data found for workspace '{self.workspace}' in '{source}', skipping"
+ )
+ return None
+
+ # Clear target collection if requested
+ if self.clear_target:
+ if not self.delete_collection(target):
+ return None
+
+ # Create target collection only after confirming workspace data exists
+ if not self.create_legacy_collection(
+ target, source_info["vector_size"], source_info["distance"]
+ ):
+ return None
+
+ # Copy data with workspace filter
+ stats = self.copy_collection_data(
+ source, target, collection_type, workspace_count
+ )
+
+ # Print result
+ if stats.failed_records == 0:
+ print(
+ f" {BOLD_GREEN}✓{RESET} Copied {stats.copied_records:,} records in {stats.elapsed_time:.2f}s"
+ )
+ else:
+ print(
+ f" {BOLD_YELLOW}⚠{RESET} Copied {stats.copied_records:,} records, "
+ f"{BOLD_RED}{stats.failed_records:,} failed{RESET} in {stats.elapsed_time:.2f}s"
+ )
+
+ return stats
+
+ def print_summary(self, all_stats: List[CopyStats]):
+ """Print summary of all operations"""
+ print("\n" + "=" * 60)
+ print("Summary")
+ print("=" * 60)
+
+ total_copied = sum(s.copied_records for s in all_stats)
+ total_failed = sum(s.failed_records for s in all_stats)
+ total_time = sum(s.elapsed_time for s in all_stats)
+
+ for stats in all_stats:
+ status = (
+ f"{BOLD_GREEN}✓{RESET}"
+ if stats.failed_records == 0
+ else f"{BOLD_YELLOW}⚠{RESET}"
+ )
+ print(
+ f" {status} {stats.collection_type}: {stats.copied_records:,}/{stats.total_records:,} "
+ f"({stats.source_collection} → {stats.target_collection})"
+ )
+
+ print("-" * 60)
+ print(f" Total records copied: {BOLD_CYAN}{total_copied:,}{RESET}")
+ if total_failed > 0:
+ print(f" Total records failed: {BOLD_RED}{total_failed:,}{RESET}")
+ print(f" Total time: {total_time:.2f}s")
+
+ if self.dry_run:
+ print(f"\n{BOLD_YELLOW}⚠️ DRY RUN - No actual changes were made{RESET}")
+
+ # Print error details if any
+ all_errors = []
+ for stats in all_stats:
+ all_errors.extend(stats.errors)
+
+ if all_errors:
+ print(f"\n{BOLD_RED}Errors ({len(all_errors)}){RESET}")
+ for i, error in enumerate(all_errors[:5], 1):
+ print(
+ f" {i}. Batch {error['batch']}: {error['error_type']}: {error['error_msg']}"
+ )
+ if len(all_errors) > 5:
+ print(f" ... and {len(all_errors) - 5} more errors")
+
+ print("=" * 60)
+
+ async def run(self, collection_types: Optional[List[str]] = None):
+ """
+ Run the data preparation tool.
+
+ Args:
+ collection_types: List of collection types to process (default: all)
+ """
+ self.print_header()
+
+ # Check connection
+ if not self.check_connection():
+ return
+
+ # Determine which collection types to process
+ if collection_types:
+ types_to_process = [t.strip() for t in collection_types]
+ invalid_types = [
+ t for t in types_to_process if t not in COLLECTION_NAMESPACES
+ ]
+ if invalid_types:
+ print(
+ f"{BOLD_RED}✗{RESET} Invalid collection types: {', '.join(invalid_types)}"
+ )
+ print(f" Valid types: {', '.join(COLLECTION_NAMESPACES.keys())}")
+ return
+ else:
+ types_to_process = list(COLLECTION_NAMESPACES.keys())
+
+ print(f"\nCollection types to process: {', '.join(types_to_process)}")
+
+ # Process each collection type
+ all_stats = []
+ for ctype in types_to_process:
+ stats = self.process_collection_type(ctype)
+ if stats:
+ all_stats.append(stats)
+
+ # Print summary
+ if all_stats:
+ self.print_summary(all_stats)
+ else:
+ print(f"\n{BOLD_YELLOW}⚠{RESET} No collections were processed")
+
+
+def parse_args():
+ """Parse command line arguments"""
+ parser = argparse.ArgumentParser(
+ description="Prepare legacy data in Qdrant for migration testing",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog="""
+Examples:
+ python -m lightrag.tools.prepare_qdrant_legacy_data
+ python -m lightrag.tools.prepare_qdrant_legacy_data --workspace space1
+ python -m lightrag.tools.prepare_qdrant_legacy_data --types chunks,entities
+ python -m lightrag.tools.prepare_qdrant_legacy_data --dry-run
+ """,
+ )
+
+ parser.add_argument(
+ "--workspace",
+ type=str,
+ default="space1",
+ help="Workspace name (default: space1)",
+ )
+
+ parser.add_argument(
+ "--types",
+ type=str,
+ default=None,
+ help="Comma-separated list of collection types (chunks, entities, relationships)",
+ )
+
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=DEFAULT_BATCH_SIZE,
+ help=f"Batch size for copy operations (default: {DEFAULT_BATCH_SIZE})",
+ )
+
+ parser.add_argument(
+ "--dry-run",
+ action="store_true",
+ help="Preview operations without making changes",
+ )
+
+ parser.add_argument(
+ "--clear-target",
+ action="store_true",
+ help="Delete target collections before copying (for clean test environment)",
+ )
+
+ return parser.parse_args()
+
+
+async def main():
+ """Main entry point"""
+ args = parse_args()
+
+ collection_types = None
+ if args.types:
+ collection_types = [t.strip() for t in args.types.split(",")]
+
+ tool = QdrantLegacyDataPreparationTool(
+ workspace=args.workspace,
+ batch_size=args.batch_size,
+ dry_run=args.dry_run,
+ clear_target=args.clear_target,
+ )
+
+ await tool.run(collection_types=collection_types)
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/lightrag/utils.py b/lightrag/utils.py
index 65c1e4bc..d795acdb 100644
--- a/lightrag/utils.py
+++ b/lightrag/utils.py
@@ -425,6 +425,9 @@ class EmbeddingFunc:
send_dimensions: bool = (
False # Control whether to send embedding_dim to the function
)
+ model_name: str | None = (
+ None # Model name for implementating workspace data isolation in vector DB
+ )
async def __call__(self, *args, **kwargs) -> np.ndarray:
# Only inject embedding_dim when send_dimensions is True
@@ -1016,42 +1019,36 @@ def wrap_embedding_func_with_attrs(**kwargs):
Correct usage patterns:
- 1. Direct implementation (decorated):
+ 1. Direct decoration:
```python
- @wrap_embedding_func_with_attrs(embedding_dim=1536)
+ @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192, model_name="my_embedding_model")
async def my_embed(texts, embedding_dim=None):
# Direct implementation
return embeddings
```
-
- 2. Wrapper calling decorated function (DO NOT decorate wrapper):
+ 2. Double decoration:
```python
- # my_embed is already decorated above
+ @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192, model_name="my_embedding_model")
+ @retry(...)
+ async def openai_embed(texts, ...):
+ # Base implementation
+ pass
- async def my_wrapper(texts, **kwargs): # ❌ DO NOT decorate this!
- # Must call .func to access unwrapped implementation
- return await my_embed.func(texts, **kwargs)
- ```
-
- 3. Wrapper calling decorated function (properly decorated):
- ```python
- @wrap_embedding_func_with_attrs(embedding_dim=1536)
- async def my_wrapper(texts, **kwargs): # ✅ Can decorate if calling .func
- # Calling .func avoids double decoration
- return await my_embed.func(texts, **kwargs)
+ @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=4096, model_name="another_embedding_model")
+ # Note: No @retry here!
+ async def new_openai_embed(texts, ...):
+ # CRITICAL: Call .func to access unwrapped function
+ return await openai_embed.func(texts, ...) # ✅ Correct
+ # return await openai_embed(texts, ...) # ❌ Wrong - double decoration!
```
The decorated function becomes an EmbeddingFunc instance with:
- embedding_dim: The embedding dimension
- max_token_size: Maximum token limit (optional)
+ - model_name: Model name (optional)
- func: The original unwrapped function (access via .func)
- __call__: Wrapper that injects embedding_dim parameter
- Double decoration causes:
- - Double injection of embedding_dim parameter
- - Incorrect parameter passing to the underlying implementation
- - Runtime errors due to parameter conflicts
-
Args:
embedding_dim: The dimension of embedding vectors
max_token_size: Maximum number of tokens (optional)
@@ -1059,21 +1056,6 @@ def wrap_embedding_func_with_attrs(**kwargs):
Returns:
A decorator that wraps the function as an EmbeddingFunc instance
-
- Example of correct wrapper implementation:
- ```python
- @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
- @retry(...)
- async def openai_embed(texts, ...):
- # Base implementation
- pass
-
- @wrap_embedding_func_with_attrs(embedding_dim=1536) # Note: No @retry here!
- async def azure_openai_embed(texts, ...):
- # CRITICAL: Call .func to access unwrapped function
- return await openai_embed.func(texts, ...) # ✅ Correct
- # return await openai_embed(texts, ...) # ❌ Wrong - double decoration!
- ```
"""
def final_decro(func) -> EmbeddingFunc:
diff --git a/pyproject.toml b/pyproject.toml
index 761a3309..dd3dbc92 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -72,7 +72,6 @@ api = [
# API-specific dependencies
"aiofiles",
"ascii_colors",
- "asyncpg",
"distro",
"fastapi",
"httpcore",
@@ -108,7 +107,8 @@ offline-storage = [
"neo4j>=5.0.0,<7.0.0",
"pymilvus>=2.6.2,<3.0.0",
"pymongo>=4.0.0,<5.0.0",
- "asyncpg>=0.29.0,<1.0.0",
+ "asyncpg>=0.31.0,<1.0.0",
+ "pgvector>=0.4.2,<1.0.0",
"qdrant-client>=1.11.0,<2.0.0",
]
diff --git a/requirements-offline-storage.txt b/requirements-offline-storage.txt
index 13a9c0e2..82caacbd 100644
--- a/requirements-offline-storage.txt
+++ b/requirements-offline-storage.txt
@@ -8,8 +8,9 @@
# Or use constraints: pip install --constraint constraints-offline.txt -r requirements-offline-storage.txt
# Storage backend dependencies (with version constraints matching pyproject.toml)
-asyncpg>=0.29.0,<1.0.0
+asyncpg>=0.31.0,<1.0.0
neo4j>=5.0.0,<7.0.0
+pgvector>=0.4.2,<1.0.0
pymilvus>=2.6.2,<3.0.0
pymongo>=4.0.0,<5.0.0
qdrant-client>=1.11.0,<2.0.0
diff --git a/requirements-offline.txt b/requirements-offline.txt
index 87ca7a6a..283ced73 100644
--- a/requirements-offline.txt
+++ b/requirements-offline.txt
@@ -7,20 +7,17 @@
# Recommended: Use pip install lightrag-hku[offline] for the same effect
# Or use constraints: pip install --constraint constraints-offline.txt -r requirements-offline.txt
-# LLM provider dependencies (with version constraints matching pyproject.toml)
aioboto3>=12.0.0,<16.0.0
anthropic>=0.18.0,<1.0.0
-
-# Storage backend dependencies
-asyncpg>=0.29.0,<1.0.0
+asyncpg>=0.31.0,<1.0.0
+google-api-core>=2.0.0,<3.0.0
google-genai>=1.0.0,<2.0.0
-
-# Document processing dependencies
llama-index>=0.9.0,<1.0.0
neo4j>=5.0.0,<7.0.0
ollama>=0.1.0,<1.0.0
openai>=2.0.0,<3.0.0
openpyxl>=3.0.0,<4.0.0
+pgvector>=0.4.2,<1.0.0
pycryptodome>=3.0.0,<4.0.0
pymilvus>=2.6.2,<3.0.0
pymongo>=4.0.0,<5.0.0
diff --git a/tests/test_dimension_mismatch.py b/tests/test_dimension_mismatch.py
new file mode 100644
index 00000000..b63fbd35
--- /dev/null
+++ b/tests/test_dimension_mismatch.py
@@ -0,0 +1,377 @@
+"""
+Tests for dimension mismatch handling during migration.
+
+This test module verifies that both PostgreSQL and Qdrant storage backends
+properly detect and handle vector dimension mismatches when migrating from
+legacy collections/tables to new ones with different embedding models.
+"""
+
+import json
+import pytest
+from unittest.mock import MagicMock, AsyncMock, patch
+
+from lightrag.kg.qdrant_impl import QdrantVectorDBStorage
+from lightrag.kg.postgres_impl import PGVectorStorage
+from lightrag.exceptions import DataMigrationError
+
+
+# Note: Tests should use proper table names that have DDL templates
+# Valid base tables: LIGHTRAG_VDB_CHUNKS, LIGHTRAG_VDB_ENTITIES, LIGHTRAG_VDB_RELATIONSHIPS,
+# LIGHTRAG_DOC_CHUNKS, LIGHTRAG_DOC_FULL_DOCS, LIGHTRAG_DOC_TEXT_CHUNKS
+
+
+class TestQdrantDimensionMismatch:
+ """Test suite for Qdrant dimension mismatch handling."""
+
+ def test_qdrant_dimension_mismatch_raises_error(self):
+ """
+ Test that Qdrant raises DataMigrationError when dimensions don't match.
+
+ Scenario: Legacy collection has 1536d vectors, new model expects 3072d.
+ Expected: DataMigrationError is raised to prevent data corruption.
+ """
+ from qdrant_client import models
+
+ # Setup mock client
+ client = MagicMock()
+
+ # Mock legacy collection with 1536d vectors
+ legacy_collection_info = MagicMock()
+ legacy_collection_info.config.params.vectors.size = 1536
+
+ # Setup collection existence checks
+ def collection_exists_side_effect(name):
+ if (
+ name == "lightrag_vdb_chunks"
+ ): # legacy (matches _find_legacy_collection pattern)
+ return True
+ elif name == "lightrag_chunks_model_3072d": # new
+ return False
+ return False
+
+ client.collection_exists.side_effect = collection_exists_side_effect
+ client.get_collection.return_value = legacy_collection_info
+ client.count.return_value.count = 100 # Legacy has data
+
+ # Patch _find_legacy_collection to return the legacy collection name
+ with patch(
+ "lightrag.kg.qdrant_impl._find_legacy_collection",
+ return_value="lightrag_vdb_chunks",
+ ):
+ # Call setup_collection with 3072d (different from legacy 1536d)
+ # Should raise DataMigrationError due to dimension mismatch
+ with pytest.raises(DataMigrationError) as exc_info:
+ QdrantVectorDBStorage.setup_collection(
+ client,
+ "lightrag_chunks_model_3072d",
+ namespace="chunks",
+ workspace="test",
+ vectors_config=models.VectorParams(
+ size=3072, distance=models.Distance.COSINE
+ ),
+ hnsw_config=models.HnswConfigDiff(
+ payload_m=16,
+ m=0,
+ ),
+ model_suffix="model_3072d",
+ )
+
+ # Verify error message contains dimension information
+ assert "3072" in str(exc_info.value) or "1536" in str(exc_info.value)
+
+ # Verify new collection was NOT created (error raised before creation)
+ client.create_collection.assert_not_called()
+
+ # Verify migration was NOT attempted
+ client.scroll.assert_not_called()
+ client.upsert.assert_not_called()
+
+ def test_qdrant_dimension_match_proceed_migration(self):
+ """
+ Test that Qdrant proceeds with migration when dimensions match.
+
+ Scenario: Legacy collection has 1536d vectors, new model also expects 1536d.
+ Expected: Migration proceeds normally.
+ """
+ from qdrant_client import models
+
+ client = MagicMock()
+
+ # Mock legacy collection with 1536d vectors (matching new)
+ legacy_collection_info = MagicMock()
+ legacy_collection_info.config.params.vectors.size = 1536
+
+ def collection_exists_side_effect(name):
+ if name == "lightrag_chunks": # legacy
+ return True
+ elif name == "lightrag_chunks_model_1536d": # new
+ return False
+ return False
+
+ client.collection_exists.side_effect = collection_exists_side_effect
+ client.get_collection.return_value = legacy_collection_info
+
+ # Track whether upsert has been called (migration occurred)
+ migration_done = {"value": False}
+
+ def upsert_side_effect(*args, **kwargs):
+ migration_done["value"] = True
+ return MagicMock()
+
+ client.upsert.side_effect = upsert_side_effect
+
+ # Mock count to return different values based on collection name and migration state
+ # Before migration: new collection has 0 records
+ # After migration: new collection has 1 record (matching migrated data)
+ def count_side_effect(collection_name, **kwargs):
+ result = MagicMock()
+ if collection_name == "lightrag_chunks": # legacy
+ result.count = 1 # Legacy has 1 record
+ elif collection_name == "lightrag_chunks_model_1536d": # new
+ # Return 0 before migration, 1 after migration
+ result.count = 1 if migration_done["value"] else 0
+ else:
+ result.count = 0
+ return result
+
+ client.count.side_effect = count_side_effect
+
+ # Mock scroll to return sample data (1 record for easier verification)
+ sample_point = MagicMock()
+ sample_point.id = "test_id"
+ sample_point.vector = [0.1] * 1536
+ sample_point.payload = {"id": "test"}
+ client.scroll.return_value = ([sample_point], None)
+
+ # Mock _find_legacy_collection to return the legacy collection name
+ with patch(
+ "lightrag.kg.qdrant_impl._find_legacy_collection",
+ return_value="lightrag_chunks",
+ ):
+ # Call setup_collection with matching 1536d
+ QdrantVectorDBStorage.setup_collection(
+ client,
+ "lightrag_chunks_model_1536d",
+ namespace="chunks",
+ workspace="test",
+ vectors_config=models.VectorParams(
+ size=1536, distance=models.Distance.COSINE
+ ),
+ hnsw_config=models.HnswConfigDiff(
+ payload_m=16,
+ m=0,
+ ),
+ model_suffix="model_1536d",
+ )
+
+ # Verify migration WAS attempted
+ client.create_collection.assert_called_once()
+ client.scroll.assert_called()
+ client.upsert.assert_called()
+
+
+class TestPostgresDimensionMismatch:
+ """Test suite for PostgreSQL dimension mismatch handling."""
+
+ async def test_postgres_dimension_mismatch_raises_error_metadata(self):
+ """
+ Test that PostgreSQL raises DataMigrationError when dimensions don't match.
+
+ Scenario: Legacy table has 1536d vectors, new model expects 3072d.
+ Expected: DataMigrationError is raised to prevent data corruption.
+ """
+ # Setup mock database
+ db = AsyncMock()
+
+ # Mock check_table_exists
+ async def mock_check_table_exists(table_name):
+ if table_name == "LIGHTRAG_DOC_CHUNKS": # legacy
+ return True
+ elif table_name == "LIGHTRAG_DOC_CHUNKS_model_3072d": # new
+ return False
+ return False
+
+ db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
+
+ # Mock table existence and dimension checks
+ async def query_side_effect(query, params, **kwargs):
+ if "COUNT(*)" in query:
+ return {"count": 100} # Legacy has data
+ elif "SELECT content_vector FROM" in query:
+ # Return sample vector with 1536 dimensions
+ return {"content_vector": [0.1] * 1536}
+ return {}
+
+ db.query.side_effect = query_side_effect
+ db.execute = AsyncMock()
+ db._create_vector_index = AsyncMock()
+
+ # Call setup_table with 3072d (different from legacy 1536d)
+ # Should raise DataMigrationError due to dimension mismatch
+ with pytest.raises(DataMigrationError) as exc_info:
+ await PGVectorStorage.setup_table(
+ db,
+ "LIGHTRAG_DOC_CHUNKS_model_3072d",
+ legacy_table_name="LIGHTRAG_DOC_CHUNKS",
+ base_table="LIGHTRAG_DOC_CHUNKS",
+ embedding_dim=3072,
+ workspace="test",
+ )
+
+ # Verify error message contains dimension information
+ assert "3072" in str(exc_info.value) or "1536" in str(exc_info.value)
+
+ async def test_postgres_dimension_mismatch_raises_error_sampling(self):
+ """
+ Test that PostgreSQL raises error when dimensions don't match (via sampling).
+
+ Scenario: Legacy table vector sampling detects 1536d vs expected 3072d.
+ Expected: DataMigrationError is raised to prevent data corruption.
+ """
+ db = AsyncMock()
+
+ # Mock check_table_exists
+ async def mock_check_table_exists(table_name):
+ if table_name == "LIGHTRAG_DOC_CHUNKS": # legacy
+ return True
+ elif table_name == "LIGHTRAG_DOC_CHUNKS_model_3072d": # new
+ return False
+ return False
+
+ db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
+
+ # Mock table existence and dimension checks
+ async def query_side_effect(query, params, **kwargs):
+ if "information_schema.tables" in query:
+ if params[0] == "LIGHTRAG_DOC_CHUNKS": # legacy
+ return {"exists": True}
+ elif params[0] == "LIGHTRAG_DOC_CHUNKS_model_3072d": # new
+ return {"exists": False}
+ elif "COUNT(*)" in query:
+ return {"count": 100} # Legacy has data
+ elif "SELECT content_vector FROM" in query:
+ # Return sample vector with 1536 dimensions as a JSON string
+ return {"content_vector": json.dumps([0.1] * 1536)}
+ return {}
+
+ db.query.side_effect = query_side_effect
+ db.execute = AsyncMock()
+ db._create_vector_index = AsyncMock()
+
+ # Call setup_table with 3072d (different from legacy 1536d)
+ # Should raise DataMigrationError due to dimension mismatch
+ with pytest.raises(DataMigrationError) as exc_info:
+ await PGVectorStorage.setup_table(
+ db,
+ "LIGHTRAG_DOC_CHUNKS_model_3072d",
+ legacy_table_name="LIGHTRAG_DOC_CHUNKS",
+ base_table="LIGHTRAG_DOC_CHUNKS",
+ embedding_dim=3072,
+ workspace="test",
+ )
+
+ # Verify error message contains dimension information
+ assert "3072" in str(exc_info.value) or "1536" in str(exc_info.value)
+
+ async def test_postgres_dimension_match_proceed_migration(self):
+ """
+ Test that PostgreSQL proceeds with migration when dimensions match.
+
+ Scenario: Legacy table has 1536d vectors, new model also expects 1536d.
+ Expected: Migration proceeds normally.
+ """
+ db = AsyncMock()
+
+ # Track migration state
+ migration_done = {"value": False}
+
+ # Define exactly 2 records for consistency
+ mock_records = [
+ {
+ "id": "test1",
+ "content_vector": [0.1] * 1536,
+ "workspace": "test",
+ },
+ {
+ "id": "test2",
+ "content_vector": [0.2] * 1536,
+ "workspace": "test",
+ },
+ ]
+
+ # Mock check_table_exists
+ async def mock_check_table_exists(table_name):
+ if table_name == "LIGHTRAG_DOC_CHUNKS": # legacy exists
+ return True
+ elif table_name == "LIGHTRAG_DOC_CHUNKS_model_1536d": # new doesn't exist
+ return False
+ return False
+
+ db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
+
+ async def query_side_effect(query, params, **kwargs):
+ multirows = kwargs.get("multirows", False)
+ query_upper = query.upper()
+
+ if "information_schema.tables" in query:
+ if params[0] == "LIGHTRAG_DOC_CHUNKS": # legacy
+ return {"exists": True}
+ elif params[0] == "LIGHTRAG_DOC_CHUNKS_model_1536d": # new
+ return {"exists": False}
+ elif "COUNT(*)" in query_upper:
+ # Return different counts based on table name in query and migration state
+ if "LIGHTRAG_DOC_CHUNKS_MODEL_1536D" in query_upper:
+ # After migration: return migrated count, before: return 0
+ return {
+ "count": len(mock_records) if migration_done["value"] else 0
+ }
+ # Legacy table always has 2 records (matching mock_records)
+ return {"count": len(mock_records)}
+ elif "PG_ATTRIBUTE" in query_upper:
+ return {"vector_dim": 1536} # Legacy has matching 1536d
+ elif "SELECT" in query_upper and "FROM" in query_upper and multirows:
+ # Return sample data for migration using keyset pagination
+ # Handle keyset pagination: params = [workspace, limit] or [workspace, last_id, limit]
+ if "id >" in query.lower():
+ # Keyset pagination: params = [workspace, last_id, limit]
+ last_id = params[1] if len(params) > 1 else None
+ # Find records after last_id
+ found_idx = -1
+ for i, rec in enumerate(mock_records):
+ if rec["id"] == last_id:
+ found_idx = i
+ break
+ if found_idx >= 0:
+ return mock_records[found_idx + 1 :]
+ return []
+ else:
+ # First batch: params = [workspace, limit]
+ return mock_records
+ return {}
+
+ db.query.side_effect = query_side_effect
+
+ # Mock _run_with_retry to track when migration happens
+ migration_executed = []
+
+ async def mock_run_with_retry(operation, *args, **kwargs):
+ migration_executed.append(True)
+ migration_done["value"] = True
+ return None
+
+ db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry)
+ db.execute = AsyncMock()
+ db._create_vector_index = AsyncMock()
+
+ # Call setup_table with matching 1536d
+ await PGVectorStorage.setup_table(
+ db,
+ "LIGHTRAG_DOC_CHUNKS_model_1536d",
+ legacy_table_name="LIGHTRAG_DOC_CHUNKS",
+ base_table="LIGHTRAG_DOC_CHUNKS",
+ embedding_dim=1536,
+ workspace="test",
+ )
+
+ # Verify migration WAS called (via _run_with_retry for batch operations)
+ assert len(migration_executed) > 0, "Migration should have been executed"
diff --git a/tests/test_no_model_suffix_safety.py b/tests/test_no_model_suffix_safety.py
new file mode 100644
index 00000000..2f438d38
--- /dev/null
+++ b/tests/test_no_model_suffix_safety.py
@@ -0,0 +1,220 @@
+"""
+Tests for safety when model suffix is absent (no model_name provided).
+
+This test module verifies that the system correctly handles the case when
+no model_name is provided, preventing accidental deletion of the only table/collection
+on restart.
+
+Critical Bug: When model_suffix is empty, table_name == legacy_table_name.
+On second startup, Case 1 logic would delete the only table/collection thinking
+it's "legacy", causing all subsequent operations to fail.
+"""
+
+from unittest.mock import MagicMock, AsyncMock, patch
+
+from lightrag.kg.qdrant_impl import QdrantVectorDBStorage
+from lightrag.kg.postgres_impl import PGVectorStorage
+
+
+class TestNoModelSuffixSafety:
+ """Test suite for preventing data loss when model_suffix is absent."""
+
+ def test_qdrant_no_suffix_second_startup(self):
+ """
+ Test Qdrant doesn't delete collection on second startup when no model_name.
+
+ Scenario:
+ 1. First startup: Creates collection without suffix
+ 2. Collection is empty
+ 3. Second startup: Should NOT delete the collection
+
+ Bug: Without fix, Case 1 would delete the only collection.
+ """
+ from qdrant_client import models
+
+ client = MagicMock()
+
+ # Simulate second startup: collection already exists and is empty
+ # IMPORTANT: Without suffix, collection_name == legacy collection name
+ collection_name = "lightrag_vdb_chunks" # No suffix, same as legacy
+
+ # Both exist (they're the same collection)
+ client.collection_exists.return_value = True
+
+ # Collection is empty
+ client.count.return_value.count = 0
+
+ # Patch _find_legacy_collection to return the SAME collection name
+ # This simulates the scenario where new collection == legacy collection
+ with patch(
+ "lightrag.kg.qdrant_impl._find_legacy_collection",
+ return_value="lightrag_vdb_chunks", # Same as collection_name
+ ):
+ # Call setup_collection
+ # This should detect that new == legacy and skip deletion
+ QdrantVectorDBStorage.setup_collection(
+ client,
+ collection_name,
+ namespace="chunks",
+ workspace="_",
+ vectors_config=models.VectorParams(
+ size=1536, distance=models.Distance.COSINE
+ ),
+ hnsw_config=models.HnswConfigDiff(
+ payload_m=16,
+ m=0,
+ ),
+ model_suffix="", # Empty suffix to simulate no model_name provided
+ )
+
+ # CRITICAL: Collection should NOT be deleted
+ client.delete_collection.assert_not_called()
+
+ # Verify we returned early (skipped Case 1 cleanup)
+ # The collection_exists was checked, but we didn't proceed to count
+ # because we detected same name
+ assert client.collection_exists.call_count >= 1
+
+ async def test_postgres_no_suffix_second_startup(self):
+ """
+ Test PostgreSQL doesn't delete table on second startup when no model_name.
+
+ Scenario:
+ 1. First startup: Creates table without suffix
+ 2. Table is empty
+ 3. Second startup: Should NOT delete the table
+
+ Bug: Without fix, Case 1 would delete the only table.
+ """
+ db = AsyncMock()
+
+ # Configure mock return values to avoid unawaited coroutine warnings
+ db.query.return_value = {"count": 0}
+ db._create_vector_index.return_value = None
+
+ # Simulate second startup: table already exists and is empty
+ # IMPORTANT: table_name and legacy_table_name are THE SAME
+ table_name = "LIGHTRAG_VDB_CHUNKS" # No suffix
+ legacy_table_name = "LIGHTRAG_VDB_CHUNKS" # Same as new
+
+ # Setup mock responses using check_table_exists on db
+ async def check_table_exists_side_effect(name):
+ # Both tables exist (they're the same)
+ return True
+
+ db.check_table_exists = AsyncMock(side_effect=check_table_exists_side_effect)
+
+ # Call setup_table
+ # This should detect that new == legacy and skip deletion
+ await PGVectorStorage.setup_table(
+ db,
+ table_name,
+ workspace="test_workspace",
+ embedding_dim=1536,
+ legacy_table_name=legacy_table_name,
+ base_table="LIGHTRAG_VDB_CHUNKS",
+ )
+
+ # CRITICAL: Table should NOT be deleted (no DROP TABLE)
+ drop_calls = [
+ call
+ for call in db.execute.call_args_list
+ if call[0][0] and "DROP TABLE" in call[0][0]
+ ]
+ assert (
+ len(drop_calls) == 0
+ ), "Should not drop table when new and legacy are the same"
+
+ # Note: COUNT queries for workspace data are expected behavior in Case 1
+ # (for logging/warning purposes when workspace data is empty).
+ # The critical safety check is that DROP TABLE is not called.
+
+ def test_qdrant_with_suffix_case1_still_works(self):
+ """
+ Test that Case 1 cleanup still works when there IS a suffix.
+
+ This ensures our fix doesn't break the normal Case 1 scenario.
+ """
+ from qdrant_client import models
+
+ client = MagicMock()
+
+ # Different names (normal case)
+ collection_name = "lightrag_vdb_chunks_ada_002_1536d" # With suffix
+ legacy_collection = "lightrag_vdb_chunks" # Without suffix
+
+ # Setup: both exist
+ def collection_exists_side_effect(name):
+ return name in [collection_name, legacy_collection]
+
+ client.collection_exists.side_effect = collection_exists_side_effect
+
+ # Legacy is empty
+ client.count.return_value.count = 0
+
+ # Call setup_collection
+ QdrantVectorDBStorage.setup_collection(
+ client,
+ collection_name,
+ namespace="chunks",
+ workspace="_",
+ vectors_config=models.VectorParams(
+ size=1536, distance=models.Distance.COSINE
+ ),
+ hnsw_config=models.HnswConfigDiff(
+ payload_m=16,
+ m=0,
+ ),
+ model_suffix="ada_002_1536d",
+ )
+
+ # SHOULD delete legacy (normal Case 1 behavior)
+ client.delete_collection.assert_called_once_with(
+ collection_name=legacy_collection
+ )
+
+ async def test_postgres_with_suffix_case1_still_works(self):
+ """
+ Test that Case 1 cleanup still works when there IS a suffix.
+
+ This ensures our fix doesn't break the normal Case 1 scenario.
+ """
+ db = AsyncMock()
+
+ # Different names (normal case)
+ table_name = "LIGHTRAG_VDB_CHUNKS_ADA_002_1536D" # With suffix
+ legacy_table_name = "LIGHTRAG_VDB_CHUNKS" # Without suffix
+
+ # Setup mock responses using check_table_exists on db
+ async def check_table_exists_side_effect(name):
+ # Both tables exist
+ return True
+
+ db.check_table_exists = AsyncMock(side_effect=check_table_exists_side_effect)
+
+ # Mock empty table
+ async def query_side_effect(sql, params, **kwargs):
+ if "COUNT(*)" in sql:
+ return {"count": 0}
+ return {}
+
+ db.query.side_effect = query_side_effect
+
+ # Call setup_table
+ await PGVectorStorage.setup_table(
+ db,
+ table_name,
+ workspace="test_workspace",
+ embedding_dim=1536,
+ legacy_table_name=legacy_table_name,
+ base_table="LIGHTRAG_VDB_CHUNKS",
+ )
+
+ # SHOULD delete legacy (normal Case 1 behavior)
+ drop_calls = [
+ call
+ for call in db.execute.call_args_list
+ if call[0][0] and "DROP TABLE" in call[0][0]
+ ]
+ assert len(drop_calls) == 1, "Should drop legacy table in normal Case 1"
+ assert legacy_table_name in drop_calls[0][0][0]
diff --git a/tests/test_postgres_index_name.py b/tests/test_postgres_index_name.py
new file mode 100644
index 00000000..e0af9834
--- /dev/null
+++ b/tests/test_postgres_index_name.py
@@ -0,0 +1,210 @@
+"""
+Unit tests for PostgreSQL safe index name generation.
+
+This module tests the _safe_index_name helper function which prevents
+PostgreSQL's silent 63-byte identifier truncation from causing index
+lookup failures.
+"""
+
+import pytest
+
+# Mark all tests as offline (no external dependencies)
+pytestmark = pytest.mark.offline
+
+
+class TestSafeIndexName:
+ """Test suite for _safe_index_name function."""
+
+ def test_short_name_unchanged(self):
+ """Short index names should remain unchanged."""
+ from lightrag.kg.postgres_impl import _safe_index_name
+
+ # Short table name - should return unchanged
+ result = _safe_index_name("lightrag_vdb_entity", "hnsw_cosine")
+ assert result == "idx_lightrag_vdb_entity_hnsw_cosine"
+ assert len(result.encode("utf-8")) <= 63
+
+ def test_long_name_gets_hashed(self):
+ """Long table names exceeding 63 bytes should get hashed."""
+ from lightrag.kg.postgres_impl import _safe_index_name
+
+ # Long table name that would exceed 63 bytes
+ long_table_name = "LIGHTRAG_VDB_ENTITY_text_embedding_3_large_3072d"
+ result = _safe_index_name(long_table_name, "hnsw_cosine")
+
+ # Should be within 63 bytes
+ assert len(result.encode("utf-8")) <= 63
+
+ # Should start with idx_ prefix
+ assert result.startswith("idx_")
+
+ # Should contain the suffix
+ assert result.endswith("_hnsw_cosine")
+
+ # Should NOT be the naive concatenation (which would be truncated)
+ naive_name = f"idx_{long_table_name.lower()}_hnsw_cosine"
+ assert result != naive_name
+
+ def test_deterministic_output(self):
+ """Same input should always produce same output (deterministic)."""
+ from lightrag.kg.postgres_impl import _safe_index_name
+
+ table_name = "LIGHTRAG_VDB_CHUNKS_text_embedding_3_large_3072d"
+ suffix = "hnsw_cosine"
+
+ result1 = _safe_index_name(table_name, suffix)
+ result2 = _safe_index_name(table_name, suffix)
+
+ assert result1 == result2
+
+ def test_different_suffixes_different_results(self):
+ """Different suffixes should produce different index names."""
+ from lightrag.kg.postgres_impl import _safe_index_name
+
+ table_name = "LIGHTRAG_VDB_ENTITY_text_embedding_3_large_3072d"
+
+ result1 = _safe_index_name(table_name, "hnsw_cosine")
+ result2 = _safe_index_name(table_name, "ivfflat_cosine")
+
+ assert result1 != result2
+
+ def test_case_insensitive(self):
+ """Table names should be normalized to lowercase."""
+ from lightrag.kg.postgres_impl import _safe_index_name
+
+ result_upper = _safe_index_name("LIGHTRAG_VDB_ENTITY", "hnsw_cosine")
+ result_lower = _safe_index_name("lightrag_vdb_entity", "hnsw_cosine")
+
+ assert result_upper == result_lower
+
+ def test_boundary_case_exactly_63_bytes(self):
+ """Test boundary case where name is exactly at 63-byte limit."""
+ from lightrag.kg.postgres_impl import _safe_index_name
+
+ # Create a table name that results in exactly 63 bytes
+ # idx_ (4) + table_name + _ (1) + suffix = 63
+ # So table_name + suffix = 58
+
+ # Test a name that's just under the limit (should remain unchanged)
+ short_suffix = "id"
+ # idx_ (4) + 56 chars + _ (1) + id (2) = 63
+ table_56 = "a" * 56
+ result = _safe_index_name(table_56, short_suffix)
+ expected = f"idx_{table_56}_{short_suffix}"
+ assert result == expected
+ assert len(result.encode("utf-8")) == 63
+
+ def test_unicode_handling(self):
+ """Unicode characters should be properly handled (bytes, not chars)."""
+ from lightrag.kg.postgres_impl import _safe_index_name
+
+ # Unicode characters can take more bytes than visible chars
+ # Chinese characters are 3 bytes each in UTF-8
+ table_name = "lightrag_测试_table" # Contains Chinese chars
+ result = _safe_index_name(table_name, "hnsw_cosine")
+
+ # Should always be within 63 bytes
+ assert len(result.encode("utf-8")) <= 63
+
+ def test_real_world_model_names(self):
+ """Test with real-world embedding model names that cause issues."""
+ from lightrag.kg.postgres_impl import _safe_index_name
+
+ # These are actual model names that have caused issues
+ test_cases = [
+ ("LIGHTRAG_VDB_CHUNKS_text_embedding_3_large_3072d", "hnsw_cosine"),
+ ("LIGHTRAG_VDB_ENTITY_text_embedding_3_large_3072d", "hnsw_cosine"),
+ ("LIGHTRAG_VDB_RELATION_text_embedding_3_large_3072d", "hnsw_cosine"),
+ (
+ "LIGHTRAG_VDB_ENTITY_bge_m3_1024d",
+ "hnsw_cosine",
+ ), # Shorter model name
+ (
+ "LIGHTRAG_VDB_CHUNKS_nomic_embed_text_v1_768d",
+ "ivfflat_cosine",
+ ), # Different index type
+ ]
+
+ for table_name, suffix in test_cases:
+ result = _safe_index_name(table_name, suffix)
+
+ # Critical: must be within PostgreSQL's 63-byte limit
+ assert (
+ len(result.encode("utf-8")) <= 63
+ ), f"Index name too long: {result} for table {table_name}"
+
+ # Must have consistent format
+ assert result.startswith("idx_"), f"Missing idx_ prefix: {result}"
+ assert result.endswith(f"_{suffix}"), f"Missing suffix {suffix}: {result}"
+
+ def test_hash_uniqueness_for_similar_tables(self):
+ """Similar but different table names should produce different hashes."""
+ from lightrag.kg.postgres_impl import _safe_index_name
+
+ # These tables have similar names but should have different hashes
+ tables = [
+ "LIGHTRAG_VDB_CHUNKS_model_a_1024d",
+ "LIGHTRAG_VDB_CHUNKS_model_b_1024d",
+ "LIGHTRAG_VDB_ENTITY_model_a_1024d",
+ ]
+
+ results = [_safe_index_name(t, "hnsw_cosine") for t in tables]
+
+ # All results should be unique
+ assert len(set(results)) == len(results), "Hash collision detected!"
+
+
+class TestIndexNameIntegration:
+ """Integration-style tests for index name usage patterns."""
+
+ def test_pg_indexes_lookup_compatibility(self):
+ """
+ Test that the generated index name will work with pg_indexes lookup.
+
+ This is the core problem: PostgreSQL stores the truncated name,
+ but we were looking up the untruncated name. Our fix ensures we
+ always use a name that fits within 63 bytes.
+ """
+ from lightrag.kg.postgres_impl import _safe_index_name
+
+ table_name = "LIGHTRAG_VDB_CHUNKS_text_embedding_3_large_3072d"
+ suffix = "hnsw_cosine"
+
+ # Generate the index name
+ index_name = _safe_index_name(table_name, suffix)
+
+ # Simulate what PostgreSQL would store (truncate at 63 bytes)
+ stored_name = index_name.encode("utf-8")[:63].decode("utf-8", errors="ignore")
+
+ # The key fix: our generated name should equal the stored name
+ # because it's already within the 63-byte limit
+ assert (
+ index_name == stored_name
+ ), "Index name would be truncated by PostgreSQL, causing lookup failures!"
+
+ def test_backward_compatibility_short_names(self):
+ """
+ Ensure backward compatibility with existing short index names.
+
+ For tables that have existing indexes with short names (pre-model-suffix era),
+ the function should not change their names.
+ """
+ from lightrag.kg.postgres_impl import _safe_index_name
+
+ # Legacy table names without model suffix
+ legacy_tables = [
+ "LIGHTRAG_VDB_ENTITY",
+ "LIGHTRAG_VDB_RELATION",
+ "LIGHTRAG_VDB_CHUNKS",
+ ]
+
+ for table in legacy_tables:
+ for suffix in ["hnsw_cosine", "ivfflat_cosine", "id"]:
+ result = _safe_index_name(table, suffix)
+ expected = f"idx_{table.lower()}_{suffix}"
+
+ # Short names should remain unchanged for backward compatibility
+ if len(expected.encode("utf-8")) <= 63:
+ assert (
+ result == expected
+ ), f"Short name changed unexpectedly: {result} != {expected}"
diff --git a/tests/test_postgres_migration.py b/tests/test_postgres_migration.py
new file mode 100644
index 00000000..ce431a1d
--- /dev/null
+++ b/tests/test_postgres_migration.py
@@ -0,0 +1,856 @@
+import pytest
+from unittest.mock import patch, AsyncMock
+import numpy as np
+from lightrag.utils import EmbeddingFunc
+from lightrag.kg.postgres_impl import (
+ PGVectorStorage,
+)
+from lightrag.namespace import NameSpace
+
+
+# Mock PostgreSQLDB
+@pytest.fixture
+def mock_pg_db():
+ """Mock PostgreSQL database connection"""
+ db = AsyncMock()
+ db.workspace = "test_workspace"
+
+ # Mock query responses with multirows support
+ async def mock_query(sql, params=None, multirows=False, **kwargs):
+ # Default return value
+ if multirows:
+ return [] # Return empty list for multirows
+ return {"exists": False, "count": 0}
+
+ # Mock for execute that mimics PostgreSQLDB.execute() behavior
+ async def mock_execute(sql, data=None, **kwargs):
+ """
+ Mock that mimics PostgreSQLDB.execute() behavior:
+ - Accepts data as dict[str, Any] | None (second parameter)
+ - Internally converts dict.values() to tuple for AsyncPG
+ """
+ # Mimic real execute() which accepts dict and converts to tuple
+ if data is not None and not isinstance(data, dict):
+ raise TypeError(
+ f"PostgreSQLDB.execute() expects data as dict, got {type(data).__name__}"
+ )
+ return None
+
+ db.query = AsyncMock(side_effect=mock_query)
+ db.execute = AsyncMock(side_effect=mock_execute)
+
+ return db
+
+
+# Mock get_data_init_lock to avoid async lock issues in tests
+@pytest.fixture(autouse=True)
+def mock_data_init_lock():
+ with patch("lightrag.kg.postgres_impl.get_data_init_lock") as mock_lock:
+ mock_lock_ctx = AsyncMock()
+ mock_lock.return_value = mock_lock_ctx
+ yield mock_lock
+
+
+# Mock ClientManager
+@pytest.fixture
+def mock_client_manager(mock_pg_db):
+ with patch("lightrag.kg.postgres_impl.ClientManager") as mock_manager:
+ mock_manager.get_client = AsyncMock(return_value=mock_pg_db)
+ mock_manager.release_client = AsyncMock()
+ yield mock_manager
+
+
+# Mock Embedding function
+@pytest.fixture
+def mock_embedding_func():
+ async def embed_func(texts, **kwargs):
+ return np.array([[0.1] * 768 for _ in texts])
+
+ func = EmbeddingFunc(embedding_dim=768, func=embed_func, model_name="test_model")
+ return func
+
+
+async def test_postgres_table_naming(
+ mock_client_manager, mock_pg_db, mock_embedding_func
+):
+ """Test if table name is correctly generated with model suffix"""
+ config = {
+ "embedding_batch_num": 10,
+ "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
+ }
+
+ storage = PGVectorStorage(
+ namespace=NameSpace.VECTOR_STORE_CHUNKS,
+ global_config=config,
+ embedding_func=mock_embedding_func,
+ workspace="test_ws",
+ )
+
+ # Verify table name contains model suffix
+ expected_suffix = "test_model_768d"
+ assert expected_suffix in storage.table_name
+ assert storage.table_name == f"LIGHTRAG_VDB_CHUNKS_{expected_suffix}"
+
+ # Verify legacy table name
+ assert storage.legacy_table_name == "LIGHTRAG_VDB_CHUNKS"
+
+
+async def test_postgres_migration_trigger(
+ mock_client_manager, mock_pg_db, mock_embedding_func
+):
+ """Test if migration logic is triggered correctly"""
+ config = {
+ "embedding_batch_num": 10,
+ "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
+ }
+
+ storage = PGVectorStorage(
+ namespace=NameSpace.VECTOR_STORE_CHUNKS,
+ global_config=config,
+ embedding_func=mock_embedding_func,
+ workspace="test_ws",
+ )
+
+ # Setup mocks for migration scenario
+ # 1. New table does not exist, legacy table exists
+ async def mock_check_table_exists(table_name):
+ return table_name == storage.legacy_table_name
+
+ mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
+
+ # 2. Legacy table has 100 records
+ mock_rows = [
+ {"id": f"test_id_{i}", "content": f"content_{i}", "workspace": "test_ws"}
+ for i in range(100)
+ ]
+ migration_state = {"new_table_count": 0}
+
+ async def mock_query(sql, params=None, multirows=False, **kwargs):
+ if "COUNT(*)" in sql:
+ sql_upper = sql.upper()
+ legacy_table = storage.legacy_table_name.upper()
+ new_table = storage.table_name.upper()
+ is_new_table = new_table in sql_upper
+ is_legacy_table = legacy_table in sql_upper and not is_new_table
+
+ if is_new_table:
+ return {"count": migration_state["new_table_count"]}
+ if is_legacy_table:
+ return {"count": 100}
+ return {"count": 0}
+ elif multirows and "SELECT *" in sql:
+ # Mock batch fetch for migration using keyset pagination
+ # New pattern: WHERE workspace = $1 AND id > $2 ORDER BY id LIMIT $3
+ # or first batch: WHERE workspace = $1 ORDER BY id LIMIT $2
+ if "WHERE workspace" in sql:
+ if "id >" in sql:
+ # Keyset pagination: params = [workspace, last_id, limit]
+ last_id = params[1] if len(params) > 1 else None
+ # Find rows after last_id
+ start_idx = 0
+ for i, row in enumerate(mock_rows):
+ if row["id"] == last_id:
+ start_idx = i + 1
+ break
+ limit = params[2] if len(params) > 2 else 500
+ else:
+ # First batch (no last_id): params = [workspace, limit]
+ start_idx = 0
+ limit = params[1] if len(params) > 1 else 500
+ else:
+ # No workspace filter with keyset
+ if "id >" in sql:
+ last_id = params[0] if params else None
+ start_idx = 0
+ for i, row in enumerate(mock_rows):
+ if row["id"] == last_id:
+ start_idx = i + 1
+ break
+ limit = params[1] if len(params) > 1 else 500
+ else:
+ start_idx = 0
+ limit = params[0] if params else 500
+ end = min(start_idx + limit, len(mock_rows))
+ return mock_rows[start_idx:end]
+ return {}
+
+ mock_pg_db.query = AsyncMock(side_effect=mock_query)
+
+ # Track migration through _run_with_retry calls
+ migration_executed = []
+
+ async def mock_run_with_retry(operation, **kwargs):
+ # Track that migration batch operation was called
+ migration_executed.append(True)
+ migration_state["new_table_count"] = 100
+ return None
+
+ mock_pg_db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry)
+
+ with patch(
+ "lightrag.kg.postgres_impl.PGVectorStorage._pg_create_table", AsyncMock()
+ ):
+ # Initialize storage (should trigger migration)
+ await storage.initialize()
+
+ # Verify migration was executed by checking _run_with_retry was called
+ # (batch migration uses _run_with_retry with executemany)
+ assert len(migration_executed) > 0, "Migration should have been executed"
+
+
+async def test_postgres_no_migration_needed(
+ mock_client_manager, mock_pg_db, mock_embedding_func
+):
+ """Test scenario where new table already exists (no migration needed)"""
+ config = {
+ "embedding_batch_num": 10,
+ "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
+ }
+
+ storage = PGVectorStorage(
+ namespace=NameSpace.VECTOR_STORE_CHUNKS,
+ global_config=config,
+ embedding_func=mock_embedding_func,
+ workspace="test_ws",
+ )
+
+ # Mock: new table already exists
+ async def mock_check_table_exists(table_name):
+ return table_name == storage.table_name
+
+ mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
+
+ with patch(
+ "lightrag.kg.postgres_impl.PGVectorStorage._pg_create_table", AsyncMock()
+ ) as mock_create:
+ await storage.initialize()
+
+ # Verify no table creation was attempted
+ mock_create.assert_not_called()
+
+
+async def test_scenario_1_new_workspace_creation(
+ mock_client_manager, mock_pg_db, mock_embedding_func
+):
+ """
+ Scenario 1: New workspace creation
+
+ Expected behavior:
+ - No legacy table exists
+ - Directly create new table with model suffix
+ - No migration needed
+ """
+ config = {
+ "embedding_batch_num": 10,
+ "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
+ }
+
+ embedding_func = EmbeddingFunc(
+ embedding_dim=3072,
+ func=mock_embedding_func.func,
+ model_name="text-embedding-3-large",
+ )
+
+ storage = PGVectorStorage(
+ namespace=NameSpace.VECTOR_STORE_CHUNKS,
+ global_config=config,
+ embedding_func=embedding_func,
+ workspace="new_workspace",
+ )
+
+ # Mock: neither table exists
+ async def mock_check_table_exists(table_name):
+ return False
+
+ mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
+
+ with patch(
+ "lightrag.kg.postgres_impl.PGVectorStorage._pg_create_table", AsyncMock()
+ ) as mock_create:
+ await storage.initialize()
+
+ # Verify table name format
+ assert "text_embedding_3_large_3072d" in storage.table_name
+
+ # Verify new table creation was called
+ mock_create.assert_called_once()
+ call_args = mock_create.call_args
+ assert (
+ call_args[0][1] == storage.table_name
+ ) # table_name is second positional arg
+
+
+async def test_scenario_2_legacy_upgrade_migration(
+ mock_client_manager, mock_pg_db, mock_embedding_func
+):
+ """
+ Scenario 2: Upgrade from legacy version
+
+ Expected behavior:
+ - Legacy table exists (without model suffix)
+ - New table doesn't exist
+ - Automatically migrate data to new table with suffix
+ """
+ config = {
+ "embedding_batch_num": 10,
+ "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
+ }
+
+ embedding_func = EmbeddingFunc(
+ embedding_dim=1536,
+ func=mock_embedding_func.func,
+ model_name="text-embedding-ada-002",
+ )
+
+ storage = PGVectorStorage(
+ namespace=NameSpace.VECTOR_STORE_CHUNKS,
+ global_config=config,
+ embedding_func=embedding_func,
+ workspace="legacy_workspace",
+ )
+
+ # Mock: only legacy table exists
+ async def mock_check_table_exists(table_name):
+ return table_name == storage.legacy_table_name
+
+ mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
+
+ # Mock: legacy table has 50 records
+ mock_rows = [
+ {
+ "id": f"legacy_id_{i}",
+ "content": f"legacy_content_{i}",
+ "workspace": "legacy_workspace",
+ }
+ for i in range(50)
+ ]
+
+ # Track which queries have been made for proper response
+ query_history = []
+ migration_state = {"new_table_count": 0}
+
+ async def mock_query(sql, params=None, multirows=False, **kwargs):
+ query_history.append(sql)
+
+ if "COUNT(*)" in sql:
+ # Determine table type:
+ # - Legacy: contains base name but NOT model suffix
+ # - New: contains model suffix (e.g., text_embedding_ada_002_1536d)
+ sql_upper = sql.upper()
+ base_name = storage.legacy_table_name.upper()
+
+ # Check if this is querying the new table (has model suffix)
+ has_model_suffix = storage.table_name.upper() in sql_upper
+
+ is_legacy_table = base_name in sql_upper and not has_model_suffix
+ has_workspace_filter = "WHERE workspace" in sql
+
+ if is_legacy_table and has_workspace_filter:
+ # Count for legacy table with workspace filter (before migration)
+ return {"count": 50}
+ elif is_legacy_table and not has_workspace_filter:
+ # Total count for legacy table
+ return {"count": 50}
+ else:
+ # New table count (before/after migration)
+ return {"count": migration_state["new_table_count"]}
+ elif multirows and "SELECT *" in sql:
+ # Mock batch fetch for migration using keyset pagination
+ # New pattern: WHERE workspace = $1 AND id > $2 ORDER BY id LIMIT $3
+ # or first batch: WHERE workspace = $1 ORDER BY id LIMIT $2
+ if "WHERE workspace" in sql:
+ if "id >" in sql:
+ # Keyset pagination: params = [workspace, last_id, limit]
+ last_id = params[1] if len(params) > 1 else None
+ # Find rows after last_id
+ start_idx = 0
+ for i, row in enumerate(mock_rows):
+ if row["id"] == last_id:
+ start_idx = i + 1
+ break
+ limit = params[2] if len(params) > 2 else 500
+ else:
+ # First batch (no last_id): params = [workspace, limit]
+ start_idx = 0
+ limit = params[1] if len(params) > 1 else 500
+ else:
+ # No workspace filter with keyset
+ if "id >" in sql:
+ last_id = params[0] if params else None
+ start_idx = 0
+ for i, row in enumerate(mock_rows):
+ if row["id"] == last_id:
+ start_idx = i + 1
+ break
+ limit = params[1] if len(params) > 1 else 500
+ else:
+ start_idx = 0
+ limit = params[0] if params else 500
+ end = min(start_idx + limit, len(mock_rows))
+ return mock_rows[start_idx:end]
+ return {}
+
+ mock_pg_db.query = AsyncMock(side_effect=mock_query)
+
+ # Track migration through _run_with_retry calls
+ migration_executed = []
+
+ async def mock_run_with_retry(operation, **kwargs):
+ # Track that migration batch operation was called
+ migration_executed.append(True)
+ migration_state["new_table_count"] = 50
+ return None
+
+ mock_pg_db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry)
+
+ with patch(
+ "lightrag.kg.postgres_impl.PGVectorStorage._pg_create_table", AsyncMock()
+ ) as mock_create:
+ await storage.initialize()
+
+ # Verify table name contains ada-002
+ assert "text_embedding_ada_002_1536d" in storage.table_name
+
+ # Verify migration was executed (batch migration uses _run_with_retry)
+ assert len(migration_executed) > 0, "Migration should have been executed"
+ mock_create.assert_called_once()
+
+
+async def test_scenario_3_multi_model_coexistence(
+ mock_client_manager, mock_pg_db, mock_embedding_func
+):
+ """
+ Scenario 3: Multiple embedding models coexist
+
+ Expected behavior:
+ - Different embedding models create separate tables
+ - Tables are isolated by model suffix
+ - No interference between different models
+ """
+ config = {
+ "embedding_batch_num": 10,
+ "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
+ }
+
+ # Workspace A: uses bge-small (768d)
+ embedding_func_a = EmbeddingFunc(
+ embedding_dim=768, func=mock_embedding_func.func, model_name="bge-small"
+ )
+
+ storage_a = PGVectorStorage(
+ namespace=NameSpace.VECTOR_STORE_CHUNKS,
+ global_config=config,
+ embedding_func=embedding_func_a,
+ workspace="workspace_a",
+ )
+
+ # Workspace B: uses bge-large (1024d)
+ async def embed_func_b(texts, **kwargs):
+ return np.array([[0.1] * 1024 for _ in texts])
+
+ embedding_func_b = EmbeddingFunc(
+ embedding_dim=1024, func=embed_func_b, model_name="bge-large"
+ )
+
+ storage_b = PGVectorStorage(
+ namespace=NameSpace.VECTOR_STORE_CHUNKS,
+ global_config=config,
+ embedding_func=embedding_func_b,
+ workspace="workspace_b",
+ )
+
+ # Verify different table names
+ assert storage_a.table_name != storage_b.table_name
+ assert "bge_small_768d" in storage_a.table_name
+ assert "bge_large_1024d" in storage_b.table_name
+
+ # Mock: both tables don't exist yet
+ async def mock_check_table_exists(table_name):
+ return False
+
+ mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
+
+ with patch(
+ "lightrag.kg.postgres_impl.PGVectorStorage._pg_create_table", AsyncMock()
+ ) as mock_create:
+ # Initialize both storages
+ await storage_a.initialize()
+ await storage_b.initialize()
+
+ # Verify two separate tables were created
+ assert mock_create.call_count == 2
+
+ # Verify table names are different
+ call_args_list = mock_create.call_args_list
+ table_names = [call[0][1] for call in call_args_list] # Second positional arg
+ assert len(set(table_names)) == 2 # Two unique table names
+ assert storage_a.table_name in table_names
+ assert storage_b.table_name in table_names
+
+
+async def test_case1_empty_legacy_auto_cleanup(
+ mock_client_manager, mock_pg_db, mock_embedding_func
+):
+ """
+ Case 1a: Both new and legacy tables exist, but legacy is EMPTY
+ Expected: Automatically delete empty legacy table (safe cleanup)
+ """
+ config = {
+ "embedding_batch_num": 10,
+ "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
+ }
+
+ embedding_func = EmbeddingFunc(
+ embedding_dim=1536,
+ func=mock_embedding_func.func,
+ model_name="test-model",
+ )
+
+ storage = PGVectorStorage(
+ namespace=NameSpace.VECTOR_STORE_CHUNKS,
+ global_config=config,
+ embedding_func=embedding_func,
+ workspace="test_ws",
+ )
+
+ # Mock: Both tables exist
+ async def mock_check_table_exists(table_name):
+ return True # Both new and legacy exist
+
+ mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
+
+ # Mock: Legacy table is empty (0 records)
+ async def mock_query(sql, params=None, multirows=False, **kwargs):
+ if "COUNT(*)" in sql:
+ if storage.legacy_table_name in sql:
+ return {"count": 0} # Empty legacy table
+ else:
+ return {"count": 100} # New table has data
+ return {}
+
+ mock_pg_db.query = AsyncMock(side_effect=mock_query)
+
+ with patch("lightrag.kg.postgres_impl.logger"):
+ await storage.initialize()
+
+ # Verify: Empty legacy table should be automatically cleaned up
+ # Empty tables are safe to delete without data loss risk
+ delete_calls = [
+ call
+ for call in mock_pg_db.execute.call_args_list
+ if call[0][0] and "DROP TABLE" in call[0][0]
+ ]
+ assert len(delete_calls) >= 1, "Empty legacy table should be auto-deleted"
+ # Check if legacy table was dropped
+ dropped_table = storage.legacy_table_name
+ assert any(
+ dropped_table in str(call) for call in delete_calls
+ ), f"Expected to drop empty legacy table '{dropped_table}'"
+
+ print(
+ f"✅ Case 1a: Empty legacy table '{dropped_table}' auto-deleted successfully"
+ )
+
+
+async def test_case1_nonempty_legacy_warning(
+ mock_client_manager, mock_pg_db, mock_embedding_func
+):
+ """
+ Case 1b: Both new and legacy tables exist, and legacy HAS DATA
+ Expected: Log warning, do not delete legacy (preserve data)
+ """
+ config = {
+ "embedding_batch_num": 10,
+ "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
+ }
+
+ embedding_func = EmbeddingFunc(
+ embedding_dim=1536,
+ func=mock_embedding_func.func,
+ model_name="test-model",
+ )
+
+ storage = PGVectorStorage(
+ namespace=NameSpace.VECTOR_STORE_CHUNKS,
+ global_config=config,
+ embedding_func=embedding_func,
+ workspace="test_ws",
+ )
+
+ # Mock: Both tables exist
+ async def mock_check_table_exists(table_name):
+ return True # Both new and legacy exist
+
+ mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists)
+
+ # Mock: Legacy table has data (50 records)
+ async def mock_query(sql, params=None, multirows=False, **kwargs):
+ if "COUNT(*)" in sql:
+ if storage.legacy_table_name in sql:
+ return {"count": 50} # Legacy has data
+ else:
+ return {"count": 100} # New table has data
+ return {}
+
+ mock_pg_db.query = AsyncMock(side_effect=mock_query)
+
+ with patch("lightrag.kg.postgres_impl.logger"):
+ await storage.initialize()
+
+ # Verify: Legacy table with data should be preserved
+ # We never auto-delete tables that contain data to prevent accidental data loss
+ delete_calls = [
+ call
+ for call in mock_pg_db.execute.call_args_list
+ if call[0][0] and "DROP TABLE" in call[0][0]
+ ]
+ # Check if legacy table was deleted (it should not be)
+ dropped_table = storage.legacy_table_name
+ legacy_deleted = any(dropped_table in str(call) for call in delete_calls)
+ assert not legacy_deleted, "Legacy table with data should NOT be auto-deleted"
+
+ print(
+ f"✅ Case 1b: Legacy table '{dropped_table}' with data preserved (warning only)"
+ )
+
+
+async def test_case1_sequential_workspace_migration(
+ mock_client_manager, mock_pg_db, mock_embedding_func
+):
+ """
+ Case 1c: Sequential workspace migration (Multi-tenant scenario)
+
+ Critical bug fix verification:
+ Timeline:
+ 1. Legacy table has workspace_a (3 records) + workspace_b (3 records)
+ 2. Workspace A initializes first → Case 3 (only legacy exists) → migrates A's data
+ 3. Workspace B initializes later → Case 3 (both tables exist, legacy has B's data) → should migrate B's data
+ 4. Verify workspace B's data is correctly migrated to new table
+
+ This test verifies the migration logic correctly handles multi-tenant scenarios
+ where different workspaces migrate sequentially.
+ """
+ config = {
+ "embedding_batch_num": 10,
+ "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
+ }
+
+ embedding_func = EmbeddingFunc(
+ embedding_dim=1536,
+ func=mock_embedding_func.func,
+ model_name="test-model",
+ )
+
+ # Mock data: Legacy table has 6 records total (3 from workspace_a, 3 from workspace_b)
+ mock_rows_a = [
+ {"id": f"a_{i}", "content": f"A content {i}", "workspace": "workspace_a"}
+ for i in range(3)
+ ]
+ mock_rows_b = [
+ {"id": f"b_{i}", "content": f"B content {i}", "workspace": "workspace_b"}
+ for i in range(3)
+ ]
+
+ # Track migration state
+ migration_state = {
+ "new_table_exists": False,
+ "workspace_a_migrated": False,
+ "workspace_a_migration_count": 0,
+ "workspace_b_migration_count": 0,
+ }
+
+ # Step 1: Simulate workspace_a initialization (Case 3 - only legacy exists)
+ # CRITICAL: Set db.workspace to workspace_a
+ mock_pg_db.workspace = "workspace_a"
+
+ storage_a = PGVectorStorage(
+ namespace=NameSpace.VECTOR_STORE_CHUNKS,
+ global_config=config,
+ embedding_func=embedding_func,
+ workspace="workspace_a",
+ )
+
+ # Mock table_exists for workspace_a
+ async def mock_check_table_exists_a(table_name):
+ if table_name == storage_a.legacy_table_name:
+ return True
+ if table_name == storage_a.table_name:
+ return migration_state["new_table_exists"]
+ return False
+
+ mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists_a)
+
+ # Mock query for workspace_a (Case 3)
+ async def mock_query_a(sql, params=None, multirows=False, **kwargs):
+ sql_upper = sql.upper()
+ base_name = storage_a.legacy_table_name.upper()
+
+ if "COUNT(*)" in sql:
+ has_model_suffix = "TEST_MODEL_1536D" in sql_upper
+ is_legacy = base_name in sql_upper and not has_model_suffix
+ has_workspace_filter = "WHERE workspace" in sql
+
+ if is_legacy and has_workspace_filter:
+ workspace = params[0] if params and len(params) > 0 else None
+ if workspace == "workspace_a":
+ return {"count": 3}
+ elif workspace == "workspace_b":
+ return {"count": 3}
+ elif is_legacy and not has_workspace_filter:
+ # Global count in legacy table
+ return {"count": 6}
+ elif has_model_suffix:
+ if has_workspace_filter:
+ workspace = params[0] if params and len(params) > 0 else None
+ if workspace == "workspace_a":
+ return {"count": migration_state["workspace_a_migration_count"]}
+ if workspace == "workspace_b":
+ return {"count": migration_state["workspace_b_migration_count"]}
+ return {
+ "count": migration_state["workspace_a_migration_count"]
+ + migration_state["workspace_b_migration_count"]
+ }
+ elif multirows and "SELECT *" in sql:
+ if "WHERE workspace" in sql:
+ workspace = params[0] if params and len(params) > 0 else None
+ if workspace == "workspace_a":
+ # Handle keyset pagination
+ if "id >" in sql:
+ # params = [workspace, last_id, limit]
+ last_id = params[1] if len(params) > 1 else None
+ start_idx = 0
+ for i, row in enumerate(mock_rows_a):
+ if row["id"] == last_id:
+ start_idx = i + 1
+ break
+ limit = params[2] if len(params) > 2 else 500
+ else:
+ # First batch: params = [workspace, limit]
+ start_idx = 0
+ limit = params[1] if len(params) > 1 else 500
+ end = min(start_idx + limit, len(mock_rows_a))
+ return mock_rows_a[start_idx:end]
+ return {}
+
+ mock_pg_db.query = AsyncMock(side_effect=mock_query_a)
+
+ # Track migration via _run_with_retry (batch migration uses this)
+ migration_a_executed = []
+
+ async def mock_run_with_retry_a(operation, **kwargs):
+ migration_a_executed.append(True)
+ migration_state["workspace_a_migration_count"] = len(mock_rows_a)
+ return None
+
+ mock_pg_db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry_a)
+
+ # Initialize workspace_a (Case 3)
+ with patch("lightrag.kg.postgres_impl.logger"):
+ await storage_a.initialize()
+ migration_state["new_table_exists"] = True
+ migration_state["workspace_a_migrated"] = True
+
+ print("✅ Step 1: Workspace A initialized")
+ # Verify migration was executed via _run_with_retry (batch migration uses executemany)
+ assert (
+ len(migration_a_executed) > 0
+ ), "Migration should have been executed for workspace_a"
+ print(f"✅ Step 1: Migration executed {len(migration_a_executed)} batch(es)")
+
+ # Step 2: Simulate workspace_b initialization (Case 3 - both exist, but legacy has B's data)
+ # CRITICAL: Set db.workspace to workspace_b
+ mock_pg_db.workspace = "workspace_b"
+
+ storage_b = PGVectorStorage(
+ namespace=NameSpace.VECTOR_STORE_CHUNKS,
+ global_config=config,
+ embedding_func=embedding_func,
+ workspace="workspace_b",
+ )
+
+ mock_pg_db.reset_mock()
+
+ # Mock table_exists for workspace_b (both exist)
+ async def mock_check_table_exists_b(table_name):
+ return True # Both tables exist
+
+ mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists_b)
+
+ # Mock query for workspace_b (Case 3)
+ async def mock_query_b(sql, params=None, multirows=False, **kwargs):
+ sql_upper = sql.upper()
+ base_name = storage_b.legacy_table_name.upper()
+
+ if "COUNT(*)" in sql:
+ has_model_suffix = "TEST_MODEL_1536D" in sql_upper
+ is_legacy = base_name in sql_upper and not has_model_suffix
+ has_workspace_filter = "WHERE workspace" in sql
+
+ if is_legacy and has_workspace_filter:
+ workspace = params[0] if params and len(params) > 0 else None
+ if workspace == "workspace_b":
+ return {"count": 3} # workspace_b still has data in legacy
+ elif workspace == "workspace_a":
+ return {"count": 0} # workspace_a already migrated
+ elif is_legacy and not has_workspace_filter:
+ # Global count: only workspace_b data remains
+ return {"count": 3}
+ elif has_model_suffix:
+ if has_workspace_filter:
+ workspace = params[0] if params and len(params) > 0 else None
+ if workspace == "workspace_b":
+ return {"count": migration_state["workspace_b_migration_count"]}
+ elif workspace == "workspace_a":
+ return {"count": 3}
+ else:
+ return {"count": 3 + migration_state["workspace_b_migration_count"]}
+ elif multirows and "SELECT *" in sql:
+ if "WHERE workspace" in sql:
+ workspace = params[0] if params and len(params) > 0 else None
+ if workspace == "workspace_b":
+ # Handle keyset pagination
+ if "id >" in sql:
+ # params = [workspace, last_id, limit]
+ last_id = params[1] if len(params) > 1 else None
+ start_idx = 0
+ for i, row in enumerate(mock_rows_b):
+ if row["id"] == last_id:
+ start_idx = i + 1
+ break
+ limit = params[2] if len(params) > 2 else 500
+ else:
+ # First batch: params = [workspace, limit]
+ start_idx = 0
+ limit = params[1] if len(params) > 1 else 500
+ end = min(start_idx + limit, len(mock_rows_b))
+ return mock_rows_b[start_idx:end]
+ return {}
+
+ mock_pg_db.query = AsyncMock(side_effect=mock_query_b)
+
+ # Track migration via _run_with_retry for workspace_b
+ migration_b_executed = []
+
+ async def mock_run_with_retry_b(operation, **kwargs):
+ migration_b_executed.append(True)
+ migration_state["workspace_b_migration_count"] = len(mock_rows_b)
+ return None
+
+ mock_pg_db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry_b)
+
+ # Initialize workspace_b (Case 3 - both tables exist)
+ with patch("lightrag.kg.postgres_impl.logger"):
+ await storage_b.initialize()
+
+ print("✅ Step 2: Workspace B initialized")
+
+ # Verify workspace_b migration happens when new table has no workspace_b data
+ # but legacy table still has workspace_b data.
+ assert (
+ len(migration_b_executed) > 0
+ ), "Migration should have been executed for workspace_b"
+ print("✅ Step 2: Migration executed for workspace_b")
+
+ print("\n🎉 Case 1c: Sequential workspace migration verification complete!")
+ print(" - Workspace A: Migrated successfully (only legacy existed)")
+ print(" - Workspace B: Migrated successfully (new table empty for workspace_b)")
diff --git a/tests/test_qdrant_migration.py b/tests/test_qdrant_migration.py
new file mode 100644
index 00000000..25d4eca9
--- /dev/null
+++ b/tests/test_qdrant_migration.py
@@ -0,0 +1,562 @@
+import pytest
+from unittest.mock import MagicMock, patch, AsyncMock
+import numpy as np
+from qdrant_client import models
+from lightrag.utils import EmbeddingFunc
+from lightrag.kg.qdrant_impl import QdrantVectorDBStorage
+
+
+# Mock QdrantClient
+@pytest.fixture
+def mock_qdrant_client():
+ with patch("lightrag.kg.qdrant_impl.QdrantClient") as mock_client_cls:
+ client = mock_client_cls.return_value
+ client.collection_exists.return_value = False
+ client.count.return_value.count = 0
+ # Mock payload schema and vector config for get_collection
+ collection_info = MagicMock()
+ collection_info.payload_schema = {}
+ # Mock vector dimension to match mock_embedding_func (768d)
+ collection_info.config.params.vectors.size = 768
+ client.get_collection.return_value = collection_info
+ yield client
+
+
+# Mock get_data_init_lock to avoid async lock issues in tests
+@pytest.fixture(autouse=True)
+def mock_data_init_lock():
+ with patch("lightrag.kg.qdrant_impl.get_data_init_lock") as mock_lock:
+ mock_lock_ctx = AsyncMock()
+ mock_lock.return_value = mock_lock_ctx
+ yield mock_lock
+
+
+# Mock Embedding function
+@pytest.fixture
+def mock_embedding_func():
+ async def embed_func(texts, **kwargs):
+ return np.array([[0.1] * 768 for _ in texts])
+
+ func = EmbeddingFunc(embedding_dim=768, func=embed_func, model_name="test-model")
+ return func
+
+
+async def test_qdrant_collection_naming(mock_qdrant_client, mock_embedding_func):
+ """Test if collection name is correctly generated with model suffix"""
+ config = {
+ "embedding_batch_num": 10,
+ "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
+ }
+
+ storage = QdrantVectorDBStorage(
+ namespace="chunks",
+ global_config=config,
+ embedding_func=mock_embedding_func,
+ workspace="test_ws",
+ )
+
+ # Verify collection name contains model suffix
+ expected_suffix = "test_model_768d"
+ assert expected_suffix in storage.final_namespace
+ assert storage.final_namespace == f"lightrag_vdb_chunks_{expected_suffix}"
+
+
+async def test_qdrant_migration_trigger(mock_qdrant_client, mock_embedding_func):
+ """Test if migration logic is triggered correctly"""
+ config = {
+ "embedding_batch_num": 10,
+ "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
+ }
+
+ storage = QdrantVectorDBStorage(
+ namespace="chunks",
+ global_config=config,
+ embedding_func=mock_embedding_func,
+ workspace="test_ws",
+ )
+
+ # Legacy collection name (without model suffix)
+ legacy_collection = "lightrag_vdb_chunks"
+
+ # Setup mocks for migration scenario
+ # 1. New collection does not exist, only legacy exists
+ mock_qdrant_client.collection_exists.side_effect = (
+ lambda name: name == legacy_collection
+ )
+
+ # 2. Legacy collection exists and has data
+ migration_state = {"new_workspace_count": 0}
+
+ def count_mock(collection_name, exact=True, count_filter=None):
+ mock_result = MagicMock()
+ if collection_name == legacy_collection:
+ mock_result.count = 100
+ elif collection_name == storage.final_namespace:
+ mock_result.count = migration_state["new_workspace_count"]
+ else:
+ mock_result.count = 0
+ return mock_result
+
+ mock_qdrant_client.count.side_effect = count_mock
+
+ # 3. Mock scroll for data migration
+ mock_point = MagicMock()
+ mock_point.id = "old_id"
+ mock_point.vector = [0.1] * 768
+ mock_point.payload = {"content": "test"} # No workspace_id in payload
+
+ # When payload_schema is empty, the code first samples payloads to detect workspace_id
+ # Then proceeds with migration batches
+ # Scroll calls: 1) Sampling (limit=10), 2) Migration batch, 3) End of migration
+ mock_qdrant_client.scroll.side_effect = [
+ ([mock_point], "_"), # Sampling scroll - no workspace_id found
+ ([mock_point], "next_offset"), # Migration batch
+ ([], None), # End of migration
+ ]
+
+ def upsert_mock(*args, **kwargs):
+ migration_state["new_workspace_count"] = 100
+ return None
+
+ mock_qdrant_client.upsert.side_effect = upsert_mock
+
+ # Initialize storage (triggers migration)
+ await storage.initialize()
+
+ # Verify migration steps
+ # 1. Legacy count checked
+ mock_qdrant_client.count.assert_any_call(
+ collection_name=legacy_collection, exact=True
+ )
+
+ # 2. New collection created
+ mock_qdrant_client.create_collection.assert_called()
+
+ # 3. Data scrolled from legacy
+ # First call (index 0) is sampling scroll with limit=10
+ # Second call (index 1) is migration batch with limit=500
+ assert mock_qdrant_client.scroll.call_count >= 2
+ # Check sampling scroll
+ sampling_call = mock_qdrant_client.scroll.call_args_list[0]
+ assert sampling_call.kwargs["collection_name"] == legacy_collection
+ assert sampling_call.kwargs["limit"] == 10
+ # Check migration batch scroll
+ migration_call = mock_qdrant_client.scroll.call_args_list[1]
+ assert migration_call.kwargs["collection_name"] == legacy_collection
+ assert migration_call.kwargs["limit"] == 500
+
+ # 4. Data upserted to new
+ mock_qdrant_client.upsert.assert_called()
+
+ # 5. Payload index created
+ mock_qdrant_client.create_payload_index.assert_called()
+
+
+async def test_qdrant_no_migration_needed(mock_qdrant_client, mock_embedding_func):
+ """Test scenario where new collection already exists (Case 1 in setup_collection)
+
+ When only the new collection exists and no legacy collection is found,
+ the implementation should:
+ 1. Create payload index on the new collection (ensure index exists)
+ 2. NOT attempt any data migration (no scroll calls)
+ """
+ config = {
+ "embedding_batch_num": 10,
+ "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
+ }
+
+ storage = QdrantVectorDBStorage(
+ namespace="chunks",
+ global_config=config,
+ embedding_func=mock_embedding_func,
+ workspace="test_ws",
+ )
+
+ # Only new collection exists (no legacy collection found)
+ mock_qdrant_client.collection_exists.side_effect = (
+ lambda name: name == storage.final_namespace
+ )
+
+ # Initialize
+ await storage.initialize()
+
+ # Should create payload index on the new collection (ensure index)
+ mock_qdrant_client.create_payload_index.assert_called_with(
+ collection_name=storage.final_namespace,
+ field_name="workspace_id",
+ field_schema=models.KeywordIndexParams(
+ type=models.KeywordIndexType.KEYWORD,
+ is_tenant=True,
+ ),
+ )
+ # Should NOT migrate (no scroll calls since no legacy collection exists)
+ mock_qdrant_client.scroll.assert_not_called()
+
+
+# ============================================================================
+# Tests for scenarios described in design document (Lines 606-649)
+# ============================================================================
+
+
+async def test_scenario_1_new_workspace_creation(
+ mock_qdrant_client, mock_embedding_func
+):
+ """
+ 场景1:新建workspace
+ 预期:直接创建lightrag_vdb_chunks_text_embedding_3_large_3072d
+ """
+ # Use a large embedding model
+ large_model_func = EmbeddingFunc(
+ embedding_dim=3072,
+ func=mock_embedding_func.func,
+ model_name="text-embedding-3-large",
+ )
+
+ config = {
+ "embedding_batch_num": 10,
+ "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
+ }
+
+ storage = QdrantVectorDBStorage(
+ namespace="chunks",
+ global_config=config,
+ embedding_func=large_model_func,
+ workspace="test_new",
+ )
+
+ # Case 3: Neither legacy nor new collection exists
+ mock_qdrant_client.collection_exists.return_value = False
+
+ # Initialize storage
+ await storage.initialize()
+
+ # Verify: Should create new collection with model suffix
+ expected_collection = "lightrag_vdb_chunks_text_embedding_3_large_3072d"
+ assert storage.final_namespace == expected_collection
+
+ # Verify create_collection was called with correct name
+ create_calls = [
+ call for call in mock_qdrant_client.create_collection.call_args_list
+ ]
+ assert len(create_calls) > 0
+ assert (
+ create_calls[0][0][0] == expected_collection
+ or create_calls[0].kwargs.get("collection_name") == expected_collection
+ )
+
+ # Verify no migration was attempted
+ mock_qdrant_client.scroll.assert_not_called()
+
+ print(
+ f"✅ Scenario 1: New workspace created with collection '{expected_collection}'"
+ )
+
+
+async def test_scenario_2_legacy_upgrade_migration(
+ mock_qdrant_client, mock_embedding_func
+):
+ """
+ 场景2:从旧版本升级
+ 已存在lightrag_vdb_chunks(无后缀)
+ 预期:自动迁移数据到lightrag_vdb_chunks_text_embedding_ada_002_1536d
+ 注意:迁移后不再自动删除遗留集合,需要手动删除
+ """
+ # Use ada-002 model
+ ada_func = EmbeddingFunc(
+ embedding_dim=1536,
+ func=mock_embedding_func.func,
+ model_name="text-embedding-ada-002",
+ )
+
+ config = {
+ "embedding_batch_num": 10,
+ "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
+ }
+
+ storage = QdrantVectorDBStorage(
+ namespace="chunks",
+ global_config=config,
+ embedding_func=ada_func,
+ workspace="test_legacy",
+ )
+
+ # Legacy collection name (without model suffix)
+ legacy_collection = "lightrag_vdb_chunks"
+ new_collection = storage.final_namespace
+
+ # Case 4: Only legacy collection exists
+ mock_qdrant_client.collection_exists.side_effect = (
+ lambda name: name == legacy_collection
+ )
+
+ # Mock legacy collection info with 1536d vectors
+ legacy_collection_info = MagicMock()
+ legacy_collection_info.payload_schema = {}
+ legacy_collection_info.config.params.vectors.size = 1536
+ mock_qdrant_client.get_collection.return_value = legacy_collection_info
+
+ migration_state = {"new_workspace_count": 0}
+
+ def count_mock(collection_name, exact=True, count_filter=None):
+ mock_result = MagicMock()
+ if collection_name == legacy_collection:
+ mock_result.count = 150
+ elif collection_name == new_collection:
+ mock_result.count = migration_state["new_workspace_count"]
+ else:
+ mock_result.count = 0
+ return mock_result
+
+ mock_qdrant_client.count.side_effect = count_mock
+
+ # Mock scroll results (simulate migration in batches)
+ mock_points = []
+ for i in range(10):
+ point = MagicMock()
+ point.id = f"legacy-{i}"
+ point.vector = [0.1] * 1536
+ # No workspace_id in payload - simulates legacy data
+ point.payload = {"content": f"Legacy document {i}", "id": f"doc-{i}"}
+ mock_points.append(point)
+
+ # When payload_schema is empty, the code first samples payloads to detect workspace_id
+ # Then proceeds with migration batches
+ # Scroll calls: 1) Sampling (limit=10), 2) Migration batch, 3) End of migration
+ mock_qdrant_client.scroll.side_effect = [
+ (mock_points, "_"), # Sampling scroll - no workspace_id found in payloads
+ (mock_points, "offset1"), # Migration batch
+ ([], None), # End of migration
+ ]
+
+ def upsert_mock(*args, **kwargs):
+ migration_state["new_workspace_count"] = 150
+ return None
+
+ mock_qdrant_client.upsert.side_effect = upsert_mock
+
+ # Initialize (triggers migration)
+ await storage.initialize()
+
+ # Verify: New collection should be created
+ expected_new_collection = "lightrag_vdb_chunks_text_embedding_ada_002_1536d"
+ assert storage.final_namespace == expected_new_collection
+
+ # Verify migration steps
+ # 1. Check legacy count
+ mock_qdrant_client.count.assert_any_call(
+ collection_name=legacy_collection, exact=True
+ )
+
+ # 2. Create new collection
+ mock_qdrant_client.create_collection.assert_called()
+
+ # 3. Scroll legacy data
+ scroll_calls = [call for call in mock_qdrant_client.scroll.call_args_list]
+ assert len(scroll_calls) >= 1
+ assert scroll_calls[0].kwargs["collection_name"] == legacy_collection
+
+ # 4. Upsert to new collection
+ upsert_calls = [call for call in mock_qdrant_client.upsert.call_args_list]
+ assert len(upsert_calls) >= 1
+ assert upsert_calls[0].kwargs["collection_name"] == new_collection
+
+ # Note: Legacy collection is NOT automatically deleted after migration
+ # Manual deletion is required after data migration verification
+
+ print(
+ f"✅ Scenario 2: Legacy data migrated from '{legacy_collection}' to '{expected_new_collection}'"
+ )
+
+
+async def test_scenario_3_multi_model_coexistence(mock_qdrant_client):
+ """
+ 场景3:多模型并存
+ 预期:两个独立的collection,互不干扰
+ """
+
+ # Model A: bge-small with 768d
+ async def embed_func_a(texts, **kwargs):
+ return np.array([[0.1] * 768 for _ in texts])
+
+ model_a_func = EmbeddingFunc(
+ embedding_dim=768, func=embed_func_a, model_name="bge-small"
+ )
+
+ # Model B: bge-large with 1024d
+ async def embed_func_b(texts, **kwargs):
+ return np.array([[0.2] * 1024 for _ in texts])
+
+ model_b_func = EmbeddingFunc(
+ embedding_dim=1024, func=embed_func_b, model_name="bge-large"
+ )
+
+ config = {
+ "embedding_batch_num": 10,
+ "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
+ }
+
+ # Create storage for workspace A with model A
+ storage_a = QdrantVectorDBStorage(
+ namespace="chunks",
+ global_config=config,
+ embedding_func=model_a_func,
+ workspace="workspace_a",
+ )
+
+ # Create storage for workspace B with model B
+ storage_b = QdrantVectorDBStorage(
+ namespace="chunks",
+ global_config=config,
+ embedding_func=model_b_func,
+ workspace="workspace_b",
+ )
+
+ # Verify: Collection names are different
+ assert storage_a.final_namespace != storage_b.final_namespace
+
+ # Verify: Model A collection
+ expected_collection_a = "lightrag_vdb_chunks_bge_small_768d"
+ assert storage_a.final_namespace == expected_collection_a
+
+ # Verify: Model B collection
+ expected_collection_b = "lightrag_vdb_chunks_bge_large_1024d"
+ assert storage_b.final_namespace == expected_collection_b
+
+ # Verify: Different embedding dimensions are preserved
+ assert storage_a.embedding_func.embedding_dim == 768
+ assert storage_b.embedding_func.embedding_dim == 1024
+
+ print("✅ Scenario 3: Multi-model coexistence verified")
+ print(f" - Workspace A: {expected_collection_a} (768d)")
+ print(f" - Workspace B: {expected_collection_b} (1024d)")
+ print(" - Collections are independent")
+
+
+async def test_case1_empty_legacy_auto_cleanup(mock_qdrant_client, mock_embedding_func):
+ """
+ Case 1a: 新旧collection都存在,且旧库为空
+ 预期:自动删除旧库
+ """
+ config = {
+ "embedding_batch_num": 10,
+ "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
+ }
+
+ storage = QdrantVectorDBStorage(
+ namespace="chunks",
+ global_config=config,
+ embedding_func=mock_embedding_func,
+ workspace="test_ws",
+ )
+
+ # Legacy collection name (without model suffix)
+ legacy_collection = "lightrag_vdb_chunks"
+ new_collection = storage.final_namespace
+
+ # Mock: Both collections exist
+ mock_qdrant_client.collection_exists.side_effect = lambda name: name in [
+ legacy_collection,
+ new_collection,
+ ]
+
+ # Mock: Legacy collection is empty (0 records)
+ def count_mock(collection_name, exact=True, count_filter=None):
+ mock_result = MagicMock()
+ if collection_name == legacy_collection:
+ mock_result.count = 0 # Empty legacy collection
+ else:
+ mock_result.count = 100 # New collection has data
+ return mock_result
+
+ mock_qdrant_client.count.side_effect = count_mock
+
+ # Mock get_collection for Case 2 check
+ collection_info = MagicMock()
+ collection_info.payload_schema = {"workspace_id": True}
+ mock_qdrant_client.get_collection.return_value = collection_info
+
+ # Initialize storage
+ await storage.initialize()
+
+ # Verify: Empty legacy collection should be automatically cleaned up
+ # Empty collections are safe to delete without data loss risk
+ delete_calls = [
+ call for call in mock_qdrant_client.delete_collection.call_args_list
+ ]
+ assert len(delete_calls) >= 1, "Empty legacy collection should be auto-deleted"
+ deleted_collection = (
+ delete_calls[0][0][0]
+ if delete_calls[0][0]
+ else delete_calls[0].kwargs.get("collection_name")
+ )
+ assert (
+ deleted_collection == legacy_collection
+ ), f"Expected to delete '{legacy_collection}', but deleted '{deleted_collection}'"
+
+ print(
+ f"✅ Case 1a: Empty legacy collection '{legacy_collection}' auto-deleted successfully"
+ )
+
+
+async def test_case1_nonempty_legacy_warning(mock_qdrant_client, mock_embedding_func):
+ """
+ Case 1b: 新旧collection都存在,且旧库有数据
+ 预期:警告但不删除
+ """
+ config = {
+ "embedding_batch_num": 10,
+ "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
+ }
+
+ storage = QdrantVectorDBStorage(
+ namespace="chunks",
+ global_config=config,
+ embedding_func=mock_embedding_func,
+ workspace="test_ws",
+ )
+
+ # Legacy collection name (without model suffix)
+ legacy_collection = "lightrag_vdb_chunks"
+ new_collection = storage.final_namespace
+
+ # Mock: Both collections exist
+ mock_qdrant_client.collection_exists.side_effect = lambda name: name in [
+ legacy_collection,
+ new_collection,
+ ]
+
+ # Mock: Legacy collection has data (50 records)
+ def count_mock(collection_name, exact=True, count_filter=None):
+ mock_result = MagicMock()
+ if collection_name == legacy_collection:
+ mock_result.count = 50 # Legacy has data
+ else:
+ mock_result.count = 100 # New collection has data
+ return mock_result
+
+ mock_qdrant_client.count.side_effect = count_mock
+
+ # Mock get_collection for Case 2 check
+ collection_info = MagicMock()
+ collection_info.payload_schema = {"workspace_id": True}
+ mock_qdrant_client.get_collection.return_value = collection_info
+
+ # Initialize storage
+ await storage.initialize()
+
+ # Verify: Legacy collection with data should be preserved
+ # We never auto-delete collections that contain data to prevent accidental data loss
+ delete_calls = [
+ call for call in mock_qdrant_client.delete_collection.call_args_list
+ ]
+ # Check if legacy collection was deleted (it should not be)
+ legacy_deleted = any(
+ (call[0][0] if call[0] else call.kwargs.get("collection_name"))
+ == legacy_collection
+ for call in delete_calls
+ )
+ assert not legacy_deleted, "Legacy collection with data should NOT be auto-deleted"
+
+ print(
+ f"✅ Case 1b: Legacy collection '{legacy_collection}' with data preserved (warning only)"
+ )
diff --git a/tests/test_unified_lock_safety.py b/tests/test_unified_lock_safety.py
new file mode 100644
index 00000000..1d83190a
--- /dev/null
+++ b/tests/test_unified_lock_safety.py
@@ -0,0 +1,122 @@
+"""
+Tests for UnifiedLock safety when lock is None.
+
+This test module verifies that get_internal_lock() and get_data_init_lock()
+raise RuntimeError when shared data is not initialized, preventing false
+security and potential race conditions.
+
+Design: The None check has been moved from UnifiedLock.__aenter__/__enter__
+to the lock factory functions (get_internal_lock, get_data_init_lock) for
+early failure detection.
+
+Critical Bug 1 (Fixed): When self._lock is None, the code would fail with
+AttributeError. Now the check is in factory functions for clearer errors.
+
+Critical Bug 2: In __aexit__, when async_lock.release() fails, the error
+recovery logic would attempt to release it again, causing double-release issues.
+"""
+
+from unittest.mock import MagicMock, AsyncMock
+
+import pytest
+
+from lightrag.kg.shared_storage import (
+ UnifiedLock,
+ get_internal_lock,
+ get_data_init_lock,
+ finalize_share_data,
+)
+
+
+class TestUnifiedLockSafety:
+ """Test suite for UnifiedLock None safety checks."""
+
+ def setup_method(self):
+ """Ensure shared data is finalized before each test."""
+ finalize_share_data()
+
+ def teardown_method(self):
+ """Clean up after each test."""
+ finalize_share_data()
+
+ def test_get_internal_lock_raises_when_not_initialized(self):
+ """
+ Test that get_internal_lock() raises RuntimeError when shared data is not initialized.
+
+ Scenario: Call get_internal_lock() before initialize_share_data() is called.
+ Expected: RuntimeError raised with clear error message.
+
+ This test verifies the None check has been moved to the factory function.
+ """
+ with pytest.raises(
+ RuntimeError, match="Shared data not initialized.*initialize_share_data"
+ ):
+ get_internal_lock()
+
+ def test_get_data_init_lock_raises_when_not_initialized(self):
+ """
+ Test that get_data_init_lock() raises RuntimeError when shared data is not initialized.
+
+ Scenario: Call get_data_init_lock() before initialize_share_data() is called.
+ Expected: RuntimeError raised with clear error message.
+
+ This test verifies the None check has been moved to the factory function.
+ """
+ with pytest.raises(
+ RuntimeError, match="Shared data not initialized.*initialize_share_data"
+ ):
+ get_data_init_lock()
+
+ @pytest.mark.offline
+ async def test_aexit_no_double_release_on_async_lock_failure(self):
+ """
+ Test that __aexit__ doesn't attempt to release async_lock twice when it fails.
+
+ Scenario: async_lock.release() fails during normal release.
+ Expected: Recovery logic should NOT attempt to release async_lock again,
+ preventing double-release issues.
+
+ This tests Bug 2 fix: async_lock_released tracking prevents double release.
+ """
+ # Create mock locks
+ main_lock = MagicMock()
+ main_lock.acquire = MagicMock()
+ main_lock.release = MagicMock()
+
+ async_lock = AsyncMock()
+ async_lock.acquire = AsyncMock()
+
+ # Make async_lock.release() fail
+ release_call_count = 0
+
+ def mock_release_fail():
+ nonlocal release_call_count
+ release_call_count += 1
+ raise RuntimeError("Async lock release failed")
+
+ async_lock.release = MagicMock(side_effect=mock_release_fail)
+
+ # Create UnifiedLock with both locks (sync mode with async_lock)
+ lock = UnifiedLock(
+ lock=main_lock,
+ is_async=False,
+ name="test_double_release",
+ enable_logging=False,
+ )
+ lock._async_lock = async_lock
+
+ # Try to use the lock - should fail during __aexit__
+ try:
+ async with lock:
+ pass
+ except RuntimeError as e:
+ # Should get the async lock release error
+ assert "Async lock release failed" in str(e)
+
+ # Verify async_lock.release() was called only ONCE, not twice
+ assert (
+ release_call_count == 1
+ ), f"async_lock.release() should be called only once, but was called {release_call_count} times"
+
+ # Main lock should have been released successfully
+ main_lock.release.assert_called_once()
diff --git a/tests/test_workspace_isolation.py b/tests/test_workspace_isolation.py
index 68f7f8ec..0aac3186 100644
--- a/tests/test_workspace_isolation.py
+++ b/tests/test_workspace_isolation.py
@@ -149,7 +149,6 @@ def _assert_no_timeline_overlap(timeline: List[Tuple[str, str]]) -> None:
@pytest.mark.offline
-@pytest.mark.asyncio
async def test_pipeline_status_isolation():
"""
Test that pipeline status is isolated between different workspaces.
@@ -204,7 +203,6 @@ async def test_pipeline_status_isolation():
@pytest.mark.offline
-@pytest.mark.asyncio
async def test_lock_mechanism(stress_test_mode, parallel_workers):
"""
Test that the new keyed lock mechanism works correctly without deadlocks.
@@ -274,7 +272,6 @@ async def test_lock_mechanism(stress_test_mode, parallel_workers):
@pytest.mark.offline
-@pytest.mark.asyncio
async def test_backward_compatibility():
"""
Test that legacy code without workspace parameter still works correctly.
@@ -348,7 +345,6 @@ async def test_backward_compatibility():
@pytest.mark.offline
-@pytest.mark.asyncio
async def test_multi_workspace_concurrency():
"""
Test that multiple workspaces can operate concurrently without interference.
@@ -432,7 +428,6 @@ async def test_multi_workspace_concurrency():
@pytest.mark.offline
-@pytest.mark.asyncio
async def test_namespace_lock_reentrance():
"""
Test that NamespaceLock prevents re-entrance in the same coroutine
@@ -506,7 +501,6 @@ async def test_namespace_lock_reentrance():
@pytest.mark.offline
-@pytest.mark.asyncio
async def test_different_namespace_lock_isolation():
"""
Test that locks for different namespaces (same workspace) are independent.
@@ -546,7 +540,6 @@ async def test_different_namespace_lock_isolation():
@pytest.mark.offline
-@pytest.mark.asyncio
async def test_error_handling():
"""
Test error handling for invalid workspace configurations.
@@ -597,7 +590,6 @@ async def test_error_handling():
@pytest.mark.offline
-@pytest.mark.asyncio
async def test_update_flags_workspace_isolation():
"""
Test that update flags are properly isolated between workspaces.
@@ -727,7 +719,6 @@ async def test_update_flags_workspace_isolation():
@pytest.mark.offline
-@pytest.mark.asyncio
async def test_empty_workspace_standardization():
"""
Test that empty workspace is properly standardized to "" instead of "_".
@@ -781,7 +772,6 @@ async def test_empty_workspace_standardization():
@pytest.mark.offline
-@pytest.mark.asyncio
async def test_json_kv_storage_workspace_isolation(keep_test_artifacts):
"""
Integration test: Verify JsonKVStorage properly isolates data between workspaces.
@@ -961,7 +951,6 @@ async def test_json_kv_storage_workspace_isolation(keep_test_artifacts):
@pytest.mark.offline
-@pytest.mark.asyncio
async def test_lightrag_end_to_end_workspace_isolation(keep_test_artifacts):
"""
End-to-end test: Create two LightRAG instances with different workspaces,
diff --git a/tests/test_workspace_migration_isolation.py b/tests/test_workspace_migration_isolation.py
new file mode 100644
index 00000000..d0e3bfd2
--- /dev/null
+++ b/tests/test_workspace_migration_isolation.py
@@ -0,0 +1,288 @@
+"""
+Tests for workspace isolation during PostgreSQL migration.
+
+This test module verifies that setup_table() properly filters migration data
+by workspace, preventing cross-workspace data leakage during legacy table migration.
+
+Critical Bug: Migration copied ALL records from legacy table regardless of workspace,
+causing workspace A to receive workspace B's data, violating multi-tenant isolation.
+"""
+
+import pytest
+from unittest.mock import AsyncMock
+
+from lightrag.kg.postgres_impl import PGVectorStorage
+
+
+class TestWorkspaceMigrationIsolation:
+ """Test suite for workspace-scoped migration in PostgreSQL."""
+
+ async def test_migration_filters_by_workspace(self):
+ """
+ Test that migration only copies data from the specified workspace.
+
+ Scenario: Legacy table contains data from multiple workspaces.
+ Migrate only workspace_a's data to new table.
+ Expected: New table contains only workspace_a data, workspace_b data excluded.
+ """
+ db = AsyncMock()
+
+ # Configure mock return values to avoid unawaited coroutine warnings
+ db._create_vector_index.return_value = None
+
+ # Track state for new table count (starts at 0, increases after migration)
+ new_table_record_count = {"count": 0}
+
+ # Mock table existence checks
+ async def table_exists_side_effect(db_instance, name):
+ if name.lower() == "lightrag_doc_chunks": # legacy
+ return True
+ elif name.lower() == "lightrag_doc_chunks_model_1536d": # new
+ return False # New table doesn't exist initially
+ return False
+
+ # Mock data for workspace_a
+ mock_records_a = [
+ {
+ "id": "a1",
+ "workspace": "workspace_a",
+ "content": "content_a1",
+ "content_vector": [0.1] * 1536,
+ },
+ {
+ "id": "a2",
+ "workspace": "workspace_a",
+ "content": "content_a2",
+ "content_vector": [0.2] * 1536,
+ },
+ ]
+
+ # Mock query responses
+ async def query_side_effect(sql, params, **kwargs):
+ multirows = kwargs.get("multirows", False)
+ sql_upper = sql.upper()
+
+ # Count query for new table workspace data (verification before migration)
+ if (
+ "COUNT(*)" in sql_upper
+ and "MODEL_1536D" in sql_upper
+ and "WHERE WORKSPACE" in sql_upper
+ ):
+ return new_table_record_count # Initially 0
+
+ # Count query with workspace filter (legacy table) - for workspace count
+ elif "COUNT(*)" in sql_upper and "WHERE WORKSPACE" in sql_upper:
+ if params and params[0] == "workspace_a":
+ return {"count": 2} # workspace_a has 2 records
+ elif params and params[0] == "workspace_b":
+ return {"count": 3} # workspace_b has 3 records
+ return {"count": 0}
+
+ # Count query for legacy table (total, no workspace filter)
+ elif (
+ "COUNT(*)" in sql_upper
+ and "LIGHTRAG" in sql_upper
+ and "WHERE WORKSPACE" not in sql_upper
+ ):
+ return {"count": 5} # Total records in legacy
+
+ # SELECT with workspace filter for migration (multirows)
+ elif "SELECT" in sql_upper and "FROM" in sql_upper and multirows:
+ workspace = params[0] if params else None
+ if workspace == "workspace_a":
+ # Handle keyset pagination: check for "id >" pattern
+ if "id >" in sql.lower():
+ # Keyset pagination: params = [workspace, last_id, limit]
+ last_id = params[1] if len(params) > 1 else None
+ # Find records after last_id
+ found_idx = -1
+ for i, rec in enumerate(mock_records_a):
+ if rec["id"] == last_id:
+ found_idx = i
+ break
+ if found_idx >= 0:
+ return mock_records_a[found_idx + 1 :]
+ return []
+ else:
+ # First batch: params = [workspace, limit]
+ return mock_records_a
+ return [] # No data for other workspaces
+
+ return {}
+
+ db.query.side_effect = query_side_effect
+ db.execute = AsyncMock()
+
+ # Mock check_table_exists on db
+ async def check_table_exists_side_effect(name):
+ if name.lower() == "lightrag_doc_chunks": # legacy
+ return True
+ elif name.lower() == "lightrag_doc_chunks_model_1536d": # new
+ return False # New table doesn't exist initially
+ return False
+
+ db.check_table_exists = AsyncMock(side_effect=check_table_exists_side_effect)
+
+ # Track migration through _run_with_retry calls
+ migration_executed = []
+
+ async def mock_run_with_retry(operation, *args, **kwargs):
+ migration_executed.append(True)
+ new_table_record_count["count"] = 2 # Simulate 2 records migrated
+ return None
+
+ db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry)
+
+ # Migrate for workspace_a only - correct parameter order
+ await PGVectorStorage.setup_table(
+ db,
+ "LIGHTRAG_DOC_CHUNKS_model_1536d",
+ workspace="workspace_a", # CRITICAL: Only migrate workspace_a
+ embedding_dim=1536,
+ legacy_table_name="LIGHTRAG_DOC_CHUNKS",
+ base_table="LIGHTRAG_DOC_CHUNKS",
+ )
+
+ # Verify the migration was triggered
+ assert (
+ len(migration_executed) > 0
+ ), "Migration should have been executed for workspace_a"
+
+ async def test_migration_without_workspace_raises_error(self):
+ """
+ Test that migration without workspace parameter raises ValueError.
+
+ Scenario: setup_table called without workspace parameter.
+ Expected: ValueError is raised because workspace is required.
+ """
+ db = AsyncMock()
+
+ # workspace is now a required parameter - calling with None should raise ValueError
+ with pytest.raises(ValueError, match="workspace must be provided"):
+ await PGVectorStorage.setup_table(
+ db,
+ "lightrag_doc_chunks_model_1536d",
+ workspace=None, # No workspace - should raise ValueError
+ embedding_dim=1536,
+ legacy_table_name="lightrag_doc_chunks",
+ base_table="lightrag_doc_chunks",
+ )
+
+ async def test_no_cross_workspace_contamination(self):
+ """
+ Test that workspace B's migration doesn't include workspace A's data.
+
+ Scenario: Migration for workspace_b only.
+ Expected: Only workspace_b data is queried, workspace_a data excluded.
+ """
+ db = AsyncMock()
+
+ # Configure mock return values to avoid unawaited coroutine warnings
+ db._create_vector_index.return_value = None
+
+ # Track which workspace is being queried
+ queried_workspace = None
+ new_table_count = {"count": 0}
+
+ # Mock data for workspace_b
+ mock_records_b = [
+ {
+ "id": "b1",
+ "workspace": "workspace_b",
+ "content": "content_b1",
+ "content_vector": [0.3] * 1536,
+ },
+ ]
+
+ async def table_exists_side_effect(db_instance, name):
+ if name.lower() == "lightrag_doc_chunks": # legacy
+ return True
+ elif name.lower() == "lightrag_doc_chunks_model_1536d": # new
+ return False
+ return False
+
+ async def query_side_effect(sql, params, **kwargs):
+ nonlocal queried_workspace
+ multirows = kwargs.get("multirows", False)
+ sql_upper = sql.upper()
+
+ # Count query for new table workspace data (should be 0 initially)
+ if (
+ "COUNT(*)" in sql_upper
+ and "MODEL_1536D" in sql_upper
+ and "WHERE WORKSPACE" in sql_upper
+ ):
+ return new_table_count
+
+ # Count query with workspace filter (legacy table)
+ elif "COUNT(*)" in sql_upper and "WHERE WORKSPACE" in sql_upper:
+ queried_workspace = params[0] if params else None
+ return {"count": 1} # 1 record for the queried workspace
+
+ # Count query for legacy table total (no workspace filter)
+ elif (
+ "COUNT(*)" in sql_upper
+ and "LIGHTRAG" in sql_upper
+ and "WHERE WORKSPACE" not in sql_upper
+ ):
+ return {"count": 3} # 3 total records in legacy
+
+ # SELECT with workspace filter for migration (multirows)
+ elif "SELECT" in sql_upper and "FROM" in sql_upper and multirows:
+ workspace = params[0] if params else None
+ if workspace == "workspace_b":
+ # Handle keyset pagination: check for "id >" pattern
+ if "id >" in sql.lower():
+ # Keyset pagination: params = [workspace, last_id, limit]
+ last_id = params[1] if len(params) > 1 else None
+ # Find records after last_id
+ found_idx = -1
+ for i, rec in enumerate(mock_records_b):
+ if rec["id"] == last_id:
+ found_idx = i
+ break
+ if found_idx >= 0:
+ return mock_records_b[found_idx + 1 :]
+ return []
+ else:
+ # First batch: params = [workspace, limit]
+ return mock_records_b
+ return [] # No data for other workspaces
+
+ return {}
+
+ db.query.side_effect = query_side_effect
+ db.execute = AsyncMock()
+
+ # Mock check_table_exists on db
+ async def check_table_exists_side_effect(name):
+ if name.lower() == "lightrag_doc_chunks": # legacy
+ return True
+ elif name.lower() == "lightrag_doc_chunks_model_1536d": # new
+ return False
+ return False
+
+ db.check_table_exists = AsyncMock(side_effect=check_table_exists_side_effect)
+
+ # Track migration through _run_with_retry calls
+ migration_executed = []
+
+ async def mock_run_with_retry(operation, *args, **kwargs):
+ migration_executed.append(True)
+ new_table_count["count"] = 1 # Simulate migration
+ return None
+
+ db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry)
+
+ # Migrate workspace_b - correct parameter order
+ await PGVectorStorage.setup_table(
+ db,
+ "LIGHTRAG_DOC_CHUNKS_model_1536d",
+ workspace="workspace_b", # Only migrate workspace_b
+ embedding_dim=1536,
+ legacy_table_name="LIGHTRAG_DOC_CHUNKS",
+ base_table="LIGHTRAG_DOC_CHUNKS",
+ )
+
+ # Verify only workspace_b was queried
+ assert queried_workspace == "workspace_b", "Should only query workspace_b"