Refactor PG vector storage and add index creation

* Move helper functions to static methods
* Move check table exists functions to PostgreSQLDB
* Create ID and workspace indexes in DDL
This commit is contained in:
yangdx
2025-12-21 16:25:58 +08:00
parent 2228a75dd0
commit 8ef86c4898

View File

@@ -1538,6 +1538,24 @@ class PostgreSQLDB:
logger.error(f"PostgreSQL database, error:{e}")
raise
async def check_table_exists(self, table_name: str) -> bool:
"""Check if a table exists in PostgreSQL database
Args:
table_name: Name of the table to check
Returns:
bool: True if table exists, False otherwise
"""
query = """
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = $1
)
"""
result = await self.query(query, [table_name.lower()])
return result.get("exists", False) if result else False
async def execute(
self,
sql: str,
@@ -2239,143 +2257,6 @@ class PGKVStorage(BaseKVStorage):
return {"status": "error", "message": str(e)}
async def _pg_table_exists(db: PostgreSQLDB, table_name: str) -> bool:
"""Check if a table exists in PostgreSQL database"""
query = """
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = $1
)
"""
result = await db.query(query, [table_name.lower()])
return result.get("exists", False) if result else False
async def _pg_create_table(
db: PostgreSQLDB, table_name: str, base_table: str, embedding_dim: int
) -> None:
"""Create a new vector table by replacing the table name in DDL template"""
if base_table not in TABLES:
raise ValueError(f"No DDL template found for table: {base_table}")
ddl_template = TABLES[base_table]["ddl"]
# Replace embedding dimension placeholder if exists
ddl = ddl_template.replace("VECTOR(dimension)", f"VECTOR({embedding_dim})")
# Replace table name
ddl = ddl.replace(base_table, table_name)
await db.execute(ddl)
async def _pg_migrate_workspace_data(
db: PostgreSQLDB,
legacy_table_name: str,
new_table_name: str,
workspace: str,
expected_count: int,
embedding_dim: int,
) -> int:
"""Migrate workspace data from legacy table to new table using batch insert.
This function uses asyncpg's executemany for efficient batch insertion,
reducing database round-trips from N to 1 per batch.
Uses keyset pagination (cursor-based) with ORDER BY id for stable ordering.
This ensures every legacy row is migrated exactly once, avoiding the
non-deterministic row ordering issues with OFFSET/LIMIT without ORDER BY.
"""
migrated_count = 0
last_id: str | None = None
batch_size = 500
while True:
# Use keyset pagination with ORDER BY id for deterministic ordering
# This avoids OFFSET/LIMIT without ORDER BY which can skip or duplicate rows
if workspace:
if last_id is not None:
select_query = f"SELECT * FROM {legacy_table_name} WHERE workspace = $1 AND id > $2 ORDER BY id LIMIT $3"
rows = await db.query(
select_query, [workspace, last_id, batch_size], multirows=True
)
else:
select_query = f"SELECT * FROM {legacy_table_name} WHERE workspace = $1 ORDER BY id LIMIT $2"
rows = await db.query(
select_query, [workspace, batch_size], multirows=True
)
else:
if last_id is not None:
select_query = f"SELECT * FROM {legacy_table_name} WHERE id > $1 ORDER BY id LIMIT $2"
rows = await db.query(
select_query, [last_id, batch_size], multirows=True
)
else:
select_query = f"SELECT * FROM {legacy_table_name} ORDER BY id LIMIT $1"
rows = await db.query(select_query, [batch_size], multirows=True)
if not rows:
break
# Track the last ID for keyset pagination cursor
last_id = rows[-1]["id"]
# Batch insert optimization: use executemany instead of individual inserts
# Get column names from the first row
first_row = dict(rows[0])
columns = list(first_row.keys())
columns_str = ", ".join(columns)
placeholders = ", ".join([f"${i + 1}" for i in range(len(columns))])
insert_query = f"""
INSERT INTO {new_table_name} ({columns_str})
VALUES ({placeholders})
ON CONFLICT (workspace, id) DO NOTHING
"""
# Prepare batch data: convert rows to list of tuples
batch_values = []
for row in rows:
row_dict = dict(row)
# FIX: Parse vector strings from connections without register_vector codec.
# When pgvector codec is not registered on the read connection, vector
# columns are returned as text strings like "[0.1,0.2,...]" instead of
# lists/arrays. We need to convert these to numpy arrays before passing
# to executemany, which uses a connection WITH register_vector codec
# that expects list/tuple/ndarray types.
if "content_vector" in row_dict:
vec = row_dict["content_vector"]
if isinstance(vec, str):
# pgvector text format: "[0.1,0.2,0.3,...]"
vec = vec.strip("[]")
if vec:
row_dict["content_vector"] = np.array(
[float(x) for x in vec.split(",")], dtype=np.float32
)
else:
row_dict["content_vector"] = None
# Extract values in column order to match placeholders
values_tuple = tuple(row_dict[col] for col in columns)
batch_values.append(values_tuple)
# Use executemany for batch execution - significantly reduces DB round-trips
# Note: register_vector is already called on pool init, no need to call it again
async def _batch_insert(connection: asyncpg.Connection) -> None:
await connection.executemany(insert_query, batch_values)
await db._run_with_retry(_batch_insert)
migrated_count += len(rows)
workspace_info = f" for workspace '{workspace}'" if workspace else ""
logger.info(
f"PostgreSQL: {migrated_count}/{expected_count} records migrated{workspace_info}"
)
return migrated_count
@final
@dataclass
class PGVectorStorage(BaseVectorStorage):
@@ -2423,6 +2304,181 @@ class PGVectorStorage(BaseVectorStorage):
f"Consider using a shorter embedding model name or workspace name."
)
@staticmethod
async def _pg_create_table(
db: PostgreSQLDB, table_name: str, base_table: str, embedding_dim: int
) -> None:
"""Create a new vector table by replacing the table name in DDL template,
and create indexes on id and (workspace, id) columns.
Args:
db: PostgreSQLDB instance
table_name: Name of the new table to create
base_table: Base table name for DDL template lookup
embedding_dim: Embedding dimension for vector column
"""
if base_table not in TABLES:
raise ValueError(f"No DDL template found for table: {base_table}")
ddl_template = TABLES[base_table]["ddl"]
# Replace embedding dimension placeholder if exists
ddl = ddl_template.replace("VECTOR(dimension)", f"VECTOR({embedding_dim})")
# Replace table name
ddl = ddl.replace(base_table, table_name)
await db.execute(ddl)
# Create indexes similar to check_tables() but with safe index names
# Create index for id column
id_index_name = _safe_index_name(table_name, "id")
try:
create_id_index_sql = f"CREATE INDEX {id_index_name} ON {table_name}(id)"
logger.info(
f"PostgreSQL, Creating index {id_index_name} on table {table_name}"
)
await db.execute(create_id_index_sql)
except Exception as e:
logger.error(
f"PostgreSQL, Failed to create index {id_index_name}, Got: {e}"
)
# Create composite index for (workspace, id)
workspace_id_index_name = _safe_index_name(table_name, "workspace_id")
try:
create_composite_index_sql = (
f"CREATE INDEX {workspace_id_index_name} ON {table_name}(workspace, id)"
)
logger.info(
f"PostgreSQL, Creating composite index {workspace_id_index_name} on table {table_name}"
)
await db.execute(create_composite_index_sql)
except Exception as e:
logger.error(
f"PostgreSQL, Failed to create composite index {workspace_id_index_name}, Got: {e}"
)
@staticmethod
async def _pg_migrate_workspace_data(
db: PostgreSQLDB,
legacy_table_name: str,
new_table_name: str,
workspace: str,
expected_count: int,
embedding_dim: int,
) -> int:
"""Migrate workspace data from legacy table to new table using batch insert.
This function uses asyncpg's executemany for efficient batch insertion,
reducing database round-trips from N to 1 per batch.
Uses keyset pagination (cursor-based) with ORDER BY id for stable ordering.
This ensures every legacy row is migrated exactly once, avoiding the
non-deterministic row ordering issues with OFFSET/LIMIT without ORDER BY.
Args:
db: PostgreSQLDB instance
legacy_table_name: Name of the legacy table to migrate from
new_table_name: Name of the new table to migrate to
workspace: Workspace to filter records for migration
expected_count: Expected number of records to migrate
embedding_dim: Embedding dimension for vector column
Returns:
Number of records migrated
"""
migrated_count = 0
last_id: str | None = None
batch_size = 500
while True:
# Use keyset pagination with ORDER BY id for deterministic ordering
# This avoids OFFSET/LIMIT without ORDER BY which can skip or duplicate rows
if workspace:
if last_id is not None:
select_query = f"SELECT * FROM {legacy_table_name} WHERE workspace = $1 AND id > $2 ORDER BY id LIMIT $3"
rows = await db.query(
select_query, [workspace, last_id, batch_size], multirows=True
)
else:
select_query = f"SELECT * FROM {legacy_table_name} WHERE workspace = $1 ORDER BY id LIMIT $2"
rows = await db.query(
select_query, [workspace, batch_size], multirows=True
)
else:
if last_id is not None:
select_query = f"SELECT * FROM {legacy_table_name} WHERE id > $1 ORDER BY id LIMIT $2"
rows = await db.query(
select_query, [last_id, batch_size], multirows=True
)
else:
select_query = (
f"SELECT * FROM {legacy_table_name} ORDER BY id LIMIT $1"
)
rows = await db.query(select_query, [batch_size], multirows=True)
if not rows:
break
# Track the last ID for keyset pagination cursor
last_id = rows[-1]["id"]
# Batch insert optimization: use executemany instead of individual inserts
# Get column names from the first row
first_row = dict(rows[0])
columns = list(first_row.keys())
columns_str = ", ".join(columns)
placeholders = ", ".join([f"${i + 1}" for i in range(len(columns))])
insert_query = f"""
INSERT INTO {new_table_name} ({columns_str})
VALUES ({placeholders})
ON CONFLICT (workspace, id) DO NOTHING
"""
# Prepare batch data: convert rows to list of tuples
batch_values = []
for row in rows:
row_dict = dict(row)
# FIX: Parse vector strings from connections without register_vector codec.
# When pgvector codec is not registered on the read connection, vector
# columns are returned as text strings like "[0.1,0.2,...]" instead of
# lists/arrays. We need to convert these to numpy arrays before passing
# to executemany, which uses a connection WITH register_vector codec
# that expects list/tuple/ndarray types.
if "content_vector" in row_dict:
vec = row_dict["content_vector"]
if isinstance(vec, str):
# pgvector text format: "[0.1,0.2,0.3,...]"
vec = vec.strip("[]")
if vec:
row_dict["content_vector"] = np.array(
[float(x) for x in vec.split(",")], dtype=np.float32
)
else:
row_dict["content_vector"] = None
# Extract values in column order to match placeholders
values_tuple = tuple(row_dict[col] for col in columns)
batch_values.append(values_tuple)
# Use executemany for batch execution - significantly reduces DB round-trips
# Note: register_vector is already called on pool init, no need to call it again
async def _batch_insert(connection: asyncpg.Connection) -> None:
await connection.executemany(insert_query, batch_values)
await db._run_with_retry(_batch_insert)
migrated_count += len(rows)
workspace_info = f" for workspace '{workspace}'" if workspace else ""
logger.info(
f"PostgreSQL: {migrated_count}/{expected_count} records migrated{workspace_info}"
)
return migrated_count
@staticmethod
async def setup_table(
db: PostgreSQLDB,
@@ -2439,6 +2495,7 @@ class PGVectorStorage(BaseVectorStorage):
Check vector dimension compatibility before new table creation.
Drop legacy table if it exists and is empty.
Only migrate data from legacy table to new table when new table first created and legacy table is not empty.
This function must be call ClientManager.get_client() to legacy table is migrated to latest schema.
Args:
db: PostgreSQLDB instance
@@ -2451,9 +2508,9 @@ class PGVectorStorage(BaseVectorStorage):
if not workspace:
raise ValueError("workspace must be provided")
new_table_exists = await _pg_table_exists(db, table_name)
legacy_exists = legacy_table_name and await _pg_table_exists(
db, legacy_table_name
new_table_exists = await db.check_table_exists(table_name)
legacy_exists = legacy_table_name and await db.check_table_exists(
legacy_table_name
)
# Case 1: Only new table exists or new table is the same as legacy table
@@ -2535,7 +2592,9 @@ class PGVectorStorage(BaseVectorStorage):
f"Proceeding with caution..."
)
await _pg_create_table(db, table_name, base_table, embedding_dim)
await PGVectorStorage._pg_create_table(
db, table_name, base_table, embedding_dim
)
logger.info(f"PostgreSQL: New table '{table_name}' created successfully")
if not legacy_exists:
@@ -2603,7 +2662,7 @@ class PGVectorStorage(BaseVectorStorage):
)
try:
migrated_count = await _pg_migrate_workspace_data(
migrated_count = await PGVectorStorage._pg_migrate_workspace_data(
db,
legacy_table_name,
table_name,