Refactor PG vector storage and add index creation
* Move helper functions to static methods * Move check table exists functions to PostgreSQLDB * Create ID and workspace indexes in DDL
This commit is contained in:
@@ -1538,6 +1538,24 @@ class PostgreSQLDB:
|
||||
logger.error(f"PostgreSQL database, error:{e}")
|
||||
raise
|
||||
|
||||
async def check_table_exists(self, table_name: str) -> bool:
|
||||
"""Check if a table exists in PostgreSQL database
|
||||
|
||||
Args:
|
||||
table_name: Name of the table to check
|
||||
|
||||
Returns:
|
||||
bool: True if table exists, False otherwise
|
||||
"""
|
||||
query = """
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_name = $1
|
||||
)
|
||||
"""
|
||||
result = await self.query(query, [table_name.lower()])
|
||||
return result.get("exists", False) if result else False
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
sql: str,
|
||||
@@ -2239,143 +2257,6 @@ class PGKVStorage(BaseKVStorage):
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
async def _pg_table_exists(db: PostgreSQLDB, table_name: str) -> bool:
|
||||
"""Check if a table exists in PostgreSQL database"""
|
||||
query = """
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_name = $1
|
||||
)
|
||||
"""
|
||||
result = await db.query(query, [table_name.lower()])
|
||||
return result.get("exists", False) if result else False
|
||||
|
||||
|
||||
async def _pg_create_table(
|
||||
db: PostgreSQLDB, table_name: str, base_table: str, embedding_dim: int
|
||||
) -> None:
|
||||
"""Create a new vector table by replacing the table name in DDL template"""
|
||||
if base_table not in TABLES:
|
||||
raise ValueError(f"No DDL template found for table: {base_table}")
|
||||
|
||||
ddl_template = TABLES[base_table]["ddl"]
|
||||
|
||||
# Replace embedding dimension placeholder if exists
|
||||
ddl = ddl_template.replace("VECTOR(dimension)", f"VECTOR({embedding_dim})")
|
||||
|
||||
# Replace table name
|
||||
ddl = ddl.replace(base_table, table_name)
|
||||
|
||||
await db.execute(ddl)
|
||||
|
||||
|
||||
async def _pg_migrate_workspace_data(
|
||||
db: PostgreSQLDB,
|
||||
legacy_table_name: str,
|
||||
new_table_name: str,
|
||||
workspace: str,
|
||||
expected_count: int,
|
||||
embedding_dim: int,
|
||||
) -> int:
|
||||
"""Migrate workspace data from legacy table to new table using batch insert.
|
||||
|
||||
This function uses asyncpg's executemany for efficient batch insertion,
|
||||
reducing database round-trips from N to 1 per batch.
|
||||
|
||||
Uses keyset pagination (cursor-based) with ORDER BY id for stable ordering.
|
||||
This ensures every legacy row is migrated exactly once, avoiding the
|
||||
non-deterministic row ordering issues with OFFSET/LIMIT without ORDER BY.
|
||||
"""
|
||||
migrated_count = 0
|
||||
last_id: str | None = None
|
||||
batch_size = 500
|
||||
|
||||
while True:
|
||||
# Use keyset pagination with ORDER BY id for deterministic ordering
|
||||
# This avoids OFFSET/LIMIT without ORDER BY which can skip or duplicate rows
|
||||
if workspace:
|
||||
if last_id is not None:
|
||||
select_query = f"SELECT * FROM {legacy_table_name} WHERE workspace = $1 AND id > $2 ORDER BY id LIMIT $3"
|
||||
rows = await db.query(
|
||||
select_query, [workspace, last_id, batch_size], multirows=True
|
||||
)
|
||||
else:
|
||||
select_query = f"SELECT * FROM {legacy_table_name} WHERE workspace = $1 ORDER BY id LIMIT $2"
|
||||
rows = await db.query(
|
||||
select_query, [workspace, batch_size], multirows=True
|
||||
)
|
||||
else:
|
||||
if last_id is not None:
|
||||
select_query = f"SELECT * FROM {legacy_table_name} WHERE id > $1 ORDER BY id LIMIT $2"
|
||||
rows = await db.query(
|
||||
select_query, [last_id, batch_size], multirows=True
|
||||
)
|
||||
else:
|
||||
select_query = f"SELECT * FROM {legacy_table_name} ORDER BY id LIMIT $1"
|
||||
rows = await db.query(select_query, [batch_size], multirows=True)
|
||||
|
||||
if not rows:
|
||||
break
|
||||
|
||||
# Track the last ID for keyset pagination cursor
|
||||
last_id = rows[-1]["id"]
|
||||
|
||||
# Batch insert optimization: use executemany instead of individual inserts
|
||||
# Get column names from the first row
|
||||
first_row = dict(rows[0])
|
||||
columns = list(first_row.keys())
|
||||
columns_str = ", ".join(columns)
|
||||
placeholders = ", ".join([f"${i + 1}" for i in range(len(columns))])
|
||||
|
||||
insert_query = f"""
|
||||
INSERT INTO {new_table_name} ({columns_str})
|
||||
VALUES ({placeholders})
|
||||
ON CONFLICT (workspace, id) DO NOTHING
|
||||
"""
|
||||
|
||||
# Prepare batch data: convert rows to list of tuples
|
||||
batch_values = []
|
||||
for row in rows:
|
||||
row_dict = dict(row)
|
||||
|
||||
# FIX: Parse vector strings from connections without register_vector codec.
|
||||
# When pgvector codec is not registered on the read connection, vector
|
||||
# columns are returned as text strings like "[0.1,0.2,...]" instead of
|
||||
# lists/arrays. We need to convert these to numpy arrays before passing
|
||||
# to executemany, which uses a connection WITH register_vector codec
|
||||
# that expects list/tuple/ndarray types.
|
||||
if "content_vector" in row_dict:
|
||||
vec = row_dict["content_vector"]
|
||||
if isinstance(vec, str):
|
||||
# pgvector text format: "[0.1,0.2,0.3,...]"
|
||||
vec = vec.strip("[]")
|
||||
if vec:
|
||||
row_dict["content_vector"] = np.array(
|
||||
[float(x) for x in vec.split(",")], dtype=np.float32
|
||||
)
|
||||
else:
|
||||
row_dict["content_vector"] = None
|
||||
|
||||
# Extract values in column order to match placeholders
|
||||
values_tuple = tuple(row_dict[col] for col in columns)
|
||||
batch_values.append(values_tuple)
|
||||
|
||||
# Use executemany for batch execution - significantly reduces DB round-trips
|
||||
# Note: register_vector is already called on pool init, no need to call it again
|
||||
async def _batch_insert(connection: asyncpg.Connection) -> None:
|
||||
await connection.executemany(insert_query, batch_values)
|
||||
|
||||
await db._run_with_retry(_batch_insert)
|
||||
|
||||
migrated_count += len(rows)
|
||||
workspace_info = f" for workspace '{workspace}'" if workspace else ""
|
||||
logger.info(
|
||||
f"PostgreSQL: {migrated_count}/{expected_count} records migrated{workspace_info}"
|
||||
)
|
||||
|
||||
return migrated_count
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class PGVectorStorage(BaseVectorStorage):
|
||||
@@ -2423,6 +2304,181 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
f"Consider using a shorter embedding model name or workspace name."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _pg_create_table(
|
||||
db: PostgreSQLDB, table_name: str, base_table: str, embedding_dim: int
|
||||
) -> None:
|
||||
"""Create a new vector table by replacing the table name in DDL template,
|
||||
and create indexes on id and (workspace, id) columns.
|
||||
|
||||
Args:
|
||||
db: PostgreSQLDB instance
|
||||
table_name: Name of the new table to create
|
||||
base_table: Base table name for DDL template lookup
|
||||
embedding_dim: Embedding dimension for vector column
|
||||
"""
|
||||
if base_table not in TABLES:
|
||||
raise ValueError(f"No DDL template found for table: {base_table}")
|
||||
|
||||
ddl_template = TABLES[base_table]["ddl"]
|
||||
|
||||
# Replace embedding dimension placeholder if exists
|
||||
ddl = ddl_template.replace("VECTOR(dimension)", f"VECTOR({embedding_dim})")
|
||||
|
||||
# Replace table name
|
||||
ddl = ddl.replace(base_table, table_name)
|
||||
|
||||
await db.execute(ddl)
|
||||
|
||||
# Create indexes similar to check_tables() but with safe index names
|
||||
# Create index for id column
|
||||
id_index_name = _safe_index_name(table_name, "id")
|
||||
try:
|
||||
create_id_index_sql = f"CREATE INDEX {id_index_name} ON {table_name}(id)"
|
||||
logger.info(
|
||||
f"PostgreSQL, Creating index {id_index_name} on table {table_name}"
|
||||
)
|
||||
await db.execute(create_id_index_sql)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"PostgreSQL, Failed to create index {id_index_name}, Got: {e}"
|
||||
)
|
||||
|
||||
# Create composite index for (workspace, id)
|
||||
workspace_id_index_name = _safe_index_name(table_name, "workspace_id")
|
||||
try:
|
||||
create_composite_index_sql = (
|
||||
f"CREATE INDEX {workspace_id_index_name} ON {table_name}(workspace, id)"
|
||||
)
|
||||
logger.info(
|
||||
f"PostgreSQL, Creating composite index {workspace_id_index_name} on table {table_name}"
|
||||
)
|
||||
await db.execute(create_composite_index_sql)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"PostgreSQL, Failed to create composite index {workspace_id_index_name}, Got: {e}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _pg_migrate_workspace_data(
|
||||
db: PostgreSQLDB,
|
||||
legacy_table_name: str,
|
||||
new_table_name: str,
|
||||
workspace: str,
|
||||
expected_count: int,
|
||||
embedding_dim: int,
|
||||
) -> int:
|
||||
"""Migrate workspace data from legacy table to new table using batch insert.
|
||||
|
||||
This function uses asyncpg's executemany for efficient batch insertion,
|
||||
reducing database round-trips from N to 1 per batch.
|
||||
|
||||
Uses keyset pagination (cursor-based) with ORDER BY id for stable ordering.
|
||||
This ensures every legacy row is migrated exactly once, avoiding the
|
||||
non-deterministic row ordering issues with OFFSET/LIMIT without ORDER BY.
|
||||
|
||||
Args:
|
||||
db: PostgreSQLDB instance
|
||||
legacy_table_name: Name of the legacy table to migrate from
|
||||
new_table_name: Name of the new table to migrate to
|
||||
workspace: Workspace to filter records for migration
|
||||
expected_count: Expected number of records to migrate
|
||||
embedding_dim: Embedding dimension for vector column
|
||||
|
||||
Returns:
|
||||
Number of records migrated
|
||||
"""
|
||||
migrated_count = 0
|
||||
last_id: str | None = None
|
||||
batch_size = 500
|
||||
|
||||
while True:
|
||||
# Use keyset pagination with ORDER BY id for deterministic ordering
|
||||
# This avoids OFFSET/LIMIT without ORDER BY which can skip or duplicate rows
|
||||
if workspace:
|
||||
if last_id is not None:
|
||||
select_query = f"SELECT * FROM {legacy_table_name} WHERE workspace = $1 AND id > $2 ORDER BY id LIMIT $3"
|
||||
rows = await db.query(
|
||||
select_query, [workspace, last_id, batch_size], multirows=True
|
||||
)
|
||||
else:
|
||||
select_query = f"SELECT * FROM {legacy_table_name} WHERE workspace = $1 ORDER BY id LIMIT $2"
|
||||
rows = await db.query(
|
||||
select_query, [workspace, batch_size], multirows=True
|
||||
)
|
||||
else:
|
||||
if last_id is not None:
|
||||
select_query = f"SELECT * FROM {legacy_table_name} WHERE id > $1 ORDER BY id LIMIT $2"
|
||||
rows = await db.query(
|
||||
select_query, [last_id, batch_size], multirows=True
|
||||
)
|
||||
else:
|
||||
select_query = (
|
||||
f"SELECT * FROM {legacy_table_name} ORDER BY id LIMIT $1"
|
||||
)
|
||||
rows = await db.query(select_query, [batch_size], multirows=True)
|
||||
|
||||
if not rows:
|
||||
break
|
||||
|
||||
# Track the last ID for keyset pagination cursor
|
||||
last_id = rows[-1]["id"]
|
||||
|
||||
# Batch insert optimization: use executemany instead of individual inserts
|
||||
# Get column names from the first row
|
||||
first_row = dict(rows[0])
|
||||
columns = list(first_row.keys())
|
||||
columns_str = ", ".join(columns)
|
||||
placeholders = ", ".join([f"${i + 1}" for i in range(len(columns))])
|
||||
|
||||
insert_query = f"""
|
||||
INSERT INTO {new_table_name} ({columns_str})
|
||||
VALUES ({placeholders})
|
||||
ON CONFLICT (workspace, id) DO NOTHING
|
||||
"""
|
||||
|
||||
# Prepare batch data: convert rows to list of tuples
|
||||
batch_values = []
|
||||
for row in rows:
|
||||
row_dict = dict(row)
|
||||
|
||||
# FIX: Parse vector strings from connections without register_vector codec.
|
||||
# When pgvector codec is not registered on the read connection, vector
|
||||
# columns are returned as text strings like "[0.1,0.2,...]" instead of
|
||||
# lists/arrays. We need to convert these to numpy arrays before passing
|
||||
# to executemany, which uses a connection WITH register_vector codec
|
||||
# that expects list/tuple/ndarray types.
|
||||
if "content_vector" in row_dict:
|
||||
vec = row_dict["content_vector"]
|
||||
if isinstance(vec, str):
|
||||
# pgvector text format: "[0.1,0.2,0.3,...]"
|
||||
vec = vec.strip("[]")
|
||||
if vec:
|
||||
row_dict["content_vector"] = np.array(
|
||||
[float(x) for x in vec.split(",")], dtype=np.float32
|
||||
)
|
||||
else:
|
||||
row_dict["content_vector"] = None
|
||||
|
||||
# Extract values in column order to match placeholders
|
||||
values_tuple = tuple(row_dict[col] for col in columns)
|
||||
batch_values.append(values_tuple)
|
||||
|
||||
# Use executemany for batch execution - significantly reduces DB round-trips
|
||||
# Note: register_vector is already called on pool init, no need to call it again
|
||||
async def _batch_insert(connection: asyncpg.Connection) -> None:
|
||||
await connection.executemany(insert_query, batch_values)
|
||||
|
||||
await db._run_with_retry(_batch_insert)
|
||||
|
||||
migrated_count += len(rows)
|
||||
workspace_info = f" for workspace '{workspace}'" if workspace else ""
|
||||
logger.info(
|
||||
f"PostgreSQL: {migrated_count}/{expected_count} records migrated{workspace_info}"
|
||||
)
|
||||
|
||||
return migrated_count
|
||||
|
||||
@staticmethod
|
||||
async def setup_table(
|
||||
db: PostgreSQLDB,
|
||||
@@ -2439,6 +2495,7 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
Check vector dimension compatibility before new table creation.
|
||||
Drop legacy table if it exists and is empty.
|
||||
Only migrate data from legacy table to new table when new table first created and legacy table is not empty.
|
||||
This function must be call ClientManager.get_client() to legacy table is migrated to latest schema.
|
||||
|
||||
Args:
|
||||
db: PostgreSQLDB instance
|
||||
@@ -2451,9 +2508,9 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
if not workspace:
|
||||
raise ValueError("workspace must be provided")
|
||||
|
||||
new_table_exists = await _pg_table_exists(db, table_name)
|
||||
legacy_exists = legacy_table_name and await _pg_table_exists(
|
||||
db, legacy_table_name
|
||||
new_table_exists = await db.check_table_exists(table_name)
|
||||
legacy_exists = legacy_table_name and await db.check_table_exists(
|
||||
legacy_table_name
|
||||
)
|
||||
|
||||
# Case 1: Only new table exists or new table is the same as legacy table
|
||||
@@ -2535,7 +2592,9 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
f"Proceeding with caution..."
|
||||
)
|
||||
|
||||
await _pg_create_table(db, table_name, base_table, embedding_dim)
|
||||
await PGVectorStorage._pg_create_table(
|
||||
db, table_name, base_table, embedding_dim
|
||||
)
|
||||
logger.info(f"PostgreSQL: New table '{table_name}' created successfully")
|
||||
|
||||
if not legacy_exists:
|
||||
@@ -2603,7 +2662,7 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
)
|
||||
|
||||
try:
|
||||
migrated_count = await _pg_migrate_workspace_data(
|
||||
migrated_count = await PGVectorStorage._pg_migrate_workspace_data(
|
||||
db,
|
||||
legacy_table_name,
|
||||
table_name,
|
||||
|
||||
Reference in New Issue
Block a user