From 8ef86c489883fcb6a5c9f4254269f1539292aaaa Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 21 Dec 2025 16:25:58 +0800 Subject: [PATCH] Refactor PG vector storage and add index creation * Move helper functions to static methods * Move check table exists functions to PostgreSQLDB * Create ID and workspace indexes in DDL --- lightrag/kg/postgres_impl.py | 343 ++++++++++++++++++++--------------- 1 file changed, 201 insertions(+), 142 deletions(-) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 100d76ff..6174ac48 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -1538,6 +1538,24 @@ class PostgreSQLDB: logger.error(f"PostgreSQL database, error:{e}") raise + async def check_table_exists(self, table_name: str) -> bool: + """Check if a table exists in PostgreSQL database + + Args: + table_name: Name of the table to check + + Returns: + bool: True if table exists, False otherwise + """ + query = """ + SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = $1 + ) + """ + result = await self.query(query, [table_name.lower()]) + return result.get("exists", False) if result else False + async def execute( self, sql: str, @@ -2239,143 +2257,6 @@ class PGKVStorage(BaseKVStorage): return {"status": "error", "message": str(e)} -async def _pg_table_exists(db: PostgreSQLDB, table_name: str) -> bool: - """Check if a table exists in PostgreSQL database""" - query = """ - SELECT EXISTS ( - SELECT FROM information_schema.tables - WHERE table_name = $1 - ) - """ - result = await db.query(query, [table_name.lower()]) - return result.get("exists", False) if result else False - - -async def _pg_create_table( - db: PostgreSQLDB, table_name: str, base_table: str, embedding_dim: int -) -> None: - """Create a new vector table by replacing the table name in DDL template""" - if base_table not in TABLES: - raise ValueError(f"No DDL template found for table: {base_table}") - - ddl_template = TABLES[base_table]["ddl"] - - # Replace embedding dimension placeholder if exists - ddl = ddl_template.replace("VECTOR(dimension)", f"VECTOR({embedding_dim})") - - # Replace table name - ddl = ddl.replace(base_table, table_name) - - await db.execute(ddl) - - -async def _pg_migrate_workspace_data( - db: PostgreSQLDB, - legacy_table_name: str, - new_table_name: str, - workspace: str, - expected_count: int, - embedding_dim: int, -) -> int: - """Migrate workspace data from legacy table to new table using batch insert. - - This function uses asyncpg's executemany for efficient batch insertion, - reducing database round-trips from N to 1 per batch. - - Uses keyset pagination (cursor-based) with ORDER BY id for stable ordering. - This ensures every legacy row is migrated exactly once, avoiding the - non-deterministic row ordering issues with OFFSET/LIMIT without ORDER BY. - """ - migrated_count = 0 - last_id: str | None = None - batch_size = 500 - - while True: - # Use keyset pagination with ORDER BY id for deterministic ordering - # This avoids OFFSET/LIMIT without ORDER BY which can skip or duplicate rows - if workspace: - if last_id is not None: - select_query = f"SELECT * FROM {legacy_table_name} WHERE workspace = $1 AND id > $2 ORDER BY id LIMIT $3" - rows = await db.query( - select_query, [workspace, last_id, batch_size], multirows=True - ) - else: - select_query = f"SELECT * FROM {legacy_table_name} WHERE workspace = $1 ORDER BY id LIMIT $2" - rows = await db.query( - select_query, [workspace, batch_size], multirows=True - ) - else: - if last_id is not None: - select_query = f"SELECT * FROM {legacy_table_name} WHERE id > $1 ORDER BY id LIMIT $2" - rows = await db.query( - select_query, [last_id, batch_size], multirows=True - ) - else: - select_query = f"SELECT * FROM {legacy_table_name} ORDER BY id LIMIT $1" - rows = await db.query(select_query, [batch_size], multirows=True) - - if not rows: - break - - # Track the last ID for keyset pagination cursor - last_id = rows[-1]["id"] - - # Batch insert optimization: use executemany instead of individual inserts - # Get column names from the first row - first_row = dict(rows[0]) - columns = list(first_row.keys()) - columns_str = ", ".join(columns) - placeholders = ", ".join([f"${i + 1}" for i in range(len(columns))]) - - insert_query = f""" - INSERT INTO {new_table_name} ({columns_str}) - VALUES ({placeholders}) - ON CONFLICT (workspace, id) DO NOTHING - """ - - # Prepare batch data: convert rows to list of tuples - batch_values = [] - for row in rows: - row_dict = dict(row) - - # FIX: Parse vector strings from connections without register_vector codec. - # When pgvector codec is not registered on the read connection, vector - # columns are returned as text strings like "[0.1,0.2,...]" instead of - # lists/arrays. We need to convert these to numpy arrays before passing - # to executemany, which uses a connection WITH register_vector codec - # that expects list/tuple/ndarray types. - if "content_vector" in row_dict: - vec = row_dict["content_vector"] - if isinstance(vec, str): - # pgvector text format: "[0.1,0.2,0.3,...]" - vec = vec.strip("[]") - if vec: - row_dict["content_vector"] = np.array( - [float(x) for x in vec.split(",")], dtype=np.float32 - ) - else: - row_dict["content_vector"] = None - - # Extract values in column order to match placeholders - values_tuple = tuple(row_dict[col] for col in columns) - batch_values.append(values_tuple) - - # Use executemany for batch execution - significantly reduces DB round-trips - # Note: register_vector is already called on pool init, no need to call it again - async def _batch_insert(connection: asyncpg.Connection) -> None: - await connection.executemany(insert_query, batch_values) - - await db._run_with_retry(_batch_insert) - - migrated_count += len(rows) - workspace_info = f" for workspace '{workspace}'" if workspace else "" - logger.info( - f"PostgreSQL: {migrated_count}/{expected_count} records migrated{workspace_info}" - ) - - return migrated_count - - @final @dataclass class PGVectorStorage(BaseVectorStorage): @@ -2423,6 +2304,181 @@ class PGVectorStorage(BaseVectorStorage): f"Consider using a shorter embedding model name or workspace name." ) + @staticmethod + async def _pg_create_table( + db: PostgreSQLDB, table_name: str, base_table: str, embedding_dim: int + ) -> None: + """Create a new vector table by replacing the table name in DDL template, + and create indexes on id and (workspace, id) columns. + + Args: + db: PostgreSQLDB instance + table_name: Name of the new table to create + base_table: Base table name for DDL template lookup + embedding_dim: Embedding dimension for vector column + """ + if base_table not in TABLES: + raise ValueError(f"No DDL template found for table: {base_table}") + + ddl_template = TABLES[base_table]["ddl"] + + # Replace embedding dimension placeholder if exists + ddl = ddl_template.replace("VECTOR(dimension)", f"VECTOR({embedding_dim})") + + # Replace table name + ddl = ddl.replace(base_table, table_name) + + await db.execute(ddl) + + # Create indexes similar to check_tables() but with safe index names + # Create index for id column + id_index_name = _safe_index_name(table_name, "id") + try: + create_id_index_sql = f"CREATE INDEX {id_index_name} ON {table_name}(id)" + logger.info( + f"PostgreSQL, Creating index {id_index_name} on table {table_name}" + ) + await db.execute(create_id_index_sql) + except Exception as e: + logger.error( + f"PostgreSQL, Failed to create index {id_index_name}, Got: {e}" + ) + + # Create composite index for (workspace, id) + workspace_id_index_name = _safe_index_name(table_name, "workspace_id") + try: + create_composite_index_sql = ( + f"CREATE INDEX {workspace_id_index_name} ON {table_name}(workspace, id)" + ) + logger.info( + f"PostgreSQL, Creating composite index {workspace_id_index_name} on table {table_name}" + ) + await db.execute(create_composite_index_sql) + except Exception as e: + logger.error( + f"PostgreSQL, Failed to create composite index {workspace_id_index_name}, Got: {e}" + ) + + @staticmethod + async def _pg_migrate_workspace_data( + db: PostgreSQLDB, + legacy_table_name: str, + new_table_name: str, + workspace: str, + expected_count: int, + embedding_dim: int, + ) -> int: + """Migrate workspace data from legacy table to new table using batch insert. + + This function uses asyncpg's executemany for efficient batch insertion, + reducing database round-trips from N to 1 per batch. + + Uses keyset pagination (cursor-based) with ORDER BY id for stable ordering. + This ensures every legacy row is migrated exactly once, avoiding the + non-deterministic row ordering issues with OFFSET/LIMIT without ORDER BY. + + Args: + db: PostgreSQLDB instance + legacy_table_name: Name of the legacy table to migrate from + new_table_name: Name of the new table to migrate to + workspace: Workspace to filter records for migration + expected_count: Expected number of records to migrate + embedding_dim: Embedding dimension for vector column + + Returns: + Number of records migrated + """ + migrated_count = 0 + last_id: str | None = None + batch_size = 500 + + while True: + # Use keyset pagination with ORDER BY id for deterministic ordering + # This avoids OFFSET/LIMIT without ORDER BY which can skip or duplicate rows + if workspace: + if last_id is not None: + select_query = f"SELECT * FROM {legacy_table_name} WHERE workspace = $1 AND id > $2 ORDER BY id LIMIT $3" + rows = await db.query( + select_query, [workspace, last_id, batch_size], multirows=True + ) + else: + select_query = f"SELECT * FROM {legacy_table_name} WHERE workspace = $1 ORDER BY id LIMIT $2" + rows = await db.query( + select_query, [workspace, batch_size], multirows=True + ) + else: + if last_id is not None: + select_query = f"SELECT * FROM {legacy_table_name} WHERE id > $1 ORDER BY id LIMIT $2" + rows = await db.query( + select_query, [last_id, batch_size], multirows=True + ) + else: + select_query = ( + f"SELECT * FROM {legacy_table_name} ORDER BY id LIMIT $1" + ) + rows = await db.query(select_query, [batch_size], multirows=True) + + if not rows: + break + + # Track the last ID for keyset pagination cursor + last_id = rows[-1]["id"] + + # Batch insert optimization: use executemany instead of individual inserts + # Get column names from the first row + first_row = dict(rows[0]) + columns = list(first_row.keys()) + columns_str = ", ".join(columns) + placeholders = ", ".join([f"${i + 1}" for i in range(len(columns))]) + + insert_query = f""" + INSERT INTO {new_table_name} ({columns_str}) + VALUES ({placeholders}) + ON CONFLICT (workspace, id) DO NOTHING + """ + + # Prepare batch data: convert rows to list of tuples + batch_values = [] + for row in rows: + row_dict = dict(row) + + # FIX: Parse vector strings from connections without register_vector codec. + # When pgvector codec is not registered on the read connection, vector + # columns are returned as text strings like "[0.1,0.2,...]" instead of + # lists/arrays. We need to convert these to numpy arrays before passing + # to executemany, which uses a connection WITH register_vector codec + # that expects list/tuple/ndarray types. + if "content_vector" in row_dict: + vec = row_dict["content_vector"] + if isinstance(vec, str): + # pgvector text format: "[0.1,0.2,0.3,...]" + vec = vec.strip("[]") + if vec: + row_dict["content_vector"] = np.array( + [float(x) for x in vec.split(",")], dtype=np.float32 + ) + else: + row_dict["content_vector"] = None + + # Extract values in column order to match placeholders + values_tuple = tuple(row_dict[col] for col in columns) + batch_values.append(values_tuple) + + # Use executemany for batch execution - significantly reduces DB round-trips + # Note: register_vector is already called on pool init, no need to call it again + async def _batch_insert(connection: asyncpg.Connection) -> None: + await connection.executemany(insert_query, batch_values) + + await db._run_with_retry(_batch_insert) + + migrated_count += len(rows) + workspace_info = f" for workspace '{workspace}'" if workspace else "" + logger.info( + f"PostgreSQL: {migrated_count}/{expected_count} records migrated{workspace_info}" + ) + + return migrated_count + @staticmethod async def setup_table( db: PostgreSQLDB, @@ -2439,6 +2495,7 @@ class PGVectorStorage(BaseVectorStorage): Check vector dimension compatibility before new table creation. Drop legacy table if it exists and is empty. Only migrate data from legacy table to new table when new table first created and legacy table is not empty. + This function must be call ClientManager.get_client() to legacy table is migrated to latest schema. Args: db: PostgreSQLDB instance @@ -2451,9 +2508,9 @@ class PGVectorStorage(BaseVectorStorage): if not workspace: raise ValueError("workspace must be provided") - new_table_exists = await _pg_table_exists(db, table_name) - legacy_exists = legacy_table_name and await _pg_table_exists( - db, legacy_table_name + new_table_exists = await db.check_table_exists(table_name) + legacy_exists = legacy_table_name and await db.check_table_exists( + legacy_table_name ) # Case 1: Only new table exists or new table is the same as legacy table @@ -2535,7 +2592,9 @@ class PGVectorStorage(BaseVectorStorage): f"Proceeding with caution..." ) - await _pg_create_table(db, table_name, base_table, embedding_dim) + await PGVectorStorage._pg_create_table( + db, table_name, base_table, embedding_dim + ) logger.info(f"PostgreSQL: New table '{table_name}' created successfully") if not legacy_exists: @@ -2603,7 +2662,7 @@ class PGVectorStorage(BaseVectorStorage): ) try: - migrated_count = await _pg_migrate_workspace_data( + migrated_count = await PGVectorStorage._pg_migrate_workspace_data( db, legacy_table_name, table_name,