feat: Flatten LLM cache structure for improved recall efficiency
Refactored the LLM cache to a flat Key-Value (KV) structure, replacing the previous nested format. The old structure used the 'mode' as a key and stored specific cache content as JSON nested under it. This change significantly enhances cache recall efficiency.
This commit is contained in:
@@ -52,18 +52,23 @@ async def copy_from_postgres_to_json():
|
||||
embedding_func=None,
|
||||
)
|
||||
|
||||
# Get all cache data using the new flattened structure
|
||||
all_data = await from_llm_response_cache.get_all()
|
||||
|
||||
# Convert flattened data to hierarchical structure for JsonKVStorage
|
||||
kv = {}
|
||||
for c_id in await from_llm_response_cache.all_keys():
|
||||
print(f"Copying {c_id}")
|
||||
workspace = c_id["workspace"]
|
||||
mode = c_id["mode"]
|
||||
_id = c_id["id"]
|
||||
postgres_db.workspace = workspace
|
||||
obj = await from_llm_response_cache.get_by_mode_and_id(mode, _id)
|
||||
if mode not in kv:
|
||||
kv[mode] = {}
|
||||
kv[mode][_id] = obj[_id]
|
||||
print(f"Object {obj}")
|
||||
for flattened_key, cache_entry in all_data.items():
|
||||
# Parse flattened key: {mode}:{cache_type}:{hash}
|
||||
parts = flattened_key.split(":", 2)
|
||||
if len(parts) == 3:
|
||||
mode, cache_type, hash_value = parts
|
||||
if mode not in kv:
|
||||
kv[mode] = {}
|
||||
kv[mode][hash_value] = cache_entry
|
||||
print(f"Copying {flattened_key} -> {mode}[{hash_value}]")
|
||||
else:
|
||||
print(f"Skipping invalid key format: {flattened_key}")
|
||||
|
||||
await to_llm_response_cache.upsert(kv)
|
||||
await to_llm_response_cache.index_done_callback()
|
||||
print("Mission accomplished!")
|
||||
@@ -85,13 +90,24 @@ async def copy_from_json_to_postgres():
|
||||
db=postgres_db,
|
||||
)
|
||||
|
||||
for mode in await from_llm_response_cache.all_keys():
|
||||
print(f"Copying {mode}")
|
||||
caches = await from_llm_response_cache.get_by_id(mode)
|
||||
for k, v in caches.items():
|
||||
item = {mode: {k: v}}
|
||||
print(f"\tCopying {item}")
|
||||
await to_llm_response_cache.upsert(item)
|
||||
# Get all cache data from JsonKVStorage (hierarchical structure)
|
||||
all_data = await from_llm_response_cache.get_all()
|
||||
|
||||
# Convert hierarchical data to flattened structure for PGKVStorage
|
||||
flattened_data = {}
|
||||
for mode, mode_data in all_data.items():
|
||||
print(f"Processing mode: {mode}")
|
||||
for hash_value, cache_entry in mode_data.items():
|
||||
# Determine cache_type from cache entry or use default
|
||||
cache_type = cache_entry.get("cache_type", "extract")
|
||||
# Create flattened key: {mode}:{cache_type}:{hash}
|
||||
flattened_key = f"{mode}:{cache_type}:{hash_value}"
|
||||
flattened_data[flattened_key] = cache_entry
|
||||
print(f"\tConverting {mode}[{hash_value}] -> {flattened_key}")
|
||||
|
||||
# Upsert the flattened data
|
||||
await to_llm_response_cache.upsert(flattened_data)
|
||||
print("Mission accomplished!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -37,6 +37,7 @@ STORAGE_IMPLEMENTATIONS = {
|
||||
"DOC_STATUS_STORAGE": {
|
||||
"implementations": [
|
||||
"JsonDocStatusStorage",
|
||||
"RedisDocStatusStorage",
|
||||
"PGDocStatusStorage",
|
||||
"MongoDocStatusStorage",
|
||||
],
|
||||
@@ -79,6 +80,7 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
|
||||
"MongoVectorDBStorage": [],
|
||||
# Document Status Storage Implementations
|
||||
"JsonDocStatusStorage": [],
|
||||
"RedisDocStatusStorage": ["REDIS_URI"],
|
||||
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||
"MongoDocStatusStorage": [],
|
||||
}
|
||||
@@ -96,6 +98,7 @@ STORAGES = {
|
||||
"MongoGraphStorage": ".kg.mongo_impl",
|
||||
"MongoVectorDBStorage": ".kg.mongo_impl",
|
||||
"RedisKVStorage": ".kg.redis_impl",
|
||||
"RedisDocStatusStorage": ".kg.redis_impl",
|
||||
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
||||
# "TiDBKVStorage": ".kg.tidb_impl",
|
||||
# "TiDBVectorDBStorage": ".kg.tidb_impl",
|
||||
|
||||
@@ -109,7 +109,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
raise
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
|
||||
@@ -234,7 +234,6 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
ids: List of vector IDs to be deleted
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
|
||||
self._collection.delete(ids=ids)
|
||||
logger.debug(
|
||||
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
||||
|
||||
@@ -42,19 +42,14 @@ class JsonKVStorage(BaseKVStorage):
|
||||
if need_init:
|
||||
loaded_data = load_json(self._file_name) or {}
|
||||
async with self._storage_lock:
|
||||
self._data.update(loaded_data)
|
||||
|
||||
# Calculate data count based on namespace
|
||||
if self.namespace.endswith("cache"):
|
||||
# For cache namespaces, sum the cache entries across all cache types
|
||||
data_count = sum(
|
||||
len(first_level_dict)
|
||||
for first_level_dict in loaded_data.values()
|
||||
if isinstance(first_level_dict, dict)
|
||||
# Migrate legacy cache structure if needed
|
||||
if self.namespace.endswith("_cache"):
|
||||
loaded_data = await self._migrate_legacy_cache_structure(
|
||||
loaded_data
|
||||
)
|
||||
else:
|
||||
# For non-cache namespaces, use the original count method
|
||||
data_count = len(loaded_data)
|
||||
|
||||
self._data.update(loaded_data)
|
||||
data_count = len(loaded_data)
|
||||
|
||||
logger.info(
|
||||
f"Process {os.getpid()} KV load {self.namespace} with {data_count} records"
|
||||
@@ -67,17 +62,8 @@ class JsonKVStorage(BaseKVStorage):
|
||||
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
||||
)
|
||||
|
||||
# Calculate data count based on namespace
|
||||
if self.namespace.endswith("cache"):
|
||||
# # For cache namespaces, sum the cache entries across all cache types
|
||||
data_count = sum(
|
||||
len(first_level_dict)
|
||||
for first_level_dict in data_dict.values()
|
||||
if isinstance(first_level_dict, dict)
|
||||
)
|
||||
else:
|
||||
# For non-cache namespaces, use the original count method
|
||||
data_count = len(data_dict)
|
||||
# Calculate data count - all data is now flattened
|
||||
data_count = len(data_dict)
|
||||
|
||||
logger.debug(
|
||||
f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}"
|
||||
@@ -150,14 +136,14 @@ class JsonKVStorage(BaseKVStorage):
|
||||
await set_all_update_flags(self.namespace)
|
||||
|
||||
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
||||
"""Delete specific records from storage by by cache mode
|
||||
"""Delete specific records from storage by cache mode
|
||||
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. update flags to notify other processes that data persistence is needed
|
||||
|
||||
Args:
|
||||
ids (list[str]): List of cache mode to be drop from storage
|
||||
modes (list[str]): List of cache modes to be dropped from storage
|
||||
|
||||
Returns:
|
||||
True: if the cache drop successfully
|
||||
@@ -167,9 +153,29 @@ class JsonKVStorage(BaseKVStorage):
|
||||
return False
|
||||
|
||||
try:
|
||||
await self.delete(modes)
|
||||
async with self._storage_lock:
|
||||
keys_to_delete = []
|
||||
modes_set = set(modes) # Convert to set for efficient lookup
|
||||
|
||||
for key in list(self._data.keys()):
|
||||
# Parse flattened cache key: mode:cache_type:hash
|
||||
parts = key.split(":", 2)
|
||||
if len(parts) == 3 and parts[0] in modes_set:
|
||||
keys_to_delete.append(key)
|
||||
|
||||
# Batch delete
|
||||
for key in keys_to_delete:
|
||||
self._data.pop(key, None)
|
||||
|
||||
if keys_to_delete:
|
||||
await set_all_update_flags(self.namespace)
|
||||
logger.info(
|
||||
f"Dropped {len(keys_to_delete)} cache entries for modes: {modes}"
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping cache by modes: {e}")
|
||||
return False
|
||||
|
||||
# async def drop_cache_by_chunk_ids(self, chunk_ids: list[str] | None = None) -> bool:
|
||||
@@ -245,9 +251,58 @@ class JsonKVStorage(BaseKVStorage):
|
||||
logger.error(f"Error dropping {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def _migrate_legacy_cache_structure(self, data: dict) -> dict:
|
||||
"""Migrate legacy nested cache structure to flattened structure
|
||||
|
||||
Args:
|
||||
data: Original data dictionary that may contain legacy structure
|
||||
|
||||
Returns:
|
||||
Migrated data dictionary with flattened cache keys
|
||||
"""
|
||||
from lightrag.utils import generate_cache_key
|
||||
|
||||
# Early return if data is empty
|
||||
if not data:
|
||||
return data
|
||||
|
||||
# Check first entry to see if it's already in new format
|
||||
first_key = next(iter(data.keys()))
|
||||
if ":" in first_key and len(first_key.split(":")) == 3:
|
||||
# Already in flattened format, return as-is
|
||||
return data
|
||||
|
||||
migrated_data = {}
|
||||
migration_count = 0
|
||||
|
||||
for key, value in data.items():
|
||||
# Check if this is a legacy nested cache structure
|
||||
if isinstance(value, dict) and all(
|
||||
isinstance(v, dict) and "return" in v for v in value.values()
|
||||
):
|
||||
# This looks like a legacy cache mode with nested structure
|
||||
mode = key
|
||||
for cache_hash, cache_entry in value.items():
|
||||
cache_type = cache_entry.get("cache_type", "extract")
|
||||
flattened_key = generate_cache_key(mode, cache_type, cache_hash)
|
||||
migrated_data[flattened_key] = cache_entry
|
||||
migration_count += 1
|
||||
else:
|
||||
# Keep non-cache data or already flattened cache data as-is
|
||||
migrated_data[key] = value
|
||||
|
||||
if migration_count > 0:
|
||||
logger.info(
|
||||
f"Migrated {migration_count} legacy cache entries to flattened structure"
|
||||
)
|
||||
# Persist migrated data immediately
|
||||
write_json(migrated_data, self._file_name)
|
||||
|
||||
return migrated_data
|
||||
|
||||
async def finalize(self):
|
||||
"""Finalize storage resources
|
||||
Persistence cache data to disk before exiting
|
||||
"""
|
||||
if self.namespace.endswith("cache"):
|
||||
if self.namespace.endswith("_cache"):
|
||||
await self.index_done_callback()
|
||||
|
||||
@@ -75,7 +75,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ from ..base import (
|
||||
DocStatus,
|
||||
DocStatusStorage,
|
||||
)
|
||||
from ..namespace import NameSpace, is_namespace
|
||||
from ..utils import logger, compute_mdhash_id
|
||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||
from ..constants import GRAPH_FIELD_SEP
|
||||
@@ -98,17 +97,8 @@ class MongoKVStorage(BaseKVStorage):
|
||||
self._data = None
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
if id == "default":
|
||||
# Find all documents with _id starting with "default_"
|
||||
cursor = self._data.find({"_id": {"$regex": "^default_"}})
|
||||
result = {}
|
||||
async for doc in cursor:
|
||||
# Use the complete _id as key
|
||||
result[doc["_id"]] = doc
|
||||
return result if result else None
|
||||
else:
|
||||
# Original behavior for non-"default" ids
|
||||
return await self._data.find_one({"_id": id})
|
||||
# Unified handling for flattened keys
|
||||
return await self._data.find_one({"_id": id})
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
cursor = self._data.find({"_id": {"$in": ids}})
|
||||
@@ -133,43 +123,21 @@ class MongoKVStorage(BaseKVStorage):
|
||||
return result
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||
update_tasks: list[Any] = []
|
||||
for mode, items in data.items():
|
||||
for k, v in items.items():
|
||||
key = f"{mode}_{k}"
|
||||
data[mode][k]["_id"] = f"{mode}_{k}"
|
||||
update_tasks.append(
|
||||
self._data.update_one(
|
||||
{"_id": key}, {"$setOnInsert": v}, upsert=True
|
||||
)
|
||||
)
|
||||
await asyncio.gather(*update_tasks)
|
||||
else:
|
||||
update_tasks = []
|
||||
for k, v in data.items():
|
||||
data[k]["_id"] = k
|
||||
update_tasks.append(
|
||||
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
|
||||
)
|
||||
await asyncio.gather(*update_tasks)
|
||||
# Unified handling for all namespaces with flattened keys
|
||||
# Use bulk_write for better performance
|
||||
from pymongo import UpdateOne
|
||||
|
||||
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||
res = {}
|
||||
v = await self._data.find_one({"_id": mode + "_" + id})
|
||||
if v:
|
||||
res[id] = v
|
||||
logger.debug(f"llm_response_cache find one by:{id}")
|
||||
return res
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
operations = []
|
||||
for k, v in data.items():
|
||||
v["_id"] = k # Use flattened key as _id
|
||||
operations.append(UpdateOne({"_id": k}, {"$set": v}, upsert=True))
|
||||
|
||||
if operations:
|
||||
await self._data.bulk_write(operations)
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Mongo handles persistence automatically
|
||||
@@ -209,8 +177,8 @@ class MongoKVStorage(BaseKVStorage):
|
||||
return False
|
||||
|
||||
try:
|
||||
# Build regex pattern to match documents with the specified modes
|
||||
pattern = f"^({'|'.join(modes)})_"
|
||||
# Build regex pattern to match flattened key format: mode:cache_type:hash
|
||||
pattern = f"^({'|'.join(modes)}):"
|
||||
result = await self._data.delete_many({"_id": {"$regex": pattern}})
|
||||
logger.info(f"Deleted {result.deleted_count} documents by modes: {modes}")
|
||||
return True
|
||||
@@ -274,7 +242,7 @@ class MongoDocStatusStorage(DocStatusStorage):
|
||||
return data - existing_ids
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
update_tasks: list[Any] = []
|
||||
@@ -1282,7 +1250,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||
logger.debug("vector index already exist")
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
|
||||
@@ -1371,7 +1339,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||
Args:
|
||||
ids: List of vector IDs to be deleted
|
||||
"""
|
||||
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
|
||||
logger.debug(f"Deleting {len(ids)} vectors from {self.namespace}")
|
||||
if not ids:
|
||||
return
|
||||
|
||||
|
||||
@@ -247,6 +247,116 @@ class PostgreSQLDB:
|
||||
logger.error(f"Failed during data migration to LIGHTRAG_VDB_CHUNKS: {e}")
|
||||
# Do not re-raise, to allow the application to start
|
||||
|
||||
async def _check_llm_cache_needs_migration(self):
|
||||
"""Check if LLM cache data needs migration by examining the first record"""
|
||||
try:
|
||||
# Only query the first record to determine format
|
||||
check_sql = """
|
||||
SELECT id FROM LIGHTRAG_LLM_CACHE
|
||||
ORDER BY create_time ASC
|
||||
LIMIT 1
|
||||
"""
|
||||
result = await self.query(check_sql)
|
||||
|
||||
if result and result.get("id"):
|
||||
# If id doesn't contain colon, it's old format
|
||||
return ":" not in result["id"]
|
||||
|
||||
return False # No data or already new format
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to check LLM cache migration status: {e}")
|
||||
return False
|
||||
|
||||
async def _migrate_llm_cache_to_flattened_keys(self):
|
||||
"""Migrate LLM cache to flattened key format, recalculating hash values"""
|
||||
try:
|
||||
# Get all old format data
|
||||
old_data_sql = """
|
||||
SELECT id, mode, original_prompt, return_value, chunk_id,
|
||||
create_time, update_time
|
||||
FROM LIGHTRAG_LLM_CACHE
|
||||
WHERE id NOT LIKE '%:%'
|
||||
"""
|
||||
|
||||
old_records = await self.query(old_data_sql, multirows=True)
|
||||
|
||||
if not old_records:
|
||||
logger.info("No old format LLM cache data found, skipping migration")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Found {len(old_records)} old format cache records, starting migration..."
|
||||
)
|
||||
|
||||
# Import hash calculation function
|
||||
from ..utils import compute_args_hash
|
||||
|
||||
migrated_count = 0
|
||||
|
||||
# Migrate data in batches
|
||||
for record in old_records:
|
||||
try:
|
||||
# Recalculate hash using correct method
|
||||
new_hash = compute_args_hash(
|
||||
record["mode"], record["original_prompt"]
|
||||
)
|
||||
|
||||
# Generate new flattened key
|
||||
cache_type = "extract" # Default type
|
||||
new_key = f"{record['mode']}:{cache_type}:{new_hash}"
|
||||
|
||||
# Insert new format data
|
||||
insert_sql = """
|
||||
INSERT INTO LIGHTRAG_LLM_CACHE
|
||||
(workspace, id, mode, original_prompt, return_value, chunk_id, create_time, update_time)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
ON CONFLICT (workspace, mode, id) DO NOTHING
|
||||
"""
|
||||
|
||||
await self.execute(
|
||||
insert_sql,
|
||||
{
|
||||
"workspace": self.workspace,
|
||||
"id": new_key,
|
||||
"mode": record["mode"],
|
||||
"original_prompt": record["original_prompt"],
|
||||
"return_value": record["return_value"],
|
||||
"chunk_id": record["chunk_id"],
|
||||
"create_time": record["create_time"],
|
||||
"update_time": record["update_time"],
|
||||
},
|
||||
)
|
||||
|
||||
# Delete old data
|
||||
delete_sql = """
|
||||
DELETE FROM LIGHTRAG_LLM_CACHE
|
||||
WHERE workspace=$1 AND mode=$2 AND id=$3
|
||||
"""
|
||||
await self.execute(
|
||||
delete_sql,
|
||||
{
|
||||
"workspace": self.workspace,
|
||||
"mode": record["mode"],
|
||||
"id": record["id"], # Old id
|
||||
},
|
||||
)
|
||||
|
||||
migrated_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to migrate cache record {record['id']}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"Successfully migrated {migrated_count} cache records to flattened format"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM cache migration failed: {e}")
|
||||
# Don't raise exception, allow system to continue startup
|
||||
|
||||
async def check_tables(self):
|
||||
# First create all tables
|
||||
for k, v in TABLES.items():
|
||||
@@ -304,6 +414,13 @@ class PostgreSQLDB:
|
||||
except Exception as e:
|
||||
logger.error(f"PostgreSQL, Failed to migrate doc_chunks to vdb_chunks: {e}")
|
||||
|
||||
# Check and migrate LLM cache to flattened keys if needed
|
||||
try:
|
||||
if await self._check_llm_cache_needs_migration():
|
||||
await self._migrate_llm_cache_to_flattened_keys()
|
||||
except Exception as e:
|
||||
logger.error(f"PostgreSQL, LLM cache migration failed: {e}")
|
||||
|
||||
async def query(
|
||||
self,
|
||||
sql: str,
|
||||
@@ -486,77 +603,48 @@ class PGKVStorage(BaseKVStorage):
|
||||
|
||||
try:
|
||||
results = await self.db.query(sql, params, multirows=True)
|
||||
|
||||
|
||||
# Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||
result_dict = {}
|
||||
processed_results = {}
|
||||
for row in results:
|
||||
mode = row["mode"]
|
||||
if mode not in result_dict:
|
||||
result_dict[mode] = {}
|
||||
result_dict[mode][row["id"]] = row
|
||||
return result_dict
|
||||
else:
|
||||
return {row["id"]: row for row in results}
|
||||
# Parse flattened key to extract cache_type
|
||||
key_parts = row["id"].split(":")
|
||||
cache_type = key_parts[1] if len(key_parts) >= 3 else "unknown"
|
||||
|
||||
# Map field names and add cache_type for compatibility
|
||||
processed_row = {
|
||||
**row,
|
||||
"return": row.get("return_value", ""), # Map return_value to return
|
||||
"cache_type": cache_type, # Add cache_type from key
|
||||
"original_prompt": row.get("original_prompt", ""),
|
||||
"chunk_id": row.get("chunk_id"),
|
||||
"mode": row.get("mode", "default")
|
||||
}
|
||||
processed_results[row["id"]] = processed_row
|
||||
return processed_results
|
||||
|
||||
# For other namespaces, return as-is
|
||||
return {row["id"]: row for row in results}
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving all data from {self.namespace}: {e}")
|
||||
return {}
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get doc_full data by id."""
|
||||
"""Get data by id."""
|
||||
sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||
# For LLM cache, the id parameter actually represents the mode
|
||||
params = {"workspace": self.db.workspace, "mode": id}
|
||||
array_res = await self.db.query(sql, params, multirows=True)
|
||||
res = {}
|
||||
for row in array_res:
|
||||
# Dynamically add cache_type field based on mode
|
||||
row_with_cache_type = dict(row)
|
||||
if id == "default":
|
||||
row_with_cache_type["cache_type"] = "extract"
|
||||
else:
|
||||
row_with_cache_type["cache_type"] = "unknown"
|
||||
res[row["id"]] = row_with_cache_type
|
||||
return res if res else None
|
||||
else:
|
||||
params = {"workspace": self.db.workspace, "id": id}
|
||||
response = await self.db.query(sql, params)
|
||||
return response if response else None
|
||||
|
||||
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
||||
"""Specifically for llm_response_cache."""
|
||||
sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
|
||||
params = {"workspace": self.db.workspace, "mode": mode, "id": id}
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||
array_res = await self.db.query(sql, params, multirows=True)
|
||||
res = {}
|
||||
for row in array_res:
|
||||
res[row["id"]] = row
|
||||
return res
|
||||
else:
|
||||
return None
|
||||
params = {"workspace": self.db.workspace, "id": id}
|
||||
response = await self.db.query(sql, params)
|
||||
return response if response else None
|
||||
|
||||
# Query by id
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get doc_chunks data by id"""
|
||||
"""Get data by ids"""
|
||||
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
||||
ids=",".join([f"'{id}'" for id in ids])
|
||||
)
|
||||
params = {"workspace": self.db.workspace}
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||
array_res = await self.db.query(sql, params, multirows=True)
|
||||
modes = set()
|
||||
dict_res: dict[str, dict] = {}
|
||||
for row in array_res:
|
||||
modes.add(row["mode"])
|
||||
for mode in modes:
|
||||
if mode not in dict_res:
|
||||
dict_res[mode] = {}
|
||||
for row in array_res:
|
||||
dict_res[row["mode"]][row["id"]] = row
|
||||
return [{k: v} for k, v in dict_res.items()]
|
||||
else:
|
||||
return await self.db.query(sql, params, multirows=True)
|
||||
return await self.db.query(sql, params, multirows=True)
|
||||
|
||||
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
|
||||
"""Specifically for llm_response_cache."""
|
||||
@@ -617,19 +705,18 @@ class PGKVStorage(BaseKVStorage):
|
||||
}
|
||||
await self.db.execute(upsert_sql, _data)
|
||||
elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||
for mode, items in data.items():
|
||||
for k, v in items.items():
|
||||
upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
|
||||
_data = {
|
||||
"workspace": self.db.workspace,
|
||||
"id": k,
|
||||
"original_prompt": v["original_prompt"],
|
||||
"return_value": v["return"],
|
||||
"mode": mode,
|
||||
"chunk_id": v.get("chunk_id"),
|
||||
}
|
||||
for k, v in data.items():
|
||||
upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
|
||||
_data = {
|
||||
"workspace": self.db.workspace,
|
||||
"id": k, # Use flattened key as id
|
||||
"original_prompt": v["original_prompt"],
|
||||
"return_value": v["return"],
|
||||
"mode": v.get("mode", "default"), # Get mode from data
|
||||
"chunk_id": v.get("chunk_id"),
|
||||
}
|
||||
|
||||
await self.db.execute(upsert_sql, _data)
|
||||
await self.db.execute(upsert_sql, _data)
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# PG handles persistence automatically
|
||||
@@ -1035,8 +1122,8 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||
else:
|
||||
exist_keys = []
|
||||
new_keys = set([s for s in keys if s not in exist_keys])
|
||||
print(f"keys: {keys}")
|
||||
print(f"new_keys: {new_keys}")
|
||||
# print(f"keys: {keys}")
|
||||
# print(f"new_keys: {new_keys}")
|
||||
return new_keys
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
@@ -2621,7 +2708,7 @@ SQL_TEMPLATES = {
|
||||
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
|
||||
""",
|
||||
"get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id
|
||||
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2
|
||||
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id=$2
|
||||
""",
|
||||
"get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id
|
||||
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 AND id=$3
|
||||
|
||||
@@ -85,7 +85,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import os
|
||||
from typing import Any, final
|
||||
from typing import Any, final, Union
|
||||
from dataclasses import dataclass
|
||||
import pipmaster as pm
|
||||
import configparser
|
||||
from contextlib import asynccontextmanager
|
||||
import threading
|
||||
|
||||
if not pm.is_installed("redis"):
|
||||
pm.install("redis")
|
||||
@@ -13,7 +14,7 @@ from redis.asyncio import Redis, ConnectionPool # type: ignore
|
||||
from redis.exceptions import RedisError, ConnectionError # type: ignore
|
||||
from lightrag.utils import logger
|
||||
|
||||
from lightrag.base import BaseKVStorage
|
||||
from lightrag.base import BaseKVStorage, DocStatusStorage, DocStatus, DocProcessingStatus
|
||||
import json
|
||||
|
||||
|
||||
@@ -26,6 +27,41 @@ SOCKET_TIMEOUT = 5.0
|
||||
SOCKET_CONNECT_TIMEOUT = 3.0
|
||||
|
||||
|
||||
class RedisConnectionManager:
|
||||
"""Shared Redis connection pool manager to avoid creating multiple pools for the same Redis URI"""
|
||||
|
||||
_pools = {}
|
||||
_lock = threading.Lock()
|
||||
|
||||
@classmethod
|
||||
def get_pool(cls, redis_url: str) -> ConnectionPool:
|
||||
"""Get or create a connection pool for the given Redis URL"""
|
||||
if redis_url not in cls._pools:
|
||||
with cls._lock:
|
||||
if redis_url not in cls._pools:
|
||||
cls._pools[redis_url] = ConnectionPool.from_url(
|
||||
redis_url,
|
||||
max_connections=MAX_CONNECTIONS,
|
||||
decode_responses=True,
|
||||
socket_timeout=SOCKET_TIMEOUT,
|
||||
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
|
||||
)
|
||||
logger.info(f"Created shared Redis connection pool for {redis_url}")
|
||||
return cls._pools[redis_url]
|
||||
|
||||
@classmethod
|
||||
def close_all_pools(cls):
|
||||
"""Close all connection pools (for cleanup)"""
|
||||
with cls._lock:
|
||||
for url, pool in cls._pools.items():
|
||||
try:
|
||||
pool.disconnect()
|
||||
logger.info(f"Closed Redis connection pool for {url}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing Redis pool for {url}: {e}")
|
||||
cls._pools.clear()
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class RedisKVStorage(BaseKVStorage):
|
||||
@@ -33,19 +69,28 @@ class RedisKVStorage(BaseKVStorage):
|
||||
redis_url = os.environ.get(
|
||||
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
|
||||
)
|
||||
# Create a connection pool with limits
|
||||
self._pool = ConnectionPool.from_url(
|
||||
redis_url,
|
||||
max_connections=MAX_CONNECTIONS,
|
||||
decode_responses=True,
|
||||
socket_timeout=SOCKET_TIMEOUT,
|
||||
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
|
||||
)
|
||||
# Use shared connection pool
|
||||
self._pool = RedisConnectionManager.get_pool(redis_url)
|
||||
self._redis = Redis(connection_pool=self._pool)
|
||||
logger.info(
|
||||
f"Initialized Redis connection pool for {self.namespace} with max {MAX_CONNECTIONS} connections"
|
||||
f"Initialized Redis KV storage for {self.namespace} using shared connection pool"
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize Redis connection and migrate legacy cache structure if needed"""
|
||||
# Test connection
|
||||
try:
|
||||
async with self._get_redis_connection() as redis:
|
||||
await redis.ping()
|
||||
logger.info(f"Connected to Redis for namespace {self.namespace}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Redis: {e}")
|
||||
raise
|
||||
|
||||
# Migrate legacy cache structure if this is a cache namespace
|
||||
if self.namespace.endswith("_cache"):
|
||||
await self._migrate_legacy_cache_structure()
|
||||
|
||||
@asynccontextmanager
|
||||
async def _get_redis_connection(self):
|
||||
"""Safe context manager for Redis operations."""
|
||||
@@ -99,21 +144,57 @@ class RedisKVStorage(BaseKVStorage):
|
||||
logger.error(f"JSON decode error in batch get: {e}")
|
||||
return [None] * len(ids)
|
||||
|
||||
async def get_all(self) -> dict[str, Any]:
|
||||
"""Get all data from storage
|
||||
|
||||
Returns:
|
||||
Dictionary containing all stored data
|
||||
"""
|
||||
async with self._get_redis_connection() as redis:
|
||||
try:
|
||||
# Get all keys for this namespace
|
||||
keys = await redis.keys(f"{self.namespace}:*")
|
||||
|
||||
if not keys:
|
||||
return {}
|
||||
|
||||
# Get all values in batch
|
||||
pipe = redis.pipeline()
|
||||
for key in keys:
|
||||
pipe.get(key)
|
||||
values = await pipe.execute()
|
||||
|
||||
# Build result dictionary
|
||||
result = {}
|
||||
for key, value in zip(keys, values):
|
||||
if value:
|
||||
# Extract the ID part (after namespace:)
|
||||
key_id = key.split(":", 1)[1]
|
||||
try:
|
||||
result[key_id] = json.loads(value)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON decode error for key {key}: {e}")
|
||||
continue
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting all data from Redis: {e}")
|
||||
return {}
|
||||
|
||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||
async with self._get_redis_connection() as redis:
|
||||
pipe = redis.pipeline()
|
||||
for key in keys:
|
||||
keys_list = list(keys) # Convert set to list for indexing
|
||||
for key in keys_list:
|
||||
pipe.exists(f"{self.namespace}:{key}")
|
||||
results = await pipe.execute()
|
||||
|
||||
existing_ids = {keys[i] for i, exists in enumerate(results) if exists}
|
||||
existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists}
|
||||
return set(keys) - existing_ids
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
if not data:
|
||||
return
|
||||
|
||||
logger.info(f"Inserting {len(data)} items to {self.namespace}")
|
||||
async with self._get_redis_connection() as redis:
|
||||
try:
|
||||
pipe = redis.pipeline()
|
||||
@@ -148,13 +229,13 @@ class RedisKVStorage(BaseKVStorage):
|
||||
)
|
||||
|
||||
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
||||
"""Delete specific records from storage by by cache mode
|
||||
"""Delete specific records from storage by cache mode
|
||||
|
||||
Importance notes for Redis storage:
|
||||
1. This will immediately delete the specified cache modes from Redis
|
||||
|
||||
Args:
|
||||
modes (list[str]): List of cache mode to be drop from storage
|
||||
modes (list[str]): List of cache modes to be dropped from storage
|
||||
|
||||
Returns:
|
||||
True: if the cache drop successfully
|
||||
@@ -164,9 +245,43 @@ class RedisKVStorage(BaseKVStorage):
|
||||
return False
|
||||
|
||||
try:
|
||||
await self.delete(modes)
|
||||
async with self._get_redis_connection() as redis:
|
||||
keys_to_delete = []
|
||||
|
||||
# Find matching keys for each mode using SCAN
|
||||
for mode in modes:
|
||||
# Use correct pattern to match flattened cache key format {namespace}:{mode}:{cache_type}:{hash}
|
||||
pattern = f"{self.namespace}:{mode}:*"
|
||||
cursor = 0
|
||||
mode_keys = []
|
||||
|
||||
while True:
|
||||
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
|
||||
if keys:
|
||||
mode_keys.extend(keys)
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
keys_to_delete.extend(mode_keys)
|
||||
logger.info(f"Found {len(mode_keys)} keys for mode '{mode}' with pattern '{pattern}'")
|
||||
|
||||
if keys_to_delete:
|
||||
# Batch delete
|
||||
pipe = redis.pipeline()
|
||||
for key in keys_to_delete:
|
||||
pipe.delete(key)
|
||||
results = await pipe.execute()
|
||||
deleted_count = sum(results)
|
||||
logger.info(
|
||||
f"Dropped {deleted_count} cache entries for modes: {modes}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"No cache entries found for modes: {modes}")
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping cache by modes in Redis: {e}")
|
||||
return False
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
@@ -177,24 +292,350 @@ class RedisKVStorage(BaseKVStorage):
|
||||
"""
|
||||
async with self._get_redis_connection() as redis:
|
||||
try:
|
||||
keys = await redis.keys(f"{self.namespace}:*")
|
||||
# Use SCAN to find all keys with the namespace prefix
|
||||
pattern = f"{self.namespace}:*"
|
||||
cursor = 0
|
||||
deleted_count = 0
|
||||
|
||||
while True:
|
||||
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
|
||||
if keys:
|
||||
# Delete keys in batches
|
||||
pipe = redis.pipeline()
|
||||
for key in keys:
|
||||
pipe.delete(key)
|
||||
results = await pipe.execute()
|
||||
deleted_count += sum(results)
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
if keys:
|
||||
pipe = redis.pipeline()
|
||||
for key in keys:
|
||||
pipe.delete(key)
|
||||
results = await pipe.execute()
|
||||
deleted_count = sum(results)
|
||||
|
||||
logger.info(f"Dropped {deleted_count} keys from {self.namespace}")
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"{deleted_count} keys dropped",
|
||||
}
|
||||
else:
|
||||
logger.info(f"No keys found to drop in {self.namespace}")
|
||||
return {"status": "success", "message": "no keys to drop"}
|
||||
logger.info(f"Dropped {deleted_count} keys from {self.namespace}")
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"{deleted_count} keys dropped",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping keys from {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def _migrate_legacy_cache_structure(self):
|
||||
"""Migrate legacy nested cache structure to flattened structure for Redis
|
||||
|
||||
Redis already stores data in a flattened way, but we need to check for
|
||||
legacy keys that might contain nested JSON structures and migrate them.
|
||||
|
||||
Early exit if any flattened key is found (indicating migration already done).
|
||||
"""
|
||||
from lightrag.utils import generate_cache_key
|
||||
|
||||
async with self._get_redis_connection() as redis:
|
||||
# Get all keys for this namespace
|
||||
keys = await redis.keys(f"{self.namespace}:*")
|
||||
|
||||
if not keys:
|
||||
return
|
||||
|
||||
# Check if we have any flattened keys already - if so, skip migration
|
||||
has_flattened_keys = False
|
||||
keys_to_migrate = []
|
||||
|
||||
for key in keys:
|
||||
# Extract the ID part (after namespace:)
|
||||
key_id = key.split(":", 1)[1]
|
||||
|
||||
# Check if already in flattened format (contains exactly 2 colons for mode:cache_type:hash)
|
||||
if ":" in key_id and len(key_id.split(":")) == 3:
|
||||
has_flattened_keys = True
|
||||
break # Early exit - migration already done
|
||||
|
||||
# Get the data to check if it's a legacy nested structure
|
||||
data = await redis.get(key)
|
||||
if data:
|
||||
try:
|
||||
parsed_data = json.loads(data)
|
||||
# Check if this looks like a legacy cache mode with nested structure
|
||||
if isinstance(parsed_data, dict) and all(
|
||||
isinstance(v, dict) and "return" in v
|
||||
for v in parsed_data.values()
|
||||
):
|
||||
keys_to_migrate.append((key, key_id, parsed_data))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# If we found any flattened keys, assume migration is already done
|
||||
if has_flattened_keys:
|
||||
logger.debug(
|
||||
f"Found flattened cache keys in {self.namespace}, skipping migration"
|
||||
)
|
||||
return
|
||||
|
||||
if not keys_to_migrate:
|
||||
return
|
||||
|
||||
# Perform migration
|
||||
pipe = redis.pipeline()
|
||||
migration_count = 0
|
||||
|
||||
for old_key, mode, nested_data in keys_to_migrate:
|
||||
# Delete the old key
|
||||
pipe.delete(old_key)
|
||||
|
||||
# Create new flattened keys
|
||||
for cache_hash, cache_entry in nested_data.items():
|
||||
cache_type = cache_entry.get("cache_type", "extract")
|
||||
flattened_key = generate_cache_key(mode, cache_type, cache_hash)
|
||||
full_key = f"{self.namespace}:{flattened_key}"
|
||||
pipe.set(full_key, json.dumps(cache_entry))
|
||||
migration_count += 1
|
||||
|
||||
await pipe.execute()
|
||||
|
||||
if migration_count > 0:
|
||||
logger.info(
|
||||
f"Migrated {migration_count} legacy cache entries to flattened structure in Redis"
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class RedisDocStatusStorage(DocStatusStorage):
|
||||
"""Redis implementation of document status storage"""
|
||||
|
||||
def __post_init__(self):
|
||||
redis_url = os.environ.get(
|
||||
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
|
||||
)
|
||||
# Use shared connection pool
|
||||
self._pool = RedisConnectionManager.get_pool(redis_url)
|
||||
self._redis = Redis(connection_pool=self._pool)
|
||||
logger.info(
|
||||
f"Initialized Redis doc status storage for {self.namespace} using shared connection pool"
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize Redis connection"""
|
||||
try:
|
||||
async with self._get_redis_connection() as redis:
|
||||
await redis.ping()
|
||||
logger.info(f"Connected to Redis for doc status namespace {self.namespace}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Redis for doc status: {e}")
|
||||
raise
|
||||
|
||||
@asynccontextmanager
|
||||
async def _get_redis_connection(self):
|
||||
"""Safe context manager for Redis operations."""
|
||||
try:
|
||||
yield self._redis
|
||||
except ConnectionError as e:
|
||||
logger.error(f"Redis connection error in doc status {self.namespace}: {e}")
|
||||
raise
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis operation error in doc status {self.namespace}: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error in Redis doc status operation for {self.namespace}: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def close(self):
|
||||
"""Close the Redis connection."""
|
||||
if hasattr(self, "_redis") and self._redis:
|
||||
await self._redis.close()
|
||||
logger.debug(f"Closed Redis connection for doc status {self.namespace}")
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Support for async context manager."""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Ensure Redis resources are cleaned up when exiting context."""
|
||||
await self.close()
|
||||
|
||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
||||
async with self._get_redis_connection() as redis:
|
||||
pipe = redis.pipeline()
|
||||
keys_list = list(keys)
|
||||
for key in keys_list:
|
||||
pipe.exists(f"{self.namespace}:{key}")
|
||||
results = await pipe.execute()
|
||||
|
||||
existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists}
|
||||
return set(keys) - existing_ids
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
result: list[dict[str, Any]] = []
|
||||
async with self._get_redis_connection() as redis:
|
||||
try:
|
||||
pipe = redis.pipeline()
|
||||
for id in ids:
|
||||
pipe.get(f"{self.namespace}:{id}")
|
||||
results = await pipe.execute()
|
||||
|
||||
for result_data in results:
|
||||
if result_data:
|
||||
try:
|
||||
result.append(json.loads(result_data))
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON decode error in get_by_ids: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_by_ids: {e}")
|
||||
return result
|
||||
|
||||
async def get_status_counts(self) -> dict[str, int]:
|
||||
"""Get counts of documents in each status"""
|
||||
counts = {status.value: 0 for status in DocStatus}
|
||||
async with self._get_redis_connection() as redis:
|
||||
try:
|
||||
# Use SCAN to iterate through all keys in the namespace
|
||||
cursor = 0
|
||||
while True:
|
||||
cursor, keys = await redis.scan(cursor, match=f"{self.namespace}:*", count=1000)
|
||||
if keys:
|
||||
# Get all values in batch
|
||||
pipe = redis.pipeline()
|
||||
for key in keys:
|
||||
pipe.get(key)
|
||||
values = await pipe.execute()
|
||||
|
||||
# Count statuses
|
||||
for value in values:
|
||||
if value:
|
||||
try:
|
||||
doc_data = json.loads(value)
|
||||
status = doc_data.get("status")
|
||||
if status in counts:
|
||||
counts[status] += 1
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting status counts: {e}")
|
||||
|
||||
return counts
|
||||
|
||||
async def get_docs_by_status(
|
||||
self, status: DocStatus
|
||||
) -> dict[str, DocProcessingStatus]:
|
||||
"""Get all documents with a specific status"""
|
||||
result = {}
|
||||
async with self._get_redis_connection() as redis:
|
||||
try:
|
||||
# Use SCAN to iterate through all keys in the namespace
|
||||
cursor = 0
|
||||
while True:
|
||||
cursor, keys = await redis.scan(cursor, match=f"{self.namespace}:*", count=1000)
|
||||
if keys:
|
||||
# Get all values in batch
|
||||
pipe = redis.pipeline()
|
||||
for key in keys:
|
||||
pipe.get(key)
|
||||
values = await pipe.execute()
|
||||
|
||||
# Filter by status and create DocProcessingStatus objects
|
||||
for key, value in zip(keys, values):
|
||||
if value:
|
||||
try:
|
||||
doc_data = json.loads(value)
|
||||
if doc_data.get("status") == status.value:
|
||||
# Extract document ID from key
|
||||
doc_id = key.split(":", 1)[1]
|
||||
|
||||
# Make a copy of the data to avoid modifying the original
|
||||
data = doc_data.copy()
|
||||
# If content is missing, use content_summary as content
|
||||
if "content" not in data and "content_summary" in data:
|
||||
data["content"] = data["content_summary"]
|
||||
# If file_path is not in data, use document id as file path
|
||||
if "file_path" not in data:
|
||||
data["file_path"] = "no-file-path"
|
||||
|
||||
result[doc_id] = DocProcessingStatus(**data)
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.error(f"Error processing document {key}: {e}")
|
||||
continue
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting docs by status: {e}")
|
||||
|
||||
return result
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
"""Redis handles persistence automatically"""
|
||||
pass
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""Insert or update document status data"""
|
||||
if not data:
|
||||
return
|
||||
|
||||
logger.debug(f"Inserting {len(data)} records to {self.namespace}")
|
||||
async with self._get_redis_connection() as redis:
|
||||
try:
|
||||
pipe = redis.pipeline()
|
||||
for k, v in data.items():
|
||||
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
|
||||
await pipe.execute()
|
||||
except json.JSONEncodeError as e:
|
||||
logger.error(f"JSON encode error during upsert: {e}")
|
||||
raise
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||
async with self._get_redis_connection() as redis:
|
||||
try:
|
||||
data = await redis.get(f"{self.namespace}:{id}")
|
||||
return json.loads(data) if data else None
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON decode error for id {id}: {e}")
|
||||
return None
|
||||
|
||||
async def delete(self, doc_ids: list[str]) -> None:
|
||||
"""Delete specific records from storage by their IDs"""
|
||||
if not doc_ids:
|
||||
return
|
||||
|
||||
async with self._get_redis_connection() as redis:
|
||||
pipe = redis.pipeline()
|
||||
for doc_id in doc_ids:
|
||||
pipe.delete(f"{self.namespace}:{doc_id}")
|
||||
|
||||
results = await pipe.execute()
|
||||
deleted_count = sum(results)
|
||||
logger.info(f"Deleted {deleted_count} of {len(doc_ids)} doc status entries from {self.namespace}")
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all document status data from storage and clean up resources"""
|
||||
try:
|
||||
async with self._get_redis_connection() as redis:
|
||||
# Use SCAN to find all keys with the namespace prefix
|
||||
pattern = f"{self.namespace}:*"
|
||||
cursor = 0
|
||||
deleted_count = 0
|
||||
|
||||
while True:
|
||||
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
|
||||
if keys:
|
||||
# Delete keys in batches
|
||||
pipe = redis.pipeline()
|
||||
for key in keys:
|
||||
pipe.delete(key)
|
||||
results = await pipe.execute()
|
||||
deleted_count += sum(results)
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
logger.info(f"Dropped {deleted_count} doc status keys from {self.namespace}")
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping doc status {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
@@ -257,7 +257,7 @@ class TiDBKVStorage(BaseKVStorage):
|
||||
|
||||
################ INSERT full_doc AND chunks ################
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||
@@ -454,11 +454,9 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
|
||||
###### INSERT entities And relationships ######
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
|
||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
logger.debug(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
|
||||
# Get current time as UNIX timestamp
|
||||
import time
|
||||
|
||||
@@ -399,10 +399,10 @@ async def _get_cached_extraction_results(
|
||||
"""
|
||||
cached_results = {}
|
||||
|
||||
# Get all cached data for "default" mode (entity extraction cache)
|
||||
default_cache = await llm_response_cache.get_by_id("default") or {}
|
||||
# Get all cached data (flattened cache structure)
|
||||
all_cache = await llm_response_cache.get_all()
|
||||
|
||||
for cache_key, cache_entry in default_cache.items():
|
||||
for cache_key, cache_entry in all_cache.items():
|
||||
if (
|
||||
isinstance(cache_entry, dict)
|
||||
and cache_entry.get("cache_type") == "extract"
|
||||
@@ -1387,7 +1387,7 @@ async def kg_query(
|
||||
use_model_func = partial(use_model_func, _priority=5)
|
||||
|
||||
# Handle cache
|
||||
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
||||
args_hash = compute_args_hash(query_param.mode, query)
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||||
)
|
||||
@@ -1546,7 +1546,7 @@ async def extract_keywords_only(
|
||||
"""
|
||||
|
||||
# 1. Handle cache if needed - add cache type for keywords
|
||||
args_hash = compute_args_hash(param.mode, text, cache_type="keywords")
|
||||
args_hash = compute_args_hash(param.mode, text)
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, text, param.mode, cache_type="keywords"
|
||||
)
|
||||
@@ -2413,7 +2413,7 @@ async def naive_query(
|
||||
use_model_func = partial(use_model_func, _priority=5)
|
||||
|
||||
# Handle cache
|
||||
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
||||
args_hash = compute_args_hash(query_param.mode, query)
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||||
)
|
||||
@@ -2529,7 +2529,7 @@ async def kg_query_with_keywords(
|
||||
# Apply higher priority (5) to query relation LLM function
|
||||
use_model_func = partial(use_model_func, _priority=5)
|
||||
|
||||
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
||||
args_hash = compute_args_hash(query_param.mode, query)
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||||
)
|
||||
|
||||
@@ -14,7 +14,6 @@ from functools import wraps
|
||||
from hashlib import md5
|
||||
from typing import Any, Protocol, Callable, TYPE_CHECKING, List
|
||||
import numpy as np
|
||||
from lightrag.prompt import PROMPTS
|
||||
from dotenv import load_dotenv
|
||||
from lightrag.constants import (
|
||||
DEFAULT_LOG_MAX_BYTES,
|
||||
@@ -278,11 +277,10 @@ def convert_response_to_json(response: str) -> dict[str, Any]:
|
||||
raise e from None
|
||||
|
||||
|
||||
def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
|
||||
def compute_args_hash(*args: Any) -> str:
|
||||
"""Compute a hash for the given arguments.
|
||||
Args:
|
||||
*args: Arguments to hash
|
||||
cache_type: Type of cache (e.g., 'keywords', 'query', 'extract')
|
||||
Returns:
|
||||
str: Hash string
|
||||
"""
|
||||
@@ -290,13 +288,40 @@ def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
|
||||
|
||||
# Convert all arguments to strings and join them
|
||||
args_str = "".join([str(arg) for arg in args])
|
||||
if cache_type:
|
||||
args_str = f"{cache_type}:{args_str}"
|
||||
|
||||
# Compute MD5 hash
|
||||
return hashlib.md5(args_str.encode()).hexdigest()
|
||||
|
||||
|
||||
def generate_cache_key(mode: str, cache_type: str, hash_value: str) -> str:
|
||||
"""Generate a flattened cache key in the format {mode}:{cache_type}:{hash}
|
||||
|
||||
Args:
|
||||
mode: Cache mode (e.g., 'default', 'local', 'global')
|
||||
cache_type: Type of cache (e.g., 'extract', 'query', 'keywords')
|
||||
hash_value: Hash value from compute_args_hash
|
||||
|
||||
Returns:
|
||||
str: Flattened cache key
|
||||
"""
|
||||
return f"{mode}:{cache_type}:{hash_value}"
|
||||
|
||||
|
||||
def parse_cache_key(cache_key: str) -> tuple[str, str, str] | None:
|
||||
"""Parse a flattened cache key back into its components
|
||||
|
||||
Args:
|
||||
cache_key: Flattened cache key in format {mode}:{cache_type}:{hash}
|
||||
|
||||
Returns:
|
||||
tuple[str, str, str] | None: (mode, cache_type, hash) or None if invalid format
|
||||
"""
|
||||
parts = cache_key.split(":", 2)
|
||||
if len(parts) == 3:
|
||||
return parts[0], parts[1], parts[2]
|
||||
return None
|
||||
|
||||
|
||||
def compute_mdhash_id(content: str, prefix: str = "") -> str:
|
||||
"""
|
||||
Compute a unique ID for a given content string.
|
||||
@@ -783,131 +808,6 @@ def process_combine_contexts(*context_lists):
|
||||
return combined_data
|
||||
|
||||
|
||||
async def get_best_cached_response(
|
||||
hashing_kv,
|
||||
current_embedding,
|
||||
similarity_threshold=0.95,
|
||||
mode="default",
|
||||
use_llm_check=False,
|
||||
llm_func=None,
|
||||
original_prompt=None,
|
||||
cache_type=None,
|
||||
) -> str | None:
|
||||
logger.debug(
|
||||
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
|
||||
)
|
||||
mode_cache = await hashing_kv.get_by_id(mode)
|
||||
if not mode_cache:
|
||||
return None
|
||||
|
||||
best_similarity = -1
|
||||
best_response = None
|
||||
best_prompt = None
|
||||
best_cache_id = None
|
||||
|
||||
# Only iterate through cache entries for this mode
|
||||
for cache_id, cache_data in mode_cache.items():
|
||||
# Skip if cache_type doesn't match
|
||||
if cache_type and cache_data.get("cache_type") != cache_type:
|
||||
continue
|
||||
|
||||
# Check if cache data is valid
|
||||
if cache_data["embedding"] is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Safely convert cached embedding
|
||||
cached_quantized = np.frombuffer(
|
||||
bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
|
||||
).reshape(cache_data["embedding_shape"])
|
||||
|
||||
# Ensure min_val and max_val are valid float values
|
||||
embedding_min = cache_data.get("embedding_min")
|
||||
embedding_max = cache_data.get("embedding_max")
|
||||
|
||||
if (
|
||||
embedding_min is None
|
||||
or embedding_max is None
|
||||
or embedding_min >= embedding_max
|
||||
):
|
||||
logger.warning(
|
||||
f"Invalid embedding min/max values: min={embedding_min}, max={embedding_max}"
|
||||
)
|
||||
continue
|
||||
|
||||
cached_embedding = dequantize_embedding(
|
||||
cached_quantized,
|
||||
embedding_min,
|
||||
embedding_max,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing cached embedding: {str(e)}")
|
||||
continue
|
||||
|
||||
similarity = cosine_similarity(current_embedding, cached_embedding)
|
||||
if similarity > best_similarity:
|
||||
best_similarity = similarity
|
||||
best_response = cache_data["return"]
|
||||
best_prompt = cache_data["original_prompt"]
|
||||
best_cache_id = cache_id
|
||||
|
||||
if best_similarity > similarity_threshold:
|
||||
# If LLM check is enabled and all required parameters are provided
|
||||
if (
|
||||
use_llm_check
|
||||
and llm_func
|
||||
and original_prompt
|
||||
and best_prompt
|
||||
and best_response is not None
|
||||
):
|
||||
compare_prompt = PROMPTS["similarity_check"].format(
|
||||
original_prompt=original_prompt, cached_prompt=best_prompt
|
||||
)
|
||||
|
||||
try:
|
||||
llm_result = await llm_func(compare_prompt)
|
||||
llm_result = llm_result.strip()
|
||||
llm_similarity = float(llm_result)
|
||||
|
||||
# Replace vector similarity with LLM similarity score
|
||||
best_similarity = llm_similarity
|
||||
if best_similarity < similarity_threshold:
|
||||
log_data = {
|
||||
"event": "cache_rejected_by_llm",
|
||||
"type": cache_type,
|
||||
"mode": mode,
|
||||
"original_question": original_prompt[:100] + "..."
|
||||
if len(original_prompt) > 100
|
||||
else original_prompt,
|
||||
"cached_question": best_prompt[:100] + "..."
|
||||
if len(best_prompt) > 100
|
||||
else best_prompt,
|
||||
"similarity_score": round(best_similarity, 4),
|
||||
"threshold": similarity_threshold,
|
||||
}
|
||||
logger.debug(json.dumps(log_data, ensure_ascii=False))
|
||||
logger.info(f"Cache rejected by LLM(mode:{mode} tpye:{cache_type})")
|
||||
return None
|
||||
except Exception as e: # Catch all possible exceptions
|
||||
logger.warning(f"LLM similarity check failed: {e}")
|
||||
return None # Return None directly when LLM check fails
|
||||
|
||||
prompt_display = (
|
||||
best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt
|
||||
)
|
||||
log_data = {
|
||||
"event": "cache_hit",
|
||||
"type": cache_type,
|
||||
"mode": mode,
|
||||
"similarity": round(best_similarity, 4),
|
||||
"cache_id": best_cache_id,
|
||||
"original_prompt": prompt_display,
|
||||
}
|
||||
logger.debug(json.dumps(log_data, ensure_ascii=False))
|
||||
return best_response
|
||||
return None
|
||||
|
||||
|
||||
def cosine_similarity(v1, v2):
|
||||
"""Calculate cosine similarity between two vectors"""
|
||||
dot_product = np.dot(v1, v2)
|
||||
@@ -957,7 +857,7 @@ async def handle_cache(
|
||||
mode="default",
|
||||
cache_type=None,
|
||||
):
|
||||
"""Generic cache handling function"""
|
||||
"""Generic cache handling function with flattened cache keys"""
|
||||
if hashing_kv is None:
|
||||
return None, None, None, None
|
||||
|
||||
@@ -968,15 +868,14 @@ async def handle_cache(
|
||||
if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"):
|
||||
return None, None, None, None
|
||||
|
||||
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
||||
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
|
||||
else:
|
||||
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
||||
if args_hash in mode_cache:
|
||||
logger.debug(f"Non-embedding cached hit(mode:{mode} type:{cache_type})")
|
||||
return mode_cache[args_hash]["return"], None, None, None
|
||||
# Use flattened cache key format: {mode}:{cache_type}:{hash}
|
||||
flattened_key = generate_cache_key(mode, cache_type, args_hash)
|
||||
cache_entry = await hashing_kv.get_by_id(flattened_key)
|
||||
if cache_entry:
|
||||
logger.debug(f"Flattened cache hit(key:{flattened_key})")
|
||||
return cache_entry["return"], None, None, None
|
||||
|
||||
logger.debug(f"Non-embedding cached missed(mode:{mode} type:{cache_type})")
|
||||
logger.debug(f"Cache missed(mode:{mode} type:{cache_type})")
|
||||
return None, None, None, None
|
||||
|
||||
|
||||
@@ -994,7 +893,7 @@ class CacheData:
|
||||
|
||||
|
||||
async def save_to_cache(hashing_kv, cache_data: CacheData):
|
||||
"""Save data to cache, with improved handling for streaming responses and duplicate content.
|
||||
"""Save data to cache using flattened key structure.
|
||||
|
||||
Args:
|
||||
hashing_kv: The key-value storage for caching
|
||||
@@ -1009,26 +908,21 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
|
||||
logger.debug("Streaming response detected, skipping cache")
|
||||
return
|
||||
|
||||
# Get existing cache data
|
||||
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
||||
mode_cache = (
|
||||
await hashing_kv.get_by_mode_and_id(cache_data.mode, cache_data.args_hash)
|
||||
or {}
|
||||
)
|
||||
else:
|
||||
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
|
||||
# Use flattened cache key format: {mode}:{cache_type}:{hash}
|
||||
flattened_key = generate_cache_key(
|
||||
cache_data.mode, cache_data.cache_type, cache_data.args_hash
|
||||
)
|
||||
|
||||
# Check if we already have identical content cached
|
||||
if cache_data.args_hash in mode_cache:
|
||||
existing_content = mode_cache[cache_data.args_hash].get("return")
|
||||
existing_cache = await hashing_kv.get_by_id(flattened_key)
|
||||
if existing_cache:
|
||||
existing_content = existing_cache.get("return")
|
||||
if existing_content == cache_data.content:
|
||||
logger.info(
|
||||
f"Cache content unchanged for {cache_data.args_hash}, skipping update"
|
||||
)
|
||||
logger.info(f"Cache content unchanged for {flattened_key}, skipping update")
|
||||
return
|
||||
|
||||
# Update cache with new content
|
||||
mode_cache[cache_data.args_hash] = {
|
||||
# Create cache entry with flattened structure
|
||||
cache_entry = {
|
||||
"return": cache_data.content,
|
||||
"cache_type": cache_data.cache_type,
|
||||
"chunk_id": cache_data.chunk_id if cache_data.chunk_id is not None else None,
|
||||
@@ -1043,10 +937,10 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
|
||||
"original_prompt": cache_data.prompt,
|
||||
}
|
||||
|
||||
logger.info(f" == LLM cache == saving {cache_data.mode}: {cache_data.args_hash}")
|
||||
logger.info(f" == LLM cache == saving: {flattened_key}")
|
||||
|
||||
# Only upsert if there's actual new content
|
||||
await hashing_kv.upsert({cache_data.mode: mode_cache})
|
||||
# Save using flattened key
|
||||
await hashing_kv.upsert({flattened_key: cache_entry})
|
||||
|
||||
|
||||
def safe_unicode_decode(content):
|
||||
|
||||
Reference in New Issue
Block a user