Optimize Postgres batch operations and refine workspace migration logic
- Use executemany for efficient upserts - Optimize data migration with batching - Refine multi-workspace migration logic - Add pgvector dependency - Update DDL templates for dynamic dims
This commit is contained in:
@@ -31,6 +31,7 @@ from ..base import (
|
||||
DocStatus,
|
||||
DocStatusStorage,
|
||||
)
|
||||
from ..exceptions import DataMigrationError
|
||||
from ..namespace import NameSpace, is_namespace
|
||||
from ..utils import logger
|
||||
from ..kg.shared_storage import get_data_init_lock
|
||||
@@ -39,9 +40,12 @@ import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("asyncpg"):
|
||||
pm.install("asyncpg")
|
||||
if not pm.is_installed("pgvector"):
|
||||
pm.install("pgvector")
|
||||
|
||||
import asyncpg # type: ignore
|
||||
from asyncpg import Pool # type: ignore
|
||||
from pgvector.asyncpg import register_vector # type: ignore
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@@ -2191,9 +2195,7 @@ async def _pg_create_table(
|
||||
ddl_template = TABLES[base_table]["ddl"]
|
||||
|
||||
# Replace embedding dimension placeholder if exists
|
||||
ddl = ddl_template.replace(
|
||||
f"VECTOR({os.environ.get('EMBEDDING_DIM', 1024)})", f"VECTOR({embedding_dim})"
|
||||
)
|
||||
ddl = ddl_template.replace("VECTOR(dimension)", f"VECTOR({embedding_dim})")
|
||||
|
||||
# Replace table name
|
||||
ddl = ddl.replace(base_table, table_name)
|
||||
@@ -2209,7 +2211,11 @@ async def _pg_migrate_workspace_data(
|
||||
expected_count: int,
|
||||
embedding_dim: int,
|
||||
) -> int:
|
||||
"""Migrate workspace data from legacy table to new table"""
|
||||
"""Migrate workspace data from legacy table to new table using batch insert.
|
||||
|
||||
This function uses asyncpg's executemany for efficient batch insertion,
|
||||
reducing database round-trips from N to 1 per batch.
|
||||
"""
|
||||
migrated_count = 0
|
||||
offset = 0
|
||||
batch_size = 500
|
||||
@@ -2227,20 +2233,34 @@ async def _pg_migrate_workspace_data(
|
||||
if not rows:
|
||||
break
|
||||
|
||||
# Batch insert optimization: use executemany instead of individual inserts
|
||||
# Get column names from the first row
|
||||
first_row = dict(rows[0])
|
||||
columns = list(first_row.keys())
|
||||
columns_str = ", ".join(columns)
|
||||
placeholders = ", ".join([f"${i + 1}" for i in range(len(columns))])
|
||||
|
||||
insert_query = f"""
|
||||
INSERT INTO {new_table_name} ({columns_str})
|
||||
VALUES ({placeholders})
|
||||
ON CONFLICT (workspace, id) DO NOTHING
|
||||
"""
|
||||
|
||||
# Prepare batch data: convert rows to list of tuples
|
||||
batch_values = []
|
||||
for row in rows:
|
||||
row_dict = dict(row)
|
||||
columns = list(row_dict.keys())
|
||||
columns_str = ", ".join(columns)
|
||||
placeholders = ", ".join([f"${i + 1}" for i in range(len(columns))])
|
||||
insert_query = f"""
|
||||
INSERT INTO {new_table_name} ({columns_str})
|
||||
VALUES ({placeholders})
|
||||
ON CONFLICT (workspace, id) DO NOTHING
|
||||
"""
|
||||
# Rebuild dict in columns order to ensure values() matches placeholders order
|
||||
# Python 3.7+ dicts maintain insertion order, and execute() uses tuple(data.values())
|
||||
values = {col: row_dict[col] for col in columns}
|
||||
await db.execute(insert_query, values)
|
||||
# Extract values in column order to match placeholders
|
||||
values_tuple = tuple(row_dict[col] for col in columns)
|
||||
batch_values.append(values_tuple)
|
||||
|
||||
# Use executemany for batch execution - significantly reduces DB round-trips
|
||||
# Register pgvector codec to handle vector fields alongside other fields seamlessly
|
||||
async def _batch_insert(connection: asyncpg.Connection) -> None:
|
||||
await register_vector(connection)
|
||||
await connection.executemany(insert_query, batch_values)
|
||||
|
||||
await db._run_with_retry(_batch_insert)
|
||||
|
||||
migrated_count += len(rows)
|
||||
workspace_info = f" for workspace '{workspace}'" if workspace else ""
|
||||
@@ -2284,395 +2304,208 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
# Fallback: use base table name if model_suffix is unavailable
|
||||
self.table_name = base_table
|
||||
logger.warning(
|
||||
f"Model suffix unavailable, using base table name '{base_table}'. "
|
||||
f"Ensure embedding_func has model_name for proper model isolation."
|
||||
"Missing collection suffix. Ensure embedding_func has model_name for proper model isolation."
|
||||
)
|
||||
|
||||
# Legacy table name (without suffix, for migration)
|
||||
self.legacy_table_name = base_table
|
||||
|
||||
logger.debug(
|
||||
f"PostgreSQL table naming: "
|
||||
f"new='{self.table_name}', "
|
||||
f"legacy='{self.legacy_table_name}', "
|
||||
f"model_suffix='{self.model_suffix}'"
|
||||
)
|
||||
logger.info(f"PostgreSQL table name: {self.table_name}")
|
||||
|
||||
@staticmethod
|
||||
async def setup_table(
|
||||
db: PostgreSQLDB,
|
||||
table_name: str,
|
||||
legacy_table_name: str = None,
|
||||
base_table: str = None,
|
||||
embedding_dim: int = None,
|
||||
workspace: str = None,
|
||||
workspace: str,
|
||||
embedding_dim: int,
|
||||
legacy_table_name: str,
|
||||
base_table: str,
|
||||
):
|
||||
"""
|
||||
Setup PostgreSQL table with migration support from legacy tables.
|
||||
|
||||
This method mirrors Qdrant's setup_collection approach to maintain consistency.
|
||||
Ensure final table has workspace isolation index.
|
||||
Check vector dimension compatibility before new table creation.
|
||||
Drop legacy table if it exists and is empty.
|
||||
Only migrate data from legacy table to new table when new table first created and legacy table is not empty.
|
||||
|
||||
Args:
|
||||
db: PostgreSQLDB instance
|
||||
table_name: Name of the new table
|
||||
legacy_table_name: Name of the legacy table (if exists)
|
||||
workspace: Workspace to filter records for migration
|
||||
legacy_table_name: Name of the legacy table to check for migration
|
||||
base_table: Base table name for DDL template lookup
|
||||
embedding_dim: Embedding dimension for vector column
|
||||
"""
|
||||
if not workspace:
|
||||
raise ValueError("workspace must be provided")
|
||||
|
||||
new_table_exists = await _pg_table_exists(db, table_name)
|
||||
legacy_exists = legacy_table_name and await _pg_table_exists(
|
||||
db, legacy_table_name
|
||||
)
|
||||
|
||||
# Case 1: Both new and legacy tables exist
|
||||
if new_table_exists and legacy_exists:
|
||||
if table_name.lower() == legacy_table_name.lower():
|
||||
logger.debug(
|
||||
f"PostgreSQL: Table '{table_name}' already exists (no model suffix). Skipping Case 1 cleanup."
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
workspace_info = f" for workspace '{workspace}'" if workspace else ""
|
||||
|
||||
if workspace:
|
||||
count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name} WHERE workspace = $1"
|
||||
count_result = await db.query(count_query, [workspace])
|
||||
else:
|
||||
count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name}"
|
||||
count_result = await db.query(count_query, [])
|
||||
|
||||
workspace_count = count_result.get("count", 0) if count_result else 0
|
||||
|
||||
if workspace_count > 0:
|
||||
logger.info(
|
||||
f"PostgreSQL: Found {workspace_count} records in legacy table{workspace_info}. Migrating..."
|
||||
)
|
||||
|
||||
legacy_dim = None
|
||||
try:
|
||||
dim_query = """
|
||||
SELECT
|
||||
CASE
|
||||
WHEN typname = 'vector' THEN
|
||||
COALESCE(atttypmod, -1)
|
||||
ELSE -1
|
||||
END as vector_dim
|
||||
FROM pg_attribute a
|
||||
JOIN pg_type t ON a.atttypid = t.oid
|
||||
WHERE a.attrelid = $1::regclass
|
||||
AND a.attname = 'content_vector'
|
||||
"""
|
||||
dim_result = await db.query(dim_query, [legacy_table_name])
|
||||
legacy_dim = (
|
||||
dim_result.get("vector_dim", -1) if dim_result else -1
|
||||
)
|
||||
|
||||
if legacy_dim <= 0:
|
||||
sample_query = f"SELECT content_vector FROM {legacy_table_name} LIMIT 1"
|
||||
sample_result = await db.query(sample_query, [])
|
||||
if sample_result and sample_result.get("content_vector"):
|
||||
vector_data = sample_result["content_vector"]
|
||||
if isinstance(vector_data, (list, tuple)):
|
||||
legacy_dim = len(vector_data)
|
||||
elif isinstance(vector_data, str):
|
||||
import json
|
||||
|
||||
vector_list = json.loads(vector_data)
|
||||
legacy_dim = len(vector_list)
|
||||
|
||||
if (
|
||||
legacy_dim > 0
|
||||
and embedding_dim
|
||||
and legacy_dim != embedding_dim
|
||||
):
|
||||
logger.warning(
|
||||
f"PostgreSQL: Dimension mismatch - "
|
||||
f"legacy table has {legacy_dim}d vectors, "
|
||||
f"new embedding model expects {embedding_dim}d. "
|
||||
f"Skipping migration{workspace_info}."
|
||||
)
|
||||
await db._create_vector_index(table_name, embedding_dim)
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"PostgreSQL: Could not verify vector dimension: {e}. Proceeding with caution..."
|
||||
)
|
||||
|
||||
migrated_count = await _pg_migrate_workspace_data(
|
||||
db,
|
||||
legacy_table_name,
|
||||
table_name,
|
||||
workspace,
|
||||
workspace_count,
|
||||
embedding_dim,
|
||||
)
|
||||
|
||||
if workspace:
|
||||
new_count_query = f"SELECT COUNT(*) as count FROM {table_name} WHERE workspace = $1"
|
||||
new_count_result = await db.query(new_count_query, [workspace])
|
||||
else:
|
||||
new_count_query = f"SELECT COUNT(*) as count FROM {table_name}"
|
||||
new_count_result = await db.query(new_count_query, [])
|
||||
|
||||
new_count = (
|
||||
new_count_result.get("count", 0) if new_count_result else 0
|
||||
)
|
||||
|
||||
if new_count < workspace_count:
|
||||
logger.warning(
|
||||
f"PostgreSQL: Expected {workspace_count} records, found {new_count}{workspace_info}. "
|
||||
f"Some records may have been skipped due to conflicts."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"PostgreSQL: Migration completed: {migrated_count} records migrated{workspace_info}"
|
||||
)
|
||||
|
||||
if workspace:
|
||||
delete_query = (
|
||||
f"DELETE FROM {legacy_table_name} WHERE workspace = $1"
|
||||
)
|
||||
await db.execute(delete_query, {"workspace": workspace})
|
||||
logger.info(
|
||||
f"PostgreSQL: Deleted workspace '{workspace}' data from legacy table"
|
||||
)
|
||||
|
||||
total_count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name}"
|
||||
total_count_result = await db.query(total_count_query, [])
|
||||
total_count = (
|
||||
total_count_result.get("count", 0) if total_count_result else 0
|
||||
)
|
||||
|
||||
if total_count == 0:
|
||||
logger.info(
|
||||
f"PostgreSQL: Legacy table '{legacy_table_name}' is empty. Deleting..."
|
||||
)
|
||||
drop_query = f"DROP TABLE {legacy_table_name}"
|
||||
await db.execute(drop_query, None)
|
||||
logger.info(
|
||||
f"PostgreSQL: Legacy table '{legacy_table_name}' deleted successfully"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"PostgreSQL: Legacy table '{legacy_table_name}' preserved "
|
||||
f"({total_count} records from other workspaces remain)"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"PostgreSQL: Error during Case 1 migration: {e}. Vector index will still be ensured."
|
||||
)
|
||||
|
||||
# Case 1: Only new table exists or new table is the same as legacy table
|
||||
# No data migration needed, ensuring index is created then return
|
||||
if (new_table_exists and not legacy_exists) or (
|
||||
table_name.lower() == legacy_table_name.lower()
|
||||
):
|
||||
await db._create_vector_index(table_name, embedding_dim)
|
||||
return
|
||||
|
||||
# Case 2: Only new table exists - Already migrated or newly created
|
||||
if new_table_exists:
|
||||
logger.debug(f"PostgreSQL: Table '{table_name}' already exists")
|
||||
# Ensure vector index exists with correct embedding dimension
|
||||
await db._create_vector_index(table_name, embedding_dim)
|
||||
return
|
||||
|
||||
# Case 3: Neither exists - Create new table
|
||||
if not legacy_exists:
|
||||
logger.info(f"PostgreSQL: Creating new table '{table_name}'")
|
||||
await _pg_create_table(db, table_name, base_table, embedding_dim)
|
||||
logger.info(f"PostgreSQL: Table '{table_name}' created successfully")
|
||||
# Create vector index with correct embedding dimension
|
||||
await db._create_vector_index(table_name, embedding_dim)
|
||||
return
|
||||
|
||||
# Case 4: Only legacy exists - Migrate data
|
||||
logger.info(
|
||||
f"PostgreSQL: Migrating data from legacy table '{legacy_table_name}'"
|
||||
)
|
||||
|
||||
try:
|
||||
# Get legacy table count (with workspace filtering)
|
||||
if workspace:
|
||||
legacy_count = None
|
||||
if not new_table_exists:
|
||||
# Check vector dimension compatibility before creating new table
|
||||
if legacy_exists:
|
||||
count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name} WHERE workspace = $1"
|
||||
count_result = await db.query(count_query, [workspace])
|
||||
else:
|
||||
count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name}"
|
||||
count_result = await db.query(count_query, [])
|
||||
logger.warning(
|
||||
"PostgreSQL: Migration without workspace filter - this may copy data from all workspaces!"
|
||||
)
|
||||
legacy_count = count_result.get("count", 0) if count_result else 0
|
||||
|
||||
legacy_count = count_result.get("count", 0) if count_result else 0
|
||||
workspace_info = f" for workspace '{workspace}'" if workspace else ""
|
||||
logger.info(
|
||||
f"PostgreSQL: Found {legacy_count} records in legacy table{workspace_info}"
|
||||
if legacy_count > 0:
|
||||
legacy_dim = None
|
||||
try:
|
||||
sample_query = f"SELECT content_vector FROM {legacy_table_name} WHERE workspace = $1 LIMIT 1"
|
||||
sample_result = await db.query(sample_query, [workspace])
|
||||
if sample_result and sample_result.get("content_vector"):
|
||||
vector_data = sample_result["content_vector"]
|
||||
# pgvector returns list directly
|
||||
if isinstance(vector_data, (list, tuple)):
|
||||
legacy_dim = len(vector_data)
|
||||
elif isinstance(vector_data, str):
|
||||
import json
|
||||
|
||||
vector_list = json.loads(vector_data)
|
||||
legacy_dim = len(vector_list)
|
||||
|
||||
if legacy_dim and legacy_dim != embedding_dim:
|
||||
logger.error(
|
||||
f"PostgreSQL: Dimension mismatch detected! "
|
||||
f"Legacy table '{legacy_table_name}' has {legacy_dim}d vectors, "
|
||||
f"but new embedding model expects {embedding_dim}d."
|
||||
)
|
||||
raise DataMigrationError(
|
||||
f"Dimension mismatch between legacy table '{legacy_table_name}' "
|
||||
f"and new embedding model. Expected {embedding_dim} but got {legacy_dim}."
|
||||
)
|
||||
|
||||
except DataMigrationError:
|
||||
# Re-raise DataMigrationError as-is to preserve specific error messages
|
||||
raise
|
||||
except Exception as e:
|
||||
raise DataMigrationError(
|
||||
f"Could not verify legacy table vector dimension: {e}. "
|
||||
f"Proceeding with caution..."
|
||||
)
|
||||
|
||||
await _pg_create_table(db, table_name, base_table, embedding_dim)
|
||||
logger.info(f"PostgreSQL: New table '{table_name}' created successfully")
|
||||
|
||||
# Ensure vector index is created
|
||||
await db._create_vector_index(table_name, embedding_dim)
|
||||
|
||||
# Case 2: Legacy table exist
|
||||
if legacy_exists:
|
||||
workspace_info = f" for workspace '{workspace}'"
|
||||
|
||||
# Only drop legacy table if entire table is empty
|
||||
total_count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name}"
|
||||
total_count_result = await db.query(total_count_query, [])
|
||||
total_count = (
|
||||
total_count_result.get("count", 0) if total_count_result else 0
|
||||
)
|
||||
|
||||
if legacy_count == 0:
|
||||
logger.info("PostgreSQL: Legacy table is empty, skipping migration")
|
||||
await _pg_create_table(db, table_name, base_table, embedding_dim)
|
||||
# Create vector index with correct embedding dimension
|
||||
await db._create_vector_index(table_name, embedding_dim)
|
||||
if total_count == 0:
|
||||
logger.info(
|
||||
f"PostgreSQL: Empty legacy table '{legacy_table_name}' deleted successfully"
|
||||
)
|
||||
drop_query = f"DROP TABLE {legacy_table_name}"
|
||||
await db.execute(drop_query, None)
|
||||
return
|
||||
|
||||
# Check vector dimension compatibility before migration
|
||||
legacy_dim = None
|
||||
try:
|
||||
# Try to get vector dimension from pg_attribute metadata
|
||||
dim_query = """
|
||||
SELECT
|
||||
CASE
|
||||
WHEN typname = 'vector' THEN
|
||||
COALESCE(atttypmod, -1)
|
||||
ELSE -1
|
||||
END as vector_dim
|
||||
FROM pg_attribute a
|
||||
JOIN pg_type t ON a.atttypid = t.oid
|
||||
WHERE a.attrelid = $1::regclass
|
||||
AND a.attname = 'content_vector'
|
||||
"""
|
||||
dim_result = await db.query(dim_query, [legacy_table_name])
|
||||
legacy_dim = dim_result.get("vector_dim", -1) if dim_result else -1
|
||||
# No data migration needed if legacy workspace is empty
|
||||
if legacy_count is None:
|
||||
count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name} WHERE workspace = $1"
|
||||
count_result = await db.query(count_query, [workspace])
|
||||
legacy_count = count_result.get("count", 0) if count_result else 0
|
||||
|
||||
if legacy_dim <= 0:
|
||||
# Alternative: Try to detect by sampling a vector
|
||||
logger.info(
|
||||
"PostgreSQL: Metadata dimension check failed, trying vector sampling..."
|
||||
)
|
||||
sample_query = (
|
||||
f"SELECT content_vector FROM {legacy_table_name} LIMIT 1"
|
||||
)
|
||||
sample_result = await db.query(sample_query, [])
|
||||
if sample_result and sample_result.get("content_vector"):
|
||||
vector_data = sample_result["content_vector"]
|
||||
# pgvector returns list directly
|
||||
if isinstance(vector_data, (list, tuple)):
|
||||
legacy_dim = len(vector_data)
|
||||
elif isinstance(vector_data, str):
|
||||
import json
|
||||
if legacy_count == 0:
|
||||
logger.info(
|
||||
f"PostgreSQL: No records{workspace_info} found in legacy table. "
|
||||
f"No data migration needed."
|
||||
)
|
||||
return
|
||||
|
||||
vector_list = json.loads(vector_data)
|
||||
legacy_dim = len(vector_list)
|
||||
new_count_query = (
|
||||
f"SELECT COUNT(*) as count FROM {table_name} WHERE workspace = $1"
|
||||
)
|
||||
new_count_result = await db.query(new_count_query, [workspace])
|
||||
new_table_workspace_count = (
|
||||
new_count_result.get("count", 0) if new_count_result else 0
|
||||
)
|
||||
|
||||
if legacy_dim > 0 and embedding_dim and legacy_dim != embedding_dim:
|
||||
logger.warning(
|
||||
f"PostgreSQL: Dimension mismatch detected! "
|
||||
f"Legacy table '{legacy_table_name}' has {legacy_dim}d vectors, "
|
||||
f"but new embedding model expects {embedding_dim}d. "
|
||||
f"Migration skipped to prevent data loss. "
|
||||
f"Legacy table preserved as '{legacy_table_name}'. "
|
||||
f"Creating new empty table '{table_name}' for new data."
|
||||
)
|
||||
|
||||
# Create new table but skip migration
|
||||
await _pg_create_table(db, table_name, base_table, embedding_dim)
|
||||
await db._create_vector_index(table_name, embedding_dim)
|
||||
|
||||
logger.info(
|
||||
f"PostgreSQL: New table '{table_name}' created. "
|
||||
f"To query legacy data, please use a {legacy_dim}d embedding model."
|
||||
)
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
if new_table_workspace_count > 0:
|
||||
logger.warning(
|
||||
f"PostgreSQL: Could not verify legacy table vector dimension: {e}. "
|
||||
f"Proceeding with caution..."
|
||||
f"PostgreSQL: New table '{table_name}' already has "
|
||||
f"{new_table_workspace_count} records{workspace_info}. "
|
||||
"Data migration skipped to avoid duplicates."
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(f"PostgreSQL: Creating new table '{table_name}'")
|
||||
await _pg_create_table(db, table_name, base_table, embedding_dim)
|
||||
|
||||
migrated_count = await _pg_migrate_workspace_data(
|
||||
db,
|
||||
legacy_table_name,
|
||||
table_name,
|
||||
workspace,
|
||||
legacy_count,
|
||||
embedding_dim,
|
||||
)
|
||||
|
||||
logger.info("PostgreSQL: Verifying migration...")
|
||||
new_count_query = f"SELECT COUNT(*) as count FROM {table_name}"
|
||||
new_count_result = await db.query(new_count_query, [])
|
||||
new_count = new_count_result.get("count", 0) if new_count_result else 0
|
||||
|
||||
if new_count != legacy_count:
|
||||
error_msg = (
|
||||
f"PostgreSQL: Migration verification failed, "
|
||||
f"expected {legacy_count} records, got {new_count} in new table"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise PostgreSQLMigrationError(error_msg)
|
||||
|
||||
# Case 3: Legacy has workspace data and new table is empty for workspace
|
||||
logger.info(
|
||||
f"PostgreSQL: Migration completed successfully: {migrated_count} records migrated"
|
||||
f"PostgreSQL: Found legacy table '{legacy_table_name}' with {legacy_count} records{workspace_info}."
|
||||
)
|
||||
logger.info(
|
||||
f"PostgreSQL: Migrating data from legacy table '{legacy_table_name}' to new table '{table_name}'"
|
||||
)
|
||||
|
||||
try:
|
||||
migrated_count = await _pg_migrate_workspace_data(
|
||||
db,
|
||||
legacy_table_name,
|
||||
table_name,
|
||||
workspace,
|
||||
legacy_count,
|
||||
embedding_dim,
|
||||
)
|
||||
if migrated_count != legacy_count:
|
||||
logger.warning(
|
||||
"PostgreSQL: Read %s legacy records%s during migration, expected %s.",
|
||||
migrated_count,
|
||||
workspace_info,
|
||||
legacy_count,
|
||||
)
|
||||
|
||||
new_count_result = await db.query(new_count_query, [workspace])
|
||||
new_table_count_after = (
|
||||
new_count_result.get("count", 0) if new_count_result else 0
|
||||
)
|
||||
inserted_count = new_table_count_after - new_table_workspace_count
|
||||
|
||||
if inserted_count != legacy_count:
|
||||
error_msg = (
|
||||
"PostgreSQL: Migration verification failed, "
|
||||
f"expected {legacy_count} inserted records, got {inserted_count}."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise DataMigrationError(error_msg)
|
||||
|
||||
except DataMigrationError:
|
||||
# Re-raise DataMigrationError as-is to preserve specific error messages
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"PostgreSQL: Failed to migrate data from legacy table '{legacy_table_name}' to new table '{table_name}': {e}"
|
||||
)
|
||||
raise DataMigrationError(
|
||||
f"Failed to migrate data from legacy table '{legacy_table_name}' to new table '{table_name}'"
|
||||
) from e
|
||||
|
||||
logger.info(
|
||||
f"PostgreSQL: Migration from '{legacy_table_name}' to '{table_name}' completed successfully"
|
||||
)
|
||||
|
||||
await db._create_vector_index(table_name, embedding_dim)
|
||||
|
||||
try:
|
||||
if workspace:
|
||||
logger.info(
|
||||
f"PostgreSQL: Deleting migrated workspace '{workspace}' data from legacy table '{legacy_table_name}'..."
|
||||
)
|
||||
delete_query = (
|
||||
f"DELETE FROM {legacy_table_name} WHERE workspace = $1"
|
||||
)
|
||||
await db.execute(delete_query, {"workspace": workspace})
|
||||
logger.info(
|
||||
f"PostgreSQL: Deleted workspace '{workspace}' data from legacy table"
|
||||
)
|
||||
|
||||
remaining_query = (
|
||||
f"SELECT COUNT(*) as count FROM {legacy_table_name}"
|
||||
)
|
||||
remaining_result = await db.query(remaining_query, [])
|
||||
remaining_count = (
|
||||
remaining_result.get("count", 0) if remaining_result else 0
|
||||
)
|
||||
|
||||
if remaining_count == 0:
|
||||
logger.info(
|
||||
f"PostgreSQL: Legacy table '{legacy_table_name}' is empty, deleting..."
|
||||
)
|
||||
drop_query = f"DROP TABLE {legacy_table_name}"
|
||||
await db.execute(drop_query, None)
|
||||
logger.info(
|
||||
f"PostgreSQL: Legacy table '{legacy_table_name}' deleted successfully"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"PostgreSQL: Legacy table '{legacy_table_name}' preserved ({remaining_count} records from other workspaces remain)"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"PostgreSQL: No workspace specified, deleting entire legacy table '{legacy_table_name}'..."
|
||||
)
|
||||
drop_query = f"DROP TABLE {legacy_table_name}"
|
||||
await db.execute(drop_query, None)
|
||||
logger.info(
|
||||
f"PostgreSQL: Legacy table '{legacy_table_name}' deleted"
|
||||
)
|
||||
|
||||
except Exception as delete_error:
|
||||
# If cleanup fails, log warning but don't fail migration
|
||||
logger.warning(
|
||||
f"PostgreSQL: Failed to clean up legacy table '{legacy_table_name}': {delete_error}. "
|
||||
"Migration succeeded, but manual cleanup may be needed."
|
||||
)
|
||||
|
||||
except PostgreSQLMigrationError:
|
||||
# Re-raise migration errors without wrapping
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"PostgreSQL: Migration failed with error: {e}"
|
||||
logger.error(error_msg)
|
||||
# Mirror Qdrant behavior: no automatic rollback
|
||||
# Reason: partial data can be continued by re-running migration
|
||||
raise PostgreSQLMigrationError(error_msg) from e
|
||||
logger.info(
|
||||
"PostgreSQL: Manual deletion is required after data migration verification."
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
async with get_data_init_lock():
|
||||
@@ -2694,10 +2527,10 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
await PGVectorStorage.setup_table(
|
||||
self.db,
|
||||
self.table_name,
|
||||
self.workspace, # CRITICAL: Filter migration by workspace
|
||||
embedding_dim=self.embedding_func.embedding_dim,
|
||||
legacy_table_name=self.legacy_table_name,
|
||||
base_table=self.legacy_table_name, # base_table for DDL template lookup
|
||||
embedding_dim=self.embedding_func.embedding_dim,
|
||||
workspace=self.workspace, # CRITICAL: Filter migration by workspace
|
||||
)
|
||||
|
||||
async def finalize(self):
|
||||
@@ -2707,34 +2540,45 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
|
||||
def _upsert_chunks(
|
||||
self, item: dict[str, Any], current_time: datetime.datetime
|
||||
) -> tuple[str, dict[str, Any]]:
|
||||
) -> tuple[str, tuple[Any, ...]]:
|
||||
"""Prepare upsert data for chunks.
|
||||
|
||||
Returns:
|
||||
Tuple of (SQL template, values tuple for executemany)
|
||||
"""
|
||||
try:
|
||||
upsert_sql = SQL_TEMPLATES["upsert_chunk"].format(
|
||||
table_name=self.table_name
|
||||
)
|
||||
data: dict[str, Any] = {
|
||||
"workspace": self.workspace,
|
||||
"id": item["__id__"],
|
||||
"tokens": item["tokens"],
|
||||
"chunk_order_index": item["chunk_order_index"],
|
||||
"full_doc_id": item["full_doc_id"],
|
||||
"content": item["content"],
|
||||
"content_vector": json.dumps(item["__vector__"].tolist()),
|
||||
"file_path": item["file_path"],
|
||||
"create_time": current_time,
|
||||
"update_time": current_time,
|
||||
}
|
||||
# Return tuple in the exact order of SQL parameters ($1, $2, ...)
|
||||
values: tuple[Any, ...] = (
|
||||
self.workspace, # $1
|
||||
item["__id__"], # $2
|
||||
item["tokens"], # $3
|
||||
item["chunk_order_index"], # $4
|
||||
item["full_doc_id"], # $5
|
||||
item["content"], # $6
|
||||
item["__vector__"], # $7 - numpy array, handled by pgvector codec
|
||||
item["file_path"], # $8
|
||||
current_time, # $9
|
||||
current_time, # $10
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[{self.workspace}] Error to prepare upsert,\nsql: {e}\nitem: {item}"
|
||||
f"[{self.workspace}] Error to prepare upsert,\nerror: {e}\nitem: {item}"
|
||||
)
|
||||
raise
|
||||
|
||||
return upsert_sql, data
|
||||
return upsert_sql, values
|
||||
|
||||
def _upsert_entities(
|
||||
self, item: dict[str, Any], current_time: datetime.datetime
|
||||
) -> tuple[str, dict[str, Any]]:
|
||||
) -> tuple[str, tuple[Any, ...]]:
|
||||
"""Prepare upsert data for entities.
|
||||
|
||||
Returns:
|
||||
Tuple of (SQL template, values tuple for executemany)
|
||||
"""
|
||||
upsert_sql = SQL_TEMPLATES["upsert_entity"].format(table_name=self.table_name)
|
||||
source_id = item["source_id"]
|
||||
if isinstance(source_id, str) and "<SEP>" in source_id:
|
||||
@@ -2742,22 +2586,28 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
else:
|
||||
chunk_ids = [source_id]
|
||||
|
||||
data: dict[str, Any] = {
|
||||
"workspace": self.workspace,
|
||||
"id": item["__id__"],
|
||||
"entity_name": item["entity_name"],
|
||||
"content": item["content"],
|
||||
"content_vector": json.dumps(item["__vector__"].tolist()),
|
||||
"chunk_ids": chunk_ids,
|
||||
"file_path": item.get("file_path", None),
|
||||
"create_time": current_time,
|
||||
"update_time": current_time,
|
||||
}
|
||||
return upsert_sql, data
|
||||
# Return tuple in the exact order of SQL parameters ($1, $2, ...)
|
||||
values: tuple[Any, ...] = (
|
||||
self.workspace, # $1
|
||||
item["__id__"], # $2
|
||||
item["entity_name"], # $3
|
||||
item["content"], # $4
|
||||
item["__vector__"], # $5 - numpy array, handled by pgvector codec
|
||||
chunk_ids, # $6
|
||||
item.get("file_path", None), # $7
|
||||
current_time, # $8
|
||||
current_time, # $9
|
||||
)
|
||||
return upsert_sql, values
|
||||
|
||||
def _upsert_relationships(
|
||||
self, item: dict[str, Any], current_time: datetime.datetime
|
||||
) -> tuple[str, dict[str, Any]]:
|
||||
) -> tuple[str, tuple[Any, ...]]:
|
||||
"""Prepare upsert data for relationships.
|
||||
|
||||
Returns:
|
||||
Tuple of (SQL template, values tuple for executemany)
|
||||
"""
|
||||
upsert_sql = SQL_TEMPLATES["upsert_relationship"].format(
|
||||
table_name=self.table_name
|
||||
)
|
||||
@@ -2767,19 +2617,20 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
else:
|
||||
chunk_ids = [source_id]
|
||||
|
||||
data: dict[str, Any] = {
|
||||
"workspace": self.workspace,
|
||||
"id": item["__id__"],
|
||||
"source_id": item["src_id"],
|
||||
"target_id": item["tgt_id"],
|
||||
"content": item["content"],
|
||||
"content_vector": json.dumps(item["__vector__"].tolist()),
|
||||
"chunk_ids": chunk_ids,
|
||||
"file_path": item.get("file_path", None),
|
||||
"create_time": current_time,
|
||||
"update_time": current_time,
|
||||
}
|
||||
return upsert_sql, data
|
||||
# Return tuple in the exact order of SQL parameters ($1, $2, ...)
|
||||
values: tuple[Any, ...] = (
|
||||
self.workspace, # $1
|
||||
item["__id__"], # $2
|
||||
item["src_id"], # $3
|
||||
item["tgt_id"], # $4
|
||||
item["content"], # $5
|
||||
item["__vector__"], # $6 - numpy array, handled by pgvector codec
|
||||
chunk_ids, # $7
|
||||
item.get("file_path", None), # $8
|
||||
current_time, # $9
|
||||
current_time, # $10
|
||||
)
|
||||
return upsert_sql, values
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
|
||||
@@ -2807,17 +2658,34 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
for i, d in enumerate(list_data):
|
||||
d["__vector__"] = embeddings[i]
|
||||
|
||||
# Prepare batch values for executemany
|
||||
batch_values: list[tuple[Any, ...]] = []
|
||||
upsert_sql = None
|
||||
|
||||
for item in list_data:
|
||||
if is_namespace(self.namespace, NameSpace.VECTOR_STORE_CHUNKS):
|
||||
upsert_sql, data = self._upsert_chunks(item, current_time)
|
||||
upsert_sql, values = self._upsert_chunks(item, current_time)
|
||||
elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_ENTITIES):
|
||||
upsert_sql, data = self._upsert_entities(item, current_time)
|
||||
upsert_sql, values = self._upsert_entities(item, current_time)
|
||||
elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_RELATIONSHIPS):
|
||||
upsert_sql, data = self._upsert_relationships(item, current_time)
|
||||
upsert_sql, values = self._upsert_relationships(item, current_time)
|
||||
else:
|
||||
raise ValueError(f"{self.namespace} is not supported")
|
||||
|
||||
await self.db.execute(upsert_sql, data)
|
||||
batch_values.append(values)
|
||||
|
||||
# Use executemany for batch execution - significantly reduces DB round-trips
|
||||
if batch_values and upsert_sql:
|
||||
|
||||
async def _batch_upsert(connection: asyncpg.Connection) -> None:
|
||||
await register_vector(connection)
|
||||
await connection.executemany(upsert_sql, batch_values)
|
||||
|
||||
await self.db._run_with_retry(_batch_upsert)
|
||||
logger.debug(
|
||||
f"[{self.workspace}] Batch upserted {len(batch_values)} records to {self.namespace}"
|
||||
)
|
||||
|
||||
#################### query method ###############
|
||||
async def query(
|
||||
@@ -3658,12 +3526,6 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
class PostgreSQLMigrationError(Exception):
|
||||
"""Exception for PostgreSQL table migration errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class PGGraphQueryException(Exception):
|
||||
"""Exception for the AGE queries."""
|
||||
|
||||
@@ -5263,14 +5125,14 @@ TABLES = {
|
||||
)"""
|
||||
},
|
||||
"LIGHTRAG_VDB_CHUNKS": {
|
||||
"ddl": f"""CREATE TABLE LIGHTRAG_VDB_CHUNKS (
|
||||
"ddl": """CREATE TABLE LIGHTRAG_VDB_CHUNKS (
|
||||
id VARCHAR(255),
|
||||
workspace VARCHAR(255),
|
||||
full_doc_id VARCHAR(256),
|
||||
chunk_order_index INTEGER,
|
||||
tokens INTEGER,
|
||||
content TEXT,
|
||||
content_vector VECTOR({os.environ.get("EMBEDDING_DIM", 1024)}),
|
||||
content_vector VECTOR(dimension),
|
||||
file_path TEXT NULL,
|
||||
create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
|
||||
@@ -5278,12 +5140,12 @@ TABLES = {
|
||||
)"""
|
||||
},
|
||||
"LIGHTRAG_VDB_ENTITY": {
|
||||
"ddl": f"""CREATE TABLE LIGHTRAG_VDB_ENTITY (
|
||||
"ddl": """CREATE TABLE LIGHTRAG_VDB_ENTITY (
|
||||
id VARCHAR(255),
|
||||
workspace VARCHAR(255),
|
||||
entity_name VARCHAR(512),
|
||||
content TEXT,
|
||||
content_vector VECTOR({os.environ.get("EMBEDDING_DIM", 1024)}),
|
||||
content_vector VECTOR(dimension),
|
||||
create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
|
||||
chunk_ids VARCHAR(255)[] NULL,
|
||||
@@ -5292,13 +5154,13 @@ TABLES = {
|
||||
)"""
|
||||
},
|
||||
"LIGHTRAG_VDB_RELATION": {
|
||||
"ddl": f"""CREATE TABLE LIGHTRAG_VDB_RELATION (
|
||||
"ddl": """CREATE TABLE LIGHTRAG_VDB_RELATION (
|
||||
id VARCHAR(255),
|
||||
workspace VARCHAR(255),
|
||||
source_id VARCHAR(512),
|
||||
target_id VARCHAR(512),
|
||||
content TEXT,
|
||||
content_vector VECTOR({os.environ.get("EMBEDDING_DIM", 1024)}),
|
||||
content_vector VECTOR(dimension),
|
||||
create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
|
||||
chunk_ids VARCHAR(255)[] NULL,
|
||||
|
||||
@@ -72,7 +72,6 @@ api = [
|
||||
# API-specific dependencies
|
||||
"aiofiles",
|
||||
"ascii_colors",
|
||||
"asyncpg",
|
||||
"distro",
|
||||
"fastapi",
|
||||
"httpcore",
|
||||
@@ -108,7 +107,8 @@ offline-storage = [
|
||||
"neo4j>=5.0.0,<7.0.0",
|
||||
"pymilvus>=2.6.2,<3.0.0",
|
||||
"pymongo>=4.0.0,<5.0.0",
|
||||
"asyncpg>=0.29.0,<1.0.0",
|
||||
"asyncpg>=0.31.0,<1.0.0",
|
||||
"pgvector>=0.4.2,<1.0.0",
|
||||
"qdrant-client>=1.11.0,<2.0.0",
|
||||
]
|
||||
|
||||
|
||||
@@ -8,8 +8,9 @@
|
||||
# Or use constraints: pip install --constraint constraints-offline.txt -r requirements-offline-storage.txt
|
||||
|
||||
# Storage backend dependencies (with version constraints matching pyproject.toml)
|
||||
asyncpg>=0.29.0,<1.0.0
|
||||
asyncpg>=0.31.0,<1.0.0
|
||||
neo4j>=5.0.0,<7.0.0
|
||||
pgvector>=0.4.2,<1.0.0
|
||||
pymilvus>=2.6.2,<3.0.0
|
||||
pymongo>=4.0.0,<5.0.0
|
||||
qdrant-client>=1.11.0,<2.0.0
|
||||
|
||||
@@ -7,20 +7,17 @@
|
||||
# Recommended: Use pip install lightrag-hku[offline] for the same effect
|
||||
# Or use constraints: pip install --constraint constraints-offline.txt -r requirements-offline.txt
|
||||
|
||||
# LLM provider dependencies (with version constraints matching pyproject.toml)
|
||||
aioboto3>=12.0.0,<16.0.0
|
||||
anthropic>=0.18.0,<1.0.0
|
||||
|
||||
# Storage backend dependencies
|
||||
asyncpg>=0.29.0,<1.0.0
|
||||
asyncpg>=0.31.0,<1.0.0
|
||||
google-api-core>=2.0.0,<3.0.0
|
||||
google-genai>=1.0.0,<2.0.0
|
||||
|
||||
# Document processing dependencies
|
||||
llama-index>=0.9.0,<1.0.0
|
||||
neo4j>=5.0.0,<7.0.0
|
||||
ollama>=0.1.0,<1.0.0
|
||||
openai>=2.0.0,<3.0.0
|
||||
openpyxl>=3.0.0,<4.0.0
|
||||
pgvector>=0.4.2,<1.0.0
|
||||
pycryptodome>=3.0.0,<4.0.0
|
||||
pymilvus>=2.6.2,<3.0.0
|
||||
pymongo>=4.0.0,<5.0.0
|
||||
|
||||
@@ -123,10 +123,21 @@ async def test_postgres_migration_trigger(
|
||||
{"id": f"test_id_{i}", "content": f"content_{i}", "workspace": "test_ws"}
|
||||
for i in range(100)
|
||||
]
|
||||
migration_state = {"new_table_count": 0}
|
||||
|
||||
async def mock_query(sql, params=None, multirows=False, **kwargs):
|
||||
if "COUNT(*)" in sql:
|
||||
return {"count": 100}
|
||||
sql_upper = sql.upper()
|
||||
legacy_table = storage.legacy_table_name.upper()
|
||||
new_table = storage.table_name.upper()
|
||||
is_new_table = new_table in sql_upper
|
||||
is_legacy_table = legacy_table in sql_upper and not is_new_table
|
||||
|
||||
if is_new_table:
|
||||
return {"count": migration_state["new_table_count"]}
|
||||
if is_legacy_table:
|
||||
return {"count": 100}
|
||||
return {"count": 0}
|
||||
elif multirows and "SELECT *" in sql:
|
||||
# Mock batch fetch for migration
|
||||
# Handle workspace filtering: params = [workspace, offset, limit] or [offset, limit]
|
||||
@@ -145,6 +156,17 @@ async def test_postgres_migration_trigger(
|
||||
|
||||
mock_pg_db.query = AsyncMock(side_effect=mock_query)
|
||||
|
||||
# Track migration through _run_with_retry calls
|
||||
migration_executed = []
|
||||
|
||||
async def mock_run_with_retry(operation, **kwargs):
|
||||
# Track that migration batch operation was called
|
||||
migration_executed.append(True)
|
||||
migration_state["new_table_count"] = 100
|
||||
return None
|
||||
|
||||
mock_pg_db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists
|
||||
@@ -154,9 +176,9 @@ async def test_postgres_migration_trigger(
|
||||
# Initialize storage (should trigger migration)
|
||||
await storage.initialize()
|
||||
|
||||
# Verify migration was executed
|
||||
# Check that execute was called for inserting rows
|
||||
assert mock_pg_db.execute.call_count > 0
|
||||
# Verify migration was executed by checking _run_with_retry was called
|
||||
# (batch migration uses _run_with_retry with executemany)
|
||||
assert len(migration_executed) > 0, "Migration should have been executed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -291,6 +313,7 @@ async def test_scenario_2_legacy_upgrade_migration(
|
||||
|
||||
# Track which queries have been made for proper response
|
||||
query_history = []
|
||||
migration_state = {"new_table_count": 0}
|
||||
|
||||
async def mock_query(sql, params=None, multirows=False, **kwargs):
|
||||
query_history.append(sql)
|
||||
@@ -303,27 +326,20 @@ async def test_scenario_2_legacy_upgrade_migration(
|
||||
base_name = storage.legacy_table_name.upper()
|
||||
|
||||
# Check if this is querying the new table (has model suffix)
|
||||
has_model_suffix = any(
|
||||
suffix in sql_upper
|
||||
for suffix in ["TEXT_EMBEDDING", "_1536D", "_768D", "_1024D", "_3072D"]
|
||||
)
|
||||
has_model_suffix = storage.table_name.upper() in sql_upper
|
||||
|
||||
is_legacy_table = base_name in sql_upper and not has_model_suffix
|
||||
is_new_table = has_model_suffix
|
||||
has_workspace_filter = "WHERE workspace" in sql
|
||||
|
||||
if is_legacy_table and has_workspace_filter:
|
||||
# Count for legacy table with workspace filter (before migration)
|
||||
return {"count": 50}
|
||||
elif is_legacy_table and not has_workspace_filter:
|
||||
# Total count for legacy table (after deletion, checking remaining)
|
||||
return {"count": 0}
|
||||
elif is_new_table:
|
||||
# Count for new table (verification after migration)
|
||||
# Total count for legacy table
|
||||
return {"count": 50}
|
||||
else:
|
||||
# Fallback
|
||||
return {"count": 0}
|
||||
# New table count (before/after migration)
|
||||
return {"count": migration_state["new_table_count"]}
|
||||
elif multirows and "SELECT *" in sql:
|
||||
# Mock batch fetch for migration
|
||||
# Handle workspace filtering: params = [workspace, offset, limit] or [offset, limit]
|
||||
@@ -342,6 +358,17 @@ async def test_scenario_2_legacy_upgrade_migration(
|
||||
|
||||
mock_pg_db.query = AsyncMock(side_effect=mock_query)
|
||||
|
||||
# Track migration through _run_with_retry calls
|
||||
migration_executed = []
|
||||
|
||||
async def mock_run_with_retry(operation, **kwargs):
|
||||
# Track that migration batch operation was called
|
||||
migration_executed.append(True)
|
||||
migration_state["new_table_count"] = 50
|
||||
return None
|
||||
|
||||
mock_pg_db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists
|
||||
@@ -353,26 +380,10 @@ async def test_scenario_2_legacy_upgrade_migration(
|
||||
# Verify table name contains ada-002
|
||||
assert "text_embedding_ada_002_1536d" in storage.table_name
|
||||
|
||||
# Verify migration was executed
|
||||
assert mock_pg_db.execute.call_count >= 50 # At least one execute per row
|
||||
# Verify migration was executed (batch migration uses _run_with_retry)
|
||||
assert len(migration_executed) > 0, "Migration should have been executed"
|
||||
mock_create.assert_called_once()
|
||||
|
||||
# Verify legacy table was automatically deleted after successful migration
|
||||
# This prevents Case 1 warnings on next startup
|
||||
delete_calls = [
|
||||
call
|
||||
for call in mock_pg_db.execute.call_args_list
|
||||
if call[0][0] and "DROP TABLE" in call[0][0]
|
||||
]
|
||||
assert (
|
||||
len(delete_calls) >= 1
|
||||
), "Legacy table should be deleted after successful migration"
|
||||
# Check if legacy table was dropped
|
||||
dropped_table = storage.legacy_table_name
|
||||
assert any(
|
||||
dropped_table in str(call) for call in delete_calls
|
||||
), f"Expected to drop '{dropped_table}'"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_3_multi_model_coexistence(
|
||||
@@ -586,13 +597,12 @@ async def test_case1_sequential_workspace_migration(
|
||||
Critical bug fix verification:
|
||||
Timeline:
|
||||
1. Legacy table has workspace_a (3 records) + workspace_b (3 records)
|
||||
2. Workspace A initializes first → Case 4 (only legacy exists) → migrates A's data
|
||||
3. Workspace B initializes later → Case 1 (both tables exist) → should migrate B's data
|
||||
2. Workspace A initializes first → Case 3 (only legacy exists) → migrates A's data
|
||||
3. Workspace B initializes later → Case 3 (both tables exist, legacy has B's data) → should migrate B's data
|
||||
4. Verify workspace B's data is correctly migrated to new table
|
||||
5. Verify legacy table is cleaned up after both workspaces migrate
|
||||
|
||||
This test verifies the fix where Case 1 now checks and migrates current
|
||||
workspace's data instead of just checking if legacy table is empty globally.
|
||||
This test verifies the migration logic correctly handles multi-tenant scenarios
|
||||
where different workspaces migrate sequentially.
|
||||
"""
|
||||
config = {
|
||||
"embedding_batch_num": 10,
|
||||
@@ -616,9 +626,14 @@ async def test_case1_sequential_workspace_migration(
|
||||
]
|
||||
|
||||
# Track migration state
|
||||
migration_state = {"new_table_exists": False, "workspace_a_migrated": False}
|
||||
migration_state = {
|
||||
"new_table_exists": False,
|
||||
"workspace_a_migrated": False,
|
||||
"workspace_a_migration_count": 0,
|
||||
"workspace_b_migration_count": 0,
|
||||
}
|
||||
|
||||
# Step 1: Simulate workspace_a initialization (Case 4)
|
||||
# Step 1: Simulate workspace_a initialization (Case 3 - only legacy exists)
|
||||
# CRITICAL: Set db.workspace to workspace_a
|
||||
mock_pg_db.workspace = "workspace_a"
|
||||
|
||||
@@ -637,16 +652,7 @@ async def test_case1_sequential_workspace_migration(
|
||||
return migration_state["new_table_exists"]
|
||||
return False
|
||||
|
||||
# Track inserted records count for verification
|
||||
inserted_count = {"workspace_a": 0}
|
||||
|
||||
# Mock execute to track inserts
|
||||
async def mock_execute_a(sql, data=None, **kwargs):
|
||||
if sql and "INSERT INTO" in sql.upper():
|
||||
inserted_count["workspace_a"] += 1
|
||||
return None
|
||||
|
||||
# Mock query for workspace_a (Case 4)
|
||||
# Mock query for workspace_a (Case 3)
|
||||
async def mock_query_a(sql, params=None, multirows=False, **kwargs):
|
||||
sql_upper = sql.upper()
|
||||
base_name = storage_a.legacy_table_name.upper()
|
||||
@@ -659,17 +665,23 @@ async def test_case1_sequential_workspace_migration(
|
||||
if is_legacy and has_workspace_filter:
|
||||
workspace = params[0] if params and len(params) > 0 else None
|
||||
if workspace == "workspace_a":
|
||||
# After migration starts, pretend legacy is empty for this workspace
|
||||
return {"count": 3 - inserted_count["workspace_a"]}
|
||||
return {"count": 3}
|
||||
elif workspace == "workspace_b":
|
||||
return {"count": 3}
|
||||
elif is_legacy and not has_workspace_filter:
|
||||
# Global count in legacy table
|
||||
remaining = 6 - inserted_count["workspace_a"]
|
||||
return {"count": remaining}
|
||||
return {"count": 6}
|
||||
elif has_model_suffix:
|
||||
# New table count (for verification)
|
||||
return {"count": inserted_count["workspace_a"]}
|
||||
if has_workspace_filter:
|
||||
workspace = params[0] if params and len(params) > 0 else None
|
||||
if workspace == "workspace_a":
|
||||
return {"count": migration_state["workspace_a_migration_count"]}
|
||||
if workspace == "workspace_b":
|
||||
return {"count": migration_state["workspace_b_migration_count"]}
|
||||
return {
|
||||
"count": migration_state["workspace_a_migration_count"]
|
||||
+ migration_state["workspace_b_migration_count"]
|
||||
}
|
||||
elif multirows and "SELECT *" in sql:
|
||||
if "WHERE workspace" in sql:
|
||||
workspace = params[0] if params and len(params) > 0 else None
|
||||
@@ -680,9 +692,18 @@ async def test_case1_sequential_workspace_migration(
|
||||
return {}
|
||||
|
||||
mock_pg_db.query = AsyncMock(side_effect=mock_query_a)
|
||||
mock_pg_db.execute = AsyncMock(side_effect=mock_execute_a)
|
||||
|
||||
# Initialize workspace_a (Case 4)
|
||||
# Track migration via _run_with_retry (batch migration uses this)
|
||||
migration_a_executed = []
|
||||
|
||||
async def mock_run_with_retry_a(operation, **kwargs):
|
||||
migration_a_executed.append(True)
|
||||
migration_state["workspace_a_migration_count"] = len(mock_rows_a)
|
||||
return None
|
||||
|
||||
mock_pg_db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry_a)
|
||||
|
||||
# Initialize workspace_a (Case 3)
|
||||
with (
|
||||
patch(
|
||||
"lightrag.kg.postgres_impl._pg_table_exists",
|
||||
@@ -694,11 +715,14 @@ async def test_case1_sequential_workspace_migration(
|
||||
migration_state["new_table_exists"] = True
|
||||
migration_state["workspace_a_migrated"] = True
|
||||
|
||||
print("✅ Step 1: Workspace A initialized (Case 4)")
|
||||
assert mock_pg_db.execute.call_count >= 3
|
||||
print(f"✅ Step 1: {mock_pg_db.execute.call_count} execute calls")
|
||||
print("✅ Step 1: Workspace A initialized")
|
||||
# Verify migration was executed via _run_with_retry (batch migration uses executemany)
|
||||
assert (
|
||||
len(migration_a_executed) > 0
|
||||
), "Migration should have been executed for workspace_a"
|
||||
print(f"✅ Step 1: Migration executed {len(migration_a_executed)} batch(es)")
|
||||
|
||||
# Step 2: Simulate workspace_b initialization (Case 1)
|
||||
# Step 2: Simulate workspace_b initialization (Case 3 - both exist, but legacy has B's data)
|
||||
# CRITICAL: Set db.workspace to workspace_b
|
||||
mock_pg_db.workspace = "workspace_b"
|
||||
|
||||
@@ -710,22 +734,12 @@ async def test_case1_sequential_workspace_migration(
|
||||
)
|
||||
|
||||
mock_pg_db.reset_mock()
|
||||
migration_state["workspace_b_migrated"] = False
|
||||
|
||||
# Mock table_exists for workspace_b (both exist)
|
||||
async def mock_table_exists_b(db, table_name):
|
||||
return True
|
||||
return True # Both tables exist
|
||||
|
||||
# Track inserted records count for workspace_b
|
||||
inserted_count["workspace_b"] = 0
|
||||
|
||||
# Mock execute for workspace_b to track inserts
|
||||
async def mock_execute_b(sql, data=None, **kwargs):
|
||||
if sql and "INSERT INTO" in sql.upper():
|
||||
inserted_count["workspace_b"] += 1
|
||||
return None
|
||||
|
||||
# Mock query for workspace_b (Case 1)
|
||||
# Mock query for workspace_b (Case 3)
|
||||
async def mock_query_b(sql, params=None, multirows=False, **kwargs):
|
||||
sql_upper = sql.upper()
|
||||
base_name = storage_b.legacy_table_name.upper()
|
||||
@@ -738,24 +752,21 @@ async def test_case1_sequential_workspace_migration(
|
||||
if is_legacy and has_workspace_filter:
|
||||
workspace = params[0] if params and len(params) > 0 else None
|
||||
if workspace == "workspace_b":
|
||||
# After migration starts, pretend legacy is empty for this workspace
|
||||
return {"count": 3 - inserted_count["workspace_b"]}
|
||||
return {"count": 3} # workspace_b still has data in legacy
|
||||
elif workspace == "workspace_a":
|
||||
return {"count": 0} # Already migrated
|
||||
return {"count": 0} # workspace_a already migrated
|
||||
elif is_legacy and not has_workspace_filter:
|
||||
# Global count: only workspace_b data remains
|
||||
return {"count": 3 - inserted_count["workspace_b"]}
|
||||
return {"count": 3}
|
||||
elif has_model_suffix:
|
||||
# New table total count (workspace_a: 3 + workspace_b: inserted)
|
||||
if has_workspace_filter:
|
||||
workspace = params[0] if params and len(params) > 0 else None
|
||||
if workspace == "workspace_b":
|
||||
return {"count": inserted_count["workspace_b"]}
|
||||
return {"count": migration_state["workspace_b_migration_count"]}
|
||||
elif workspace == "workspace_a":
|
||||
return {"count": 3}
|
||||
else:
|
||||
# Total count in new table (for verification)
|
||||
return {"count": 3 + inserted_count["workspace_b"]}
|
||||
return {"count": 3 + migration_state["workspace_b_migration_count"]}
|
||||
elif multirows and "SELECT *" in sql:
|
||||
if "WHERE workspace" in sql:
|
||||
workspace = params[0] if params and len(params) > 0 else None
|
||||
@@ -766,40 +777,32 @@ async def test_case1_sequential_workspace_migration(
|
||||
return {}
|
||||
|
||||
mock_pg_db.query = AsyncMock(side_effect=mock_query_b)
|
||||
mock_pg_db.execute = AsyncMock(side_effect=mock_execute_b)
|
||||
|
||||
# Initialize workspace_b (Case 1)
|
||||
# Track migration via _run_with_retry for workspace_b
|
||||
migration_b_executed = []
|
||||
|
||||
async def mock_run_with_retry_b(operation, **kwargs):
|
||||
migration_b_executed.append(True)
|
||||
migration_state["workspace_b_migration_count"] = len(mock_rows_b)
|
||||
return None
|
||||
|
||||
mock_pg_db._run_with_retry = AsyncMock(side_effect=mock_run_with_retry_b)
|
||||
|
||||
# Initialize workspace_b (Case 3 - both tables exist)
|
||||
with patch(
|
||||
"lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists_b
|
||||
):
|
||||
await storage_b.initialize()
|
||||
migration_state["workspace_b_migrated"] = True
|
||||
|
||||
print("✅ Step 2: Workspace B initialized (Case 1)")
|
||||
print("✅ Step 2: Workspace B initialized")
|
||||
|
||||
# Verify workspace_b migration happened
|
||||
execute_calls = mock_pg_db.execute.call_args_list
|
||||
insert_calls = [
|
||||
call for call in execute_calls if call[0][0] and "INSERT INTO" in call[0][0]
|
||||
]
|
||||
assert len(insert_calls) >= 3, f"Expected >= 3 inserts, got {len(insert_calls)}"
|
||||
print(f"✅ Step 2: {len(insert_calls)} insert calls")
|
||||
# Verify workspace_b migration happens when new table has no workspace_b data
|
||||
# but legacy table still has workspace_b data.
|
||||
assert (
|
||||
len(migration_b_executed) > 0
|
||||
), "Migration should have been executed for workspace_b"
|
||||
print("✅ Step 2: Migration executed for workspace_b")
|
||||
|
||||
# Verify DELETE and DROP TABLE
|
||||
delete_calls = [
|
||||
call
|
||||
for call in execute_calls
|
||||
if call[0][0]
|
||||
and "DELETE FROM" in call[0][0]
|
||||
and "WHERE workspace" in call[0][0]
|
||||
]
|
||||
assert len(delete_calls) >= 1, "Expected DELETE workspace_b data"
|
||||
print("✅ Step 2: DELETE workspace_b from legacy")
|
||||
|
||||
drop_calls = [
|
||||
call for call in execute_calls if call[0][0] and "DROP TABLE" in call[0][0]
|
||||
]
|
||||
assert len(drop_calls) >= 1, "Expected DROP TABLE"
|
||||
print("✅ Step 2: Legacy table dropped")
|
||||
|
||||
print("\n🎉 Case 1c: Sequential workspace migration verified!")
|
||||
print("\n🎉 Case 1c: Sequential workspace migration verification complete!")
|
||||
print(" - Workspace A: Migrated successfully (only legacy existed)")
|
||||
print(" - Workspace B: Migrated successfully (new table empty for workspace_b)")
|
||||
|
||||
Reference in New Issue
Block a user