From ada5f10be7b82ae468747cd54fc111f5ac08192e Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 19 Dec 2025 12:05:22 +0800 Subject: [PATCH] Optimize Postgres batch operations and refine workspace migration logic - Use executemany for efficient upserts - Optimize data migration with batching - Refine multi-workspace migration logic - Add pgvector dependency - Update DDL templates for dynamic dims --- lightrag/kg/postgres_impl.py | 700 +++++++++++++------------------ pyproject.toml | 4 +- requirements-offline-storage.txt | 3 +- requirements-offline.txt | 9 +- tests/test_postgres_migration.py | 225 +++++----- 5 files changed, 402 insertions(+), 539 deletions(-) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 7f5b2ce5..5b0591dc 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -31,6 +31,7 @@ from ..base import ( DocStatus, DocStatusStorage, ) +from ..exceptions import DataMigrationError from ..namespace import NameSpace, is_namespace from ..utils import logger from ..kg.shared_storage import get_data_init_lock @@ -39,9 +40,12 @@ import pipmaster as pm if not pm.is_installed("asyncpg"): pm.install("asyncpg") +if not pm.is_installed("pgvector"): + pm.install("pgvector") import asyncpg # type: ignore from asyncpg import Pool # type: ignore +from pgvector.asyncpg import register_vector # type: ignore from dotenv import load_dotenv @@ -2191,9 +2195,7 @@ async def _pg_create_table( ddl_template = TABLES[base_table]["ddl"] # Replace embedding dimension placeholder if exists - ddl = ddl_template.replace( - f"VECTOR({os.environ.get('EMBEDDING_DIM', 1024)})", f"VECTOR({embedding_dim})" - ) + ddl = ddl_template.replace("VECTOR(dimension)", f"VECTOR({embedding_dim})") # Replace table name ddl = ddl.replace(base_table, table_name) @@ -2209,7 +2211,11 @@ async def _pg_migrate_workspace_data( expected_count: int, embedding_dim: int, ) -> int: - """Migrate workspace data from legacy table to new table""" + """Migrate workspace data from legacy table to new table using batch insert. + + This function uses asyncpg's executemany for efficient batch insertion, + reducing database round-trips from N to 1 per batch. + """ migrated_count = 0 offset = 0 batch_size = 500 @@ -2227,20 +2233,34 @@ async def _pg_migrate_workspace_data( if not rows: break + # Batch insert optimization: use executemany instead of individual inserts + # Get column names from the first row + first_row = dict(rows[0]) + columns = list(first_row.keys()) + columns_str = ", ".join(columns) + placeholders = ", ".join([f"${i + 1}" for i in range(len(columns))]) + + insert_query = f""" + INSERT INTO {new_table_name} ({columns_str}) + VALUES ({placeholders}) + ON CONFLICT (workspace, id) DO NOTHING + """ + + # Prepare batch data: convert rows to list of tuples + batch_values = [] for row in rows: row_dict = dict(row) - columns = list(row_dict.keys()) - columns_str = ", ".join(columns) - placeholders = ", ".join([f"${i + 1}" for i in range(len(columns))]) - insert_query = f""" - INSERT INTO {new_table_name} ({columns_str}) - VALUES ({placeholders}) - ON CONFLICT (workspace, id) DO NOTHING - """ - # Rebuild dict in columns order to ensure values() matches placeholders order - # Python 3.7+ dicts maintain insertion order, and execute() uses tuple(data.values()) - values = {col: row_dict[col] for col in columns} - await db.execute(insert_query, values) + # Extract values in column order to match placeholders + values_tuple = tuple(row_dict[col] for col in columns) + batch_values.append(values_tuple) + + # Use executemany for batch execution - significantly reduces DB round-trips + # Register pgvector codec to handle vector fields alongside other fields seamlessly + async def _batch_insert(connection: asyncpg.Connection) -> None: + await register_vector(connection) + await connection.executemany(insert_query, batch_values) + + await db._run_with_retry(_batch_insert) migrated_count += len(rows) workspace_info = f" for workspace '{workspace}'" if workspace else "" @@ -2284,395 +2304,208 @@ class PGVectorStorage(BaseVectorStorage): # Fallback: use base table name if model_suffix is unavailable self.table_name = base_table logger.warning( - f"Model suffix unavailable, using base table name '{base_table}'. " - f"Ensure embedding_func has model_name for proper model isolation." + "Missing collection suffix. Ensure embedding_func has model_name for proper model isolation." ) # Legacy table name (without suffix, for migration) self.legacy_table_name = base_table - logger.debug( - f"PostgreSQL table naming: " - f"new='{self.table_name}', " - f"legacy='{self.legacy_table_name}', " - f"model_suffix='{self.model_suffix}'" - ) + logger.info(f"PostgreSQL table name: {self.table_name}") @staticmethod async def setup_table( db: PostgreSQLDB, table_name: str, - legacy_table_name: str = None, - base_table: str = None, - embedding_dim: int = None, - workspace: str = None, + workspace: str, + embedding_dim: int, + legacy_table_name: str, + base_table: str, ): """ Setup PostgreSQL table with migration support from legacy tables. - This method mirrors Qdrant's setup_collection approach to maintain consistency. + Ensure final table has workspace isolation index. + Check vector dimension compatibility before new table creation. + Drop legacy table if it exists and is empty. + Only migrate data from legacy table to new table when new table first created and legacy table is not empty. Args: db: PostgreSQLDB instance table_name: Name of the new table - legacy_table_name: Name of the legacy table (if exists) + workspace: Workspace to filter records for migration + legacy_table_name: Name of the legacy table to check for migration base_table: Base table name for DDL template lookup embedding_dim: Embedding dimension for vector column """ + if not workspace: + raise ValueError("workspace must be provided") + new_table_exists = await _pg_table_exists(db, table_name) legacy_exists = legacy_table_name and await _pg_table_exists( db, legacy_table_name ) - # Case 1: Both new and legacy tables exist - if new_table_exists and legacy_exists: - if table_name.lower() == legacy_table_name.lower(): - logger.debug( - f"PostgreSQL: Table '{table_name}' already exists (no model suffix). Skipping Case 1 cleanup." - ) - return - - try: - workspace_info = f" for workspace '{workspace}'" if workspace else "" - - if workspace: - count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name} WHERE workspace = $1" - count_result = await db.query(count_query, [workspace]) - else: - count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name}" - count_result = await db.query(count_query, []) - - workspace_count = count_result.get("count", 0) if count_result else 0 - - if workspace_count > 0: - logger.info( - f"PostgreSQL: Found {workspace_count} records in legacy table{workspace_info}. Migrating..." - ) - - legacy_dim = None - try: - dim_query = """ - SELECT - CASE - WHEN typname = 'vector' THEN - COALESCE(atttypmod, -1) - ELSE -1 - END as vector_dim - FROM pg_attribute a - JOIN pg_type t ON a.atttypid = t.oid - WHERE a.attrelid = $1::regclass - AND a.attname = 'content_vector' - """ - dim_result = await db.query(dim_query, [legacy_table_name]) - legacy_dim = ( - dim_result.get("vector_dim", -1) if dim_result else -1 - ) - - if legacy_dim <= 0: - sample_query = f"SELECT content_vector FROM {legacy_table_name} LIMIT 1" - sample_result = await db.query(sample_query, []) - if sample_result and sample_result.get("content_vector"): - vector_data = sample_result["content_vector"] - if isinstance(vector_data, (list, tuple)): - legacy_dim = len(vector_data) - elif isinstance(vector_data, str): - import json - - vector_list = json.loads(vector_data) - legacy_dim = len(vector_list) - - if ( - legacy_dim > 0 - and embedding_dim - and legacy_dim != embedding_dim - ): - logger.warning( - f"PostgreSQL: Dimension mismatch - " - f"legacy table has {legacy_dim}d vectors, " - f"new embedding model expects {embedding_dim}d. " - f"Skipping migration{workspace_info}." - ) - await db._create_vector_index(table_name, embedding_dim) - return - - except Exception as e: - logger.warning( - f"PostgreSQL: Could not verify vector dimension: {e}. Proceeding with caution..." - ) - - migrated_count = await _pg_migrate_workspace_data( - db, - legacy_table_name, - table_name, - workspace, - workspace_count, - embedding_dim, - ) - - if workspace: - new_count_query = f"SELECT COUNT(*) as count FROM {table_name} WHERE workspace = $1" - new_count_result = await db.query(new_count_query, [workspace]) - else: - new_count_query = f"SELECT COUNT(*) as count FROM {table_name}" - new_count_result = await db.query(new_count_query, []) - - new_count = ( - new_count_result.get("count", 0) if new_count_result else 0 - ) - - if new_count < workspace_count: - logger.warning( - f"PostgreSQL: Expected {workspace_count} records, found {new_count}{workspace_info}. " - f"Some records may have been skipped due to conflicts." - ) - else: - logger.info( - f"PostgreSQL: Migration completed: {migrated_count} records migrated{workspace_info}" - ) - - if workspace: - delete_query = ( - f"DELETE FROM {legacy_table_name} WHERE workspace = $1" - ) - await db.execute(delete_query, {"workspace": workspace}) - logger.info( - f"PostgreSQL: Deleted workspace '{workspace}' data from legacy table" - ) - - total_count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name}" - total_count_result = await db.query(total_count_query, []) - total_count = ( - total_count_result.get("count", 0) if total_count_result else 0 - ) - - if total_count == 0: - logger.info( - f"PostgreSQL: Legacy table '{legacy_table_name}' is empty. Deleting..." - ) - drop_query = f"DROP TABLE {legacy_table_name}" - await db.execute(drop_query, None) - logger.info( - f"PostgreSQL: Legacy table '{legacy_table_name}' deleted successfully" - ) - else: - logger.info( - f"PostgreSQL: Legacy table '{legacy_table_name}' preserved " - f"({total_count} records from other workspaces remain)" - ) - - except Exception as e: - logger.warning( - f"PostgreSQL: Error during Case 1 migration: {e}. Vector index will still be ensured." - ) - + # Case 1: Only new table exists or new table is the same as legacy table + # No data migration needed, ensuring index is created then return + if (new_table_exists and not legacy_exists) or ( + table_name.lower() == legacy_table_name.lower() + ): await db._create_vector_index(table_name, embedding_dim) return - # Case 2: Only new table exists - Already migrated or newly created - if new_table_exists: - logger.debug(f"PostgreSQL: Table '{table_name}' already exists") - # Ensure vector index exists with correct embedding dimension - await db._create_vector_index(table_name, embedding_dim) - return - - # Case 3: Neither exists - Create new table - if not legacy_exists: - logger.info(f"PostgreSQL: Creating new table '{table_name}'") - await _pg_create_table(db, table_name, base_table, embedding_dim) - logger.info(f"PostgreSQL: Table '{table_name}' created successfully") - # Create vector index with correct embedding dimension - await db._create_vector_index(table_name, embedding_dim) - return - - # Case 4: Only legacy exists - Migrate data - logger.info( - f"PostgreSQL: Migrating data from legacy table '{legacy_table_name}'" - ) - - try: - # Get legacy table count (with workspace filtering) - if workspace: + legacy_count = None + if not new_table_exists: + # Check vector dimension compatibility before creating new table + if legacy_exists: count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name} WHERE workspace = $1" count_result = await db.query(count_query, [workspace]) - else: - count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name}" - count_result = await db.query(count_query, []) - logger.warning( - "PostgreSQL: Migration without workspace filter - this may copy data from all workspaces!" - ) + legacy_count = count_result.get("count", 0) if count_result else 0 - legacy_count = count_result.get("count", 0) if count_result else 0 - workspace_info = f" for workspace '{workspace}'" if workspace else "" - logger.info( - f"PostgreSQL: Found {legacy_count} records in legacy table{workspace_info}" + if legacy_count > 0: + legacy_dim = None + try: + sample_query = f"SELECT content_vector FROM {legacy_table_name} WHERE workspace = $1 LIMIT 1" + sample_result = await db.query(sample_query, [workspace]) + if sample_result and sample_result.get("content_vector"): + vector_data = sample_result["content_vector"] + # pgvector returns list directly + if isinstance(vector_data, (list, tuple)): + legacy_dim = len(vector_data) + elif isinstance(vector_data, str): + import json + + vector_list = json.loads(vector_data) + legacy_dim = len(vector_list) + + if legacy_dim and legacy_dim != embedding_dim: + logger.error( + f"PostgreSQL: Dimension mismatch detected! " + f"Legacy table '{legacy_table_name}' has {legacy_dim}d vectors, " + f"but new embedding model expects {embedding_dim}d." + ) + raise DataMigrationError( + f"Dimension mismatch between legacy table '{legacy_table_name}' " + f"and new embedding model. Expected {embedding_dim} but got {legacy_dim}." + ) + + except DataMigrationError: + # Re-raise DataMigrationError as-is to preserve specific error messages + raise + except Exception as e: + raise DataMigrationError( + f"Could not verify legacy table vector dimension: {e}. " + f"Proceeding with caution..." + ) + + await _pg_create_table(db, table_name, base_table, embedding_dim) + logger.info(f"PostgreSQL: New table '{table_name}' created successfully") + + # Ensure vector index is created + await db._create_vector_index(table_name, embedding_dim) + + # Case 2: Legacy table exist + if legacy_exists: + workspace_info = f" for workspace '{workspace}'" + + # Only drop legacy table if entire table is empty + total_count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name}" + total_count_result = await db.query(total_count_query, []) + total_count = ( + total_count_result.get("count", 0) if total_count_result else 0 ) - - if legacy_count == 0: - logger.info("PostgreSQL: Legacy table is empty, skipping migration") - await _pg_create_table(db, table_name, base_table, embedding_dim) - # Create vector index with correct embedding dimension - await db._create_vector_index(table_name, embedding_dim) + if total_count == 0: + logger.info( + f"PostgreSQL: Empty legacy table '{legacy_table_name}' deleted successfully" + ) + drop_query = f"DROP TABLE {legacy_table_name}" + await db.execute(drop_query, None) return - # Check vector dimension compatibility before migration - legacy_dim = None - try: - # Try to get vector dimension from pg_attribute metadata - dim_query = """ - SELECT - CASE - WHEN typname = 'vector' THEN - COALESCE(atttypmod, -1) - ELSE -1 - END as vector_dim - FROM pg_attribute a - JOIN pg_type t ON a.atttypid = t.oid - WHERE a.attrelid = $1::regclass - AND a.attname = 'content_vector' - """ - dim_result = await db.query(dim_query, [legacy_table_name]) - legacy_dim = dim_result.get("vector_dim", -1) if dim_result else -1 + # No data migration needed if legacy workspace is empty + if legacy_count is None: + count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name} WHERE workspace = $1" + count_result = await db.query(count_query, [workspace]) + legacy_count = count_result.get("count", 0) if count_result else 0 - if legacy_dim <= 0: - # Alternative: Try to detect by sampling a vector - logger.info( - "PostgreSQL: Metadata dimension check failed, trying vector sampling..." - ) - sample_query = ( - f"SELECT content_vector FROM {legacy_table_name} LIMIT 1" - ) - sample_result = await db.query(sample_query, []) - if sample_result and sample_result.get("content_vector"): - vector_data = sample_result["content_vector"] - # pgvector returns list directly - if isinstance(vector_data, (list, tuple)): - legacy_dim = len(vector_data) - elif isinstance(vector_data, str): - import json + if legacy_count == 0: + logger.info( + f"PostgreSQL: No records{workspace_info} found in legacy table. " + f"No data migration needed." + ) + return - vector_list = json.loads(vector_data) - legacy_dim = len(vector_list) + new_count_query = ( + f"SELECT COUNT(*) as count FROM {table_name} WHERE workspace = $1" + ) + new_count_result = await db.query(new_count_query, [workspace]) + new_table_workspace_count = ( + new_count_result.get("count", 0) if new_count_result else 0 + ) - if legacy_dim > 0 and embedding_dim and legacy_dim != embedding_dim: - logger.warning( - f"PostgreSQL: Dimension mismatch detected! " - f"Legacy table '{legacy_table_name}' has {legacy_dim}d vectors, " - f"but new embedding model expects {embedding_dim}d. " - f"Migration skipped to prevent data loss. " - f"Legacy table preserved as '{legacy_table_name}'. " - f"Creating new empty table '{table_name}' for new data." - ) - - # Create new table but skip migration - await _pg_create_table(db, table_name, base_table, embedding_dim) - await db._create_vector_index(table_name, embedding_dim) - - logger.info( - f"PostgreSQL: New table '{table_name}' created. " - f"To query legacy data, please use a {legacy_dim}d embedding model." - ) - return - - except Exception as e: + if new_table_workspace_count > 0: logger.warning( - f"PostgreSQL: Could not verify legacy table vector dimension: {e}. " - f"Proceeding with caution..." + f"PostgreSQL: New table '{table_name}' already has " + f"{new_table_workspace_count} records{workspace_info}. " + "Data migration skipped to avoid duplicates." ) + return - logger.info(f"PostgreSQL: Creating new table '{table_name}'") - await _pg_create_table(db, table_name, base_table, embedding_dim) - - migrated_count = await _pg_migrate_workspace_data( - db, - legacy_table_name, - table_name, - workspace, - legacy_count, - embedding_dim, - ) - - logger.info("PostgreSQL: Verifying migration...") - new_count_query = f"SELECT COUNT(*) as count FROM {table_name}" - new_count_result = await db.query(new_count_query, []) - new_count = new_count_result.get("count", 0) if new_count_result else 0 - - if new_count != legacy_count: - error_msg = ( - f"PostgreSQL: Migration verification failed, " - f"expected {legacy_count} records, got {new_count} in new table" - ) - logger.error(error_msg) - raise PostgreSQLMigrationError(error_msg) - + # Case 3: Legacy has workspace data and new table is empty for workspace logger.info( - f"PostgreSQL: Migration completed successfully: {migrated_count} records migrated" + f"PostgreSQL: Found legacy table '{legacy_table_name}' with {legacy_count} records{workspace_info}." ) + logger.info( + f"PostgreSQL: Migrating data from legacy table '{legacy_table_name}' to new table '{table_name}'" + ) + + try: + migrated_count = await _pg_migrate_workspace_data( + db, + legacy_table_name, + table_name, + workspace, + legacy_count, + embedding_dim, + ) + if migrated_count != legacy_count: + logger.warning( + "PostgreSQL: Read %s legacy records%s during migration, expected %s.", + migrated_count, + workspace_info, + legacy_count, + ) + + new_count_result = await db.query(new_count_query, [workspace]) + new_table_count_after = ( + new_count_result.get("count", 0) if new_count_result else 0 + ) + inserted_count = new_table_count_after - new_table_workspace_count + + if inserted_count != legacy_count: + error_msg = ( + "PostgreSQL: Migration verification failed, " + f"expected {legacy_count} inserted records, got {inserted_count}." + ) + logger.error(error_msg) + raise DataMigrationError(error_msg) + + except DataMigrationError: + # Re-raise DataMigrationError as-is to preserve specific error messages + raise + except Exception as e: + logger.error( + f"PostgreSQL: Failed to migrate data from legacy table '{legacy_table_name}' to new table '{table_name}': {e}" + ) + raise DataMigrationError( + f"Failed to migrate data from legacy table '{legacy_table_name}' to new table '{table_name}'" + ) from e + logger.info( f"PostgreSQL: Migration from '{legacy_table_name}' to '{table_name}' completed successfully" ) - - await db._create_vector_index(table_name, embedding_dim) - - try: - if workspace: - logger.info( - f"PostgreSQL: Deleting migrated workspace '{workspace}' data from legacy table '{legacy_table_name}'..." - ) - delete_query = ( - f"DELETE FROM {legacy_table_name} WHERE workspace = $1" - ) - await db.execute(delete_query, {"workspace": workspace}) - logger.info( - f"PostgreSQL: Deleted workspace '{workspace}' data from legacy table" - ) - - remaining_query = ( - f"SELECT COUNT(*) as count FROM {legacy_table_name}" - ) - remaining_result = await db.query(remaining_query, []) - remaining_count = ( - remaining_result.get("count", 0) if remaining_result else 0 - ) - - if remaining_count == 0: - logger.info( - f"PostgreSQL: Legacy table '{legacy_table_name}' is empty, deleting..." - ) - drop_query = f"DROP TABLE {legacy_table_name}" - await db.execute(drop_query, None) - logger.info( - f"PostgreSQL: Legacy table '{legacy_table_name}' deleted successfully" - ) - else: - logger.info( - f"PostgreSQL: Legacy table '{legacy_table_name}' preserved ({remaining_count} records from other workspaces remain)" - ) - else: - logger.warning( - f"PostgreSQL: No workspace specified, deleting entire legacy table '{legacy_table_name}'..." - ) - drop_query = f"DROP TABLE {legacy_table_name}" - await db.execute(drop_query, None) - logger.info( - f"PostgreSQL: Legacy table '{legacy_table_name}' deleted" - ) - - except Exception as delete_error: - # If cleanup fails, log warning but don't fail migration - logger.warning( - f"PostgreSQL: Failed to clean up legacy table '{legacy_table_name}': {delete_error}. " - "Migration succeeded, but manual cleanup may be needed." - ) - - except PostgreSQLMigrationError: - # Re-raise migration errors without wrapping - raise - except Exception as e: - error_msg = f"PostgreSQL: Migration failed with error: {e}" - logger.error(error_msg) - # Mirror Qdrant behavior: no automatic rollback - # Reason: partial data can be continued by re-running migration - raise PostgreSQLMigrationError(error_msg) from e + logger.info( + "PostgreSQL: Manual deletion is required after data migration verification." + ) async def initialize(self): async with get_data_init_lock(): @@ -2694,10 +2527,10 @@ class PGVectorStorage(BaseVectorStorage): await PGVectorStorage.setup_table( self.db, self.table_name, + self.workspace, # CRITICAL: Filter migration by workspace + embedding_dim=self.embedding_func.embedding_dim, legacy_table_name=self.legacy_table_name, base_table=self.legacy_table_name, # base_table for DDL template lookup - embedding_dim=self.embedding_func.embedding_dim, - workspace=self.workspace, # CRITICAL: Filter migration by workspace ) async def finalize(self): @@ -2707,34 +2540,45 @@ class PGVectorStorage(BaseVectorStorage): def _upsert_chunks( self, item: dict[str, Any], current_time: datetime.datetime - ) -> tuple[str, dict[str, Any]]: + ) -> tuple[str, tuple[Any, ...]]: + """Prepare upsert data for chunks. + + Returns: + Tuple of (SQL template, values tuple for executemany) + """ try: upsert_sql = SQL_TEMPLATES["upsert_chunk"].format( table_name=self.table_name ) - data: dict[str, Any] = { - "workspace": self.workspace, - "id": item["__id__"], - "tokens": item["tokens"], - "chunk_order_index": item["chunk_order_index"], - "full_doc_id": item["full_doc_id"], - "content": item["content"], - "content_vector": json.dumps(item["__vector__"].tolist()), - "file_path": item["file_path"], - "create_time": current_time, - "update_time": current_time, - } + # Return tuple in the exact order of SQL parameters ($1, $2, ...) + values: tuple[Any, ...] = ( + self.workspace, # $1 + item["__id__"], # $2 + item["tokens"], # $3 + item["chunk_order_index"], # $4 + item["full_doc_id"], # $5 + item["content"], # $6 + item["__vector__"], # $7 - numpy array, handled by pgvector codec + item["file_path"], # $8 + current_time, # $9 + current_time, # $10 + ) except Exception as e: logger.error( - f"[{self.workspace}] Error to prepare upsert,\nsql: {e}\nitem: {item}" + f"[{self.workspace}] Error to prepare upsert,\nerror: {e}\nitem: {item}" ) raise - return upsert_sql, data + return upsert_sql, values def _upsert_entities( self, item: dict[str, Any], current_time: datetime.datetime - ) -> tuple[str, dict[str, Any]]: + ) -> tuple[str, tuple[Any, ...]]: + """Prepare upsert data for entities. + + Returns: + Tuple of (SQL template, values tuple for executemany) + """ upsert_sql = SQL_TEMPLATES["upsert_entity"].format(table_name=self.table_name) source_id = item["source_id"] if isinstance(source_id, str) and "" in source_id: @@ -2742,22 +2586,28 @@ class PGVectorStorage(BaseVectorStorage): else: chunk_ids = [source_id] - data: dict[str, Any] = { - "workspace": self.workspace, - "id": item["__id__"], - "entity_name": item["entity_name"], - "content": item["content"], - "content_vector": json.dumps(item["__vector__"].tolist()), - "chunk_ids": chunk_ids, - "file_path": item.get("file_path", None), - "create_time": current_time, - "update_time": current_time, - } - return upsert_sql, data + # Return tuple in the exact order of SQL parameters ($1, $2, ...) + values: tuple[Any, ...] = ( + self.workspace, # $1 + item["__id__"], # $2 + item["entity_name"], # $3 + item["content"], # $4 + item["__vector__"], # $5 - numpy array, handled by pgvector codec + chunk_ids, # $6 + item.get("file_path", None), # $7 + current_time, # $8 + current_time, # $9 + ) + return upsert_sql, values def _upsert_relationships( self, item: dict[str, Any], current_time: datetime.datetime - ) -> tuple[str, dict[str, Any]]: + ) -> tuple[str, tuple[Any, ...]]: + """Prepare upsert data for relationships. + + Returns: + Tuple of (SQL template, values tuple for executemany) + """ upsert_sql = SQL_TEMPLATES["upsert_relationship"].format( table_name=self.table_name ) @@ -2767,19 +2617,20 @@ class PGVectorStorage(BaseVectorStorage): else: chunk_ids = [source_id] - data: dict[str, Any] = { - "workspace": self.workspace, - "id": item["__id__"], - "source_id": item["src_id"], - "target_id": item["tgt_id"], - "content": item["content"], - "content_vector": json.dumps(item["__vector__"].tolist()), - "chunk_ids": chunk_ids, - "file_path": item.get("file_path", None), - "create_time": current_time, - "update_time": current_time, - } - return upsert_sql, data + # Return tuple in the exact order of SQL parameters ($1, $2, ...) + values: tuple[Any, ...] = ( + self.workspace, # $1 + item["__id__"], # $2 + item["src_id"], # $3 + item["tgt_id"], # $4 + item["content"], # $5 + item["__vector__"], # $6 - numpy array, handled by pgvector codec + chunk_ids, # $7 + item.get("file_path", None), # $8 + current_time, # $9 + current_time, # $10 + ) + return upsert_sql, values async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}") @@ -2807,17 +2658,34 @@ class PGVectorStorage(BaseVectorStorage): embeddings = np.concatenate(embeddings_list) for i, d in enumerate(list_data): d["__vector__"] = embeddings[i] + + # Prepare batch values for executemany + batch_values: list[tuple[Any, ...]] = [] + upsert_sql = None + for item in list_data: if is_namespace(self.namespace, NameSpace.VECTOR_STORE_CHUNKS): - upsert_sql, data = self._upsert_chunks(item, current_time) + upsert_sql, values = self._upsert_chunks(item, current_time) elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_ENTITIES): - upsert_sql, data = self._upsert_entities(item, current_time) + upsert_sql, values = self._upsert_entities(item, current_time) elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_RELATIONSHIPS): - upsert_sql, data = self._upsert_relationships(item, current_time) + upsert_sql, values = self._upsert_relationships(item, current_time) else: raise ValueError(f"{self.namespace} is not supported") - await self.db.execute(upsert_sql, data) + batch_values.append(values) + + # Use executemany for batch execution - significantly reduces DB round-trips + if batch_values and upsert_sql: + + async def _batch_upsert(connection: asyncpg.Connection) -> None: + await register_vector(connection) + await connection.executemany(upsert_sql, batch_values) + + await self.db._run_with_retry(_batch_upsert) + logger.debug( + f"[{self.workspace}] Batch upserted {len(batch_values)} records to {self.namespace}" + ) #################### query method ############### async def query( @@ -3658,12 +3526,6 @@ class PGDocStatusStorage(DocStatusStorage): return {"status": "error", "message": str(e)} -class PostgreSQLMigrationError(Exception): - """Exception for PostgreSQL table migration errors.""" - - pass - - class PGGraphQueryException(Exception): """Exception for the AGE queries.""" @@ -5263,14 +5125,14 @@ TABLES = { )""" }, "LIGHTRAG_VDB_CHUNKS": { - "ddl": f"""CREATE TABLE LIGHTRAG_VDB_CHUNKS ( + "ddl": """CREATE TABLE LIGHTRAG_VDB_CHUNKS ( id VARCHAR(255), workspace VARCHAR(255), full_doc_id VARCHAR(256), chunk_order_index INTEGER, tokens INTEGER, content TEXT, - content_vector VECTOR({os.environ.get("EMBEDDING_DIM", 1024)}), + content_vector VECTOR(dimension), file_path TEXT NULL, create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, @@ -5278,12 +5140,12 @@ TABLES = { )""" }, "LIGHTRAG_VDB_ENTITY": { - "ddl": f"""CREATE TABLE LIGHTRAG_VDB_ENTITY ( + "ddl": """CREATE TABLE LIGHTRAG_VDB_ENTITY ( id VARCHAR(255), workspace VARCHAR(255), entity_name VARCHAR(512), content TEXT, - content_vector VECTOR({os.environ.get("EMBEDDING_DIM", 1024)}), + content_vector VECTOR(dimension), create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, chunk_ids VARCHAR(255)[] NULL, @@ -5292,13 +5154,13 @@ TABLES = { )""" }, "LIGHTRAG_VDB_RELATION": { - "ddl": f"""CREATE TABLE LIGHTRAG_VDB_RELATION ( + "ddl": """CREATE TABLE LIGHTRAG_VDB_RELATION ( id VARCHAR(255), workspace VARCHAR(255), source_id VARCHAR(512), target_id VARCHAR(512), content TEXT, - content_vector VECTOR({os.environ.get("EMBEDDING_DIM", 1024)}), + content_vector VECTOR(dimension), create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, chunk_ids VARCHAR(255)[] NULL, diff --git a/pyproject.toml b/pyproject.toml index 761a3309..dd3dbc92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,7 +72,6 @@ api = [ # API-specific dependencies "aiofiles", "ascii_colors", - "asyncpg", "distro", "fastapi", "httpcore", @@ -108,7 +107,8 @@ offline-storage = [ "neo4j>=5.0.0,<7.0.0", "pymilvus>=2.6.2,<3.0.0", "pymongo>=4.0.0,<5.0.0", - "asyncpg>=0.29.0,<1.0.0", + "asyncpg>=0.31.0,<1.0.0", + "pgvector>=0.4.2,<1.0.0", "qdrant-client>=1.11.0,<2.0.0", ] diff --git a/requirements-offline-storage.txt b/requirements-offline-storage.txt index 13a9c0e2..82caacbd 100644 --- a/requirements-offline-storage.txt +++ b/requirements-offline-storage.txt @@ -8,8 +8,9 @@ # Or use constraints: pip install --constraint constraints-offline.txt -r requirements-offline-storage.txt # Storage backend dependencies (with version constraints matching pyproject.toml) -asyncpg>=0.29.0,<1.0.0 +asyncpg>=0.31.0,<1.0.0 neo4j>=5.0.0,<7.0.0 +pgvector>=0.4.2,<1.0.0 pymilvus>=2.6.2,<3.0.0 pymongo>=4.0.0,<5.0.0 qdrant-client>=1.11.0,<2.0.0 diff --git a/requirements-offline.txt b/requirements-offline.txt index 87ca7a6a..283ced73 100644 --- a/requirements-offline.txt +++ b/requirements-offline.txt @@ -7,20 +7,17 @@ # Recommended: Use pip install lightrag-hku[offline] for the same effect # Or use constraints: pip install --constraint constraints-offline.txt -r requirements-offline.txt -# LLM provider dependencies (with version constraints matching pyproject.toml) aioboto3>=12.0.0,<16.0.0 anthropic>=0.18.0,<1.0.0 - -# Storage backend dependencies -asyncpg>=0.29.0,<1.0.0 +asyncpg>=0.31.0,<1.0.0 +google-api-core>=2.0.0,<3.0.0 google-genai>=1.0.0,<2.0.0 - -# Document processing dependencies llama-index>=0.9.0,<1.0.0 neo4j>=5.0.0,<7.0.0 ollama>=0.1.0,<1.0.0 openai>=2.0.0,<3.0.0 openpyxl>=3.0.0,<4.0.0 +pgvector>=0.4.2,<1.0.0 pycryptodome>=3.0.0,<4.0.0 pymilvus>=2.6.2,<3.0.0 pymongo>=4.0.0,<5.0.0 diff --git a/tests/test_postgres_migration.py b/tests/test_postgres_migration.py index df88e700..7509562f 100644 --- a/tests/test_postgres_migration.py +++ b/tests/test_postgres_migration.py @@ -123,10 +123,21 @@ async def test_postgres_migration_trigger( {"id": f"test_id_{i}", "content": f"content_{i}", "workspace": "test_ws"} for i in range(100) ] + migration_state = {"new_table_count": 0} async def mock_query(sql, params=None, multirows=False, **kwargs): if "COUNT(*)" in sql: - return {"count": 100} + sql_upper = sql.upper() + legacy_table = storage.legacy_table_name.upper() + new_table = storage.table_name.upper() + is_new_table = new_table in sql_upper + is_legacy_table = legacy_table in sql_upper and not is_new_table + + if is_new_table: + return {"count": migration_state["new_table_count"]} + if is_legacy_table: + return {"count": 100} + return {"count": 0} elif multirows and "SELECT *" in sql: # Mock batch fetch for migration # Handle workspace filtering: params = [workspace, offset, limit] or [offset, limit] @@ -145,6 +156,17 @@ async def test_postgres_migration_trigger( mock_pg_db.query = AsyncMock(side_effect=mock_query) + # Track migration through _run_with_retry calls + migration_executed = [] + + async def mock_run_with_retry(operation, **kwargs): + # Track that migration batch operation was called + migration_executed.append(True) + migration_state["new_table_count"] = 100 + return None + + mock_pg_db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry) + with ( patch( "lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists @@ -154,9 +176,9 @@ async def test_postgres_migration_trigger( # Initialize storage (should trigger migration) await storage.initialize() - # Verify migration was executed - # Check that execute was called for inserting rows - assert mock_pg_db.execute.call_count > 0 + # Verify migration was executed by checking _run_with_retry was called + # (batch migration uses _run_with_retry with executemany) + assert len(migration_executed) > 0, "Migration should have been executed" @pytest.mark.asyncio @@ -291,6 +313,7 @@ async def test_scenario_2_legacy_upgrade_migration( # Track which queries have been made for proper response query_history = [] + migration_state = {"new_table_count": 0} async def mock_query(sql, params=None, multirows=False, **kwargs): query_history.append(sql) @@ -303,27 +326,20 @@ async def test_scenario_2_legacy_upgrade_migration( base_name = storage.legacy_table_name.upper() # Check if this is querying the new table (has model suffix) - has_model_suffix = any( - suffix in sql_upper - for suffix in ["TEXT_EMBEDDING", "_1536D", "_768D", "_1024D", "_3072D"] - ) + has_model_suffix = storage.table_name.upper() in sql_upper is_legacy_table = base_name in sql_upper and not has_model_suffix - is_new_table = has_model_suffix has_workspace_filter = "WHERE workspace" in sql if is_legacy_table and has_workspace_filter: # Count for legacy table with workspace filter (before migration) return {"count": 50} elif is_legacy_table and not has_workspace_filter: - # Total count for legacy table (after deletion, checking remaining) - return {"count": 0} - elif is_new_table: - # Count for new table (verification after migration) + # Total count for legacy table return {"count": 50} else: - # Fallback - return {"count": 0} + # New table count (before/after migration) + return {"count": migration_state["new_table_count"]} elif multirows and "SELECT *" in sql: # Mock batch fetch for migration # Handle workspace filtering: params = [workspace, offset, limit] or [offset, limit] @@ -342,6 +358,17 @@ async def test_scenario_2_legacy_upgrade_migration( mock_pg_db.query = AsyncMock(side_effect=mock_query) + # Track migration through _run_with_retry calls + migration_executed = [] + + async def mock_run_with_retry(operation, **kwargs): + # Track that migration batch operation was called + migration_executed.append(True) + migration_state["new_table_count"] = 50 + return None + + mock_pg_db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry) + with ( patch( "lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists @@ -353,26 +380,10 @@ async def test_scenario_2_legacy_upgrade_migration( # Verify table name contains ada-002 assert "text_embedding_ada_002_1536d" in storage.table_name - # Verify migration was executed - assert mock_pg_db.execute.call_count >= 50 # At least one execute per row + # Verify migration was executed (batch migration uses _run_with_retry) + assert len(migration_executed) > 0, "Migration should have been executed" mock_create.assert_called_once() - # Verify legacy table was automatically deleted after successful migration - # This prevents Case 1 warnings on next startup - delete_calls = [ - call - for call in mock_pg_db.execute.call_args_list - if call[0][0] and "DROP TABLE" in call[0][0] - ] - assert ( - len(delete_calls) >= 1 - ), "Legacy table should be deleted after successful migration" - # Check if legacy table was dropped - dropped_table = storage.legacy_table_name - assert any( - dropped_table in str(call) for call in delete_calls - ), f"Expected to drop '{dropped_table}'" - @pytest.mark.asyncio async def test_scenario_3_multi_model_coexistence( @@ -586,13 +597,12 @@ async def test_case1_sequential_workspace_migration( Critical bug fix verification: Timeline: 1. Legacy table has workspace_a (3 records) + workspace_b (3 records) - 2. Workspace A initializes first → Case 4 (only legacy exists) → migrates A's data - 3. Workspace B initializes later → Case 1 (both tables exist) → should migrate B's data + 2. Workspace A initializes first → Case 3 (only legacy exists) → migrates A's data + 3. Workspace B initializes later → Case 3 (both tables exist, legacy has B's data) → should migrate B's data 4. Verify workspace B's data is correctly migrated to new table - 5. Verify legacy table is cleaned up after both workspaces migrate - This test verifies the fix where Case 1 now checks and migrates current - workspace's data instead of just checking if legacy table is empty globally. + This test verifies the migration logic correctly handles multi-tenant scenarios + where different workspaces migrate sequentially. """ config = { "embedding_batch_num": 10, @@ -616,9 +626,14 @@ async def test_case1_sequential_workspace_migration( ] # Track migration state - migration_state = {"new_table_exists": False, "workspace_a_migrated": False} + migration_state = { + "new_table_exists": False, + "workspace_a_migrated": False, + "workspace_a_migration_count": 0, + "workspace_b_migration_count": 0, + } - # Step 1: Simulate workspace_a initialization (Case 4) + # Step 1: Simulate workspace_a initialization (Case 3 - only legacy exists) # CRITICAL: Set db.workspace to workspace_a mock_pg_db.workspace = "workspace_a" @@ -637,16 +652,7 @@ async def test_case1_sequential_workspace_migration( return migration_state["new_table_exists"] return False - # Track inserted records count for verification - inserted_count = {"workspace_a": 0} - - # Mock execute to track inserts - async def mock_execute_a(sql, data=None, **kwargs): - if sql and "INSERT INTO" in sql.upper(): - inserted_count["workspace_a"] += 1 - return None - - # Mock query for workspace_a (Case 4) + # Mock query for workspace_a (Case 3) async def mock_query_a(sql, params=None, multirows=False, **kwargs): sql_upper = sql.upper() base_name = storage_a.legacy_table_name.upper() @@ -659,17 +665,23 @@ async def test_case1_sequential_workspace_migration( if is_legacy and has_workspace_filter: workspace = params[0] if params and len(params) > 0 else None if workspace == "workspace_a": - # After migration starts, pretend legacy is empty for this workspace - return {"count": 3 - inserted_count["workspace_a"]} + return {"count": 3} elif workspace == "workspace_b": return {"count": 3} elif is_legacy and not has_workspace_filter: # Global count in legacy table - remaining = 6 - inserted_count["workspace_a"] - return {"count": remaining} + return {"count": 6} elif has_model_suffix: - # New table count (for verification) - return {"count": inserted_count["workspace_a"]} + if has_workspace_filter: + workspace = params[0] if params and len(params) > 0 else None + if workspace == "workspace_a": + return {"count": migration_state["workspace_a_migration_count"]} + if workspace == "workspace_b": + return {"count": migration_state["workspace_b_migration_count"]} + return { + "count": migration_state["workspace_a_migration_count"] + + migration_state["workspace_b_migration_count"] + } elif multirows and "SELECT *" in sql: if "WHERE workspace" in sql: workspace = params[0] if params and len(params) > 0 else None @@ -680,9 +692,18 @@ async def test_case1_sequential_workspace_migration( return {} mock_pg_db.query = AsyncMock(side_effect=mock_query_a) - mock_pg_db.execute = AsyncMock(side_effect=mock_execute_a) - # Initialize workspace_a (Case 4) + # Track migration via _run_with_retry (batch migration uses this) + migration_a_executed = [] + + async def mock_run_with_retry_a(operation, **kwargs): + migration_a_executed.append(True) + migration_state["workspace_a_migration_count"] = len(mock_rows_a) + return None + + mock_pg_db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry_a) + + # Initialize workspace_a (Case 3) with ( patch( "lightrag.kg.postgres_impl._pg_table_exists", @@ -694,11 +715,14 @@ async def test_case1_sequential_workspace_migration( migration_state["new_table_exists"] = True migration_state["workspace_a_migrated"] = True - print("✅ Step 1: Workspace A initialized (Case 4)") - assert mock_pg_db.execute.call_count >= 3 - print(f"✅ Step 1: {mock_pg_db.execute.call_count} execute calls") + print("✅ Step 1: Workspace A initialized") + # Verify migration was executed via _run_with_retry (batch migration uses executemany) + assert ( + len(migration_a_executed) > 0 + ), "Migration should have been executed for workspace_a" + print(f"✅ Step 1: Migration executed {len(migration_a_executed)} batch(es)") - # Step 2: Simulate workspace_b initialization (Case 1) + # Step 2: Simulate workspace_b initialization (Case 3 - both exist, but legacy has B's data) # CRITICAL: Set db.workspace to workspace_b mock_pg_db.workspace = "workspace_b" @@ -710,22 +734,12 @@ async def test_case1_sequential_workspace_migration( ) mock_pg_db.reset_mock() - migration_state["workspace_b_migrated"] = False # Mock table_exists for workspace_b (both exist) async def mock_table_exists_b(db, table_name): - return True + return True # Both tables exist - # Track inserted records count for workspace_b - inserted_count["workspace_b"] = 0 - - # Mock execute for workspace_b to track inserts - async def mock_execute_b(sql, data=None, **kwargs): - if sql and "INSERT INTO" in sql.upper(): - inserted_count["workspace_b"] += 1 - return None - - # Mock query for workspace_b (Case 1) + # Mock query for workspace_b (Case 3) async def mock_query_b(sql, params=None, multirows=False, **kwargs): sql_upper = sql.upper() base_name = storage_b.legacy_table_name.upper() @@ -738,24 +752,21 @@ async def test_case1_sequential_workspace_migration( if is_legacy and has_workspace_filter: workspace = params[0] if params and len(params) > 0 else None if workspace == "workspace_b": - # After migration starts, pretend legacy is empty for this workspace - return {"count": 3 - inserted_count["workspace_b"]} + return {"count": 3} # workspace_b still has data in legacy elif workspace == "workspace_a": - return {"count": 0} # Already migrated + return {"count": 0} # workspace_a already migrated elif is_legacy and not has_workspace_filter: # Global count: only workspace_b data remains - return {"count": 3 - inserted_count["workspace_b"]} + return {"count": 3} elif has_model_suffix: - # New table total count (workspace_a: 3 + workspace_b: inserted) if has_workspace_filter: workspace = params[0] if params and len(params) > 0 else None if workspace == "workspace_b": - return {"count": inserted_count["workspace_b"]} + return {"count": migration_state["workspace_b_migration_count"]} elif workspace == "workspace_a": return {"count": 3} else: - # Total count in new table (for verification) - return {"count": 3 + inserted_count["workspace_b"]} + return {"count": 3 + migration_state["workspace_b_migration_count"]} elif multirows and "SELECT *" in sql: if "WHERE workspace" in sql: workspace = params[0] if params and len(params) > 0 else None @@ -766,40 +777,32 @@ async def test_case1_sequential_workspace_migration( return {} mock_pg_db.query = AsyncMock(side_effect=mock_query_b) - mock_pg_db.execute = AsyncMock(side_effect=mock_execute_b) - # Initialize workspace_b (Case 1) + # Track migration via _run_with_retry for workspace_b + migration_b_executed = [] + + async def mock_run_with_retry_b(operation, **kwargs): + migration_b_executed.append(True) + migration_state["workspace_b_migration_count"] = len(mock_rows_b) + return None + + mock_pg_db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry_b) + + # Initialize workspace_b (Case 3 - both tables exist) with patch( "lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists_b ): await storage_b.initialize() - migration_state["workspace_b_migrated"] = True - print("✅ Step 2: Workspace B initialized (Case 1)") + print("✅ Step 2: Workspace B initialized") - # Verify workspace_b migration happened - execute_calls = mock_pg_db.execute.call_args_list - insert_calls = [ - call for call in execute_calls if call[0][0] and "INSERT INTO" in call[0][0] - ] - assert len(insert_calls) >= 3, f"Expected >= 3 inserts, got {len(insert_calls)}" - print(f"✅ Step 2: {len(insert_calls)} insert calls") + # Verify workspace_b migration happens when new table has no workspace_b data + # but legacy table still has workspace_b data. + assert ( + len(migration_b_executed) > 0 + ), "Migration should have been executed for workspace_b" + print("✅ Step 2: Migration executed for workspace_b") - # Verify DELETE and DROP TABLE - delete_calls = [ - call - for call in execute_calls - if call[0][0] - and "DELETE FROM" in call[0][0] - and "WHERE workspace" in call[0][0] - ] - assert len(delete_calls) >= 1, "Expected DELETE workspace_b data" - print("✅ Step 2: DELETE workspace_b from legacy") - - drop_calls = [ - call for call in execute_calls if call[0][0] and "DROP TABLE" in call[0][0] - ] - assert len(drop_calls) >= 1, "Expected DROP TABLE" - print("✅ Step 2: Legacy table dropped") - - print("\n🎉 Case 1c: Sequential workspace migration verified!") + print("\n🎉 Case 1c: Sequential workspace migration verification complete!") + print(" - Workspace A: Migrated successfully (only legacy existed)") + print(" - Workspace B: Migrated successfully (new table empty for workspace_b)")