From 271722405f8e48c6e2641e7e665da672a4d74874 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 2 Jul 2025 16:11:53 +0800 Subject: [PATCH] feat: Flatten LLM cache structure for improved recall efficiency Refactored the LLM cache to a flat Key-Value (KV) structure, replacing the previous nested format. The old structure used the 'mode' as a key and stored specific cache content as JSON nested under it. This change significantly enhances cache recall efficiency. --- .../copy_llm_cache_to_another_storage.py | 52 +- lightrag/kg/__init__.py | 3 + lightrag/kg/chroma_impl.py | 3 +- lightrag/kg/json_kv_impl.py | 111 +++- lightrag/kg/milvus_impl.py | 2 +- lightrag/kg/mongo_impl.py | 68 +-- lightrag/kg/postgres_impl.py | 229 +++++--- lightrag/kg/qdrant_impl.py | 2 +- lightrag/kg/redis_impl.py | 511 ++++++++++++++++-- lightrag/kg/tidb_impl.py | 6 +- lightrag/operate.py | 14 +- lightrag/utils.py | 210 ++----- 12 files changed, 836 insertions(+), 375 deletions(-) diff --git a/examples/unofficial-sample/copy_llm_cache_to_another_storage.py b/examples/unofficial-sample/copy_llm_cache_to_another_storage.py index 60fa6192..1671b5d5 100644 --- a/examples/unofficial-sample/copy_llm_cache_to_another_storage.py +++ b/examples/unofficial-sample/copy_llm_cache_to_another_storage.py @@ -52,18 +52,23 @@ async def copy_from_postgres_to_json(): embedding_func=None, ) + # Get all cache data using the new flattened structure + all_data = await from_llm_response_cache.get_all() + + # Convert flattened data to hierarchical structure for JsonKVStorage kv = {} - for c_id in await from_llm_response_cache.all_keys(): - print(f"Copying {c_id}") - workspace = c_id["workspace"] - mode = c_id["mode"] - _id = c_id["id"] - postgres_db.workspace = workspace - obj = await from_llm_response_cache.get_by_mode_and_id(mode, _id) - if mode not in kv: - kv[mode] = {} - kv[mode][_id] = obj[_id] - print(f"Object {obj}") + for flattened_key, cache_entry in all_data.items(): + # Parse flattened key: {mode}:{cache_type}:{hash} + parts = flattened_key.split(":", 2) + if len(parts) == 3: + mode, cache_type, hash_value = parts + if mode not in kv: + kv[mode] = {} + kv[mode][hash_value] = cache_entry + print(f"Copying {flattened_key} -> {mode}[{hash_value}]") + else: + print(f"Skipping invalid key format: {flattened_key}") + await to_llm_response_cache.upsert(kv) await to_llm_response_cache.index_done_callback() print("Mission accomplished!") @@ -85,13 +90,24 @@ async def copy_from_json_to_postgres(): db=postgres_db, ) - for mode in await from_llm_response_cache.all_keys(): - print(f"Copying {mode}") - caches = await from_llm_response_cache.get_by_id(mode) - for k, v in caches.items(): - item = {mode: {k: v}} - print(f"\tCopying {item}") - await to_llm_response_cache.upsert(item) + # Get all cache data from JsonKVStorage (hierarchical structure) + all_data = await from_llm_response_cache.get_all() + + # Convert hierarchical data to flattened structure for PGKVStorage + flattened_data = {} + for mode, mode_data in all_data.items(): + print(f"Processing mode: {mode}") + for hash_value, cache_entry in mode_data.items(): + # Determine cache_type from cache entry or use default + cache_type = cache_entry.get("cache_type", "extract") + # Create flattened key: {mode}:{cache_type}:{hash} + flattened_key = f"{mode}:{cache_type}:{hash_value}" + flattened_data[flattened_key] = cache_entry + print(f"\tConverting {mode}[{hash_value}] -> {flattened_key}") + + # Upsert the flattened data + await to_llm_response_cache.upsert(flattened_data) + print("Mission accomplished!") if __name__ == "__main__": diff --git a/lightrag/kg/__init__.py b/lightrag/kg/__init__.py index b4ba0983..697750e7 100644 --- a/lightrag/kg/__init__.py +++ b/lightrag/kg/__init__.py @@ -37,6 +37,7 @@ STORAGE_IMPLEMENTATIONS = { "DOC_STATUS_STORAGE": { "implementations": [ "JsonDocStatusStorage", + "RedisDocStatusStorage", "PGDocStatusStorage", "MongoDocStatusStorage", ], @@ -79,6 +80,7 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = { "MongoVectorDBStorage": [], # Document Status Storage Implementations "JsonDocStatusStorage": [], + "RedisDocStatusStorage": ["REDIS_URI"], "PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], "MongoDocStatusStorage": [], } @@ -96,6 +98,7 @@ STORAGES = { "MongoGraphStorage": ".kg.mongo_impl", "MongoVectorDBStorage": ".kg.mongo_impl", "RedisKVStorage": ".kg.redis_impl", + "RedisDocStatusStorage": ".kg.redis_impl", "ChromaVectorDBStorage": ".kg.chroma_impl", # "TiDBKVStorage": ".kg.tidb_impl", # "TiDBVectorDBStorage": ".kg.tidb_impl", diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py index c3927a19..ebdd4593 100644 --- a/lightrag/kg/chroma_impl.py +++ b/lightrag/kg/chroma_impl.py @@ -109,7 +109,7 @@ class ChromaVectorDBStorage(BaseVectorStorage): raise async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.info(f"Inserting {len(data)} to {self.namespace}") + logger.debug(f"Inserting {len(data)} to {self.namespace}") if not data: return @@ -234,7 +234,6 @@ class ChromaVectorDBStorage(BaseVectorStorage): ids: List of vector IDs to be deleted """ try: - logger.info(f"Deleting {len(ids)} vectors from {self.namespace}") self._collection.delete(ids=ids) logger.debug( f"Successfully deleted {len(ids)} vectors from {self.namespace}" diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index fa819d4a..d6e2cb70 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -42,19 +42,14 @@ class JsonKVStorage(BaseKVStorage): if need_init: loaded_data = load_json(self._file_name) or {} async with self._storage_lock: - self._data.update(loaded_data) - - # Calculate data count based on namespace - if self.namespace.endswith("cache"): - # For cache namespaces, sum the cache entries across all cache types - data_count = sum( - len(first_level_dict) - for first_level_dict in loaded_data.values() - if isinstance(first_level_dict, dict) + # Migrate legacy cache structure if needed + if self.namespace.endswith("_cache"): + loaded_data = await self._migrate_legacy_cache_structure( + loaded_data ) - else: - # For non-cache namespaces, use the original count method - data_count = len(loaded_data) + + self._data.update(loaded_data) + data_count = len(loaded_data) logger.info( f"Process {os.getpid()} KV load {self.namespace} with {data_count} records" @@ -67,17 +62,8 @@ class JsonKVStorage(BaseKVStorage): dict(self._data) if hasattr(self._data, "_getvalue") else self._data ) - # Calculate data count based on namespace - if self.namespace.endswith("cache"): - # # For cache namespaces, sum the cache entries across all cache types - data_count = sum( - len(first_level_dict) - for first_level_dict in data_dict.values() - if isinstance(first_level_dict, dict) - ) - else: - # For non-cache namespaces, use the original count method - data_count = len(data_dict) + # Calculate data count - all data is now flattened + data_count = len(data_dict) logger.debug( f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}" @@ -150,14 +136,14 @@ class JsonKVStorage(BaseKVStorage): await set_all_update_flags(self.namespace) async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool: - """Delete specific records from storage by by cache mode + """Delete specific records from storage by cache mode Importance notes for in-memory storage: 1. Changes will be persisted to disk during the next index_done_callback 2. update flags to notify other processes that data persistence is needed Args: - ids (list[str]): List of cache mode to be drop from storage + modes (list[str]): List of cache modes to be dropped from storage Returns: True: if the cache drop successfully @@ -167,9 +153,29 @@ class JsonKVStorage(BaseKVStorage): return False try: - await self.delete(modes) + async with self._storage_lock: + keys_to_delete = [] + modes_set = set(modes) # Convert to set for efficient lookup + + for key in list(self._data.keys()): + # Parse flattened cache key: mode:cache_type:hash + parts = key.split(":", 2) + if len(parts) == 3 and parts[0] in modes_set: + keys_to_delete.append(key) + + # Batch delete + for key in keys_to_delete: + self._data.pop(key, None) + + if keys_to_delete: + await set_all_update_flags(self.namespace) + logger.info( + f"Dropped {len(keys_to_delete)} cache entries for modes: {modes}" + ) + return True - except Exception: + except Exception as e: + logger.error(f"Error dropping cache by modes: {e}") return False # async def drop_cache_by_chunk_ids(self, chunk_ids: list[str] | None = None) -> bool: @@ -245,9 +251,58 @@ class JsonKVStorage(BaseKVStorage): logger.error(f"Error dropping {self.namespace}: {e}") return {"status": "error", "message": str(e)} + async def _migrate_legacy_cache_structure(self, data: dict) -> dict: + """Migrate legacy nested cache structure to flattened structure + + Args: + data: Original data dictionary that may contain legacy structure + + Returns: + Migrated data dictionary with flattened cache keys + """ + from lightrag.utils import generate_cache_key + + # Early return if data is empty + if not data: + return data + + # Check first entry to see if it's already in new format + first_key = next(iter(data.keys())) + if ":" in first_key and len(first_key.split(":")) == 3: + # Already in flattened format, return as-is + return data + + migrated_data = {} + migration_count = 0 + + for key, value in data.items(): + # Check if this is a legacy nested cache structure + if isinstance(value, dict) and all( + isinstance(v, dict) and "return" in v for v in value.values() + ): + # This looks like a legacy cache mode with nested structure + mode = key + for cache_hash, cache_entry in value.items(): + cache_type = cache_entry.get("cache_type", "extract") + flattened_key = generate_cache_key(mode, cache_type, cache_hash) + migrated_data[flattened_key] = cache_entry + migration_count += 1 + else: + # Keep non-cache data or already flattened cache data as-is + migrated_data[key] = value + + if migration_count > 0: + logger.info( + f"Migrated {migration_count} legacy cache entries to flattened structure" + ) + # Persist migrated data immediately + write_json(migrated_data, self._file_name) + + return migrated_data + async def finalize(self): """Finalize storage resources Persistence cache data to disk before exiting """ - if self.namespace.endswith("cache"): + if self.namespace.endswith("_cache"): await self.index_done_callback() diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index 23e178bc..70a793f7 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -75,7 +75,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): ) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.info(f"Inserting {len(data)} to {self.namespace}") + logger.debug(f"Inserting {len(data)} to {self.namespace}") if not data: return diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index a8dda1b9..38baff5c 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -15,7 +15,6 @@ from ..base import ( DocStatus, DocStatusStorage, ) -from ..namespace import NameSpace, is_namespace from ..utils import logger, compute_mdhash_id from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..constants import GRAPH_FIELD_SEP @@ -98,17 +97,8 @@ class MongoKVStorage(BaseKVStorage): self._data = None async def get_by_id(self, id: str) -> dict[str, Any] | None: - if id == "default": - # Find all documents with _id starting with "default_" - cursor = self._data.find({"_id": {"$regex": "^default_"}}) - result = {} - async for doc in cursor: - # Use the complete _id as key - result[doc["_id"]] = doc - return result if result else None - else: - # Original behavior for non-"default" ids - return await self._data.find_one({"_id": id}) + # Unified handling for flattened keys + return await self._data.find_one({"_id": id}) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: cursor = self._data.find({"_id": {"$in": ids}}) @@ -133,43 +123,21 @@ class MongoKVStorage(BaseKVStorage): return result async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.info(f"Inserting {len(data)} to {self.namespace}") + logger.debug(f"Inserting {len(data)} to {self.namespace}") if not data: return - if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): - update_tasks: list[Any] = [] - for mode, items in data.items(): - for k, v in items.items(): - key = f"{mode}_{k}" - data[mode][k]["_id"] = f"{mode}_{k}" - update_tasks.append( - self._data.update_one( - {"_id": key}, {"$setOnInsert": v}, upsert=True - ) - ) - await asyncio.gather(*update_tasks) - else: - update_tasks = [] - for k, v in data.items(): - data[k]["_id"] = k - update_tasks.append( - self._data.update_one({"_id": k}, {"$set": v}, upsert=True) - ) - await asyncio.gather(*update_tasks) + # Unified handling for all namespaces with flattened keys + # Use bulk_write for better performance + from pymongo import UpdateOne - async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: - if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): - res = {} - v = await self._data.find_one({"_id": mode + "_" + id}) - if v: - res[id] = v - logger.debug(f"llm_response_cache find one by:{id}") - return res - else: - return None - else: - return None + operations = [] + for k, v in data.items(): + v["_id"] = k # Use flattened key as _id + operations.append(UpdateOne({"_id": k}, {"$set": v}, upsert=True)) + + if operations: + await self._data.bulk_write(operations) async def index_done_callback(self) -> None: # Mongo handles persistence automatically @@ -209,8 +177,8 @@ class MongoKVStorage(BaseKVStorage): return False try: - # Build regex pattern to match documents with the specified modes - pattern = f"^({'|'.join(modes)})_" + # Build regex pattern to match flattened key format: mode:cache_type:hash + pattern = f"^({'|'.join(modes)}):" result = await self._data.delete_many({"_id": {"$regex": pattern}}) logger.info(f"Deleted {result.deleted_count} documents by modes: {modes}") return True @@ -274,7 +242,7 @@ class MongoDocStatusStorage(DocStatusStorage): return data - existing_ids async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.info(f"Inserting {len(data)} to {self.namespace}") + logger.debug(f"Inserting {len(data)} to {self.namespace}") if not data: return update_tasks: list[Any] = [] @@ -1282,7 +1250,7 @@ class MongoVectorDBStorage(BaseVectorStorage): logger.debug("vector index already exist") async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.info(f"Inserting {len(data)} to {self.namespace}") + logger.debug(f"Inserting {len(data)} to {self.namespace}") if not data: return @@ -1371,7 +1339,7 @@ class MongoVectorDBStorage(BaseVectorStorage): Args: ids: List of vector IDs to be deleted """ - logger.info(f"Deleting {len(ids)} vectors from {self.namespace}") + logger.debug(f"Deleting {len(ids)} vectors from {self.namespace}") if not ids: return diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 7feae36c..f4b533d8 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -247,6 +247,116 @@ class PostgreSQLDB: logger.error(f"Failed during data migration to LIGHTRAG_VDB_CHUNKS: {e}") # Do not re-raise, to allow the application to start + async def _check_llm_cache_needs_migration(self): + """Check if LLM cache data needs migration by examining the first record""" + try: + # Only query the first record to determine format + check_sql = """ + SELECT id FROM LIGHTRAG_LLM_CACHE + ORDER BY create_time ASC + LIMIT 1 + """ + result = await self.query(check_sql) + + if result and result.get("id"): + # If id doesn't contain colon, it's old format + return ":" not in result["id"] + + return False # No data or already new format + except Exception as e: + logger.warning(f"Failed to check LLM cache migration status: {e}") + return False + + async def _migrate_llm_cache_to_flattened_keys(self): + """Migrate LLM cache to flattened key format, recalculating hash values""" + try: + # Get all old format data + old_data_sql = """ + SELECT id, mode, original_prompt, return_value, chunk_id, + create_time, update_time + FROM LIGHTRAG_LLM_CACHE + WHERE id NOT LIKE '%:%' + """ + + old_records = await self.query(old_data_sql, multirows=True) + + if not old_records: + logger.info("No old format LLM cache data found, skipping migration") + return + + logger.info( + f"Found {len(old_records)} old format cache records, starting migration..." + ) + + # Import hash calculation function + from ..utils import compute_args_hash + + migrated_count = 0 + + # Migrate data in batches + for record in old_records: + try: + # Recalculate hash using correct method + new_hash = compute_args_hash( + record["mode"], record["original_prompt"] + ) + + # Generate new flattened key + cache_type = "extract" # Default type + new_key = f"{record['mode']}:{cache_type}:{new_hash}" + + # Insert new format data + insert_sql = """ + INSERT INTO LIGHTRAG_LLM_CACHE + (workspace, id, mode, original_prompt, return_value, chunk_id, create_time, update_time) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + ON CONFLICT (workspace, mode, id) DO NOTHING + """ + + await self.execute( + insert_sql, + { + "workspace": self.workspace, + "id": new_key, + "mode": record["mode"], + "original_prompt": record["original_prompt"], + "return_value": record["return_value"], + "chunk_id": record["chunk_id"], + "create_time": record["create_time"], + "update_time": record["update_time"], + }, + ) + + # Delete old data + delete_sql = """ + DELETE FROM LIGHTRAG_LLM_CACHE + WHERE workspace=$1 AND mode=$2 AND id=$3 + """ + await self.execute( + delete_sql, + { + "workspace": self.workspace, + "mode": record["mode"], + "id": record["id"], # Old id + }, + ) + + migrated_count += 1 + + except Exception as e: + logger.warning( + f"Failed to migrate cache record {record['id']}: {e}" + ) + continue + + logger.info( + f"Successfully migrated {migrated_count} cache records to flattened format" + ) + + except Exception as e: + logger.error(f"LLM cache migration failed: {e}") + # Don't raise exception, allow system to continue startup + async def check_tables(self): # First create all tables for k, v in TABLES.items(): @@ -304,6 +414,13 @@ class PostgreSQLDB: except Exception as e: logger.error(f"PostgreSQL, Failed to migrate doc_chunks to vdb_chunks: {e}") + # Check and migrate LLM cache to flattened keys if needed + try: + if await self._check_llm_cache_needs_migration(): + await self._migrate_llm_cache_to_flattened_keys() + except Exception as e: + logger.error(f"PostgreSQL, LLM cache migration failed: {e}") + async def query( self, sql: str, @@ -486,77 +603,48 @@ class PGKVStorage(BaseKVStorage): try: results = await self.db.query(sql, params, multirows=True) - + + # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): - result_dict = {} + processed_results = {} for row in results: - mode = row["mode"] - if mode not in result_dict: - result_dict[mode] = {} - result_dict[mode][row["id"]] = row - return result_dict - else: - return {row["id"]: row for row in results} + # Parse flattened key to extract cache_type + key_parts = row["id"].split(":") + cache_type = key_parts[1] if len(key_parts) >= 3 else "unknown" + + # Map field names and add cache_type for compatibility + processed_row = { + **row, + "return": row.get("return_value", ""), # Map return_value to return + "cache_type": cache_type, # Add cache_type from key + "original_prompt": row.get("original_prompt", ""), + "chunk_id": row.get("chunk_id"), + "mode": row.get("mode", "default") + } + processed_results[row["id"]] = processed_row + return processed_results + + # For other namespaces, return as-is + return {row["id"]: row for row in results} except Exception as e: logger.error(f"Error retrieving all data from {self.namespace}: {e}") return {} async def get_by_id(self, id: str) -> dict[str, Any] | None: - """Get doc_full data by id.""" + """Get data by id.""" sql = SQL_TEMPLATES["get_by_id_" + self.namespace] - if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): - # For LLM cache, the id parameter actually represents the mode - params = {"workspace": self.db.workspace, "mode": id} - array_res = await self.db.query(sql, params, multirows=True) - res = {} - for row in array_res: - # Dynamically add cache_type field based on mode - row_with_cache_type = dict(row) - if id == "default": - row_with_cache_type["cache_type"] = "extract" - else: - row_with_cache_type["cache_type"] = "unknown" - res[row["id"]] = row_with_cache_type - return res if res else None - else: - params = {"workspace": self.db.workspace, "id": id} - response = await self.db.query(sql, params) - return response if response else None - - async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: - """Specifically for llm_response_cache.""" - sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace] - params = {"workspace": self.db.workspace, "mode": mode, "id": id} - if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): - array_res = await self.db.query(sql, params, multirows=True) - res = {} - for row in array_res: - res[row["id"]] = row - return res - else: - return None + params = {"workspace": self.db.workspace, "id": id} + response = await self.db.query(sql, params) + return response if response else None # Query by id async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: - """Get doc_chunks data by id""" + """Get data by ids""" sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( ids=",".join([f"'{id}'" for id in ids]) ) params = {"workspace": self.db.workspace} - if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): - array_res = await self.db.query(sql, params, multirows=True) - modes = set() - dict_res: dict[str, dict] = {} - for row in array_res: - modes.add(row["mode"]) - for mode in modes: - if mode not in dict_res: - dict_res[mode] = {} - for row in array_res: - dict_res[row["mode"]][row["id"]] = row - return [{k: v} for k, v in dict_res.items()] - else: - return await self.db.query(sql, params, multirows=True) + return await self.db.query(sql, params, multirows=True) async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]: """Specifically for llm_response_cache.""" @@ -617,19 +705,18 @@ class PGKVStorage(BaseKVStorage): } await self.db.execute(upsert_sql, _data) elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): - for mode, items in data.items(): - for k, v in items.items(): - upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"] - _data = { - "workspace": self.db.workspace, - "id": k, - "original_prompt": v["original_prompt"], - "return_value": v["return"], - "mode": mode, - "chunk_id": v.get("chunk_id"), - } + for k, v in data.items(): + upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"] + _data = { + "workspace": self.db.workspace, + "id": k, # Use flattened key as id + "original_prompt": v["original_prompt"], + "return_value": v["return"], + "mode": v.get("mode", "default"), # Get mode from data + "chunk_id": v.get("chunk_id"), + } - await self.db.execute(upsert_sql, _data) + await self.db.execute(upsert_sql, _data) async def index_done_callback(self) -> None: # PG handles persistence automatically @@ -1035,8 +1122,8 @@ class PGDocStatusStorage(DocStatusStorage): else: exist_keys = [] new_keys = set([s for s in keys if s not in exist_keys]) - print(f"keys: {keys}") - print(f"new_keys: {new_keys}") + # print(f"keys: {keys}") + # print(f"new_keys: {new_keys}") return new_keys except Exception as e: logger.error( @@ -2621,7 +2708,7 @@ SQL_TEMPLATES = { FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2 """, "get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id - FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 + FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id=$2 """, "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 AND id=$3 diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index 885a23ca..dada278a 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -85,7 +85,7 @@ class QdrantVectorDBStorage(BaseVectorStorage): ) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.info(f"Inserting {len(data)} to {self.namespace}") + logger.debug(f"Inserting {len(data)} to {self.namespace}") if not data: return diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index 65c25bfc..c87a9a4b 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -1,9 +1,10 @@ import os -from typing import Any, final +from typing import Any, final, Union from dataclasses import dataclass import pipmaster as pm import configparser from contextlib import asynccontextmanager +import threading if not pm.is_installed("redis"): pm.install("redis") @@ -13,7 +14,7 @@ from redis.asyncio import Redis, ConnectionPool # type: ignore from redis.exceptions import RedisError, ConnectionError # type: ignore from lightrag.utils import logger -from lightrag.base import BaseKVStorage +from lightrag.base import BaseKVStorage, DocStatusStorage, DocStatus, DocProcessingStatus import json @@ -26,6 +27,41 @@ SOCKET_TIMEOUT = 5.0 SOCKET_CONNECT_TIMEOUT = 3.0 +class RedisConnectionManager: + """Shared Redis connection pool manager to avoid creating multiple pools for the same Redis URI""" + + _pools = {} + _lock = threading.Lock() + + @classmethod + def get_pool(cls, redis_url: str) -> ConnectionPool: + """Get or create a connection pool for the given Redis URL""" + if redis_url not in cls._pools: + with cls._lock: + if redis_url not in cls._pools: + cls._pools[redis_url] = ConnectionPool.from_url( + redis_url, + max_connections=MAX_CONNECTIONS, + decode_responses=True, + socket_timeout=SOCKET_TIMEOUT, + socket_connect_timeout=SOCKET_CONNECT_TIMEOUT, + ) + logger.info(f"Created shared Redis connection pool for {redis_url}") + return cls._pools[redis_url] + + @classmethod + def close_all_pools(cls): + """Close all connection pools (for cleanup)""" + with cls._lock: + for url, pool in cls._pools.items(): + try: + pool.disconnect() + logger.info(f"Closed Redis connection pool for {url}") + except Exception as e: + logger.error(f"Error closing Redis pool for {url}: {e}") + cls._pools.clear() + + @final @dataclass class RedisKVStorage(BaseKVStorage): @@ -33,19 +69,28 @@ class RedisKVStorage(BaseKVStorage): redis_url = os.environ.get( "REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379") ) - # Create a connection pool with limits - self._pool = ConnectionPool.from_url( - redis_url, - max_connections=MAX_CONNECTIONS, - decode_responses=True, - socket_timeout=SOCKET_TIMEOUT, - socket_connect_timeout=SOCKET_CONNECT_TIMEOUT, - ) + # Use shared connection pool + self._pool = RedisConnectionManager.get_pool(redis_url) self._redis = Redis(connection_pool=self._pool) logger.info( - f"Initialized Redis connection pool for {self.namespace} with max {MAX_CONNECTIONS} connections" + f"Initialized Redis KV storage for {self.namespace} using shared connection pool" ) + async def initialize(self): + """Initialize Redis connection and migrate legacy cache structure if needed""" + # Test connection + try: + async with self._get_redis_connection() as redis: + await redis.ping() + logger.info(f"Connected to Redis for namespace {self.namespace}") + except Exception as e: + logger.error(f"Failed to connect to Redis: {e}") + raise + + # Migrate legacy cache structure if this is a cache namespace + if self.namespace.endswith("_cache"): + await self._migrate_legacy_cache_structure() + @asynccontextmanager async def _get_redis_connection(self): """Safe context manager for Redis operations.""" @@ -99,21 +144,57 @@ class RedisKVStorage(BaseKVStorage): logger.error(f"JSON decode error in batch get: {e}") return [None] * len(ids) + async def get_all(self) -> dict[str, Any]: + """Get all data from storage + + Returns: + Dictionary containing all stored data + """ + async with self._get_redis_connection() as redis: + try: + # Get all keys for this namespace + keys = await redis.keys(f"{self.namespace}:*") + + if not keys: + return {} + + # Get all values in batch + pipe = redis.pipeline() + for key in keys: + pipe.get(key) + values = await pipe.execute() + + # Build result dictionary + result = {} + for key, value in zip(keys, values): + if value: + # Extract the ID part (after namespace:) + key_id = key.split(":", 1)[1] + try: + result[key_id] = json.loads(value) + except json.JSONDecodeError as e: + logger.error(f"JSON decode error for key {key}: {e}") + continue + + return result + except Exception as e: + logger.error(f"Error getting all data from Redis: {e}") + return {} + async def filter_keys(self, keys: set[str]) -> set[str]: async with self._get_redis_connection() as redis: pipe = redis.pipeline() - for key in keys: + keys_list = list(keys) # Convert set to list for indexing + for key in keys_list: pipe.exists(f"{self.namespace}:{key}") results = await pipe.execute() - existing_ids = {keys[i] for i, exists in enumerate(results) if exists} + existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists} return set(keys) - existing_ids async def upsert(self, data: dict[str, dict[str, Any]]) -> None: if not data: return - - logger.info(f"Inserting {len(data)} items to {self.namespace}") async with self._get_redis_connection() as redis: try: pipe = redis.pipeline() @@ -148,13 +229,13 @@ class RedisKVStorage(BaseKVStorage): ) async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool: - """Delete specific records from storage by by cache mode + """Delete specific records from storage by cache mode Importance notes for Redis storage: 1. This will immediately delete the specified cache modes from Redis Args: - modes (list[str]): List of cache mode to be drop from storage + modes (list[str]): List of cache modes to be dropped from storage Returns: True: if the cache drop successfully @@ -164,9 +245,43 @@ class RedisKVStorage(BaseKVStorage): return False try: - await self.delete(modes) + async with self._get_redis_connection() as redis: + keys_to_delete = [] + + # Find matching keys for each mode using SCAN + for mode in modes: + # Use correct pattern to match flattened cache key format {namespace}:{mode}:{cache_type}:{hash} + pattern = f"{self.namespace}:{mode}:*" + cursor = 0 + mode_keys = [] + + while True: + cursor, keys = await redis.scan(cursor, match=pattern, count=1000) + if keys: + mode_keys.extend(keys) + + if cursor == 0: + break + + keys_to_delete.extend(mode_keys) + logger.info(f"Found {len(mode_keys)} keys for mode '{mode}' with pattern '{pattern}'") + + if keys_to_delete: + # Batch delete + pipe = redis.pipeline() + for key in keys_to_delete: + pipe.delete(key) + results = await pipe.execute() + deleted_count = sum(results) + logger.info( + f"Dropped {deleted_count} cache entries for modes: {modes}" + ) + else: + logger.warning(f"No cache entries found for modes: {modes}") + return True - except Exception: + except Exception as e: + logger.error(f"Error dropping cache by modes in Redis: {e}") return False async def drop(self) -> dict[str, str]: @@ -177,24 +292,350 @@ class RedisKVStorage(BaseKVStorage): """ async with self._get_redis_connection() as redis: try: - keys = await redis.keys(f"{self.namespace}:*") + # Use SCAN to find all keys with the namespace prefix + pattern = f"{self.namespace}:*" + cursor = 0 + deleted_count = 0 + + while True: + cursor, keys = await redis.scan(cursor, match=pattern, count=1000) + if keys: + # Delete keys in batches + pipe = redis.pipeline() + for key in keys: + pipe.delete(key) + results = await pipe.execute() + deleted_count += sum(results) + + if cursor == 0: + break - if keys: - pipe = redis.pipeline() - for key in keys: - pipe.delete(key) - results = await pipe.execute() - deleted_count = sum(results) - - logger.info(f"Dropped {deleted_count} keys from {self.namespace}") - return { - "status": "success", - "message": f"{deleted_count} keys dropped", - } - else: - logger.info(f"No keys found to drop in {self.namespace}") - return {"status": "success", "message": "no keys to drop"} + logger.info(f"Dropped {deleted_count} keys from {self.namespace}") + return { + "status": "success", + "message": f"{deleted_count} keys dropped", + } except Exception as e: logger.error(f"Error dropping keys from {self.namespace}: {e}") return {"status": "error", "message": str(e)} + + async def _migrate_legacy_cache_structure(self): + """Migrate legacy nested cache structure to flattened structure for Redis + + Redis already stores data in a flattened way, but we need to check for + legacy keys that might contain nested JSON structures and migrate them. + + Early exit if any flattened key is found (indicating migration already done). + """ + from lightrag.utils import generate_cache_key + + async with self._get_redis_connection() as redis: + # Get all keys for this namespace + keys = await redis.keys(f"{self.namespace}:*") + + if not keys: + return + + # Check if we have any flattened keys already - if so, skip migration + has_flattened_keys = False + keys_to_migrate = [] + + for key in keys: + # Extract the ID part (after namespace:) + key_id = key.split(":", 1)[1] + + # Check if already in flattened format (contains exactly 2 colons for mode:cache_type:hash) + if ":" in key_id and len(key_id.split(":")) == 3: + has_flattened_keys = True + break # Early exit - migration already done + + # Get the data to check if it's a legacy nested structure + data = await redis.get(key) + if data: + try: + parsed_data = json.loads(data) + # Check if this looks like a legacy cache mode with nested structure + if isinstance(parsed_data, dict) and all( + isinstance(v, dict) and "return" in v + for v in parsed_data.values() + ): + keys_to_migrate.append((key, key_id, parsed_data)) + except json.JSONDecodeError: + continue + + # If we found any flattened keys, assume migration is already done + if has_flattened_keys: + logger.debug( + f"Found flattened cache keys in {self.namespace}, skipping migration" + ) + return + + if not keys_to_migrate: + return + + # Perform migration + pipe = redis.pipeline() + migration_count = 0 + + for old_key, mode, nested_data in keys_to_migrate: + # Delete the old key + pipe.delete(old_key) + + # Create new flattened keys + for cache_hash, cache_entry in nested_data.items(): + cache_type = cache_entry.get("cache_type", "extract") + flattened_key = generate_cache_key(mode, cache_type, cache_hash) + full_key = f"{self.namespace}:{flattened_key}" + pipe.set(full_key, json.dumps(cache_entry)) + migration_count += 1 + + await pipe.execute() + + if migration_count > 0: + logger.info( + f"Migrated {migration_count} legacy cache entries to flattened structure in Redis" + ) + + +@final +@dataclass +class RedisDocStatusStorage(DocStatusStorage): + """Redis implementation of document status storage""" + + def __post_init__(self): + redis_url = os.environ.get( + "REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379") + ) + # Use shared connection pool + self._pool = RedisConnectionManager.get_pool(redis_url) + self._redis = Redis(connection_pool=self._pool) + logger.info( + f"Initialized Redis doc status storage for {self.namespace} using shared connection pool" + ) + + async def initialize(self): + """Initialize Redis connection""" + try: + async with self._get_redis_connection() as redis: + await redis.ping() + logger.info(f"Connected to Redis for doc status namespace {self.namespace}") + except Exception as e: + logger.error(f"Failed to connect to Redis for doc status: {e}") + raise + + @asynccontextmanager + async def _get_redis_connection(self): + """Safe context manager for Redis operations.""" + try: + yield self._redis + except ConnectionError as e: + logger.error(f"Redis connection error in doc status {self.namespace}: {e}") + raise + except RedisError as e: + logger.error(f"Redis operation error in doc status {self.namespace}: {e}") + raise + except Exception as e: + logger.error( + f"Unexpected error in Redis doc status operation for {self.namespace}: {e}" + ) + raise + + async def close(self): + """Close the Redis connection.""" + if hasattr(self, "_redis") and self._redis: + await self._redis.close() + logger.debug(f"Closed Redis connection for doc status {self.namespace}") + + async def __aenter__(self): + """Support for async context manager.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Ensure Redis resources are cleaned up when exiting context.""" + await self.close() + + async def filter_keys(self, keys: set[str]) -> set[str]: + """Return keys that should be processed (not in storage or not successfully processed)""" + async with self._get_redis_connection() as redis: + pipe = redis.pipeline() + keys_list = list(keys) + for key in keys_list: + pipe.exists(f"{self.namespace}:{key}") + results = await pipe.execute() + + existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists} + return set(keys) - existing_ids + + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + result: list[dict[str, Any]] = [] + async with self._get_redis_connection() as redis: + try: + pipe = redis.pipeline() + for id in ids: + pipe.get(f"{self.namespace}:{id}") + results = await pipe.execute() + + for result_data in results: + if result_data: + try: + result.append(json.loads(result_data)) + except json.JSONDecodeError as e: + logger.error(f"JSON decode error in get_by_ids: {e}") + continue + except Exception as e: + logger.error(f"Error in get_by_ids: {e}") + return result + + async def get_status_counts(self) -> dict[str, int]: + """Get counts of documents in each status""" + counts = {status.value: 0 for status in DocStatus} + async with self._get_redis_connection() as redis: + try: + # Use SCAN to iterate through all keys in the namespace + cursor = 0 + while True: + cursor, keys = await redis.scan(cursor, match=f"{self.namespace}:*", count=1000) + if keys: + # Get all values in batch + pipe = redis.pipeline() + for key in keys: + pipe.get(key) + values = await pipe.execute() + + # Count statuses + for value in values: + if value: + try: + doc_data = json.loads(value) + status = doc_data.get("status") + if status in counts: + counts[status] += 1 + except json.JSONDecodeError: + continue + + if cursor == 0: + break + except Exception as e: + logger.error(f"Error getting status counts: {e}") + + return counts + + async def get_docs_by_status( + self, status: DocStatus + ) -> dict[str, DocProcessingStatus]: + """Get all documents with a specific status""" + result = {} + async with self._get_redis_connection() as redis: + try: + # Use SCAN to iterate through all keys in the namespace + cursor = 0 + while True: + cursor, keys = await redis.scan(cursor, match=f"{self.namespace}:*", count=1000) + if keys: + # Get all values in batch + pipe = redis.pipeline() + for key in keys: + pipe.get(key) + values = await pipe.execute() + + # Filter by status and create DocProcessingStatus objects + for key, value in zip(keys, values): + if value: + try: + doc_data = json.loads(value) + if doc_data.get("status") == status.value: + # Extract document ID from key + doc_id = key.split(":", 1)[1] + + # Make a copy of the data to avoid modifying the original + data = doc_data.copy() + # If content is missing, use content_summary as content + if "content" not in data and "content_summary" in data: + data["content"] = data["content_summary"] + # If file_path is not in data, use document id as file path + if "file_path" not in data: + data["file_path"] = "no-file-path" + + result[doc_id] = DocProcessingStatus(**data) + except (json.JSONDecodeError, KeyError) as e: + logger.error(f"Error processing document {key}: {e}") + continue + + if cursor == 0: + break + except Exception as e: + logger.error(f"Error getting docs by status: {e}") + + return result + + async def index_done_callback(self) -> None: + """Redis handles persistence automatically""" + pass + + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: + """Insert or update document status data""" + if not data: + return + + logger.debug(f"Inserting {len(data)} records to {self.namespace}") + async with self._get_redis_connection() as redis: + try: + pipe = redis.pipeline() + for k, v in data.items(): + pipe.set(f"{self.namespace}:{k}", json.dumps(v)) + await pipe.execute() + except json.JSONEncodeError as e: + logger.error(f"JSON encode error during upsert: {e}") + raise + + async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + async with self._get_redis_connection() as redis: + try: + data = await redis.get(f"{self.namespace}:{id}") + return json.loads(data) if data else None + except json.JSONDecodeError as e: + logger.error(f"JSON decode error for id {id}: {e}") + return None + + async def delete(self, doc_ids: list[str]) -> None: + """Delete specific records from storage by their IDs""" + if not doc_ids: + return + + async with self._get_redis_connection() as redis: + pipe = redis.pipeline() + for doc_id in doc_ids: + pipe.delete(f"{self.namespace}:{doc_id}") + + results = await pipe.execute() + deleted_count = sum(results) + logger.info(f"Deleted {deleted_count} of {len(doc_ids)} doc status entries from {self.namespace}") + + async def drop(self) -> dict[str, str]: + """Drop all document status data from storage and clean up resources""" + try: + async with self._get_redis_connection() as redis: + # Use SCAN to find all keys with the namespace prefix + pattern = f"{self.namespace}:*" + cursor = 0 + deleted_count = 0 + + while True: + cursor, keys = await redis.scan(cursor, match=pattern, count=1000) + if keys: + # Delete keys in batches + pipe = redis.pipeline() + for key in keys: + pipe.delete(key) + results = await pipe.execute() + deleted_count += sum(results) + + if cursor == 0: + break + + logger.info(f"Dropped {deleted_count} doc status keys from {self.namespace}") + return {"status": "success", "message": "data dropped"} + except Exception as e: + logger.error(f"Error dropping doc status {self.namespace}: {e}") + return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 9b9d17a9..06ec1cd5 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -257,7 +257,7 @@ class TiDBKVStorage(BaseKVStorage): ################ INSERT full_doc AND chunks ################ async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.info(f"Inserting {len(data)} to {self.namespace}") + logger.debug(f"Inserting {len(data)} to {self.namespace}") if not data: return left_data = {k: v for k, v in data.items() if k not in self._data} @@ -454,11 +454,9 @@ class TiDBVectorDBStorage(BaseVectorStorage): ###### INSERT entities And relationships ###### async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.info(f"Inserting {len(data)} to {self.namespace}") if not data: return - - logger.info(f"Inserting {len(data)} vectors to {self.namespace}") + logger.debug(f"Inserting {len(data)} vectors to {self.namespace}") # Get current time as UNIX timestamp import time diff --git a/lightrag/operate.py b/lightrag/operate.py index b41b3de9..bd70ceed 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -399,10 +399,10 @@ async def _get_cached_extraction_results( """ cached_results = {} - # Get all cached data for "default" mode (entity extraction cache) - default_cache = await llm_response_cache.get_by_id("default") or {} + # Get all cached data (flattened cache structure) + all_cache = await llm_response_cache.get_all() - for cache_key, cache_entry in default_cache.items(): + for cache_key, cache_entry in all_cache.items(): if ( isinstance(cache_entry, dict) and cache_entry.get("cache_type") == "extract" @@ -1387,7 +1387,7 @@ async def kg_query( use_model_func = partial(use_model_func, _priority=5) # Handle cache - args_hash = compute_args_hash(query_param.mode, query, cache_type="query") + args_hash = compute_args_hash(query_param.mode, query) cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query" ) @@ -1546,7 +1546,7 @@ async def extract_keywords_only( """ # 1. Handle cache if needed - add cache type for keywords - args_hash = compute_args_hash(param.mode, text, cache_type="keywords") + args_hash = compute_args_hash(param.mode, text) cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, text, param.mode, cache_type="keywords" ) @@ -2413,7 +2413,7 @@ async def naive_query( use_model_func = partial(use_model_func, _priority=5) # Handle cache - args_hash = compute_args_hash(query_param.mode, query, cache_type="query") + args_hash = compute_args_hash(query_param.mode, query) cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query" ) @@ -2529,7 +2529,7 @@ async def kg_query_with_keywords( # Apply higher priority (5) to query relation LLM function use_model_func = partial(use_model_func, _priority=5) - args_hash = compute_args_hash(query_param.mode, query, cache_type="query") + args_hash = compute_args_hash(query_param.mode, query) cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query" ) diff --git a/lightrag/utils.py b/lightrag/utils.py index 06b7a468..6c40407b 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -14,7 +14,6 @@ from functools import wraps from hashlib import md5 from typing import Any, Protocol, Callable, TYPE_CHECKING, List import numpy as np -from lightrag.prompt import PROMPTS from dotenv import load_dotenv from lightrag.constants import ( DEFAULT_LOG_MAX_BYTES, @@ -278,11 +277,10 @@ def convert_response_to_json(response: str) -> dict[str, Any]: raise e from None -def compute_args_hash(*args: Any, cache_type: str | None = None) -> str: +def compute_args_hash(*args: Any) -> str: """Compute a hash for the given arguments. Args: *args: Arguments to hash - cache_type: Type of cache (e.g., 'keywords', 'query', 'extract') Returns: str: Hash string """ @@ -290,13 +288,40 @@ def compute_args_hash(*args: Any, cache_type: str | None = None) -> str: # Convert all arguments to strings and join them args_str = "".join([str(arg) for arg in args]) - if cache_type: - args_str = f"{cache_type}:{args_str}" # Compute MD5 hash return hashlib.md5(args_str.encode()).hexdigest() +def generate_cache_key(mode: str, cache_type: str, hash_value: str) -> str: + """Generate a flattened cache key in the format {mode}:{cache_type}:{hash} + + Args: + mode: Cache mode (e.g., 'default', 'local', 'global') + cache_type: Type of cache (e.g., 'extract', 'query', 'keywords') + hash_value: Hash value from compute_args_hash + + Returns: + str: Flattened cache key + """ + return f"{mode}:{cache_type}:{hash_value}" + + +def parse_cache_key(cache_key: str) -> tuple[str, str, str] | None: + """Parse a flattened cache key back into its components + + Args: + cache_key: Flattened cache key in format {mode}:{cache_type}:{hash} + + Returns: + tuple[str, str, str] | None: (mode, cache_type, hash) or None if invalid format + """ + parts = cache_key.split(":", 2) + if len(parts) == 3: + return parts[0], parts[1], parts[2] + return None + + def compute_mdhash_id(content: str, prefix: str = "") -> str: """ Compute a unique ID for a given content string. @@ -783,131 +808,6 @@ def process_combine_contexts(*context_lists): return combined_data -async def get_best_cached_response( - hashing_kv, - current_embedding, - similarity_threshold=0.95, - mode="default", - use_llm_check=False, - llm_func=None, - original_prompt=None, - cache_type=None, -) -> str | None: - logger.debug( - f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}" - ) - mode_cache = await hashing_kv.get_by_id(mode) - if not mode_cache: - return None - - best_similarity = -1 - best_response = None - best_prompt = None - best_cache_id = None - - # Only iterate through cache entries for this mode - for cache_id, cache_data in mode_cache.items(): - # Skip if cache_type doesn't match - if cache_type and cache_data.get("cache_type") != cache_type: - continue - - # Check if cache data is valid - if cache_data["embedding"] is None: - continue - - try: - # Safely convert cached embedding - cached_quantized = np.frombuffer( - bytes.fromhex(cache_data["embedding"]), dtype=np.uint8 - ).reshape(cache_data["embedding_shape"]) - - # Ensure min_val and max_val are valid float values - embedding_min = cache_data.get("embedding_min") - embedding_max = cache_data.get("embedding_max") - - if ( - embedding_min is None - or embedding_max is None - or embedding_min >= embedding_max - ): - logger.warning( - f"Invalid embedding min/max values: min={embedding_min}, max={embedding_max}" - ) - continue - - cached_embedding = dequantize_embedding( - cached_quantized, - embedding_min, - embedding_max, - ) - except Exception as e: - logger.warning(f"Error processing cached embedding: {str(e)}") - continue - - similarity = cosine_similarity(current_embedding, cached_embedding) - if similarity > best_similarity: - best_similarity = similarity - best_response = cache_data["return"] - best_prompt = cache_data["original_prompt"] - best_cache_id = cache_id - - if best_similarity > similarity_threshold: - # If LLM check is enabled and all required parameters are provided - if ( - use_llm_check - and llm_func - and original_prompt - and best_prompt - and best_response is not None - ): - compare_prompt = PROMPTS["similarity_check"].format( - original_prompt=original_prompt, cached_prompt=best_prompt - ) - - try: - llm_result = await llm_func(compare_prompt) - llm_result = llm_result.strip() - llm_similarity = float(llm_result) - - # Replace vector similarity with LLM similarity score - best_similarity = llm_similarity - if best_similarity < similarity_threshold: - log_data = { - "event": "cache_rejected_by_llm", - "type": cache_type, - "mode": mode, - "original_question": original_prompt[:100] + "..." - if len(original_prompt) > 100 - else original_prompt, - "cached_question": best_prompt[:100] + "..." - if len(best_prompt) > 100 - else best_prompt, - "similarity_score": round(best_similarity, 4), - "threshold": similarity_threshold, - } - logger.debug(json.dumps(log_data, ensure_ascii=False)) - logger.info(f"Cache rejected by LLM(mode:{mode} tpye:{cache_type})") - return None - except Exception as e: # Catch all possible exceptions - logger.warning(f"LLM similarity check failed: {e}") - return None # Return None directly when LLM check fails - - prompt_display = ( - best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt - ) - log_data = { - "event": "cache_hit", - "type": cache_type, - "mode": mode, - "similarity": round(best_similarity, 4), - "cache_id": best_cache_id, - "original_prompt": prompt_display, - } - logger.debug(json.dumps(log_data, ensure_ascii=False)) - return best_response - return None - - def cosine_similarity(v1, v2): """Calculate cosine similarity between two vectors""" dot_product = np.dot(v1, v2) @@ -957,7 +857,7 @@ async def handle_cache( mode="default", cache_type=None, ): - """Generic cache handling function""" + """Generic cache handling function with flattened cache keys""" if hashing_kv is None: return None, None, None, None @@ -968,15 +868,14 @@ async def handle_cache( if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"): return None, None, None, None - if exists_func(hashing_kv, "get_by_mode_and_id"): - mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {} - else: - mode_cache = await hashing_kv.get_by_id(mode) or {} - if args_hash in mode_cache: - logger.debug(f"Non-embedding cached hit(mode:{mode} type:{cache_type})") - return mode_cache[args_hash]["return"], None, None, None + # Use flattened cache key format: {mode}:{cache_type}:{hash} + flattened_key = generate_cache_key(mode, cache_type, args_hash) + cache_entry = await hashing_kv.get_by_id(flattened_key) + if cache_entry: + logger.debug(f"Flattened cache hit(key:{flattened_key})") + return cache_entry["return"], None, None, None - logger.debug(f"Non-embedding cached missed(mode:{mode} type:{cache_type})") + logger.debug(f"Cache missed(mode:{mode} type:{cache_type})") return None, None, None, None @@ -994,7 +893,7 @@ class CacheData: async def save_to_cache(hashing_kv, cache_data: CacheData): - """Save data to cache, with improved handling for streaming responses and duplicate content. + """Save data to cache using flattened key structure. Args: hashing_kv: The key-value storage for caching @@ -1009,26 +908,21 @@ async def save_to_cache(hashing_kv, cache_data: CacheData): logger.debug("Streaming response detected, skipping cache") return - # Get existing cache data - if exists_func(hashing_kv, "get_by_mode_and_id"): - mode_cache = ( - await hashing_kv.get_by_mode_and_id(cache_data.mode, cache_data.args_hash) - or {} - ) - else: - mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {} + # Use flattened cache key format: {mode}:{cache_type}:{hash} + flattened_key = generate_cache_key( + cache_data.mode, cache_data.cache_type, cache_data.args_hash + ) # Check if we already have identical content cached - if cache_data.args_hash in mode_cache: - existing_content = mode_cache[cache_data.args_hash].get("return") + existing_cache = await hashing_kv.get_by_id(flattened_key) + if existing_cache: + existing_content = existing_cache.get("return") if existing_content == cache_data.content: - logger.info( - f"Cache content unchanged for {cache_data.args_hash}, skipping update" - ) + logger.info(f"Cache content unchanged for {flattened_key}, skipping update") return - # Update cache with new content - mode_cache[cache_data.args_hash] = { + # Create cache entry with flattened structure + cache_entry = { "return": cache_data.content, "cache_type": cache_data.cache_type, "chunk_id": cache_data.chunk_id if cache_data.chunk_id is not None else None, @@ -1043,10 +937,10 @@ async def save_to_cache(hashing_kv, cache_data: CacheData): "original_prompt": cache_data.prompt, } - logger.info(f" == LLM cache == saving {cache_data.mode}: {cache_data.args_hash}") + logger.info(f" == LLM cache == saving: {flattened_key}") - # Only upsert if there's actual new content - await hashing_kv.upsert({cache_data.mode: mode_cache}) + # Save using flattened key + await hashing_kv.upsert({flattened_key: cache_entry}) def safe_unicode_decode(content):