Optimize Postgres batch operations and refine workspace migration logic

- Use executemany for efficient upserts
- Optimize data migration with batching
- Refine multi-workspace migration logic
- Add pgvector dependency
- Update DDL templates for dynamic dims
This commit is contained in:
yangdx
2025-12-19 12:05:22 +08:00
parent 0ae60d36bc
commit ada5f10be7
5 changed files with 402 additions and 539 deletions

View File

@@ -31,6 +31,7 @@ from ..base import (
DocStatus,
DocStatusStorage,
)
from ..exceptions import DataMigrationError
from ..namespace import NameSpace, is_namespace
from ..utils import logger
from ..kg.shared_storage import get_data_init_lock
@@ -39,9 +40,12 @@ import pipmaster as pm
if not pm.is_installed("asyncpg"):
pm.install("asyncpg")
if not pm.is_installed("pgvector"):
pm.install("pgvector")
import asyncpg # type: ignore
from asyncpg import Pool # type: ignore
from pgvector.asyncpg import register_vector # type: ignore
from dotenv import load_dotenv
@@ -2191,9 +2195,7 @@ async def _pg_create_table(
ddl_template = TABLES[base_table]["ddl"]
# Replace embedding dimension placeholder if exists
ddl = ddl_template.replace(
f"VECTOR({os.environ.get('EMBEDDING_DIM', 1024)})", f"VECTOR({embedding_dim})"
)
ddl = ddl_template.replace("VECTOR(dimension)", f"VECTOR({embedding_dim})")
# Replace table name
ddl = ddl.replace(base_table, table_name)
@@ -2209,7 +2211,11 @@ async def _pg_migrate_workspace_data(
expected_count: int,
embedding_dim: int,
) -> int:
"""Migrate workspace data from legacy table to new table"""
"""Migrate workspace data from legacy table to new table using batch insert.
This function uses asyncpg's executemany for efficient batch insertion,
reducing database round-trips from N to 1 per batch.
"""
migrated_count = 0
offset = 0
batch_size = 500
@@ -2227,20 +2233,34 @@ async def _pg_migrate_workspace_data(
if not rows:
break
# Batch insert optimization: use executemany instead of individual inserts
# Get column names from the first row
first_row = dict(rows[0])
columns = list(first_row.keys())
columns_str = ", ".join(columns)
placeholders = ", ".join([f"${i + 1}" for i in range(len(columns))])
insert_query = f"""
INSERT INTO {new_table_name} ({columns_str})
VALUES ({placeholders})
ON CONFLICT (workspace, id) DO NOTHING
"""
# Prepare batch data: convert rows to list of tuples
batch_values = []
for row in rows:
row_dict = dict(row)
columns = list(row_dict.keys())
columns_str = ", ".join(columns)
placeholders = ", ".join([f"${i + 1}" for i in range(len(columns))])
insert_query = f"""
INSERT INTO {new_table_name} ({columns_str})
VALUES ({placeholders})
ON CONFLICT (workspace, id) DO NOTHING
"""
# Rebuild dict in columns order to ensure values() matches placeholders order
# Python 3.7+ dicts maintain insertion order, and execute() uses tuple(data.values())
values = {col: row_dict[col] for col in columns}
await db.execute(insert_query, values)
# Extract values in column order to match placeholders
values_tuple = tuple(row_dict[col] for col in columns)
batch_values.append(values_tuple)
# Use executemany for batch execution - significantly reduces DB round-trips
# Register pgvector codec to handle vector fields alongside other fields seamlessly
async def _batch_insert(connection: asyncpg.Connection) -> None:
await register_vector(connection)
await connection.executemany(insert_query, batch_values)
await db._run_with_retry(_batch_insert)
migrated_count += len(rows)
workspace_info = f" for workspace '{workspace}'" if workspace else ""
@@ -2284,395 +2304,208 @@ class PGVectorStorage(BaseVectorStorage):
# Fallback: use base table name if model_suffix is unavailable
self.table_name = base_table
logger.warning(
f"Model suffix unavailable, using base table name '{base_table}'. "
f"Ensure embedding_func has model_name for proper model isolation."
"Missing collection suffix. Ensure embedding_func has model_name for proper model isolation."
)
# Legacy table name (without suffix, for migration)
self.legacy_table_name = base_table
logger.debug(
f"PostgreSQL table naming: "
f"new='{self.table_name}', "
f"legacy='{self.legacy_table_name}', "
f"model_suffix='{self.model_suffix}'"
)
logger.info(f"PostgreSQL table name: {self.table_name}")
@staticmethod
async def setup_table(
db: PostgreSQLDB,
table_name: str,
legacy_table_name: str = None,
base_table: str = None,
embedding_dim: int = None,
workspace: str = None,
workspace: str,
embedding_dim: int,
legacy_table_name: str,
base_table: str,
):
"""
Setup PostgreSQL table with migration support from legacy tables.
This method mirrors Qdrant's setup_collection approach to maintain consistency.
Ensure final table has workspace isolation index.
Check vector dimension compatibility before new table creation.
Drop legacy table if it exists and is empty.
Only migrate data from legacy table to new table when new table first created and legacy table is not empty.
Args:
db: PostgreSQLDB instance
table_name: Name of the new table
legacy_table_name: Name of the legacy table (if exists)
workspace: Workspace to filter records for migration
legacy_table_name: Name of the legacy table to check for migration
base_table: Base table name for DDL template lookup
embedding_dim: Embedding dimension for vector column
"""
if not workspace:
raise ValueError("workspace must be provided")
new_table_exists = await _pg_table_exists(db, table_name)
legacy_exists = legacy_table_name and await _pg_table_exists(
db, legacy_table_name
)
# Case 1: Both new and legacy tables exist
if new_table_exists and legacy_exists:
if table_name.lower() == legacy_table_name.lower():
logger.debug(
f"PostgreSQL: Table '{table_name}' already exists (no model suffix). Skipping Case 1 cleanup."
)
return
try:
workspace_info = f" for workspace '{workspace}'" if workspace else ""
if workspace:
count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name} WHERE workspace = $1"
count_result = await db.query(count_query, [workspace])
else:
count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name}"
count_result = await db.query(count_query, [])
workspace_count = count_result.get("count", 0) if count_result else 0
if workspace_count > 0:
logger.info(
f"PostgreSQL: Found {workspace_count} records in legacy table{workspace_info}. Migrating..."
)
legacy_dim = None
try:
dim_query = """
SELECT
CASE
WHEN typname = 'vector' THEN
COALESCE(atttypmod, -1)
ELSE -1
END as vector_dim
FROM pg_attribute a
JOIN pg_type t ON a.atttypid = t.oid
WHERE a.attrelid = $1::regclass
AND a.attname = 'content_vector'
"""
dim_result = await db.query(dim_query, [legacy_table_name])
legacy_dim = (
dim_result.get("vector_dim", -1) if dim_result else -1
)
if legacy_dim <= 0:
sample_query = f"SELECT content_vector FROM {legacy_table_name} LIMIT 1"
sample_result = await db.query(sample_query, [])
if sample_result and sample_result.get("content_vector"):
vector_data = sample_result["content_vector"]
if isinstance(vector_data, (list, tuple)):
legacy_dim = len(vector_data)
elif isinstance(vector_data, str):
import json
vector_list = json.loads(vector_data)
legacy_dim = len(vector_list)
if (
legacy_dim > 0
and embedding_dim
and legacy_dim != embedding_dim
):
logger.warning(
f"PostgreSQL: Dimension mismatch - "
f"legacy table has {legacy_dim}d vectors, "
f"new embedding model expects {embedding_dim}d. "
f"Skipping migration{workspace_info}."
)
await db._create_vector_index(table_name, embedding_dim)
return
except Exception as e:
logger.warning(
f"PostgreSQL: Could not verify vector dimension: {e}. Proceeding with caution..."
)
migrated_count = await _pg_migrate_workspace_data(
db,
legacy_table_name,
table_name,
workspace,
workspace_count,
embedding_dim,
)
if workspace:
new_count_query = f"SELECT COUNT(*) as count FROM {table_name} WHERE workspace = $1"
new_count_result = await db.query(new_count_query, [workspace])
else:
new_count_query = f"SELECT COUNT(*) as count FROM {table_name}"
new_count_result = await db.query(new_count_query, [])
new_count = (
new_count_result.get("count", 0) if new_count_result else 0
)
if new_count < workspace_count:
logger.warning(
f"PostgreSQL: Expected {workspace_count} records, found {new_count}{workspace_info}. "
f"Some records may have been skipped due to conflicts."
)
else:
logger.info(
f"PostgreSQL: Migration completed: {migrated_count} records migrated{workspace_info}"
)
if workspace:
delete_query = (
f"DELETE FROM {legacy_table_name} WHERE workspace = $1"
)
await db.execute(delete_query, {"workspace": workspace})
logger.info(
f"PostgreSQL: Deleted workspace '{workspace}' data from legacy table"
)
total_count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name}"
total_count_result = await db.query(total_count_query, [])
total_count = (
total_count_result.get("count", 0) if total_count_result else 0
)
if total_count == 0:
logger.info(
f"PostgreSQL: Legacy table '{legacy_table_name}' is empty. Deleting..."
)
drop_query = f"DROP TABLE {legacy_table_name}"
await db.execute(drop_query, None)
logger.info(
f"PostgreSQL: Legacy table '{legacy_table_name}' deleted successfully"
)
else:
logger.info(
f"PostgreSQL: Legacy table '{legacy_table_name}' preserved "
f"({total_count} records from other workspaces remain)"
)
except Exception as e:
logger.warning(
f"PostgreSQL: Error during Case 1 migration: {e}. Vector index will still be ensured."
)
# Case 1: Only new table exists or new table is the same as legacy table
# No data migration needed, ensuring index is created then return
if (new_table_exists and not legacy_exists) or (
table_name.lower() == legacy_table_name.lower()
):
await db._create_vector_index(table_name, embedding_dim)
return
# Case 2: Only new table exists - Already migrated or newly created
if new_table_exists:
logger.debug(f"PostgreSQL: Table '{table_name}' already exists")
# Ensure vector index exists with correct embedding dimension
await db._create_vector_index(table_name, embedding_dim)
return
# Case 3: Neither exists - Create new table
if not legacy_exists:
logger.info(f"PostgreSQL: Creating new table '{table_name}'")
await _pg_create_table(db, table_name, base_table, embedding_dim)
logger.info(f"PostgreSQL: Table '{table_name}' created successfully")
# Create vector index with correct embedding dimension
await db._create_vector_index(table_name, embedding_dim)
return
# Case 4: Only legacy exists - Migrate data
logger.info(
f"PostgreSQL: Migrating data from legacy table '{legacy_table_name}'"
)
try:
# Get legacy table count (with workspace filtering)
if workspace:
legacy_count = None
if not new_table_exists:
# Check vector dimension compatibility before creating new table
if legacy_exists:
count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name} WHERE workspace = $1"
count_result = await db.query(count_query, [workspace])
else:
count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name}"
count_result = await db.query(count_query, [])
logger.warning(
"PostgreSQL: Migration without workspace filter - this may copy data from all workspaces!"
)
legacy_count = count_result.get("count", 0) if count_result else 0
legacy_count = count_result.get("count", 0) if count_result else 0
workspace_info = f" for workspace '{workspace}'" if workspace else ""
logger.info(
f"PostgreSQL: Found {legacy_count} records in legacy table{workspace_info}"
if legacy_count > 0:
legacy_dim = None
try:
sample_query = f"SELECT content_vector FROM {legacy_table_name} WHERE workspace = $1 LIMIT 1"
sample_result = await db.query(sample_query, [workspace])
if sample_result and sample_result.get("content_vector"):
vector_data = sample_result["content_vector"]
# pgvector returns list directly
if isinstance(vector_data, (list, tuple)):
legacy_dim = len(vector_data)
elif isinstance(vector_data, str):
import json
vector_list = json.loads(vector_data)
legacy_dim = len(vector_list)
if legacy_dim and legacy_dim != embedding_dim:
logger.error(
f"PostgreSQL: Dimension mismatch detected! "
f"Legacy table '{legacy_table_name}' has {legacy_dim}d vectors, "
f"but new embedding model expects {embedding_dim}d."
)
raise DataMigrationError(
f"Dimension mismatch between legacy table '{legacy_table_name}' "
f"and new embedding model. Expected {embedding_dim} but got {legacy_dim}."
)
except DataMigrationError:
# Re-raise DataMigrationError as-is to preserve specific error messages
raise
except Exception as e:
raise DataMigrationError(
f"Could not verify legacy table vector dimension: {e}. "
f"Proceeding with caution..."
)
await _pg_create_table(db, table_name, base_table, embedding_dim)
logger.info(f"PostgreSQL: New table '{table_name}' created successfully")
# Ensure vector index is created
await db._create_vector_index(table_name, embedding_dim)
# Case 2: Legacy table exist
if legacy_exists:
workspace_info = f" for workspace '{workspace}'"
# Only drop legacy table if entire table is empty
total_count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name}"
total_count_result = await db.query(total_count_query, [])
total_count = (
total_count_result.get("count", 0) if total_count_result else 0
)
if legacy_count == 0:
logger.info("PostgreSQL: Legacy table is empty, skipping migration")
await _pg_create_table(db, table_name, base_table, embedding_dim)
# Create vector index with correct embedding dimension
await db._create_vector_index(table_name, embedding_dim)
if total_count == 0:
logger.info(
f"PostgreSQL: Empty legacy table '{legacy_table_name}' deleted successfully"
)
drop_query = f"DROP TABLE {legacy_table_name}"
await db.execute(drop_query, None)
return
# Check vector dimension compatibility before migration
legacy_dim = None
try:
# Try to get vector dimension from pg_attribute metadata
dim_query = """
SELECT
CASE
WHEN typname = 'vector' THEN
COALESCE(atttypmod, -1)
ELSE -1
END as vector_dim
FROM pg_attribute a
JOIN pg_type t ON a.atttypid = t.oid
WHERE a.attrelid = $1::regclass
AND a.attname = 'content_vector'
"""
dim_result = await db.query(dim_query, [legacy_table_name])
legacy_dim = dim_result.get("vector_dim", -1) if dim_result else -1
# No data migration needed if legacy workspace is empty
if legacy_count is None:
count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name} WHERE workspace = $1"
count_result = await db.query(count_query, [workspace])
legacy_count = count_result.get("count", 0) if count_result else 0
if legacy_dim <= 0:
# Alternative: Try to detect by sampling a vector
logger.info(
"PostgreSQL: Metadata dimension check failed, trying vector sampling..."
)
sample_query = (
f"SELECT content_vector FROM {legacy_table_name} LIMIT 1"
)
sample_result = await db.query(sample_query, [])
if sample_result and sample_result.get("content_vector"):
vector_data = sample_result["content_vector"]
# pgvector returns list directly
if isinstance(vector_data, (list, tuple)):
legacy_dim = len(vector_data)
elif isinstance(vector_data, str):
import json
if legacy_count == 0:
logger.info(
f"PostgreSQL: No records{workspace_info} found in legacy table. "
f"No data migration needed."
)
return
vector_list = json.loads(vector_data)
legacy_dim = len(vector_list)
new_count_query = (
f"SELECT COUNT(*) as count FROM {table_name} WHERE workspace = $1"
)
new_count_result = await db.query(new_count_query, [workspace])
new_table_workspace_count = (
new_count_result.get("count", 0) if new_count_result else 0
)
if legacy_dim > 0 and embedding_dim and legacy_dim != embedding_dim:
logger.warning(
f"PostgreSQL: Dimension mismatch detected! "
f"Legacy table '{legacy_table_name}' has {legacy_dim}d vectors, "
f"but new embedding model expects {embedding_dim}d. "
f"Migration skipped to prevent data loss. "
f"Legacy table preserved as '{legacy_table_name}'. "
f"Creating new empty table '{table_name}' for new data."
)
# Create new table but skip migration
await _pg_create_table(db, table_name, base_table, embedding_dim)
await db._create_vector_index(table_name, embedding_dim)
logger.info(
f"PostgreSQL: New table '{table_name}' created. "
f"To query legacy data, please use a {legacy_dim}d embedding model."
)
return
except Exception as e:
if new_table_workspace_count > 0:
logger.warning(
f"PostgreSQL: Could not verify legacy table vector dimension: {e}. "
f"Proceeding with caution..."
f"PostgreSQL: New table '{table_name}' already has "
f"{new_table_workspace_count} records{workspace_info}. "
"Data migration skipped to avoid duplicates."
)
return
logger.info(f"PostgreSQL: Creating new table '{table_name}'")
await _pg_create_table(db, table_name, base_table, embedding_dim)
migrated_count = await _pg_migrate_workspace_data(
db,
legacy_table_name,
table_name,
workspace,
legacy_count,
embedding_dim,
)
logger.info("PostgreSQL: Verifying migration...")
new_count_query = f"SELECT COUNT(*) as count FROM {table_name}"
new_count_result = await db.query(new_count_query, [])
new_count = new_count_result.get("count", 0) if new_count_result else 0
if new_count != legacy_count:
error_msg = (
f"PostgreSQL: Migration verification failed, "
f"expected {legacy_count} records, got {new_count} in new table"
)
logger.error(error_msg)
raise PostgreSQLMigrationError(error_msg)
# Case 3: Legacy has workspace data and new table is empty for workspace
logger.info(
f"PostgreSQL: Migration completed successfully: {migrated_count} records migrated"
f"PostgreSQL: Found legacy table '{legacy_table_name}' with {legacy_count} records{workspace_info}."
)
logger.info(
f"PostgreSQL: Migrating data from legacy table '{legacy_table_name}' to new table '{table_name}'"
)
try:
migrated_count = await _pg_migrate_workspace_data(
db,
legacy_table_name,
table_name,
workspace,
legacy_count,
embedding_dim,
)
if migrated_count != legacy_count:
logger.warning(
"PostgreSQL: Read %s legacy records%s during migration, expected %s.",
migrated_count,
workspace_info,
legacy_count,
)
new_count_result = await db.query(new_count_query, [workspace])
new_table_count_after = (
new_count_result.get("count", 0) if new_count_result else 0
)
inserted_count = new_table_count_after - new_table_workspace_count
if inserted_count != legacy_count:
error_msg = (
"PostgreSQL: Migration verification failed, "
f"expected {legacy_count} inserted records, got {inserted_count}."
)
logger.error(error_msg)
raise DataMigrationError(error_msg)
except DataMigrationError:
# Re-raise DataMigrationError as-is to preserve specific error messages
raise
except Exception as e:
logger.error(
f"PostgreSQL: Failed to migrate data from legacy table '{legacy_table_name}' to new table '{table_name}': {e}"
)
raise DataMigrationError(
f"Failed to migrate data from legacy table '{legacy_table_name}' to new table '{table_name}'"
) from e
logger.info(
f"PostgreSQL: Migration from '{legacy_table_name}' to '{table_name}' completed successfully"
)
await db._create_vector_index(table_name, embedding_dim)
try:
if workspace:
logger.info(
f"PostgreSQL: Deleting migrated workspace '{workspace}' data from legacy table '{legacy_table_name}'..."
)
delete_query = (
f"DELETE FROM {legacy_table_name} WHERE workspace = $1"
)
await db.execute(delete_query, {"workspace": workspace})
logger.info(
f"PostgreSQL: Deleted workspace '{workspace}' data from legacy table"
)
remaining_query = (
f"SELECT COUNT(*) as count FROM {legacy_table_name}"
)
remaining_result = await db.query(remaining_query, [])
remaining_count = (
remaining_result.get("count", 0) if remaining_result else 0
)
if remaining_count == 0:
logger.info(
f"PostgreSQL: Legacy table '{legacy_table_name}' is empty, deleting..."
)
drop_query = f"DROP TABLE {legacy_table_name}"
await db.execute(drop_query, None)
logger.info(
f"PostgreSQL: Legacy table '{legacy_table_name}' deleted successfully"
)
else:
logger.info(
f"PostgreSQL: Legacy table '{legacy_table_name}' preserved ({remaining_count} records from other workspaces remain)"
)
else:
logger.warning(
f"PostgreSQL: No workspace specified, deleting entire legacy table '{legacy_table_name}'..."
)
drop_query = f"DROP TABLE {legacy_table_name}"
await db.execute(drop_query, None)
logger.info(
f"PostgreSQL: Legacy table '{legacy_table_name}' deleted"
)
except Exception as delete_error:
# If cleanup fails, log warning but don't fail migration
logger.warning(
f"PostgreSQL: Failed to clean up legacy table '{legacy_table_name}': {delete_error}. "
"Migration succeeded, but manual cleanup may be needed."
)
except PostgreSQLMigrationError:
# Re-raise migration errors without wrapping
raise
except Exception as e:
error_msg = f"PostgreSQL: Migration failed with error: {e}"
logger.error(error_msg)
# Mirror Qdrant behavior: no automatic rollback
# Reason: partial data can be continued by re-running migration
raise PostgreSQLMigrationError(error_msg) from e
logger.info(
"PostgreSQL: Manual deletion is required after data migration verification."
)
async def initialize(self):
async with get_data_init_lock():
@@ -2694,10 +2527,10 @@ class PGVectorStorage(BaseVectorStorage):
await PGVectorStorage.setup_table(
self.db,
self.table_name,
self.workspace, # CRITICAL: Filter migration by workspace
embedding_dim=self.embedding_func.embedding_dim,
legacy_table_name=self.legacy_table_name,
base_table=self.legacy_table_name, # base_table for DDL template lookup
embedding_dim=self.embedding_func.embedding_dim,
workspace=self.workspace, # CRITICAL: Filter migration by workspace
)
async def finalize(self):
@@ -2707,34 +2540,45 @@ class PGVectorStorage(BaseVectorStorage):
def _upsert_chunks(
self, item: dict[str, Any], current_time: datetime.datetime
) -> tuple[str, dict[str, Any]]:
) -> tuple[str, tuple[Any, ...]]:
"""Prepare upsert data for chunks.
Returns:
Tuple of (SQL template, values tuple for executemany)
"""
try:
upsert_sql = SQL_TEMPLATES["upsert_chunk"].format(
table_name=self.table_name
)
data: dict[str, Any] = {
"workspace": self.workspace,
"id": item["__id__"],
"tokens": item["tokens"],
"chunk_order_index": item["chunk_order_index"],
"full_doc_id": item["full_doc_id"],
"content": item["content"],
"content_vector": json.dumps(item["__vector__"].tolist()),
"file_path": item["file_path"],
"create_time": current_time,
"update_time": current_time,
}
# Return tuple in the exact order of SQL parameters ($1, $2, ...)
values: tuple[Any, ...] = (
self.workspace, # $1
item["__id__"], # $2
item["tokens"], # $3
item["chunk_order_index"], # $4
item["full_doc_id"], # $5
item["content"], # $6
item["__vector__"], # $7 - numpy array, handled by pgvector codec
item["file_path"], # $8
current_time, # $9
current_time, # $10
)
except Exception as e:
logger.error(
f"[{self.workspace}] Error to prepare upsert,\nsql: {e}\nitem: {item}"
f"[{self.workspace}] Error to prepare upsert,\nerror: {e}\nitem: {item}"
)
raise
return upsert_sql, data
return upsert_sql, values
def _upsert_entities(
self, item: dict[str, Any], current_time: datetime.datetime
) -> tuple[str, dict[str, Any]]:
) -> tuple[str, tuple[Any, ...]]:
"""Prepare upsert data for entities.
Returns:
Tuple of (SQL template, values tuple for executemany)
"""
upsert_sql = SQL_TEMPLATES["upsert_entity"].format(table_name=self.table_name)
source_id = item["source_id"]
if isinstance(source_id, str) and "<SEP>" in source_id:
@@ -2742,22 +2586,28 @@ class PGVectorStorage(BaseVectorStorage):
else:
chunk_ids = [source_id]
data: dict[str, Any] = {
"workspace": self.workspace,
"id": item["__id__"],
"entity_name": item["entity_name"],
"content": item["content"],
"content_vector": json.dumps(item["__vector__"].tolist()),
"chunk_ids": chunk_ids,
"file_path": item.get("file_path", None),
"create_time": current_time,
"update_time": current_time,
}
return upsert_sql, data
# Return tuple in the exact order of SQL parameters ($1, $2, ...)
values: tuple[Any, ...] = (
self.workspace, # $1
item["__id__"], # $2
item["entity_name"], # $3
item["content"], # $4
item["__vector__"], # $5 - numpy array, handled by pgvector codec
chunk_ids, # $6
item.get("file_path", None), # $7
current_time, # $8
current_time, # $9
)
return upsert_sql, values
def _upsert_relationships(
self, item: dict[str, Any], current_time: datetime.datetime
) -> tuple[str, dict[str, Any]]:
) -> tuple[str, tuple[Any, ...]]:
"""Prepare upsert data for relationships.
Returns:
Tuple of (SQL template, values tuple for executemany)
"""
upsert_sql = SQL_TEMPLATES["upsert_relationship"].format(
table_name=self.table_name
)
@@ -2767,19 +2617,20 @@ class PGVectorStorage(BaseVectorStorage):
else:
chunk_ids = [source_id]
data: dict[str, Any] = {
"workspace": self.workspace,
"id": item["__id__"],
"source_id": item["src_id"],
"target_id": item["tgt_id"],
"content": item["content"],
"content_vector": json.dumps(item["__vector__"].tolist()),
"chunk_ids": chunk_ids,
"file_path": item.get("file_path", None),
"create_time": current_time,
"update_time": current_time,
}
return upsert_sql, data
# Return tuple in the exact order of SQL parameters ($1, $2, ...)
values: tuple[Any, ...] = (
self.workspace, # $1
item["__id__"], # $2
item["src_id"], # $3
item["tgt_id"], # $4
item["content"], # $5
item["__vector__"], # $6 - numpy array, handled by pgvector codec
chunk_ids, # $7
item.get("file_path", None), # $8
current_time, # $9
current_time, # $10
)
return upsert_sql, values
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
@@ -2807,17 +2658,34 @@ class PGVectorStorage(BaseVectorStorage):
embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
# Prepare batch values for executemany
batch_values: list[tuple[Any, ...]] = []
upsert_sql = None
for item in list_data:
if is_namespace(self.namespace, NameSpace.VECTOR_STORE_CHUNKS):
upsert_sql, data = self._upsert_chunks(item, current_time)
upsert_sql, values = self._upsert_chunks(item, current_time)
elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_ENTITIES):
upsert_sql, data = self._upsert_entities(item, current_time)
upsert_sql, values = self._upsert_entities(item, current_time)
elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_RELATIONSHIPS):
upsert_sql, data = self._upsert_relationships(item, current_time)
upsert_sql, values = self._upsert_relationships(item, current_time)
else:
raise ValueError(f"{self.namespace} is not supported")
await self.db.execute(upsert_sql, data)
batch_values.append(values)
# Use executemany for batch execution - significantly reduces DB round-trips
if batch_values and upsert_sql:
async def _batch_upsert(connection: asyncpg.Connection) -> None:
await register_vector(connection)
await connection.executemany(upsert_sql, batch_values)
await self.db._run_with_retry(_batch_upsert)
logger.debug(
f"[{self.workspace}] Batch upserted {len(batch_values)} records to {self.namespace}"
)
#################### query method ###############
async def query(
@@ -3658,12 +3526,6 @@ class PGDocStatusStorage(DocStatusStorage):
return {"status": "error", "message": str(e)}
class PostgreSQLMigrationError(Exception):
"""Exception for PostgreSQL table migration errors."""
pass
class PGGraphQueryException(Exception):
"""Exception for the AGE queries."""
@@ -5263,14 +5125,14 @@ TABLES = {
)"""
},
"LIGHTRAG_VDB_CHUNKS": {
"ddl": f"""CREATE TABLE LIGHTRAG_VDB_CHUNKS (
"ddl": """CREATE TABLE LIGHTRAG_VDB_CHUNKS (
id VARCHAR(255),
workspace VARCHAR(255),
full_doc_id VARCHAR(256),
chunk_order_index INTEGER,
tokens INTEGER,
content TEXT,
content_vector VECTOR({os.environ.get("EMBEDDING_DIM", 1024)}),
content_vector VECTOR(dimension),
file_path TEXT NULL,
create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
@@ -5278,12 +5140,12 @@ TABLES = {
)"""
},
"LIGHTRAG_VDB_ENTITY": {
"ddl": f"""CREATE TABLE LIGHTRAG_VDB_ENTITY (
"ddl": """CREATE TABLE LIGHTRAG_VDB_ENTITY (
id VARCHAR(255),
workspace VARCHAR(255),
entity_name VARCHAR(512),
content TEXT,
content_vector VECTOR({os.environ.get("EMBEDDING_DIM", 1024)}),
content_vector VECTOR(dimension),
create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
chunk_ids VARCHAR(255)[] NULL,
@@ -5292,13 +5154,13 @@ TABLES = {
)"""
},
"LIGHTRAG_VDB_RELATION": {
"ddl": f"""CREATE TABLE LIGHTRAG_VDB_RELATION (
"ddl": """CREATE TABLE LIGHTRAG_VDB_RELATION (
id VARCHAR(255),
workspace VARCHAR(255),
source_id VARCHAR(512),
target_id VARCHAR(512),
content TEXT,
content_vector VECTOR({os.environ.get("EMBEDDING_DIM", 1024)}),
content_vector VECTOR(dimension),
create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
chunk_ids VARCHAR(255)[] NULL,

View File

@@ -72,7 +72,6 @@ api = [
# API-specific dependencies
"aiofiles",
"ascii_colors",
"asyncpg",
"distro",
"fastapi",
"httpcore",
@@ -108,7 +107,8 @@ offline-storage = [
"neo4j>=5.0.0,<7.0.0",
"pymilvus>=2.6.2,<3.0.0",
"pymongo>=4.0.0,<5.0.0",
"asyncpg>=0.29.0,<1.0.0",
"asyncpg>=0.31.0,<1.0.0",
"pgvector>=0.4.2,<1.0.0",
"qdrant-client>=1.11.0,<2.0.0",
]

View File

@@ -8,8 +8,9 @@
# Or use constraints: pip install --constraint constraints-offline.txt -r requirements-offline-storage.txt
# Storage backend dependencies (with version constraints matching pyproject.toml)
asyncpg>=0.29.0,<1.0.0
asyncpg>=0.31.0,<1.0.0
neo4j>=5.0.0,<7.0.0
pgvector>=0.4.2,<1.0.0
pymilvus>=2.6.2,<3.0.0
pymongo>=4.0.0,<5.0.0
qdrant-client>=1.11.0,<2.0.0

View File

@@ -7,20 +7,17 @@
# Recommended: Use pip install lightrag-hku[offline] for the same effect
# Or use constraints: pip install --constraint constraints-offline.txt -r requirements-offline.txt
# LLM provider dependencies (with version constraints matching pyproject.toml)
aioboto3>=12.0.0,<16.0.0
anthropic>=0.18.0,<1.0.0
# Storage backend dependencies
asyncpg>=0.29.0,<1.0.0
asyncpg>=0.31.0,<1.0.0
google-api-core>=2.0.0,<3.0.0
google-genai>=1.0.0,<2.0.0
# Document processing dependencies
llama-index>=0.9.0,<1.0.0
neo4j>=5.0.0,<7.0.0
ollama>=0.1.0,<1.0.0
openai>=2.0.0,<3.0.0
openpyxl>=3.0.0,<4.0.0
pgvector>=0.4.2,<1.0.0
pycryptodome>=3.0.0,<4.0.0
pymilvus>=2.6.2,<3.0.0
pymongo>=4.0.0,<5.0.0

View File

@@ -123,10 +123,21 @@ async def test_postgres_migration_trigger(
{"id": f"test_id_{i}", "content": f"content_{i}", "workspace": "test_ws"}
for i in range(100)
]
migration_state = {"new_table_count": 0}
async def mock_query(sql, params=None, multirows=False, **kwargs):
if "COUNT(*)" in sql:
return {"count": 100}
sql_upper = sql.upper()
legacy_table = storage.legacy_table_name.upper()
new_table = storage.table_name.upper()
is_new_table = new_table in sql_upper
is_legacy_table = legacy_table in sql_upper and not is_new_table
if is_new_table:
return {"count": migration_state["new_table_count"]}
if is_legacy_table:
return {"count": 100}
return {"count": 0}
elif multirows and "SELECT *" in sql:
# Mock batch fetch for migration
# Handle workspace filtering: params = [workspace, offset, limit] or [offset, limit]
@@ -145,6 +156,17 @@ async def test_postgres_migration_trigger(
mock_pg_db.query = AsyncMock(side_effect=mock_query)
# Track migration through _run_with_retry calls
migration_executed = []
async def mock_run_with_retry(operation, **kwargs):
# Track that migration batch operation was called
migration_executed.append(True)
migration_state["new_table_count"] = 100
return None
mock_pg_db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry)
with (
patch(
"lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists
@@ -154,9 +176,9 @@ async def test_postgres_migration_trigger(
# Initialize storage (should trigger migration)
await storage.initialize()
# Verify migration was executed
# Check that execute was called for inserting rows
assert mock_pg_db.execute.call_count > 0
# Verify migration was executed by checking _run_with_retry was called
# (batch migration uses _run_with_retry with executemany)
assert len(migration_executed) > 0, "Migration should have been executed"
@pytest.mark.asyncio
@@ -291,6 +313,7 @@ async def test_scenario_2_legacy_upgrade_migration(
# Track which queries have been made for proper response
query_history = []
migration_state = {"new_table_count": 0}
async def mock_query(sql, params=None, multirows=False, **kwargs):
query_history.append(sql)
@@ -303,27 +326,20 @@ async def test_scenario_2_legacy_upgrade_migration(
base_name = storage.legacy_table_name.upper()
# Check if this is querying the new table (has model suffix)
has_model_suffix = any(
suffix in sql_upper
for suffix in ["TEXT_EMBEDDING", "_1536D", "_768D", "_1024D", "_3072D"]
)
has_model_suffix = storage.table_name.upper() in sql_upper
is_legacy_table = base_name in sql_upper and not has_model_suffix
is_new_table = has_model_suffix
has_workspace_filter = "WHERE workspace" in sql
if is_legacy_table and has_workspace_filter:
# Count for legacy table with workspace filter (before migration)
return {"count": 50}
elif is_legacy_table and not has_workspace_filter:
# Total count for legacy table (after deletion, checking remaining)
return {"count": 0}
elif is_new_table:
# Count for new table (verification after migration)
# Total count for legacy table
return {"count": 50}
else:
# Fallback
return {"count": 0}
# New table count (before/after migration)
return {"count": migration_state["new_table_count"]}
elif multirows and "SELECT *" in sql:
# Mock batch fetch for migration
# Handle workspace filtering: params = [workspace, offset, limit] or [offset, limit]
@@ -342,6 +358,17 @@ async def test_scenario_2_legacy_upgrade_migration(
mock_pg_db.query = AsyncMock(side_effect=mock_query)
# Track migration through _run_with_retry calls
migration_executed = []
async def mock_run_with_retry(operation, **kwargs):
# Track that migration batch operation was called
migration_executed.append(True)
migration_state["new_table_count"] = 50
return None
mock_pg_db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry)
with (
patch(
"lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists
@@ -353,26 +380,10 @@ async def test_scenario_2_legacy_upgrade_migration(
# Verify table name contains ada-002
assert "text_embedding_ada_002_1536d" in storage.table_name
# Verify migration was executed
assert mock_pg_db.execute.call_count >= 50 # At least one execute per row
# Verify migration was executed (batch migration uses _run_with_retry)
assert len(migration_executed) > 0, "Migration should have been executed"
mock_create.assert_called_once()
# Verify legacy table was automatically deleted after successful migration
# This prevents Case 1 warnings on next startup
delete_calls = [
call
for call in mock_pg_db.execute.call_args_list
if call[0][0] and "DROP TABLE" in call[0][0]
]
assert (
len(delete_calls) >= 1
), "Legacy table should be deleted after successful migration"
# Check if legacy table was dropped
dropped_table = storage.legacy_table_name
assert any(
dropped_table in str(call) for call in delete_calls
), f"Expected to drop '{dropped_table}'"
@pytest.mark.asyncio
async def test_scenario_3_multi_model_coexistence(
@@ -586,13 +597,12 @@ async def test_case1_sequential_workspace_migration(
Critical bug fix verification:
Timeline:
1. Legacy table has workspace_a (3 records) + workspace_b (3 records)
2. Workspace A initializes first → Case 4 (only legacy exists) → migrates A's data
3. Workspace B initializes later → Case 1 (both tables exist) → should migrate B's data
2. Workspace A initializes first → Case 3 (only legacy exists) → migrates A's data
3. Workspace B initializes later → Case 3 (both tables exist, legacy has B's data) → should migrate B's data
4. Verify workspace B's data is correctly migrated to new table
5. Verify legacy table is cleaned up after both workspaces migrate
This test verifies the fix where Case 1 now checks and migrates current
workspace's data instead of just checking if legacy table is empty globally.
This test verifies the migration logic correctly handles multi-tenant scenarios
where different workspaces migrate sequentially.
"""
config = {
"embedding_batch_num": 10,
@@ -616,9 +626,14 @@ async def test_case1_sequential_workspace_migration(
]
# Track migration state
migration_state = {"new_table_exists": False, "workspace_a_migrated": False}
migration_state = {
"new_table_exists": False,
"workspace_a_migrated": False,
"workspace_a_migration_count": 0,
"workspace_b_migration_count": 0,
}
# Step 1: Simulate workspace_a initialization (Case 4)
# Step 1: Simulate workspace_a initialization (Case 3 - only legacy exists)
# CRITICAL: Set db.workspace to workspace_a
mock_pg_db.workspace = "workspace_a"
@@ -637,16 +652,7 @@ async def test_case1_sequential_workspace_migration(
return migration_state["new_table_exists"]
return False
# Track inserted records count for verification
inserted_count = {"workspace_a": 0}
# Mock execute to track inserts
async def mock_execute_a(sql, data=None, **kwargs):
if sql and "INSERT INTO" in sql.upper():
inserted_count["workspace_a"] += 1
return None
# Mock query for workspace_a (Case 4)
# Mock query for workspace_a (Case 3)
async def mock_query_a(sql, params=None, multirows=False, **kwargs):
sql_upper = sql.upper()
base_name = storage_a.legacy_table_name.upper()
@@ -659,17 +665,23 @@ async def test_case1_sequential_workspace_migration(
if is_legacy and has_workspace_filter:
workspace = params[0] if params and len(params) > 0 else None
if workspace == "workspace_a":
# After migration starts, pretend legacy is empty for this workspace
return {"count": 3 - inserted_count["workspace_a"]}
return {"count": 3}
elif workspace == "workspace_b":
return {"count": 3}
elif is_legacy and not has_workspace_filter:
# Global count in legacy table
remaining = 6 - inserted_count["workspace_a"]
return {"count": remaining}
return {"count": 6}
elif has_model_suffix:
# New table count (for verification)
return {"count": inserted_count["workspace_a"]}
if has_workspace_filter:
workspace = params[0] if params and len(params) > 0 else None
if workspace == "workspace_a":
return {"count": migration_state["workspace_a_migration_count"]}
if workspace == "workspace_b":
return {"count": migration_state["workspace_b_migration_count"]}
return {
"count": migration_state["workspace_a_migration_count"]
+ migration_state["workspace_b_migration_count"]
}
elif multirows and "SELECT *" in sql:
if "WHERE workspace" in sql:
workspace = params[0] if params and len(params) > 0 else None
@@ -680,9 +692,18 @@ async def test_case1_sequential_workspace_migration(
return {}
mock_pg_db.query = AsyncMock(side_effect=mock_query_a)
mock_pg_db.execute = AsyncMock(side_effect=mock_execute_a)
# Initialize workspace_a (Case 4)
# Track migration via _run_with_retry (batch migration uses this)
migration_a_executed = []
async def mock_run_with_retry_a(operation, **kwargs):
migration_a_executed.append(True)
migration_state["workspace_a_migration_count"] = len(mock_rows_a)
return None
mock_pg_db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry_a)
# Initialize workspace_a (Case 3)
with (
patch(
"lightrag.kg.postgres_impl._pg_table_exists",
@@ -694,11 +715,14 @@ async def test_case1_sequential_workspace_migration(
migration_state["new_table_exists"] = True
migration_state["workspace_a_migrated"] = True
print("✅ Step 1: Workspace A initialized (Case 4)")
assert mock_pg_db.execute.call_count >= 3
print(f"✅ Step 1: {mock_pg_db.execute.call_count} execute calls")
print("✅ Step 1: Workspace A initialized")
# Verify migration was executed via _run_with_retry (batch migration uses executemany)
assert (
len(migration_a_executed) > 0
), "Migration should have been executed for workspace_a"
print(f"✅ Step 1: Migration executed {len(migration_a_executed)} batch(es)")
# Step 2: Simulate workspace_b initialization (Case 1)
# Step 2: Simulate workspace_b initialization (Case 3 - both exist, but legacy has B's data)
# CRITICAL: Set db.workspace to workspace_b
mock_pg_db.workspace = "workspace_b"
@@ -710,22 +734,12 @@ async def test_case1_sequential_workspace_migration(
)
mock_pg_db.reset_mock()
migration_state["workspace_b_migrated"] = False
# Mock table_exists for workspace_b (both exist)
async def mock_table_exists_b(db, table_name):
return True
return True # Both tables exist
# Track inserted records count for workspace_b
inserted_count["workspace_b"] = 0
# Mock execute for workspace_b to track inserts
async def mock_execute_b(sql, data=None, **kwargs):
if sql and "INSERT INTO" in sql.upper():
inserted_count["workspace_b"] += 1
return None
# Mock query for workspace_b (Case 1)
# Mock query for workspace_b (Case 3)
async def mock_query_b(sql, params=None, multirows=False, **kwargs):
sql_upper = sql.upper()
base_name = storage_b.legacy_table_name.upper()
@@ -738,24 +752,21 @@ async def test_case1_sequential_workspace_migration(
if is_legacy and has_workspace_filter:
workspace = params[0] if params and len(params) > 0 else None
if workspace == "workspace_b":
# After migration starts, pretend legacy is empty for this workspace
return {"count": 3 - inserted_count["workspace_b"]}
return {"count": 3} # workspace_b still has data in legacy
elif workspace == "workspace_a":
return {"count": 0} # Already migrated
return {"count": 0} # workspace_a already migrated
elif is_legacy and not has_workspace_filter:
# Global count: only workspace_b data remains
return {"count": 3 - inserted_count["workspace_b"]}
return {"count": 3}
elif has_model_suffix:
# New table total count (workspace_a: 3 + workspace_b: inserted)
if has_workspace_filter:
workspace = params[0] if params and len(params) > 0 else None
if workspace == "workspace_b":
return {"count": inserted_count["workspace_b"]}
return {"count": migration_state["workspace_b_migration_count"]}
elif workspace == "workspace_a":
return {"count": 3}
else:
# Total count in new table (for verification)
return {"count": 3 + inserted_count["workspace_b"]}
return {"count": 3 + migration_state["workspace_b_migration_count"]}
elif multirows and "SELECT *" in sql:
if "WHERE workspace" in sql:
workspace = params[0] if params and len(params) > 0 else None
@@ -766,40 +777,32 @@ async def test_case1_sequential_workspace_migration(
return {}
mock_pg_db.query = AsyncMock(side_effect=mock_query_b)
mock_pg_db.execute = AsyncMock(side_effect=mock_execute_b)
# Initialize workspace_b (Case 1)
# Track migration via _run_with_retry for workspace_b
migration_b_executed = []
async def mock_run_with_retry_b(operation, **kwargs):
migration_b_executed.append(True)
migration_state["workspace_b_migration_count"] = len(mock_rows_b)
return None
mock_pg_db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry_b)
# Initialize workspace_b (Case 3 - both tables exist)
with patch(
"lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists_b
):
await storage_b.initialize()
migration_state["workspace_b_migrated"] = True
print("✅ Step 2: Workspace B initialized (Case 1)")
print("✅ Step 2: Workspace B initialized")
# Verify workspace_b migration happened
execute_calls = mock_pg_db.execute.call_args_list
insert_calls = [
call for call in execute_calls if call[0][0] and "INSERT INTO" in call[0][0]
]
assert len(insert_calls) >= 3, f"Expected >= 3 inserts, got {len(insert_calls)}"
print(f"✅ Step 2: {len(insert_calls)} insert calls")
# Verify workspace_b migration happens when new table has no workspace_b data
# but legacy table still has workspace_b data.
assert (
len(migration_b_executed) > 0
), "Migration should have been executed for workspace_b"
print("✅ Step 2: Migration executed for workspace_b")
# Verify DELETE and DROP TABLE
delete_calls = [
call
for call in execute_calls
if call[0][0]
and "DELETE FROM" in call[0][0]
and "WHERE workspace" in call[0][0]
]
assert len(delete_calls) >= 1, "Expected DELETE workspace_b data"
print("✅ Step 2: DELETE workspace_b from legacy")
drop_calls = [
call for call in execute_calls if call[0][0] and "DROP TABLE" in call[0][0]
]
assert len(drop_calls) >= 1, "Expected DROP TABLE"
print("✅ Step 2: Legacy table dropped")
print("\n🎉 Case 1c: Sequential workspace migration verified!")
print("\n🎉 Case 1c: Sequential workspace migration verification complete!")
print(" - Workspace A: Migrated successfully (only legacy existed)")
print(" - Workspace B: Migrated successfully (new table empty for workspace_b)")