diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index f54a7ee3..c9fb6c0b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,7 +13,7 @@ jobs: strategy: matrix: - python-version: ['3.12', '3.13', '3.14'] + python-version: ['3.12', '3.14'] steps: - uses: actions/checkout@v6 diff --git a/README-zh.md b/README-zh.md index 5a331b39..f72f4e01 100644 --- a/README-zh.md +++ b/README-zh.md @@ -286,7 +286,7 @@ if __name__ == "__main__": 参数 | **参数** | **类型** | **说明** | **默认值** | -|--------------|----------|-----------------|-------------| +| -------------- | ---------- | ----------------- | ------------- | | **working_dir** | `str` | 存储缓存的目录 | `lightrag_cache+timestamp` | | **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` | | **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` | @@ -425,7 +425,7 @@ async def llm_model_func( **kwargs ) -@wrap_embedding_func_with_attrs(embedding_dim=4096, max_token_size=8192) +@wrap_embedding_func_with_attrs(embedding_dim=4096, max_token_size=8192, model_name="solar-embedding-1-large-query") async def embedding_func(texts: list[str]) -> np.ndarray: return await openai_embed.func( texts, @@ -490,7 +490,7 @@ import numpy as np from lightrag.utils import wrap_embedding_func_with_attrs from lightrag.llm.ollama import ollama_model_complete, ollama_embed -@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192) +@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192, model_name="nomic-embed-text") async def embedding_func(texts: list[str]) -> np.ndarray: return await ollama_embed.func(texts, embed_model="nomic-embed-text") @@ -542,7 +542,7 @@ import numpy as np from lightrag.utils import wrap_embedding_func_with_attrs from lightrag.llm.ollama import ollama_model_complete, ollama_embed -@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192) +@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192, model_name="nomic-embed-text") async def embedding_func(texts: list[str]) -> np.ndarray: return await ollama_embed.func(texts, embed_model="nomic-embed-text") @@ -1633,24 +1633,24 @@ LightRAG使用以下提示生成高级查询,相应的代码在`example/genera ### 总体性能表 -| |**农业**| |**计算机科学**| |**法律**| |**混合**| | +||**农业**||**计算机科学**||**法律**||**混合**|| |----------------------|---------------|------------|------|------------|---------|------------|-------|------------| -| |NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**| +||NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**| |**全面性**|32.4%|**67.6%**|38.4%|**61.6%**|16.4%|**83.6%**|38.8%|**61.2%**| |**多样性**|23.6%|**76.4%**|38.0%|**62.0%**|13.6%|**86.4%**|32.4%|**67.6%**| |**赋能性**|32.4%|**67.6%**|38.8%|**61.2%**|16.4%|**83.6%**|42.8%|**57.2%**| |**总体**|32.4%|**67.6%**|38.8%|**61.2%**|15.2%|**84.8%**|40.0%|**60.0%**| -| |RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**| +||RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**| |**全面性**|31.6%|**68.4%**|38.8%|**61.2%**|15.2%|**84.8%**|39.2%|**60.8%**| |**多样性**|29.2%|**70.8%**|39.2%|**60.8%**|11.6%|**88.4%**|30.8%|**69.2%**| |**赋能性**|31.6%|**68.4%**|36.4%|**63.6%**|15.2%|**84.8%**|42.4%|**57.6%**| |**总体**|32.4%|**67.6%**|38.0%|**62.0%**|14.4%|**85.6%**|40.0%|**60.0%**| -| |HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**| +||HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**| |**全面性**|26.0%|**74.0%**|41.6%|**58.4%**|26.8%|**73.2%**|40.4%|**59.6%**| |**多样性**|24.0%|**76.0%**|38.8%|**61.2%**|20.0%|**80.0%**|32.4%|**67.6%**| |**赋能性**|25.2%|**74.8%**|40.8%|**59.2%**|26.0%|**74.0%**|46.0%|**54.0%**| |**总体**|24.8%|**75.2%**|41.6%|**58.4%**|26.4%|**73.6%**|42.4%|**57.6%**| -| |GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**| +||GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**| |**全面性**|45.6%|**54.4%**|48.4%|**51.6%**|48.4%|**51.6%**|**50.4%**|49.6%| |**多样性**|22.8%|**77.2%**|40.8%|**59.2%**|26.4%|**73.6%**|36.0%|**64.0%**| |**赋能性**|41.2%|**58.8%**|45.2%|**54.8%**|43.6%|**56.4%**|**50.8%**|49.2%| diff --git a/README.md b/README.md index b157c350..d7a4f563 100644 --- a/README.md +++ b/README.md @@ -287,9 +287,9 @@ A full list of LightRAG init parameters: Parameters | **Parameter** | **Type** | **Explanation** | **Default** | -|--------------|----------|-----------------|-------------| +| -------------- | ---------- | ----------------- | ------------- | | **working_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` | -| **workspace** | str | Workspace name for data isolation between different LightRAG Instances | | +| **workspace** | str | Workspace name for data isolation between different LightRAG Instances | | | **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` | | **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` | | **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`,`Neo4JStorage`,`PGGraphStorage`,`AGEStorage` | `NetworkXStorage` | @@ -421,7 +421,7 @@ async def llm_model_func( **kwargs ) -@wrap_embedding_func_with_attrs(embedding_dim=4096, max_token_size=8192) +@wrap_embedding_func_with_attrs(embedding_dim=4096, max_token_size=8192, model_name="solar-embedding-1-large-query") async def embedding_func(texts: list[str]) -> np.ndarray: return await openai_embed.func( texts, @@ -488,7 +488,7 @@ import numpy as np from lightrag.utils import wrap_embedding_func_with_attrs from lightrag.llm.ollama import ollama_model_complete, ollama_embed -@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192) +@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192, model_name="nomic-embed-text") async def embedding_func(texts: list[str]) -> np.ndarray: return await ollama_embed.func(texts, embed_model="nomic-embed-text") @@ -540,7 +540,7 @@ import numpy as np from lightrag.utils import wrap_embedding_func_with_attrs from lightrag.llm.ollama import ollama_model_complete, ollama_embed -@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192) +@wrap_embedding_func_with_attrs(embedding_dim=768, max_token_size=8192, model_name="nomic-embed-text") async def embedding_func(texts: list[str]) -> np.ndarray: return await ollama_embed.func(texts, embed_model="nomic-embed-text") @@ -1701,24 +1701,24 @@ Output your evaluation in the following JSON format: ### Overall Performance Table -| |**Agriculture**| |**CS**| |**Legal**| |**Mix**| | +||**Agriculture**||**CS**||**Legal**||**Mix**|| |----------------------|---------------|------------|------|------------|---------|------------|-------|------------| -| |NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**| +||NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**|NaiveRAG|**LightRAG**| |**Comprehensiveness**|32.4%|**67.6%**|38.4%|**61.6%**|16.4%|**83.6%**|38.8%|**61.2%**| |**Diversity**|23.6%|**76.4%**|38.0%|**62.0%**|13.6%|**86.4%**|32.4%|**67.6%**| |**Empowerment**|32.4%|**67.6%**|38.8%|**61.2%**|16.4%|**83.6%**|42.8%|**57.2%**| |**Overall**|32.4%|**67.6%**|38.8%|**61.2%**|15.2%|**84.8%**|40.0%|**60.0%**| -| |RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**| +||RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**|RQ-RAG|**LightRAG**| |**Comprehensiveness**|31.6%|**68.4%**|38.8%|**61.2%**|15.2%|**84.8%**|39.2%|**60.8%**| |**Diversity**|29.2%|**70.8%**|39.2%|**60.8%**|11.6%|**88.4%**|30.8%|**69.2%**| |**Empowerment**|31.6%|**68.4%**|36.4%|**63.6%**|15.2%|**84.8%**|42.4%|**57.6%**| |**Overall**|32.4%|**67.6%**|38.0%|**62.0%**|14.4%|**85.6%**|40.0%|**60.0%**| -| |HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**| +||HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**|HyDE|**LightRAG**| |**Comprehensiveness**|26.0%|**74.0%**|41.6%|**58.4%**|26.8%|**73.2%**|40.4%|**59.6%**| |**Diversity**|24.0%|**76.0%**|38.8%|**61.2%**|20.0%|**80.0%**|32.4%|**67.6%**| |**Empowerment**|25.2%|**74.8%**|40.8%|**59.2%**|26.0%|**74.0%**|46.0%|**54.0%**| |**Overall**|24.8%|**75.2%**|41.6%|**58.4%**|26.4%|**73.6%**|42.4%|**57.6%**| -| |GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**| +||GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**|GraphRAG|**LightRAG**| |**Comprehensiveness**|45.6%|**54.4%**|48.4%|**51.6%**|48.4%|**51.6%**|**50.4%**|49.6%| |**Diversity**|22.8%|**77.2%**|40.8%|**59.2%**|26.4%|**73.6%**|36.0%|**64.0%**| |**Empowerment**|41.2%|**58.8%**|45.2%|**54.8%**|43.6%|**56.4%**|**50.8%**|49.2%| diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 692be453..9151f02e 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -868,6 +868,7 @@ def create_app(args): func=optimized_embedding_function, max_token_size=final_max_token_size, send_dimensions=False, # Will be set later based on binding requirements + model_name=model, ) # Log final embedding configuration diff --git a/lightrag/base.py b/lightrag/base.py index bae0728b..75059377 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -220,6 +220,37 @@ class BaseVectorStorage(StorageNameSpace, ABC): cosine_better_than_threshold: float = field(default=0.2) meta_fields: set[str] = field(default_factory=set) + def __post_init__(self): + """Validate required embedding_func for vector storage.""" + if self.embedding_func is None: + raise ValueError( + "embedding_func is required for vector storage. " + "Please provide a valid EmbeddingFunc instance." + ) + + def _generate_collection_suffix(self) -> str | None: + """Generates collection/table suffix from embedding_func. + + Return suffix if model_name exists in embedding_func, otherwise return None. + Note: embedding_func is guaranteed to exist (validated in __post_init__). + + Returns: + str | None: Suffix string e.g. "text_embedding_3_large_3072d", or None if model_name not available + """ + import re + + # Check if model_name exists (model_name is optional in EmbeddingFunc) + model_name = getattr(self.embedding_func, "model_name", None) + if not model_name: + return None + + # embedding_dim is required in EmbeddingFunc + embedding_dim = self.embedding_func.embedding_dim + + # Generate suffix: clean model name and append dimension + safe_model_name = re.sub(r"[^a-zA-Z0-9_]", "_", model_name.lower()) + return f"{safe_model_name}_{embedding_dim}d" + @abstractmethod async def query( self, query: str, top_k: int, query_embedding: list[float] = None diff --git a/lightrag/exceptions.py b/lightrag/exceptions.py index 709f294d..7c9accef 100644 --- a/lightrag/exceptions.py +++ b/lightrag/exceptions.py @@ -128,8 +128,8 @@ class ChunkTokenLimitExceededError(ValueError): self.chunk_preview = truncated_preview -class QdrantMigrationError(Exception): - """Raised when Qdrant data migration from legacy collections fails.""" +class DataMigrationError(Exception): + """Raised when data migration from legacy collection/table fails.""" def __init__(self, message: str): super().__init__(message) diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index adb0058b..e299211b 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -28,6 +28,7 @@ class FaissVectorDBStorage(BaseVectorStorage): """ def __post_init__(self): + super().__post_init__() # Grab config values if available kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = kwargs.get("cosine_better_than_threshold") @@ -358,9 +359,22 @@ class FaissVectorDBStorage(BaseVectorStorage): ) return + dim_mismatch = False try: # Load the Faiss index self._index = faiss.read_index(self._faiss_index_file) + + # Verify dimension consistency between loaded index and embedding function + if self._index.d != self._dim: + error_msg = ( + f"Dimension mismatch: loaded Faiss index has dimension {self._index.d}, " + f"but embedding function expects dimension {self._dim}. " + f"Please ensure the embedding model matches the stored index or rebuild the index." + ) + logger.error(error_msg) + dim_mismatch = True + raise ValueError(error_msg) + # Load metadata with open(self._meta_file, "r", encoding="utf-8") as f: stored_dict = json.load(f) @@ -375,6 +389,8 @@ class FaissVectorDBStorage(BaseVectorStorage): f"[{self.workspace}] Faiss index loaded with {self._index.ntotal} vectors from {self._faiss_index_file}" ) except Exception as e: + if dim_mismatch: + raise logger.error( f"[{self.workspace}] Failed to load Faiss index or metadata: {e}" ) diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index d42c91a7..50c233a8 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -934,6 +934,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): raise def __post_init__(self): + super().__post_init__() # Check for MILVUS_WORKSPACE environment variable first (higher priority) # This allows administrators to force a specific workspace for all Milvus storage instances milvus_workspace = os.environ.get("MILVUS_WORKSPACE") diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index e11e6411..351e0039 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -89,7 +89,7 @@ class MongoKVStorage(BaseKVStorage): global_config=global_config, embedding_func=embedding_func, ) - self.__post_init__() + # __post_init__() is automatically called by super().__init__() def __post_init__(self): # Check for MONGODB_WORKSPACE environment variable first (higher priority) @@ -317,7 +317,7 @@ class MongoDocStatusStorage(DocStatusStorage): global_config=global_config, embedding_func=embedding_func, ) - self.__post_init__() + # __post_init__() is automatically called by super().__init__() def __post_init__(self): # Check for MONGODB_WORKSPACE environment variable first (higher priority) @@ -2052,9 +2052,12 @@ class MongoVectorDBStorage(BaseVectorStorage): embedding_func=embedding_func, meta_fields=meta_fields or set(), ) - self.__post_init__() + # __post_init__() is automatically called by super().__init__() def __post_init__(self): + # Call parent class __post_init__ to validate embedding_func + super().__post_init__() + # Check for MONGODB_WORKSPACE environment variable first (higher priority) # This allows administrators to force a specific workspace for all MongoDB storage instances mongodb_workspace = os.environ.get("MONGODB_WORKSPACE") @@ -2131,8 +2134,32 @@ class MongoVectorDBStorage(BaseVectorStorage): indexes = await indexes_cursor.to_list(length=None) for index in indexes: if index["name"] == self._index_name: + # Check if the existing index has matching vector dimensions + existing_dim = None + definition = index.get("latestDefinition", {}) + fields = definition.get("fields", []) + for field in fields: + if ( + field.get("type") == "vector" + and field.get("path") == "vector" + ): + existing_dim = field.get("numDimensions") + break + + expected_dim = self.embedding_func.embedding_dim + + if existing_dim is not None and existing_dim != expected_dim: + error_msg = ( + f"Vector dimension mismatch! Index '{self._index_name}' has " + f"dimension {existing_dim}, but current embedding model expects " + f"dimension {expected_dim}. Please drop the existing index or " + f"use an embedding model with matching dimensions." + ) + logger.error(f"[{self.workspace}] {error_msg}") + raise ValueError(error_msg) + logger.info( - f"[{self.workspace}] vector index {self._index_name} already exist" + f"[{self.workspace}] vector index {self._index_name} already exists with matching dimensions ({expected_dim})" ) return diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index d390c37b..9b868c11 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -25,6 +25,7 @@ from .shared_storage import ( @dataclass class NanoVectorDBStorage(BaseVectorStorage): def __post_init__(self): + super().__post_init__() # Initialize basic attributes self._client = None self._storage_lock = None diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 49069ce3..6b70b505 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -1,4 +1,5 @@ import asyncio +import hashlib import json import os import re @@ -31,6 +32,7 @@ from ..base import ( DocStatus, DocStatusStorage, ) +from ..exceptions import DataMigrationError from ..namespace import NameSpace, is_namespace from ..utils import logger from ..kg.shared_storage import get_data_init_lock @@ -39,9 +41,12 @@ import pipmaster as pm if not pm.is_installed("asyncpg"): pm.install("asyncpg") +if not pm.is_installed("pgvector"): + pm.install("pgvector") import asyncpg # type: ignore from asyncpg import Pool # type: ignore +from pgvector.asyncpg import register_vector # type: ignore from dotenv import load_dotenv @@ -52,6 +57,42 @@ load_dotenv(dotenv_path=".env", override=False) T = TypeVar("T") +# PostgreSQL identifier length limit (in bytes) +PG_MAX_IDENTIFIER_LENGTH = 63 + + +def _safe_index_name(table_name: str, index_suffix: str) -> str: + """ + Generate a PostgreSQL-safe index name that won't be truncated. + + PostgreSQL silently truncates identifiers to 63 bytes. This function + ensures index names stay within that limit by hashing long table names. + + Args: + table_name: The table name (may be long with model suffix) + index_suffix: The index type suffix (e.g., 'hnsw_cosine', 'id', 'workspace_id') + + Returns: + A deterministic index name that fits within 63 bytes + """ + # Construct the full index name + full_name = f"idx_{table_name.lower()}_{index_suffix}" + + # If it fits within the limit, use it as-is + if len(full_name.encode("utf-8")) <= PG_MAX_IDENTIFIER_LENGTH: + return full_name + + # Otherwise, hash the table name to create a shorter unique identifier + # Keep 'idx_' prefix and suffix readable, hash the middle + hash_input = table_name.lower().encode("utf-8") + table_hash = hashlib.md5(hash_input).hexdigest()[:12] # 12 hex chars + + # Format: idx_{hash}_{suffix} - guaranteed to fit + # Maximum: idx_ (4) + hash (12) + _ (1) + suffix (variable) = 17 + suffix + shortened_name = f"idx_{table_hash}_{index_suffix}" + + return shortened_name + class PostgreSQLDB: def __init__(self, config: dict[str, Any], **kwargs: Any): @@ -252,14 +293,41 @@ class PostgreSQLDB: else wait_fixed(0) ) + async def _init_connection(connection: asyncpg.Connection) -> None: + """Initialize each connection with pgvector codec. + + This callback is invoked by asyncpg for every new connection in the pool. + Registering the vector codec here ensures ALL connections can properly + encode/decode vector columns, eliminating non-deterministic behavior + where some connections have the codec and others don't. + """ + await register_vector(connection) + async def _create_pool_once() -> None: - pool = await asyncpg.create_pool(**connection_params) # type: ignore + # STEP 1: Bootstrap - ensure vector extension exists BEFORE pool creation. + # On a fresh database, register_vector() in _init_connection will fail + # if the vector extension doesn't exist yet, because the 'vector' type + # won't be found in pg_catalog. We must create the extension first + # using a standalone bootstrap connection. + bootstrap_conn = await asyncpg.connect( + user=self.user, + password=self.password, + database=self.database, + host=self.host, + port=self.port, + ssl=connection_params.get("ssl"), + ) try: - async with pool.acquire() as connection: - await self.configure_vector_extension(connection) - except Exception: - await pool.close() - raise + await self.configure_vector_extension(bootstrap_conn) + finally: + await bootstrap_conn.close() + + # STEP 2: Now safe to create pool with register_vector callback. + # The vector extension is guaranteed to exist at this point. + pool = await asyncpg.create_pool( + **connection_params, + init=_init_connection, # Register pgvector codec on every connection + ) # type: ignore self.pool = pool try: @@ -564,6 +632,24 @@ class PostgreSQLDB: } try: + # Filter out tables that don't exist (e.g., legacy vector tables may not exist) + existing_tables = {} + for table_name, columns in tables_to_migrate.items(): + if await self.check_table_exists(table_name): + existing_tables[table_name] = columns + else: + logger.debug( + f"Table {table_name} does not exist, skipping timestamp migration" + ) + + # Skip if no tables to migrate + if not existing_tables: + logger.debug("No tables found for timestamp migration") + return + + # Use filtered tables for migration + tables_to_migrate = existing_tables + # Optimization: Batch check all columns in one query instead of 8 separate queries table_names_lower = [t.lower() for t in tables_to_migrate.keys()] all_column_names = list( @@ -640,6 +726,22 @@ class PostgreSQLDB: """ try: + # 0. Check if both tables exist before proceeding + vdb_chunks_exists = await self.check_table_exists("LIGHTRAG_VDB_CHUNKS") + doc_chunks_exists = await self.check_table_exists("LIGHTRAG_DOC_CHUNKS") + + if not vdb_chunks_exists: + logger.debug( + "Skipping migration: LIGHTRAG_VDB_CHUNKS table does not exist" + ) + return + + if not doc_chunks_exists: + logger.debug( + "Skipping migration: LIGHTRAG_DOC_CHUNKS table does not exist" + ) + return + # 1. Check if the new table LIGHTRAG_VDB_CHUNKS is empty vdb_chunks_count_sql = "SELECT COUNT(1) as count FROM LIGHTRAG_VDB_CHUNKS" vdb_chunks_count_result = await self.query(vdb_chunks_count_sql) @@ -1008,6 +1110,24 @@ class PostgreSQLDB: ] try: + # Filter out tables that don't exist (e.g., legacy vector tables may not exist) + existing_migrations = [] + for migration in field_migrations: + if await self.check_table_exists(migration["table"]): + existing_migrations.append(migration) + else: + logger.debug( + f"Table {migration['table']} does not exist, skipping field length migration for {migration['column']}" + ) + + # Skip if no migrations to process + if not existing_migrations: + logger.debug("No tables found for field length migration") + return + + # Use filtered migrations for processing + field_migrations = existing_migrations + # Optimization: Batch check all columns in one query instead of 5 separate queries unique_tables = list(set(m["table"].lower() for m in field_migrations)) unique_columns = list(set(m["column"] for m in field_migrations)) @@ -1092,8 +1212,20 @@ class PostgreSQLDB: logger.error(f"Failed to batch check field lengths: {e}") async def check_tables(self): - # First create all tables + # Vector tables that should be skipped - they are created by PGVectorStorage.setup_table() + # with proper embedding model and dimension suffix for data isolation + vector_tables_to_skip = { + "LIGHTRAG_VDB_CHUNKS", + "LIGHTRAG_VDB_ENTITY", + "LIGHTRAG_VDB_RELATION", + } + + # First create all tables (except vector tables) for k, v in TABLES.items(): + # Skip vector tables - they are created by PGVectorStorage.setup_table() + if k in vector_tables_to_skip: + continue + try: await self.query(f"SELECT 1 FROM {k} LIMIT 1") except Exception: @@ -1111,7 +1243,8 @@ class PostgreSQLDB: # Batch check all indexes at once (optimization: single query instead of N queries) try: - table_names = list(TABLES.keys()) + # Exclude vector tables from index creation since they are created by PGVectorStorage.setup_table() + table_names = [k for k in TABLES.keys() if k not in vector_tables_to_skip] table_names_lower = [t.lower() for t in table_names] # Get all existing indexes for our tables in one query @@ -1163,23 +1296,9 @@ class PostgreSQLDB: except Exception as e: logger.error(f"PostgreSQL, Failed to batch check/create indexes: {e}") - # Create vector indexs - if self.vector_index_type: - logger.info( - f"PostgreSQL, Create vector indexs, type: {self.vector_index_type}" - ) - try: - if self.vector_index_type in ["HNSW", "IVFFLAT", "VCHORDRQ"]: - await self._create_vector_indexes() - else: - logger.warning( - "Doesn't support this vector index type: {self.vector_index_type}. " - "Supported types: HNSW, IVFFLAT, VCHORDRQ" - ) - except Exception as e: - logger.error( - f"PostgreSQL, Failed to create vector index, type: {self.vector_index_type}, Got: {e}" - ) + # NOTE: Vector index creation moved to PGVectorStorage.setup_table() + # Each vector storage instance creates its own index with correct embedding_dim + # After all tables are created, attempt to migrate timestamp fields try: await self._migrate_timestamp_columns() @@ -1381,64 +1500,74 @@ class PostgreSQLDB: except Exception as e: logger.warning(f"Failed to create index {index['name']}: {e}") - async def _create_vector_indexes(self): - vdb_tables = [ - "LIGHTRAG_VDB_CHUNKS", - "LIGHTRAG_VDB_ENTITY", - "LIGHTRAG_VDB_RELATION", - ] + async def _create_vector_index(self, table_name: str, embedding_dim: int): + """ + Create vector index for a specific table. + + Args: + table_name: Name of the table to create index on + embedding_dim: Embedding dimension for the vector column + """ + if not self.vector_index_type: + return create_sql = { "HNSW": f""" CREATE INDEX {{vector_index_name}} - ON {{k}} USING hnsw (content_vector vector_cosine_ops) + ON {{table_name}} USING hnsw (content_vector vector_cosine_ops) WITH (m = {self.hnsw_m}, ef_construction = {self.hnsw_ef}) """, "IVFFLAT": f""" CREATE INDEX {{vector_index_name}} - ON {{k}} USING ivfflat (content_vector vector_cosine_ops) + ON {{table_name}} USING ivfflat (content_vector vector_cosine_ops) WITH (lists = {self.ivfflat_lists}) """, "VCHORDRQ": f""" CREATE INDEX {{vector_index_name}} - ON {{k}} USING vchordrq (content_vector vector_cosine_ops) - {f'WITH (options = $${self.vchordrq_build_options}$$)' if self.vchordrq_build_options else ''} + ON {{table_name}} USING vchordrq (content_vector vector_cosine_ops) + {f"WITH (options = $${self.vchordrq_build_options}$$)" if self.vchordrq_build_options else ""} """, } - embedding_dim = int(os.environ.get("EMBEDDING_DIM", 1024)) - for k in vdb_tables: - vector_index_name = ( - f"idx_{k.lower()}_{self.vector_index_type.lower()}_cosine" + if self.vector_index_type not in create_sql: + logger.warning( + f"Unsupported vector index type: {self.vector_index_type}. " + "Supported types: HNSW, IVFFLAT, VCHORDRQ" ) - check_vector_index_sql = f""" - SELECT 1 FROM pg_indexes - WHERE indexname = '{vector_index_name}' AND tablename = '{k.lower()}' - """ - try: - vector_index_exists = await self.query(check_vector_index_sql) - if not vector_index_exists: - # Only set vector dimension when index doesn't exist - alter_sql = f"ALTER TABLE {k} ALTER COLUMN content_vector TYPE VECTOR({embedding_dim})" - await self.execute(alter_sql) - logger.debug(f"Ensured vector dimension for {k}") - logger.info( - f"Creating {self.vector_index_type} index {vector_index_name} on table {k}" + return + + k = table_name + # Use _safe_index_name to avoid PostgreSQL's 63-byte identifier truncation + index_suffix = f"{self.vector_index_type.lower()}_cosine" + vector_index_name = _safe_index_name(k, index_suffix) + check_vector_index_sql = f""" + SELECT 1 FROM pg_indexes + WHERE indexname = '{vector_index_name}' AND tablename = '{k.lower()}' + """ + try: + vector_index_exists = await self.query(check_vector_index_sql) + if not vector_index_exists: + # Only set vector dimension when index doesn't exist + alter_sql = f"ALTER TABLE {k} ALTER COLUMN content_vector TYPE VECTOR({embedding_dim})" + await self.execute(alter_sql) + logger.debug(f"Ensured vector dimension for {k}") + logger.info( + f"Creating {self.vector_index_type} index {vector_index_name} on table {k}" + ) + await self.execute( + create_sql[self.vector_index_type].format( + vector_index_name=vector_index_name, table_name=k ) - await self.execute( - create_sql[self.vector_index_type].format( - vector_index_name=vector_index_name, k=k - ) - ) - logger.info( - f"Successfully created vector index {vector_index_name} on table {k}" - ) - else: - logger.info( - f"{self.vector_index_type} vector index {vector_index_name} already exists on table {k}" - ) - except Exception as e: - logger.error(f"Failed to create vector index on table {k}, Got: {e}") + ) + logger.info( + f"Successfully created vector index {vector_index_name} on table {k}" + ) + else: + logger.info( + f"{self.vector_index_type} vector index {vector_index_name} already exists on table {k}" + ) + except Exception as e: + logger.error(f"Failed to create vector index on table {k}, Got: {e}") async def query( self, @@ -1474,6 +1603,24 @@ class PostgreSQLDB: logger.error(f"PostgreSQL database, error:{e}") raise + async def check_table_exists(self, table_name: str) -> bool: + """Check if a table exists in PostgreSQL database + + Args: + table_name: Name of the table to check + + Returns: + bool: True if table exists, False otherwise + """ + query = """ + SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = $1 + ) + """ + result = await self.query(query, [table_name.lower()]) + return result.get("exists", False) if result else False + async def execute( self, sql: str, @@ -2181,6 +2328,7 @@ class PGVectorStorage(BaseVectorStorage): db: PostgreSQLDB | None = field(default=None) def __post_init__(self): + super().__post_init__() self._max_batch_size = self.global_config["embedding_batch_num"] config = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = config.get("cosine_better_than_threshold") @@ -2190,6 +2338,443 @@ class PGVectorStorage(BaseVectorStorage): ) self.cosine_better_than_threshold = cosine_threshold + # Generate model suffix for table isolation + self.model_suffix = self._generate_collection_suffix() + + # Get base table name + base_table = namespace_to_table_name(self.namespace) + if not base_table: + raise ValueError(f"Unknown namespace: {self.namespace}") + + # New table name (with suffix) + # Ensure model_suffix is not empty before appending + if self.model_suffix: + self.table_name = f"{base_table}_{self.model_suffix}" + logger.info(f"PostgreSQL table: {self.table_name}") + else: + # Fallback: use base table name if model_suffix is unavailable + self.table_name = base_table + logger.warning( + f"PostgreSQL table: {self.table_name} missing suffix. Pls add model_name to embedding_func for proper workspace data isolation." + ) + + # Legacy table name (without suffix, for migration) + self.legacy_table_name = base_table + + # Validate table name length (PostgreSQL identifier limit is 63 characters) + if len(self.table_name) > PG_MAX_IDENTIFIER_LENGTH: + raise ValueError( + f"PostgreSQL table name exceeds {PG_MAX_IDENTIFIER_LENGTH} character limit: '{self.table_name}' " + f"(length: {len(self.table_name)}). " + f"Consider using a shorter embedding model name or workspace name." + ) + + @staticmethod + async def _pg_create_table( + db: PostgreSQLDB, table_name: str, base_table: str, embedding_dim: int + ) -> None: + """Create a new vector table by replacing the table name in DDL template, + and create indexes on id and (workspace, id) columns. + + Args: + db: PostgreSQLDB instance + table_name: Name of the new table to create + base_table: Base table name for DDL template lookup + embedding_dim: Embedding dimension for vector column + """ + if base_table not in TABLES: + raise ValueError(f"No DDL template found for table: {base_table}") + + ddl_template = TABLES[base_table]["ddl"] + + # Replace embedding dimension placeholder if exists + ddl = ddl_template.replace("VECTOR(dimension)", f"VECTOR({embedding_dim})") + + # Replace table name + ddl = ddl.replace(base_table, table_name) + + await db.execute(ddl) + + # Create indexes similar to check_tables() but with safe index names + # Create index for id column + id_index_name = _safe_index_name(table_name, "id") + try: + create_id_index_sql = f"CREATE INDEX {id_index_name} ON {table_name}(id)" + logger.info( + f"PostgreSQL, Creating index {id_index_name} on table {table_name}" + ) + await db.execute(create_id_index_sql) + except Exception as e: + logger.error( + f"PostgreSQL, Failed to create index {id_index_name}, Got: {e}" + ) + + # Create composite index for (workspace, id) + workspace_id_index_name = _safe_index_name(table_name, "workspace_id") + try: + create_composite_index_sql = ( + f"CREATE INDEX {workspace_id_index_name} ON {table_name}(workspace, id)" + ) + logger.info( + f"PostgreSQL, Creating composite index {workspace_id_index_name} on table {table_name}" + ) + await db.execute(create_composite_index_sql) + except Exception as e: + logger.error( + f"PostgreSQL, Failed to create composite index {workspace_id_index_name}, Got: {e}" + ) + + @staticmethod + async def _pg_migrate_workspace_data( + db: PostgreSQLDB, + legacy_table_name: str, + new_table_name: str, + workspace: str, + expected_count: int, + embedding_dim: int, + ) -> int: + """Migrate workspace data from legacy table to new table using batch insert. + + This function uses asyncpg's executemany for efficient batch insertion, + reducing database round-trips from N to 1 per batch. + + Uses keyset pagination (cursor-based) with ORDER BY id for stable ordering. + This ensures every legacy row is migrated exactly once, avoiding the + non-deterministic row ordering issues with OFFSET/LIMIT without ORDER BY. + + Args: + db: PostgreSQLDB instance + legacy_table_name: Name of the legacy table to migrate from + new_table_name: Name of the new table to migrate to + workspace: Workspace to filter records for migration + expected_count: Expected number of records to migrate + embedding_dim: Embedding dimension for vector column + + Returns: + Number of records migrated + """ + migrated_count = 0 + last_id: str | None = None + batch_size = 500 + + while True: + # Use keyset pagination with ORDER BY id for deterministic ordering + # This avoids OFFSET/LIMIT without ORDER BY which can skip or duplicate rows + if workspace: + if last_id is not None: + select_query = f"SELECT * FROM {legacy_table_name} WHERE workspace = $1 AND id > $2 ORDER BY id LIMIT $3" + rows = await db.query( + select_query, [workspace, last_id, batch_size], multirows=True + ) + else: + select_query = f"SELECT * FROM {legacy_table_name} WHERE workspace = $1 ORDER BY id LIMIT $2" + rows = await db.query( + select_query, [workspace, batch_size], multirows=True + ) + else: + if last_id is not None: + select_query = f"SELECT * FROM {legacy_table_name} WHERE id > $1 ORDER BY id LIMIT $2" + rows = await db.query( + select_query, [last_id, batch_size], multirows=True + ) + else: + select_query = ( + f"SELECT * FROM {legacy_table_name} ORDER BY id LIMIT $1" + ) + rows = await db.query(select_query, [batch_size], multirows=True) + + if not rows: + break + + # Track the last ID for keyset pagination cursor + last_id = rows[-1]["id"] + + # Batch insert optimization: use executemany instead of individual inserts + # Get column names from the first row + first_row = dict(rows[0]) + columns = list(first_row.keys()) + columns_str = ", ".join(columns) + placeholders = ", ".join([f"${i + 1}" for i in range(len(columns))]) + + insert_query = f""" + INSERT INTO {new_table_name} ({columns_str}) + VALUES ({placeholders}) + ON CONFLICT (workspace, id) DO NOTHING + """ + + # Prepare batch data: convert rows to list of tuples + batch_values = [] + for row in rows: + row_dict = dict(row) + + # FIX: Parse vector strings from connections without register_vector codec. + # When pgvector codec is not registered on the read connection, vector + # columns are returned as text strings like "[0.1,0.2,...]" instead of + # lists/arrays. We need to convert these to numpy arrays before passing + # to executemany, which uses a connection WITH register_vector codec + # that expects list/tuple/ndarray types. + if "content_vector" in row_dict: + vec = row_dict["content_vector"] + if isinstance(vec, str): + # pgvector text format: "[0.1,0.2,0.3,...]" + vec = vec.strip("[]") + if vec: + row_dict["content_vector"] = np.array( + [float(x) for x in vec.split(",")], dtype=np.float32 + ) + else: + row_dict["content_vector"] = None + + # Extract values in column order to match placeholders + values_tuple = tuple(row_dict[col] for col in columns) + batch_values.append(values_tuple) + + # Use executemany for batch execution - significantly reduces DB round-trips + # Note: register_vector is already called on pool init, no need to call it again + async def _batch_insert(connection: asyncpg.Connection) -> None: + await connection.executemany(insert_query, batch_values) + + await db._run_with_retry(_batch_insert) + + migrated_count += len(rows) + workspace_info = f" for workspace '{workspace}'" if workspace else "" + logger.info( + f"PostgreSQL: {migrated_count}/{expected_count} records migrated{workspace_info}" + ) + + return migrated_count + + @staticmethod + async def setup_table( + db: PostgreSQLDB, + table_name: str, + workspace: str, + embedding_dim: int, + legacy_table_name: str, + base_table: str, + ): + """ + Setup PostgreSQL table with migration support from legacy tables. + + Ensure final table has workspace isolation index. + Check vector dimension compatibility before new table creation. + Drop legacy table if it exists and is empty. + Only migrate data from legacy table to new table when new table first created and legacy table is not empty. + This function must be call ClientManager.get_client() to legacy table is migrated to latest schema. + + Args: + db: PostgreSQLDB instance + table_name: Name of the new table + workspace: Workspace to filter records for migration + legacy_table_name: Name of the legacy table to check for migration + base_table: Base table name for DDL template lookup + embedding_dim: Embedding dimension for vector column + """ + if not workspace: + raise ValueError("workspace must be provided") + + new_table_exists = await db.check_table_exists(table_name) + legacy_exists = legacy_table_name and await db.check_table_exists( + legacy_table_name + ) + + # Case 1: Only new table exists or new table is the same as legacy table + # No data migration needed, ensuring index is created then return + if (new_table_exists and not legacy_exists) or ( + new_table_exists and (table_name.lower() == legacy_table_name.lower()) + ): + await db._create_vector_index(table_name, embedding_dim) + + workspace_count_query = ( + f"SELECT COUNT(*) as count FROM {table_name} WHERE workspace = $1" + ) + workspace_count_result = await db.query(workspace_count_query, [workspace]) + workspace_count = ( + workspace_count_result.get("count", 0) if workspace_count_result else 0 + ) + if workspace_count == 0 and not ( + table_name.lower() == legacy_table_name.lower() + ): + logger.warning( + f"PostgreSQL: workspace data in table '{table_name}' is empty. " + f"Ensure it is caused by new workspace setup and not an unexpected embedding model change." + ) + + return + + legacy_count = None + if not new_table_exists: + # Check vector dimension compatibility before creating new table + if legacy_exists: + count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name} WHERE workspace = $1" + count_result = await db.query(count_query, [workspace]) + legacy_count = count_result.get("count", 0) if count_result else 0 + + if legacy_count > 0: + legacy_dim = None + try: + sample_query = f"SELECT content_vector FROM {legacy_table_name} WHERE workspace = $1 LIMIT 1" + sample_result = await db.query(sample_query, [workspace]) + # Fix: Use 'is not None' instead of truthiness check to avoid + # NumPy array boolean ambiguity error + if ( + sample_result + and sample_result.get("content_vector") is not None + ): + vector_data = sample_result["content_vector"] + # pgvector returns list directly, but may also return NumPy arrays + # when register_vector codec is active on the connection + if isinstance(vector_data, (list, tuple)): + legacy_dim = len(vector_data) + elif hasattr(vector_data, "__len__") and not isinstance( + vector_data, str + ): + # Handle NumPy arrays and other array-like objects + legacy_dim = len(vector_data) + elif isinstance(vector_data, str): + import json + + vector_list = json.loads(vector_data) + legacy_dim = len(vector_list) + + if legacy_dim and legacy_dim != embedding_dim: + logger.error( + f"PostgreSQL: Dimension mismatch detected! " + f"Legacy table '{legacy_table_name}' has {legacy_dim}d vectors, " + f"but new embedding model expects {embedding_dim}d." + ) + raise DataMigrationError( + f"Dimension mismatch between legacy table '{legacy_table_name}' " + f"and new embedding model. Expected {embedding_dim}d but got {legacy_dim}d." + ) + + except DataMigrationError: + # Re-raise DataMigrationError as-is to preserve specific error messages + raise + except Exception as e: + raise DataMigrationError( + f"Could not verify legacy table vector dimension: {e}. " + f"Proceeding with caution..." + ) + + await PGVectorStorage._pg_create_table( + db, table_name, base_table, embedding_dim + ) + logger.info(f"PostgreSQL: New table '{table_name}' created successfully") + + if not legacy_exists: + await db._create_vector_index(table_name, embedding_dim) + logger.info( + "Ensure this new table creation is caused by new workspace setup and not an unexpected embedding model change." + ) + return + + # Ensure vector index is created + await db._create_vector_index(table_name, embedding_dim) + + # Case 2: Legacy table exist + if legacy_exists: + workspace_info = f" for workspace '{workspace}'" + + # Only drop legacy table if entire table is empty + total_count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name}" + total_count_result = await db.query(total_count_query, []) + total_count = ( + total_count_result.get("count", 0) if total_count_result else 0 + ) + if total_count == 0: + logger.info( + f"PostgreSQL: Empty legacy table '{legacy_table_name}' deleted successfully" + ) + drop_query = f"DROP TABLE {legacy_table_name}" + await db.execute(drop_query, None) + return + + # No data migration needed if legacy workspace is empty + if legacy_count is None: + count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name} WHERE workspace = $1" + count_result = await db.query(count_query, [workspace]) + legacy_count = count_result.get("count", 0) if count_result else 0 + + if legacy_count == 0: + logger.info( + f"PostgreSQL: No records{workspace_info} found in legacy table. " + f"No data migration needed." + ) + return + + new_count_query = ( + f"SELECT COUNT(*) as count FROM {table_name} WHERE workspace = $1" + ) + new_count_result = await db.query(new_count_query, [workspace]) + new_table_workspace_count = ( + new_count_result.get("count", 0) if new_count_result else 0 + ) + + if new_table_workspace_count > 0: + logger.warning( + f"PostgreSQL: Both new and legacy collection have data. " + f"{legacy_count} records in {legacy_table_name} require manual deletion after migration verification." + ) + return + + # Case 3: Legacy has workspace data and new table is empty for workspace + logger.info( + f"PostgreSQL: Found legacy table '{legacy_table_name}' with {legacy_count} records{workspace_info}." + ) + logger.info( + f"PostgreSQL: Migrating data from legacy table '{legacy_table_name}' to new table '{table_name}'" + ) + + try: + migrated_count = await PGVectorStorage._pg_migrate_workspace_data( + db, + legacy_table_name, + table_name, + workspace, + legacy_count, + embedding_dim, + ) + if migrated_count != legacy_count: + logger.warning( + "PostgreSQL: Read %s legacy records%s during migration, expected %s.", + migrated_count, + workspace_info, + legacy_count, + ) + + new_count_result = await db.query(new_count_query, [workspace]) + new_table_count_after = ( + new_count_result.get("count", 0) if new_count_result else 0 + ) + inserted_count = new_table_count_after - new_table_workspace_count + + if inserted_count != legacy_count: + error_msg = ( + "PostgreSQL: Migration verification failed, " + f"expected {legacy_count} inserted records, got {inserted_count}." + ) + logger.error(error_msg) + raise DataMigrationError(error_msg) + + except DataMigrationError: + # Re-raise DataMigrationError as-is to preserve specific error messages + raise + except Exception as e: + logger.error( + f"PostgreSQL: Failed to migrate data from legacy table '{legacy_table_name}' to new table '{table_name}': {e}" + ) + raise DataMigrationError( + f"Failed to migrate data from legacy table '{legacy_table_name}' to new table '{table_name}'" + ) from e + + logger.info( + f"PostgreSQL: Migration from '{legacy_table_name}' to '{table_name}' completed successfully" + ) + logger.warning( + "PostgreSQL: Manual deletion is required after data migration verification." + ) + async def initialize(self): async with get_data_init_lock(): if self.db is None: @@ -2206,6 +2791,16 @@ class PGVectorStorage(BaseVectorStorage): # Use "default" for compatibility (lowest priority) self.workspace = "default" + # Setup table (create if not exists and handle migration) + await PGVectorStorage.setup_table( + self.db, + self.table_name, + self.workspace, # CRITICAL: Filter migration by workspace + embedding_dim=self.embedding_func.embedding_dim, + legacy_table_name=self.legacy_table_name, + base_table=self.legacy_table_name, # base_table for DDL template lookup + ) + async def finalize(self): if self.db is not None: await ClientManager.release_client(self.db) @@ -2213,75 +2808,97 @@ class PGVectorStorage(BaseVectorStorage): def _upsert_chunks( self, item: dict[str, Any], current_time: datetime.datetime - ) -> tuple[str, dict[str, Any]]: + ) -> tuple[str, tuple[Any, ...]]: + """Prepare upsert data for chunks. + + Returns: + Tuple of (SQL template, values tuple for executemany) + """ try: - upsert_sql = SQL_TEMPLATES["upsert_chunk"] - data: dict[str, Any] = { - "workspace": self.workspace, - "id": item["__id__"], - "tokens": item["tokens"], - "chunk_order_index": item["chunk_order_index"], - "full_doc_id": item["full_doc_id"], - "content": item["content"], - "content_vector": json.dumps(item["__vector__"].tolist()), - "file_path": item["file_path"], - "create_time": current_time, - "update_time": current_time, - } + upsert_sql = SQL_TEMPLATES["upsert_chunk"].format( + table_name=self.table_name + ) + # Return tuple in the exact order of SQL parameters ($1, $2, ...) + values: tuple[Any, ...] = ( + self.workspace, # $1 + item["__id__"], # $2 + item["tokens"], # $3 + item["chunk_order_index"], # $4 + item["full_doc_id"], # $5 + item["content"], # $6 + item["__vector__"], # $7 - numpy array, handled by pgvector codec + item["file_path"], # $8 + current_time, # $9 + current_time, # $10 + ) except Exception as e: logger.error( - f"[{self.workspace}] Error to prepare upsert,\nsql: {e}\nitem: {item}" + f"[{self.workspace}] Error to prepare upsert,\nerror: {e}\nitem: {item}" ) raise - return upsert_sql, data + return upsert_sql, values def _upsert_entities( self, item: dict[str, Any], current_time: datetime.datetime - ) -> tuple[str, dict[str, Any]]: - upsert_sql = SQL_TEMPLATES["upsert_entity"] + ) -> tuple[str, tuple[Any, ...]]: + """Prepare upsert data for entities. + + Returns: + Tuple of (SQL template, values tuple for executemany) + """ + upsert_sql = SQL_TEMPLATES["upsert_entity"].format(table_name=self.table_name) source_id = item["source_id"] if isinstance(source_id, str) and "" in source_id: chunk_ids = source_id.split("") else: chunk_ids = [source_id] - data: dict[str, Any] = { - "workspace": self.workspace, - "id": item["__id__"], - "entity_name": item["entity_name"], - "content": item["content"], - "content_vector": json.dumps(item["__vector__"].tolist()), - "chunk_ids": chunk_ids, - "file_path": item.get("file_path", None), - "create_time": current_time, - "update_time": current_time, - } - return upsert_sql, data + # Return tuple in the exact order of SQL parameters ($1, $2, ...) + values: tuple[Any, ...] = ( + self.workspace, # $1 + item["__id__"], # $2 + item["entity_name"], # $3 + item["content"], # $4 + item["__vector__"], # $5 - numpy array, handled by pgvector codec + chunk_ids, # $6 + item.get("file_path", None), # $7 + current_time, # $8 + current_time, # $9 + ) + return upsert_sql, values def _upsert_relationships( self, item: dict[str, Any], current_time: datetime.datetime - ) -> tuple[str, dict[str, Any]]: - upsert_sql = SQL_TEMPLATES["upsert_relationship"] + ) -> tuple[str, tuple[Any, ...]]: + """Prepare upsert data for relationships. + + Returns: + Tuple of (SQL template, values tuple for executemany) + """ + upsert_sql = SQL_TEMPLATES["upsert_relationship"].format( + table_name=self.table_name + ) source_id = item["source_id"] if isinstance(source_id, str) and "" in source_id: chunk_ids = source_id.split("") else: chunk_ids = [source_id] - data: dict[str, Any] = { - "workspace": self.workspace, - "id": item["__id__"], - "source_id": item["src_id"], - "target_id": item["tgt_id"], - "content": item["content"], - "content_vector": json.dumps(item["__vector__"].tolist()), - "chunk_ids": chunk_ids, - "file_path": item.get("file_path", None), - "create_time": current_time, - "update_time": current_time, - } - return upsert_sql, data + # Return tuple in the exact order of SQL parameters ($1, $2, ...) + values: tuple[Any, ...] = ( + self.workspace, # $1 + item["__id__"], # $2 + item["src_id"], # $3 + item["tgt_id"], # $4 + item["content"], # $5 + item["__vector__"], # $6 - numpy array, handled by pgvector codec + chunk_ids, # $7 + item.get("file_path", None), # $8 + current_time, # $9 + current_time, # $10 + ) + return upsert_sql, values async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}") @@ -2309,17 +2926,34 @@ class PGVectorStorage(BaseVectorStorage): embeddings = np.concatenate(embeddings_list) for i, d in enumerate(list_data): d["__vector__"] = embeddings[i] + + # Prepare batch values for executemany + batch_values: list[tuple[Any, ...]] = [] + upsert_sql = None + for item in list_data: if is_namespace(self.namespace, NameSpace.VECTOR_STORE_CHUNKS): - upsert_sql, data = self._upsert_chunks(item, current_time) + upsert_sql, values = self._upsert_chunks(item, current_time) elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_ENTITIES): - upsert_sql, data = self._upsert_entities(item, current_time) + upsert_sql, values = self._upsert_entities(item, current_time) elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_RELATIONSHIPS): - upsert_sql, data = self._upsert_relationships(item, current_time) + upsert_sql, values = self._upsert_relationships(item, current_time) else: raise ValueError(f"{self.namespace} is not supported") - await self.db.execute(upsert_sql, data) + batch_values.append(values) + + # Use executemany for batch execution - significantly reduces DB round-trips + # Note: register_vector is already called on pool init, no need to call it again + if batch_values and upsert_sql: + + async def _batch_upsert(connection: asyncpg.Connection) -> None: + await connection.executemany(upsert_sql, batch_values) + + await self.db._run_with_retry(_batch_upsert) + logger.debug( + f"[{self.workspace}] Batch upserted {len(batch_values)} records to {self.namespace}" + ) #################### query method ############### async def query( @@ -2335,7 +2969,9 @@ class PGVectorStorage(BaseVectorStorage): embedding_string = ",".join(map(str, embedding)) - sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string) + sql = SQL_TEMPLATES[self.namespace].format( + embedding_string=embedding_string, table_name=self.table_name + ) params = { "workspace": self.workspace, "closer_than_threshold": 1 - self.cosine_better_than_threshold, @@ -2357,14 +2993,9 @@ class PGVectorStorage(BaseVectorStorage): if not ids: return - table_name = namespace_to_table_name(self.namespace) - if not table_name: - logger.error( - f"[{self.workspace}] Unknown namespace for vector deletion: {self.namespace}" - ) - return - - delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)" + delete_sql = ( + f"DELETE FROM {self.table_name} WHERE workspace=$1 AND id = ANY($2)" + ) try: await self.db.execute(delete_sql, {"workspace": self.workspace, "ids": ids}) @@ -2383,8 +3014,8 @@ class PGVectorStorage(BaseVectorStorage): entity_name: The name of the entity to delete """ try: - # Construct SQL to delete the entity - delete_sql = """DELETE FROM LIGHTRAG_VDB_ENTITY + # Construct SQL to delete the entity using dynamic table name + delete_sql = f"""DELETE FROM {self.table_name} WHERE workspace=$1 AND entity_name=$2""" await self.db.execute( @@ -2404,7 +3035,7 @@ class PGVectorStorage(BaseVectorStorage): """ try: # Delete relations where the entity is either the source or target - delete_sql = """DELETE FROM LIGHTRAG_VDB_RELATION + delete_sql = f"""DELETE FROM {self.table_name} WHERE workspace=$1 AND (source_id=$2 OR target_id=$2)""" await self.db.execute( @@ -2427,14 +3058,7 @@ class PGVectorStorage(BaseVectorStorage): Returns: The vector data if found, or None if not found """ - table_name = namespace_to_table_name(self.namespace) - if not table_name: - logger.error( - f"[{self.workspace}] Unknown namespace for ID lookup: {self.namespace}" - ) - return None - - query = f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM {table_name} WHERE workspace=$1 AND id=$2" + query = f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM {self.table_name} WHERE workspace=$1 AND id=$2" params = {"workspace": self.workspace, "id": id} try: @@ -2460,15 +3084,8 @@ class PGVectorStorage(BaseVectorStorage): if not ids: return [] - table_name = namespace_to_table_name(self.namespace) - if not table_name: - logger.error( - f"[{self.workspace}] Unknown namespace for IDs lookup: {self.namespace}" - ) - return [] - ids_str = ",".join([f"'{id}'" for id in ids]) - query = f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM {table_name} WHERE workspace=$1 AND id IN ({ids_str})" + query = f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM {self.table_name} WHERE workspace=$1 AND id IN ({ids_str})" params = {"workspace": self.workspace} try: @@ -2509,15 +3126,8 @@ class PGVectorStorage(BaseVectorStorage): if not ids: return {} - table_name = namespace_to_table_name(self.namespace) - if not table_name: - logger.error( - f"[{self.workspace}] Unknown namespace for vector lookup: {self.namespace}" - ) - return {} - ids_str = ",".join([f"'{id}'" for id in ids]) - query = f"SELECT id, content_vector FROM {table_name} WHERE workspace=$1 AND id IN ({ids_str})" + query = f"SELECT id, content_vector FROM {self.table_name} WHERE workspace=$1 AND id IN ({ids_str})" params = {"workspace": self.workspace} try: @@ -2527,10 +3137,18 @@ class PGVectorStorage(BaseVectorStorage): for result in results: if result and "content_vector" in result and "id" in result: try: - # Parse JSON string to get vector as list of floats - vector_data = json.loads(result["content_vector"]) - if isinstance(vector_data, list): - vectors_dict[result["id"]] = vector_data + vector_data = result["content_vector"] + # Handle both pgvector-registered connections (returns list/tuple) + # and non-registered connections (returns JSON string) + if isinstance(vector_data, (list, tuple)): + vectors_dict[result["id"]] = list(vector_data) + elif isinstance(vector_data, str): + parsed = json.loads(vector_data) + if isinstance(parsed, list): + vectors_dict[result["id"]] = parsed + # Handle numpy arrays from pgvector + elif hasattr(vector_data, "tolist"): + vectors_dict[result["id"]] = vector_data.tolist() except (json.JSONDecodeError, TypeError) as e: logger.warning( f"[{self.workspace}] Failed to parse vector data for ID {result['id']}: {e}" @@ -2546,15 +3164,8 @@ class PGVectorStorage(BaseVectorStorage): async def drop(self) -> dict[str, str]: """Drop the storage""" try: - table_name = namespace_to_table_name(self.namespace) - if not table_name: - return { - "status": "error", - "message": f"Unknown namespace: {self.namespace}", - } - drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( - table_name=table_name + table_name=self.table_name ) await self.db.execute(drop_sql, {"workspace": self.workspace}) return {"status": "success", "message": "data dropped"} @@ -2593,6 +3204,9 @@ class PGDocStatusStorage(DocStatusStorage): # Use "default" for compatibility (lowest priority) self.workspace = "default" + # NOTE: Table creation is handled by PostgreSQLDB.initdb() during initialization + # No need to create table here as it's already created in the TABLES dict + async def finalize(self): if self.db is not None: await ClientManager.release_client(self.db) @@ -4787,14 +5401,14 @@ TABLES = { )""" }, "LIGHTRAG_VDB_CHUNKS": { - "ddl": f"""CREATE TABLE LIGHTRAG_VDB_CHUNKS ( + "ddl": """CREATE TABLE LIGHTRAG_VDB_CHUNKS ( id VARCHAR(255), workspace VARCHAR(255), full_doc_id VARCHAR(256), chunk_order_index INTEGER, tokens INTEGER, content TEXT, - content_vector VECTOR({os.environ.get("EMBEDDING_DIM", 1024)}), + content_vector VECTOR(dimension), file_path TEXT NULL, create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, @@ -4802,12 +5416,12 @@ TABLES = { )""" }, "LIGHTRAG_VDB_ENTITY": { - "ddl": f"""CREATE TABLE LIGHTRAG_VDB_ENTITY ( + "ddl": """CREATE TABLE LIGHTRAG_VDB_ENTITY ( id VARCHAR(255), workspace VARCHAR(255), entity_name VARCHAR(512), content TEXT, - content_vector VECTOR({os.environ.get("EMBEDDING_DIM", 1024)}), + content_vector VECTOR(dimension), create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, chunk_ids VARCHAR(255)[] NULL, @@ -4816,13 +5430,13 @@ TABLES = { )""" }, "LIGHTRAG_VDB_RELATION": { - "ddl": f"""CREATE TABLE LIGHTRAG_VDB_RELATION ( + "ddl": """CREATE TABLE LIGHTRAG_VDB_RELATION ( id VARCHAR(255), workspace VARCHAR(255), source_id VARCHAR(512), target_id VARCHAR(512), content TEXT, - content_vector VECTOR({os.environ.get("EMBEDDING_DIM", 1024)}), + content_vector VECTOR(dimension), create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, chunk_ids VARCHAR(255)[] NULL, @@ -5047,7 +5661,7 @@ SQL_TEMPLATES = { update_time = EXCLUDED.update_time """, # SQL for VectorStorage - "upsert_chunk": """INSERT INTO LIGHTRAG_VDB_CHUNKS (workspace, id, tokens, + "upsert_chunk": """INSERT INTO {table_name} (workspace, id, tokens, chunk_order_index, full_doc_id, content, content_vector, file_path, create_time, update_time) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) @@ -5060,7 +5674,7 @@ SQL_TEMPLATES = { file_path=EXCLUDED.file_path, update_time = EXCLUDED.update_time """, - "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, + "upsert_entity": """INSERT INTO {table_name} (workspace, id, entity_name, content, content_vector, chunk_ids, file_path, create_time, update_time) VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7, $8, $9) ON CONFLICT (workspace,id) DO UPDATE @@ -5071,7 +5685,7 @@ SQL_TEMPLATES = { file_path=EXCLUDED.file_path, update_time=EXCLUDED.update_time """, - "upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id, + "upsert_relationship": """INSERT INTO {table_name} (workspace, id, source_id, target_id, content, content_vector, chunk_ids, file_path, create_time, update_time) VALUES ($1, $2, $3, $4, $5, $6, $7::varchar[], $8, $9, $10) ON CONFLICT (workspace,id) DO UPDATE @@ -5087,7 +5701,7 @@ SQL_TEMPLATES = { SELECT r.source_id AS src_id, r.target_id AS tgt_id, EXTRACT(EPOCH FROM r.create_time)::BIGINT AS created_at - FROM LIGHTRAG_VDB_RELATION r + FROM {table_name} r WHERE r.workspace = $1 AND r.content_vector <=> '[{embedding_string}]'::vector < $2 ORDER BY r.content_vector <=> '[{embedding_string}]'::vector @@ -5096,7 +5710,7 @@ SQL_TEMPLATES = { "entities": """ SELECT e.entity_name, EXTRACT(EPOCH FROM e.create_time)::BIGINT AS created_at - FROM LIGHTRAG_VDB_ENTITY e + FROM {table_name} e WHERE e.workspace = $1 AND e.content_vector <=> '[{embedding_string}]'::vector < $2 ORDER BY e.content_vector <=> '[{embedding_string}]'::vector @@ -5107,7 +5721,7 @@ SQL_TEMPLATES = { c.content, c.file_path, EXTRACT(EPOCH FROM c.create_time)::BIGINT AS created_at - FROM LIGHTRAG_VDB_CHUNKS c + FROM {table_name} c WHERE c.workspace = $1 AND c.content_vector <=> '[{embedding_string}]'::vector < $2 ORDER BY c.content_vector <=> '[{embedding_string}]'::vector diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index 75de2613..81f0f4e4 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -10,7 +10,7 @@ import numpy as np import pipmaster as pm from ..base import BaseVectorStorage -from ..exceptions import QdrantMigrationError +from ..exceptions import DataMigrationError from ..kg.shared_storage import get_data_init_lock from ..utils import compute_mdhash_id, logger @@ -66,6 +66,51 @@ def workspace_filter_condition(workspace: str) -> models.FieldCondition: ) +def _find_legacy_collection( + client: QdrantClient, + namespace: str, + workspace: str = None, + model_suffix: str = None, +) -> str | None: + """ + Find legacy collection with backward compatibility support. + + This function tries multiple naming patterns to locate legacy collections + created by older versions of LightRAG: + + 1. lightrag_vdb_{namespace} - if model_suffix is provided (HIGHEST PRIORITY) + 2. {workspace}_{namespace} or {namespace} - no matter if model_suffix is provided or not + 3. lightrag_vdb_{namespace} - fall back value no matter if model_suffix is provided or not (LOWEST PRIORITY) + + Args: + client: QdrantClient instance + namespace: Base namespace (e.g., "chunks", "entities") + workspace: Optional workspace identifier + model_suffix: Optional model suffix for new collection + + Returns: + Collection name if found, None otherwise + """ + # Try multiple naming patterns for backward compatibility + # More specific names (with workspace) have higher priority + candidates = [ + f"lightrag_vdb_{namespace}" if model_suffix else None, + f"{workspace}_{namespace}" if workspace else None, + f"lightrag_vdb_{namespace}", + namespace, + ] + + for candidate in candidates: + if candidate and client.collection_exists(candidate): + logger.info( + f"Qdrant: Found legacy collection '{candidate}' " + f"(namespace={namespace}, workspace={workspace or 'none'})" + ) + return candidate + + return None + + @final @dataclass class QdrantVectorDBStorage(BaseVectorStorage): @@ -79,63 +124,52 @@ class QdrantVectorDBStorage(BaseVectorStorage): embedding_func=embedding_func, meta_fields=meta_fields or set(), ) - self.__post_init__() + # __post_init__() is automatically called by super().__init__() @staticmethod def setup_collection( client: QdrantClient, collection_name: str, - legacy_namespace: str = None, - workspace: str = None, - **kwargs, + namespace: str, + workspace: str, + vectors_config: models.VectorParams, + hnsw_config: models.HnswConfigDiff, + model_suffix: str, ): """ Setup Qdrant collection with migration support from legacy collections. + Ensure final collection has workspace isolation index. + Check vector dimension compatibility before new collection creation. + Drop legacy collection if it exists and is empty. + Only migrate data from legacy collection to new collection when new collection first created and legacy collection is not empty. + Args: client: QdrantClient instance - collection_name: Name of the new collection - legacy_namespace: Name of the legacy collection (if exists) + collection_name: Name of the final collection + namespace: Base namespace (e.g., "chunks", "entities") workspace: Workspace identifier for data isolation - **kwargs: Additional arguments for collection creation (vectors_config, hnsw_config, etc.) + vectors_config: Vector configuration parameters for the collection + hnsw_config: HNSW index configuration diff for the collection """ + if not namespace or not workspace: + raise ValueError("namespace and workspace must be provided") + + workspace_count_filter = models.Filter( + must=[workspace_filter_condition(workspace)] + ) + new_collection_exists = client.collection_exists(collection_name) - legacy_exists = legacy_namespace and client.collection_exists(legacy_namespace) + legacy_collection = _find_legacy_collection( + client, namespace, workspace, model_suffix + ) - # Case 1: Both new and legacy collections exist - Warning only (no migration) - if new_collection_exists and legacy_exists: - logger.warning( - f"Qdrant: Legacy collection '{legacy_namespace}' still exist. Remove it if migration is complete." - ) - return - - # Case 2: Only new collection exists - Ensure index exists - if new_collection_exists: - # Check if workspace index exists, create if missing - try: - collection_info = client.get_collection(collection_name) - if WORKSPACE_ID_FIELD not in collection_info.payload_schema: - logger.info( - f"Qdrant: Creating missing workspace index for '{collection_name}'" - ) - client.create_payload_index( - collection_name=collection_name, - field_name=WORKSPACE_ID_FIELD, - field_schema=models.KeywordIndexParams( - type=models.KeywordIndexType.KEYWORD, - is_tenant=True, - ), - ) - except Exception as e: - logger.warning( - f"Qdrant: Could not verify/create workspace index for '{collection_name}': {e}" - ) - return - - # Case 3: Neither exists - Create new collection - if not legacy_exists: - logger.info(f"Qdrant: Creating new collection '{collection_name}'") - client.create_collection(collection_name, **kwargs) + # Case 1: Only new collection exists or new collection is the same as legacy collection + # No data migration needed, and ensuring index is created then return + if (new_collection_exists and not legacy_collection) or ( + collection_name == legacy_collection + ): + # create_payload_index return without error if index already exists client.create_payload_index( collection_name=collection_name, field_name=WORKSPACE_ID_FIELD, @@ -144,132 +178,244 @@ class QdrantVectorDBStorage(BaseVectorStorage): is_tenant=True, ), ) - logger.info(f"Qdrant: Collection '{collection_name}' created successfully") + new_workspace_count = client.count( + collection_name=collection_name, + count_filter=workspace_count_filter, + exact=True, + ).count + + # Skip data migration if new collection already has workspace data + if new_workspace_count == 0 and not (collection_name == legacy_collection): + logger.warning( + f"Qdrant: workspace data in collection '{collection_name}' is empty. " + f"Ensure it is caused by new workspace setup and not an unexpected embedding model change." + ) + return - # Case 4: Only legacy exists - Migrate data - logger.info( - f"Qdrant: Migrating data from legacy collection '{legacy_namespace}'" + legacy_count = None + if not new_collection_exists: + # Check vector dimension compatibility before creating new collection + if legacy_collection: + legacy_count = client.count( + collection_name=legacy_collection, exact=True + ).count + if legacy_count > 0: + legacy_info = client.get_collection(legacy_collection) + legacy_dim = legacy_info.config.params.vectors.size + + if vectors_config.size and legacy_dim != vectors_config.size: + logger.error( + f"Qdrant: Dimension mismatch detected! " + f"Legacy collection '{legacy_collection}' has {legacy_dim}d vectors, " + f"but new embedding model expects {vectors_config.size}d." + ) + + raise DataMigrationError( + f"Dimension mismatch between legacy collection '{legacy_collection}' " + f"and new collection. Expected {vectors_config.size}d but got {legacy_dim}d." + ) + + client.create_collection( + collection_name, vectors_config=vectors_config, hnsw_config=hnsw_config + ) + logger.info(f"Qdrant: Collection '{collection_name}' created successfully") + if not legacy_collection: + logger.warning( + "Qdrant: Ensure this new collection creation is caused by new workspace setup and not an unexpected embedding model change." + ) + + # create_payload_index return without error if index already exists + client.create_payload_index( + collection_name=collection_name, + field_name=WORKSPACE_ID_FIELD, + field_schema=models.KeywordIndexParams( + type=models.KeywordIndexType.KEYWORD, + is_tenant=True, + ), ) - try: - # Get legacy collection count - legacy_count = client.count( - collection_name=legacy_namespace, exact=True - ).count - logger.info(f"Qdrant: Found {legacy_count} records in legacy collection") - + # Case 2: Legacy collection exist + if legacy_collection: + # Only drop legacy collection if it's empty + if legacy_count is None: + legacy_count = client.count( + collection_name=legacy_collection, exact=True + ).count if legacy_count == 0: - logger.info("Qdrant: Legacy collection is empty, skipping migration") - # Create new empty collection - client.create_collection(collection_name, **kwargs) - client.create_payload_index( - collection_name=collection_name, - field_name=WORKSPACE_ID_FIELD, - field_schema=models.KeywordIndexParams( - type=models.KeywordIndexType.KEYWORD, - is_tenant=True, - ), + client.delete_collection(collection_name=legacy_collection) + logger.info( + f"Qdrant: Empty legacy collection '{legacy_collection}' deleted successfully" ) return - # Create new collection first - logger.info(f"Qdrant: Creating new collection '{collection_name}'") - client.create_collection(collection_name, **kwargs) + new_workspace_count = client.count( + collection_name=collection_name, + count_filter=workspace_count_filter, + exact=True, + ).count - # Batch migration (500 records per batch) - migrated_count = 0 - offset = None - batch_size = 500 - - while True: - # Scroll through legacy data - result = client.scroll( - collection_name=legacy_namespace, - limit=batch_size, - offset=offset, - with_vectors=True, - with_payload=True, + # Skip data migration if new collection already has workspace data + if new_workspace_count > 0: + logger.warning( + f"Qdrant: Both new and legacy collection have data. " + f"{legacy_count} records in {legacy_collection} require manual deletion after migration verification." ) - points, next_offset = result + return - if not points: - break + # Case 3: Only legacy exists - migrate data from legacy collection to new collection + # Check if legacy collection has workspace_id to determine migration strategy + # Note: payload_schema only reflects INDEXED fields, so we also sample + # actual payloads to detect unindexed workspace_id fields + legacy_info = client.get_collection(legacy_collection) + has_workspace_index = WORKSPACE_ID_FIELD in ( + legacy_info.payload_schema or {} + ) - # Transform points for new collection - new_points = [] - for point in points: - # Add workspace_id to payload - new_payload = dict(point.payload or {}) - new_payload[WORKSPACE_ID_FIELD] = workspace or DEFAULT_WORKSPACE - - # Create new point with workspace-prefixed ID - original_id = new_payload.get(ID_FIELD) - if original_id: - new_point_id = compute_mdhash_id_for_qdrant( - original_id, prefix=workspace or DEFAULT_WORKSPACE + # Detect workspace_id field presence by sampling payloads if not indexed + # This prevents cross-workspace data leakage when workspace_id exists but isn't indexed + has_workspace_field = has_workspace_index + if not has_workspace_index: + # Sample a small batch of points to check for workspace_id in payloads + # All points must have workspace_id if any point has it + sample_result = client.scroll( + collection_name=legacy_collection, + limit=10, # Small sample is sufficient for detection + with_payload=True, + with_vectors=False, + ) + sample_points, _ = sample_result + for point in sample_points: + if point.payload and WORKSPACE_ID_FIELD in point.payload: + has_workspace_field = True + logger.info( + f"Qdrant: Detected unindexed {WORKSPACE_ID_FIELD} field " + f"in legacy collection '{legacy_collection}' via payload sampling" ) - else: - # Fallback: use original point ID - new_point_id = str(point.id) + break - new_points.append( - models.PointStruct( - id=new_point_id, - vector=point.vector, - payload=new_payload, + # Build workspace filter if legacy collection has workspace support + # This prevents cross-workspace data leakage during migration + legacy_scroll_filter = None + if has_workspace_field: + legacy_scroll_filter = models.Filter( + must=[workspace_filter_condition(workspace)] + ) + # Recount with workspace filter for accurate migration tracking + legacy_count = client.count( + collection_name=legacy_collection, + count_filter=legacy_scroll_filter, + exact=True, + ).count + logger.info( + f"Qdrant: Legacy collection has workspace support, " + f"filtering to {legacy_count} records for workspace '{workspace}'" + ) + + logger.info( + f"Qdrant: Found legacy collection '{legacy_collection}' with {legacy_count} records to migrate." + ) + logger.info( + f"Qdrant: Migrating data from legacy collection '{legacy_collection}' to new collection '{collection_name}'" + ) + + try: + # Batch migration (500 records per batch) + migrated_count = 0 + offset = None + batch_size = 500 + + while True: + # Scroll through legacy data with optional workspace filter + result = client.scroll( + collection_name=legacy_collection, + scroll_filter=legacy_scroll_filter, + limit=batch_size, + offset=offset, + with_vectors=True, + with_payload=True, + ) + points, next_offset = result + + if not points: + break + + # Transform points for new collection + new_points = [] + for point in points: + # Set workspace_id in payload + new_payload = dict(point.payload or {}) + new_payload[WORKSPACE_ID_FIELD] = workspace + + # Create new point with workspace-prefixed ID + original_id = new_payload.get(ID_FIELD) + if original_id: + new_point_id = compute_mdhash_id_for_qdrant( + original_id, prefix=workspace + ) + else: + # Fallback: use original point ID + new_point_id = str(point.id) + + new_points.append( + models.PointStruct( + id=new_point_id, + vector=point.vector, + payload=new_payload, + ) ) + + # Upsert to new collection + client.upsert( + collection_name=collection_name, points=new_points, wait=True ) - # Upsert to new collection - client.upsert( - collection_name=collection_name, points=new_points, wait=True + migrated_count += len(points) + logger.info( + f"Qdrant: {migrated_count}/{legacy_count} records migrated" + ) + + # Check if we've reached the end + if next_offset is None: + break + offset = next_offset + + new_count_after = client.count( + collection_name=collection_name, + count_filter=workspace_count_filter, + exact=True, + ).count + inserted_count = new_count_after - new_workspace_count + if inserted_count != legacy_count: + error_msg = ( + "Qdrant: Migration verification failed, expected " + f"{legacy_count} inserted records, got {inserted_count}." + ) + logger.error(error_msg) + raise DataMigrationError(error_msg) + + except DataMigrationError: + # Re-raise DataMigrationError as-is to preserve specific error messages + raise + except Exception as e: + logger.error( + f"Qdrant: Failed to migrate data from legacy collection '{legacy_collection}' to new collection '{collection_name}': {e}" ) - - migrated_count += len(points) - logger.info(f"Qdrant: {migrated_count}/{legacy_count} records migrated") - - # Check if we've reached the end - if next_offset is None: - break - offset = next_offset - - # Verify migration by comparing counts - logger.info("Verifying migration...") - new_count = client.count(collection_name=collection_name, exact=True).count - - if new_count != legacy_count: - error_msg = f"Qdrant: Migration verification failed, expected {legacy_count} records, got {new_count} in new collection" - logger.error(error_msg) - raise QdrantMigrationError(error_msg) + raise DataMigrationError( + f"Failed to migrate data from legacy collection '{legacy_collection}' to new collection '{collection_name}'" + ) from e logger.info( - f"Qdrant: Migration completed successfully: {migrated_count} records migrated" + f"Qdrant: Migration from '{legacy_collection}' to '{collection_name}' completed successfully" ) - - # Create payload index after successful migration - logger.info("Qdrant: Creating workspace payload index...") - client.create_payload_index( - collection_name=collection_name, - field_name=WORKSPACE_ID_FIELD, - field_schema=models.KeywordIndexParams( - type=models.KeywordIndexType.KEYWORD, - is_tenant=True, - ), + logger.warning( + "Qdrant: Manual deletion is required after data migration verification." ) - logger.info( - f"Qdrant: Migration from '{legacy_namespace}' to '{collection_name}' completed successfully" - ) - - except QdrantMigrationError: - # Re-raise migration errors without wrapping - raise - except Exception as e: - error_msg = f"Qdrant: Migration failed with error: {e}" - logger.error(error_msg) - raise QdrantMigrationError(error_msg) from e def __post_init__(self): + # Call parent class __post_init__ to validate embedding_func + super().__post_init__() + # Check for QDRANT_WORKSPACE environment variable first (higher priority) # This allows administrators to force a specific workspace for all Qdrant storage instances qdrant_workspace = os.environ.get("QDRANT_WORKSPACE") @@ -287,20 +433,23 @@ class QdrantVectorDBStorage(BaseVectorStorage): f"Using passed workspace parameter: '{effective_workspace}'" ) - # Get legacy namespace for data migration from old version - if effective_workspace: - self.legacy_namespace = f"{effective_workspace}_{self.namespace}" - else: - self.legacy_namespace = self.namespace - self.effective_workspace = effective_workspace or DEFAULT_WORKSPACE - # Use a shared collection with payload-based partitioning (Qdrant's recommended approach) - # Ref: https://qdrant.tech/documentation/guides/multiple-partitions/ - self.final_namespace = f"lightrag_vdb_{self.namespace}" - logger.debug( - f"Using shared collection '{self.final_namespace}' with workspace '{self.effective_workspace}' for payload-based partitioning" - ) + # Generate model suffix + self.model_suffix = self._generate_collection_suffix() + + # New naming scheme with model isolation + # Example: "lightrag_vdb_chunks_text_embedding_ada_002_1536d" + # Ensure model_suffix is not empty before appending + if self.model_suffix: + self.final_namespace = f"lightrag_vdb_{self.namespace}_{self.model_suffix}" + logger.info(f"Qdrant collection: {self.final_namespace}") + else: + # Fallback: use legacy namespace if model_suffix is unavailable + self.final_namespace = f"lightrag_vdb_{self.namespace}" + logger.warning( + f"Qdrant collection: {self.final_namespace} missing suffix. Pls add model_name to embedding_func for proper workspace data isolation." + ) kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = kwargs.get("cosine_better_than_threshold") @@ -338,11 +487,11 @@ class QdrantVectorDBStorage(BaseVectorStorage): ) # Setup collection (create if not exists and configure indexes) - # Pass legacy_namespace and workspace for migration support + # Pass namespace and workspace for backward-compatible migration support QdrantVectorDBStorage.setup_collection( self._client, self.final_namespace, - legacy_namespace=self.legacy_namespace, + namespace=self.namespace, workspace=self.effective_workspace, vectors_config=models.VectorParams( size=self.embedding_func.embedding_dim, @@ -352,8 +501,11 @@ class QdrantVectorDBStorage(BaseVectorStorage): payload_m=16, m=0, ), + model_suffix=self.model_suffix, ) + # Removed duplicate max batch size initialization + self._initialized = True logger.info( f"[{self.workspace}] Qdrant collection '{self.namespace}' initialized successfully" @@ -481,21 +633,44 @@ class QdrantVectorDBStorage(BaseVectorStorage): entity_name: Name of the entity to delete """ try: - # Generate the entity ID using the same function as used for storage + # Compute entity ID from name (same as Milvus) entity_id = compute_mdhash_id(entity_name, prefix=ENTITY_PREFIX) - qdrant_entity_id = compute_mdhash_id_for_qdrant( - entity_id, prefix=self.effective_workspace + logger.debug( + f"[{self.workspace}] Attempting to delete entity {entity_name} with ID {entity_id}" ) - # Delete the entity point by its Qdrant ID directly - self._client.delete( + # Scroll to find the entity by its ID field in payload with workspace filtering + # This is safer than reconstructing the Qdrant point ID + results = self._client.scroll( collection_name=self.final_namespace, - points_selector=models.PointIdsList(points=[qdrant_entity_id]), - wait=True, - ) - logger.debug( - f"[{self.workspace}] Successfully deleted entity {entity_name}" + scroll_filter=models.Filter( + must=[ + workspace_filter_condition(self.effective_workspace), + models.FieldCondition( + key=ID_FIELD, match=models.MatchValue(value=entity_id) + ), + ] + ), + with_payload=False, + limit=1, ) + + # Extract point IDs to delete + points = results[0] + if points: + ids_to_delete = [point.id for point in points] + self._client.delete( + collection_name=self.final_namespace, + points_selector=models.PointIdsList(points=ids_to_delete), + wait=True, + ) + logger.debug( + f"[{self.workspace}] Successfully deleted entity {entity_name}" + ) + else: + logger.debug( + f"[{self.workspace}] Entity {entity_name} not found in storage" + ) except Exception as e: logger.error(f"[{self.workspace}] Error deleting entity {entity_name}: {e}") @@ -506,38 +681,60 @@ class QdrantVectorDBStorage(BaseVectorStorage): entity_name: Name of the entity whose relations should be deleted """ try: - # Find relations where the entity is either source or target, with workspace filtering - results = self._client.scroll( - collection_name=self.final_namespace, - scroll_filter=models.Filter( - must=[workspace_filter_condition(self.effective_workspace)], - should=[ - models.FieldCondition( - key="src_id", match=models.MatchValue(value=entity_name) - ), - models.FieldCondition( - key="tgt_id", match=models.MatchValue(value=entity_name) - ), - ], - ), - with_payload=True, - limit=1000, # Adjust as needed for your use case + # Build the filter to find relations where entity is either source or target + # must + should = workspace_id matches AND (src_id matches OR tgt_id matches) + relation_filter = models.Filter( + must=[workspace_filter_condition(self.effective_workspace)], + should=[ + models.FieldCondition( + key="src_id", match=models.MatchValue(value=entity_name) + ), + models.FieldCondition( + key="tgt_id", match=models.MatchValue(value=entity_name) + ), + ], ) - # Extract points that need to be deleted - relation_points = results[0] - ids_to_delete = [point.id for point in relation_points] + # Paginate through all matching relations to handle large datasets + total_deleted = 0 + offset = None + batch_size = 1000 - if ids_to_delete: - # Delete the relations with workspace filtering - assert isinstance(self._client, QdrantClient) + while True: + # Scroll to find relations, using with_payload=False for efficiency + # since we only need point IDs for deletion + results = self._client.scroll( + collection_name=self.final_namespace, + scroll_filter=relation_filter, + with_payload=False, + with_vectors=False, + limit=batch_size, + offset=offset, + ) + + points, next_offset = results + if not points: + break + + # Extract point IDs to delete + ids_to_delete = [point.id for point in points] + + # Delete the batch of relations self._client.delete( collection_name=self.final_namespace, points_selector=models.PointIdsList(points=ids_to_delete), wait=True, ) + total_deleted += len(ids_to_delete) + + # Check if we've reached the end + if next_offset is None: + break + offset = next_offset + + if total_deleted > 0: logger.debug( - f"[{self.workspace}] Deleted {len(ids_to_delete)} relations for {entity_name}" + f"[{self.workspace}] Deleted {total_deleted} relations for {entity_name}" ) else: logger.debug( diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index ef0f61e2..6da56308 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -163,7 +163,9 @@ class UnifiedLock(Generic[T]): enable_output=self._enable_logging, ) - # Then acquire the main lock + # Acquire the main lock + # Note: self._lock should never be None here as the check has been moved + # to get_internal_lock() and get_data_init_lock() functions if self._is_async: await self._lock.acquire() else: @@ -193,19 +195,21 @@ class UnifiedLock(Generic[T]): async def __aexit__(self, exc_type, exc_val, exc_tb): main_lock_released = False + async_lock_released = False try: # Release main lock first - if self._is_async: - self._lock.release() - else: - self._lock.release() - main_lock_released = True + if self._lock is not None: + if self._is_async: + self._lock.release() + else: + self._lock.release() - direct_log( - f"== Lock == Process {self._pid}: Released lock {self._name} (async={self._is_async})", - level="INFO", - enable_output=self._enable_logging, - ) + direct_log( + f"== Lock == Process {self._pid}: Released lock {self._name} (async={self._is_async})", + level="INFO", + enable_output=self._enable_logging, + ) + main_lock_released = True # Then release async lock if in multiprocess mode if not self._is_async and self._async_lock is not None: @@ -215,6 +219,7 @@ class UnifiedLock(Generic[T]): level="DEBUG", enable_output=self._enable_logging, ) + async_lock_released = True except Exception as e: direct_log( @@ -223,9 +228,10 @@ class UnifiedLock(Generic[T]): enable_output=True, ) - # If main lock release failed but async lock hasn't been released, try to release it + # If main lock release failed but async lock hasn't been attempted yet, try to release it if ( not main_lock_released + and not async_lock_released and not self._is_async and self._async_lock is not None ): @@ -255,6 +261,10 @@ class UnifiedLock(Generic[T]): try: if self._is_async: raise RuntimeError("Use 'async with' for shared_storage lock") + + # Acquire the main lock + # Note: self._lock should never be None here as the check has been moved + # to get_internal_lock() and get_data_init_lock() functions direct_log( f"== Lock == Process {self._pid}: Acquiring lock {self._name} (sync)", level="DEBUG", @@ -1060,6 +1070,10 @@ class _KeyedLockContext: def get_internal_lock(enable_logging: bool = False) -> UnifiedLock: """return unified storage lock for data consistency""" + if _internal_lock is None: + raise RuntimeError( + "Shared data not initialized. Call initialize_share_data() before using locks!" + ) async_lock = _async_locks.get("internal_lock") if _is_multiprocess else None return UnifiedLock( lock=_internal_lock, @@ -1090,6 +1104,10 @@ def get_storage_keyed_lock( def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock: """return unified data initialization lock for ensuring atomic data initialization""" + if _data_init_lock is None: + raise RuntimeError( + "Shared data not initialized. Call initialize_share_data() before using locks!" + ) async_lock = _async_locks.get("data_init_lock") if _is_multiprocess else None return UnifiedLock( lock=_data_init_lock, diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 8a638759..d1e2bac3 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -7,7 +7,7 @@ import inspect import os import time import warnings -from dataclasses import asdict, dataclass, field +from dataclasses import asdict, dataclass, field, replace from datetime import datetime, timezone from functools import partial from typing import ( @@ -518,14 +518,9 @@ class LightRAG: f"max_total_tokens({self.summary_max_tokens}) should greater than summary_length_recommended({self.summary_length_recommended})" ) - # Fix global_config now - global_config = asdict(self) - - _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) - logger.debug(f"LightRAG init with param:\n {_print_config}\n") - # Init Embedding - # Step 1: Capture max_token_size before applying decorator (decorator strips dataclass attributes) + # Step 1: Capture embedding_func and max_token_size before applying decorator + original_embedding_func = self.embedding_func embedding_max_token_size = None if self.embedding_func and hasattr(self.embedding_func, "max_token_size"): embedding_max_token_size = self.embedding_func.max_token_size @@ -534,12 +529,26 @@ class LightRAG: ) self.embedding_token_limit = embedding_max_token_size - # Step 2: Apply priority wrapper decorator - self.embedding_func = priority_limit_async_func_call( - self.embedding_func_max_async, - llm_timeout=self.default_embedding_timeout, - queue_name="Embedding func", - )(self.embedding_func) + # Fix global_config now + global_config = asdict(self) + # Restore original EmbeddingFunc object (asdict converts it to dict) + global_config["embedding_func"] = original_embedding_func + + _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) + logger.debug(f"LightRAG init with param:\n {_print_config}\n") + + # Step 2: Apply priority wrapper decorator to EmbeddingFunc's inner func + # Create a NEW EmbeddingFunc instance with the wrapped func to avoid mutating the caller's object + # This ensures _generate_collection_suffix can still access attributes (model_name, embedding_dim) + # while preventing side effects when the same EmbeddingFunc is reused across multiple LightRAG instances + if self.embedding_func is not None: + wrapped_func = priority_limit_async_func_call( + self.embedding_func_max_async, + llm_timeout=self.default_embedding_timeout, + queue_name="Embedding func", + )(self.embedding_func.func) + # Use dataclasses.replace() to create a new instance, leaving the original unchanged + self.embedding_func = replace(self.embedding_func, func=wrapped_func) # Initialize all storages self.key_string_value_json_storage_cls: type[BaseKVStorage] = ( diff --git a/lightrag/llm/bedrock.py b/lightrag/llm/bedrock.py index f6871422..e651e3c8 100644 --- a/lightrag/llm/bedrock.py +++ b/lightrag/llm/bedrock.py @@ -351,7 +351,9 @@ async def bedrock_complete( return result -@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) +@wrap_embedding_func_with_attrs( + embedding_dim=1024, max_token_size=8192, model_name="amazon.titan-embed-text-v2:0" +) @retry( stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60), diff --git a/lightrag/llm/gemini.py b/lightrag/llm/gemini.py index 37ce7206..5e438ceb 100644 --- a/lightrag/llm/gemini.py +++ b/lightrag/llm/gemini.py @@ -453,7 +453,9 @@ async def gemini_model_complete( ) -@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=2048) +@wrap_embedding_func_with_attrs( + embedding_dim=1536, max_token_size=2048, model_name="gemini-embedding-001" +) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=60), diff --git a/lightrag/llm/hf.py b/lightrag/llm/hf.py index 447f95c3..eff89650 100644 --- a/lightrag/llm/hf.py +++ b/lightrag/llm/hf.py @@ -142,7 +142,9 @@ async def hf_model_complete( return result -@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) +@wrap_embedding_func_with_attrs( + embedding_dim=1024, max_token_size=8192, model_name="hf_embedding_model" +) async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray: # Detect the appropriate device if torch.cuda.is_available(): diff --git a/lightrag/llm/jina.py b/lightrag/llm/jina.py index 41251f4a..5c380854 100644 --- a/lightrag/llm/jina.py +++ b/lightrag/llm/jina.py @@ -58,7 +58,9 @@ async def fetch_data(url, headers, data): return data_list -@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192) +@wrap_embedding_func_with_attrs( + embedding_dim=2048, max_token_size=8192, model_name="jina-embeddings-v4" +) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=60), diff --git a/lightrag/llm/lollms.py b/lightrag/llm/lollms.py index 2f2a1dbf..3eaef1af 100644 --- a/lightrag/llm/lollms.py +++ b/lightrag/llm/lollms.py @@ -138,7 +138,9 @@ async def lollms_model_complete( ) -@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) +@wrap_embedding_func_with_attrs( + embedding_dim=1024, max_token_size=8192, model_name="lollms_embedding_model" +) async def lollms_embed( texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs ) -> np.ndarray: diff --git a/lightrag/llm/nvidia_openai.py b/lightrag/llm/nvidia_openai.py index 1ebaf3a6..9025ec13 100644 --- a/lightrag/llm/nvidia_openai.py +++ b/lightrag/llm/nvidia_openai.py @@ -33,7 +33,9 @@ from lightrag.utils import ( import numpy as np -@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192) +@wrap_embedding_func_with_attrs( + embedding_dim=2048, max_token_size=8192, model_name="nvidia_embedding_model" +) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=60), diff --git a/lightrag/llm/ollama.py b/lightrag/llm/ollama.py index cd633e80..62269296 100644 --- a/lightrag/llm/ollama.py +++ b/lightrag/llm/ollama.py @@ -172,7 +172,9 @@ async def ollama_model_complete( ) -@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) +@wrap_embedding_func_with_attrs( + embedding_dim=1024, max_token_size=8192, model_name="bge-m3:latest" +) async def ollama_embed( texts: list[str], embed_model: str = "bge-m3:latest", **kwargs ) -> np.ndarray: diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 9c3d0261..b49cac71 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -677,7 +677,9 @@ async def nvidia_openai_complete( return result -@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) +@wrap_embedding_func_with_attrs( + embedding_dim=1536, max_token_size=8192, model_name="text-embedding-3-small" +) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=60), @@ -867,7 +869,11 @@ async def azure_openai_complete( return result -@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) +@wrap_embedding_func_with_attrs( + embedding_dim=1536, + max_token_size=8192, + model_name="my-text-embedding-3-large-deployment", +) async def azure_openai_embed( texts: list[str], model: str | None = None, diff --git a/lightrag/llm/zhipu.py b/lightrag/llm/zhipu.py index d90f3cc1..5caa82bf 100644 --- a/lightrag/llm/zhipu.py +++ b/lightrag/llm/zhipu.py @@ -179,7 +179,9 @@ async def zhipu_complete( ) -@wrap_embedding_func_with_attrs(embedding_dim=1024) +@wrap_embedding_func_with_attrs( + embedding_dim=1024, max_token_size=8192, model_name="embedding-3" +) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=60), diff --git a/lightrag/tools/prepare_qdrant_legacy_data.py b/lightrag/tools/prepare_qdrant_legacy_data.py new file mode 100644 index 00000000..2ac90196 --- /dev/null +++ b/lightrag/tools/prepare_qdrant_legacy_data.py @@ -0,0 +1,720 @@ +#!/usr/bin/env python3 +""" +Qdrant Legacy Data Preparation Tool for LightRAG + +This tool copies data from new collections to legacy collections for testing +the data migration logic in setup_collection function. + +New Collections (with workspace_id): + - lightrag_vdb_chunks + - lightrag_vdb_entities + - lightrag_vdb_relationships + +Legacy Collections (without workspace_id, dynamically named as {workspace}_{suffix}): + - {workspace}_chunks (e.g., space1_chunks) + - {workspace}_entities (e.g., space1_entities) + - {workspace}_relationships (e.g., space1_relationships) + +The tool: + 1. Filters source data by workspace_id + 2. Verifies workspace data exists before creating legacy collections + 3. Removes workspace_id field to simulate legacy data format + 4. Copies only the specified workspace's data to legacy collections + +Usage: + python -m lightrag.tools.prepare_qdrant_legacy_data + # or + python lightrag/tools/prepare_qdrant_legacy_data.py + + # Specify custom workspace + python -m lightrag.tools.prepare_qdrant_legacy_data --workspace space1 + + # Process specific collection types only + python -m lightrag.tools.prepare_qdrant_legacy_data --types chunks,entities + + # Dry run (preview only, no actual changes) + python -m lightrag.tools.prepare_qdrant_legacy_data --dry-run +""" + +import argparse +import asyncio +import configparser +import os +import sys +import time +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +import pipmaster as pm +from dotenv import load_dotenv +from qdrant_client import QdrantClient, models # type: ignore + +# Add project root to path for imports +sys.path.insert( + 0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) + +# Load environment variables +load_dotenv(dotenv_path=".env", override=False) + +# Ensure qdrant-client is installed +if not pm.is_installed("qdrant-client"): + pm.install("qdrant-client") + +# Collection namespace mapping: new collection pattern -> legacy suffix +# Legacy collection will be named as: {workspace}_{suffix} +COLLECTION_NAMESPACES = { + "chunks": { + "new": "lightrag_vdb_chunks", + "suffix": "chunks", + }, + "entities": { + "new": "lightrag_vdb_entities", + "suffix": "entities", + }, + "relationships": { + "new": "lightrag_vdb_relationships", + "suffix": "relationships", + }, +} + +# Default batch size for copy operations +DEFAULT_BATCH_SIZE = 500 + +# Field to remove from legacy data +WORKSPACE_ID_FIELD = "workspace_id" + +# ANSI color codes for terminal output +BOLD_CYAN = "\033[1;36m" +BOLD_GREEN = "\033[1;32m" +BOLD_YELLOW = "\033[1;33m" +BOLD_RED = "\033[1;31m" +RESET = "\033[0m" + + +@dataclass +class CopyStats: + """Copy operation statistics""" + + collection_type: str + source_collection: str + target_collection: str + total_records: int = 0 + copied_records: int = 0 + failed_records: int = 0 + errors: List[Dict[str, Any]] = field(default_factory=list) + elapsed_time: float = 0.0 + + def add_error(self, batch_idx: int, error: Exception, batch_size: int): + """Record batch error""" + self.errors.append( + { + "batch": batch_idx, + "error_type": type(error).__name__, + "error_msg": str(error), + "records_lost": batch_size, + "timestamp": time.time(), + } + ) + self.failed_records += batch_size + + +class QdrantLegacyDataPreparationTool: + """Tool for preparing legacy data in Qdrant for migration testing""" + + def __init__( + self, + workspace: str = "space1", + batch_size: int = DEFAULT_BATCH_SIZE, + dry_run: bool = False, + clear_target: bool = False, + ): + """ + Initialize the tool. + + Args: + workspace: Workspace to use for filtering new collection data + batch_size: Number of records to process per batch + dry_run: If True, only preview operations without making changes + clear_target: If True, delete target collection before copying data + """ + self.workspace = workspace + self.batch_size = batch_size + self.dry_run = dry_run + self.clear_target = clear_target + self._client: Optional[QdrantClient] = None + + def _get_client(self) -> QdrantClient: + """Get or create QdrantClient instance""" + if self._client is None: + config = configparser.ConfigParser() + config.read("config.ini", "utf-8") + + self._client = QdrantClient( + url=os.environ.get( + "QDRANT_URL", config.get("qdrant", "uri", fallback=None) + ), + api_key=os.environ.get( + "QDRANT_API_KEY", + config.get("qdrant", "apikey", fallback=None), + ), + ) + return self._client + + def print_header(self): + """Print tool header""" + print("\n" + "=" * 60) + print("Qdrant Legacy Data Preparation Tool - LightRAG") + print("=" * 60) + if self.dry_run: + print(f"{BOLD_YELLOW}⚠️ DRY RUN MODE - No changes will be made{RESET}") + if self.clear_target: + print( + f"{BOLD_RED}⚠️ CLEAR TARGET MODE - Target collections will be deleted first{RESET}" + ) + print(f"Workspace: {BOLD_CYAN}{self.workspace}{RESET}") + print(f"Batch Size: {self.batch_size}") + print("=" * 60) + + def check_connection(self) -> bool: + """Check Qdrant connection""" + try: + client = self._get_client() + # Try to list collections to verify connection + client.get_collections() + print(f"{BOLD_GREEN}✓{RESET} Qdrant connection successful") + return True + except Exception as e: + print(f"{BOLD_RED}✗{RESET} Qdrant connection failed: {e}") + return False + + def get_collection_info(self, collection_name: str) -> Optional[Dict[str, Any]]: + """ + Get collection information. + + Args: + collection_name: Name of the collection + + Returns: + Dictionary with collection info (vector_size, count) or None if not exists + """ + client = self._get_client() + + if not client.collection_exists(collection_name): + return None + + info = client.get_collection(collection_name) + count = client.count(collection_name=collection_name, exact=True).count + + # Handle both object and dict formats for vectors config + vectors_config = info.config.params.vectors + if isinstance(vectors_config, dict): + # Named vectors format or dict format + if vectors_config: + first_key = next(iter(vectors_config.keys()), None) + if first_key and hasattr(vectors_config[first_key], "size"): + vector_size = vectors_config[first_key].size + distance = vectors_config[first_key].distance + else: + # Try to get from dict values + first_val = next(iter(vectors_config.values()), {}) + vector_size = ( + first_val.get("size") + if isinstance(first_val, dict) + else getattr(first_val, "size", None) + ) + distance = ( + first_val.get("distance") + if isinstance(first_val, dict) + else getattr(first_val, "distance", None) + ) + else: + vector_size = None + distance = None + else: + # Standard single vector format + vector_size = vectors_config.size + distance = vectors_config.distance + + return { + "name": collection_name, + "vector_size": vector_size, + "count": count, + "distance": distance, + } + + def delete_collection(self, collection_name: str) -> bool: + """ + Delete a collection if it exists. + + Args: + collection_name: Name of the collection to delete + + Returns: + True if deleted or doesn't exist + """ + client = self._get_client() + + if not client.collection_exists(collection_name): + return True + + if self.dry_run: + target_info = self.get_collection_info(collection_name) + count = target_info["count"] if target_info else 0 + print( + f" {BOLD_YELLOW}[DRY RUN]{RESET} Would delete collection '{collection_name}' ({count:,} records)" + ) + return True + + try: + target_info = self.get_collection_info(collection_name) + count = target_info["count"] if target_info else 0 + client.delete_collection(collection_name=collection_name) + print( + f" {BOLD_RED}✗{RESET} Deleted collection '{collection_name}' ({count:,} records)" + ) + return True + except Exception as e: + print(f" {BOLD_RED}✗{RESET} Failed to delete collection: {e}") + return False + + def create_legacy_collection( + self, collection_name: str, vector_size: int, distance: models.Distance + ) -> bool: + """ + Create legacy collection if it doesn't exist. + + Args: + collection_name: Name of the collection to create + vector_size: Dimension of vectors + distance: Distance metric + + Returns: + True if created or already exists + """ + client = self._get_client() + + if client.collection_exists(collection_name): + print(f" Collection '{collection_name}' already exists") + return True + + if self.dry_run: + print( + f" {BOLD_YELLOW}[DRY RUN]{RESET} Would create collection '{collection_name}' with {vector_size}d vectors" + ) + return True + + try: + client.create_collection( + collection_name=collection_name, + vectors_config=models.VectorParams( + size=vector_size, + distance=distance, + ), + hnsw_config=models.HnswConfigDiff( + payload_m=16, + m=0, + ), + ) + print( + f" {BOLD_GREEN}✓{RESET} Created collection '{collection_name}' with {vector_size}d vectors" + ) + return True + except Exception as e: + print(f" {BOLD_RED}✗{RESET} Failed to create collection: {e}") + return False + + def _get_workspace_filter(self) -> models.Filter: + """Create workspace filter for Qdrant queries""" + return models.Filter( + must=[ + models.FieldCondition( + key=WORKSPACE_ID_FIELD, + match=models.MatchValue(value=self.workspace), + ) + ] + ) + + def get_workspace_count(self, collection_name: str) -> int: + """ + Get count of records for the current workspace in a collection. + + Args: + collection_name: Name of the collection + + Returns: + Count of records for the workspace + """ + client = self._get_client() + return client.count( + collection_name=collection_name, + count_filter=self._get_workspace_filter(), + exact=True, + ).count + + def copy_collection_data( + self, + source_collection: str, + target_collection: str, + collection_type: str, + workspace_count: int, + ) -> CopyStats: + """ + Copy data from source to target collection. + + This filters by workspace_id and removes it from payload to simulate legacy data format. + + Args: + source_collection: Source collection name + target_collection: Target collection name + collection_type: Type of collection (chunks, entities, relationships) + workspace_count: Pre-computed count of workspace records + + Returns: + CopyStats with operation results + """ + client = self._get_client() + stats = CopyStats( + collection_type=collection_type, + source_collection=source_collection, + target_collection=target_collection, + ) + + start_time = time.time() + stats.total_records = workspace_count + + if workspace_count == 0: + print(f" No records for workspace '{self.workspace}', skipping") + stats.elapsed_time = time.time() - start_time + return stats + + print(f" Workspace records: {workspace_count:,}") + + if self.dry_run: + print( + f" {BOLD_YELLOW}[DRY RUN]{RESET} Would copy {workspace_count:,} records to '{target_collection}'" + ) + stats.copied_records = workspace_count + stats.elapsed_time = time.time() - start_time + return stats + + # Batch copy using scroll with workspace filter + workspace_filter = self._get_workspace_filter() + offset = None + batch_idx = 0 + + while True: + # Scroll source collection with workspace filter + result = client.scroll( + collection_name=source_collection, + scroll_filter=workspace_filter, + limit=self.batch_size, + offset=offset, + with_vectors=True, + with_payload=True, + ) + points, next_offset = result + + if not points: + break + + batch_idx += 1 + + # Transform points: remove workspace_id from payload + new_points = [] + for point in points: + new_payload = dict(point.payload or {}) + # Remove workspace_id to simulate legacy format + new_payload.pop(WORKSPACE_ID_FIELD, None) + + # Use original id from payload if available, otherwise use point.id + original_id = new_payload.get("id") + if original_id: + # Generate a simple deterministic id for legacy format + # Use original id directly (legacy format didn't have workspace prefix) + import hashlib + import uuid + + hashed = hashlib.sha256(original_id.encode("utf-8")).digest() + point_id = uuid.UUID(bytes=hashed[:16], version=4).hex + else: + point_id = str(point.id) + + new_points.append( + models.PointStruct( + id=point_id, + vector=point.vector, + payload=new_payload, + ) + ) + + try: + # Upsert to target collection + client.upsert( + collection_name=target_collection, points=new_points, wait=True + ) + stats.copied_records += len(new_points) + + # Progress bar + progress = (stats.copied_records / workspace_count) * 100 + bar_length = 30 + filled = int(bar_length * stats.copied_records // workspace_count) + bar = "█" * filled + "░" * (bar_length - filled) + + print( + f"\r Copying: {bar} {stats.copied_records:,}/{workspace_count:,} ({progress:.1f}%) ", + end="", + flush=True, + ) + + except Exception as e: + stats.add_error(batch_idx, e, len(new_points)) + print( + f"\n {BOLD_RED}✗{RESET} Batch {batch_idx} failed: {type(e).__name__}: {e}" + ) + + if next_offset is None: + break + offset = next_offset + + print() # New line after progress bar + stats.elapsed_time = time.time() - start_time + + return stats + + def process_collection_type(self, collection_type: str) -> Optional[CopyStats]: + """ + Process a single collection type. + + Args: + collection_type: Type of collection (chunks, entities, relationships) + + Returns: + CopyStats or None if error + """ + namespace_config = COLLECTION_NAMESPACES.get(collection_type) + if not namespace_config: + print(f"{BOLD_RED}✗{RESET} Unknown collection type: {collection_type}") + return None + + source = namespace_config["new"] + # Generate legacy collection name dynamically: {workspace}_{suffix} + target = f"{self.workspace}_{namespace_config['suffix']}" + + print(f"\n{'=' * 50}") + print(f"Processing: {BOLD_CYAN}{collection_type}{RESET}") + print(f"{'=' * 50}") + print(f" Source: {source}") + print(f" Target: {target}") + + # Check source collection + source_info = self.get_collection_info(source) + if source_info is None: + print( + f" {BOLD_YELLOW}⚠{RESET} Source collection '{source}' does not exist, skipping" + ) + return None + + print(f" Source vector dimension: {source_info['vector_size']}d") + print(f" Source distance metric: {source_info['distance']}") + print(f" Source total records: {source_info['count']:,}") + + # Check workspace data exists BEFORE creating legacy collection + workspace_count = self.get_workspace_count(source) + print(f" Workspace '{self.workspace}' records: {workspace_count:,}") + + if workspace_count == 0: + print( + f" {BOLD_YELLOW}⚠{RESET} No data found for workspace '{self.workspace}' in '{source}', skipping" + ) + return None + + # Clear target collection if requested + if self.clear_target: + if not self.delete_collection(target): + return None + + # Create target collection only after confirming workspace data exists + if not self.create_legacy_collection( + target, source_info["vector_size"], source_info["distance"] + ): + return None + + # Copy data with workspace filter + stats = self.copy_collection_data( + source, target, collection_type, workspace_count + ) + + # Print result + if stats.failed_records == 0: + print( + f" {BOLD_GREEN}✓{RESET} Copied {stats.copied_records:,} records in {stats.elapsed_time:.2f}s" + ) + else: + print( + f" {BOLD_YELLOW}⚠{RESET} Copied {stats.copied_records:,} records, " + f"{BOLD_RED}{stats.failed_records:,} failed{RESET} in {stats.elapsed_time:.2f}s" + ) + + return stats + + def print_summary(self, all_stats: List[CopyStats]): + """Print summary of all operations""" + print("\n" + "=" * 60) + print("Summary") + print("=" * 60) + + total_copied = sum(s.copied_records for s in all_stats) + total_failed = sum(s.failed_records for s in all_stats) + total_time = sum(s.elapsed_time for s in all_stats) + + for stats in all_stats: + status = ( + f"{BOLD_GREEN}✓{RESET}" + if stats.failed_records == 0 + else f"{BOLD_YELLOW}⚠{RESET}" + ) + print( + f" {status} {stats.collection_type}: {stats.copied_records:,}/{stats.total_records:,} " + f"({stats.source_collection} → {stats.target_collection})" + ) + + print("-" * 60) + print(f" Total records copied: {BOLD_CYAN}{total_copied:,}{RESET}") + if total_failed > 0: + print(f" Total records failed: {BOLD_RED}{total_failed:,}{RESET}") + print(f" Total time: {total_time:.2f}s") + + if self.dry_run: + print(f"\n{BOLD_YELLOW}⚠️ DRY RUN - No actual changes were made{RESET}") + + # Print error details if any + all_errors = [] + for stats in all_stats: + all_errors.extend(stats.errors) + + if all_errors: + print(f"\n{BOLD_RED}Errors ({len(all_errors)}){RESET}") + for i, error in enumerate(all_errors[:5], 1): + print( + f" {i}. Batch {error['batch']}: {error['error_type']}: {error['error_msg']}" + ) + if len(all_errors) > 5: + print(f" ... and {len(all_errors) - 5} more errors") + + print("=" * 60) + + async def run(self, collection_types: Optional[List[str]] = None): + """ + Run the data preparation tool. + + Args: + collection_types: List of collection types to process (default: all) + """ + self.print_header() + + # Check connection + if not self.check_connection(): + return + + # Determine which collection types to process + if collection_types: + types_to_process = [t.strip() for t in collection_types] + invalid_types = [ + t for t in types_to_process if t not in COLLECTION_NAMESPACES + ] + if invalid_types: + print( + f"{BOLD_RED}✗{RESET} Invalid collection types: {', '.join(invalid_types)}" + ) + print(f" Valid types: {', '.join(COLLECTION_NAMESPACES.keys())}") + return + else: + types_to_process = list(COLLECTION_NAMESPACES.keys()) + + print(f"\nCollection types to process: {', '.join(types_to_process)}") + + # Process each collection type + all_stats = [] + for ctype in types_to_process: + stats = self.process_collection_type(ctype) + if stats: + all_stats.append(stats) + + # Print summary + if all_stats: + self.print_summary(all_stats) + else: + print(f"\n{BOLD_YELLOW}⚠{RESET} No collections were processed") + + +def parse_args(): + """Parse command line arguments""" + parser = argparse.ArgumentParser( + description="Prepare legacy data in Qdrant for migration testing", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python -m lightrag.tools.prepare_qdrant_legacy_data + python -m lightrag.tools.prepare_qdrant_legacy_data --workspace space1 + python -m lightrag.tools.prepare_qdrant_legacy_data --types chunks,entities + python -m lightrag.tools.prepare_qdrant_legacy_data --dry-run + """, + ) + + parser.add_argument( + "--workspace", + type=str, + default="space1", + help="Workspace name (default: space1)", + ) + + parser.add_argument( + "--types", + type=str, + default=None, + help="Comma-separated list of collection types (chunks, entities, relationships)", + ) + + parser.add_argument( + "--batch-size", + type=int, + default=DEFAULT_BATCH_SIZE, + help=f"Batch size for copy operations (default: {DEFAULT_BATCH_SIZE})", + ) + + parser.add_argument( + "--dry-run", + action="store_true", + help="Preview operations without making changes", + ) + + parser.add_argument( + "--clear-target", + action="store_true", + help="Delete target collections before copying (for clean test environment)", + ) + + return parser.parse_args() + + +async def main(): + """Main entry point""" + args = parse_args() + + collection_types = None + if args.types: + collection_types = [t.strip() for t in args.types.split(",")] + + tool = QdrantLegacyDataPreparationTool( + workspace=args.workspace, + batch_size=args.batch_size, + dry_run=args.dry_run, + clear_target=args.clear_target, + ) + + await tool.run(collection_types=collection_types) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/lightrag/utils.py b/lightrag/utils.py index 65c1e4bc..d795acdb 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -425,6 +425,9 @@ class EmbeddingFunc: send_dimensions: bool = ( False # Control whether to send embedding_dim to the function ) + model_name: str | None = ( + None # Model name for implementating workspace data isolation in vector DB + ) async def __call__(self, *args, **kwargs) -> np.ndarray: # Only inject embedding_dim when send_dimensions is True @@ -1016,42 +1019,36 @@ def wrap_embedding_func_with_attrs(**kwargs): Correct usage patterns: - 1. Direct implementation (decorated): + 1. Direct decoration: ```python - @wrap_embedding_func_with_attrs(embedding_dim=1536) + @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192, model_name="my_embedding_model") async def my_embed(texts, embedding_dim=None): # Direct implementation return embeddings ``` - - 2. Wrapper calling decorated function (DO NOT decorate wrapper): + 2. Double decoration: ```python - # my_embed is already decorated above + @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192, model_name="my_embedding_model") + @retry(...) + async def openai_embed(texts, ...): + # Base implementation + pass - async def my_wrapper(texts, **kwargs): # ❌ DO NOT decorate this! - # Must call .func to access unwrapped implementation - return await my_embed.func(texts, **kwargs) - ``` - - 3. Wrapper calling decorated function (properly decorated): - ```python - @wrap_embedding_func_with_attrs(embedding_dim=1536) - async def my_wrapper(texts, **kwargs): # ✅ Can decorate if calling .func - # Calling .func avoids double decoration - return await my_embed.func(texts, **kwargs) + @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=4096, model_name="another_embedding_model") + # Note: No @retry here! + async def new_openai_embed(texts, ...): + # CRITICAL: Call .func to access unwrapped function + return await openai_embed.func(texts, ...) # ✅ Correct + # return await openai_embed(texts, ...) # ❌ Wrong - double decoration! ``` The decorated function becomes an EmbeddingFunc instance with: - embedding_dim: The embedding dimension - max_token_size: Maximum token limit (optional) + - model_name: Model name (optional) - func: The original unwrapped function (access via .func) - __call__: Wrapper that injects embedding_dim parameter - Double decoration causes: - - Double injection of embedding_dim parameter - - Incorrect parameter passing to the underlying implementation - - Runtime errors due to parameter conflicts - Args: embedding_dim: The dimension of embedding vectors max_token_size: Maximum number of tokens (optional) @@ -1059,21 +1056,6 @@ def wrap_embedding_func_with_attrs(**kwargs): Returns: A decorator that wraps the function as an EmbeddingFunc instance - - Example of correct wrapper implementation: - ```python - @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) - @retry(...) - async def openai_embed(texts, ...): - # Base implementation - pass - - @wrap_embedding_func_with_attrs(embedding_dim=1536) # Note: No @retry here! - async def azure_openai_embed(texts, ...): - # CRITICAL: Call .func to access unwrapped function - return await openai_embed.func(texts, ...) # ✅ Correct - # return await openai_embed(texts, ...) # ❌ Wrong - double decoration! - ``` """ def final_decro(func) -> EmbeddingFunc: diff --git a/pyproject.toml b/pyproject.toml index 761a3309..dd3dbc92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,7 +72,6 @@ api = [ # API-specific dependencies "aiofiles", "ascii_colors", - "asyncpg", "distro", "fastapi", "httpcore", @@ -108,7 +107,8 @@ offline-storage = [ "neo4j>=5.0.0,<7.0.0", "pymilvus>=2.6.2,<3.0.0", "pymongo>=4.0.0,<5.0.0", - "asyncpg>=0.29.0,<1.0.0", + "asyncpg>=0.31.0,<1.0.0", + "pgvector>=0.4.2,<1.0.0", "qdrant-client>=1.11.0,<2.0.0", ] diff --git a/requirements-offline-storage.txt b/requirements-offline-storage.txt index 13a9c0e2..82caacbd 100644 --- a/requirements-offline-storage.txt +++ b/requirements-offline-storage.txt @@ -8,8 +8,9 @@ # Or use constraints: pip install --constraint constraints-offline.txt -r requirements-offline-storage.txt # Storage backend dependencies (with version constraints matching pyproject.toml) -asyncpg>=0.29.0,<1.0.0 +asyncpg>=0.31.0,<1.0.0 neo4j>=5.0.0,<7.0.0 +pgvector>=0.4.2,<1.0.0 pymilvus>=2.6.2,<3.0.0 pymongo>=4.0.0,<5.0.0 qdrant-client>=1.11.0,<2.0.0 diff --git a/requirements-offline.txt b/requirements-offline.txt index 87ca7a6a..283ced73 100644 --- a/requirements-offline.txt +++ b/requirements-offline.txt @@ -7,20 +7,17 @@ # Recommended: Use pip install lightrag-hku[offline] for the same effect # Or use constraints: pip install --constraint constraints-offline.txt -r requirements-offline.txt -# LLM provider dependencies (with version constraints matching pyproject.toml) aioboto3>=12.0.0,<16.0.0 anthropic>=0.18.0,<1.0.0 - -# Storage backend dependencies -asyncpg>=0.29.0,<1.0.0 +asyncpg>=0.31.0,<1.0.0 +google-api-core>=2.0.0,<3.0.0 google-genai>=1.0.0,<2.0.0 - -# Document processing dependencies llama-index>=0.9.0,<1.0.0 neo4j>=5.0.0,<7.0.0 ollama>=0.1.0,<1.0.0 openai>=2.0.0,<3.0.0 openpyxl>=3.0.0,<4.0.0 +pgvector>=0.4.2,<1.0.0 pycryptodome>=3.0.0,<4.0.0 pymilvus>=2.6.2,<3.0.0 pymongo>=4.0.0,<5.0.0 diff --git a/tests/test_dimension_mismatch.py b/tests/test_dimension_mismatch.py new file mode 100644 index 00000000..b63fbd35 --- /dev/null +++ b/tests/test_dimension_mismatch.py @@ -0,0 +1,377 @@ +""" +Tests for dimension mismatch handling during migration. + +This test module verifies that both PostgreSQL and Qdrant storage backends +properly detect and handle vector dimension mismatches when migrating from +legacy collections/tables to new ones with different embedding models. +""" + +import json +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +from lightrag.kg.qdrant_impl import QdrantVectorDBStorage +from lightrag.kg.postgres_impl import PGVectorStorage +from lightrag.exceptions import DataMigrationError + + +# Note: Tests should use proper table names that have DDL templates +# Valid base tables: LIGHTRAG_VDB_CHUNKS, LIGHTRAG_VDB_ENTITIES, LIGHTRAG_VDB_RELATIONSHIPS, +# LIGHTRAG_DOC_CHUNKS, LIGHTRAG_DOC_FULL_DOCS, LIGHTRAG_DOC_TEXT_CHUNKS + + +class TestQdrantDimensionMismatch: + """Test suite for Qdrant dimension mismatch handling.""" + + def test_qdrant_dimension_mismatch_raises_error(self): + """ + Test that Qdrant raises DataMigrationError when dimensions don't match. + + Scenario: Legacy collection has 1536d vectors, new model expects 3072d. + Expected: DataMigrationError is raised to prevent data corruption. + """ + from qdrant_client import models + + # Setup mock client + client = MagicMock() + + # Mock legacy collection with 1536d vectors + legacy_collection_info = MagicMock() + legacy_collection_info.config.params.vectors.size = 1536 + + # Setup collection existence checks + def collection_exists_side_effect(name): + if ( + name == "lightrag_vdb_chunks" + ): # legacy (matches _find_legacy_collection pattern) + return True + elif name == "lightrag_chunks_model_3072d": # new + return False + return False + + client.collection_exists.side_effect = collection_exists_side_effect + client.get_collection.return_value = legacy_collection_info + client.count.return_value.count = 100 # Legacy has data + + # Patch _find_legacy_collection to return the legacy collection name + with patch( + "lightrag.kg.qdrant_impl._find_legacy_collection", + return_value="lightrag_vdb_chunks", + ): + # Call setup_collection with 3072d (different from legacy 1536d) + # Should raise DataMigrationError due to dimension mismatch + with pytest.raises(DataMigrationError) as exc_info: + QdrantVectorDBStorage.setup_collection( + client, + "lightrag_chunks_model_3072d", + namespace="chunks", + workspace="test", + vectors_config=models.VectorParams( + size=3072, distance=models.Distance.COSINE + ), + hnsw_config=models.HnswConfigDiff( + payload_m=16, + m=0, + ), + model_suffix="model_3072d", + ) + + # Verify error message contains dimension information + assert "3072" in str(exc_info.value) or "1536" in str(exc_info.value) + + # Verify new collection was NOT created (error raised before creation) + client.create_collection.assert_not_called() + + # Verify migration was NOT attempted + client.scroll.assert_not_called() + client.upsert.assert_not_called() + + def test_qdrant_dimension_match_proceed_migration(self): + """ + Test that Qdrant proceeds with migration when dimensions match. + + Scenario: Legacy collection has 1536d vectors, new model also expects 1536d. + Expected: Migration proceeds normally. + """ + from qdrant_client import models + + client = MagicMock() + + # Mock legacy collection with 1536d vectors (matching new) + legacy_collection_info = MagicMock() + legacy_collection_info.config.params.vectors.size = 1536 + + def collection_exists_side_effect(name): + if name == "lightrag_chunks": # legacy + return True + elif name == "lightrag_chunks_model_1536d": # new + return False + return False + + client.collection_exists.side_effect = collection_exists_side_effect + client.get_collection.return_value = legacy_collection_info + + # Track whether upsert has been called (migration occurred) + migration_done = {"value": False} + + def upsert_side_effect(*args, **kwargs): + migration_done["value"] = True + return MagicMock() + + client.upsert.side_effect = upsert_side_effect + + # Mock count to return different values based on collection name and migration state + # Before migration: new collection has 0 records + # After migration: new collection has 1 record (matching migrated data) + def count_side_effect(collection_name, **kwargs): + result = MagicMock() + if collection_name == "lightrag_chunks": # legacy + result.count = 1 # Legacy has 1 record + elif collection_name == "lightrag_chunks_model_1536d": # new + # Return 0 before migration, 1 after migration + result.count = 1 if migration_done["value"] else 0 + else: + result.count = 0 + return result + + client.count.side_effect = count_side_effect + + # Mock scroll to return sample data (1 record for easier verification) + sample_point = MagicMock() + sample_point.id = "test_id" + sample_point.vector = [0.1] * 1536 + sample_point.payload = {"id": "test"} + client.scroll.return_value = ([sample_point], None) + + # Mock _find_legacy_collection to return the legacy collection name + with patch( + "lightrag.kg.qdrant_impl._find_legacy_collection", + return_value="lightrag_chunks", + ): + # Call setup_collection with matching 1536d + QdrantVectorDBStorage.setup_collection( + client, + "lightrag_chunks_model_1536d", + namespace="chunks", + workspace="test", + vectors_config=models.VectorParams( + size=1536, distance=models.Distance.COSINE + ), + hnsw_config=models.HnswConfigDiff( + payload_m=16, + m=0, + ), + model_suffix="model_1536d", + ) + + # Verify migration WAS attempted + client.create_collection.assert_called_once() + client.scroll.assert_called() + client.upsert.assert_called() + + +class TestPostgresDimensionMismatch: + """Test suite for PostgreSQL dimension mismatch handling.""" + + async def test_postgres_dimension_mismatch_raises_error_metadata(self): + """ + Test that PostgreSQL raises DataMigrationError when dimensions don't match. + + Scenario: Legacy table has 1536d vectors, new model expects 3072d. + Expected: DataMigrationError is raised to prevent data corruption. + """ + # Setup mock database + db = AsyncMock() + + # Mock check_table_exists + async def mock_check_table_exists(table_name): + if table_name == "LIGHTRAG_DOC_CHUNKS": # legacy + return True + elif table_name == "LIGHTRAG_DOC_CHUNKS_model_3072d": # new + return False + return False + + db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists) + + # Mock table existence and dimension checks + async def query_side_effect(query, params, **kwargs): + if "COUNT(*)" in query: + return {"count": 100} # Legacy has data + elif "SELECT content_vector FROM" in query: + # Return sample vector with 1536 dimensions + return {"content_vector": [0.1] * 1536} + return {} + + db.query.side_effect = query_side_effect + db.execute = AsyncMock() + db._create_vector_index = AsyncMock() + + # Call setup_table with 3072d (different from legacy 1536d) + # Should raise DataMigrationError due to dimension mismatch + with pytest.raises(DataMigrationError) as exc_info: + await PGVectorStorage.setup_table( + db, + "LIGHTRAG_DOC_CHUNKS_model_3072d", + legacy_table_name="LIGHTRAG_DOC_CHUNKS", + base_table="LIGHTRAG_DOC_CHUNKS", + embedding_dim=3072, + workspace="test", + ) + + # Verify error message contains dimension information + assert "3072" in str(exc_info.value) or "1536" in str(exc_info.value) + + async def test_postgres_dimension_mismatch_raises_error_sampling(self): + """ + Test that PostgreSQL raises error when dimensions don't match (via sampling). + + Scenario: Legacy table vector sampling detects 1536d vs expected 3072d. + Expected: DataMigrationError is raised to prevent data corruption. + """ + db = AsyncMock() + + # Mock check_table_exists + async def mock_check_table_exists(table_name): + if table_name == "LIGHTRAG_DOC_CHUNKS": # legacy + return True + elif table_name == "LIGHTRAG_DOC_CHUNKS_model_3072d": # new + return False + return False + + db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists) + + # Mock table existence and dimension checks + async def query_side_effect(query, params, **kwargs): + if "information_schema.tables" in query: + if params[0] == "LIGHTRAG_DOC_CHUNKS": # legacy + return {"exists": True} + elif params[0] == "LIGHTRAG_DOC_CHUNKS_model_3072d": # new + return {"exists": False} + elif "COUNT(*)" in query: + return {"count": 100} # Legacy has data + elif "SELECT content_vector FROM" in query: + # Return sample vector with 1536 dimensions as a JSON string + return {"content_vector": json.dumps([0.1] * 1536)} + return {} + + db.query.side_effect = query_side_effect + db.execute = AsyncMock() + db._create_vector_index = AsyncMock() + + # Call setup_table with 3072d (different from legacy 1536d) + # Should raise DataMigrationError due to dimension mismatch + with pytest.raises(DataMigrationError) as exc_info: + await PGVectorStorage.setup_table( + db, + "LIGHTRAG_DOC_CHUNKS_model_3072d", + legacy_table_name="LIGHTRAG_DOC_CHUNKS", + base_table="LIGHTRAG_DOC_CHUNKS", + embedding_dim=3072, + workspace="test", + ) + + # Verify error message contains dimension information + assert "3072" in str(exc_info.value) or "1536" in str(exc_info.value) + + async def test_postgres_dimension_match_proceed_migration(self): + """ + Test that PostgreSQL proceeds with migration when dimensions match. + + Scenario: Legacy table has 1536d vectors, new model also expects 1536d. + Expected: Migration proceeds normally. + """ + db = AsyncMock() + + # Track migration state + migration_done = {"value": False} + + # Define exactly 2 records for consistency + mock_records = [ + { + "id": "test1", + "content_vector": [0.1] * 1536, + "workspace": "test", + }, + { + "id": "test2", + "content_vector": [0.2] * 1536, + "workspace": "test", + }, + ] + + # Mock check_table_exists + async def mock_check_table_exists(table_name): + if table_name == "LIGHTRAG_DOC_CHUNKS": # legacy exists + return True + elif table_name == "LIGHTRAG_DOC_CHUNKS_model_1536d": # new doesn't exist + return False + return False + + db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists) + + async def query_side_effect(query, params, **kwargs): + multirows = kwargs.get("multirows", False) + query_upper = query.upper() + + if "information_schema.tables" in query: + if params[0] == "LIGHTRAG_DOC_CHUNKS": # legacy + return {"exists": True} + elif params[0] == "LIGHTRAG_DOC_CHUNKS_model_1536d": # new + return {"exists": False} + elif "COUNT(*)" in query_upper: + # Return different counts based on table name in query and migration state + if "LIGHTRAG_DOC_CHUNKS_MODEL_1536D" in query_upper: + # After migration: return migrated count, before: return 0 + return { + "count": len(mock_records) if migration_done["value"] else 0 + } + # Legacy table always has 2 records (matching mock_records) + return {"count": len(mock_records)} + elif "PG_ATTRIBUTE" in query_upper: + return {"vector_dim": 1536} # Legacy has matching 1536d + elif "SELECT" in query_upper and "FROM" in query_upper and multirows: + # Return sample data for migration using keyset pagination + # Handle keyset pagination: params = [workspace, limit] or [workspace, last_id, limit] + if "id >" in query.lower(): + # Keyset pagination: params = [workspace, last_id, limit] + last_id = params[1] if len(params) > 1 else None + # Find records after last_id + found_idx = -1 + for i, rec in enumerate(mock_records): + if rec["id"] == last_id: + found_idx = i + break + if found_idx >= 0: + return mock_records[found_idx + 1 :] + return [] + else: + # First batch: params = [workspace, limit] + return mock_records + return {} + + db.query.side_effect = query_side_effect + + # Mock _run_with_retry to track when migration happens + migration_executed = [] + + async def mock_run_with_retry(operation, *args, **kwargs): + migration_executed.append(True) + migration_done["value"] = True + return None + + db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry) + db.execute = AsyncMock() + db._create_vector_index = AsyncMock() + + # Call setup_table with matching 1536d + await PGVectorStorage.setup_table( + db, + "LIGHTRAG_DOC_CHUNKS_model_1536d", + legacy_table_name="LIGHTRAG_DOC_CHUNKS", + base_table="LIGHTRAG_DOC_CHUNKS", + embedding_dim=1536, + workspace="test", + ) + + # Verify migration WAS called (via _run_with_retry for batch operations) + assert len(migration_executed) > 0, "Migration should have been executed" diff --git a/tests/test_no_model_suffix_safety.py b/tests/test_no_model_suffix_safety.py new file mode 100644 index 00000000..2f438d38 --- /dev/null +++ b/tests/test_no_model_suffix_safety.py @@ -0,0 +1,220 @@ +""" +Tests for safety when model suffix is absent (no model_name provided). + +This test module verifies that the system correctly handles the case when +no model_name is provided, preventing accidental deletion of the only table/collection +on restart. + +Critical Bug: When model_suffix is empty, table_name == legacy_table_name. +On second startup, Case 1 logic would delete the only table/collection thinking +it's "legacy", causing all subsequent operations to fail. +""" + +from unittest.mock import MagicMock, AsyncMock, patch + +from lightrag.kg.qdrant_impl import QdrantVectorDBStorage +from lightrag.kg.postgres_impl import PGVectorStorage + + +class TestNoModelSuffixSafety: + """Test suite for preventing data loss when model_suffix is absent.""" + + def test_qdrant_no_suffix_second_startup(self): + """ + Test Qdrant doesn't delete collection on second startup when no model_name. + + Scenario: + 1. First startup: Creates collection without suffix + 2. Collection is empty + 3. Second startup: Should NOT delete the collection + + Bug: Without fix, Case 1 would delete the only collection. + """ + from qdrant_client import models + + client = MagicMock() + + # Simulate second startup: collection already exists and is empty + # IMPORTANT: Without suffix, collection_name == legacy collection name + collection_name = "lightrag_vdb_chunks" # No suffix, same as legacy + + # Both exist (they're the same collection) + client.collection_exists.return_value = True + + # Collection is empty + client.count.return_value.count = 0 + + # Patch _find_legacy_collection to return the SAME collection name + # This simulates the scenario where new collection == legacy collection + with patch( + "lightrag.kg.qdrant_impl._find_legacy_collection", + return_value="lightrag_vdb_chunks", # Same as collection_name + ): + # Call setup_collection + # This should detect that new == legacy and skip deletion + QdrantVectorDBStorage.setup_collection( + client, + collection_name, + namespace="chunks", + workspace="_", + vectors_config=models.VectorParams( + size=1536, distance=models.Distance.COSINE + ), + hnsw_config=models.HnswConfigDiff( + payload_m=16, + m=0, + ), + model_suffix="", # Empty suffix to simulate no model_name provided + ) + + # CRITICAL: Collection should NOT be deleted + client.delete_collection.assert_not_called() + + # Verify we returned early (skipped Case 1 cleanup) + # The collection_exists was checked, but we didn't proceed to count + # because we detected same name + assert client.collection_exists.call_count >= 1 + + async def test_postgres_no_suffix_second_startup(self): + """ + Test PostgreSQL doesn't delete table on second startup when no model_name. + + Scenario: + 1. First startup: Creates table without suffix + 2. Table is empty + 3. Second startup: Should NOT delete the table + + Bug: Without fix, Case 1 would delete the only table. + """ + db = AsyncMock() + + # Configure mock return values to avoid unawaited coroutine warnings + db.query.return_value = {"count": 0} + db._create_vector_index.return_value = None + + # Simulate second startup: table already exists and is empty + # IMPORTANT: table_name and legacy_table_name are THE SAME + table_name = "LIGHTRAG_VDB_CHUNKS" # No suffix + legacy_table_name = "LIGHTRAG_VDB_CHUNKS" # Same as new + + # Setup mock responses using check_table_exists on db + async def check_table_exists_side_effect(name): + # Both tables exist (they're the same) + return True + + db.check_table_exists = AsyncMock(side_effect=check_table_exists_side_effect) + + # Call setup_table + # This should detect that new == legacy and skip deletion + await PGVectorStorage.setup_table( + db, + table_name, + workspace="test_workspace", + embedding_dim=1536, + legacy_table_name=legacy_table_name, + base_table="LIGHTRAG_VDB_CHUNKS", + ) + + # CRITICAL: Table should NOT be deleted (no DROP TABLE) + drop_calls = [ + call + for call in db.execute.call_args_list + if call[0][0] and "DROP TABLE" in call[0][0] + ] + assert ( + len(drop_calls) == 0 + ), "Should not drop table when new and legacy are the same" + + # Note: COUNT queries for workspace data are expected behavior in Case 1 + # (for logging/warning purposes when workspace data is empty). + # The critical safety check is that DROP TABLE is not called. + + def test_qdrant_with_suffix_case1_still_works(self): + """ + Test that Case 1 cleanup still works when there IS a suffix. + + This ensures our fix doesn't break the normal Case 1 scenario. + """ + from qdrant_client import models + + client = MagicMock() + + # Different names (normal case) + collection_name = "lightrag_vdb_chunks_ada_002_1536d" # With suffix + legacy_collection = "lightrag_vdb_chunks" # Without suffix + + # Setup: both exist + def collection_exists_side_effect(name): + return name in [collection_name, legacy_collection] + + client.collection_exists.side_effect = collection_exists_side_effect + + # Legacy is empty + client.count.return_value.count = 0 + + # Call setup_collection + QdrantVectorDBStorage.setup_collection( + client, + collection_name, + namespace="chunks", + workspace="_", + vectors_config=models.VectorParams( + size=1536, distance=models.Distance.COSINE + ), + hnsw_config=models.HnswConfigDiff( + payload_m=16, + m=0, + ), + model_suffix="ada_002_1536d", + ) + + # SHOULD delete legacy (normal Case 1 behavior) + client.delete_collection.assert_called_once_with( + collection_name=legacy_collection + ) + + async def test_postgres_with_suffix_case1_still_works(self): + """ + Test that Case 1 cleanup still works when there IS a suffix. + + This ensures our fix doesn't break the normal Case 1 scenario. + """ + db = AsyncMock() + + # Different names (normal case) + table_name = "LIGHTRAG_VDB_CHUNKS_ADA_002_1536D" # With suffix + legacy_table_name = "LIGHTRAG_VDB_CHUNKS" # Without suffix + + # Setup mock responses using check_table_exists on db + async def check_table_exists_side_effect(name): + # Both tables exist + return True + + db.check_table_exists = AsyncMock(side_effect=check_table_exists_side_effect) + + # Mock empty table + async def query_side_effect(sql, params, **kwargs): + if "COUNT(*)" in sql: + return {"count": 0} + return {} + + db.query.side_effect = query_side_effect + + # Call setup_table + await PGVectorStorage.setup_table( + db, + table_name, + workspace="test_workspace", + embedding_dim=1536, + legacy_table_name=legacy_table_name, + base_table="LIGHTRAG_VDB_CHUNKS", + ) + + # SHOULD delete legacy (normal Case 1 behavior) + drop_calls = [ + call + for call in db.execute.call_args_list + if call[0][0] and "DROP TABLE" in call[0][0] + ] + assert len(drop_calls) == 1, "Should drop legacy table in normal Case 1" + assert legacy_table_name in drop_calls[0][0][0] diff --git a/tests/test_postgres_index_name.py b/tests/test_postgres_index_name.py new file mode 100644 index 00000000..e0af9834 --- /dev/null +++ b/tests/test_postgres_index_name.py @@ -0,0 +1,210 @@ +""" +Unit tests for PostgreSQL safe index name generation. + +This module tests the _safe_index_name helper function which prevents +PostgreSQL's silent 63-byte identifier truncation from causing index +lookup failures. +""" + +import pytest + +# Mark all tests as offline (no external dependencies) +pytestmark = pytest.mark.offline + + +class TestSafeIndexName: + """Test suite for _safe_index_name function.""" + + def test_short_name_unchanged(self): + """Short index names should remain unchanged.""" + from lightrag.kg.postgres_impl import _safe_index_name + + # Short table name - should return unchanged + result = _safe_index_name("lightrag_vdb_entity", "hnsw_cosine") + assert result == "idx_lightrag_vdb_entity_hnsw_cosine" + assert len(result.encode("utf-8")) <= 63 + + def test_long_name_gets_hashed(self): + """Long table names exceeding 63 bytes should get hashed.""" + from lightrag.kg.postgres_impl import _safe_index_name + + # Long table name that would exceed 63 bytes + long_table_name = "LIGHTRAG_VDB_ENTITY_text_embedding_3_large_3072d" + result = _safe_index_name(long_table_name, "hnsw_cosine") + + # Should be within 63 bytes + assert len(result.encode("utf-8")) <= 63 + + # Should start with idx_ prefix + assert result.startswith("idx_") + + # Should contain the suffix + assert result.endswith("_hnsw_cosine") + + # Should NOT be the naive concatenation (which would be truncated) + naive_name = f"idx_{long_table_name.lower()}_hnsw_cosine" + assert result != naive_name + + def test_deterministic_output(self): + """Same input should always produce same output (deterministic).""" + from lightrag.kg.postgres_impl import _safe_index_name + + table_name = "LIGHTRAG_VDB_CHUNKS_text_embedding_3_large_3072d" + suffix = "hnsw_cosine" + + result1 = _safe_index_name(table_name, suffix) + result2 = _safe_index_name(table_name, suffix) + + assert result1 == result2 + + def test_different_suffixes_different_results(self): + """Different suffixes should produce different index names.""" + from lightrag.kg.postgres_impl import _safe_index_name + + table_name = "LIGHTRAG_VDB_ENTITY_text_embedding_3_large_3072d" + + result1 = _safe_index_name(table_name, "hnsw_cosine") + result2 = _safe_index_name(table_name, "ivfflat_cosine") + + assert result1 != result2 + + def test_case_insensitive(self): + """Table names should be normalized to lowercase.""" + from lightrag.kg.postgres_impl import _safe_index_name + + result_upper = _safe_index_name("LIGHTRAG_VDB_ENTITY", "hnsw_cosine") + result_lower = _safe_index_name("lightrag_vdb_entity", "hnsw_cosine") + + assert result_upper == result_lower + + def test_boundary_case_exactly_63_bytes(self): + """Test boundary case where name is exactly at 63-byte limit.""" + from lightrag.kg.postgres_impl import _safe_index_name + + # Create a table name that results in exactly 63 bytes + # idx_ (4) + table_name + _ (1) + suffix = 63 + # So table_name + suffix = 58 + + # Test a name that's just under the limit (should remain unchanged) + short_suffix = "id" + # idx_ (4) + 56 chars + _ (1) + id (2) = 63 + table_56 = "a" * 56 + result = _safe_index_name(table_56, short_suffix) + expected = f"idx_{table_56}_{short_suffix}" + assert result == expected + assert len(result.encode("utf-8")) == 63 + + def test_unicode_handling(self): + """Unicode characters should be properly handled (bytes, not chars).""" + from lightrag.kg.postgres_impl import _safe_index_name + + # Unicode characters can take more bytes than visible chars + # Chinese characters are 3 bytes each in UTF-8 + table_name = "lightrag_测试_table" # Contains Chinese chars + result = _safe_index_name(table_name, "hnsw_cosine") + + # Should always be within 63 bytes + assert len(result.encode("utf-8")) <= 63 + + def test_real_world_model_names(self): + """Test with real-world embedding model names that cause issues.""" + from lightrag.kg.postgres_impl import _safe_index_name + + # These are actual model names that have caused issues + test_cases = [ + ("LIGHTRAG_VDB_CHUNKS_text_embedding_3_large_3072d", "hnsw_cosine"), + ("LIGHTRAG_VDB_ENTITY_text_embedding_3_large_3072d", "hnsw_cosine"), + ("LIGHTRAG_VDB_RELATION_text_embedding_3_large_3072d", "hnsw_cosine"), + ( + "LIGHTRAG_VDB_ENTITY_bge_m3_1024d", + "hnsw_cosine", + ), # Shorter model name + ( + "LIGHTRAG_VDB_CHUNKS_nomic_embed_text_v1_768d", + "ivfflat_cosine", + ), # Different index type + ] + + for table_name, suffix in test_cases: + result = _safe_index_name(table_name, suffix) + + # Critical: must be within PostgreSQL's 63-byte limit + assert ( + len(result.encode("utf-8")) <= 63 + ), f"Index name too long: {result} for table {table_name}" + + # Must have consistent format + assert result.startswith("idx_"), f"Missing idx_ prefix: {result}" + assert result.endswith(f"_{suffix}"), f"Missing suffix {suffix}: {result}" + + def test_hash_uniqueness_for_similar_tables(self): + """Similar but different table names should produce different hashes.""" + from lightrag.kg.postgres_impl import _safe_index_name + + # These tables have similar names but should have different hashes + tables = [ + "LIGHTRAG_VDB_CHUNKS_model_a_1024d", + "LIGHTRAG_VDB_CHUNKS_model_b_1024d", + "LIGHTRAG_VDB_ENTITY_model_a_1024d", + ] + + results = [_safe_index_name(t, "hnsw_cosine") for t in tables] + + # All results should be unique + assert len(set(results)) == len(results), "Hash collision detected!" + + +class TestIndexNameIntegration: + """Integration-style tests for index name usage patterns.""" + + def test_pg_indexes_lookup_compatibility(self): + """ + Test that the generated index name will work with pg_indexes lookup. + + This is the core problem: PostgreSQL stores the truncated name, + but we were looking up the untruncated name. Our fix ensures we + always use a name that fits within 63 bytes. + """ + from lightrag.kg.postgres_impl import _safe_index_name + + table_name = "LIGHTRAG_VDB_CHUNKS_text_embedding_3_large_3072d" + suffix = "hnsw_cosine" + + # Generate the index name + index_name = _safe_index_name(table_name, suffix) + + # Simulate what PostgreSQL would store (truncate at 63 bytes) + stored_name = index_name.encode("utf-8")[:63].decode("utf-8", errors="ignore") + + # The key fix: our generated name should equal the stored name + # because it's already within the 63-byte limit + assert ( + index_name == stored_name + ), "Index name would be truncated by PostgreSQL, causing lookup failures!" + + def test_backward_compatibility_short_names(self): + """ + Ensure backward compatibility with existing short index names. + + For tables that have existing indexes with short names (pre-model-suffix era), + the function should not change their names. + """ + from lightrag.kg.postgres_impl import _safe_index_name + + # Legacy table names without model suffix + legacy_tables = [ + "LIGHTRAG_VDB_ENTITY", + "LIGHTRAG_VDB_RELATION", + "LIGHTRAG_VDB_CHUNKS", + ] + + for table in legacy_tables: + for suffix in ["hnsw_cosine", "ivfflat_cosine", "id"]: + result = _safe_index_name(table, suffix) + expected = f"idx_{table.lower()}_{suffix}" + + # Short names should remain unchanged for backward compatibility + if len(expected.encode("utf-8")) <= 63: + assert ( + result == expected + ), f"Short name changed unexpectedly: {result} != {expected}" diff --git a/tests/test_postgres_migration.py b/tests/test_postgres_migration.py new file mode 100644 index 00000000..ce431a1d --- /dev/null +++ b/tests/test_postgres_migration.py @@ -0,0 +1,856 @@ +import pytest +from unittest.mock import patch, AsyncMock +import numpy as np +from lightrag.utils import EmbeddingFunc +from lightrag.kg.postgres_impl import ( + PGVectorStorage, +) +from lightrag.namespace import NameSpace + + +# Mock PostgreSQLDB +@pytest.fixture +def mock_pg_db(): + """Mock PostgreSQL database connection""" + db = AsyncMock() + db.workspace = "test_workspace" + + # Mock query responses with multirows support + async def mock_query(sql, params=None, multirows=False, **kwargs): + # Default return value + if multirows: + return [] # Return empty list for multirows + return {"exists": False, "count": 0} + + # Mock for execute that mimics PostgreSQLDB.execute() behavior + async def mock_execute(sql, data=None, **kwargs): + """ + Mock that mimics PostgreSQLDB.execute() behavior: + - Accepts data as dict[str, Any] | None (second parameter) + - Internally converts dict.values() to tuple for AsyncPG + """ + # Mimic real execute() which accepts dict and converts to tuple + if data is not None and not isinstance(data, dict): + raise TypeError( + f"PostgreSQLDB.execute() expects data as dict, got {type(data).__name__}" + ) + return None + + db.query = AsyncMock(side_effect=mock_query) + db.execute = AsyncMock(side_effect=mock_execute) + + return db + + +# Mock get_data_init_lock to avoid async lock issues in tests +@pytest.fixture(autouse=True) +def mock_data_init_lock(): + with patch("lightrag.kg.postgres_impl.get_data_init_lock") as mock_lock: + mock_lock_ctx = AsyncMock() + mock_lock.return_value = mock_lock_ctx + yield mock_lock + + +# Mock ClientManager +@pytest.fixture +def mock_client_manager(mock_pg_db): + with patch("lightrag.kg.postgres_impl.ClientManager") as mock_manager: + mock_manager.get_client = AsyncMock(return_value=mock_pg_db) + mock_manager.release_client = AsyncMock() + yield mock_manager + + +# Mock Embedding function +@pytest.fixture +def mock_embedding_func(): + async def embed_func(texts, **kwargs): + return np.array([[0.1] * 768 for _ in texts]) + + func = EmbeddingFunc(embedding_dim=768, func=embed_func, model_name="test_model") + return func + + +async def test_postgres_table_naming( + mock_client_manager, mock_pg_db, mock_embedding_func +): + """Test if table name is correctly generated with model suffix""" + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + storage = PGVectorStorage( + namespace=NameSpace.VECTOR_STORE_CHUNKS, + global_config=config, + embedding_func=mock_embedding_func, + workspace="test_ws", + ) + + # Verify table name contains model suffix + expected_suffix = "test_model_768d" + assert expected_suffix in storage.table_name + assert storage.table_name == f"LIGHTRAG_VDB_CHUNKS_{expected_suffix}" + + # Verify legacy table name + assert storage.legacy_table_name == "LIGHTRAG_VDB_CHUNKS" + + +async def test_postgres_migration_trigger( + mock_client_manager, mock_pg_db, mock_embedding_func +): + """Test if migration logic is triggered correctly""" + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + storage = PGVectorStorage( + namespace=NameSpace.VECTOR_STORE_CHUNKS, + global_config=config, + embedding_func=mock_embedding_func, + workspace="test_ws", + ) + + # Setup mocks for migration scenario + # 1. New table does not exist, legacy table exists + async def mock_check_table_exists(table_name): + return table_name == storage.legacy_table_name + + mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists) + + # 2. Legacy table has 100 records + mock_rows = [ + {"id": f"test_id_{i}", "content": f"content_{i}", "workspace": "test_ws"} + for i in range(100) + ] + migration_state = {"new_table_count": 0} + + async def mock_query(sql, params=None, multirows=False, **kwargs): + if "COUNT(*)" in sql: + sql_upper = sql.upper() + legacy_table = storage.legacy_table_name.upper() + new_table = storage.table_name.upper() + is_new_table = new_table in sql_upper + is_legacy_table = legacy_table in sql_upper and not is_new_table + + if is_new_table: + return {"count": migration_state["new_table_count"]} + if is_legacy_table: + return {"count": 100} + return {"count": 0} + elif multirows and "SELECT *" in sql: + # Mock batch fetch for migration using keyset pagination + # New pattern: WHERE workspace = $1 AND id > $2 ORDER BY id LIMIT $3 + # or first batch: WHERE workspace = $1 ORDER BY id LIMIT $2 + if "WHERE workspace" in sql: + if "id >" in sql: + # Keyset pagination: params = [workspace, last_id, limit] + last_id = params[1] if len(params) > 1 else None + # Find rows after last_id + start_idx = 0 + for i, row in enumerate(mock_rows): + if row["id"] == last_id: + start_idx = i + 1 + break + limit = params[2] if len(params) > 2 else 500 + else: + # First batch (no last_id): params = [workspace, limit] + start_idx = 0 + limit = params[1] if len(params) > 1 else 500 + else: + # No workspace filter with keyset + if "id >" in sql: + last_id = params[0] if params else None + start_idx = 0 + for i, row in enumerate(mock_rows): + if row["id"] == last_id: + start_idx = i + 1 + break + limit = params[1] if len(params) > 1 else 500 + else: + start_idx = 0 + limit = params[0] if params else 500 + end = min(start_idx + limit, len(mock_rows)) + return mock_rows[start_idx:end] + return {} + + mock_pg_db.query = AsyncMock(side_effect=mock_query) + + # Track migration through _run_with_retry calls + migration_executed = [] + + async def mock_run_with_retry(operation, **kwargs): + # Track that migration batch operation was called + migration_executed.append(True) + migration_state["new_table_count"] = 100 + return None + + mock_pg_db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry) + + with patch( + "lightrag.kg.postgres_impl.PGVectorStorage._pg_create_table", AsyncMock() + ): + # Initialize storage (should trigger migration) + await storage.initialize() + + # Verify migration was executed by checking _run_with_retry was called + # (batch migration uses _run_with_retry with executemany) + assert len(migration_executed) > 0, "Migration should have been executed" + + +async def test_postgres_no_migration_needed( + mock_client_manager, mock_pg_db, mock_embedding_func +): + """Test scenario where new table already exists (no migration needed)""" + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + storage = PGVectorStorage( + namespace=NameSpace.VECTOR_STORE_CHUNKS, + global_config=config, + embedding_func=mock_embedding_func, + workspace="test_ws", + ) + + # Mock: new table already exists + async def mock_check_table_exists(table_name): + return table_name == storage.table_name + + mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists) + + with patch( + "lightrag.kg.postgres_impl.PGVectorStorage._pg_create_table", AsyncMock() + ) as mock_create: + await storage.initialize() + + # Verify no table creation was attempted + mock_create.assert_not_called() + + +async def test_scenario_1_new_workspace_creation( + mock_client_manager, mock_pg_db, mock_embedding_func +): + """ + Scenario 1: New workspace creation + + Expected behavior: + - No legacy table exists + - Directly create new table with model suffix + - No migration needed + """ + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + embedding_func = EmbeddingFunc( + embedding_dim=3072, + func=mock_embedding_func.func, + model_name="text-embedding-3-large", + ) + + storage = PGVectorStorage( + namespace=NameSpace.VECTOR_STORE_CHUNKS, + global_config=config, + embedding_func=embedding_func, + workspace="new_workspace", + ) + + # Mock: neither table exists + async def mock_check_table_exists(table_name): + return False + + mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists) + + with patch( + "lightrag.kg.postgres_impl.PGVectorStorage._pg_create_table", AsyncMock() + ) as mock_create: + await storage.initialize() + + # Verify table name format + assert "text_embedding_3_large_3072d" in storage.table_name + + # Verify new table creation was called + mock_create.assert_called_once() + call_args = mock_create.call_args + assert ( + call_args[0][1] == storage.table_name + ) # table_name is second positional arg + + +async def test_scenario_2_legacy_upgrade_migration( + mock_client_manager, mock_pg_db, mock_embedding_func +): + """ + Scenario 2: Upgrade from legacy version + + Expected behavior: + - Legacy table exists (without model suffix) + - New table doesn't exist + - Automatically migrate data to new table with suffix + """ + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + embedding_func = EmbeddingFunc( + embedding_dim=1536, + func=mock_embedding_func.func, + model_name="text-embedding-ada-002", + ) + + storage = PGVectorStorage( + namespace=NameSpace.VECTOR_STORE_CHUNKS, + global_config=config, + embedding_func=embedding_func, + workspace="legacy_workspace", + ) + + # Mock: only legacy table exists + async def mock_check_table_exists(table_name): + return table_name == storage.legacy_table_name + + mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists) + + # Mock: legacy table has 50 records + mock_rows = [ + { + "id": f"legacy_id_{i}", + "content": f"legacy_content_{i}", + "workspace": "legacy_workspace", + } + for i in range(50) + ] + + # Track which queries have been made for proper response + query_history = [] + migration_state = {"new_table_count": 0} + + async def mock_query(sql, params=None, multirows=False, **kwargs): + query_history.append(sql) + + if "COUNT(*)" in sql: + # Determine table type: + # - Legacy: contains base name but NOT model suffix + # - New: contains model suffix (e.g., text_embedding_ada_002_1536d) + sql_upper = sql.upper() + base_name = storage.legacy_table_name.upper() + + # Check if this is querying the new table (has model suffix) + has_model_suffix = storage.table_name.upper() in sql_upper + + is_legacy_table = base_name in sql_upper and not has_model_suffix + has_workspace_filter = "WHERE workspace" in sql + + if is_legacy_table and has_workspace_filter: + # Count for legacy table with workspace filter (before migration) + return {"count": 50} + elif is_legacy_table and not has_workspace_filter: + # Total count for legacy table + return {"count": 50} + else: + # New table count (before/after migration) + return {"count": migration_state["new_table_count"]} + elif multirows and "SELECT *" in sql: + # Mock batch fetch for migration using keyset pagination + # New pattern: WHERE workspace = $1 AND id > $2 ORDER BY id LIMIT $3 + # or first batch: WHERE workspace = $1 ORDER BY id LIMIT $2 + if "WHERE workspace" in sql: + if "id >" in sql: + # Keyset pagination: params = [workspace, last_id, limit] + last_id = params[1] if len(params) > 1 else None + # Find rows after last_id + start_idx = 0 + for i, row in enumerate(mock_rows): + if row["id"] == last_id: + start_idx = i + 1 + break + limit = params[2] if len(params) > 2 else 500 + else: + # First batch (no last_id): params = [workspace, limit] + start_idx = 0 + limit = params[1] if len(params) > 1 else 500 + else: + # No workspace filter with keyset + if "id >" in sql: + last_id = params[0] if params else None + start_idx = 0 + for i, row in enumerate(mock_rows): + if row["id"] == last_id: + start_idx = i + 1 + break + limit = params[1] if len(params) > 1 else 500 + else: + start_idx = 0 + limit = params[0] if params else 500 + end = min(start_idx + limit, len(mock_rows)) + return mock_rows[start_idx:end] + return {} + + mock_pg_db.query = AsyncMock(side_effect=mock_query) + + # Track migration through _run_with_retry calls + migration_executed = [] + + async def mock_run_with_retry(operation, **kwargs): + # Track that migration batch operation was called + migration_executed.append(True) + migration_state["new_table_count"] = 50 + return None + + mock_pg_db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry) + + with patch( + "lightrag.kg.postgres_impl.PGVectorStorage._pg_create_table", AsyncMock() + ) as mock_create: + await storage.initialize() + + # Verify table name contains ada-002 + assert "text_embedding_ada_002_1536d" in storage.table_name + + # Verify migration was executed (batch migration uses _run_with_retry) + assert len(migration_executed) > 0, "Migration should have been executed" + mock_create.assert_called_once() + + +async def test_scenario_3_multi_model_coexistence( + mock_client_manager, mock_pg_db, mock_embedding_func +): + """ + Scenario 3: Multiple embedding models coexist + + Expected behavior: + - Different embedding models create separate tables + - Tables are isolated by model suffix + - No interference between different models + """ + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + # Workspace A: uses bge-small (768d) + embedding_func_a = EmbeddingFunc( + embedding_dim=768, func=mock_embedding_func.func, model_name="bge-small" + ) + + storage_a = PGVectorStorage( + namespace=NameSpace.VECTOR_STORE_CHUNKS, + global_config=config, + embedding_func=embedding_func_a, + workspace="workspace_a", + ) + + # Workspace B: uses bge-large (1024d) + async def embed_func_b(texts, **kwargs): + return np.array([[0.1] * 1024 for _ in texts]) + + embedding_func_b = EmbeddingFunc( + embedding_dim=1024, func=embed_func_b, model_name="bge-large" + ) + + storage_b = PGVectorStorage( + namespace=NameSpace.VECTOR_STORE_CHUNKS, + global_config=config, + embedding_func=embedding_func_b, + workspace="workspace_b", + ) + + # Verify different table names + assert storage_a.table_name != storage_b.table_name + assert "bge_small_768d" in storage_a.table_name + assert "bge_large_1024d" in storage_b.table_name + + # Mock: both tables don't exist yet + async def mock_check_table_exists(table_name): + return False + + mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists) + + with patch( + "lightrag.kg.postgres_impl.PGVectorStorage._pg_create_table", AsyncMock() + ) as mock_create: + # Initialize both storages + await storage_a.initialize() + await storage_b.initialize() + + # Verify two separate tables were created + assert mock_create.call_count == 2 + + # Verify table names are different + call_args_list = mock_create.call_args_list + table_names = [call[0][1] for call in call_args_list] # Second positional arg + assert len(set(table_names)) == 2 # Two unique table names + assert storage_a.table_name in table_names + assert storage_b.table_name in table_names + + +async def test_case1_empty_legacy_auto_cleanup( + mock_client_manager, mock_pg_db, mock_embedding_func +): + """ + Case 1a: Both new and legacy tables exist, but legacy is EMPTY + Expected: Automatically delete empty legacy table (safe cleanup) + """ + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + embedding_func = EmbeddingFunc( + embedding_dim=1536, + func=mock_embedding_func.func, + model_name="test-model", + ) + + storage = PGVectorStorage( + namespace=NameSpace.VECTOR_STORE_CHUNKS, + global_config=config, + embedding_func=embedding_func, + workspace="test_ws", + ) + + # Mock: Both tables exist + async def mock_check_table_exists(table_name): + return True # Both new and legacy exist + + mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists) + + # Mock: Legacy table is empty (0 records) + async def mock_query(sql, params=None, multirows=False, **kwargs): + if "COUNT(*)" in sql: + if storage.legacy_table_name in sql: + return {"count": 0} # Empty legacy table + else: + return {"count": 100} # New table has data + return {} + + mock_pg_db.query = AsyncMock(side_effect=mock_query) + + with patch("lightrag.kg.postgres_impl.logger"): + await storage.initialize() + + # Verify: Empty legacy table should be automatically cleaned up + # Empty tables are safe to delete without data loss risk + delete_calls = [ + call + for call in mock_pg_db.execute.call_args_list + if call[0][0] and "DROP TABLE" in call[0][0] + ] + assert len(delete_calls) >= 1, "Empty legacy table should be auto-deleted" + # Check if legacy table was dropped + dropped_table = storage.legacy_table_name + assert any( + dropped_table in str(call) for call in delete_calls + ), f"Expected to drop empty legacy table '{dropped_table}'" + + print( + f"✅ Case 1a: Empty legacy table '{dropped_table}' auto-deleted successfully" + ) + + +async def test_case1_nonempty_legacy_warning( + mock_client_manager, mock_pg_db, mock_embedding_func +): + """ + Case 1b: Both new and legacy tables exist, and legacy HAS DATA + Expected: Log warning, do not delete legacy (preserve data) + """ + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + embedding_func = EmbeddingFunc( + embedding_dim=1536, + func=mock_embedding_func.func, + model_name="test-model", + ) + + storage = PGVectorStorage( + namespace=NameSpace.VECTOR_STORE_CHUNKS, + global_config=config, + embedding_func=embedding_func, + workspace="test_ws", + ) + + # Mock: Both tables exist + async def mock_check_table_exists(table_name): + return True # Both new and legacy exist + + mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists) + + # Mock: Legacy table has data (50 records) + async def mock_query(sql, params=None, multirows=False, **kwargs): + if "COUNT(*)" in sql: + if storage.legacy_table_name in sql: + return {"count": 50} # Legacy has data + else: + return {"count": 100} # New table has data + return {} + + mock_pg_db.query = AsyncMock(side_effect=mock_query) + + with patch("lightrag.kg.postgres_impl.logger"): + await storage.initialize() + + # Verify: Legacy table with data should be preserved + # We never auto-delete tables that contain data to prevent accidental data loss + delete_calls = [ + call + for call in mock_pg_db.execute.call_args_list + if call[0][0] and "DROP TABLE" in call[0][0] + ] + # Check if legacy table was deleted (it should not be) + dropped_table = storage.legacy_table_name + legacy_deleted = any(dropped_table in str(call) for call in delete_calls) + assert not legacy_deleted, "Legacy table with data should NOT be auto-deleted" + + print( + f"✅ Case 1b: Legacy table '{dropped_table}' with data preserved (warning only)" + ) + + +async def test_case1_sequential_workspace_migration( + mock_client_manager, mock_pg_db, mock_embedding_func +): + """ + Case 1c: Sequential workspace migration (Multi-tenant scenario) + + Critical bug fix verification: + Timeline: + 1. Legacy table has workspace_a (3 records) + workspace_b (3 records) + 2. Workspace A initializes first → Case 3 (only legacy exists) → migrates A's data + 3. Workspace B initializes later → Case 3 (both tables exist, legacy has B's data) → should migrate B's data + 4. Verify workspace B's data is correctly migrated to new table + + This test verifies the migration logic correctly handles multi-tenant scenarios + where different workspaces migrate sequentially. + """ + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + embedding_func = EmbeddingFunc( + embedding_dim=1536, + func=mock_embedding_func.func, + model_name="test-model", + ) + + # Mock data: Legacy table has 6 records total (3 from workspace_a, 3 from workspace_b) + mock_rows_a = [ + {"id": f"a_{i}", "content": f"A content {i}", "workspace": "workspace_a"} + for i in range(3) + ] + mock_rows_b = [ + {"id": f"b_{i}", "content": f"B content {i}", "workspace": "workspace_b"} + for i in range(3) + ] + + # Track migration state + migration_state = { + "new_table_exists": False, + "workspace_a_migrated": False, + "workspace_a_migration_count": 0, + "workspace_b_migration_count": 0, + } + + # Step 1: Simulate workspace_a initialization (Case 3 - only legacy exists) + # CRITICAL: Set db.workspace to workspace_a + mock_pg_db.workspace = "workspace_a" + + storage_a = PGVectorStorage( + namespace=NameSpace.VECTOR_STORE_CHUNKS, + global_config=config, + embedding_func=embedding_func, + workspace="workspace_a", + ) + + # Mock table_exists for workspace_a + async def mock_check_table_exists_a(table_name): + if table_name == storage_a.legacy_table_name: + return True + if table_name == storage_a.table_name: + return migration_state["new_table_exists"] + return False + + mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists_a) + + # Mock query for workspace_a (Case 3) + async def mock_query_a(sql, params=None, multirows=False, **kwargs): + sql_upper = sql.upper() + base_name = storage_a.legacy_table_name.upper() + + if "COUNT(*)" in sql: + has_model_suffix = "TEST_MODEL_1536D" in sql_upper + is_legacy = base_name in sql_upper and not has_model_suffix + has_workspace_filter = "WHERE workspace" in sql + + if is_legacy and has_workspace_filter: + workspace = params[0] if params and len(params) > 0 else None + if workspace == "workspace_a": + return {"count": 3} + elif workspace == "workspace_b": + return {"count": 3} + elif is_legacy and not has_workspace_filter: + # Global count in legacy table + return {"count": 6} + elif has_model_suffix: + if has_workspace_filter: + workspace = params[0] if params and len(params) > 0 else None + if workspace == "workspace_a": + return {"count": migration_state["workspace_a_migration_count"]} + if workspace == "workspace_b": + return {"count": migration_state["workspace_b_migration_count"]} + return { + "count": migration_state["workspace_a_migration_count"] + + migration_state["workspace_b_migration_count"] + } + elif multirows and "SELECT *" in sql: + if "WHERE workspace" in sql: + workspace = params[0] if params and len(params) > 0 else None + if workspace == "workspace_a": + # Handle keyset pagination + if "id >" in sql: + # params = [workspace, last_id, limit] + last_id = params[1] if len(params) > 1 else None + start_idx = 0 + for i, row in enumerate(mock_rows_a): + if row["id"] == last_id: + start_idx = i + 1 + break + limit = params[2] if len(params) > 2 else 500 + else: + # First batch: params = [workspace, limit] + start_idx = 0 + limit = params[1] if len(params) > 1 else 500 + end = min(start_idx + limit, len(mock_rows_a)) + return mock_rows_a[start_idx:end] + return {} + + mock_pg_db.query = AsyncMock(side_effect=mock_query_a) + + # Track migration via _run_with_retry (batch migration uses this) + migration_a_executed = [] + + async def mock_run_with_retry_a(operation, **kwargs): + migration_a_executed.append(True) + migration_state["workspace_a_migration_count"] = len(mock_rows_a) + return None + + mock_pg_db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry_a) + + # Initialize workspace_a (Case 3) + with patch("lightrag.kg.postgres_impl.logger"): + await storage_a.initialize() + migration_state["new_table_exists"] = True + migration_state["workspace_a_migrated"] = True + + print("✅ Step 1: Workspace A initialized") + # Verify migration was executed via _run_with_retry (batch migration uses executemany) + assert ( + len(migration_a_executed) > 0 + ), "Migration should have been executed for workspace_a" + print(f"✅ Step 1: Migration executed {len(migration_a_executed)} batch(es)") + + # Step 2: Simulate workspace_b initialization (Case 3 - both exist, but legacy has B's data) + # CRITICAL: Set db.workspace to workspace_b + mock_pg_db.workspace = "workspace_b" + + storage_b = PGVectorStorage( + namespace=NameSpace.VECTOR_STORE_CHUNKS, + global_config=config, + embedding_func=embedding_func, + workspace="workspace_b", + ) + + mock_pg_db.reset_mock() + + # Mock table_exists for workspace_b (both exist) + async def mock_check_table_exists_b(table_name): + return True # Both tables exist + + mock_pg_db.check_table_exists = AsyncMock(side_effect=mock_check_table_exists_b) + + # Mock query for workspace_b (Case 3) + async def mock_query_b(sql, params=None, multirows=False, **kwargs): + sql_upper = sql.upper() + base_name = storage_b.legacy_table_name.upper() + + if "COUNT(*)" in sql: + has_model_suffix = "TEST_MODEL_1536D" in sql_upper + is_legacy = base_name in sql_upper and not has_model_suffix + has_workspace_filter = "WHERE workspace" in sql + + if is_legacy and has_workspace_filter: + workspace = params[0] if params and len(params) > 0 else None + if workspace == "workspace_b": + return {"count": 3} # workspace_b still has data in legacy + elif workspace == "workspace_a": + return {"count": 0} # workspace_a already migrated + elif is_legacy and not has_workspace_filter: + # Global count: only workspace_b data remains + return {"count": 3} + elif has_model_suffix: + if has_workspace_filter: + workspace = params[0] if params and len(params) > 0 else None + if workspace == "workspace_b": + return {"count": migration_state["workspace_b_migration_count"]} + elif workspace == "workspace_a": + return {"count": 3} + else: + return {"count": 3 + migration_state["workspace_b_migration_count"]} + elif multirows and "SELECT *" in sql: + if "WHERE workspace" in sql: + workspace = params[0] if params and len(params) > 0 else None + if workspace == "workspace_b": + # Handle keyset pagination + if "id >" in sql: + # params = [workspace, last_id, limit] + last_id = params[1] if len(params) > 1 else None + start_idx = 0 + for i, row in enumerate(mock_rows_b): + if row["id"] == last_id: + start_idx = i + 1 + break + limit = params[2] if len(params) > 2 else 500 + else: + # First batch: params = [workspace, limit] + start_idx = 0 + limit = params[1] if len(params) > 1 else 500 + end = min(start_idx + limit, len(mock_rows_b)) + return mock_rows_b[start_idx:end] + return {} + + mock_pg_db.query = AsyncMock(side_effect=mock_query_b) + + # Track migration via _run_with_retry for workspace_b + migration_b_executed = [] + + async def mock_run_with_retry_b(operation, **kwargs): + migration_b_executed.append(True) + migration_state["workspace_b_migration_count"] = len(mock_rows_b) + return None + + mock_pg_db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry_b) + + # Initialize workspace_b (Case 3 - both tables exist) + with patch("lightrag.kg.postgres_impl.logger"): + await storage_b.initialize() + + print("✅ Step 2: Workspace B initialized") + + # Verify workspace_b migration happens when new table has no workspace_b data + # but legacy table still has workspace_b data. + assert ( + len(migration_b_executed) > 0 + ), "Migration should have been executed for workspace_b" + print("✅ Step 2: Migration executed for workspace_b") + + print("\n🎉 Case 1c: Sequential workspace migration verification complete!") + print(" - Workspace A: Migrated successfully (only legacy existed)") + print(" - Workspace B: Migrated successfully (new table empty for workspace_b)") diff --git a/tests/test_qdrant_migration.py b/tests/test_qdrant_migration.py new file mode 100644 index 00000000..25d4eca9 --- /dev/null +++ b/tests/test_qdrant_migration.py @@ -0,0 +1,562 @@ +import pytest +from unittest.mock import MagicMock, patch, AsyncMock +import numpy as np +from qdrant_client import models +from lightrag.utils import EmbeddingFunc +from lightrag.kg.qdrant_impl import QdrantVectorDBStorage + + +# Mock QdrantClient +@pytest.fixture +def mock_qdrant_client(): + with patch("lightrag.kg.qdrant_impl.QdrantClient") as mock_client_cls: + client = mock_client_cls.return_value + client.collection_exists.return_value = False + client.count.return_value.count = 0 + # Mock payload schema and vector config for get_collection + collection_info = MagicMock() + collection_info.payload_schema = {} + # Mock vector dimension to match mock_embedding_func (768d) + collection_info.config.params.vectors.size = 768 + client.get_collection.return_value = collection_info + yield client + + +# Mock get_data_init_lock to avoid async lock issues in tests +@pytest.fixture(autouse=True) +def mock_data_init_lock(): + with patch("lightrag.kg.qdrant_impl.get_data_init_lock") as mock_lock: + mock_lock_ctx = AsyncMock() + mock_lock.return_value = mock_lock_ctx + yield mock_lock + + +# Mock Embedding function +@pytest.fixture +def mock_embedding_func(): + async def embed_func(texts, **kwargs): + return np.array([[0.1] * 768 for _ in texts]) + + func = EmbeddingFunc(embedding_dim=768, func=embed_func, model_name="test-model") + return func + + +async def test_qdrant_collection_naming(mock_qdrant_client, mock_embedding_func): + """Test if collection name is correctly generated with model suffix""" + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + storage = QdrantVectorDBStorage( + namespace="chunks", + global_config=config, + embedding_func=mock_embedding_func, + workspace="test_ws", + ) + + # Verify collection name contains model suffix + expected_suffix = "test_model_768d" + assert expected_suffix in storage.final_namespace + assert storage.final_namespace == f"lightrag_vdb_chunks_{expected_suffix}" + + +async def test_qdrant_migration_trigger(mock_qdrant_client, mock_embedding_func): + """Test if migration logic is triggered correctly""" + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + storage = QdrantVectorDBStorage( + namespace="chunks", + global_config=config, + embedding_func=mock_embedding_func, + workspace="test_ws", + ) + + # Legacy collection name (without model suffix) + legacy_collection = "lightrag_vdb_chunks" + + # Setup mocks for migration scenario + # 1. New collection does not exist, only legacy exists + mock_qdrant_client.collection_exists.side_effect = ( + lambda name: name == legacy_collection + ) + + # 2. Legacy collection exists and has data + migration_state = {"new_workspace_count": 0} + + def count_mock(collection_name, exact=True, count_filter=None): + mock_result = MagicMock() + if collection_name == legacy_collection: + mock_result.count = 100 + elif collection_name == storage.final_namespace: + mock_result.count = migration_state["new_workspace_count"] + else: + mock_result.count = 0 + return mock_result + + mock_qdrant_client.count.side_effect = count_mock + + # 3. Mock scroll for data migration + mock_point = MagicMock() + mock_point.id = "old_id" + mock_point.vector = [0.1] * 768 + mock_point.payload = {"content": "test"} # No workspace_id in payload + + # When payload_schema is empty, the code first samples payloads to detect workspace_id + # Then proceeds with migration batches + # Scroll calls: 1) Sampling (limit=10), 2) Migration batch, 3) End of migration + mock_qdrant_client.scroll.side_effect = [ + ([mock_point], "_"), # Sampling scroll - no workspace_id found + ([mock_point], "next_offset"), # Migration batch + ([], None), # End of migration + ] + + def upsert_mock(*args, **kwargs): + migration_state["new_workspace_count"] = 100 + return None + + mock_qdrant_client.upsert.side_effect = upsert_mock + + # Initialize storage (triggers migration) + await storage.initialize() + + # Verify migration steps + # 1. Legacy count checked + mock_qdrant_client.count.assert_any_call( + collection_name=legacy_collection, exact=True + ) + + # 2. New collection created + mock_qdrant_client.create_collection.assert_called() + + # 3. Data scrolled from legacy + # First call (index 0) is sampling scroll with limit=10 + # Second call (index 1) is migration batch with limit=500 + assert mock_qdrant_client.scroll.call_count >= 2 + # Check sampling scroll + sampling_call = mock_qdrant_client.scroll.call_args_list[0] + assert sampling_call.kwargs["collection_name"] == legacy_collection + assert sampling_call.kwargs["limit"] == 10 + # Check migration batch scroll + migration_call = mock_qdrant_client.scroll.call_args_list[1] + assert migration_call.kwargs["collection_name"] == legacy_collection + assert migration_call.kwargs["limit"] == 500 + + # 4. Data upserted to new + mock_qdrant_client.upsert.assert_called() + + # 5. Payload index created + mock_qdrant_client.create_payload_index.assert_called() + + +async def test_qdrant_no_migration_needed(mock_qdrant_client, mock_embedding_func): + """Test scenario where new collection already exists (Case 1 in setup_collection) + + When only the new collection exists and no legacy collection is found, + the implementation should: + 1. Create payload index on the new collection (ensure index exists) + 2. NOT attempt any data migration (no scroll calls) + """ + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + storage = QdrantVectorDBStorage( + namespace="chunks", + global_config=config, + embedding_func=mock_embedding_func, + workspace="test_ws", + ) + + # Only new collection exists (no legacy collection found) + mock_qdrant_client.collection_exists.side_effect = ( + lambda name: name == storage.final_namespace + ) + + # Initialize + await storage.initialize() + + # Should create payload index on the new collection (ensure index) + mock_qdrant_client.create_payload_index.assert_called_with( + collection_name=storage.final_namespace, + field_name="workspace_id", + field_schema=models.KeywordIndexParams( + type=models.KeywordIndexType.KEYWORD, + is_tenant=True, + ), + ) + # Should NOT migrate (no scroll calls since no legacy collection exists) + mock_qdrant_client.scroll.assert_not_called() + + +# ============================================================================ +# Tests for scenarios described in design document (Lines 606-649) +# ============================================================================ + + +async def test_scenario_1_new_workspace_creation( + mock_qdrant_client, mock_embedding_func +): + """ + 场景1:新建workspace + 预期:直接创建lightrag_vdb_chunks_text_embedding_3_large_3072d + """ + # Use a large embedding model + large_model_func = EmbeddingFunc( + embedding_dim=3072, + func=mock_embedding_func.func, + model_name="text-embedding-3-large", + ) + + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + storage = QdrantVectorDBStorage( + namespace="chunks", + global_config=config, + embedding_func=large_model_func, + workspace="test_new", + ) + + # Case 3: Neither legacy nor new collection exists + mock_qdrant_client.collection_exists.return_value = False + + # Initialize storage + await storage.initialize() + + # Verify: Should create new collection with model suffix + expected_collection = "lightrag_vdb_chunks_text_embedding_3_large_3072d" + assert storage.final_namespace == expected_collection + + # Verify create_collection was called with correct name + create_calls = [ + call for call in mock_qdrant_client.create_collection.call_args_list + ] + assert len(create_calls) > 0 + assert ( + create_calls[0][0][0] == expected_collection + or create_calls[0].kwargs.get("collection_name") == expected_collection + ) + + # Verify no migration was attempted + mock_qdrant_client.scroll.assert_not_called() + + print( + f"✅ Scenario 1: New workspace created with collection '{expected_collection}'" + ) + + +async def test_scenario_2_legacy_upgrade_migration( + mock_qdrant_client, mock_embedding_func +): + """ + 场景2:从旧版本升级 + 已存在lightrag_vdb_chunks(无后缀) + 预期:自动迁移数据到lightrag_vdb_chunks_text_embedding_ada_002_1536d + 注意:迁移后不再自动删除遗留集合,需要手动删除 + """ + # Use ada-002 model + ada_func = EmbeddingFunc( + embedding_dim=1536, + func=mock_embedding_func.func, + model_name="text-embedding-ada-002", + ) + + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + storage = QdrantVectorDBStorage( + namespace="chunks", + global_config=config, + embedding_func=ada_func, + workspace="test_legacy", + ) + + # Legacy collection name (without model suffix) + legacy_collection = "lightrag_vdb_chunks" + new_collection = storage.final_namespace + + # Case 4: Only legacy collection exists + mock_qdrant_client.collection_exists.side_effect = ( + lambda name: name == legacy_collection + ) + + # Mock legacy collection info with 1536d vectors + legacy_collection_info = MagicMock() + legacy_collection_info.payload_schema = {} + legacy_collection_info.config.params.vectors.size = 1536 + mock_qdrant_client.get_collection.return_value = legacy_collection_info + + migration_state = {"new_workspace_count": 0} + + def count_mock(collection_name, exact=True, count_filter=None): + mock_result = MagicMock() + if collection_name == legacy_collection: + mock_result.count = 150 + elif collection_name == new_collection: + mock_result.count = migration_state["new_workspace_count"] + else: + mock_result.count = 0 + return mock_result + + mock_qdrant_client.count.side_effect = count_mock + + # Mock scroll results (simulate migration in batches) + mock_points = [] + for i in range(10): + point = MagicMock() + point.id = f"legacy-{i}" + point.vector = [0.1] * 1536 + # No workspace_id in payload - simulates legacy data + point.payload = {"content": f"Legacy document {i}", "id": f"doc-{i}"} + mock_points.append(point) + + # When payload_schema is empty, the code first samples payloads to detect workspace_id + # Then proceeds with migration batches + # Scroll calls: 1) Sampling (limit=10), 2) Migration batch, 3) End of migration + mock_qdrant_client.scroll.side_effect = [ + (mock_points, "_"), # Sampling scroll - no workspace_id found in payloads + (mock_points, "offset1"), # Migration batch + ([], None), # End of migration + ] + + def upsert_mock(*args, **kwargs): + migration_state["new_workspace_count"] = 150 + return None + + mock_qdrant_client.upsert.side_effect = upsert_mock + + # Initialize (triggers migration) + await storage.initialize() + + # Verify: New collection should be created + expected_new_collection = "lightrag_vdb_chunks_text_embedding_ada_002_1536d" + assert storage.final_namespace == expected_new_collection + + # Verify migration steps + # 1. Check legacy count + mock_qdrant_client.count.assert_any_call( + collection_name=legacy_collection, exact=True + ) + + # 2. Create new collection + mock_qdrant_client.create_collection.assert_called() + + # 3. Scroll legacy data + scroll_calls = [call for call in mock_qdrant_client.scroll.call_args_list] + assert len(scroll_calls) >= 1 + assert scroll_calls[0].kwargs["collection_name"] == legacy_collection + + # 4. Upsert to new collection + upsert_calls = [call for call in mock_qdrant_client.upsert.call_args_list] + assert len(upsert_calls) >= 1 + assert upsert_calls[0].kwargs["collection_name"] == new_collection + + # Note: Legacy collection is NOT automatically deleted after migration + # Manual deletion is required after data migration verification + + print( + f"✅ Scenario 2: Legacy data migrated from '{legacy_collection}' to '{expected_new_collection}'" + ) + + +async def test_scenario_3_multi_model_coexistence(mock_qdrant_client): + """ + 场景3:多模型并存 + 预期:两个独立的collection,互不干扰 + """ + + # Model A: bge-small with 768d + async def embed_func_a(texts, **kwargs): + return np.array([[0.1] * 768 for _ in texts]) + + model_a_func = EmbeddingFunc( + embedding_dim=768, func=embed_func_a, model_name="bge-small" + ) + + # Model B: bge-large with 1024d + async def embed_func_b(texts, **kwargs): + return np.array([[0.2] * 1024 for _ in texts]) + + model_b_func = EmbeddingFunc( + embedding_dim=1024, func=embed_func_b, model_name="bge-large" + ) + + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + # Create storage for workspace A with model A + storage_a = QdrantVectorDBStorage( + namespace="chunks", + global_config=config, + embedding_func=model_a_func, + workspace="workspace_a", + ) + + # Create storage for workspace B with model B + storage_b = QdrantVectorDBStorage( + namespace="chunks", + global_config=config, + embedding_func=model_b_func, + workspace="workspace_b", + ) + + # Verify: Collection names are different + assert storage_a.final_namespace != storage_b.final_namespace + + # Verify: Model A collection + expected_collection_a = "lightrag_vdb_chunks_bge_small_768d" + assert storage_a.final_namespace == expected_collection_a + + # Verify: Model B collection + expected_collection_b = "lightrag_vdb_chunks_bge_large_1024d" + assert storage_b.final_namespace == expected_collection_b + + # Verify: Different embedding dimensions are preserved + assert storage_a.embedding_func.embedding_dim == 768 + assert storage_b.embedding_func.embedding_dim == 1024 + + print("✅ Scenario 3: Multi-model coexistence verified") + print(f" - Workspace A: {expected_collection_a} (768d)") + print(f" - Workspace B: {expected_collection_b} (1024d)") + print(" - Collections are independent") + + +async def test_case1_empty_legacy_auto_cleanup(mock_qdrant_client, mock_embedding_func): + """ + Case 1a: 新旧collection都存在,且旧库为空 + 预期:自动删除旧库 + """ + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + storage = QdrantVectorDBStorage( + namespace="chunks", + global_config=config, + embedding_func=mock_embedding_func, + workspace="test_ws", + ) + + # Legacy collection name (without model suffix) + legacy_collection = "lightrag_vdb_chunks" + new_collection = storage.final_namespace + + # Mock: Both collections exist + mock_qdrant_client.collection_exists.side_effect = lambda name: name in [ + legacy_collection, + new_collection, + ] + + # Mock: Legacy collection is empty (0 records) + def count_mock(collection_name, exact=True, count_filter=None): + mock_result = MagicMock() + if collection_name == legacy_collection: + mock_result.count = 0 # Empty legacy collection + else: + mock_result.count = 100 # New collection has data + return mock_result + + mock_qdrant_client.count.side_effect = count_mock + + # Mock get_collection for Case 2 check + collection_info = MagicMock() + collection_info.payload_schema = {"workspace_id": True} + mock_qdrant_client.get_collection.return_value = collection_info + + # Initialize storage + await storage.initialize() + + # Verify: Empty legacy collection should be automatically cleaned up + # Empty collections are safe to delete without data loss risk + delete_calls = [ + call for call in mock_qdrant_client.delete_collection.call_args_list + ] + assert len(delete_calls) >= 1, "Empty legacy collection should be auto-deleted" + deleted_collection = ( + delete_calls[0][0][0] + if delete_calls[0][0] + else delete_calls[0].kwargs.get("collection_name") + ) + assert ( + deleted_collection == legacy_collection + ), f"Expected to delete '{legacy_collection}', but deleted '{deleted_collection}'" + + print( + f"✅ Case 1a: Empty legacy collection '{legacy_collection}' auto-deleted successfully" + ) + + +async def test_case1_nonempty_legacy_warning(mock_qdrant_client, mock_embedding_func): + """ + Case 1b: 新旧collection都存在,且旧库有数据 + 预期:警告但不删除 + """ + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + storage = QdrantVectorDBStorage( + namespace="chunks", + global_config=config, + embedding_func=mock_embedding_func, + workspace="test_ws", + ) + + # Legacy collection name (without model suffix) + legacy_collection = "lightrag_vdb_chunks" + new_collection = storage.final_namespace + + # Mock: Both collections exist + mock_qdrant_client.collection_exists.side_effect = lambda name: name in [ + legacy_collection, + new_collection, + ] + + # Mock: Legacy collection has data (50 records) + def count_mock(collection_name, exact=True, count_filter=None): + mock_result = MagicMock() + if collection_name == legacy_collection: + mock_result.count = 50 # Legacy has data + else: + mock_result.count = 100 # New collection has data + return mock_result + + mock_qdrant_client.count.side_effect = count_mock + + # Mock get_collection for Case 2 check + collection_info = MagicMock() + collection_info.payload_schema = {"workspace_id": True} + mock_qdrant_client.get_collection.return_value = collection_info + + # Initialize storage + await storage.initialize() + + # Verify: Legacy collection with data should be preserved + # We never auto-delete collections that contain data to prevent accidental data loss + delete_calls = [ + call for call in mock_qdrant_client.delete_collection.call_args_list + ] + # Check if legacy collection was deleted (it should not be) + legacy_deleted = any( + (call[0][0] if call[0] else call.kwargs.get("collection_name")) + == legacy_collection + for call in delete_calls + ) + assert not legacy_deleted, "Legacy collection with data should NOT be auto-deleted" + + print( + f"✅ Case 1b: Legacy collection '{legacy_collection}' with data preserved (warning only)" + ) diff --git a/tests/test_unified_lock_safety.py b/tests/test_unified_lock_safety.py new file mode 100644 index 00000000..1d83190a --- /dev/null +++ b/tests/test_unified_lock_safety.py @@ -0,0 +1,122 @@ +""" +Tests for UnifiedLock safety when lock is None. + +This test module verifies that get_internal_lock() and get_data_init_lock() +raise RuntimeError when shared data is not initialized, preventing false +security and potential race conditions. + +Design: The None check has been moved from UnifiedLock.__aenter__/__enter__ +to the lock factory functions (get_internal_lock, get_data_init_lock) for +early failure detection. + +Critical Bug 1 (Fixed): When self._lock is None, the code would fail with +AttributeError. Now the check is in factory functions for clearer errors. + +Critical Bug 2: In __aexit__, when async_lock.release() fails, the error +recovery logic would attempt to release it again, causing double-release issues. +""" + +from unittest.mock import MagicMock, AsyncMock + +import pytest + +from lightrag.kg.shared_storage import ( + UnifiedLock, + get_internal_lock, + get_data_init_lock, + finalize_share_data, +) + + +class TestUnifiedLockSafety: + """Test suite for UnifiedLock None safety checks.""" + + def setup_method(self): + """Ensure shared data is finalized before each test.""" + finalize_share_data() + + def teardown_method(self): + """Clean up after each test.""" + finalize_share_data() + + def test_get_internal_lock_raises_when_not_initialized(self): + """ + Test that get_internal_lock() raises RuntimeError when shared data is not initialized. + + Scenario: Call get_internal_lock() before initialize_share_data() is called. + Expected: RuntimeError raised with clear error message. + + This test verifies the None check has been moved to the factory function. + """ + with pytest.raises( + RuntimeError, match="Shared data not initialized.*initialize_share_data" + ): + get_internal_lock() + + def test_get_data_init_lock_raises_when_not_initialized(self): + """ + Test that get_data_init_lock() raises RuntimeError when shared data is not initialized. + + Scenario: Call get_data_init_lock() before initialize_share_data() is called. + Expected: RuntimeError raised with clear error message. + + This test verifies the None check has been moved to the factory function. + """ + with pytest.raises( + RuntimeError, match="Shared data not initialized.*initialize_share_data" + ): + get_data_init_lock() + + @pytest.mark.offline + async def test_aexit_no_double_release_on_async_lock_failure(self): + """ + Test that __aexit__ doesn't attempt to release async_lock twice when it fails. + + Scenario: async_lock.release() fails during normal release. + Expected: Recovery logic should NOT attempt to release async_lock again, + preventing double-release issues. + + This tests Bug 2 fix: async_lock_released tracking prevents double release. + """ + # Create mock locks + main_lock = MagicMock() + main_lock.acquire = MagicMock() + main_lock.release = MagicMock() + + async_lock = AsyncMock() + async_lock.acquire = AsyncMock() + + # Make async_lock.release() fail + release_call_count = 0 + + def mock_release_fail(): + nonlocal release_call_count + release_call_count += 1 + raise RuntimeError("Async lock release failed") + + async_lock.release = MagicMock(side_effect=mock_release_fail) + + # Create UnifiedLock with both locks (sync mode with async_lock) + lock = UnifiedLock( + lock=main_lock, + is_async=False, + name="test_double_release", + enable_logging=False, + ) + lock._async_lock = async_lock + + # Try to use the lock - should fail during __aexit__ + try: + async with lock: + pass + except RuntimeError as e: + # Should get the async lock release error + assert "Async lock release failed" in str(e) + + # Verify async_lock.release() was called only ONCE, not twice + assert ( + release_call_count == 1 + ), f"async_lock.release() should be called only once, but was called {release_call_count} times" + + # Main lock should have been released successfully + main_lock.release.assert_called_once() diff --git a/tests/test_workspace_isolation.py b/tests/test_workspace_isolation.py index 68f7f8ec..0aac3186 100644 --- a/tests/test_workspace_isolation.py +++ b/tests/test_workspace_isolation.py @@ -149,7 +149,6 @@ def _assert_no_timeline_overlap(timeline: List[Tuple[str, str]]) -> None: @pytest.mark.offline -@pytest.mark.asyncio async def test_pipeline_status_isolation(): """ Test that pipeline status is isolated between different workspaces. @@ -204,7 +203,6 @@ async def test_pipeline_status_isolation(): @pytest.mark.offline -@pytest.mark.asyncio async def test_lock_mechanism(stress_test_mode, parallel_workers): """ Test that the new keyed lock mechanism works correctly without deadlocks. @@ -274,7 +272,6 @@ async def test_lock_mechanism(stress_test_mode, parallel_workers): @pytest.mark.offline -@pytest.mark.asyncio async def test_backward_compatibility(): """ Test that legacy code without workspace parameter still works correctly. @@ -348,7 +345,6 @@ async def test_backward_compatibility(): @pytest.mark.offline -@pytest.mark.asyncio async def test_multi_workspace_concurrency(): """ Test that multiple workspaces can operate concurrently without interference. @@ -432,7 +428,6 @@ async def test_multi_workspace_concurrency(): @pytest.mark.offline -@pytest.mark.asyncio async def test_namespace_lock_reentrance(): """ Test that NamespaceLock prevents re-entrance in the same coroutine @@ -506,7 +501,6 @@ async def test_namespace_lock_reentrance(): @pytest.mark.offline -@pytest.mark.asyncio async def test_different_namespace_lock_isolation(): """ Test that locks for different namespaces (same workspace) are independent. @@ -546,7 +540,6 @@ async def test_different_namespace_lock_isolation(): @pytest.mark.offline -@pytest.mark.asyncio async def test_error_handling(): """ Test error handling for invalid workspace configurations. @@ -597,7 +590,6 @@ async def test_error_handling(): @pytest.mark.offline -@pytest.mark.asyncio async def test_update_flags_workspace_isolation(): """ Test that update flags are properly isolated between workspaces. @@ -727,7 +719,6 @@ async def test_update_flags_workspace_isolation(): @pytest.mark.offline -@pytest.mark.asyncio async def test_empty_workspace_standardization(): """ Test that empty workspace is properly standardized to "" instead of "_". @@ -781,7 +772,6 @@ async def test_empty_workspace_standardization(): @pytest.mark.offline -@pytest.mark.asyncio async def test_json_kv_storage_workspace_isolation(keep_test_artifacts): """ Integration test: Verify JsonKVStorage properly isolates data between workspaces. @@ -961,7 +951,6 @@ async def test_json_kv_storage_workspace_isolation(keep_test_artifacts): @pytest.mark.offline -@pytest.mark.asyncio async def test_lightrag_end_to_end_workspace_isolation(keep_test_artifacts): """ End-to-end test: Create two LightRAG instances with different workspaces, diff --git a/tests/test_workspace_migration_isolation.py b/tests/test_workspace_migration_isolation.py new file mode 100644 index 00000000..d0e3bfd2 --- /dev/null +++ b/tests/test_workspace_migration_isolation.py @@ -0,0 +1,288 @@ +""" +Tests for workspace isolation during PostgreSQL migration. + +This test module verifies that setup_table() properly filters migration data +by workspace, preventing cross-workspace data leakage during legacy table migration. + +Critical Bug: Migration copied ALL records from legacy table regardless of workspace, +causing workspace A to receive workspace B's data, violating multi-tenant isolation. +""" + +import pytest +from unittest.mock import AsyncMock + +from lightrag.kg.postgres_impl import PGVectorStorage + + +class TestWorkspaceMigrationIsolation: + """Test suite for workspace-scoped migration in PostgreSQL.""" + + async def test_migration_filters_by_workspace(self): + """ + Test that migration only copies data from the specified workspace. + + Scenario: Legacy table contains data from multiple workspaces. + Migrate only workspace_a's data to new table. + Expected: New table contains only workspace_a data, workspace_b data excluded. + """ + db = AsyncMock() + + # Configure mock return values to avoid unawaited coroutine warnings + db._create_vector_index.return_value = None + + # Track state for new table count (starts at 0, increases after migration) + new_table_record_count = {"count": 0} + + # Mock table existence checks + async def table_exists_side_effect(db_instance, name): + if name.lower() == "lightrag_doc_chunks": # legacy + return True + elif name.lower() == "lightrag_doc_chunks_model_1536d": # new + return False # New table doesn't exist initially + return False + + # Mock data for workspace_a + mock_records_a = [ + { + "id": "a1", + "workspace": "workspace_a", + "content": "content_a1", + "content_vector": [0.1] * 1536, + }, + { + "id": "a2", + "workspace": "workspace_a", + "content": "content_a2", + "content_vector": [0.2] * 1536, + }, + ] + + # Mock query responses + async def query_side_effect(sql, params, **kwargs): + multirows = kwargs.get("multirows", False) + sql_upper = sql.upper() + + # Count query for new table workspace data (verification before migration) + if ( + "COUNT(*)" in sql_upper + and "MODEL_1536D" in sql_upper + and "WHERE WORKSPACE" in sql_upper + ): + return new_table_record_count # Initially 0 + + # Count query with workspace filter (legacy table) - for workspace count + elif "COUNT(*)" in sql_upper and "WHERE WORKSPACE" in sql_upper: + if params and params[0] == "workspace_a": + return {"count": 2} # workspace_a has 2 records + elif params and params[0] == "workspace_b": + return {"count": 3} # workspace_b has 3 records + return {"count": 0} + + # Count query for legacy table (total, no workspace filter) + elif ( + "COUNT(*)" in sql_upper + and "LIGHTRAG" in sql_upper + and "WHERE WORKSPACE" not in sql_upper + ): + return {"count": 5} # Total records in legacy + + # SELECT with workspace filter for migration (multirows) + elif "SELECT" in sql_upper and "FROM" in sql_upper and multirows: + workspace = params[0] if params else None + if workspace == "workspace_a": + # Handle keyset pagination: check for "id >" pattern + if "id >" in sql.lower(): + # Keyset pagination: params = [workspace, last_id, limit] + last_id = params[1] if len(params) > 1 else None + # Find records after last_id + found_idx = -1 + for i, rec in enumerate(mock_records_a): + if rec["id"] == last_id: + found_idx = i + break + if found_idx >= 0: + return mock_records_a[found_idx + 1 :] + return [] + else: + # First batch: params = [workspace, limit] + return mock_records_a + return [] # No data for other workspaces + + return {} + + db.query.side_effect = query_side_effect + db.execute = AsyncMock() + + # Mock check_table_exists on db + async def check_table_exists_side_effect(name): + if name.lower() == "lightrag_doc_chunks": # legacy + return True + elif name.lower() == "lightrag_doc_chunks_model_1536d": # new + return False # New table doesn't exist initially + return False + + db.check_table_exists = AsyncMock(side_effect=check_table_exists_side_effect) + + # Track migration through _run_with_retry calls + migration_executed = [] + + async def mock_run_with_retry(operation, *args, **kwargs): + migration_executed.append(True) + new_table_record_count["count"] = 2 # Simulate 2 records migrated + return None + + db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry) + + # Migrate for workspace_a only - correct parameter order + await PGVectorStorage.setup_table( + db, + "LIGHTRAG_DOC_CHUNKS_model_1536d", + workspace="workspace_a", # CRITICAL: Only migrate workspace_a + embedding_dim=1536, + legacy_table_name="LIGHTRAG_DOC_CHUNKS", + base_table="LIGHTRAG_DOC_CHUNKS", + ) + + # Verify the migration was triggered + assert ( + len(migration_executed) > 0 + ), "Migration should have been executed for workspace_a" + + async def test_migration_without_workspace_raises_error(self): + """ + Test that migration without workspace parameter raises ValueError. + + Scenario: setup_table called without workspace parameter. + Expected: ValueError is raised because workspace is required. + """ + db = AsyncMock() + + # workspace is now a required parameter - calling with None should raise ValueError + with pytest.raises(ValueError, match="workspace must be provided"): + await PGVectorStorage.setup_table( + db, + "lightrag_doc_chunks_model_1536d", + workspace=None, # No workspace - should raise ValueError + embedding_dim=1536, + legacy_table_name="lightrag_doc_chunks", + base_table="lightrag_doc_chunks", + ) + + async def test_no_cross_workspace_contamination(self): + """ + Test that workspace B's migration doesn't include workspace A's data. + + Scenario: Migration for workspace_b only. + Expected: Only workspace_b data is queried, workspace_a data excluded. + """ + db = AsyncMock() + + # Configure mock return values to avoid unawaited coroutine warnings + db._create_vector_index.return_value = None + + # Track which workspace is being queried + queried_workspace = None + new_table_count = {"count": 0} + + # Mock data for workspace_b + mock_records_b = [ + { + "id": "b1", + "workspace": "workspace_b", + "content": "content_b1", + "content_vector": [0.3] * 1536, + }, + ] + + async def table_exists_side_effect(db_instance, name): + if name.lower() == "lightrag_doc_chunks": # legacy + return True + elif name.lower() == "lightrag_doc_chunks_model_1536d": # new + return False + return False + + async def query_side_effect(sql, params, **kwargs): + nonlocal queried_workspace + multirows = kwargs.get("multirows", False) + sql_upper = sql.upper() + + # Count query for new table workspace data (should be 0 initially) + if ( + "COUNT(*)" in sql_upper + and "MODEL_1536D" in sql_upper + and "WHERE WORKSPACE" in sql_upper + ): + return new_table_count + + # Count query with workspace filter (legacy table) + elif "COUNT(*)" in sql_upper and "WHERE WORKSPACE" in sql_upper: + queried_workspace = params[0] if params else None + return {"count": 1} # 1 record for the queried workspace + + # Count query for legacy table total (no workspace filter) + elif ( + "COUNT(*)" in sql_upper + and "LIGHTRAG" in sql_upper + and "WHERE WORKSPACE" not in sql_upper + ): + return {"count": 3} # 3 total records in legacy + + # SELECT with workspace filter for migration (multirows) + elif "SELECT" in sql_upper and "FROM" in sql_upper and multirows: + workspace = params[0] if params else None + if workspace == "workspace_b": + # Handle keyset pagination: check for "id >" pattern + if "id >" in sql.lower(): + # Keyset pagination: params = [workspace, last_id, limit] + last_id = params[1] if len(params) > 1 else None + # Find records after last_id + found_idx = -1 + for i, rec in enumerate(mock_records_b): + if rec["id"] == last_id: + found_idx = i + break + if found_idx >= 0: + return mock_records_b[found_idx + 1 :] + return [] + else: + # First batch: params = [workspace, limit] + return mock_records_b + return [] # No data for other workspaces + + return {} + + db.query.side_effect = query_side_effect + db.execute = AsyncMock() + + # Mock check_table_exists on db + async def check_table_exists_side_effect(name): + if name.lower() == "lightrag_doc_chunks": # legacy + return True + elif name.lower() == "lightrag_doc_chunks_model_1536d": # new + return False + return False + + db.check_table_exists = AsyncMock(side_effect=check_table_exists_side_effect) + + # Track migration through _run_with_retry calls + migration_executed = [] + + async def mock_run_with_retry(operation, *args, **kwargs): + migration_executed.append(True) + new_table_count["count"] = 1 # Simulate migration + return None + + db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry) + + # Migrate workspace_b - correct parameter order + await PGVectorStorage.setup_table( + db, + "LIGHTRAG_DOC_CHUNKS_model_1536d", + workspace="workspace_b", # Only migrate workspace_b + embedding_dim=1536, + legacy_table_name="LIGHTRAG_DOC_CHUNKS", + base_table="LIGHTRAG_DOC_CHUNKS", + ) + + # Verify only workspace_b was queried + assert queried_workspace == "workspace_b", "Should only query workspace_b"