feat: Flatten LLM cache structure for improved recall efficiency

Refactored the LLM cache to a flat Key-Value (KV) structure, replacing the previous nested format. The old structure used the 'mode' as a key and stored specific cache content as JSON nested under it. This change significantly enhances cache recall efficiency.
This commit is contained in:
yangdx
2025-07-02 16:11:53 +08:00
parent b32c3825cc
commit 271722405f
12 changed files with 836 additions and 375 deletions

View File

@@ -52,18 +52,23 @@ async def copy_from_postgres_to_json():
embedding_func=None,
)
# Get all cache data using the new flattened structure
all_data = await from_llm_response_cache.get_all()
# Convert flattened data to hierarchical structure for JsonKVStorage
kv = {}
for c_id in await from_llm_response_cache.all_keys():
print(f"Copying {c_id}")
workspace = c_id["workspace"]
mode = c_id["mode"]
_id = c_id["id"]
postgres_db.workspace = workspace
obj = await from_llm_response_cache.get_by_mode_and_id(mode, _id)
if mode not in kv:
kv[mode] = {}
kv[mode][_id] = obj[_id]
print(f"Object {obj}")
for flattened_key, cache_entry in all_data.items():
# Parse flattened key: {mode}:{cache_type}:{hash}
parts = flattened_key.split(":", 2)
if len(parts) == 3:
mode, cache_type, hash_value = parts
if mode not in kv:
kv[mode] = {}
kv[mode][hash_value] = cache_entry
print(f"Copying {flattened_key} -> {mode}[{hash_value}]")
else:
print(f"Skipping invalid key format: {flattened_key}")
await to_llm_response_cache.upsert(kv)
await to_llm_response_cache.index_done_callback()
print("Mission accomplished!")
@@ -85,13 +90,24 @@ async def copy_from_json_to_postgres():
db=postgres_db,
)
for mode in await from_llm_response_cache.all_keys():
print(f"Copying {mode}")
caches = await from_llm_response_cache.get_by_id(mode)
for k, v in caches.items():
item = {mode: {k: v}}
print(f"\tCopying {item}")
await to_llm_response_cache.upsert(item)
# Get all cache data from JsonKVStorage (hierarchical structure)
all_data = await from_llm_response_cache.get_all()
# Convert hierarchical data to flattened structure for PGKVStorage
flattened_data = {}
for mode, mode_data in all_data.items():
print(f"Processing mode: {mode}")
for hash_value, cache_entry in mode_data.items():
# Determine cache_type from cache entry or use default
cache_type = cache_entry.get("cache_type", "extract")
# Create flattened key: {mode}:{cache_type}:{hash}
flattened_key = f"{mode}:{cache_type}:{hash_value}"
flattened_data[flattened_key] = cache_entry
print(f"\tConverting {mode}[{hash_value}] -> {flattened_key}")
# Upsert the flattened data
await to_llm_response_cache.upsert(flattened_data)
print("Mission accomplished!")
if __name__ == "__main__":

View File

@@ -37,6 +37,7 @@ STORAGE_IMPLEMENTATIONS = {
"DOC_STATUS_STORAGE": {
"implementations": [
"JsonDocStatusStorage",
"RedisDocStatusStorage",
"PGDocStatusStorage",
"MongoDocStatusStorage",
],
@@ -79,6 +80,7 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
"MongoVectorDBStorage": [],
# Document Status Storage Implementations
"JsonDocStatusStorage": [],
"RedisDocStatusStorage": ["REDIS_URI"],
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"MongoDocStatusStorage": [],
}
@@ -96,6 +98,7 @@ STORAGES = {
"MongoGraphStorage": ".kg.mongo_impl",
"MongoVectorDBStorage": ".kg.mongo_impl",
"RedisKVStorage": ".kg.redis_impl",
"RedisDocStatusStorage": ".kg.redis_impl",
"ChromaVectorDBStorage": ".kg.chroma_impl",
# "TiDBKVStorage": ".kg.tidb_impl",
# "TiDBVectorDBStorage": ".kg.tidb_impl",

View File

@@ -109,7 +109,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
raise
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
logger.debug(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
@@ -234,7 +234,6 @@ class ChromaVectorDBStorage(BaseVectorStorage):
ids: List of vector IDs to be deleted
"""
try:
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
self._collection.delete(ids=ids)
logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"

View File

@@ -42,19 +42,14 @@ class JsonKVStorage(BaseKVStorage):
if need_init:
loaded_data = load_json(self._file_name) or {}
async with self._storage_lock:
self._data.update(loaded_data)
# Calculate data count based on namespace
if self.namespace.endswith("cache"):
# For cache namespaces, sum the cache entries across all cache types
data_count = sum(
len(first_level_dict)
for first_level_dict in loaded_data.values()
if isinstance(first_level_dict, dict)
# Migrate legacy cache structure if needed
if self.namespace.endswith("_cache"):
loaded_data = await self._migrate_legacy_cache_structure(
loaded_data
)
else:
# For non-cache namespaces, use the original count method
data_count = len(loaded_data)
self._data.update(loaded_data)
data_count = len(loaded_data)
logger.info(
f"Process {os.getpid()} KV load {self.namespace} with {data_count} records"
@@ -67,17 +62,8 @@ class JsonKVStorage(BaseKVStorage):
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
)
# Calculate data count based on namespace
if self.namespace.endswith("cache"):
# # For cache namespaces, sum the cache entries across all cache types
data_count = sum(
len(first_level_dict)
for first_level_dict in data_dict.values()
if isinstance(first_level_dict, dict)
)
else:
# For non-cache namespaces, use the original count method
data_count = len(data_dict)
# Calculate data count - all data is now flattened
data_count = len(data_dict)
logger.debug(
f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}"
@@ -150,14 +136,14 @@ class JsonKVStorage(BaseKVStorage):
await set_all_update_flags(self.namespace)
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
"""Delete specific records from storage by by cache mode
"""Delete specific records from storage by cache mode
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. update flags to notify other processes that data persistence is needed
Args:
ids (list[str]): List of cache mode to be drop from storage
modes (list[str]): List of cache modes to be dropped from storage
Returns:
True: if the cache drop successfully
@@ -167,9 +153,29 @@ class JsonKVStorage(BaseKVStorage):
return False
try:
await self.delete(modes)
async with self._storage_lock:
keys_to_delete = []
modes_set = set(modes) # Convert to set for efficient lookup
for key in list(self._data.keys()):
# Parse flattened cache key: mode:cache_type:hash
parts = key.split(":", 2)
if len(parts) == 3 and parts[0] in modes_set:
keys_to_delete.append(key)
# Batch delete
for key in keys_to_delete:
self._data.pop(key, None)
if keys_to_delete:
await set_all_update_flags(self.namespace)
logger.info(
f"Dropped {len(keys_to_delete)} cache entries for modes: {modes}"
)
return True
except Exception:
except Exception as e:
logger.error(f"Error dropping cache by modes: {e}")
return False
# async def drop_cache_by_chunk_ids(self, chunk_ids: list[str] | None = None) -> bool:
@@ -245,9 +251,58 @@ class JsonKVStorage(BaseKVStorage):
logger.error(f"Error dropping {self.namespace}: {e}")
return {"status": "error", "message": str(e)}
async def _migrate_legacy_cache_structure(self, data: dict) -> dict:
"""Migrate legacy nested cache structure to flattened structure
Args:
data: Original data dictionary that may contain legacy structure
Returns:
Migrated data dictionary with flattened cache keys
"""
from lightrag.utils import generate_cache_key
# Early return if data is empty
if not data:
return data
# Check first entry to see if it's already in new format
first_key = next(iter(data.keys()))
if ":" in first_key and len(first_key.split(":")) == 3:
# Already in flattened format, return as-is
return data
migrated_data = {}
migration_count = 0
for key, value in data.items():
# Check if this is a legacy nested cache structure
if isinstance(value, dict) and all(
isinstance(v, dict) and "return" in v for v in value.values()
):
# This looks like a legacy cache mode with nested structure
mode = key
for cache_hash, cache_entry in value.items():
cache_type = cache_entry.get("cache_type", "extract")
flattened_key = generate_cache_key(mode, cache_type, cache_hash)
migrated_data[flattened_key] = cache_entry
migration_count += 1
else:
# Keep non-cache data or already flattened cache data as-is
migrated_data[key] = value
if migration_count > 0:
logger.info(
f"Migrated {migration_count} legacy cache entries to flattened structure"
)
# Persist migrated data immediately
write_json(migrated_data, self._file_name)
return migrated_data
async def finalize(self):
"""Finalize storage resources
Persistence cache data to disk before exiting
"""
if self.namespace.endswith("cache"):
if self.namespace.endswith("_cache"):
await self.index_done_callback()

View File

@@ -75,7 +75,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
)
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
logger.debug(f"Inserting {len(data)} to {self.namespace}")
if not data:
return

View File

@@ -15,7 +15,6 @@ from ..base import (
DocStatus,
DocStatusStorage,
)
from ..namespace import NameSpace, is_namespace
from ..utils import logger, compute_mdhash_id
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..constants import GRAPH_FIELD_SEP
@@ -98,17 +97,8 @@ class MongoKVStorage(BaseKVStorage):
self._data = None
async def get_by_id(self, id: str) -> dict[str, Any] | None:
if id == "default":
# Find all documents with _id starting with "default_"
cursor = self._data.find({"_id": {"$regex": "^default_"}})
result = {}
async for doc in cursor:
# Use the complete _id as key
result[doc["_id"]] = doc
return result if result else None
else:
# Original behavior for non-"default" ids
return await self._data.find_one({"_id": id})
# Unified handling for flattened keys
return await self._data.find_one({"_id": id})
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
cursor = self._data.find({"_id": {"$in": ids}})
@@ -133,43 +123,21 @@ class MongoKVStorage(BaseKVStorage):
return result
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
logger.debug(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
update_tasks: list[Any] = []
for mode, items in data.items():
for k, v in items.items():
key = f"{mode}_{k}"
data[mode][k]["_id"] = f"{mode}_{k}"
update_tasks.append(
self._data.update_one(
{"_id": key}, {"$setOnInsert": v}, upsert=True
)
)
await asyncio.gather(*update_tasks)
else:
update_tasks = []
for k, v in data.items():
data[k]["_id"] = k
update_tasks.append(
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
)
await asyncio.gather(*update_tasks)
# Unified handling for all namespaces with flattened keys
# Use bulk_write for better performance
from pymongo import UpdateOne
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
res = {}
v = await self._data.find_one({"_id": mode + "_" + id})
if v:
res[id] = v
logger.debug(f"llm_response_cache find one by:{id}")
return res
else:
return None
else:
return None
operations = []
for k, v in data.items():
v["_id"] = k # Use flattened key as _id
operations.append(UpdateOne({"_id": k}, {"$set": v}, upsert=True))
if operations:
await self._data.bulk_write(operations)
async def index_done_callback(self) -> None:
# Mongo handles persistence automatically
@@ -209,8 +177,8 @@ class MongoKVStorage(BaseKVStorage):
return False
try:
# Build regex pattern to match documents with the specified modes
pattern = f"^({'|'.join(modes)})_"
# Build regex pattern to match flattened key format: mode:cache_type:hash
pattern = f"^({'|'.join(modes)}):"
result = await self._data.delete_many({"_id": {"$regex": pattern}})
logger.info(f"Deleted {result.deleted_count} documents by modes: {modes}")
return True
@@ -274,7 +242,7 @@ class MongoDocStatusStorage(DocStatusStorage):
return data - existing_ids
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
logger.debug(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
update_tasks: list[Any] = []
@@ -1282,7 +1250,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
logger.debug("vector index already exist")
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
logger.debug(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
@@ -1371,7 +1339,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
Args:
ids: List of vector IDs to be deleted
"""
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
logger.debug(f"Deleting {len(ids)} vectors from {self.namespace}")
if not ids:
return

View File

@@ -247,6 +247,116 @@ class PostgreSQLDB:
logger.error(f"Failed during data migration to LIGHTRAG_VDB_CHUNKS: {e}")
# Do not re-raise, to allow the application to start
async def _check_llm_cache_needs_migration(self):
"""Check if LLM cache data needs migration by examining the first record"""
try:
# Only query the first record to determine format
check_sql = """
SELECT id FROM LIGHTRAG_LLM_CACHE
ORDER BY create_time ASC
LIMIT 1
"""
result = await self.query(check_sql)
if result and result.get("id"):
# If id doesn't contain colon, it's old format
return ":" not in result["id"]
return False # No data or already new format
except Exception as e:
logger.warning(f"Failed to check LLM cache migration status: {e}")
return False
async def _migrate_llm_cache_to_flattened_keys(self):
"""Migrate LLM cache to flattened key format, recalculating hash values"""
try:
# Get all old format data
old_data_sql = """
SELECT id, mode, original_prompt, return_value, chunk_id,
create_time, update_time
FROM LIGHTRAG_LLM_CACHE
WHERE id NOT LIKE '%:%'
"""
old_records = await self.query(old_data_sql, multirows=True)
if not old_records:
logger.info("No old format LLM cache data found, skipping migration")
return
logger.info(
f"Found {len(old_records)} old format cache records, starting migration..."
)
# Import hash calculation function
from ..utils import compute_args_hash
migrated_count = 0
# Migrate data in batches
for record in old_records:
try:
# Recalculate hash using correct method
new_hash = compute_args_hash(
record["mode"], record["original_prompt"]
)
# Generate new flattened key
cache_type = "extract" # Default type
new_key = f"{record['mode']}:{cache_type}:{new_hash}"
# Insert new format data
insert_sql = """
INSERT INTO LIGHTRAG_LLM_CACHE
(workspace, id, mode, original_prompt, return_value, chunk_id, create_time, update_time)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT (workspace, mode, id) DO NOTHING
"""
await self.execute(
insert_sql,
{
"workspace": self.workspace,
"id": new_key,
"mode": record["mode"],
"original_prompt": record["original_prompt"],
"return_value": record["return_value"],
"chunk_id": record["chunk_id"],
"create_time": record["create_time"],
"update_time": record["update_time"],
},
)
# Delete old data
delete_sql = """
DELETE FROM LIGHTRAG_LLM_CACHE
WHERE workspace=$1 AND mode=$2 AND id=$3
"""
await self.execute(
delete_sql,
{
"workspace": self.workspace,
"mode": record["mode"],
"id": record["id"], # Old id
},
)
migrated_count += 1
except Exception as e:
logger.warning(
f"Failed to migrate cache record {record['id']}: {e}"
)
continue
logger.info(
f"Successfully migrated {migrated_count} cache records to flattened format"
)
except Exception as e:
logger.error(f"LLM cache migration failed: {e}")
# Don't raise exception, allow system to continue startup
async def check_tables(self):
# First create all tables
for k, v in TABLES.items():
@@ -304,6 +414,13 @@ class PostgreSQLDB:
except Exception as e:
logger.error(f"PostgreSQL, Failed to migrate doc_chunks to vdb_chunks: {e}")
# Check and migrate LLM cache to flattened keys if needed
try:
if await self._check_llm_cache_needs_migration():
await self._migrate_llm_cache_to_flattened_keys()
except Exception as e:
logger.error(f"PostgreSQL, LLM cache migration failed: {e}")
async def query(
self,
sql: str,
@@ -486,77 +603,48 @@ class PGKVStorage(BaseKVStorage):
try:
results = await self.db.query(sql, params, multirows=True)
# Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
result_dict = {}
processed_results = {}
for row in results:
mode = row["mode"]
if mode not in result_dict:
result_dict[mode] = {}
result_dict[mode][row["id"]] = row
return result_dict
else:
return {row["id"]: row for row in results}
# Parse flattened key to extract cache_type
key_parts = row["id"].split(":")
cache_type = key_parts[1] if len(key_parts) >= 3 else "unknown"
# Map field names and add cache_type for compatibility
processed_row = {
**row,
"return": row.get("return_value", ""), # Map return_value to return
"cache_type": cache_type, # Add cache_type from key
"original_prompt": row.get("original_prompt", ""),
"chunk_id": row.get("chunk_id"),
"mode": row.get("mode", "default")
}
processed_results[row["id"]] = processed_row
return processed_results
# For other namespaces, return as-is
return {row["id"]: row for row in results}
except Exception as e:
logger.error(f"Error retrieving all data from {self.namespace}: {e}")
return {}
async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Get doc_full data by id."""
"""Get data by id."""
sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
# For LLM cache, the id parameter actually represents the mode
params = {"workspace": self.db.workspace, "mode": id}
array_res = await self.db.query(sql, params, multirows=True)
res = {}
for row in array_res:
# Dynamically add cache_type field based on mode
row_with_cache_type = dict(row)
if id == "default":
row_with_cache_type["cache_type"] = "extract"
else:
row_with_cache_type["cache_type"] = "unknown"
res[row["id"]] = row_with_cache_type
return res if res else None
else:
params = {"workspace": self.db.workspace, "id": id}
response = await self.db.query(sql, params)
return response if response else None
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
"""Specifically for llm_response_cache."""
sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
params = {"workspace": self.db.workspace, "mode": mode, "id": id}
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
array_res = await self.db.query(sql, params, multirows=True)
res = {}
for row in array_res:
res[row["id"]] = row
return res
else:
return None
params = {"workspace": self.db.workspace, "id": id}
response = await self.db.query(sql, params)
return response if response else None
# Query by id
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
"""Get doc_chunks data by id"""
"""Get data by ids"""
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
ids=",".join([f"'{id}'" for id in ids])
)
params = {"workspace": self.db.workspace}
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
array_res = await self.db.query(sql, params, multirows=True)
modes = set()
dict_res: dict[str, dict] = {}
for row in array_res:
modes.add(row["mode"])
for mode in modes:
if mode not in dict_res:
dict_res[mode] = {}
for row in array_res:
dict_res[row["mode"]][row["id"]] = row
return [{k: v} for k, v in dict_res.items()]
else:
return await self.db.query(sql, params, multirows=True)
return await self.db.query(sql, params, multirows=True)
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
"""Specifically for llm_response_cache."""
@@ -617,19 +705,18 @@ class PGKVStorage(BaseKVStorage):
}
await self.db.execute(upsert_sql, _data)
elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
for mode, items in data.items():
for k, v in items.items():
upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
_data = {
"workspace": self.db.workspace,
"id": k,
"original_prompt": v["original_prompt"],
"return_value": v["return"],
"mode": mode,
"chunk_id": v.get("chunk_id"),
}
for k, v in data.items():
upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
_data = {
"workspace": self.db.workspace,
"id": k, # Use flattened key as id
"original_prompt": v["original_prompt"],
"return_value": v["return"],
"mode": v.get("mode", "default"), # Get mode from data
"chunk_id": v.get("chunk_id"),
}
await self.db.execute(upsert_sql, _data)
await self.db.execute(upsert_sql, _data)
async def index_done_callback(self) -> None:
# PG handles persistence automatically
@@ -1035,8 +1122,8 @@ class PGDocStatusStorage(DocStatusStorage):
else:
exist_keys = []
new_keys = set([s for s in keys if s not in exist_keys])
print(f"keys: {keys}")
print(f"new_keys: {new_keys}")
# print(f"keys: {keys}")
# print(f"new_keys: {new_keys}")
return new_keys
except Exception as e:
logger.error(
@@ -2621,7 +2708,7 @@ SQL_TEMPLATES = {
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
""",
"get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id=$2
""",
"get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 AND id=$3

View File

@@ -85,7 +85,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
)
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
logger.debug(f"Inserting {len(data)} to {self.namespace}")
if not data:
return

View File

@@ -1,9 +1,10 @@
import os
from typing import Any, final
from typing import Any, final, Union
from dataclasses import dataclass
import pipmaster as pm
import configparser
from contextlib import asynccontextmanager
import threading
if not pm.is_installed("redis"):
pm.install("redis")
@@ -13,7 +14,7 @@ from redis.asyncio import Redis, ConnectionPool # type: ignore
from redis.exceptions import RedisError, ConnectionError # type: ignore
from lightrag.utils import logger
from lightrag.base import BaseKVStorage
from lightrag.base import BaseKVStorage, DocStatusStorage, DocStatus, DocProcessingStatus
import json
@@ -26,6 +27,41 @@ SOCKET_TIMEOUT = 5.0
SOCKET_CONNECT_TIMEOUT = 3.0
class RedisConnectionManager:
"""Shared Redis connection pool manager to avoid creating multiple pools for the same Redis URI"""
_pools = {}
_lock = threading.Lock()
@classmethod
def get_pool(cls, redis_url: str) -> ConnectionPool:
"""Get or create a connection pool for the given Redis URL"""
if redis_url not in cls._pools:
with cls._lock:
if redis_url not in cls._pools:
cls._pools[redis_url] = ConnectionPool.from_url(
redis_url,
max_connections=MAX_CONNECTIONS,
decode_responses=True,
socket_timeout=SOCKET_TIMEOUT,
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
)
logger.info(f"Created shared Redis connection pool for {redis_url}")
return cls._pools[redis_url]
@classmethod
def close_all_pools(cls):
"""Close all connection pools (for cleanup)"""
with cls._lock:
for url, pool in cls._pools.items():
try:
pool.disconnect()
logger.info(f"Closed Redis connection pool for {url}")
except Exception as e:
logger.error(f"Error closing Redis pool for {url}: {e}")
cls._pools.clear()
@final
@dataclass
class RedisKVStorage(BaseKVStorage):
@@ -33,19 +69,28 @@ class RedisKVStorage(BaseKVStorage):
redis_url = os.environ.get(
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
)
# Create a connection pool with limits
self._pool = ConnectionPool.from_url(
redis_url,
max_connections=MAX_CONNECTIONS,
decode_responses=True,
socket_timeout=SOCKET_TIMEOUT,
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
)
# Use shared connection pool
self._pool = RedisConnectionManager.get_pool(redis_url)
self._redis = Redis(connection_pool=self._pool)
logger.info(
f"Initialized Redis connection pool for {self.namespace} with max {MAX_CONNECTIONS} connections"
f"Initialized Redis KV storage for {self.namespace} using shared connection pool"
)
async def initialize(self):
"""Initialize Redis connection and migrate legacy cache structure if needed"""
# Test connection
try:
async with self._get_redis_connection() as redis:
await redis.ping()
logger.info(f"Connected to Redis for namespace {self.namespace}")
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
raise
# Migrate legacy cache structure if this is a cache namespace
if self.namespace.endswith("_cache"):
await self._migrate_legacy_cache_structure()
@asynccontextmanager
async def _get_redis_connection(self):
"""Safe context manager for Redis operations."""
@@ -99,21 +144,57 @@ class RedisKVStorage(BaseKVStorage):
logger.error(f"JSON decode error in batch get: {e}")
return [None] * len(ids)
async def get_all(self) -> dict[str, Any]:
"""Get all data from storage
Returns:
Dictionary containing all stored data
"""
async with self._get_redis_connection() as redis:
try:
# Get all keys for this namespace
keys = await redis.keys(f"{self.namespace}:*")
if not keys:
return {}
# Get all values in batch
pipe = redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
# Build result dictionary
result = {}
for key, value in zip(keys, values):
if value:
# Extract the ID part (after namespace:)
key_id = key.split(":", 1)[1]
try:
result[key_id] = json.loads(value)
except json.JSONDecodeError as e:
logger.error(f"JSON decode error for key {key}: {e}")
continue
return result
except Exception as e:
logger.error(f"Error getting all data from Redis: {e}")
return {}
async def filter_keys(self, keys: set[str]) -> set[str]:
async with self._get_redis_connection() as redis:
pipe = redis.pipeline()
for key in keys:
keys_list = list(keys) # Convert set to list for indexing
for key in keys_list:
pipe.exists(f"{self.namespace}:{key}")
results = await pipe.execute()
existing_ids = {keys[i] for i, exists in enumerate(results) if exists}
existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists}
return set(keys) - existing_ids
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
if not data:
return
logger.info(f"Inserting {len(data)} items to {self.namespace}")
async with self._get_redis_connection() as redis:
try:
pipe = redis.pipeline()
@@ -148,13 +229,13 @@ class RedisKVStorage(BaseKVStorage):
)
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
"""Delete specific records from storage by by cache mode
"""Delete specific records from storage by cache mode
Importance notes for Redis storage:
1. This will immediately delete the specified cache modes from Redis
Args:
modes (list[str]): List of cache mode to be drop from storage
modes (list[str]): List of cache modes to be dropped from storage
Returns:
True: if the cache drop successfully
@@ -164,9 +245,43 @@ class RedisKVStorage(BaseKVStorage):
return False
try:
await self.delete(modes)
async with self._get_redis_connection() as redis:
keys_to_delete = []
# Find matching keys for each mode using SCAN
for mode in modes:
# Use correct pattern to match flattened cache key format {namespace}:{mode}:{cache_type}:{hash}
pattern = f"{self.namespace}:{mode}:*"
cursor = 0
mode_keys = []
while True:
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
if keys:
mode_keys.extend(keys)
if cursor == 0:
break
keys_to_delete.extend(mode_keys)
logger.info(f"Found {len(mode_keys)} keys for mode '{mode}' with pattern '{pattern}'")
if keys_to_delete:
# Batch delete
pipe = redis.pipeline()
for key in keys_to_delete:
pipe.delete(key)
results = await pipe.execute()
deleted_count = sum(results)
logger.info(
f"Dropped {deleted_count} cache entries for modes: {modes}"
)
else:
logger.warning(f"No cache entries found for modes: {modes}")
return True
except Exception:
except Exception as e:
logger.error(f"Error dropping cache by modes in Redis: {e}")
return False
async def drop(self) -> dict[str, str]:
@@ -177,24 +292,350 @@ class RedisKVStorage(BaseKVStorage):
"""
async with self._get_redis_connection() as redis:
try:
keys = await redis.keys(f"{self.namespace}:*")
# Use SCAN to find all keys with the namespace prefix
pattern = f"{self.namespace}:*"
cursor = 0
deleted_count = 0
while True:
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
if keys:
# Delete keys in batches
pipe = redis.pipeline()
for key in keys:
pipe.delete(key)
results = await pipe.execute()
deleted_count += sum(results)
if cursor == 0:
break
if keys:
pipe = redis.pipeline()
for key in keys:
pipe.delete(key)
results = await pipe.execute()
deleted_count = sum(results)
logger.info(f"Dropped {deleted_count} keys from {self.namespace}")
return {
"status": "success",
"message": f"{deleted_count} keys dropped",
}
else:
logger.info(f"No keys found to drop in {self.namespace}")
return {"status": "success", "message": "no keys to drop"}
logger.info(f"Dropped {deleted_count} keys from {self.namespace}")
return {
"status": "success",
"message": f"{deleted_count} keys dropped",
}
except Exception as e:
logger.error(f"Error dropping keys from {self.namespace}: {e}")
return {"status": "error", "message": str(e)}
async def _migrate_legacy_cache_structure(self):
"""Migrate legacy nested cache structure to flattened structure for Redis
Redis already stores data in a flattened way, but we need to check for
legacy keys that might contain nested JSON structures and migrate them.
Early exit if any flattened key is found (indicating migration already done).
"""
from lightrag.utils import generate_cache_key
async with self._get_redis_connection() as redis:
# Get all keys for this namespace
keys = await redis.keys(f"{self.namespace}:*")
if not keys:
return
# Check if we have any flattened keys already - if so, skip migration
has_flattened_keys = False
keys_to_migrate = []
for key in keys:
# Extract the ID part (after namespace:)
key_id = key.split(":", 1)[1]
# Check if already in flattened format (contains exactly 2 colons for mode:cache_type:hash)
if ":" in key_id and len(key_id.split(":")) == 3:
has_flattened_keys = True
break # Early exit - migration already done
# Get the data to check if it's a legacy nested structure
data = await redis.get(key)
if data:
try:
parsed_data = json.loads(data)
# Check if this looks like a legacy cache mode with nested structure
if isinstance(parsed_data, dict) and all(
isinstance(v, dict) and "return" in v
for v in parsed_data.values()
):
keys_to_migrate.append((key, key_id, parsed_data))
except json.JSONDecodeError:
continue
# If we found any flattened keys, assume migration is already done
if has_flattened_keys:
logger.debug(
f"Found flattened cache keys in {self.namespace}, skipping migration"
)
return
if not keys_to_migrate:
return
# Perform migration
pipe = redis.pipeline()
migration_count = 0
for old_key, mode, nested_data in keys_to_migrate:
# Delete the old key
pipe.delete(old_key)
# Create new flattened keys
for cache_hash, cache_entry in nested_data.items():
cache_type = cache_entry.get("cache_type", "extract")
flattened_key = generate_cache_key(mode, cache_type, cache_hash)
full_key = f"{self.namespace}:{flattened_key}"
pipe.set(full_key, json.dumps(cache_entry))
migration_count += 1
await pipe.execute()
if migration_count > 0:
logger.info(
f"Migrated {migration_count} legacy cache entries to flattened structure in Redis"
)
@final
@dataclass
class RedisDocStatusStorage(DocStatusStorage):
"""Redis implementation of document status storage"""
def __post_init__(self):
redis_url = os.environ.get(
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
)
# Use shared connection pool
self._pool = RedisConnectionManager.get_pool(redis_url)
self._redis = Redis(connection_pool=self._pool)
logger.info(
f"Initialized Redis doc status storage for {self.namespace} using shared connection pool"
)
async def initialize(self):
"""Initialize Redis connection"""
try:
async with self._get_redis_connection() as redis:
await redis.ping()
logger.info(f"Connected to Redis for doc status namespace {self.namespace}")
except Exception as e:
logger.error(f"Failed to connect to Redis for doc status: {e}")
raise
@asynccontextmanager
async def _get_redis_connection(self):
"""Safe context manager for Redis operations."""
try:
yield self._redis
except ConnectionError as e:
logger.error(f"Redis connection error in doc status {self.namespace}: {e}")
raise
except RedisError as e:
logger.error(f"Redis operation error in doc status {self.namespace}: {e}")
raise
except Exception as e:
logger.error(
f"Unexpected error in Redis doc status operation for {self.namespace}: {e}"
)
raise
async def close(self):
"""Close the Redis connection."""
if hasattr(self, "_redis") and self._redis:
await self._redis.close()
logger.debug(f"Closed Redis connection for doc status {self.namespace}")
async def __aenter__(self):
"""Support for async context manager."""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Ensure Redis resources are cleaned up when exiting context."""
await self.close()
async def filter_keys(self, keys: set[str]) -> set[str]:
"""Return keys that should be processed (not in storage or not successfully processed)"""
async with self._get_redis_connection() as redis:
pipe = redis.pipeline()
keys_list = list(keys)
for key in keys_list:
pipe.exists(f"{self.namespace}:{key}")
results = await pipe.execute()
existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists}
return set(keys) - existing_ids
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
result: list[dict[str, Any]] = []
async with self._get_redis_connection() as redis:
try:
pipe = redis.pipeline()
for id in ids:
pipe.get(f"{self.namespace}:{id}")
results = await pipe.execute()
for result_data in results:
if result_data:
try:
result.append(json.loads(result_data))
except json.JSONDecodeError as e:
logger.error(f"JSON decode error in get_by_ids: {e}")
continue
except Exception as e:
logger.error(f"Error in get_by_ids: {e}")
return result
async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status"""
counts = {status.value: 0 for status in DocStatus}
async with self._get_redis_connection() as redis:
try:
# Use SCAN to iterate through all keys in the namespace
cursor = 0
while True:
cursor, keys = await redis.scan(cursor, match=f"{self.namespace}:*", count=1000)
if keys:
# Get all values in batch
pipe = redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
# Count statuses
for value in values:
if value:
try:
doc_data = json.loads(value)
status = doc_data.get("status")
if status in counts:
counts[status] += 1
except json.JSONDecodeError:
continue
if cursor == 0:
break
except Exception as e:
logger.error(f"Error getting status counts: {e}")
return counts
async def get_docs_by_status(
self, status: DocStatus
) -> dict[str, DocProcessingStatus]:
"""Get all documents with a specific status"""
result = {}
async with self._get_redis_connection() as redis:
try:
# Use SCAN to iterate through all keys in the namespace
cursor = 0
while True:
cursor, keys = await redis.scan(cursor, match=f"{self.namespace}:*", count=1000)
if keys:
# Get all values in batch
pipe = redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
# Filter by status and create DocProcessingStatus objects
for key, value in zip(keys, values):
if value:
try:
doc_data = json.loads(value)
if doc_data.get("status") == status.value:
# Extract document ID from key
doc_id = key.split(":", 1)[1]
# Make a copy of the data to avoid modifying the original
data = doc_data.copy()
# If content is missing, use content_summary as content
if "content" not in data and "content_summary" in data:
data["content"] = data["content_summary"]
# If file_path is not in data, use document id as file path
if "file_path" not in data:
data["file_path"] = "no-file-path"
result[doc_id] = DocProcessingStatus(**data)
except (json.JSONDecodeError, KeyError) as e:
logger.error(f"Error processing document {key}: {e}")
continue
if cursor == 0:
break
except Exception as e:
logger.error(f"Error getting docs by status: {e}")
return result
async def index_done_callback(self) -> None:
"""Redis handles persistence automatically"""
pass
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""Insert or update document status data"""
if not data:
return
logger.debug(f"Inserting {len(data)} records to {self.namespace}")
async with self._get_redis_connection() as redis:
try:
pipe = redis.pipeline()
for k, v in data.items():
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
await pipe.execute()
except json.JSONEncodeError as e:
logger.error(f"JSON encode error during upsert: {e}")
raise
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
async with self._get_redis_connection() as redis:
try:
data = await redis.get(f"{self.namespace}:{id}")
return json.loads(data) if data else None
except json.JSONDecodeError as e:
logger.error(f"JSON decode error for id {id}: {e}")
return None
async def delete(self, doc_ids: list[str]) -> None:
"""Delete specific records from storage by their IDs"""
if not doc_ids:
return
async with self._get_redis_connection() as redis:
pipe = redis.pipeline()
for doc_id in doc_ids:
pipe.delete(f"{self.namespace}:{doc_id}")
results = await pipe.execute()
deleted_count = sum(results)
logger.info(f"Deleted {deleted_count} of {len(doc_ids)} doc status entries from {self.namespace}")
async def drop(self) -> dict[str, str]:
"""Drop all document status data from storage and clean up resources"""
try:
async with self._get_redis_connection() as redis:
# Use SCAN to find all keys with the namespace prefix
pattern = f"{self.namespace}:*"
cursor = 0
deleted_count = 0
while True:
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
if keys:
# Delete keys in batches
pipe = redis.pipeline()
for key in keys:
pipe.delete(key)
results = await pipe.execute()
deleted_count += sum(results)
if cursor == 0:
break
logger.info(f"Dropped {deleted_count} doc status keys from {self.namespace}")
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping doc status {self.namespace}: {e}")
return {"status": "error", "message": str(e)}

View File

@@ -257,7 +257,7 @@ class TiDBKVStorage(BaseKVStorage):
################ INSERT full_doc AND chunks ################
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
logger.debug(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
left_data = {k: v for k, v in data.items() if k not in self._data}
@@ -454,11 +454,9 @@ class TiDBVectorDBStorage(BaseVectorStorage):
###### INSERT entities And relationships ######
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
logger.debug(f"Inserting {len(data)} vectors to {self.namespace}")
# Get current time as UNIX timestamp
import time

View File

@@ -399,10 +399,10 @@ async def _get_cached_extraction_results(
"""
cached_results = {}
# Get all cached data for "default" mode (entity extraction cache)
default_cache = await llm_response_cache.get_by_id("default") or {}
# Get all cached data (flattened cache structure)
all_cache = await llm_response_cache.get_all()
for cache_key, cache_entry in default_cache.items():
for cache_key, cache_entry in all_cache.items():
if (
isinstance(cache_entry, dict)
and cache_entry.get("cache_type") == "extract"
@@ -1387,7 +1387,7 @@ async def kg_query(
use_model_func = partial(use_model_func, _priority=5)
# Handle cache
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
args_hash = compute_args_hash(query_param.mode, query)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
)
@@ -1546,7 +1546,7 @@ async def extract_keywords_only(
"""
# 1. Handle cache if needed - add cache type for keywords
args_hash = compute_args_hash(param.mode, text, cache_type="keywords")
args_hash = compute_args_hash(param.mode, text)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, text, param.mode, cache_type="keywords"
)
@@ -2413,7 +2413,7 @@ async def naive_query(
use_model_func = partial(use_model_func, _priority=5)
# Handle cache
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
args_hash = compute_args_hash(query_param.mode, query)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
)
@@ -2529,7 +2529,7 @@ async def kg_query_with_keywords(
# Apply higher priority (5) to query relation LLM function
use_model_func = partial(use_model_func, _priority=5)
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
args_hash = compute_args_hash(query_param.mode, query)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
)

View File

@@ -14,7 +14,6 @@ from functools import wraps
from hashlib import md5
from typing import Any, Protocol, Callable, TYPE_CHECKING, List
import numpy as np
from lightrag.prompt import PROMPTS
from dotenv import load_dotenv
from lightrag.constants import (
DEFAULT_LOG_MAX_BYTES,
@@ -278,11 +277,10 @@ def convert_response_to_json(response: str) -> dict[str, Any]:
raise e from None
def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
def compute_args_hash(*args: Any) -> str:
"""Compute a hash for the given arguments.
Args:
*args: Arguments to hash
cache_type: Type of cache (e.g., 'keywords', 'query', 'extract')
Returns:
str: Hash string
"""
@@ -290,13 +288,40 @@ def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
# Convert all arguments to strings and join them
args_str = "".join([str(arg) for arg in args])
if cache_type:
args_str = f"{cache_type}:{args_str}"
# Compute MD5 hash
return hashlib.md5(args_str.encode()).hexdigest()
def generate_cache_key(mode: str, cache_type: str, hash_value: str) -> str:
"""Generate a flattened cache key in the format {mode}:{cache_type}:{hash}
Args:
mode: Cache mode (e.g., 'default', 'local', 'global')
cache_type: Type of cache (e.g., 'extract', 'query', 'keywords')
hash_value: Hash value from compute_args_hash
Returns:
str: Flattened cache key
"""
return f"{mode}:{cache_type}:{hash_value}"
def parse_cache_key(cache_key: str) -> tuple[str, str, str] | None:
"""Parse a flattened cache key back into its components
Args:
cache_key: Flattened cache key in format {mode}:{cache_type}:{hash}
Returns:
tuple[str, str, str] | None: (mode, cache_type, hash) or None if invalid format
"""
parts = cache_key.split(":", 2)
if len(parts) == 3:
return parts[0], parts[1], parts[2]
return None
def compute_mdhash_id(content: str, prefix: str = "") -> str:
"""
Compute a unique ID for a given content string.
@@ -783,131 +808,6 @@ def process_combine_contexts(*context_lists):
return combined_data
async def get_best_cached_response(
hashing_kv,
current_embedding,
similarity_threshold=0.95,
mode="default",
use_llm_check=False,
llm_func=None,
original_prompt=None,
cache_type=None,
) -> str | None:
logger.debug(
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
)
mode_cache = await hashing_kv.get_by_id(mode)
if not mode_cache:
return None
best_similarity = -1
best_response = None
best_prompt = None
best_cache_id = None
# Only iterate through cache entries for this mode
for cache_id, cache_data in mode_cache.items():
# Skip if cache_type doesn't match
if cache_type and cache_data.get("cache_type") != cache_type:
continue
# Check if cache data is valid
if cache_data["embedding"] is None:
continue
try:
# Safely convert cached embedding
cached_quantized = np.frombuffer(
bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
).reshape(cache_data["embedding_shape"])
# Ensure min_val and max_val are valid float values
embedding_min = cache_data.get("embedding_min")
embedding_max = cache_data.get("embedding_max")
if (
embedding_min is None
or embedding_max is None
or embedding_min >= embedding_max
):
logger.warning(
f"Invalid embedding min/max values: min={embedding_min}, max={embedding_max}"
)
continue
cached_embedding = dequantize_embedding(
cached_quantized,
embedding_min,
embedding_max,
)
except Exception as e:
logger.warning(f"Error processing cached embedding: {str(e)}")
continue
similarity = cosine_similarity(current_embedding, cached_embedding)
if similarity > best_similarity:
best_similarity = similarity
best_response = cache_data["return"]
best_prompt = cache_data["original_prompt"]
best_cache_id = cache_id
if best_similarity > similarity_threshold:
# If LLM check is enabled and all required parameters are provided
if (
use_llm_check
and llm_func
and original_prompt
and best_prompt
and best_response is not None
):
compare_prompt = PROMPTS["similarity_check"].format(
original_prompt=original_prompt, cached_prompt=best_prompt
)
try:
llm_result = await llm_func(compare_prompt)
llm_result = llm_result.strip()
llm_similarity = float(llm_result)
# Replace vector similarity with LLM similarity score
best_similarity = llm_similarity
if best_similarity < similarity_threshold:
log_data = {
"event": "cache_rejected_by_llm",
"type": cache_type,
"mode": mode,
"original_question": original_prompt[:100] + "..."
if len(original_prompt) > 100
else original_prompt,
"cached_question": best_prompt[:100] + "..."
if len(best_prompt) > 100
else best_prompt,
"similarity_score": round(best_similarity, 4),
"threshold": similarity_threshold,
}
logger.debug(json.dumps(log_data, ensure_ascii=False))
logger.info(f"Cache rejected by LLM(mode:{mode} tpye:{cache_type})")
return None
except Exception as e: # Catch all possible exceptions
logger.warning(f"LLM similarity check failed: {e}")
return None # Return None directly when LLM check fails
prompt_display = (
best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt
)
log_data = {
"event": "cache_hit",
"type": cache_type,
"mode": mode,
"similarity": round(best_similarity, 4),
"cache_id": best_cache_id,
"original_prompt": prompt_display,
}
logger.debug(json.dumps(log_data, ensure_ascii=False))
return best_response
return None
def cosine_similarity(v1, v2):
"""Calculate cosine similarity between two vectors"""
dot_product = np.dot(v1, v2)
@@ -957,7 +857,7 @@ async def handle_cache(
mode="default",
cache_type=None,
):
"""Generic cache handling function"""
"""Generic cache handling function with flattened cache keys"""
if hashing_kv is None:
return None, None, None, None
@@ -968,15 +868,14 @@ async def handle_cache(
if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"):
return None, None, None, None
if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
else:
mode_cache = await hashing_kv.get_by_id(mode) or {}
if args_hash in mode_cache:
logger.debug(f"Non-embedding cached hit(mode:{mode} type:{cache_type})")
return mode_cache[args_hash]["return"], None, None, None
# Use flattened cache key format: {mode}:{cache_type}:{hash}
flattened_key = generate_cache_key(mode, cache_type, args_hash)
cache_entry = await hashing_kv.get_by_id(flattened_key)
if cache_entry:
logger.debug(f"Flattened cache hit(key:{flattened_key})")
return cache_entry["return"], None, None, None
logger.debug(f"Non-embedding cached missed(mode:{mode} type:{cache_type})")
logger.debug(f"Cache missed(mode:{mode} type:{cache_type})")
return None, None, None, None
@@ -994,7 +893,7 @@ class CacheData:
async def save_to_cache(hashing_kv, cache_data: CacheData):
"""Save data to cache, with improved handling for streaming responses and duplicate content.
"""Save data to cache using flattened key structure.
Args:
hashing_kv: The key-value storage for caching
@@ -1009,26 +908,21 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
logger.debug("Streaming response detected, skipping cache")
return
# Get existing cache data
if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = (
await hashing_kv.get_by_mode_and_id(cache_data.mode, cache_data.args_hash)
or {}
)
else:
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
# Use flattened cache key format: {mode}:{cache_type}:{hash}
flattened_key = generate_cache_key(
cache_data.mode, cache_data.cache_type, cache_data.args_hash
)
# Check if we already have identical content cached
if cache_data.args_hash in mode_cache:
existing_content = mode_cache[cache_data.args_hash].get("return")
existing_cache = await hashing_kv.get_by_id(flattened_key)
if existing_cache:
existing_content = existing_cache.get("return")
if existing_content == cache_data.content:
logger.info(
f"Cache content unchanged for {cache_data.args_hash}, skipping update"
)
logger.info(f"Cache content unchanged for {flattened_key}, skipping update")
return
# Update cache with new content
mode_cache[cache_data.args_hash] = {
# Create cache entry with flattened structure
cache_entry = {
"return": cache_data.content,
"cache_type": cache_data.cache_type,
"chunk_id": cache_data.chunk_id if cache_data.chunk_id is not None else None,
@@ -1043,10 +937,10 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
"original_prompt": cache_data.prompt,
}
logger.info(f" == LLM cache == saving {cache_data.mode}: {cache_data.args_hash}")
logger.info(f" == LLM cache == saving: {flattened_key}")
# Only upsert if there's actual new content
await hashing_kv.upsert({cache_data.mode: mode_cache})
# Save using flattened key
await hashing_kv.upsert({flattened_key: cache_entry})
def safe_unicode_decode(content):