From ff1927d362f5a14656ed9df9a521af295604e26f Mon Sep 17 00:00:00 2001 From: DavIvek Date: Thu, 26 Jun 2025 16:15:56 +0200 Subject: [PATCH 01/30] add Memgraph graph storage backend --- config.ini.example | 3 + examples/graph_visual_with_neo4j.py | 2 +- examples/lightrag_openai_demo.py | 1 + lightrag/kg/__init__.py | 3 + lightrag/kg/memgraph_impl.py | 423 ++++++++++++++++++++++++++++ 5 files changed, 431 insertions(+), 1 deletion(-) create mode 100644 lightrag/kg/memgraph_impl.py diff --git a/config.ini.example b/config.ini.example index 63d9c2c0..94d300a1 100644 --- a/config.ini.example +++ b/config.ini.example @@ -21,3 +21,6 @@ password = your_password database = your_database workspace = default # 可选,默认为default max_connections = 12 + +[memgraph] +uri = bolt://localhost:7687 diff --git a/examples/graph_visual_with_neo4j.py b/examples/graph_visual_with_neo4j.py index 1cd2e7a3..e06c248c 100644 --- a/examples/graph_visual_with_neo4j.py +++ b/examples/graph_visual_with_neo4j.py @@ -11,7 +11,7 @@ BATCH_SIZE_EDGES = 100 # Neo4j connection credentials NEO4J_URI = "bolt://localhost:7687" NEO4J_USERNAME = "neo4j" -NEO4J_PASSWORD = "your_password" +NEO4J_PASSWORD = "david123" def xml_to_json(xml_file): diff --git a/examples/lightrag_openai_demo.py b/examples/lightrag_openai_demo.py index fa0b37f1..e573ec41 100644 --- a/examples/lightrag_openai_demo.py +++ b/examples/lightrag_openai_demo.py @@ -82,6 +82,7 @@ async def initialize_rag(): working_dir=WORKING_DIR, embedding_func=openai_embed, llm_model_func=gpt_4o_mini_complete, + graph_storage="MemgraphStorage", ) await rag.initialize_storages() diff --git a/lightrag/kg/__init__.py b/lightrag/kg/__init__.py index b4ba0983..3398b135 100644 --- a/lightrag/kg/__init__.py +++ b/lightrag/kg/__init__.py @@ -15,6 +15,7 @@ STORAGE_IMPLEMENTATIONS = { "Neo4JStorage", "PGGraphStorage", "MongoGraphStorage", + "MemgraphStorage", # "AGEStorage", # "TiDBGraphStorage", # "GremlinStorage", @@ -56,6 +57,7 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = { "NetworkXStorage": [], "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"], "MongoGraphStorage": [], + "MemgraphStorage": ["MEMGRAPH_URI"], # "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], "AGEStorage": [ "AGE_POSTGRES_DB", @@ -108,6 +110,7 @@ STORAGES = { "PGDocStatusStorage": ".kg.postgres_impl", "FaissVectorDBStorage": ".kg.faiss_impl", "QdrantVectorDBStorage": ".kg.qdrant_impl", + "MemgraphStorage": ".kg.memgraph_impl", } diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py new file mode 100644 index 00000000..cb46cdc6 --- /dev/null +++ b/lightrag/kg/memgraph_impl.py @@ -0,0 +1,423 @@ +import os +import re +from dataclasses import dataclass +from typing import final +import configparser + +from ..utils import logger +from ..base import BaseGraphStorage +from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge +from ..constants import GRAPH_FIELD_SEP +import pipmaster as pm + +if not pm.is_installed("neo4j"): + pm.install("neo4j") + +from neo4j import ( + AsyncGraphDatabase, + AsyncManagedTransaction, +) + +from dotenv import load_dotenv + +# use the .env that is inside the current folder +load_dotenv(dotenv_path=".env", override=False) + +MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) + +config = configparser.ConfigParser() +config.read("config.ini", "utf-8") + +@final +@dataclass +class MemgraphStorage(BaseGraphStorage): + def __init__(self, namespace, global_config, embedding_func): + super().__init__( + namespace=namespace, + global_config=global_config, + embedding_func=embedding_func, + ) + self._driver = None + + async def initialize(self): + URI = os.environ.get("MEMGRAPH_URI", config.get("memgraph", "uri", fallback="bolt://localhost:7687")) + USERNAME = os.environ.get("MEMGRAPH_USERNAME", config.get("memgraph", "username", fallback="")) + PASSWORD = os.environ.get("MEMGRAPH_PASSWORD", config.get("memgraph", "password", fallback="")) + DATABASE = os.environ.get("MEMGRAPH_DATABASE", config.get("memgraph", "database", fallback="memgraph")) + + self._driver = AsyncGraphDatabase.driver( + URI, + auth=(USERNAME, PASSWORD), + ) + self._DATABASE = DATABASE + try: + async with self._driver.session(database=DATABASE) as session: + # Create index for base nodes on entity_id if it doesn't exist + try: + await session.run("""CREATE INDEX ON :base(entity_id)""") + logger.info("Created index on :base(entity_id) in Memgraph.") + except Exception as e: + # Index may already exist, which is not an error + logger.warning(f"Index creation on :base(entity_id) may have failed or already exists: {e}") + await session.run("RETURN 1") + logger.info(f"Connected to Memgraph at {URI}") + except Exception as e: + logger.error(f"Failed to connect to Memgraph at {URI}: {e}") + raise + + async def finalize(self): + if self._driver is not None: + await self._driver.close() + self._driver = None + + async def __aexit__(self, exc_type, exc, tb): + await self.finalize() + + async def index_done_callback(self): + # Memgraph handles persistence automatically + pass + + async def has_node(self, node_id: str) -> bool: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists" + result = await session.run(query, entity_id=node_id) + single_result = await result.single() + await result.consume() + return single_result["node_exists"] + + async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = ( + "MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) " + "RETURN COUNT(r) > 0 AS edgeExists" + ) + result = await session.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + ) + single_result = await result.single() + await result.consume() + return single_result["edgeExists"] + + async def get_node(self, node_id: str) -> dict[str, str] | None: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" + result = await session.run(query, entity_id=node_id) + records = await result.fetch(2) + await result.consume() + if records: + node = records[0]["n"] + node_dict = dict(node) + if "labels" in node_dict: + node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"] + return node_dict + return None + + async def get_all_labels(self) -> list[str]: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = """ + MATCH (n:base) + WHERE n.entity_id IS NOT NULL + RETURN DISTINCT n.entity_id AS label + ORDER BY label + """ + result = await session.run(query) + labels = [] + async for record in result: + labels.append(record["label"]) + await result.consume() + return labels + + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = """ + MATCH (n:base {entity_id: $entity_id}) + OPTIONAL MATCH (n)-[r]-(connected:base) + WHERE connected.entity_id IS NOT NULL + RETURN n, r, connected + """ + results = await session.run(query, entity_id=source_node_id) + edges = [] + async for record in results: + source_node = record["n"] + connected_node = record["connected"] + if not source_node or not connected_node: + continue + source_label = source_node.get("entity_id") + target_label = connected_node.get("entity_id") + if source_label and target_label: + edges.append((source_label, target_label)) + await results.consume() + return edges + + async def get_edge(self, source_node_id: str, target_node_id: str) -> dict[str, str] | None: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = """ + MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id}) + RETURN properties(r) as edge_properties + """ + result = await session.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + ) + records = await result.fetch(2) + await result.consume() + if records: + edge_result = dict(records[0]["edge_properties"]) + for key, default_value in { + "weight": 0.0, + "source_id": None, + "description": None, + "keywords": None, + }.items(): + if key not in edge_result: + edge_result[key] = default_value + return edge_result + return None + + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: + properties = node_data + entity_type = properties.get("entity_type", "base") + if "entity_id" not in properties: + raise ValueError("Memgraph: node properties must contain an 'entity_id' field") + async with self._driver.session(database=self._DATABASE) as session: + async def execute_upsert(tx: AsyncManagedTransaction): + query = ( + f""" + MERGE (n:base {{entity_id: $entity_id}}) + SET n += $properties + SET n:`{entity_type}` + """ + ) + result = await tx.run(query, entity_id=node_id, properties=properties) + await result.consume() + await session.execute_write(execute_upsert) + + async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]) -> None: + edge_properties = edge_data + async with self._driver.session(database=self._DATABASE) as session: + async def execute_upsert(tx: AsyncManagedTransaction): + query = """ + MATCH (source:base {entity_id: $source_entity_id}) + WITH source + MATCH (target:base {entity_id: $target_entity_id}) + MERGE (source)-[r:DIRECTED]-(target) + SET r += $properties + RETURN r, source, target + """ + result = await tx.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + properties=edge_properties, + ) + await result.consume() + await session.execute_write(execute_upsert) + + async def delete_node(self, node_id: str) -> None: + async def _do_delete(tx: AsyncManagedTransaction): + query = """ + MATCH (n:base {entity_id: $entity_id}) + DETACH DELETE n + """ + result = await tx.run(query, entity_id=node_id) + await result.consume() + async with self._driver.session(database=self._DATABASE) as session: + await session.execute_write(_do_delete) + + async def remove_nodes(self, nodes: list[str]): + for node in nodes: + await self.delete_node(node) + + async def remove_edges(self, edges: list[tuple[str, str]]): + for source, target in edges: + async def _do_delete_edge(tx: AsyncManagedTransaction): + query = """ + MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id}) + DELETE r + """ + result = await tx.run( + query, source_entity_id=source, target_entity_id=target + ) + await result.consume() + async with self._driver.session(database=self._DATABASE) as session: + await session.execute_write(_do_delete_edge) + + async def drop(self) -> dict[str, str]: + try: + async with self._driver.session(database=self._DATABASE) as session: + query = "MATCH (n) DETACH DELETE n" + result = await session.run(query) + await result.consume() + logger.info(f"Process {os.getpid()} drop Memgraph database {self._DATABASE}") + return {"status": "success", "message": "data dropped"} + except Exception as e: + logger.error(f"Error dropping Memgraph database {self._DATABASE}: {e}") + return {"status": "error", "message": str(e)} + + async def node_degree(self, node_id: str) -> int: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = """ + MATCH (n:base {entity_id: $entity_id}) + OPTIONAL MATCH (n)-[r]-() + RETURN COUNT(r) AS degree + """ + result = await session.run(query, entity_id=node_id) + record = await result.single() + await result.consume() + if not record: + return 0 + return record["degree"] + + async def edge_degree(self, src_id: str, tgt_id: str) -> int: + src_degree = await self.node_degree(src_id) + trg_degree = await self.node_degree(tgt_id) + src_degree = 0 if src_degree is None else src_degree + trg_degree = 0 if trg_degree is None else trg_degree + return int(src_degree) + int(trg_degree) + + async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = """ + UNWIND $chunk_ids AS chunk_id + MATCH (n:base) + WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep) + RETURN DISTINCT n + """ + result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP) + nodes = [] + async for record in result: + node = record["n"] + node_dict = dict(node) + node_dict["id"] = node_dict.get("entity_id") + nodes.append(node_dict) + await result.consume() + return nodes + + async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = """ + UNWIND $chunk_ids AS chunk_id + MATCH (a:base)-[r]-(b:base) + WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep) + RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties + """ + result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP) + edges = [] + async for record in result: + edge_properties = record["properties"] + edge_properties["source"] = record["source"] + edge_properties["target"] = record["target"] + edges.append(edge_properties) + await result.consume() + return edges + + async def get_knowledge_graph( + self, + node_label: str, + max_depth: int = 3, + max_nodes: int = MAX_GRAPH_NODES, + ) -> KnowledgeGraph: + result = KnowledgeGraph() + seen_nodes = set() + seen_edges = set() + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + if node_label == "*": + count_query = "MATCH (n) RETURN count(n) as total" + count_result = await session.run(count_query) + count_record = await count_result.single() + await count_result.consume() + if count_record and count_record["total"] > max_nodes: + result.is_truncated = True + logger.info(f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}") + main_query = """ + MATCH (n) + OPTIONAL MATCH (n)-[r]-() + WITH n, COALESCE(count(r), 0) AS degree + ORDER BY degree DESC + LIMIT $max_nodes + WITH collect({node: n}) AS filtered_nodes + UNWIND filtered_nodes AS node_info + WITH collect(node_info.node) AS kept_nodes, filtered_nodes + OPTIONAL MATCH (a)-[r]-(b) + WHERE a IN kept_nodes AND b IN kept_nodes + RETURN filtered_nodes AS node_info, + collect(DISTINCT r) AS relationships + """ + result_set = await session.run(main_query, {"max_nodes": max_nodes}) + record = await result_set.single() + await result_set.consume() + else: + # BFS fallback for Memgraph (no APOC) + from collections import deque + # Get the starting node + start_query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" + node_result = await session.run(start_query, entity_id=node_label) + node_record = await node_result.single() + await node_result.consume() + if not node_record: + return result + start_node = node_record["n"] + start_node_id = start_node.get("entity_id") + queue = deque([(start_node, 0)]) + visited = set() + bfs_nodes = [] + while queue and len(bfs_nodes) < max_nodes: + current_node, depth = queue.popleft() + node_id = current_node.get("entity_id") + if node_id in visited: + continue + visited.add(node_id) + bfs_nodes.append(current_node) + if depth < max_depth: + # Get neighbors + neighbor_query = """ + MATCH (n:base {entity_id: $entity_id})-[]-(m:base) + RETURN m + """ + neighbors_result = await session.run(neighbor_query, entity_id=node_id) + neighbors = [rec["m"] for rec in await neighbors_result.to_list()] + await neighbors_result.consume() + for neighbor in neighbors: + neighbor_id = neighbor.get("entity_id") + if neighbor_id not in visited: + queue.append((neighbor, depth + 1)) + # Build subgraph + subgraph_ids = [n.get("entity_id") for n in bfs_nodes] + # Nodes + for n in bfs_nodes: + node_id = n.get("entity_id") + if node_id not in seen_nodes: + result.nodes.append(KnowledgeGraphNode( + id=node_id, + labels=[node_id], + properties=dict(n), + )) + seen_nodes.add(node_id) + # Edges + if subgraph_ids: + edge_query = """ + MATCH (a:base)-[r]-(b:base) + WHERE a.entity_id IN $ids AND b.entity_id IN $ids + RETURN DISTINCT r, a, b + """ + edge_result = await session.run(edge_query, ids=subgraph_ids) + async for record in edge_result: + r = record["r"] + a = record["a"] + b = record["b"] + edge_id = f"{a.get('entity_id')}-{b.get('entity_id')}" + if edge_id not in seen_edges: + result.edges.append(KnowledgeGraphEdge( + id=edge_id, + type="DIRECTED", + source=a.get("entity_id"), + target=b.get("entity_id"), + properties=dict(r), + )) + seen_edges.add(edge_id) + await edge_result.consume() + logger.info(f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}") + return result \ No newline at end of file From 0d6bd3bac2c9f0e5befe955628838e975610b6f2 Mon Sep 17 00:00:00 2001 From: DavIvek Date: Thu, 26 Jun 2025 16:18:25 +0200 Subject: [PATCH 02/30] Revert changes made to graph_visual_with_neo4j.py --- examples/graph_visual_with_neo4j.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/graph_visual_with_neo4j.py b/examples/graph_visual_with_neo4j.py index e06c248c..1cd2e7a3 100644 --- a/examples/graph_visual_with_neo4j.py +++ b/examples/graph_visual_with_neo4j.py @@ -11,7 +11,7 @@ BATCH_SIZE_EDGES = 100 # Neo4j connection credentials NEO4J_URI = "bolt://localhost:7687" NEO4J_USERNAME = "neo4j" -NEO4J_PASSWORD = "david123" +NEO4J_PASSWORD = "your_password" def xml_to_json(xml_file): From 80d4d5b0d50056cd89e347789d489896f0b39275 Mon Sep 17 00:00:00 2001 From: DavIvek Date: Thu, 26 Jun 2025 16:26:51 +0200 Subject: [PATCH 03/30] Add Memgraph into README.md --- README.md | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/README.md b/README.md index 617dc5e6..2068f205 100644 --- a/README.md +++ b/README.md @@ -854,6 +854,41 @@ rag = LightRAG( +
+ Using Memgraph for Storage + +* Memgraph is a high-performance, in-memory graph database compatible with the Neo4j Bolt protocol. +* You can run Memgraph locally using Docker for easy testing: +* See: https://memgraph.com/download + +```python +export MEMGRAPH_URI="bolt://localhost:7687" + +# Setup logger for LightRAG +setup_logger("lightrag", level="INFO") + +# When you launch the project, override the default KG: NetworkX +# by specifying kg="MemgraphStorage". + +# Note: Default settings use NetworkX +# Initialize LightRAG with Memgraph implementation. +async def initialize_rag(): + rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model + graph_storage="MemgraphStorage", #<-----------override KG default + ) + + # Initialize database connections + await rag.initialize_storages() + # Initialize pipeline status for document processing + await initialize_pipeline_status() + + return rag +``` + +
+ ## Edit Entities and Relations LightRAG now supports comprehensive knowledge graph management capabilities, allowing you to create, edit, and delete entities and relationships within your knowledge graph. From 7118b23ca2851d933384612d87cd34c25b4bba5e Mon Sep 17 00:00:00 2001 From: DavIvek Date: Thu, 26 Jun 2025 16:33:19 +0200 Subject: [PATCH 04/30] reformatting --- lightrag/kg/memgraph_impl.py | 142 ++++++++++++++++++++++++----------- 1 file changed, 100 insertions(+), 42 deletions(-) diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index cb46cdc6..df28b8b2 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -1,5 +1,4 @@ import os -import re from dataclasses import dataclass from typing import final import configparser @@ -28,6 +27,7 @@ MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) config = configparser.ConfigParser() config.read("config.ini", "utf-8") + @final @dataclass class MemgraphStorage(BaseGraphStorage): @@ -40,10 +40,19 @@ class MemgraphStorage(BaseGraphStorage): self._driver = None async def initialize(self): - URI = os.environ.get("MEMGRAPH_URI", config.get("memgraph", "uri", fallback="bolt://localhost:7687")) - USERNAME = os.environ.get("MEMGRAPH_USERNAME", config.get("memgraph", "username", fallback="")) - PASSWORD = os.environ.get("MEMGRAPH_PASSWORD", config.get("memgraph", "password", fallback="")) - DATABASE = os.environ.get("MEMGRAPH_DATABASE", config.get("memgraph", "database", fallback="memgraph")) + URI = os.environ.get( + "MEMGRAPH_URI", + config.get("memgraph", "uri", fallback="bolt://localhost:7687"), + ) + USERNAME = os.environ.get( + "MEMGRAPH_USERNAME", config.get("memgraph", "username", fallback="") + ) + PASSWORD = os.environ.get( + "MEMGRAPH_PASSWORD", config.get("memgraph", "password", fallback="") + ) + DATABASE = os.environ.get( + "MEMGRAPH_DATABASE", config.get("memgraph", "database", fallback="memgraph") + ) self._driver = AsyncGraphDatabase.driver( URI, @@ -58,7 +67,9 @@ class MemgraphStorage(BaseGraphStorage): logger.info("Created index on :base(entity_id) in Memgraph.") except Exception as e: # Index may already exist, which is not an error - logger.warning(f"Index creation on :base(entity_id) may have failed or already exists: {e}") + logger.warning( + f"Index creation on :base(entity_id) may have failed or already exists: {e}" + ) await session.run("RETURN 1") logger.info(f"Connected to Memgraph at {URI}") except Exception as e: @@ -78,7 +89,9 @@ class MemgraphStorage(BaseGraphStorage): pass async def has_node(self, node_id: str) -> bool: - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists" result = await session.run(query, entity_id=node_id) single_result = await result.single() @@ -86,7 +99,9 @@ class MemgraphStorage(BaseGraphStorage): return single_result["node_exists"] async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = ( "MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) " "RETURN COUNT(r) > 0 AS edgeExists" @@ -101,7 +116,9 @@ class MemgraphStorage(BaseGraphStorage): return single_result["edgeExists"] async def get_node(self, node_id: str) -> dict[str, str] | None: - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" result = await session.run(query, entity_id=node_id) records = await result.fetch(2) @@ -110,12 +127,16 @@ class MemgraphStorage(BaseGraphStorage): node = records[0]["n"] node_dict = dict(node) if "labels" in node_dict: - node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"] + node_dict["labels"] = [ + label for label in node_dict["labels"] if label != "base" + ] return node_dict return None async def get_all_labels(self) -> list[str]: - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = """ MATCH (n:base) WHERE n.entity_id IS NOT NULL @@ -130,7 +151,9 @@ class MemgraphStorage(BaseGraphStorage): return labels async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = """ MATCH (n:base {entity_id: $entity_id}) OPTIONAL MATCH (n)-[r]-(connected:base) @@ -151,8 +174,12 @@ class MemgraphStorage(BaseGraphStorage): await results.consume() return edges - async def get_edge(self, source_node_id: str, target_node_id: str) -> dict[str, str] | None: - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async def get_edge( + self, source_node_id: str, target_node_id: str + ) -> dict[str, str] | None: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = """ MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id}) RETURN properties(r) as edge_properties @@ -181,23 +208,28 @@ class MemgraphStorage(BaseGraphStorage): properties = node_data entity_type = properties.get("entity_type", "base") if "entity_id" not in properties: - raise ValueError("Memgraph: node properties must contain an 'entity_id' field") + raise ValueError( + "Memgraph: node properties must contain an 'entity_id' field" + ) async with self._driver.session(database=self._DATABASE) as session: + async def execute_upsert(tx: AsyncManagedTransaction): - query = ( - f""" + query = f""" MERGE (n:base {{entity_id: $entity_id}}) SET n += $properties SET n:`{entity_type}` """ - ) result = await tx.run(query, entity_id=node_id, properties=properties) await result.consume() + await session.execute_write(execute_upsert) - async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]) -> None: + async def upsert_edge( + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ) -> None: edge_properties = edge_data async with self._driver.session(database=self._DATABASE) as session: + async def execute_upsert(tx: AsyncManagedTransaction): query = """ MATCH (source:base {entity_id: $source_entity_id}) @@ -214,6 +246,7 @@ class MemgraphStorage(BaseGraphStorage): properties=edge_properties, ) await result.consume() + await session.execute_write(execute_upsert) async def delete_node(self, node_id: str) -> None: @@ -224,6 +257,7 @@ class MemgraphStorage(BaseGraphStorage): """ result = await tx.run(query, entity_id=node_id) await result.consume() + async with self._driver.session(database=self._DATABASE) as session: await session.execute_write(_do_delete) @@ -233,6 +267,7 @@ class MemgraphStorage(BaseGraphStorage): async def remove_edges(self, edges: list[tuple[str, str]]): for source, target in edges: + async def _do_delete_edge(tx: AsyncManagedTransaction): query = """ MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id}) @@ -242,6 +277,7 @@ class MemgraphStorage(BaseGraphStorage): query, source_entity_id=source, target_entity_id=target ) await result.consume() + async with self._driver.session(database=self._DATABASE) as session: await session.execute_write(_do_delete_edge) @@ -251,14 +287,18 @@ class MemgraphStorage(BaseGraphStorage): query = "MATCH (n) DETACH DELETE n" result = await session.run(query) await result.consume() - logger.info(f"Process {os.getpid()} drop Memgraph database {self._DATABASE}") + logger.info( + f"Process {os.getpid()} drop Memgraph database {self._DATABASE}" + ) return {"status": "success", "message": "data dropped"} except Exception as e: logger.error(f"Error dropping Memgraph database {self._DATABASE}: {e}") return {"status": "error", "message": str(e)} async def node_degree(self, node_id: str) -> int: - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = """ MATCH (n:base {entity_id: $entity_id}) OPTIONAL MATCH (n)-[r]-() @@ -279,7 +319,9 @@ class MemgraphStorage(BaseGraphStorage): return int(src_degree) + int(trg_degree) async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = """ UNWIND $chunk_ids AS chunk_id MATCH (n:base) @@ -297,7 +339,9 @@ class MemgraphStorage(BaseGraphStorage): return nodes async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = """ UNWIND $chunk_ids AS chunk_id MATCH (a:base)-[r]-(b:base) @@ -323,7 +367,9 @@ class MemgraphStorage(BaseGraphStorage): result = KnowledgeGraph() seen_nodes = set() seen_edges = set() - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: if node_label == "*": count_query = "MATCH (n) RETURN count(n) as total" count_result = await session.run(count_query) @@ -331,7 +377,9 @@ class MemgraphStorage(BaseGraphStorage): await count_result.consume() if count_record and count_record["total"] > max_nodes: result.is_truncated = True - logger.info(f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}") + logger.info( + f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}" + ) main_query = """ MATCH (n) OPTIONAL MATCH (n)-[r]-() @@ -352,6 +400,7 @@ class MemgraphStorage(BaseGraphStorage): else: # BFS fallback for Memgraph (no APOC) from collections import deque + # Get the starting node start_query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" node_result = await session.run(start_query, entity_id=node_label) @@ -360,7 +409,6 @@ class MemgraphStorage(BaseGraphStorage): if not node_record: return result start_node = node_record["n"] - start_node_id = start_node.get("entity_id") queue = deque([(start_node, 0)]) visited = set() bfs_nodes = [] @@ -377,8 +425,12 @@ class MemgraphStorage(BaseGraphStorage): MATCH (n:base {entity_id: $entity_id})-[]-(m:base) RETURN m """ - neighbors_result = await session.run(neighbor_query, entity_id=node_id) - neighbors = [rec["m"] for rec in await neighbors_result.to_list()] + neighbors_result = await session.run( + neighbor_query, entity_id=node_id + ) + neighbors = [ + rec["m"] for rec in await neighbors_result.to_list() + ] await neighbors_result.consume() for neighbor in neighbors: neighbor_id = neighbor.get("entity_id") @@ -390,11 +442,13 @@ class MemgraphStorage(BaseGraphStorage): for n in bfs_nodes: node_id = n.get("entity_id") if node_id not in seen_nodes: - result.nodes.append(KnowledgeGraphNode( - id=node_id, - labels=[node_id], - properties=dict(n), - )) + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[node_id], + properties=dict(n), + ) + ) seen_nodes.add(node_id) # Edges if subgraph_ids: @@ -410,14 +464,18 @@ class MemgraphStorage(BaseGraphStorage): b = record["b"] edge_id = f"{a.get('entity_id')}-{b.get('entity_id')}" if edge_id not in seen_edges: - result.edges.append(KnowledgeGraphEdge( - id=edge_id, - type="DIRECTED", - source=a.get("entity_id"), - target=b.get("entity_id"), - properties=dict(r), - )) + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type="DIRECTED", + source=a.get("entity_id"), + target=b.get("entity_id"), + properties=dict(r), + ) + ) seen_edges.add(edge_id) await edge_result.consume() - logger.info(f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}") - return result \ No newline at end of file + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + return result From bd158d096bb26215b283496b4bf2aebe4d0e2292 Mon Sep 17 00:00:00 2001 From: DavIvek Date: Fri, 27 Jun 2025 14:47:23 +0200 Subject: [PATCH 05/30] polish Memgraph implementation --- lightrag/kg/memgraph_impl.py | 803 ++++++++++++++++++++++++----------- 1 file changed, 551 insertions(+), 252 deletions(-) diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index df28b8b2..bf870154 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -89,183 +89,419 @@ class MemgraphStorage(BaseGraphStorage): pass async def has_node(self, node_id: str) -> bool: + """ + Check if a node exists in the graph. + + Args: + node_id: The ID of the node to check. + + Returns: + bool: True if the node exists, False otherwise. + + Raises: + Exception: If there is an error checking the node existence. + """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists" - result = await session.run(query, entity_id=node_id) - single_result = await result.single() - await result.consume() - return single_result["node_exists"] + try: + query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists" + result = await session.run(query, entity_id=node_id) + single_result = await result.single() + await result.consume() # Ensure result is fully consumed + return single_result["node_exists"] + except Exception as e: + logger.error(f"Error checking node existence for {node_id}: {str(e)}") + await result.consume() # Ensure the result is consumed even on error + raise async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + """ + Check if an edge exists between two nodes in the graph. + + Args: + source_node_id: The ID of the source node. + target_node_id: The ID of the target node. + + Returns: + bool: True if the edge exists, False otherwise. + + Raises: + Exception: If there is an error checking the edge existence. + """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = ( - "MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) " - "RETURN COUNT(r) > 0 AS edgeExists" - ) - result = await session.run( - query, - source_entity_id=source_node_id, - target_entity_id=target_node_id, - ) - single_result = await result.single() - await result.consume() - return single_result["edgeExists"] + try: + query = ( + "MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) " + "RETURN COUNT(r) > 0 AS edgeExists" + ) + result = await session.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + ) + single_result = await result.single() + await result.consume() # Ensure result is fully consumed + return single_result["edgeExists"] + except Exception as e: + logger.error( + f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" + ) + await result.consume() # Ensure the result is consumed even on error + raise async def get_node(self, node_id: str) -> dict[str, str] | None: + """Get node by its label identifier, return only node properties + + Args: + node_id: The node label to look up + + Returns: + dict: Node properties if found + None: If node not found + + Raises: + Exception: If there is an error executing the query + """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" - result = await session.run(query, entity_id=node_id) - records = await result.fetch(2) - await result.consume() - if records: - node = records[0]["n"] - node_dict = dict(node) - if "labels" in node_dict: - node_dict["labels"] = [ - label for label in node_dict["labels"] if label != "base" - ] - return node_dict - return None + try: + query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" + result = await session.run(query, entity_id=node_id) + try: + records = await result.fetch( + 2 + ) # Get 2 records for duplication check + + if len(records) > 1: + logger.warning( + f"Multiple nodes found with label '{node_id}'. Using first node." + ) + if records: + node = records[0]["n"] + node_dict = dict(node) + # Remove base label from labels list if it exists + if "labels" in node_dict: + node_dict["labels"] = [ + label + for label in node_dict["labels"] + if label != "base" + ] + return node_dict + return None + finally: + await result.consume() # Ensure result is fully consumed + except Exception as e: + logger.error(f"Error getting node for {node_id}: {str(e)}") + raise + + async def node_degree(self, node_id: str) -> int: + """Get the degree (number of relationships) of a node with the given label. + If multiple nodes have the same label, returns the degree of the first node. + If no node is found, returns 0. + + Args: + node_id: The label of the node + + Returns: + int: The number of relationships the node has, or 0 if no node found + + Raises: + Exception: If there is an error executing the query + """ + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + query = """ + MATCH (n:base {entity_id: $entity_id}) + OPTIONAL MATCH (n)-[r]-() + RETURN COUNT(r) AS degree + """ + result = await session.run(query, entity_id=node_id) + try: + record = await result.single() + + if not record: + logger.warning(f"No node found with label '{node_id}'") + return 0 + + degree = record["degree"] + return degree + finally: + await result.consume() # Ensure result is fully consumed + except Exception as e: + logger.error(f"Error getting node degree for {node_id}: {str(e)}") + raise async def get_all_labels(self) -> list[str]: + """ + Get all existing node labels in the database + Returns: + ["Person", "Company", ...] # Alphabetically sorted label list + + Raises: + Exception: If there is an error executing the query + """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = """ - MATCH (n:base) - WHERE n.entity_id IS NOT NULL - RETURN DISTINCT n.entity_id AS label - ORDER BY label - """ - result = await session.run(query) - labels = [] - async for record in result: - labels.append(record["label"]) - await result.consume() - return labels + try: + query = """ + MATCH (n:base) + WHERE n.entity_id IS NOT NULL + RETURN DISTINCT n.entity_id AS label + ORDER BY label + """ + result = await session.run(query) + labels = [] + async for record in result: + labels.append(record["label"]) + await result.consume() + return labels + except Exception as e: + logger.error(f"Error getting all labels: {str(e)}") + await result.consume() # Ensure the result is consumed even on error + raise async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: - async with self._driver.session( - database=self._DATABASE, default_access_mode="READ" - ) as session: - query = """ - MATCH (n:base {entity_id: $entity_id}) - OPTIONAL MATCH (n)-[r]-(connected:base) - WHERE connected.entity_id IS NOT NULL - RETURN n, r, connected - """ - results = await session.run(query, entity_id=source_node_id) - edges = [] - async for record in results: - source_node = record["n"] - connected_node = record["connected"] - if not source_node or not connected_node: - continue - source_label = source_node.get("entity_id") - target_label = connected_node.get("entity_id") - if source_label and target_label: - edges.append((source_label, target_label)) - await results.consume() - return edges + """Retrieves all edges (relationships) for a particular node identified by its label. + + Args: + source_node_id: Label of the node to get edges for + + Returns: + list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges + None: If no edges found + + Raises: + Exception: If there is an error executing the query + """ + try: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + query = """MATCH (n:base {entity_id: $entity_id}) + OPTIONAL MATCH (n)-[r]-(connected:base) + WHERE connected.entity_id IS NOT NULL + RETURN n, r, connected""" + results = await session.run(query, entity_id=source_node_id) + + edges = [] + async for record in results: + source_node = record["n"] + connected_node = record["connected"] + + # Skip if either node is None + if not source_node or not connected_node: + continue + + source_label = ( + source_node.get("entity_id") + if source_node.get("entity_id") + else None + ) + target_label = ( + connected_node.get("entity_id") + if connected_node.get("entity_id") + else None + ) + + if source_label and target_label: + edges.append((source_label, target_label)) + + await results.consume() # Ensure results are consumed + return edges + except Exception as e: + logger.error( + f"Error getting edges for node {source_node_id}: {str(e)}" + ) + await results.consume() # Ensure results are consumed even on error + raise + except Exception as e: + logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}") + raise async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: + """Get edge properties between two nodes. + + Args: + source_node_id: Label of the source node + target_node_id: Label of the target node + + Returns: + dict: Edge properties if found, default properties if not found or on error + + Raises: + Exception: If there is an error executing the query + """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = """ - MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id}) - RETURN properties(r) as edge_properties - """ - result = await session.run( - query, - source_entity_id=source_node_id, - target_entity_id=target_node_id, - ) - records = await result.fetch(2) - await result.consume() - if records: - edge_result = dict(records[0]["edge_properties"]) - for key, default_value in { - "weight": 0.0, - "source_id": None, - "description": None, - "keywords": None, - }.items(): - if key not in edge_result: - edge_result[key] = default_value - return edge_result - return None + try: + query = """ + MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id}) + RETURN properties(r) as edge_properties + """ + result = await session.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + ) + records = await result.fetch(2) + await result.consume() + if records: + edge_result = dict(records[0]["edge_properties"]) + for key, default_value in { + "weight": 0.0, + "source_id": None, + "description": None, + "keywords": None, + }.items(): + if key not in edge_result: + edge_result[key] = default_value + logger.warning( + f"Edge between {source_node_id} and {target_node_id} is missing property: {key}. Using default value: {default_value}" + ) + return edge_result + return None + except Exception as e: + logger.error( + f"Error getting edge between {source_node_id} and {target_node_id}: {str(e)}" + ) + await result.consume() # Ensure the result is consumed even on error + raise async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: + """ + Upsert a node in the Neo4j database. + + Args: + node_id: The unique identifier for the node (used as label) + node_data: Dictionary of node properties + """ properties = node_data - entity_type = properties.get("entity_type", "base") + entity_type = properties["entity_type"] if "entity_id" not in properties: - raise ValueError( - "Memgraph: node properties must contain an 'entity_id' field" - ) - async with self._driver.session(database=self._DATABASE) as session: + raise ValueError("Neo4j: node properties must contain an 'entity_id' field") - async def execute_upsert(tx: AsyncManagedTransaction): - query = f""" - MERGE (n:base {{entity_id: $entity_id}}) + try: + async with self._driver.session(database=self._DATABASE) as session: + + async def execute_upsert(tx: AsyncManagedTransaction): + query = ( + """ + MERGE (n:base {entity_id: $entity_id}) SET n += $properties - SET n:`{entity_type}` + SET n:`%s` """ - result = await tx.run(query, entity_id=node_id, properties=properties) - await result.consume() + % entity_type + ) + result = await tx.run( + query, entity_id=node_id, properties=properties + ) + await result.consume() # Ensure result is fully consumed - await session.execute_write(execute_upsert) + await session.execute_write(execute_upsert) + except Exception as e: + logger.error(f"Error during upsert: {str(e)}") + raise async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ) -> None: - edge_properties = edge_data - async with self._driver.session(database=self._DATABASE) as session: + """ + Upsert an edge and its properties between two nodes identified by their labels. + Ensures both source and target nodes exist and are unique before creating the edge. + Uses entity_id property to uniquely identify nodes. - async def execute_upsert(tx: AsyncManagedTransaction): - query = """ - MATCH (source:base {entity_id: $source_entity_id}) - WITH source - MATCH (target:base {entity_id: $target_entity_id}) - MERGE (source)-[r:DIRECTED]-(target) - SET r += $properties - RETURN r, source, target - """ - result = await tx.run( - query, - source_entity_id=source_node_id, - target_entity_id=target_node_id, - properties=edge_properties, - ) - await result.consume() + Args: + source_node_id (str): Label of the source node (used as identifier) + target_node_id (str): Label of the target node (used as identifier) + edge_data (dict): Dictionary of properties to set on the edge - await session.execute_write(execute_upsert) + Raises: + Exception: If there is an error executing the query + """ + try: + edge_properties = edge_data + async with self._driver.session(database=self._DATABASE) as session: + + async def execute_upsert(tx: AsyncManagedTransaction): + query = """ + MATCH (source:base {entity_id: $source_entity_id}) + WITH source + MATCH (target:base {entity_id: $target_entity_id}) + MERGE (source)-[r:DIRECTED]-(target) + SET r += $properties + RETURN r, source, target + """ + result = await tx.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + properties=edge_properties, + ) + try: + await result.fetch(2) + finally: + await result.consume() # Ensure result is consumed + + await session.execute_write(execute_upsert) + except Exception as e: + logger.error(f"Error during edge upsert: {str(e)}") + raise async def delete_node(self, node_id: str) -> None: + """Delete a node with the specified label + + Args: + node_id: The label of the node to delete + + Raises: + Exception: If there is an error executing the query + """ + async def _do_delete(tx: AsyncManagedTransaction): query = """ MATCH (n:base {entity_id: $entity_id}) DETACH DELETE n """ result = await tx.run(query, entity_id=node_id) + logger.debug(f"Deleted node with label {node_id}") await result.consume() - async with self._driver.session(database=self._DATABASE) as session: - await session.execute_write(_do_delete) + try: + async with self._driver.session(database=self._DATABASE) as session: + await session.execute_write(_do_delete) + except Exception as e: + logger.error(f"Error during node deletion: {str(e)}") + raise async def remove_nodes(self, nodes: list[str]): + """Delete multiple nodes + + Args: + nodes: List of node labels to be deleted + """ for node in nodes: await self.delete_node(node) async def remove_edges(self, edges: list[tuple[str, str]]): + """Delete multiple edges + + Args: + edges: List of edges to be deleted, each edge is a (source, target) tuple + + Raises: + Exception: If there is an error executing the query + """ for source, target in edges: async def _do_delete_edge(tx: AsyncManagedTransaction): @@ -276,15 +512,32 @@ class MemgraphStorage(BaseGraphStorage): result = await tx.run( query, source_entity_id=source, target_entity_id=target ) - await result.consume() + logger.debug(f"Deleted edge from '{source}' to '{target}'") + await result.consume() # Ensure result is fully consumed - async with self._driver.session(database=self._DATABASE) as session: - await session.execute_write(_do_delete_edge) + try: + async with self._driver.session(database=self._DATABASE) as session: + await session.execute_write(_do_delete_edge) + except Exception as e: + logger.error(f"Error during edge deletion: {str(e)}") + raise async def drop(self) -> dict[str, str]: + """Drop all data from storage and clean up resources + + This method will delete all nodes and relationships in the Neo4j database. + + Returns: + dict[str, str]: Operation status and message + - On success: {"status": "success", "message": "data dropped"} + - On failure: {"status": "error", "message": ""} + + Raises: + Exception: If there is an error executing the query + """ try: async with self._driver.session(database=self._DATABASE) as session: - query = "MATCH (n) DETACH DELETE n" + query = "DROP GRAPH" result = await session.run(query) await result.consume() logger.info( @@ -295,30 +548,36 @@ class MemgraphStorage(BaseGraphStorage): logger.error(f"Error dropping Memgraph database {self._DATABASE}: {e}") return {"status": "error", "message": str(e)} - async def node_degree(self, node_id: str) -> int: - async with self._driver.session( - database=self._DATABASE, default_access_mode="READ" - ) as session: - query = """ - MATCH (n:base {entity_id: $entity_id}) - OPTIONAL MATCH (n)-[r]-() - RETURN COUNT(r) AS degree - """ - result = await session.run(query, entity_id=node_id) - record = await result.single() - await result.consume() - if not record: - return 0 - return record["degree"] - async def edge_degree(self, src_id: str, tgt_id: str) -> int: + """Get the total degree (sum of relationships) of two nodes. + + Args: + src_id: Label of the source node + tgt_id: Label of the target node + + Returns: + int: Sum of the degrees of both nodes + """ src_degree = await self.node_degree(src_id) trg_degree = await self.node_degree(tgt_id) + + # Convert None to 0 for addition src_degree = 0 if src_degree is None else src_degree trg_degree = 0 if trg_degree is None else trg_degree - return int(src_degree) + int(trg_degree) + + degrees = int(src_degree) + int(trg_degree) + return degrees async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + """Get all nodes that are associated with the given chunk_ids. + + Args: + chunk_ids: List of chunk IDs to find associated nodes for + + Returns: + list[dict]: A list of nodes, where each node is a dictionary of its properties. + An empty list if no matching nodes are found. + """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: @@ -335,10 +594,19 @@ class MemgraphStorage(BaseGraphStorage): node_dict = dict(node) node_dict["id"] = node_dict.get("entity_id") nodes.append(node_dict) - await result.consume() - return nodes + await result.consume() + return nodes async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + """Get all edges that are associated with the given chunk_ids. + + Args: + chunk_ids: List of chunk IDs to find associated edges for + + Returns: + list[dict]: A list of edges, where each edge is a dictionary of its properties. + An empty list if no matching edges are found. + """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: @@ -364,118 +632,149 @@ class MemgraphStorage(BaseGraphStorage): max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES, ) -> KnowledgeGraph: + """ + Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. + + Args: + node_label: Label of the starting node, * means all nodes + max_depth: Maximum depth of the subgraph, Defaults to 3 + max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000 + + Returns: + KnowledgeGraph object containing nodes and edges, with an is_truncated flag + indicating whether the graph was truncated due to max_nodes limit + + Raises: + Exception: If there is an error executing the query + """ result = KnowledgeGraph() seen_nodes = set() seen_edges = set() - async with self._driver.session( - database=self._DATABASE, default_access_mode="READ" - ) as session: - if node_label == "*": - count_query = "MATCH (n) RETURN count(n) as total" - count_result = await session.run(count_query) - count_record = await count_result.single() - await count_result.consume() - if count_record and count_record["total"] > max_nodes: - result.is_truncated = True - logger.info( - f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}" - ) - main_query = """ - MATCH (n) - OPTIONAL MATCH (n)-[r]-() - WITH n, COALESCE(count(r), 0) AS degree - ORDER BY degree DESC - LIMIT $max_nodes - WITH collect({node: n}) AS filtered_nodes - UNWIND filtered_nodes AS node_info - WITH collect(node_info.node) AS kept_nodes, filtered_nodes - OPTIONAL MATCH (a)-[r]-(b) - WHERE a IN kept_nodes AND b IN kept_nodes - RETURN filtered_nodes AS node_info, - collect(DISTINCT r) AS relationships - """ - result_set = await session.run(main_query, {"max_nodes": max_nodes}) - record = await result_set.single() - await result_set.consume() - else: - # BFS fallback for Memgraph (no APOC) - from collections import deque + try: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + if node_label == "*": + count_query = "MATCH (n) RETURN count(n) as total" + count_result = None + try: + count_result = await session.run(count_query) + count_record = await count_result.single() + if count_record and count_record["total"] > max_nodes: + result.is_truncated = True + logger.info( + f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}" + ) + finally: + if count_result: + await count_result.consume() - # Get the starting node - start_query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" - node_result = await session.run(start_query, entity_id=node_label) - node_record = await node_result.single() - await node_result.consume() - if not node_record: - return result - start_node = node_record["n"] - queue = deque([(start_node, 0)]) - visited = set() - bfs_nodes = [] - while queue and len(bfs_nodes) < max_nodes: - current_node, depth = queue.popleft() - node_id = current_node.get("entity_id") - if node_id in visited: - continue - visited.add(node_id) - bfs_nodes.append(current_node) - if depth < max_depth: - # Get neighbors - neighbor_query = """ - MATCH (n:base {entity_id: $entity_id})-[]-(m:base) - RETURN m - """ - neighbors_result = await session.run( - neighbor_query, entity_id=node_id - ) - neighbors = [ - rec["m"] for rec in await neighbors_result.to_list() - ] - await neighbors_result.consume() - for neighbor in neighbors: - neighbor_id = neighbor.get("entity_id") - if neighbor_id not in visited: - queue.append((neighbor, depth + 1)) - # Build subgraph - subgraph_ids = [n.get("entity_id") for n in bfs_nodes] - # Nodes - for n in bfs_nodes: - node_id = n.get("entity_id") - if node_id not in seen_nodes: - result.nodes.append( - KnowledgeGraphNode( - id=node_id, - labels=[node_id], - properties=dict(n), - ) - ) - seen_nodes.add(node_id) - # Edges - if subgraph_ids: - edge_query = """ - MATCH (a:base)-[r]-(b:base) - WHERE a.entity_id IN $ids AND b.entity_id IN $ids - RETURN DISTINCT r, a, b + # Run the main query to get nodes with highest degree + main_query = """ + MATCH (n) + OPTIONAL MATCH (n)-[r]-() + WITH n, COALESCE(count(r), 0) AS degree + ORDER BY degree DESC + LIMIT $max_nodes + WITH collect({node: n}) AS filtered_nodes + UNWIND filtered_nodes AS node_info + WITH collect(node_info.node) AS kept_nodes, filtered_nodes + OPTIONAL MATCH (a)-[r]-(b) + WHERE a IN kept_nodes AND b IN kept_nodes + RETURN filtered_nodes AS node_info, + collect(DISTINCT r) AS relationships """ - edge_result = await session.run(edge_query, ids=subgraph_ids) - async for record in edge_result: - r = record["r"] - a = record["a"] - b = record["b"] - edge_id = f"{a.get('entity_id')}-{b.get('entity_id')}" - if edge_id not in seen_edges: - result.edges.append( - KnowledgeGraphEdge( - id=edge_id, - type="DIRECTED", - source=a.get("entity_id"), - target=b.get("entity_id"), - properties=dict(r), + result_set = None + try: + result_set = await session.run( + main_query, {"max_nodes": max_nodes} + ) + record = await result_set.single() + finally: + if result_set: + await result_set.consume() + + else: + bfs_query = """ + MATCH (start) WHERE start.entity_id = $entity_id + WITH start + CALL { + WITH start + MATCH path = (start)-[*0..$max_depth]-(node) + WITH nodes(path) AS path_nodes, relationships(path) AS path_rels + UNWIND path_nodes AS n + WITH collect(DISTINCT n) AS all_nodes, collect(DISTINCT path_rels) AS all_rel_lists + WITH all_nodes, reduce(r = [], x IN all_rel_lists | r + x) AS all_rels + RETURN all_nodes, all_rels + } + WITH all_nodes AS nodes, all_rels AS relationships, size(all_nodes) AS total_nodes + + // Apply node limiting here + WITH CASE + WHEN total_nodes <= $max_nodes THEN nodes + ELSE nodes[0..$max_nodes] + END AS limited_nodes, + relationships, + total_nodes, + total_nodes > $max_nodes AS is_truncated + UNWIND limited_nodes AS node + WITH collect({node: node}) AS node_info, relationships, total_nodes, is_truncated + RETURN node_info, relationships, total_nodes, is_truncated + """ + result_set = None + try: + result_set = await session.run( + bfs_query, + { + "entity_id": node_label, + "max_depth": max_depth, + "max_nodes": max_nodes, + }, + ) + record = await result_set.single() + if not record: + logger.debug(f"No record found for node {node_label}") + return result + + for node_info in record["node_info"]: + node = node_info["node"] + node_id = node.id + if node_id not in seen_nodes: + seen_nodes.add(node_id) + result.nodes.append( + KnowledgeGraphNode( + id=f"{node_id}", + labels=[node.get("entity_id")], + properties=dict(node), + ) ) - ) - seen_edges.add(edge_id) - await edge_result.consume() - logger.info( - f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" - ) - return result + + for rel in record["relationships"]: + edge_id = rel.id + if edge_id not in seen_edges: + seen_edges.add(edge_id) + start = rel.start_node + end = rel.end_node + result.edges.append( + KnowledgeGraphEdge( + id=f"{edge_id}", + type=rel.type, + source=f"{start.id}", + target=f"{end.id}", + properties=dict(rel), + ) + ) + + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + + return result + + finally: + if result_set: + await result_set.consume() + + except Exception as e: + logger.error(f"Error getting knowledge graph: {str(e)}") + return result From eed43e071cd887c7241f970d15ec91c3eeafe0a4 Mon Sep 17 00:00:00 2001 From: DavIvek Date: Fri, 27 Jun 2025 14:49:57 +0200 Subject: [PATCH 06/30] revert lightrag_openai_demo.py changes --- examples/lightrag_openai_demo.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/lightrag_openai_demo.py b/examples/lightrag_openai_demo.py index e573ec41..fa0b37f1 100644 --- a/examples/lightrag_openai_demo.py +++ b/examples/lightrag_openai_demo.py @@ -82,7 +82,6 @@ async def initialize_rag(): working_dir=WORKING_DIR, embedding_func=openai_embed, llm_model_func=gpt_4o_mini_complete, - graph_storage="MemgraphStorage", ) await rag.initialize_storages() From 9aaa7d2dd3e6386e6f61ef3a93e0e69077949f1a Mon Sep 17 00:00:00 2001 From: DavIvek Date: Fri, 27 Jun 2025 15:09:22 +0200 Subject: [PATCH 07/30] fix drop function in Memgraph implementation --- lightrag/kg/memgraph_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index bf870154..36f0186b 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -537,7 +537,7 @@ class MemgraphStorage(BaseGraphStorage): """ try: async with self._driver.session(database=self._DATABASE) as session: - query = "DROP GRAPH" + query = "MATCH (n) DETACH DELETE n" result = await session.run(query) await result.consume() logger.info( From c0a3638d011ab2b3df586fbc0aaf970d37aeee2c Mon Sep 17 00:00:00 2001 From: DavIvek Date: Fri, 27 Jun 2025 15:35:20 +0200 Subject: [PATCH 08/30] fix memgraph_impl.py according to test_graph_storage.py --- lightrag/kg/memgraph_impl.py | 96 +++++++++++++++++++----------------- tests/test_graph_storage.py | 1 + 2 files changed, 52 insertions(+), 45 deletions(-) diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index 36f0186b..41a1129b 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -594,8 +594,8 @@ class MemgraphStorage(BaseGraphStorage): node_dict = dict(node) node_dict["id"] = node_dict.get("entity_id") nodes.append(node_dict) - await result.consume() - return nodes + await result.consume() + return nodes async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: """Get all edges that are associated with the given chunk_ids. @@ -614,7 +614,12 @@ class MemgraphStorage(BaseGraphStorage): UNWIND $chunk_ids AS chunk_id MATCH (a:base)-[r]-(b:base) WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep) - RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties + WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id + // Ensure we only return each unique edge once by ordering the source and target + WITH a, b, r, + CASE WHEN source_id <= target_id THEN source_id ELSE target_id END AS ordered_source, + CASE WHEN source_id <= target_id THEN target_id ELSE source_id END AS ordered_target + RETURN DISTINCT ordered_source AS source, ordered_target AS target, properties(r) AS properties """ result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP) edges = [] @@ -650,10 +655,10 @@ class MemgraphStorage(BaseGraphStorage): result = KnowledgeGraph() seen_nodes = set() seen_edges = set() - try: - async with self._driver.session( - database=self._DATABASE, default_access_mode="READ" - ) as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: if node_label == "*": count_query = "MATCH (n) RETURN count(n) as total" count_result = None @@ -736,45 +741,46 @@ class MemgraphStorage(BaseGraphStorage): logger.debug(f"No record found for node {node_label}") return result - for node_info in record["node_info"]: - node = node_info["node"] - node_id = node.id - if node_id not in seen_nodes: - seen_nodes.add(node_id) - result.nodes.append( - KnowledgeGraphNode( - id=f"{node_id}", - labels=[node.get("entity_id")], - properties=dict(node), - ) - ) - - for rel in record["relationships"]: - edge_id = rel.id - if edge_id not in seen_edges: - seen_edges.add(edge_id) - start = rel.start_node - end = rel.end_node - result.edges.append( - KnowledgeGraphEdge( - id=f"{edge_id}", - type=rel.type, - source=f"{start.id}", - target=f"{end.id}", - properties=dict(rel), - ) - ) - - logger.info( - f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" - ) - - return result - finally: if result_set: await result_set.consume() - except Exception as e: - logger.error(f"Error getting knowledge graph: {str(e)}") - return result + if record: + for node_info in record["node_info"]: + node = node_info["node"] + node_id = node.id + if node_id not in seen_nodes: + seen_nodes.add(node_id) + result.nodes.append( + KnowledgeGraphNode( + id=f"{node_id}", + labels=[node.get("entity_id")], + properties=dict(node), + ) + ) + + for rel in record["relationships"]: + edge_id = rel.id + if edge_id not in seen_edges: + seen_edges.add(edge_id) + start = rel.start_node + end = rel.end_node + result.edges.append( + KnowledgeGraphEdge( + id=f"{edge_id}", + type=rel.type, + source=f"{start.id}", + target=f"{end.id}", + properties=dict(rel), + ) + ) + + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + + return result + + except Exception as e: + logger.error(f"Error getting knowledge graph: {str(e)}") + return result diff --git a/tests/test_graph_storage.py b/tests/test_graph_storage.py index 64e66f48..3fd1abbc 100644 --- a/tests/test_graph_storage.py +++ b/tests/test_graph_storage.py @@ -9,6 +9,7 @@ - NetworkXStorage - Neo4JStorage - PGGraphStorage +- MemgraphStorage """ import asyncio From 4ea38456f060892fed2953fdc843760a920c1db5 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 5 Jul 2025 00:31:52 +0800 Subject: [PATCH 09/30] Improve graph query robustness and error handling --- lightrag/kg/memgraph_impl.py | 74 +++++++++++++++++++++++------------- 1 file changed, 47 insertions(+), 27 deletions(-) diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index 41a1129b..397e5a99 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -660,16 +660,23 @@ class MemgraphStorage(BaseGraphStorage): ) as session: try: if node_label == "*": + # First check if database has any nodes count_query = "MATCH (n) RETURN count(n) as total" count_result = None + total_count = 0 try: count_result = await session.run(count_query) count_record = await count_result.single() - if count_record and count_record["total"] > max_nodes: - result.is_truncated = True - logger.info( - f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}" - ) + if count_record: + total_count = count_record["total"] + if total_count == 0: + logger.debug("No nodes found in database") + return result + if total_count > max_nodes: + result.is_truncated = True + logger.info( + f"Graph truncated: {total_count} nodes found, limited to {max_nodes}" + ) finally: if count_result: await count_result.consume() @@ -695,6 +702,9 @@ class MemgraphStorage(BaseGraphStorage): main_query, {"max_nodes": max_nodes} ) record = await result_set.single() + if not record: + logger.debug("No record returned from main query") + return result finally: if result_set: await result_set.consume() @@ -738,14 +748,22 @@ class MemgraphStorage(BaseGraphStorage): ) record = await result_set.single() if not record: - logger.debug(f"No record found for node {node_label}") + logger.debug(f"No nodes found for entity_id: {node_label}") return result + # Check if the query indicates truncation + if "is_truncated" in record and record["is_truncated"]: + result.is_truncated = True + logger.info( + f"Graph truncated: breadth-first search limited to {max_nodes} nodes" + ) + finally: if result_set: await result_set.consume() - if record: + # Process the record if it exists + if record and record["node_info"]: for node_info in record["node_info"]: node = node_info["node"] node_id = node.id @@ -759,28 +777,30 @@ class MemgraphStorage(BaseGraphStorage): ) ) - for rel in record["relationships"]: - edge_id = rel.id - if edge_id not in seen_edges: - seen_edges.add(edge_id) - start = rel.start_node - end = rel.end_node - result.edges.append( - KnowledgeGraphEdge( - id=f"{edge_id}", - type=rel.type, - source=f"{start.id}", - target=f"{end.id}", - properties=dict(rel), + if "relationships" in record and record["relationships"]: + for rel in record["relationships"]: + edge_id = rel.id + if edge_id not in seen_edges: + seen_edges.add(edge_id) + start = rel.start_node + end = rel.end_node + result.edges.append( + KnowledgeGraphEdge( + id=f"{edge_id}", + type=rel.type, + source=f"{start.id}", + target=f"{end.id}", + properties=dict(rel), + ) ) - ) - logger.info( - f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" - ) - - return result + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) except Exception as e: logger.error(f"Error getting knowledge graph: {str(e)}") - return result + # Return empty but properly initialized KnowledgeGraph on error + return KnowledgeGraph() + + return result From 2f7cef968d49a0986d0f14f3903862947d208812 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 5 Jul 2025 00:32:55 +0800 Subject: [PATCH 10/30] fix: ensure Milvus collections are loaded before operations - Resolves "collection not loaded" MilvusException errors --- lightrag/kg/milvus_impl.py | 44 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index 6cffae88..eecf679a 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -539,6 +539,23 @@ class MilvusVectorDBStorage(BaseVectorStorage): ) raise + def _ensure_collection_loaded(self): + """Ensure the collection is loaded into memory for search operations""" + try: + # Check if collection exists first + if not self._client.has_collection(self.namespace): + logger.error(f"Collection {self.namespace} does not exist") + raise ValueError(f"Collection {self.namespace} does not exist") + + # Load the collection if it's not already loaded + # In Milvus, collections need to be loaded before they can be searched + self._client.load_collection(self.namespace) + logger.debug(f"Collection {self.namespace} loaded successfully") + + except Exception as e: + logger.error(f"Failed to load collection {self.namespace}: {e}") + raise + def _create_collection_if_not_exist(self): """Create collection if not exists and check existing collection compatibility""" @@ -565,6 +582,8 @@ class MilvusVectorDBStorage(BaseVectorStorage): f"Collection '{self.namespace}' confirmed to exist, validating compatibility..." ) self._validate_collection_compatibility() + # Ensure the collection is loaded after validation + self._ensure_collection_loaded() return except Exception as describe_error: logger.warning( @@ -587,6 +606,9 @@ class MilvusVectorDBStorage(BaseVectorStorage): # Then create indexes self._create_indexes_after_collection() + # Load the newly created collection + self._ensure_collection_loaded() + logger.info(f"Successfully created Milvus collection: {self.namespace}") except Exception as e: @@ -615,6 +637,10 @@ class MilvusVectorDBStorage(BaseVectorStorage): collection_name=self.namespace, schema=schema ) self._create_indexes_after_collection() + + # Load the newly created collection + self._ensure_collection_loaded() + logger.info(f"Successfully force-created collection {self.namespace}") except Exception as create_error: @@ -670,6 +696,9 @@ class MilvusVectorDBStorage(BaseVectorStorage): if not data: return + # Ensure collection is loaded before upserting + self._ensure_collection_loaded() + import time current_time = int(time.time()) @@ -700,6 +729,9 @@ class MilvusVectorDBStorage(BaseVectorStorage): async def query( self, query: str, top_k: int, ids: list[str] | None = None ) -> list[dict[str, Any]]: + # Ensure collection is loaded before querying + self._ensure_collection_loaded() + embedding = await self.embedding_func( [query], _priority=5 ) # higher priority for query @@ -764,6 +796,9 @@ class MilvusVectorDBStorage(BaseVectorStorage): entity_name: The name of the entity whose relations should be deleted """ try: + # Ensure collection is loaded before querying + self._ensure_collection_loaded() + # Search for relations where entity is either source or target expr = f'src_id == "{entity_name}" or tgt_id == "{entity_name}"' @@ -802,6 +837,9 @@ class MilvusVectorDBStorage(BaseVectorStorage): ids: List of vector IDs to be deleted """ try: + # Ensure collection is loaded before deleting + self._ensure_collection_loaded() + # Delete vectors by IDs result = self._client.delete(collection_name=self.namespace, pks=ids) @@ -825,6 +863,9 @@ class MilvusVectorDBStorage(BaseVectorStorage): The vector data if found, or None if not found """ try: + # Ensure collection is loaded before querying + self._ensure_collection_loaded() + # Include all meta_fields (created_at is now always included) plus id output_fields = list(self.meta_fields) + ["id"] @@ -856,6 +897,9 @@ class MilvusVectorDBStorage(BaseVectorStorage): return [] try: + # Ensure collection is loaded before querying + self._ensure_collection_loaded() + # Include all meta_fields (created_at is now always included) plus id output_fields = list(self.meta_fields) + ["id"] From fb979be9ff9f19aad60030e509764938540e97ea Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 5 Jul 2025 07:09:33 +0800 Subject: [PATCH 11/30] Fix linting --- lightrag/kg/milvus_impl.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index eecf679a..2226784f 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -546,12 +546,12 @@ class MilvusVectorDBStorage(BaseVectorStorage): if not self._client.has_collection(self.namespace): logger.error(f"Collection {self.namespace} does not exist") raise ValueError(f"Collection {self.namespace} does not exist") - + # Load the collection if it's not already loaded # In Milvus, collections need to be loaded before they can be searched self._client.load_collection(self.namespace) logger.debug(f"Collection {self.namespace} loaded successfully") - + except Exception as e: logger.error(f"Failed to load collection {self.namespace}: {e}") raise @@ -637,10 +637,10 @@ class MilvusVectorDBStorage(BaseVectorStorage): collection_name=self.namespace, schema=schema ) self._create_indexes_after_collection() - + # Load the newly created collection self._ensure_collection_loaded() - + logger.info(f"Successfully force-created collection {self.namespace}") except Exception as create_error: @@ -731,7 +731,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): ) -> list[dict[str, Any]]: # Ensure collection is loaded before querying self._ensure_collection_loaded() - + embedding = await self.embedding_func( [query], _priority=5 ) # higher priority for query @@ -798,7 +798,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): try: # Ensure collection is loaded before querying self._ensure_collection_loaded() - + # Search for relations where entity is either source or target expr = f'src_id == "{entity_name}" or tgt_id == "{entity_name}"' @@ -839,7 +839,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): try: # Ensure collection is loaded before deleting self._ensure_collection_loaded() - + # Delete vectors by IDs result = self._client.delete(collection_name=self.namespace, pks=ids) @@ -865,7 +865,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): try: # Ensure collection is loaded before querying self._ensure_collection_loaded() - + # Include all meta_fields (created_at is now always included) plus id output_fields = list(self.meta_fields) + ["id"] @@ -899,7 +899,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): try: # Ensure collection is loaded before querying self._ensure_collection_loaded() - + # Include all meta_fields (created_at is now always included) plus id output_fields = list(self.meta_fields) + ["id"] From 75dd4f3498d06d754f9ddff62a6e650d639823e7 Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Mon, 7 Jul 2025 22:44:59 +0800 Subject: [PATCH 12/30] add rerank model --- docs/rerank_integration.md | 271 ++++++++++++++++++++++++++++++++ env.example | 11 ++ examples/rerank_example.py | 193 +++++++++++++++++++++++ lightrag/lightrag.py | 45 ++++++ lightrag/operate.py | 85 ++++++++++ lightrag/rerank.py | 307 +++++++++++++++++++++++++++++++++++++ 6 files changed, 912 insertions(+) create mode 100644 docs/rerank_integration.md create mode 100644 examples/rerank_example.py create mode 100644 lightrag/rerank.py diff --git a/docs/rerank_integration.md b/docs/rerank_integration.md new file mode 100644 index 00000000..647c0f91 --- /dev/null +++ b/docs/rerank_integration.md @@ -0,0 +1,271 @@ +# Rerank Integration in LightRAG + +This document explains how to configure and use the rerank functionality in LightRAG to improve retrieval quality. + +## ⚠️ Important: Parameter Priority + +**QueryParam.top_k has higher priority than rerank_top_k configuration:** + +- When you set `QueryParam(top_k=5)`, it will override the `rerank_top_k=10` setting in LightRAG configuration +- This means the actual number of documents sent to rerank will be determined by QueryParam.top_k +- For optimal rerank performance, always consider the top_k value in your QueryParam calls +- Example: `rag.aquery(query, param=QueryParam(mode="naive", top_k=20))` will use 20, not rerank_top_k + +## Overview + +Reranking is an optional feature that improves the quality of retrieved documents by re-ordering them based on their relevance to the query. This is particularly useful when you want higher precision in document retrieval across all query modes (naive, local, global, hybrid, mix). + +## Architecture + +The rerank integration follows the same design pattern as the LLM integration: + +- **Configurable Models**: Support for multiple rerank providers through a generic API +- **Async Processing**: Non-blocking rerank operations +- **Error Handling**: Graceful fallback to original results +- **Optional Feature**: Can be enabled/disabled via configuration +- **Code Reuse**: Single generic implementation for Jina/Cohere compatible APIs + +## Configuration + +### Environment Variables + +Set these variables in your `.env` file or environment: + +```bash +# Enable/disable reranking +ENABLE_RERANK=True + +# Rerank model configuration +RERANK_MODEL=BAAI/bge-reranker-v2-m3 +RERANK_MAX_ASYNC=4 +RERANK_TOP_K=10 + +# API configuration +RERANK_API_KEY=your_rerank_api_key_here +RERANK_BASE_URL=https://api.your-provider.com/v1/rerank + +# Provider-specific keys (optional alternatives) +JINA_API_KEY=your_jina_api_key_here +COHERE_API_KEY=your_cohere_api_key_here +``` + +### Programmatic Configuration + +```python +from lightrag import LightRAG +from lightrag.rerank import custom_rerank, RerankModel + +# Method 1: Using environment variables (recommended) +rag = LightRAG( + working_dir="./rag_storage", + llm_model_func=your_llm_func, + embedding_func=your_embedding_func, + # Rerank automatically configured from environment variables +) + +# Method 2: Explicit configuration +rerank_model = RerankModel( + rerank_func=custom_rerank, + kwargs={ + "model": "BAAI/bge-reranker-v2-m3", + "base_url": "https://api.your-provider.com/v1/rerank", + "api_key": "your_api_key_here", + } +) + +rag = LightRAG( + working_dir="./rag_storage", + llm_model_func=your_llm_func, + embedding_func=your_embedding_func, + enable_rerank=True, + rerank_model_func=rerank_model.rerank, + rerank_top_k=10, +) +``` + +## Supported Providers + +### 1. Custom/Generic API (Recommended) + +For Jina/Cohere compatible APIs: + +```python +from lightrag.rerank import custom_rerank + +# Your custom API endpoint +result = await custom_rerank( + query="your query", + documents=documents, + model="BAAI/bge-reranker-v2-m3", + base_url="https://api.your-provider.com/v1/rerank", + api_key="your_api_key_here", + top_k=10 +) +``` + +### 2. Jina AI + +```python +from lightrag.rerank import jina_rerank + +result = await jina_rerank( + query="your query", + documents=documents, + model="BAAI/bge-reranker-v2-m3", + api_key="your_jina_api_key" +) +``` + +### 3. Cohere + +```python +from lightrag.rerank import cohere_rerank + +result = await cohere_rerank( + query="your query", + documents=documents, + model="rerank-english-v2.0", + api_key="your_cohere_api_key" +) +``` + +## Integration Points + +Reranking is automatically applied at these key retrieval stages: + +1. **Naive Mode**: After vector similarity search in `_get_vector_context` +2. **Local Mode**: After entity retrieval in `_get_node_data` +3. **Global Mode**: After relationship retrieval in `_get_edge_data` +4. **Hybrid/Mix Modes**: Applied to all relevant components + +## Configuration Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `enable_rerank` | bool | False | Enable/disable reranking | +| `rerank_model_name` | str | "BAAI/bge-reranker-v2-m3" | Model identifier | +| `rerank_model_max_async` | int | 4 | Max concurrent rerank calls | +| `rerank_top_k` | int | 10 | Number of top results to return ⚠️ **Overridden by QueryParam.top_k** | +| `rerank_model_func` | callable | None | Custom rerank function | +| `rerank_model_kwargs` | dict | {} | Additional rerank parameters | + +## Example Usage + +### Basic Usage + +```python +import asyncio +from lightrag import LightRAG, QueryParam +from lightrag.llm.openai import gpt_4o_mini_complete, openai_embedding + +async def main(): + # Initialize with rerank enabled + rag = LightRAG( + working_dir="./rag_storage", + llm_model_func=gpt_4o_mini_complete, + embedding_func=openai_embedding, + enable_rerank=True, + ) + + # Insert documents + await rag.ainsert([ + "Document 1 content...", + "Document 2 content...", + ]) + + # Query with rerank (automatically applied) + result = await rag.aquery( + "Your question here", + param=QueryParam(mode="hybrid", top_k=5) # ⚠️ This top_k=5 overrides rerank_top_k + ) + + print(result) + +asyncio.run(main()) +``` + +### Direct Rerank Usage + +```python +from lightrag.rerank import custom_rerank + +async def test_rerank(): + documents = [ + {"content": "Text about topic A"}, + {"content": "Text about topic B"}, + {"content": "Text about topic C"}, + ] + + reranked = await custom_rerank( + query="Tell me about topic A", + documents=documents, + model="BAAI/bge-reranker-v2-m3", + base_url="https://api.your-provider.com/v1/rerank", + api_key="your_api_key_here", + top_k=2 + ) + + for doc in reranked: + print(f"Score: {doc.get('rerank_score')}, Content: {doc.get('content')}") +``` + +## Best Practices + +1. **Parameter Priority Awareness**: Remember that QueryParam.top_k always overrides rerank_top_k configuration +2. **Performance**: Use reranking selectively for better performance vs. quality tradeoff +3. **API Limits**: Monitor API usage and implement rate limiting if needed +4. **Fallback**: Always handle rerank failures gracefully (returns original results) +5. **Top-k Selection**: Choose appropriate `top_k` values in QueryParam based on your use case +6. **Cost Management**: Consider rerank API costs in your budget planning + +## Troubleshooting + +### Common Issues + +1. **API Key Missing**: Ensure `RERANK_API_KEY` or provider-specific keys are set +2. **Network Issues**: Check `RERANK_BASE_URL` and network connectivity +3. **Model Errors**: Verify the rerank model name is supported by your API +4. **Document Format**: Ensure documents have `content` or `text` fields + +### Debug Mode + +Enable debug logging to see rerank operations: + +```python +import logging +logging.getLogger("lightrag.rerank").setLevel(logging.DEBUG) +``` + +### Error Handling + +The rerank integration includes automatic fallback: + +```python +# If rerank fails, original documents are returned +# No exceptions are raised to the user +# Errors are logged for debugging +``` + +## API Compatibility + +The generic rerank API expects this response format: + +```json +{ + "results": [ + { + "index": 0, + "relevance_score": 0.95 + }, + { + "index": 2, + "relevance_score": 0.87 + } + ] +} +``` + +This is compatible with: +- Jina AI Rerank API +- Cohere Rerank API +- Custom APIs following the same format \ No newline at end of file diff --git a/env.example b/env.example index 1efe4830..49546343 100644 --- a/env.example +++ b/env.example @@ -179,3 +179,14 @@ QDRANT_URL=http://localhost:6333 ### Redis REDIS_URI=redis://localhost:6379 # REDIS_WORKSPACE=forced_workspace_name + +# Rerank Configuration +ENABLE_RERANK=False +RERANK_MODEL=BAAI/bge-reranker-v2-m3 +RERANK_MAX_ASYNC=4 +RERANK_TOP_K=10 +# Note: QueryParam.top_k in your code will override RERANK_TOP_K setting + +# Rerank API Configuration +RERANK_API_KEY=your_rerank_api_key_here +RERANK_BASE_URL=https://api.your-provider.com/v1/rerank diff --git a/examples/rerank_example.py b/examples/rerank_example.py new file mode 100644 index 00000000..30ad794d --- /dev/null +++ b/examples/rerank_example.py @@ -0,0 +1,193 @@ +""" +LightRAG Rerank Integration Example + +This example demonstrates how to use rerank functionality with LightRAG +to improve retrieval quality across different query modes. + +IMPORTANT: Parameter Priority +- QueryParam(top_k=N) has higher priority than rerank_top_k in LightRAG configuration +- If you set QueryParam(top_k=5), it will override rerank_top_k setting +- For optimal rerank performance, use appropriate top_k values in QueryParam + +Configuration Required: +1. Set your LLM API key and base URL in llm_model_func() +2. Set your embedding API key and base URL in embedding_func() +3. Set your rerank API key and base URL in the rerank configuration +4. Or use environment variables (.env file): + - RERANK_API_KEY=your_actual_rerank_api_key + - RERANK_BASE_URL=https://your-actual-rerank-endpoint/v1/rerank + - RERANK_MODEL=your_rerank_model_name +""" + +import asyncio +import os +import numpy as np + +from lightrag import LightRAG, QueryParam +from lightrag.rerank import custom_rerank, RerankModel +from lightrag.llm.openai import openai_complete_if_cache, openai_embed +from lightrag.utils import EmbeddingFunc, setup_logger + +# Set up your working directory +WORKING_DIR = "./test_rerank" +setup_logger("test_rerank") + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + return await openai_complete_if_cache( + "gpt-4o-mini", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key="your_llm_api_key_here", + base_url="https://api.your-llm-provider.com/v1", + **kwargs, + ) + +async def embedding_func(texts: list[str]) -> np.ndarray: + return await openai_embed( + texts, + model="text-embedding-3-large", + api_key="your_embedding_api_key_here", + base_url="https://api.your-embedding-provider.com/v1", + ) + +async def create_rag_with_rerank(): + """Create LightRAG instance with rerank configuration""" + + # Get embedding dimension + test_embedding = await embedding_func(["test"]) + embedding_dim = test_embedding.shape[1] + print(f"Detected embedding dimension: {embedding_dim}") + + # Create rerank model + rerank_model = RerankModel( + rerank_func=custom_rerank, + kwargs={ + "model": "BAAI/bge-reranker-v2-m3", + "base_url": "https://api.your-rerank-provider.com/v1/rerank", + "api_key": "your_rerank_api_key_here", + } + ) + + # Initialize LightRAG with rerank + rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=embedding_dim, + max_token_size=8192, + func=embedding_func, + ), + # Rerank Configuration + enable_rerank=True, + rerank_model_func=rerank_model.rerank, + rerank_top_k=10, # Note: QueryParam.top_k will override this + ) + + return rag + +async def test_rerank_with_different_topk(): + """ + Test rerank functionality with different top_k settings to demonstrate parameter priority + """ + print("🚀 Setting up LightRAG with Rerank functionality...") + + rag = await create_rag_with_rerank() + + # Insert sample documents + sample_docs = [ + "Reranking improves retrieval quality by re-ordering documents based on relevance.", + "LightRAG is a powerful retrieval-augmented generation system with multiple query modes.", + "Vector databases enable efficient similarity search in high-dimensional embedding spaces.", + "Natural language processing has evolved with large language models and transformers.", + "Machine learning algorithms can learn patterns from data without explicit programming." + ] + + print("📄 Inserting sample documents...") + await rag.ainsert(sample_docs) + + query = "How does reranking improve retrieval quality?" + print(f"\n🔍 Testing query: '{query}'") + print("=" * 80) + + # Test different top_k values to show parameter priority + top_k_values = [2, 5, 10] + + for top_k in top_k_values: + print(f"\n📊 Testing with QueryParam(top_k={top_k}) - overrides rerank_top_k=10:") + + # Test naive mode with specific top_k + result = await rag.aquery( + query, + param=QueryParam(mode="naive", top_k=top_k) + ) + print(f" Result length: {len(result)} characters") + print(f" Preview: {result[:100]}...") + +async def test_direct_rerank(): + """Test rerank function directly""" + print("\n🔧 Direct Rerank API Test") + print("=" * 40) + + documents = [ + {"content": "Reranking significantly improves retrieval quality"}, + {"content": "LightRAG supports advanced reranking capabilities"}, + {"content": "Vector search finds semantically similar documents"}, + {"content": "Natural language processing with modern transformers"}, + {"content": "The quick brown fox jumps over the lazy dog"} + ] + + query = "rerank improve quality" + print(f"Query: '{query}'") + print(f"Documents: {len(documents)}") + + try: + reranked_docs = await custom_rerank( + query=query, + documents=documents, + model="BAAI/bge-reranker-v2-m3", + base_url="https://api.your-rerank-provider.com/v1/rerank", + api_key="your_rerank_api_key_here", + top_k=3 + ) + + print("\n✅ Rerank Results:") + for i, doc in enumerate(reranked_docs): + score = doc.get("rerank_score", "N/A") + content = doc.get("content", "")[:60] + print(f" {i+1}. Score: {score:.4f} | {content}...") + + except Exception as e: + print(f"❌ Rerank failed: {e}") + +async def main(): + """Main example function""" + print("🎯 LightRAG Rerank Integration Example") + print("=" * 60) + + try: + # Test rerank with different top_k values + await test_rerank_with_different_topk() + + # Test direct rerank + await test_direct_rerank() + + print("\n✅ Example completed successfully!") + print("\n💡 Key Points:") + print(" ✓ QueryParam.top_k has higher priority than rerank_top_k") + print(" ✓ Rerank improves document relevance ordering") + print(" ✓ Configure API keys in your .env file for production") + print(" ✓ Monitor API usage and costs when using rerank services") + + except Exception as e: + print(f"\n❌ Example failed: {e}") + import traceback + traceback.print_exc() + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 5d96aeba..cee08373 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -240,6 +240,35 @@ class LightRAG: llm_model_kwargs: dict[str, Any] = field(default_factory=dict) """Additional keyword arguments passed to the LLM model function.""" + # Rerank Configuration + # --- + + enable_rerank: bool = field( + default=bool(os.getenv("ENABLE_RERANK", "False").lower() == "true") + ) + """Enable reranking for improved retrieval quality. Defaults to False.""" + + rerank_model_func: Callable[..., object] | None = field(default=None) + """Function for reranking retrieved documents. Optional.""" + + rerank_model_name: str = field( + default=os.getenv("RERANK_MODEL", "BAAI/bge-reranker-v2-m3") + ) + """Name of the rerank model used for reranking documents.""" + + rerank_model_max_async: int = field(default=int(os.getenv("RERANK_MAX_ASYNC", 4))) + """Maximum number of concurrent rerank calls.""" + + rerank_model_kwargs: dict[str, Any] = field(default_factory=dict) + """Additional keyword arguments passed to the rerank model function.""" + + rerank_top_k: int = field(default=int(os.getenv("RERANK_TOP_K", 10))) + """Number of top documents to return after reranking. + + Note: This value will be overridden by QueryParam.top_k in query calls. + Example: QueryParam(top_k=5) will override rerank_top_k=10 setting. + """ + # Storage # --- @@ -444,6 +473,22 @@ class LightRAG: ) ) + # Init Rerank + if self.enable_rerank and self.rerank_model_func: + self.rerank_model_func = priority_limit_async_func_call( + self.rerank_model_max_async + )( + partial( + self.rerank_model_func, # type: ignore + **self.rerank_model_kwargs, + ) + ) + logger.info("Rerank model initialized for improved retrieval quality") + elif self.enable_rerank and not self.rerank_model_func: + logger.warning( + "Rerank is enabled but no rerank_model_func provided. Reranking will be skipped." + ) + self._storages_status = StoragesStatus.CREATED if self.auto_manage_storages_states: diff --git a/lightrag/operate.py b/lightrag/operate.py index 88837435..b5d74c55 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1783,6 +1783,15 @@ async def _get_vector_context( if not valid_chunks: return [], [], [] + # Apply reranking if enabled + global_config = chunks_vdb.global_config + valid_chunks = await apply_rerank_if_enabled( + query=query, + retrieved_docs=valid_chunks, + global_config=global_config, + top_k=query_param.top_k, + ) + maybe_trun_chunks = truncate_list_by_token_size( valid_chunks, key=lambda x: x["content"], @@ -1966,6 +1975,15 @@ async def _get_node_data( if not len(results): return "", "", "" + # Apply reranking if enabled for entity results + global_config = entities_vdb.global_config + results = await apply_rerank_if_enabled( + query=query, + retrieved_docs=results, + global_config=global_config, + top_k=query_param.top_k, + ) + # Extract all entity IDs from your results list node_ids = [r["entity_name"] for r in results] @@ -2269,6 +2287,15 @@ async def _get_edge_data( if not len(results): return "", "", "" + # Apply reranking if enabled for relationship results + global_config = relationships_vdb.global_config + results = await apply_rerank_if_enabled( + query=keywords, + retrieved_docs=results, + global_config=global_config, + top_k=query_param.top_k, + ) + # Prepare edge pairs in two forms: # For the batch edge properties function, use dicts. edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results] @@ -2806,3 +2833,61 @@ async def query_with_keywords( ) else: raise ValueError(f"Unknown mode {param.mode}") + + +async def apply_rerank_if_enabled( + query: str, + retrieved_docs: list[dict], + global_config: dict, + top_k: int = None, +) -> list[dict]: + """ + Apply reranking to retrieved documents if rerank is enabled. + + Args: + query: The search query + retrieved_docs: List of retrieved documents + global_config: Global configuration containing rerank settings + top_k: Number of top documents to return after reranking + + Returns: + Reranked documents if rerank is enabled, otherwise original documents + """ + if not global_config.get("enable_rerank", False) or not retrieved_docs: + return retrieved_docs + + rerank_func = global_config.get("rerank_model_func") + if not rerank_func: + logger.debug( + "Rerank is enabled but no rerank function provided, skipping rerank" + ) + return retrieved_docs + + try: + # Determine top_k for reranking + rerank_top_k = top_k or global_config.get("rerank_top_k", 10) + rerank_top_k = min(rerank_top_k, len(retrieved_docs)) + + logger.debug( + f"Applying rerank to {len(retrieved_docs)} documents, returning top {rerank_top_k}" + ) + + # Apply reranking + reranked_docs = await rerank_func( + query=query, + documents=retrieved_docs, + top_k=rerank_top_k, + ) + + if reranked_docs and len(reranked_docs) > 0: + logger.info( + f"Successfully reranked {len(retrieved_docs)} documents to {len(reranked_docs)}" + ) + return reranked_docs + else: + logger.warning("Rerank returned empty results, using original documents") + return retrieved_docs[:rerank_top_k] if rerank_top_k else retrieved_docs + + except Exception as e: + logger.error(f"Error during reranking: {e}, using original documents") + return retrieved_docs diff --git a/lightrag/rerank.py b/lightrag/rerank.py new file mode 100644 index 00000000..d25a8485 --- /dev/null +++ b/lightrag/rerank.py @@ -0,0 +1,307 @@ +from __future__ import annotations + +import os +import json +import aiohttp +import numpy as np +from typing import Callable, Any, List, Dict, Optional +from pydantic import BaseModel, Field +from dataclasses import asdict + +from .utils import logger + + +class RerankModel(BaseModel): + """ + Pydantic model class for defining a custom rerank model. + + Attributes: + rerank_func (Callable[[Any], List[Dict]]): A callable function that reranks documents. + The function should take query and documents as input and return reranked results. + kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function. + This could include parameters such as the model name, API key, etc. + + Example usage: + Rerank model example from jina: + ```python + rerank_model = RerankModel( + rerank_func=jina_rerank, + kwargs={ + "model": "BAAI/bge-reranker-v2-m3", + "api_key": "your_api_key_here", + "base_url": "https://api.jina.ai/v1/rerank" + } + ) + ``` + """ + + rerank_func: Callable[[Any], List[Dict]] + kwargs: Dict[str, Any] = Field(default_factory=dict) + + async def rerank( + self, + query: str, + documents: List[Dict[str, Any]], + top_k: Optional[int] = None, + **extra_kwargs + ) -> List[Dict[str, Any]]: + """Rerank documents using the configured model function.""" + # Merge extra kwargs with model kwargs + kwargs = {**self.kwargs, **extra_kwargs} + return await self.rerank_func( + query=query, + documents=documents, + top_k=top_k, + **kwargs + ) + + +class MultiRerankModel(BaseModel): + """Multiple rerank models for different modes/scenarios.""" + + # Primary rerank model (used if mode-specific models are not defined) + rerank_model: Optional[RerankModel] = None + + # Mode-specific rerank models + entity_rerank_model: Optional[RerankModel] = None + relation_rerank_model: Optional[RerankModel] = None + chunk_rerank_model: Optional[RerankModel] = None + + async def rerank( + self, + query: str, + documents: List[Dict[str, Any]], + mode: str = "default", + top_k: Optional[int] = None, + **kwargs + ) -> List[Dict[str, Any]]: + """Rerank using the appropriate model based on mode.""" + + # Select model based on mode + if mode == "entity" and self.entity_rerank_model: + model = self.entity_rerank_model + elif mode == "relation" and self.relation_rerank_model: + model = self.relation_rerank_model + elif mode == "chunk" and self.chunk_rerank_model: + model = self.chunk_rerank_model + elif self.rerank_model: + model = self.rerank_model + else: + logger.warning(f"No rerank model available for mode: {mode}") + return documents + + return await model.rerank(query, documents, top_k, **kwargs) + + +async def generic_rerank_api( + query: str, + documents: List[Dict[str, Any]], + model: str, + base_url: str, + api_key: str, + top_k: Optional[int] = None, + **kwargs +) -> List[Dict[str, Any]]: + """ + Generic rerank function that works with Jina/Cohere compatible APIs. + + Args: + query: The search query + documents: List of documents to rerank + model: Model identifier + base_url: API endpoint URL + api_key: API authentication key + top_k: Number of top results to return + **kwargs: Additional API-specific parameters + + Returns: + List of reranked documents with relevance scores + """ + if not api_key: + logger.warning("No API key provided for rerank service") + return documents + + if not documents: + return documents + + # Prepare documents for reranking - handle both text and dict formats + prepared_docs = [] + for doc in documents: + if isinstance(doc, dict): + # Use 'content' field if available, otherwise use 'text' or convert to string + text = doc.get('content') or doc.get('text') or str(doc) + else: + text = str(doc) + prepared_docs.append(text) + + # Prepare request + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}" + } + + data = { + "model": model, + "query": query, + "documents": prepared_docs, + **kwargs + } + + if top_k is not None: + data["top_k"] = min(top_k, len(prepared_docs)) + + try: + async with aiohttp.ClientSession() as session: + async with session.post(base_url, headers=headers, json=data) as response: + if response.status != 200: + error_text = await response.text() + logger.error(f"Rerank API error {response.status}: {error_text}") + return documents + + result = await response.json() + + # Extract reranked results + if "results" in result: + # Standard format: results contain index and relevance_score + reranked_docs = [] + for item in result["results"]: + if "index" in item: + doc_idx = item["index"] + if 0 <= doc_idx < len(documents): + reranked_doc = documents[doc_idx].copy() + if "relevance_score" in item: + reranked_doc["rerank_score"] = item["relevance_score"] + reranked_docs.append(reranked_doc) + return reranked_docs + else: + logger.warning("Unexpected rerank API response format") + return documents + + except Exception as e: + logger.error(f"Error during reranking: {e}") + return documents + + +async def jina_rerank( + query: str, + documents: List[Dict[str, Any]], + model: str = "BAAI/bge-reranker-v2-m3", + top_k: Optional[int] = None, + base_url: str = "https://api.jina.ai/v1/rerank", + api_key: Optional[str] = None, + **kwargs +) -> List[Dict[str, Any]]: + """ + Rerank documents using Jina AI API. + + Args: + query: The search query + documents: List of documents to rerank + model: Jina rerank model name + top_k: Number of top results to return + base_url: Jina API endpoint + api_key: Jina API key + **kwargs: Additional parameters + + Returns: + List of reranked documents with relevance scores + """ + if api_key is None: + api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_API_KEY") + + return await generic_rerank_api( + query=query, + documents=documents, + model=model, + base_url=base_url, + api_key=api_key, + top_k=top_k, + **kwargs + ) + + +async def cohere_rerank( + query: str, + documents: List[Dict[str, Any]], + model: str = "rerank-english-v2.0", + top_k: Optional[int] = None, + base_url: str = "https://api.cohere.ai/v1/rerank", + api_key: Optional[str] = None, + **kwargs +) -> List[Dict[str, Any]]: + """ + Rerank documents using Cohere API. + + Args: + query: The search query + documents: List of documents to rerank + model: Cohere rerank model name + top_k: Number of top results to return + base_url: Cohere API endpoint + api_key: Cohere API key + **kwargs: Additional parameters + + Returns: + List of reranked documents with relevance scores + """ + if api_key is None: + api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_API_KEY") + + return await generic_rerank_api( + query=query, + documents=documents, + model=model, + base_url=base_url, + api_key=api_key, + top_k=top_k, + **kwargs + ) + + +# Convenience function for custom API endpoints +async def custom_rerank( + query: str, + documents: List[Dict[str, Any]], + model: str, + base_url: str, + api_key: str, + top_k: Optional[int] = None, + **kwargs +) -> List[Dict[str, Any]]: + """ + Rerank documents using a custom API endpoint. + This is useful for self-hosted or custom rerank services. + """ + return await generic_rerank_api( + query=query, + documents=documents, + model=model, + base_url=base_url, + api_key=api_key, + top_k=top_k, + **kwargs + ) + + +if __name__ == "__main__": + import asyncio + + async def main(): + # Example usage + docs = [ + {"content": "The capital of France is Paris."}, + {"content": "Tokyo is the capital of Japan."}, + {"content": "London is the capital of England."}, + ] + + query = "What is the capital of France?" + + result = await jina_rerank( + query=query, + documents=docs, + top_k=2, + api_key="your-api-key-here" + ) + print(result) + + asyncio.run(main()) \ No newline at end of file From 3eaadb8a4432b03f67432774fd18684206208215 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 8 Jul 2025 03:06:19 +0800 Subject: [PATCH 13/30] Update env.example --- env.example | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/env.example b/env.example index 1efe4830..ef52bd53 100644 --- a/env.example +++ b/env.example @@ -159,7 +159,7 @@ NEO4J_PASSWORD='your_password' ### MongoDB Configuration MONGO_URI=mongodb://root:root@localhost:27017/ -#MONGO_URI=mongodb+srv://root:rooot@cluster0.xxxx.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0 +#MONGO_URI=mongodb+srv://xxxx MONGO_DATABASE=LightRAG # MONGODB_WORKSPACE=forced_workspace_name From f5c80d7cde3ffcf4962a0c7a894820f01558aded Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Tue, 8 Jul 2025 11:16:34 +0800 Subject: [PATCH 14/30] Simplify Configuration --- docs/rerank_integration.md | 96 +++++++++++++------------- env.example | 8 --- examples/rerank_example.py | 137 +++++++++++++++++++++++-------------- lightrag/lightrag.py | 28 +------- lightrag/operate.py | 12 ++-- lightrag/rerank.py | 130 +++++++++++++++++++---------------- 6 files changed, 210 insertions(+), 201 deletions(-) diff --git a/docs/rerank_integration.md b/docs/rerank_integration.md index 647c0f91..f216a8c8 100644 --- a/docs/rerank_integration.md +++ b/docs/rerank_integration.md @@ -2,24 +2,15 @@ This document explains how to configure and use the rerank functionality in LightRAG to improve retrieval quality. -## ⚠️ Important: Parameter Priority - -**QueryParam.top_k has higher priority than rerank_top_k configuration:** - -- When you set `QueryParam(top_k=5)`, it will override the `rerank_top_k=10` setting in LightRAG configuration -- This means the actual number of documents sent to rerank will be determined by QueryParam.top_k -- For optimal rerank performance, always consider the top_k value in your QueryParam calls -- Example: `rag.aquery(query, param=QueryParam(mode="naive", top_k=20))` will use 20, not rerank_top_k - ## Overview Reranking is an optional feature that improves the quality of retrieved documents by re-ordering them based on their relevance to the query. This is particularly useful when you want higher precision in document retrieval across all query modes (naive, local, global, hybrid, mix). ## Architecture -The rerank integration follows the same design pattern as the LLM integration: +The rerank integration follows a simplified design pattern: -- **Configurable Models**: Support for multiple rerank providers through a generic API +- **Single Function Configuration**: All rerank settings (model, API keys, top_k, etc.) are contained within the rerank function - **Async Processing**: Non-blocking rerank operations - **Error Handling**: Graceful fallback to original results - **Optional Feature**: Can be enabled/disabled via configuration @@ -29,24 +20,11 @@ The rerank integration follows the same design pattern as the LLM integration: ### Environment Variables -Set these variables in your `.env` file or environment: +Set this variable in your `.env` file or environment: ```bash # Enable/disable reranking ENABLE_RERANK=True - -# Rerank model configuration -RERANK_MODEL=BAAI/bge-reranker-v2-m3 -RERANK_MAX_ASYNC=4 -RERANK_TOP_K=10 - -# API configuration -RERANK_API_KEY=your_rerank_api_key_here -RERANK_BASE_URL=https://api.your-provider.com/v1/rerank - -# Provider-specific keys (optional alternatives) -JINA_API_KEY=your_jina_api_key_here -COHERE_API_KEY=your_cohere_api_key_here ``` ### Programmatic Configuration @@ -55,15 +33,27 @@ COHERE_API_KEY=your_cohere_api_key_here from lightrag import LightRAG from lightrag.rerank import custom_rerank, RerankModel -# Method 1: Using environment variables (recommended) +# Method 1: Using a custom rerank function with all settings included +async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs): + return await custom_rerank( + query=query, + documents=documents, + model="BAAI/bge-reranker-v2-m3", + base_url="https://api.your-provider.com/v1/rerank", + api_key="your_api_key_here", + top_k=top_k or 10, # Handle top_k within the function + **kwargs + ) + rag = LightRAG( working_dir="./rag_storage", llm_model_func=your_llm_func, embedding_func=your_embedding_func, - # Rerank automatically configured from environment variables + enable_rerank=True, + rerank_model_func=my_rerank_func, ) -# Method 2: Explicit configuration +# Method 2: Using RerankModel wrapper rerank_model = RerankModel( rerank_func=custom_rerank, kwargs={ @@ -79,7 +69,6 @@ rag = LightRAG( embedding_func=your_embedding_func, enable_rerank=True, rerank_model_func=rerank_model.rerank, - rerank_top_k=10, ) ``` @@ -112,7 +101,8 @@ result = await jina_rerank( query="your query", documents=documents, model="BAAI/bge-reranker-v2-m3", - api_key="your_jina_api_key" + api_key="your_jina_api_key", + top_k=10 ) ``` @@ -125,7 +115,8 @@ result = await cohere_rerank( query="your query", documents=documents, model="rerank-english-v2.0", - api_key="your_cohere_api_key" + api_key="your_cohere_api_key", + top_k=10 ) ``` @@ -143,11 +134,7 @@ Reranking is automatically applied at these key retrieval stages: | Parameter | Type | Default | Description | |-----------|------|---------|-------------| | `enable_rerank` | bool | False | Enable/disable reranking | -| `rerank_model_name` | str | "BAAI/bge-reranker-v2-m3" | Model identifier | -| `rerank_model_max_async` | int | 4 | Max concurrent rerank calls | -| `rerank_top_k` | int | 10 | Number of top results to return ⚠️ **Overridden by QueryParam.top_k** | -| `rerank_model_func` | callable | None | Custom rerank function | -| `rerank_model_kwargs` | dict | {} | Additional rerank parameters | +| `rerank_model_func` | callable | None | Custom rerank function containing all configurations (model, API keys, top_k, etc.) | ## Example Usage @@ -157,6 +144,18 @@ Reranking is automatically applied at these key retrieval stages: import asyncio from lightrag import LightRAG, QueryParam from lightrag.llm.openai import gpt_4o_mini_complete, openai_embedding +from lightrag.rerank import jina_rerank + +async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs): + """Custom rerank function with all settings included""" + return await jina_rerank( + query=query, + documents=documents, + model="BAAI/bge-reranker-v2-m3", + api_key="your_jina_api_key_here", + top_k=top_k or 10, # Default top_k if not provided + **kwargs + ) async def main(): # Initialize with rerank enabled @@ -165,20 +164,21 @@ async def main(): llm_model_func=gpt_4o_mini_complete, embedding_func=openai_embedding, enable_rerank=True, + rerank_model_func=my_rerank_func, ) - + # Insert documents await rag.ainsert([ "Document 1 content...", "Document 2 content...", ]) - + # Query with rerank (automatically applied) result = await rag.aquery( "Your question here", - param=QueryParam(mode="hybrid", top_k=5) # ⚠️ This top_k=5 overrides rerank_top_k + param=QueryParam(mode="hybrid", top_k=5) # This top_k is passed to rerank function ) - + print(result) asyncio.run(main()) @@ -195,7 +195,7 @@ async def test_rerank(): {"content": "Text about topic B"}, {"content": "Text about topic C"}, ] - + reranked = await custom_rerank( query="Tell me about topic A", documents=documents, @@ -204,26 +204,26 @@ async def test_rerank(): api_key="your_api_key_here", top_k=2 ) - + for doc in reranked: print(f"Score: {doc.get('rerank_score')}, Content: {doc.get('content')}") ``` ## Best Practices -1. **Parameter Priority Awareness**: Remember that QueryParam.top_k always overrides rerank_top_k configuration +1. **Self-Contained Functions**: Include all necessary configurations (API keys, models, top_k handling) within your rerank function 2. **Performance**: Use reranking selectively for better performance vs. quality tradeoff -3. **API Limits**: Monitor API usage and implement rate limiting if needed +3. **API Limits**: Monitor API usage and implement rate limiting within your rerank function 4. **Fallback**: Always handle rerank failures gracefully (returns original results) -5. **Top-k Selection**: Choose appropriate `top_k` values in QueryParam based on your use case +5. **Top-k Handling**: Handle top_k parameter appropriately within your rerank function 6. **Cost Management**: Consider rerank API costs in your budget planning ## Troubleshooting ### Common Issues -1. **API Key Missing**: Ensure `RERANK_API_KEY` or provider-specific keys are set -2. **Network Issues**: Check `RERANK_BASE_URL` and network connectivity +1. **API Key Missing**: Ensure API keys are properly configured within your rerank function +2. **Network Issues**: Check API endpoints and network connectivity 3. **Model Errors**: Verify the rerank model name is supported by your API 4. **Document Format**: Ensure documents have `content` or `text` fields @@ -268,4 +268,4 @@ The generic rerank API expects this response format: This is compatible with: - Jina AI Rerank API - Cohere Rerank API -- Custom APIs following the same format \ No newline at end of file +- Custom APIs following the same format diff --git a/env.example b/env.example index 49546343..c4a09cad 100644 --- a/env.example +++ b/env.example @@ -182,11 +182,3 @@ REDIS_URI=redis://localhost:6379 # Rerank Configuration ENABLE_RERANK=False -RERANK_MODEL=BAAI/bge-reranker-v2-m3 -RERANK_MAX_ASYNC=4 -RERANK_TOP_K=10 -# Note: QueryParam.top_k in your code will override RERANK_TOP_K setting - -# Rerank API Configuration -RERANK_API_KEY=your_rerank_api_key_here -RERANK_BASE_URL=https://api.your-provider.com/v1/rerank diff --git a/examples/rerank_example.py b/examples/rerank_example.py index 30ad794d..74ec85bc 100644 --- a/examples/rerank_example.py +++ b/examples/rerank_example.py @@ -4,19 +4,12 @@ LightRAG Rerank Integration Example This example demonstrates how to use rerank functionality with LightRAG to improve retrieval quality across different query modes. -IMPORTANT: Parameter Priority -- QueryParam(top_k=N) has higher priority than rerank_top_k in LightRAG configuration -- If you set QueryParam(top_k=5), it will override rerank_top_k setting -- For optimal rerank performance, use appropriate top_k values in QueryParam - Configuration Required: 1. Set your LLM API key and base URL in llm_model_func() -2. Set your embedding API key and base URL in embedding_func() +2. Set your embedding API key and base URL in embedding_func() 3. Set your rerank API key and base URL in the rerank configuration 4. Or use environment variables (.env file): - - RERANK_API_KEY=your_actual_rerank_api_key - - RERANK_BASE_URL=https://your-actual-rerank-endpoint/v1/rerank - - RERANK_MODEL=your_rerank_model_name + - ENABLE_RERANK=True """ import asyncio @@ -35,6 +28,7 @@ setup_logger("test_rerank") if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) + async def llm_model_func( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: @@ -48,6 +42,7 @@ async def llm_model_func( **kwargs, ) + async def embedding_func(texts: list[str]) -> np.ndarray: return await openai_embed( texts, @@ -56,25 +51,29 @@ async def embedding_func(texts: list[str]) -> np.ndarray: base_url="https://api.your-embedding-provider.com/v1", ) + +async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs): + """Custom rerank function with all settings included""" + return await custom_rerank( + query=query, + documents=documents, + model="BAAI/bge-reranker-v2-m3", + base_url="https://api.your-rerank-provider.com/v1/rerank", + api_key="your_rerank_api_key_here", + top_k=top_k or 10, # Default top_k if not provided + **kwargs, + ) + + async def create_rag_with_rerank(): """Create LightRAG instance with rerank configuration""" - + # Get embedding dimension test_embedding = await embedding_func(["test"]) embedding_dim = test_embedding.shape[1] print(f"Detected embedding dimension: {embedding_dim}") - # Create rerank model - rerank_model = RerankModel( - rerank_func=custom_rerank, - kwargs={ - "model": "BAAI/bge-reranker-v2-m3", - "base_url": "https://api.your-rerank-provider.com/v1/rerank", - "api_key": "your_rerank_api_key_here", - } - ) - - # Initialize LightRAG with rerank + # Method 1: Using custom rerank function rag = LightRAG( working_dir=WORKING_DIR, llm_model_func=llm_model_func, @@ -83,69 +82,100 @@ async def create_rag_with_rerank(): max_token_size=8192, func=embedding_func, ), - # Rerank Configuration + # Simplified Rerank Configuration enable_rerank=True, - rerank_model_func=rerank_model.rerank, - rerank_top_k=10, # Note: QueryParam.top_k will override this + rerank_model_func=my_rerank_func, ) return rag + +async def create_rag_with_rerank_model(): + """Alternative: Create LightRAG instance using RerankModel wrapper""" + + # Get embedding dimension + test_embedding = await embedding_func(["test"]) + embedding_dim = test_embedding.shape[1] + print(f"Detected embedding dimension: {embedding_dim}") + + # Method 2: Using RerankModel wrapper + rerank_model = RerankModel( + rerank_func=custom_rerank, + kwargs={ + "model": "BAAI/bge-reranker-v2-m3", + "base_url": "https://api.your-rerank-provider.com/v1/rerank", + "api_key": "your_rerank_api_key_here", + }, + ) + + rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=embedding_dim, + max_token_size=8192, + func=embedding_func, + ), + enable_rerank=True, + rerank_model_func=rerank_model.rerank, + ) + + return rag + + async def test_rerank_with_different_topk(): """ - Test rerank functionality with different top_k settings to demonstrate parameter priority + Test rerank functionality with different top_k settings """ print("🚀 Setting up LightRAG with Rerank functionality...") - + rag = await create_rag_with_rerank() - + # Insert sample documents sample_docs = [ "Reranking improves retrieval quality by re-ordering documents based on relevance.", "LightRAG is a powerful retrieval-augmented generation system with multiple query modes.", "Vector databases enable efficient similarity search in high-dimensional embedding spaces.", "Natural language processing has evolved with large language models and transformers.", - "Machine learning algorithms can learn patterns from data without explicit programming." + "Machine learning algorithms can learn patterns from data without explicit programming.", ] - + print("📄 Inserting sample documents...") await rag.ainsert(sample_docs) - + query = "How does reranking improve retrieval quality?" print(f"\n🔍 Testing query: '{query}'") print("=" * 80) - + # Test different top_k values to show parameter priority top_k_values = [2, 5, 10] - + for top_k in top_k_values: - print(f"\n📊 Testing with QueryParam(top_k={top_k}) - overrides rerank_top_k=10:") - + print(f"\n📊 Testing with QueryParam(top_k={top_k}):") + # Test naive mode with specific top_k - result = await rag.aquery( - query, - param=QueryParam(mode="naive", top_k=top_k) - ) + result = await rag.aquery(query, param=QueryParam(mode="naive", top_k=top_k)) print(f" Result length: {len(result)} characters") print(f" Preview: {result[:100]}...") + async def test_direct_rerank(): """Test rerank function directly""" print("\n🔧 Direct Rerank API Test") print("=" * 40) - + documents = [ {"content": "Reranking significantly improves retrieval quality"}, {"content": "LightRAG supports advanced reranking capabilities"}, {"content": "Vector search finds semantically similar documents"}, {"content": "Natural language processing with modern transformers"}, - {"content": "The quick brown fox jumps over the lazy dog"} + {"content": "The quick brown fox jumps over the lazy dog"}, ] - + query = "rerank improve quality" print(f"Query: '{query}'") print(f"Documents: {len(documents)}") - + try: reranked_docs = await custom_rerank( query=query, @@ -153,41 +183,44 @@ async def test_direct_rerank(): model="BAAI/bge-reranker-v2-m3", base_url="https://api.your-rerank-provider.com/v1/rerank", api_key="your_rerank_api_key_here", - top_k=3 + top_k=3, ) - + print("\n✅ Rerank Results:") for i, doc in enumerate(reranked_docs): score = doc.get("rerank_score", "N/A") content = doc.get("content", "")[:60] print(f" {i+1}. Score: {score:.4f} | {content}...") - + except Exception as e: print(f"❌ Rerank failed: {e}") + async def main(): """Main example function""" print("🎯 LightRAG Rerank Integration Example") print("=" * 60) - + try: # Test rerank with different top_k values await test_rerank_with_different_topk() - + # Test direct rerank await test_direct_rerank() - + print("\n✅ Example completed successfully!") print("\n💡 Key Points:") - print(" ✓ QueryParam.top_k has higher priority than rerank_top_k") + print(" ✓ All rerank configurations are contained within rerank_model_func") print(" ✓ Rerank improves document relevance ordering") - print(" ✓ Configure API keys in your .env file for production") + print(" ✓ Configure API keys within your rerank function") print(" ✓ Monitor API usage and costs when using rerank services") - + except Exception as e: print(f"\n❌ Example failed: {e}") import traceback + traceback.print_exc() + if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index cee08373..63a2f531 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -249,25 +249,7 @@ class LightRAG: """Enable reranking for improved retrieval quality. Defaults to False.""" rerank_model_func: Callable[..., object] | None = field(default=None) - """Function for reranking retrieved documents. Optional.""" - - rerank_model_name: str = field( - default=os.getenv("RERANK_MODEL", "BAAI/bge-reranker-v2-m3") - ) - """Name of the rerank model used for reranking documents.""" - - rerank_model_max_async: int = field(default=int(os.getenv("RERANK_MAX_ASYNC", 4))) - """Maximum number of concurrent rerank calls.""" - - rerank_model_kwargs: dict[str, Any] = field(default_factory=dict) - """Additional keyword arguments passed to the rerank model function.""" - - rerank_top_k: int = field(default=int(os.getenv("RERANK_TOP_K", 10))) - """Number of top documents to return after reranking. - - Note: This value will be overridden by QueryParam.top_k in query calls. - Example: QueryParam(top_k=5) will override rerank_top_k=10 setting. - """ + """Function for reranking retrieved documents. All rerank configurations (model name, API keys, top_k, etc.) should be included in this function. Optional.""" # Storage # --- @@ -475,14 +457,6 @@ class LightRAG: # Init Rerank if self.enable_rerank and self.rerank_model_func: - self.rerank_model_func = priority_limit_async_func_call( - self.rerank_model_max_async - )( - partial( - self.rerank_model_func, # type: ignore - **self.rerank_model_kwargs, - ) - ) logger.info("Rerank model initialized for improved retrieval quality") elif self.enable_rerank and not self.rerank_model_func: logger.warning( diff --git a/lightrag/operate.py b/lightrag/operate.py index b5d74c55..645c1e85 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -2864,19 +2864,15 @@ async def apply_rerank_if_enabled( return retrieved_docs try: - # Determine top_k for reranking - rerank_top_k = top_k or global_config.get("rerank_top_k", 10) - rerank_top_k = min(rerank_top_k, len(retrieved_docs)) - logger.debug( - f"Applying rerank to {len(retrieved_docs)} documents, returning top {rerank_top_k}" + f"Applying rerank to {len(retrieved_docs)} documents, returning top {top_k}" ) - # Apply reranking + # Apply reranking - let rerank_model_func handle top_k internally reranked_docs = await rerank_func( query=query, documents=retrieved_docs, - top_k=rerank_top_k, + top_k=top_k, ) if reranked_docs and len(reranked_docs) > 0: @@ -2886,7 +2882,7 @@ async def apply_rerank_if_enabled( return reranked_docs else: logger.warning("Rerank returned empty results, using original documents") - return retrieved_docs[:rerank_top_k] if rerank_top_k else retrieved_docs + return retrieved_docs except Exception as e: logger.error(f"Error during reranking: {e}, using original documents") diff --git a/lightrag/rerank.py b/lightrag/rerank.py index d25a8485..59719bc9 100644 --- a/lightrag/rerank.py +++ b/lightrag/rerank.py @@ -1,12 +1,9 @@ from __future__ import annotations import os -import json import aiohttp -import numpy as np from typing import Callable, Any, List, Dict, Optional from pydantic import BaseModel, Field -from dataclasses import asdict from .utils import logger @@ -15,14 +12,17 @@ class RerankModel(BaseModel): """ Pydantic model class for defining a custom rerank model. + This class provides a convenient wrapper for rerank functions, allowing you to + encapsulate all rerank configurations (API keys, model settings, etc.) in one place. + Attributes: rerank_func (Callable[[Any], List[Dict]]): A callable function that reranks documents. The function should take query and documents as input and return reranked results. kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function. - This could include parameters such as the model name, API key, etc. + This should include all necessary configurations such as model name, API key, base_url, etc. Example usage: - Rerank model example from jina: + Rerank model example with Jina: ```python rerank_model = RerankModel( rerank_func=jina_rerank, @@ -32,6 +32,32 @@ class RerankModel(BaseModel): "base_url": "https://api.jina.ai/v1/rerank" } ) + + # Use in LightRAG + rag = LightRAG( + enable_rerank=True, + rerank_model_func=rerank_model.rerank, + # ... other configurations + ) + ``` + + Or define a custom function directly: + ```python + async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs): + return await jina_rerank( + query=query, + documents=documents, + model="BAAI/bge-reranker-v2-m3", + api_key="your_api_key_here", + top_k=top_k or 10, + **kwargs + ) + + rag = LightRAG( + enable_rerank=True, + rerank_model_func=my_rerank_func, + # ... other configurations + ) ``` """ @@ -43,25 +69,22 @@ class RerankModel(BaseModel): query: str, documents: List[Dict[str, Any]], top_k: Optional[int] = None, - **extra_kwargs + **extra_kwargs, ) -> List[Dict[str, Any]]: """Rerank documents using the configured model function.""" # Merge extra kwargs with model kwargs kwargs = {**self.kwargs, **extra_kwargs} return await self.rerank_func( - query=query, - documents=documents, - top_k=top_k, - **kwargs + query=query, documents=documents, top_k=top_k, **kwargs ) class MultiRerankModel(BaseModel): """Multiple rerank models for different modes/scenarios.""" - + # Primary rerank model (used if mode-specific models are not defined) rerank_model: Optional[RerankModel] = None - + # Mode-specific rerank models entity_rerank_model: Optional[RerankModel] = None relation_rerank_model: Optional[RerankModel] = None @@ -73,10 +96,10 @@ class MultiRerankModel(BaseModel): documents: List[Dict[str, Any]], mode: str = "default", top_k: Optional[int] = None, - **kwargs + **kwargs, ) -> List[Dict[str, Any]]: """Rerank using the appropriate model based on mode.""" - + # Select model based on mode if mode == "entity" and self.entity_rerank_model: model = self.entity_rerank_model @@ -89,7 +112,7 @@ class MultiRerankModel(BaseModel): else: logger.warning(f"No rerank model available for mode: {mode}") return documents - + return await model.rerank(query, documents, top_k, **kwargs) @@ -100,11 +123,11 @@ async def generic_rerank_api( base_url: str, api_key: str, top_k: Optional[int] = None, - **kwargs + **kwargs, ) -> List[Dict[str, Any]]: """ Generic rerank function that works with Jina/Cohere compatible APIs. - + Args: query: The search query documents: List of documents to rerank @@ -113,43 +136,35 @@ async def generic_rerank_api( api_key: API authentication key top_k: Number of top results to return **kwargs: Additional API-specific parameters - + Returns: List of reranked documents with relevance scores """ if not api_key: logger.warning("No API key provided for rerank service") return documents - + if not documents: return documents - + # Prepare documents for reranking - handle both text and dict formats prepared_docs = [] for doc in documents: if isinstance(doc, dict): # Use 'content' field if available, otherwise use 'text' or convert to string - text = doc.get('content') or doc.get('text') or str(doc) + text = doc.get("content") or doc.get("text") or str(doc) else: text = str(doc) prepared_docs.append(text) - + # Prepare request - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}" - } - - data = { - "model": model, - "query": query, - "documents": prepared_docs, - **kwargs - } - + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} + + data = {"model": model, "query": query, "documents": prepared_docs, **kwargs} + if top_k is not None: data["top_k"] = min(top_k, len(prepared_docs)) - + try: async with aiohttp.ClientSession() as session: async with session.post(base_url, headers=headers, json=data) as response: @@ -157,9 +172,9 @@ async def generic_rerank_api( error_text = await response.text() logger.error(f"Rerank API error {response.status}: {error_text}") return documents - + result = await response.json() - + # Extract reranked results if "results" in result: # Standard format: results contain index and relevance_score @@ -170,13 +185,15 @@ async def generic_rerank_api( if 0 <= doc_idx < len(documents): reranked_doc = documents[doc_idx].copy() if "relevance_score" in item: - reranked_doc["rerank_score"] = item["relevance_score"] + reranked_doc["rerank_score"] = item[ + "relevance_score" + ] reranked_docs.append(reranked_doc) return reranked_docs else: logger.warning("Unexpected rerank API response format") return documents - + except Exception as e: logger.error(f"Error during reranking: {e}") return documents @@ -189,11 +206,11 @@ async def jina_rerank( top_k: Optional[int] = None, base_url: str = "https://api.jina.ai/v1/rerank", api_key: Optional[str] = None, - **kwargs + **kwargs, ) -> List[Dict[str, Any]]: """ Rerank documents using Jina AI API. - + Args: query: The search query documents: List of documents to rerank @@ -202,13 +219,13 @@ async def jina_rerank( base_url: Jina API endpoint api_key: Jina API key **kwargs: Additional parameters - + Returns: List of reranked documents with relevance scores """ if api_key is None: api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_API_KEY") - + return await generic_rerank_api( query=query, documents=documents, @@ -216,7 +233,7 @@ async def jina_rerank( base_url=base_url, api_key=api_key, top_k=top_k, - **kwargs + **kwargs, ) @@ -227,11 +244,11 @@ async def cohere_rerank( top_k: Optional[int] = None, base_url: str = "https://api.cohere.ai/v1/rerank", api_key: Optional[str] = None, - **kwargs + **kwargs, ) -> List[Dict[str, Any]]: """ Rerank documents using Cohere API. - + Args: query: The search query documents: List of documents to rerank @@ -240,13 +257,13 @@ async def cohere_rerank( base_url: Cohere API endpoint api_key: Cohere API key **kwargs: Additional parameters - + Returns: List of reranked documents with relevance scores """ if api_key is None: api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_API_KEY") - + return await generic_rerank_api( query=query, documents=documents, @@ -254,7 +271,7 @@ async def cohere_rerank( base_url=base_url, api_key=api_key, top_k=top_k, - **kwargs + **kwargs, ) @@ -266,7 +283,7 @@ async def custom_rerank( base_url: str, api_key: str, top_k: Optional[int] = None, - **kwargs + **kwargs, ) -> List[Dict[str, Any]]: """ Rerank documents using a custom API endpoint. @@ -279,7 +296,7 @@ async def custom_rerank( base_url=base_url, api_key=api_key, top_k=top_k, - **kwargs + **kwargs, ) @@ -293,15 +310,12 @@ if __name__ == "__main__": {"content": "Tokyo is the capital of Japan."}, {"content": "London is the capital of England."}, ] - + query = "What is the capital of France?" - + result = await jina_rerank( - query=query, - documents=docs, - top_k=2, - api_key="your-api-key-here" + query=query, documents=docs, top_k=2, api_key="your-api-key-here" ) print(result) - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) From 04a57445da4d1e1c75776e392ebec89185755c30 Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Tue, 8 Jul 2025 13:31:05 +0800 Subject: [PATCH 15/30] update chunks truncation method --- README-zh.md | 10 ++ README.md | 12 +- env.example | 4 +- lightrag/base.py | 27 ++-- lightrag/operate.py | 338 +++++++++++++++++++++++--------------------- 5 files changed, 211 insertions(+), 180 deletions(-) diff --git a/README-zh.md b/README-zh.md index 45335489..7dd7e975 100644 --- a/README-zh.md +++ b/README-zh.md @@ -294,6 +294,16 @@ class QueryParam: top_k: int = int(os.getenv("TOP_K", "60")) """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" + chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", "5")) + """Number of text chunks to retrieve initially from vector search. + If None, defaults to top_k value. + """ + + chunk_rerank_top_k: int = int(os.getenv("CHUNK_RERANK_TOP_K", "5")) + """Number of text chunks to keep after reranking. + If None, keeps all chunks returned from initial retrieval. + """ + max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000")) """Maximum number of tokens allowed for each retrieved text chunk.""" diff --git a/README.md b/README.md index e812e8df..79479da9 100644 --- a/README.md +++ b/README.md @@ -153,7 +153,7 @@ curl https://raw.githubusercontent.com/gusye1234/nano-graphrag/main/tests/mock_d python examples/lightrag_openai_demo.py ``` -For a streaming response implementation example, please see `examples/lightrag_openai_compatible_demo.py`. Prior to execution, ensure you modify the sample code’s LLM and embedding configurations accordingly. +For a streaming response implementation example, please see `examples/lightrag_openai_compatible_demo.py`. Prior to execution, ensure you modify the sample code's LLM and embedding configurations accordingly. **Note 1**: When running the demo program, please be aware that different test scripts may use different embedding models. If you switch to a different embedding model, you must clear the data directory (`./dickens`); otherwise, the program may encounter errors. If you wish to retain the LLM cache, you can preserve the `kv_store_llm_response_cache.json` file while clearing the data directory. @@ -300,6 +300,16 @@ class QueryParam: top_k: int = int(os.getenv("TOP_K", "60")) """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" + chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", "5")) + """Number of text chunks to retrieve initially from vector search. + If None, defaults to top_k value. + """ + + chunk_rerank_top_k: int = int(os.getenv("CHUNK_RERANK_TOP_K", "5")) + """Number of text chunks to keep after reranking. + If None, keeps all chunks returned from initial retrieval. + """ + max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000")) """Maximum number of tokens allowed for each retrieved text chunk.""" diff --git a/env.example b/env.example index c4a09cad..e09494b8 100644 --- a/env.example +++ b/env.example @@ -46,7 +46,9 @@ OLLAMA_EMULATING_MODEL_TAG=latest # HISTORY_TURNS=3 # COSINE_THRESHOLD=0.2 # TOP_K=60 -# MAX_TOKEN_TEXT_CHUNK=4000 +# CHUNK_TOP_K=5 +# CHUNK_RERANK_TOP_K=5 +# MAX_TOKEN_TEXT_CHUNK=6000 # MAX_TOKEN_RELATION_DESC=4000 # MAX_TOKEN_ENTITY_DESC=4000 diff --git a/lightrag/base.py b/lightrag/base.py index 57cb2ac6..97564ac2 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -60,7 +60,17 @@ class QueryParam: top_k: int = int(os.getenv("TOP_K", "60")) """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" - max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000")) + chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", "5")) + """Number of text chunks to retrieve initially from vector search. + If None, defaults to top_k value. + """ + + chunk_rerank_top_k: int = int(os.getenv("CHUNK_RERANK_TOP_K", "5")) + """Number of text chunks to keep after reranking. + If None, keeps all chunks returned from initial retrieval. + """ + + max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "6000")) """Maximum number of tokens allowed for each retrieved text chunk.""" max_token_for_global_context: int = int( @@ -280,21 +290,6 @@ class BaseKVStorage(StorageNameSpace, ABC): False: if the cache drop failed, or the cache mode is not supported """ - # async def drop_cache_by_chunk_ids(self, chunk_ids: list[str] | None = None) -> bool: - # """Delete specific cache records from storage by chunk IDs - - # Importance notes for in-memory storage: - # 1. Changes will be persisted to disk during the next index_done_callback - # 2. update flags to notify other processes that data persistence is needed - - # Args: - # chunk_ids (list[str]): List of chunk IDs to be dropped from storage - - # Returns: - # True: if the cache drop successfully - # False: if the cache drop failed, or the operation is not supported - # """ - @dataclass class BaseGraphStorage(StorageNameSpace, ABC): diff --git a/lightrag/operate.py b/lightrag/operate.py index 645c1e85..f9f53285 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1526,6 +1526,7 @@ async def kg_query( # Build context context = await _build_query_context( + query, ll_keywords_str, hl_keywords_str, knowledge_graph_inst, @@ -1744,93 +1745,52 @@ async def _get_vector_context( query: str, chunks_vdb: BaseVectorStorage, query_param: QueryParam, - tokenizer: Tokenizer, -) -> tuple[list, list, list] | None: +) -> list[dict]: """ - Retrieve vector context from the vector database. + Retrieve text chunks from the vector database without reranking or truncation. - This function performs vector search to find relevant text chunks for a query, - formats them with file path and creation time information. + This function performs vector search to find relevant text chunks for a query. + Reranking and truncation will be handled later in the unified processing. Args: query: The query string to search for chunks_vdb: Vector database containing document chunks - query_param: Query parameters including top_k and ids - tokenizer: Tokenizer for counting tokens + query_param: Query parameters including chunk_top_k and ids Returns: - Tuple (empty_entities, empty_relations, text_units) for combine_contexts, - compatible with _get_edge_data and _get_node_data format + List of text chunks with metadata """ try: - results = await chunks_vdb.query( - query, top_k=query_param.top_k, ids=query_param.ids - ) + # Use chunk_top_k if specified, otherwise fall back to top_k + search_top_k = query_param.chunk_top_k or query_param.top_k + + results = await chunks_vdb.query(query, top_k=search_top_k, ids=query_param.ids) if not results: - return [], [], [] + return [] valid_chunks = [] for result in results: if "content" in result: - # Directly use content from chunks_vdb.query result - chunk_with_time = { + chunk_with_metadata = { "content": result["content"], "created_at": result.get("created_at", None), "file_path": result.get("file_path", "unknown_source"), + "source_type": "vector", # Mark the source type } - valid_chunks.append(chunk_with_time) - - if not valid_chunks: - return [], [], [] - - # Apply reranking if enabled - global_config = chunks_vdb.global_config - valid_chunks = await apply_rerank_if_enabled( - query=query, - retrieved_docs=valid_chunks, - global_config=global_config, - top_k=query_param.top_k, - ) - - maybe_trun_chunks = truncate_list_by_token_size( - valid_chunks, - key=lambda x: x["content"], - max_token_size=query_param.max_token_for_text_unit, - tokenizer=tokenizer, - ) + valid_chunks.append(chunk_with_metadata) logger.debug( - f"Truncate chunks from {len(valid_chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})" - ) - logger.info( - f"Query chunks: {len(maybe_trun_chunks)} chunks, top_k: {query_param.top_k}" + f"Vector search retrieved {len(valid_chunks)} chunks (top_k: {search_top_k})" ) + return valid_chunks - if not maybe_trun_chunks: - return [], [], [] - - # Create empty entities and relations contexts - entities_context = [] - relations_context = [] - - # Create text_units_context directly as a list of dictionaries - text_units_context = [] - for i, chunk in enumerate(maybe_trun_chunks): - text_units_context.append( - { - "id": i + 1, - "content": chunk["content"], - "file_path": chunk["file_path"], - } - ) - - return entities_context, relations_context, text_units_context except Exception as e: logger.error(f"Error in _get_vector_context: {e}") - return [], [], [] + return [] async def _build_query_context( + query: str, ll_keywords: str, hl_keywords: str, knowledge_graph_inst: BaseGraphStorage, @@ -1838,27 +1798,36 @@ async def _build_query_context( relationships_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, - chunks_vdb: BaseVectorStorage = None, # Add chunks_vdb parameter for mix mode + chunks_vdb: BaseVectorStorage = None, ): logger.info(f"Process {os.getpid()} building query context...") - # Handle local and global modes as before + # Collect all chunks from different sources + all_chunks = [] + entities_context = [] + relations_context = [] + + # Handle local and global modes if query_param.mode == "local": - entities_context, relations_context, text_units_context = await _get_node_data( + entities_context, relations_context, entity_chunks = await _get_node_data( ll_keywords, knowledge_graph_inst, entities_vdb, text_chunks_db, query_param, ) + all_chunks.extend(entity_chunks) + elif query_param.mode == "global": - entities_context, relations_context, text_units_context = await _get_edge_data( + entities_context, relations_context, relationship_chunks = await _get_edge_data( hl_keywords, knowledge_graph_inst, relationships_vdb, text_chunks_db, query_param, ) + all_chunks.extend(relationship_chunks) + else: # hybrid or mix mode ll_data = await _get_node_data( ll_keywords, @@ -1875,61 +1844,58 @@ async def _build_query_context( query_param, ) - ( - ll_entities_context, - ll_relations_context, - ll_text_units_context, - ) = ll_data + (ll_entities_context, ll_relations_context, ll_chunks) = ll_data + (hl_entities_context, hl_relations_context, hl_chunks) = hl_data - ( - hl_entities_context, - hl_relations_context, - hl_text_units_context, - ) = hl_data + # Collect chunks from entity and relationship sources + all_chunks.extend(ll_chunks) + all_chunks.extend(hl_chunks) - # Initialize vector data with empty lists - vector_entities_context, vector_relations_context, vector_text_units_context = ( - [], - [], - [], - ) - - # Only get vector data if in mix mode - if query_param.mode == "mix" and hasattr(query_param, "original_query"): - # Get tokenizer from text_chunks_db - tokenizer = text_chunks_db.global_config.get("tokenizer") - - # Get vector context in triple format - vector_data = await _get_vector_context( - query_param.original_query, # We need to pass the original query + # Get vector chunks if in mix mode + if query_param.mode == "mix" and chunks_vdb: + vector_chunks = await _get_vector_context( + query, chunks_vdb, query_param, - tokenizer, ) + all_chunks.extend(vector_chunks) - # If vector_data is not None, unpack it - if vector_data is not None: - ( - vector_entities_context, - vector_relations_context, - vector_text_units_context, - ) = vector_data - - # Combine and deduplicate the entities, relationships, and sources + # Combine entities and relations contexts entities_context = process_combine_contexts( - hl_entities_context, ll_entities_context, vector_entities_context + hl_entities_context, ll_entities_context ) relations_context = process_combine_contexts( - hl_relations_context, ll_relations_context, vector_relations_context + hl_relations_context, ll_relations_context ) - text_units_context = process_combine_contexts( - hl_text_units_context, ll_text_units_context, vector_text_units_context + + # Process all chunks uniformly: deduplication, reranking, and token truncation + processed_chunks = await process_chunks_unified( + query=query, + chunks=all_chunks, + query_param=query_param, + global_config=text_chunks_db.global_config, + source_type="mixed", + ) + + # Build final text_units_context from processed chunks + text_units_context = [] + for i, chunk in enumerate(processed_chunks): + text_units_context.append( + { + "id": i + 1, + "content": chunk["content"], + "file_path": chunk.get("file_path", "unknown_source"), + } ) + + logger.info( + f"Final context: {len(entities_context)} entities, {len(relations_context)} relations, {len(text_units_context)} chunks" + ) + # not necessary to use LLM to generate a response if not entities_context and not relations_context: return None - # 转换为 JSON 字符串 entities_str = json.dumps(entities_context, ensure_ascii=False) relations_str = json.dumps(relations_context, ensure_ascii=False) text_units_str = json.dumps(text_units_context, ensure_ascii=False) @@ -1975,15 +1941,6 @@ async def _get_node_data( if not len(results): return "", "", "" - # Apply reranking if enabled for entity results - global_config = entities_vdb.global_config - results = await apply_rerank_if_enabled( - query=query, - retrieved_docs=results, - global_config=global_config, - top_k=query_param.top_k, - ) - # Extract all entity IDs from your results list node_ids = [r["entity_name"] for r in results] @@ -2085,16 +2042,7 @@ async def _get_node_data( } ) - text_units_context = [] - for i, t in enumerate(use_text_units): - text_units_context.append( - { - "id": i + 1, - "content": t["content"], - "file_path": t.get("file_path", "unknown_source"), - } - ) - return entities_context, relations_context, text_units_context + return entities_context, relations_context, use_text_units async def _find_most_related_text_unit_from_entities( @@ -2183,23 +2131,21 @@ async def _find_most_related_text_unit_from_entities( logger.warning("No valid text units found") return [] - tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer") + # Sort by relation counts and order, but don't truncate all_text_units = sorted( all_text_units, key=lambda x: (x["order"], -x["relation_counts"]) ) - all_text_units = truncate_list_by_token_size( - all_text_units, - key=lambda x: x["data"]["content"], - max_token_size=query_param.max_token_for_text_unit, - tokenizer=tokenizer, - ) - logger.debug( - f"Truncate chunks from {len(all_text_units_lookup)} to {len(all_text_units)} (max tokens:{query_param.max_token_for_text_unit})" - ) + logger.debug(f"Found {len(all_text_units)} entity-related chunks") - all_text_units = [t["data"] for t in all_text_units] - return all_text_units + # Add source type marking and return chunk data + result_chunks = [] + for t in all_text_units: + chunk_data = t["data"].copy() + chunk_data["source_type"] = "entity" + result_chunks.append(chunk_data) + + return result_chunks async def _find_most_related_edges_from_entities( @@ -2287,15 +2233,6 @@ async def _get_edge_data( if not len(results): return "", "", "" - # Apply reranking if enabled for relationship results - global_config = relationships_vdb.global_config - results = await apply_rerank_if_enabled( - query=keywords, - retrieved_docs=results, - global_config=global_config, - top_k=query_param.top_k, - ) - # Prepare edge pairs in two forms: # For the batch edge properties function, use dicts. edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results] @@ -2510,21 +2447,16 @@ async def _find_related_text_unit_from_relationships( logger.warning("No valid text chunks after filtering") return [] - tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer") - truncated_text_units = truncate_list_by_token_size( - valid_text_units, - key=lambda x: x["data"]["content"], - max_token_size=query_param.max_token_for_text_unit, - tokenizer=tokenizer, - ) + logger.debug(f"Found {len(valid_text_units)} relationship-related chunks") - logger.debug( - f"Truncate chunks from {len(valid_text_units)} to {len(truncated_text_units)} (max tokens:{query_param.max_token_for_text_unit})" - ) + # Add source type marking and return chunk data + result_chunks = [] + for t in valid_text_units: + chunk_data = t["data"].copy() + chunk_data["source_type"] = "relationship" + result_chunks.append(chunk_data) - all_text_units: list[TextChunkSchema] = [t["data"] for t in truncated_text_units] - - return all_text_units + return result_chunks async def naive_query( @@ -2552,12 +2484,30 @@ async def naive_query( tokenizer: Tokenizer = global_config["tokenizer"] - _, _, text_units_context = await _get_vector_context( - query, chunks_vdb, query_param, tokenizer + chunks = await _get_vector_context(query, chunks_vdb, query_param) + + if chunks is None or len(chunks) == 0: + return PROMPTS["fail_response"] + + # Process chunks using unified processing + processed_chunks = await process_chunks_unified( + query=query, + chunks=chunks, + query_param=query_param, + global_config=global_config, + source_type="vector", ) - if text_units_context is None or len(text_units_context) == 0: - return PROMPTS["fail_response"] + # Build text_units_context from processed chunks + text_units_context = [] + for i, chunk in enumerate(processed_chunks): + text_units_context.append( + { + "id": i + 1, + "content": chunk["content"], + "file_path": chunk.get("file_path", "unknown_source"), + } + ) text_units_str = json.dumps(text_units_context, ensure_ascii=False) if query_param.only_need_context: @@ -2683,6 +2633,7 @@ async def kg_query_with_keywords( hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else "" context = await _build_query_context( + query, ll_keywords_str, hl_keywords_str, knowledge_graph_inst, @@ -2805,8 +2756,6 @@ async def query_with_keywords( f"{prompt}\n\n### Keywords\n\n{keywords_str}\n\n### Query\n\n{query}" ) - param.original_query = query - # Use appropriate query method based on mode if param.mode in ["local", "global", "hybrid", "mix"]: return await kg_query_with_keywords( @@ -2887,3 +2836,68 @@ async def apply_rerank_if_enabled( except Exception as e: logger.error(f"Error during reranking: {e}, using original documents") return retrieved_docs + + +async def process_chunks_unified( + query: str, + chunks: list[dict], + query_param: QueryParam, + global_config: dict, + source_type: str = "mixed", +) -> list[dict]: + """ + Unified processing for text chunks: deduplication, reranking, and token truncation. + + Args: + query: Search query for reranking + chunks: List of text chunks to process + query_param: Query parameters containing configuration + global_config: Global configuration dictionary + source_type: Source type for logging ("vector", "entity", "relationship", "mixed") + + Returns: + Processed and filtered list of text chunks + """ + if not chunks: + return [] + + # 1. Deduplication based on content + seen_content = set() + unique_chunks = [] + for chunk in chunks: + content = chunk.get("content", "") + if content and content not in seen_content: + seen_content.add(content) + unique_chunks.append(chunk) + + logger.debug( + f"Deduplication: {len(unique_chunks)} chunks (original: {len(chunks)})" + ) + + # 2. Apply reranking if enabled and query is provided + if global_config.get("enable_rerank", False) and query and unique_chunks: + rerank_top_k = query_param.chunk_rerank_top_k or len(unique_chunks) + unique_chunks = await apply_rerank_if_enabled( + query=query, + retrieved_docs=unique_chunks, + global_config=global_config, + top_k=rerank_top_k, + ) + logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})") + + # 3. Token-based final truncation + tokenizer = global_config.get("tokenizer") + if tokenizer and unique_chunks: + original_count = len(unique_chunks) + unique_chunks = truncate_list_by_token_size( + unique_chunks, + key=lambda x: x.get("content", ""), + max_token_size=query_param.max_token_for_text_unit, + tokenizer=tokenizer, + ) + logger.debug( + f"Token truncation: {len(unique_chunks)} chunks from {original_count} " + f"(max tokens: {query_param.max_token_for_text_unit}, source: {source_type})" + ) + + return unique_chunks From c295d355a0525871971c8e19bc9fb75e6a50a5d6 Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Tue, 8 Jul 2025 15:05:30 +0800 Subject: [PATCH 16/30] fix chunk_top_k limiting --- examples/rerank_example.py | 7 +++++++ lightrag/operate.py | 17 +++++++++++++---- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/examples/rerank_example.py b/examples/rerank_example.py index 74ec85bc..e0e361a5 100644 --- a/examples/rerank_example.py +++ b/examples/rerank_example.py @@ -20,6 +20,7 @@ from lightrag import LightRAG, QueryParam from lightrag.rerank import custom_rerank, RerankModel from lightrag.llm.openai import openai_complete_if_cache, openai_embed from lightrag.utils import EmbeddingFunc, setup_logger +from lightrag.kg.shared_storage import initialize_pipeline_status # Set up your working directory WORKING_DIR = "./test_rerank" @@ -87,6 +88,9 @@ async def create_rag_with_rerank(): rerank_model_func=my_rerank_func, ) + await rag.initialize_storages() + await initialize_pipeline_status() + return rag @@ -120,6 +124,9 @@ async def create_rag_with_rerank_model(): rerank_model_func=rerank_model.rerank, ) + await rag.initialize_storages() + await initialize_pipeline_status() + return rag diff --git a/lightrag/operate.py b/lightrag/operate.py index f9f53285..05fef78e 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -2823,8 +2823,9 @@ async def apply_rerank_if_enabled( documents=retrieved_docs, top_k=top_k, ) - if reranked_docs and len(reranked_docs) > 0: + if len(reranked_docs) > top_k: + reranked_docs = reranked_docs[:top_k] logger.info( f"Successfully reranked {len(retrieved_docs)} documents to {len(reranked_docs)}" ) @@ -2846,7 +2847,7 @@ async def process_chunks_unified( source_type: str = "mixed", ) -> list[dict]: """ - Unified processing for text chunks: deduplication, reranking, and token truncation. + Unified processing for text chunks: deduplication, chunk_top_k limiting, reranking, and token truncation. Args: query: Search query for reranking @@ -2874,7 +2875,15 @@ async def process_chunks_unified( f"Deduplication: {len(unique_chunks)} chunks (original: {len(chunks)})" ) - # 2. Apply reranking if enabled and query is provided + # 2. Apply chunk_top_k limiting if specified + if query_param.chunk_top_k is not None and query_param.chunk_top_k > 0: + if len(unique_chunks) > query_param.chunk_top_k: + unique_chunks = unique_chunks[: query_param.chunk_top_k] + logger.debug( + f"Chunk top-k limiting: kept {len(unique_chunks)} chunks (chunk_top_k={query_param.chunk_top_k})" + ) + + # 3. Apply reranking if enabled and query is provided if global_config.get("enable_rerank", False) and query and unique_chunks: rerank_top_k = query_param.chunk_rerank_top_k or len(unique_chunks) unique_chunks = await apply_rerank_if_enabled( @@ -2885,7 +2894,7 @@ async def process_chunks_unified( ) logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})") - # 3. Token-based final truncation + # 4. Token-based final truncation tokenizer = global_config.get("tokenizer") if tokenizer and unique_chunks: original_count = len(unique_chunks) From cf26e52d89009713c9ec8bb3b44e773df80f6b45 Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Tue, 8 Jul 2025 15:13:09 +0800 Subject: [PATCH 17/30] fix init --- docs/rerank_integration.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/rerank_integration.md b/docs/rerank_integration.md index f216a8c8..fdaebfa5 100644 --- a/docs/rerank_integration.md +++ b/docs/rerank_integration.md @@ -144,6 +144,7 @@ Reranking is automatically applied at these key retrieval stages: import asyncio from lightrag import LightRAG, QueryParam from lightrag.llm.openai import gpt_4o_mini_complete, openai_embedding +from lightrag.kg.shared_storage import initialize_pipeline_status from lightrag.rerank import jina_rerank async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs): @@ -167,6 +168,9 @@ async def main(): rerank_model_func=my_rerank_func, ) + await rag.initialize_storages() + await initialize_pipeline_status() + # Insert documents await rag.ainsert([ "Document 1 content...", From d4651d59c13d3ff75f203f145219ef9331c89338 Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Tue, 8 Jul 2025 21:44:20 +0800 Subject: [PATCH 18/30] Add rerank to server --- env.example | 14 ++++++++--- lightrag/api/config.py | 23 ++++++++++++++++++ lightrag/api/lightrag_server.py | 36 ++++++++++++++++++++++++++++ lightrag/api/routers/query_routes.py | 12 ++++++++++ 4 files changed, 82 insertions(+), 3 deletions(-) diff --git a/env.example b/env.example index 4447f5f0..874ebf3c 100644 --- a/env.example +++ b/env.example @@ -46,8 +46,19 @@ OLLAMA_EMULATING_MODEL_TAG=latest # HISTORY_TURNS=3 # COSINE_THRESHOLD=0.2 # TOP_K=60 +### Number of text chunks to retrieve initially from vector search # CHUNK_TOP_K=5 + +### Rerank Configuration +### Enable rerank functionality to improve retrieval quality +# ENABLE_RERANK=False +### Number of text chunks to keep after reranking (should be <= CHUNK_TOP_K) # CHUNK_RERANK_TOP_K=5 +### Rerank model configuration (required when ENABLE_RERANK=True) +# RERANK_MODEL=BAAI/bge-reranker-v2-m3 +# RERANK_BINDING_HOST=https://api.your-rerank-provider.com/v1/rerank +# RERANK_BINDING_API_KEY=your_rerank_api_key_here + # MAX_TOKEN_TEXT_CHUNK=6000 # MAX_TOKEN_RELATION_DESC=4000 # MAX_TOKEN_ENTITY_DESC=4000 @@ -181,6 +192,3 @@ QDRANT_URL=http://localhost:6333 ### Redis REDIS_URI=redis://localhost:6379 # REDIS_WORKSPACE=forced_workspace_name - -# Rerank Configuration -ENABLE_RERANK=False diff --git a/lightrag/api/config.py b/lightrag/api/config.py index ad0e670b..8c3fbff4 100644 --- a/lightrag/api/config.py +++ b/lightrag/api/config.py @@ -165,6 +165,24 @@ def parse_args() -> argparse.Namespace: default=get_env_value("TOP_K", 60, int), help="Number of most similar results to return (default: from env or 60)", ) + parser.add_argument( + "--chunk-top-k", + type=int, + default=get_env_value("CHUNK_TOP_K", 5, int), + help="Number of text chunks to retrieve initially from vector search (default: from env or 5)", + ) + parser.add_argument( + "--chunk-rerank-top-k", + type=int, + default=get_env_value("CHUNK_RERANK_TOP_K", 5, int), + help="Number of text chunks to keep after reranking (default: from env or 5)", + ) + parser.add_argument( + "--enable-rerank", + action="store_true", + default=get_env_value("ENABLE_RERANK", False, bool), + help="Enable rerank functionality (default: from env or False)", + ) parser.add_argument( "--cosine-threshold", type=float, @@ -295,6 +313,11 @@ def parse_args() -> argparse.Namespace: args.guest_token_expire_hours = get_env_value("GUEST_TOKEN_EXPIRE_HOURS", 24, int) args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256") + # Rerank model configuration + args.rerank_model = get_env_value("RERANK_MODEL", "BAAI/bge-reranker-v2-m3") + args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None) + args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None) + ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name return args diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index cd87af22..b43c66d9 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -291,6 +291,32 @@ def create_app(args): ), ) + # Configure rerank function if enabled + rerank_model_func = None + if args.enable_rerank and args.rerank_binding_api_key and args.rerank_binding_host: + from lightrag.rerank import custom_rerank + + async def server_rerank_func( + query: str, documents: list, top_k: int = None, **kwargs + ): + """Server rerank function with configuration from environment variables""" + return await custom_rerank( + query=query, + documents=documents, + model=args.rerank_model, + base_url=args.rerank_binding_host, + api_key=args.rerank_binding_api_key, + top_k=top_k, + **kwargs, + ) + + rerank_model_func = server_rerank_func + logger.info(f"Rerank enabled with model: {args.rerank_model}") + elif args.enable_rerank: + logger.warning( + "Rerank enabled but RERANK_BINDING_API_KEY or RERANK_BINDING_HOST not configured. Rerank will be disabled." + ) + # Initialize RAG if args.llm_binding in ["lollms", "ollama", "openai"]: rag = LightRAG( @@ -324,6 +350,8 @@ def create_app(args): }, enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, enable_llm_cache=args.enable_llm_cache, + enable_rerank=args.enable_rerank, + rerank_model_func=rerank_model_func, auto_manage_storages_states=False, max_parallel_insert=args.max_parallel_insert, max_graph_nodes=args.max_graph_nodes, @@ -352,6 +380,8 @@ def create_app(args): }, enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, enable_llm_cache=args.enable_llm_cache, + enable_rerank=args.enable_rerank, + rerank_model_func=rerank_model_func, auto_manage_storages_states=False, max_parallel_insert=args.max_parallel_insert, max_graph_nodes=args.max_graph_nodes, @@ -478,6 +508,12 @@ def create_app(args): "enable_llm_cache": args.enable_llm_cache, "workspace": args.workspace, "max_graph_nodes": args.max_graph_nodes, + # Rerank configuration + "enable_rerank": args.enable_rerank, + "rerank_model": args.rerank_model if args.enable_rerank else None, + "rerank_binding_host": args.rerank_binding_host + if args.enable_rerank + else None, }, "auth_mode": auth_mode, "pipeline_busy": pipeline_status.get("busy", False), diff --git a/lightrag/api/routers/query_routes.py b/lightrag/api/routers/query_routes.py index 69aa32d8..0a0c6227 100644 --- a/lightrag/api/routers/query_routes.py +++ b/lightrag/api/routers/query_routes.py @@ -49,6 +49,18 @@ class QueryRequest(BaseModel): description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.", ) + chunk_top_k: Optional[int] = Field( + ge=1, + default=None, + description="Number of text chunks to retrieve initially from vector search.", + ) + + chunk_rerank_top_k: Optional[int] = Field( + ge=1, + default=None, + description="Number of text chunks to keep after reranking.", + ) + max_token_for_text_unit: Optional[int] = Field( gt=1, default=None, From 4438897b6bf5e3db4fe3ff1a872805dce8377751 Mon Sep 17 00:00:00 2001 From: DavIvek Date: Tue, 8 Jul 2025 16:27:38 +0200 Subject: [PATCH 19/30] add changes based on review --- env.example | 7 ++ lightrag/kg/memgraph_impl.py | 223 +++++++++++++++++++++-------------- 2 files changed, 143 insertions(+), 87 deletions(-) diff --git a/env.example b/env.example index ef52bd53..df88a518 100644 --- a/env.example +++ b/env.example @@ -179,3 +179,10 @@ QDRANT_URL=http://localhost:6333 ### Redis REDIS_URI=redis://localhost:6379 # REDIS_WORKSPACE=forced_workspace_name + +### Memgraph Configuration +MEMGRAPH_URI=bolt://localhost:7687 +MEMGRAPH_USERNAME= +MEMGRAPH_PASSWORD= +MEMGRAPH_DATABASE=memgraph +# MEMGRAPH_WORKSPACE=forced_workspace_name diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index 397e5a99..4c16b843 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -31,14 +31,23 @@ config.read("config.ini", "utf-8") @final @dataclass class MemgraphStorage(BaseGraphStorage): - def __init__(self, namespace, global_config, embedding_func): + def __init__(self, namespace, global_config, embedding_func, workspace=None): + memgraph_workspace = os.environ.get("MEMGRAPH_WORKSPACE") + if memgraph_workspace and memgraph_workspace.strip(): + workspace = memgraph_workspace super().__init__( namespace=namespace, + workspace=workspace or "", global_config=global_config, embedding_func=embedding_func, ) self._driver = None + def _get_workspace_label(self) -> str: + """Get workspace label, return 'base' for compatibility when workspace is empty""" + workspace = getattr(self, "workspace", None) + return workspace if workspace else "base" + async def initialize(self): URI = os.environ.get( "MEMGRAPH_URI", @@ -63,12 +72,13 @@ class MemgraphStorage(BaseGraphStorage): async with self._driver.session(database=DATABASE) as session: # Create index for base nodes on entity_id if it doesn't exist try: - await session.run("""CREATE INDEX ON :base(entity_id)""") - logger.info("Created index on :base(entity_id) in Memgraph.") + workspace_label = self._get_workspace_label() + await session.run(f"""CREATE INDEX ON :{workspace_label}(entity_id)""") + logger.info(f"Created index on :{workspace_label}(entity_id) in Memgraph.") except Exception as e: # Index may already exist, which is not an error logger.warning( - f"Index creation on :base(entity_id) may have failed or already exists: {e}" + f"Index creation on :{workspace_label}(entity_id) may have failed or already exists: {e}" ) await session.run("RETURN 1") logger.info(f"Connected to Memgraph at {URI}") @@ -101,15 +111,18 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error checking the node existence. """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: - query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists" + workspace_label = self._get_workspace_label() + query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists" result = await session.run(query, entity_id=node_id) single_result = await result.single() await result.consume() # Ensure result is fully consumed - return single_result["node_exists"] + return single_result["node_exists"] if single_result is not None else False except Exception as e: logger.error(f"Error checking node existence for {node_id}: {str(e)}") await result.consume() # Ensure the result is consumed even on error @@ -129,22 +142,21 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error checking the edge existence. """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: + workspace_label = self._get_workspace_label() query = ( - "MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) " + f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) " "RETURN COUNT(r) > 0 AS edgeExists" ) - result = await session.run( - query, - source_entity_id=source_node_id, - target_entity_id=target_node_id, - ) + result = await session.run(query, source_entity_id=source_node_id, target_entity_id=target_node_id) # type: ignore single_result = await result.single() await result.consume() # Ensure result is fully consumed - return single_result["edgeExists"] + return single_result["edgeExists"] if single_result is not None else False except Exception as e: logger.error( f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" @@ -165,11 +177,14 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error executing the query """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: - query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" + workspace_label = self._get_workspace_label() + query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n" result = await session.run(query, entity_id=node_id) try: records = await result.fetch( @@ -183,12 +198,12 @@ class MemgraphStorage(BaseGraphStorage): if records: node = records[0]["n"] node_dict = dict(node) - # Remove base label from labels list if it exists + # Remove workspace label from labels list if it exists if "labels" in node_dict: node_dict["labels"] = [ label for label in node_dict["labels"] - if label != "base" + if label != workspace_label ] return node_dict return None @@ -212,12 +227,15 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error executing the query """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: - query = """ - MATCH (n:base {entity_id: $entity_id}) + workspace_label = self._get_workspace_label() + query = f""" + MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) OPTIONAL MATCH (n)-[r]-() RETURN COUNT(r) AS degree """ @@ -246,12 +264,15 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error executing the query """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: - query = """ - MATCH (n:base) + workspace_label = self._get_workspace_label() + query = f""" + MATCH (n:`{workspace_label}`) WHERE n.entity_id IS NOT NULL RETURN DISTINCT n.entity_id AS label ORDER BY label @@ -280,13 +301,16 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error executing the query """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") try: async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: - query = """MATCH (n:base {entity_id: $entity_id}) - OPTIONAL MATCH (n)-[r]-(connected:base) + workspace_label = self._get_workspace_label() + query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) + OPTIONAL MATCH (n)-[r]-(connected:`{workspace_label}`) WHERE connected.entity_id IS NOT NULL RETURN n, r, connected""" results = await session.run(query, entity_id=source_node_id) @@ -341,12 +365,15 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error executing the query """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: - query = """ - MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id}) + workspace_label = self._get_workspace_label() + query = f""" + MATCH (start:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(end:`{workspace_label}` {{entity_id: $target_entity_id}}) RETURN properties(r) as edge_properties """ result = await session.run( @@ -386,6 +413,8 @@ class MemgraphStorage(BaseGraphStorage): node_id: The unique identifier for the node (used as label) node_data: Dictionary of node properties """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") properties = node_data entity_type = properties["entity_type"] if "entity_id" not in properties: @@ -393,15 +422,14 @@ class MemgraphStorage(BaseGraphStorage): try: async with self._driver.session(database=self._DATABASE) as session: - + workspace_label = self._get_workspace_label() async def execute_upsert(tx: AsyncManagedTransaction): query = ( - """ - MERGE (n:base {entity_id: $entity_id}) + f""" + MERGE (n:`{workspace_label}` {{entity_id: $entity_id}}) SET n += $properties - SET n:`%s` + SET n:`{entity_type}` """ - % entity_type ) result = await tx.run( query, entity_id=node_id, properties=properties @@ -429,15 +457,18 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error executing the query """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") try: edge_properties = edge_data async with self._driver.session(database=self._DATABASE) as session: async def execute_upsert(tx: AsyncManagedTransaction): - query = """ - MATCH (source:base {entity_id: $source_entity_id}) + workspace_label = self._get_workspace_label() + query = f""" + MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}}) WITH source - MATCH (target:base {entity_id: $target_entity_id}) + MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}}) MERGE (source)-[r:DIRECTED]-(target) SET r += $properties RETURN r, source, target @@ -467,10 +498,13 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error executing the query """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") async def _do_delete(tx: AsyncManagedTransaction): - query = """ - MATCH (n:base {entity_id: $entity_id}) + workspace_label = self._get_workspace_label() + query = f""" + MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) DETACH DELETE n """ result = await tx.run(query, entity_id=node_id) @@ -490,6 +524,8 @@ class MemgraphStorage(BaseGraphStorage): Args: nodes: List of node labels to be deleted """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") for node in nodes: await self.delete_node(node) @@ -502,11 +538,14 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error executing the query """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") for source, target in edges: async def _do_delete_edge(tx: AsyncManagedTransaction): - query = """ - MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id}) + workspace_label = self._get_workspace_label() + query = f""" + MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(target:`{workspace_label}` {{entity_id: $target_entity_id}}) DELETE r """ result = await tx.run( @@ -523,9 +562,9 @@ class MemgraphStorage(BaseGraphStorage): raise async def drop(self) -> dict[str, str]: - """Drop all data from storage and clean up resources + """Drop all data from the current workspace and clean up resources - This method will delete all nodes and relationships in the Neo4j database. + This method will delete all nodes and relationships in the Memgraph database. Returns: dict[str, str]: Operation status and message @@ -535,17 +574,18 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error executing the query """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") try: async with self._driver.session(database=self._DATABASE) as session: - query = "MATCH (n) DETACH DELETE n" + workspace_label = self._get_workspace_label() + query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" result = await session.run(query) await result.consume() - logger.info( - f"Process {os.getpid()} drop Memgraph database {self._DATABASE}" - ) - return {"status": "success", "message": "data dropped"} + logger.info(f"Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}") + return {"status": "success", "message": "workspace data dropped"} except Exception as e: - logger.error(f"Error dropping Memgraph database {self._DATABASE}: {e}") + logger.error(f"Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}") return {"status": "error", "message": str(e)} async def edge_degree(self, src_id: str, tgt_id: str) -> int: @@ -558,6 +598,8 @@ class MemgraphStorage(BaseGraphStorage): Returns: int: Sum of the degrees of both nodes """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") src_degree = await self.node_degree(src_id) trg_degree = await self.node_degree(tgt_id) @@ -578,12 +620,15 @@ class MemgraphStorage(BaseGraphStorage): list[dict]: A list of nodes, where each node is a dictionary of its properties. An empty list if no matching nodes are found. """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + workspace_label = self._get_workspace_label() async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = """ + query = f""" UNWIND $chunk_ids AS chunk_id - MATCH (n:base) + MATCH (n:`{workspace_label}`) WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep) RETURN DISTINCT n """ @@ -607,12 +652,15 @@ class MemgraphStorage(BaseGraphStorage): list[dict]: A list of edges, where each edge is a dictionary of its properties. An empty list if no matching edges are found. """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + workspace_label = self._get_workspace_label() async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = """ + query = f""" UNWIND $chunk_ids AS chunk_id - MATCH (a:base)-[r]-(b:base) + MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`) WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep) WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id // Ensure we only return each unique edge once by ordering the source and target @@ -652,9 +700,13 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error executing the query """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + result = KnowledgeGraph() seen_nodes = set() seen_edges = set() + workspace_label = self._get_workspace_label() async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: @@ -682,19 +734,17 @@ class MemgraphStorage(BaseGraphStorage): await count_result.consume() # Run the main query to get nodes with highest degree - main_query = """ - MATCH (n) + main_query = f""" + MATCH (n:`{workspace_label}`) OPTIONAL MATCH (n)-[r]-() WITH n, COALESCE(count(r), 0) AS degree ORDER BY degree DESC LIMIT $max_nodes - WITH collect({node: n}) AS filtered_nodes - UNWIND filtered_nodes AS node_info - WITH collect(node_info.node) AS kept_nodes, filtered_nodes - OPTIONAL MATCH (a)-[r]-(b) + WITH collect(n) AS kept_nodes + MATCH (a)-[r]-(b) WHERE a IN kept_nodes AND b IN kept_nodes - RETURN filtered_nodes AS node_info, - collect(DISTINCT r) AS relationships + RETURN [node IN kept_nodes | {{node: node}}] AS node_info, + collect(DISTINCT r) AS relationships """ result_set = None try: @@ -710,31 +760,33 @@ class MemgraphStorage(BaseGraphStorage): await result_set.consume() else: - bfs_query = """ - MATCH (start) WHERE start.entity_id = $entity_id + bfs_query = f""" + MATCH (start:`{workspace_label}`) + WHERE start.entity_id = $entity_id WITH start - CALL { + CALL {{ WITH start - MATCH path = (start)-[*0..$max_depth]-(node) + MATCH path = (start)-[*0..{max_depth}]-(node) WITH nodes(path) AS path_nodes, relationships(path) AS path_rels UNWIND path_nodes AS n WITH collect(DISTINCT n) AS all_nodes, collect(DISTINCT path_rels) AS all_rel_lists WITH all_nodes, reduce(r = [], x IN all_rel_lists | r + x) AS all_rels RETURN all_nodes, all_rels - } + }} WITH all_nodes AS nodes, all_rels AS relationships, size(all_nodes) AS total_nodes - - // Apply node limiting here - WITH CASE - WHEN total_nodes <= $max_nodes THEN nodes - ELSE nodes[0..$max_nodes] + WITH + CASE + WHEN total_nodes <= {max_nodes} THEN nodes + ELSE nodes[0..{max_nodes}] END AS limited_nodes, relationships, total_nodes, - total_nodes > $max_nodes AS is_truncated - UNWIND limited_nodes AS node - WITH collect({node: node}) AS node_info, relationships, total_nodes, is_truncated - RETURN node_info, relationships, total_nodes, is_truncated + total_nodes > {max_nodes} AS is_truncated + RETURN + [node IN limited_nodes | {{node: node}}] AS node_info, + relationships, + total_nodes, + is_truncated """ result_set = None try: @@ -742,8 +794,6 @@ class MemgraphStorage(BaseGraphStorage): bfs_query, { "entity_id": node_label, - "max_depth": max_depth, - "max_nodes": max_nodes, }, ) record = await result_set.single() @@ -777,22 +827,21 @@ class MemgraphStorage(BaseGraphStorage): ) ) - if "relationships" in record and record["relationships"]: - for rel in record["relationships"]: - edge_id = rel.id - if edge_id not in seen_edges: - seen_edges.add(edge_id) - start = rel.start_node - end = rel.end_node - result.edges.append( - KnowledgeGraphEdge( - id=f"{edge_id}", - type=rel.type, - source=f"{start.id}", - target=f"{end.id}", - properties=dict(rel), - ) + for rel in record["relationships"]: + edge_id = rel.id + if edge_id not in seen_edges: + seen_edges.add(edge_id) + start = rel.start_node + end = rel.end_node + result.edges.append( + KnowledgeGraphEdge( + id=f"{edge_id}", + type=rel.type, + source=f"{start.id}", + target=f"{end.id}", + properties=dict(rel), ) + ) logger.info( f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" From 08eb68b8ed0f774843fd682e3de3cc4749b5b6e8 Mon Sep 17 00:00:00 2001 From: DavIvek Date: Tue, 8 Jul 2025 20:21:20 +0200 Subject: [PATCH 20/30] run pre-commit --- lightrag/kg/memgraph_impl.py | 113 +++++++++++++++++++++++++---------- 1 file changed, 82 insertions(+), 31 deletions(-) diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index 4c16b843..8c6d6574 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -73,8 +73,12 @@ class MemgraphStorage(BaseGraphStorage): # Create index for base nodes on entity_id if it doesn't exist try: workspace_label = self._get_workspace_label() - await session.run(f"""CREATE INDEX ON :{workspace_label}(entity_id)""") - logger.info(f"Created index on :{workspace_label}(entity_id) in Memgraph.") + await session.run( + f"""CREATE INDEX ON :{workspace_label}(entity_id)""" + ) + logger.info( + f"Created index on :{workspace_label}(entity_id) in Memgraph." + ) except Exception as e: # Index may already exist, which is not an error logger.warning( @@ -112,7 +116,9 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error checking the node existence. """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: @@ -122,7 +128,9 @@ class MemgraphStorage(BaseGraphStorage): result = await session.run(query, entity_id=node_id) single_result = await result.single() await result.consume() # Ensure result is fully consumed - return single_result["node_exists"] if single_result is not None else False + return ( + single_result["node_exists"] if single_result is not None else False + ) except Exception as e: logger.error(f"Error checking node existence for {node_id}: {str(e)}") await result.consume() # Ensure the result is consumed even on error @@ -143,7 +151,9 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error checking the edge existence. """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: @@ -153,10 +163,16 @@ class MemgraphStorage(BaseGraphStorage): f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) " "RETURN COUNT(r) > 0 AS edgeExists" ) - result = await session.run(query, source_entity_id=source_node_id, target_entity_id=target_node_id) # type: ignore + result = await session.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + ) # type: ignore single_result = await result.single() await result.consume() # Ensure result is fully consumed - return single_result["edgeExists"] if single_result is not None else False + return ( + single_result["edgeExists"] if single_result is not None else False + ) except Exception as e: logger.error( f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" @@ -178,13 +194,17 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error executing the query """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: workspace_label = self._get_workspace_label() - query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n" + query = ( + f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n" + ) result = await session.run(query, entity_id=node_id) try: records = await result.fetch( @@ -228,7 +248,9 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error executing the query """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: @@ -265,7 +287,9 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error executing the query """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: @@ -302,7 +326,9 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error executing the query """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) try: async with self._driver.session( database=self._DATABASE, default_access_mode="READ" @@ -366,7 +392,9 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error executing the query """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: @@ -414,7 +442,9 @@ class MemgraphStorage(BaseGraphStorage): node_data: Dictionary of node properties """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) properties = node_data entity_type = properties["entity_type"] if "entity_id" not in properties: @@ -423,14 +453,13 @@ class MemgraphStorage(BaseGraphStorage): try: async with self._driver.session(database=self._DATABASE) as session: workspace_label = self._get_workspace_label() + async def execute_upsert(tx: AsyncManagedTransaction): - query = ( - f""" + query = f""" MERGE (n:`{workspace_label}` {{entity_id: $entity_id}}) SET n += $properties SET n:`{entity_type}` """ - ) result = await tx.run( query, entity_id=node_id, properties=properties ) @@ -458,7 +487,9 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error executing the query """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) try: edge_properties = edge_data async with self._driver.session(database=self._DATABASE) as session: @@ -499,7 +530,9 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error executing the query """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) async def _do_delete(tx: AsyncManagedTransaction): workspace_label = self._get_workspace_label() @@ -525,7 +558,9 @@ class MemgraphStorage(BaseGraphStorage): nodes: List of node labels to be deleted """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) for node in nodes: await self.delete_node(node) @@ -539,7 +574,9 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error executing the query """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) for source, target in edges: async def _do_delete_edge(tx: AsyncManagedTransaction): @@ -575,17 +612,23 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error executing the query """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) try: async with self._driver.session(database=self._DATABASE) as session: workspace_label = self._get_workspace_label() query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" result = await session.run(query) await result.consume() - logger.info(f"Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}") + logger.info( + f"Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}" + ) return {"status": "success", "message": "workspace data dropped"} except Exception as e: - logger.error(f"Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}") + logger.error( + f"Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}" + ) return {"status": "error", "message": str(e)} async def edge_degree(self, src_id: str, tgt_id: str) -> int: @@ -599,7 +642,9 @@ class MemgraphStorage(BaseGraphStorage): int: Sum of the degrees of both nodes """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) src_degree = await self.node_degree(src_id) trg_degree = await self.node_degree(tgt_id) @@ -621,7 +666,9 @@ class MemgraphStorage(BaseGraphStorage): An empty list if no matching nodes are found. """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) workspace_label = self._get_workspace_label() async with self._driver.session( database=self._DATABASE, default_access_mode="READ" @@ -653,7 +700,9 @@ class MemgraphStorage(BaseGraphStorage): An empty list if no matching edges are found. """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) workspace_label = self._get_workspace_label() async with self._driver.session( database=self._DATABASE, default_access_mode="READ" @@ -701,7 +750,9 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error executing the query """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) result = KnowledgeGraph() seen_nodes = set() @@ -761,7 +812,7 @@ class MemgraphStorage(BaseGraphStorage): else: bfs_query = f""" - MATCH (start:`{workspace_label}`) + MATCH (start:`{workspace_label}`) WHERE start.entity_id = $entity_id WITH start CALL {{ @@ -774,7 +825,7 @@ class MemgraphStorage(BaseGraphStorage): RETURN all_nodes, all_rels }} WITH all_nodes AS nodes, all_rels AS relationships, size(all_nodes) AS total_nodes - WITH + WITH CASE WHEN total_nodes <= {max_nodes} THEN nodes ELSE nodes[0..{max_nodes}] @@ -782,7 +833,7 @@ class MemgraphStorage(BaseGraphStorage): relationships, total_nodes, total_nodes > {max_nodes} AS is_truncated - RETURN + RETURN [node IN limited_nodes | {{node: node}}] AS node_info, relationships, total_nodes, From 3a0249a6b9bc09e2584f316f0b58a0b020ec0465 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 9 Jul 2025 03:36:17 +0800 Subject: [PATCH 21/30] Update env.example --- env.example | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/env.example b/env.example index df88a518..32a9f3ed 100644 --- a/env.example +++ b/env.example @@ -134,13 +134,14 @@ EMBEDDING_BINDING_HOST=http://localhost:11434 # LIGHTRAG_VECTOR_STORAGE=QdrantVectorDBStorage ### Graph Storage (Recommended for production deployment) # LIGHTRAG_GRAPH_STORAGE=Neo4JStorage +# LIGHTRAG_GRAPH_STORAGE=MemgraphStorage #################################################################### ### Default workspace for all storage types ### For the purpose of isolation of data for each LightRAG instance ### Valid characters: a-z, A-Z, 0-9, and _ #################################################################### -# WORKSPACE=doc— +# WORKSPACE=space1 ### PostgreSQL Configuration POSTGRES_HOST=localhost From e9c3503f7724f28db0aa21e8db844329a5e221ef Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 9 Jul 2025 04:36:52 +0800 Subject: [PATCH 22/30] Update logger info --- lightrag/operate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightrag/operate.py b/lightrag/operate.py index e2251fc7..4b2aa250 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1781,8 +1781,8 @@ async def _get_vector_context( } valid_chunks.append(chunk_with_metadata) - logger.debug( - f"Vector search retrieved {len(valid_chunks)} chunks (top_k: {search_top_k})" + logger.info( + f"Naive query: {len(valid_chunks)} chunks (chunk_top_k: {search_top_k})" ) return valid_chunks From 78033edabba28a633a68c31adac6626e938ef78f Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 9 Jul 2025 04:37:04 +0800 Subject: [PATCH 23/30] Update env.example --- env.example | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/env.example b/env.example index d8a4dfd6..4515fe34 100644 --- a/env.example +++ b/env.example @@ -42,15 +42,24 @@ OLLAMA_EMULATING_MODEL_TAG=latest ### Logfile location (defaults to current working directory) # LOG_DIR=/path/to/log/directory -### Settings for RAG query +### RAG Configuration +### Chunk size for document splitting, 500~1500 is recommended +# CHUNK_SIZE=1200 +# CHUNK_OVERLAP_SIZE=100 +# MAX_TOKEN_SUMMARY=500 + +### RAG Query Configuration # HISTORY_TURNS=3 +# MAX_TOKEN_TEXT_CHUNK=6000 +# MAX_TOKEN_RELATION_DESC=4000 +# MAX_TOKEN_ENTITY_DESC=4000 # COSINE_THRESHOLD=0.2 +### Number of entities or relations to retrieve from KG # TOP_K=60 ### Number of text chunks to retrieve initially from vector search # CHUNK_TOP_K=5 ### Rerank Configuration -### Enable rerank functionality to improve retrieval quality # ENABLE_RERANK=False ### Number of text chunks to keep after reranking (should be <= CHUNK_TOP_K) # CHUNK_RERANK_TOP_K=5 @@ -59,10 +68,6 @@ OLLAMA_EMULATING_MODEL_TAG=latest # RERANK_BINDING_HOST=https://api.your-rerank-provider.com/v1/rerank # RERANK_BINDING_API_KEY=your_rerank_api_key_here -# MAX_TOKEN_TEXT_CHUNK=6000 -# MAX_TOKEN_RELATION_DESC=4000 -# MAX_TOKEN_ENTITY_DESC=4000 - ### Entity and relation summarization configuration ### Language: English, Chinese, French, German ... SUMMARY_LANGUAGE=English @@ -75,9 +80,6 @@ SUMMARY_LANGUAGE=English ### Number of parallel processing documents(Less than MAX_ASYNC/2 is recommended) # MAX_PARALLEL_INSERT=2 -### Chunk size for document splitting, 500~1500 is recommended -# CHUNK_SIZE=1200 -# CHUNK_OVERLAP_SIZE=100 ### LLM Configuration ENABLE_LLM_CACHE=true From 2056c3c8095825c086d1ee7775c0d07fd580a5ec Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 9 Jul 2025 04:41:51 +0800 Subject: [PATCH 24/30] Increase default CHUNK_TOP_K from 5 to 15 --- lightrag/api/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightrag/api/config.py b/lightrag/api/config.py index 8c3fbff4..70147bde 100644 --- a/lightrag/api/config.py +++ b/lightrag/api/config.py @@ -168,8 +168,8 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--chunk-top-k", type=int, - default=get_env_value("CHUNK_TOP_K", 5, int), - help="Number of text chunks to retrieve initially from vector search (default: from env or 5)", + default=get_env_value("CHUNK_TOP_K", 15, int), + help="Number of text chunks to retrieve initially from vector search (default: from env or 15)", ) parser.add_argument( "--chunk-rerank-top-k", From 4705a228611a3aed0ad081cc5523d41e793b12a0 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 9 Jul 2025 04:43:20 +0800 Subject: [PATCH 25/30] Bump core version to 1.4.0 --- lightrag/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/__init__.py b/lightrag/__init__.py index 392b3f60..e72f906a 100644 --- a/lightrag/__init__.py +++ b/lightrag/__init__.py @@ -1,5 +1,5 @@ from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam -__version__ = "1.3.10" +__version__ = "1.4.0" __author__ = "Zirui Guo" __url__ = "https://github.com/HKUDS/LightRAG" From b192f8c9a3daf20ec81c89f41c78fc9c866bd361 Mon Sep 17 00:00:00 2001 From: Anton Vice Date: Tue, 8 Jul 2025 19:35:22 -0300 Subject: [PATCH 26/30] Fix: Handle NoneType error when processing documents without a file path The document processing pipeline would crash with a TypeError when a document was submitted as raw text via the API, as the file_path attribute would be None. This change adds a check to handle the None case gracefully, preventing the crash and allowing text-based documents to be indexed correctly. --- lightrag/lightrag.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 1f61a42e..69ef8811 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -900,9 +900,15 @@ class LightRAG: # Get first document's file path and total count for job name first_doc_id, first_doc = next(iter(to_process_docs.items())) first_doc_path = first_doc.file_path - path_prefix = first_doc_path[:20] + ( - "..." if len(first_doc_path) > 20 else "" - ) + + # Handle cases where first_doc_path is None + if first_doc_path: + path_prefix = first_doc_path[:20] + ( + "..." if len(first_doc_path) > 20 else "" + ) + else: + path_prefix = "unknown_source" + total_files = len(to_process_docs) job_name = f"{path_prefix}[{total_files} files]" pipeline_status["job_name"] = job_name From e1541caea9a344c8fb4baf8d6bb0b65c339486a6 Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Wed, 9 Jul 2025 12:10:06 +0800 Subject: [PATCH 27/30] Update webui setting --- lightrag/operate.py | 2 ++ lightrag_webui/src/stores/settings.ts | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/lightrag/operate.py b/lightrag/operate.py index 4b2aa250..a27e19f4 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -2500,6 +2500,8 @@ async def naive_query( source_type="vector", ) + logger.info(f"Final context: {len(processed_chunks)} chunks") + # Build text_units_context from processed chunks text_units_context = [] for i, chunk in enumerate(processed_chunks): diff --git a/lightrag_webui/src/stores/settings.ts b/lightrag_webui/src/stores/settings.ts index 203502dc..5942ddca 100644 --- a/lightrag_webui/src/stores/settings.ts +++ b/lightrag_webui/src/stores/settings.ts @@ -111,7 +111,7 @@ const useSettingsStoreBase = create()( mode: 'global', response_type: 'Multiple Paragraphs', top_k: 10, - max_token_for_text_unit: 4000, + max_token_for_text_unit: 6000, max_token_for_global_context: 4000, max_token_for_local_context: 4000, only_need_context: false, From bfa0844ecb2f4cec3aaf4fbd431493909c8fbd7c Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 9 Jul 2025 15:17:05 +0800 Subject: [PATCH 28/30] Update README --- README-zh.md | 30 +++++++++++++++++++++--------- README.md | 30 +++++++++++++++++++++--------- 2 files changed, 42 insertions(+), 18 deletions(-) diff --git a/README-zh.md b/README-zh.md index e9599099..678d727b 100644 --- a/README-zh.md +++ b/README-zh.md @@ -824,7 +824,7 @@ rag = LightRAG( create INDEX CONCURRENTLY entity_idx_node_id ON dickens."Entity" (ag_catalog.agtype_access_operator(properties, '"node_id"'::agtype)); CREATE INDEX CONCURRENTLY entity_node_id_gin_idx ON dickens."Entity" using gin(properties); ALTER TABLE dickens."DIRECTED" CLUSTER ON directed_sid_idx; - + -- 如有必要可以删除 drop INDEX entity_p_idx; drop INDEX vertex_p_idx; @@ -849,6 +849,18 @@ rag = LightRAG( +### LightRAG实例间的数据隔离 + +通过 workspace 参数可以不同实现不同LightRAG实例之间的存储数据隔离。LightRAG在初始化后workspace就已经确定,之后修改workspace是无效的。下面是不同类型的存储实现工作空间的方式: + +- **对于本地基于文件的数据库,数据隔离通过工作空间子目录实现:** JsonKVStorage, JsonDocStatusStorage, NetworkXStorage, NanoVectorDBStorage, FaissVectorDBStorage。 +- **对于将数据存储在集合(collection)中的数据库,通过在集合名称前添加工作空间前缀来实现:** RedisKVStorage, RedisDocStatusStorage, MilvusVectorDBStorage, QdrantVectorDBStorage, MongoKVStorage, MongoDocStatusStorage, MongoVectorDBStorage, MongoGraphStorage, PGGraphStorage。 +- **对于关系型数据库,数据隔离通过向表中添加 `workspace` 字段进行数据的逻辑隔离:** PGKVStorage, PGVectorStorage, PGDocStatusStorage。 + +* **对于Neo4j图数据库,通过label来实现数据的逻辑隔离**:Neo4JStorage + +为了保持对遗留数据的兼容,在未配置工作空间时PostgreSQL的默认工作空间为`default`,Neo4j的默认工作空间为`base`。对于所有的外部存储,系统都提供了专用的工作空间环境变量,用于覆盖公共的 `WORKSPACE`环境变量配置。这些适用于指定存储类型的工作空间环境变量为:`REDIS_WORKSPACE`, `MILVUS_WORKSPACE`, `QDRANT_WORKSPACE`, `MONGODB_WORKSPACE`, `POSTGRES_WORKSPACE`, `NEO4J_WORKSPACE`。 + ## 编辑实体和关系 LightRAG现在支持全面的知识图谱管理功能,允许您在知识图谱中创建、编辑和删除实体和关系。 @@ -1170,17 +1182,17 @@ LightRAG 现已与 [RAG-Anything](https://github.com/HKUDS/RAG-Anything) 实现 from lightrag.llm.openai import openai_complete_if_cache, openai_embed from lightrag.utils import EmbeddingFunc import os - + async def load_existing_lightrag(): # 首先,创建或加载现有的 LightRAG 实例 lightrag_working_dir = "./existing_lightrag_storage" - + # 检查是否存在之前的 LightRAG 实例 if os.path.exists(lightrag_working_dir) and os.listdir(lightrag_working_dir): print("✅ Found existing LightRAG instance, loading...") else: print("❌ No existing LightRAG instance found, will create new one") - + # 使用您的配置创建/加载 LightRAG 实例 lightrag_instance = LightRAG( working_dir=lightrag_working_dir, @@ -1203,10 +1215,10 @@ LightRAG 现已与 [RAG-Anything](https://github.com/HKUDS/RAG-Anything) 实现 ), ) ) - + # 初始化存储(如果有现有数据,这将加载现有数据) await lightrag_instance.initialize_storages() - + # 现在使用现有的 LightRAG 实例初始化 RAGAnything rag = RAGAnything( lightrag=lightrag_instance, # 传递现有的 LightRAG 实例 @@ -1235,20 +1247,20 @@ LightRAG 现已与 [RAG-Anything](https://github.com/HKUDS/RAG-Anything) 实现 ) # 注意:working_dir、llm_model_func、embedding_func 等都从 lightrag_instance 继承 ) - + # 查询现有的知识库 result = await rag.query_with_multimodal( "What data has been processed in this LightRAG instance?", mode="hybrid" ) print("Query result:", result) - + # 向现有的 LightRAG 实例添加新的多模态文档 await rag.process_document_complete( file_path="path/to/new/multimodal_document.pdf", output_dir="./output" ) - + if __name__ == "__main__": asyncio.run(load_existing_lightrag()) ``` diff --git a/README.md b/README.md index 39353ef8..6650ada8 100644 --- a/README.md +++ b/README.md @@ -239,6 +239,7 @@ A full list of LightRAG init parameters: | **Parameter** | **Type** | **Explanation** | **Default** | |--------------|----------|-----------------|-------------| | **working_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` | +| **workspace** | str | Workspace name for data isolation between different LightRAG Instances | | | **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` | | **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` | | **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`,`Neo4JStorage`,`PGGraphStorage`,`AGEStorage` | `NetworkXStorage` | @@ -796,7 +797,7 @@ For production level scenarios you will most likely want to leverage an enterpri create INDEX CONCURRENTLY entity_idx_node_id ON dickens."Entity" (ag_catalog.agtype_access_operator(properties, '"node_id"'::agtype)); CREATE INDEX CONCURRENTLY entity_node_id_gin_idx ON dickens."Entity" using gin(properties); ALTER TABLE dickens."DIRECTED" CLUSTER ON directed_sid_idx; - + -- drop if necessary drop INDEX entity_p_idx; drop INDEX vertex_p_idx; @@ -895,6 +896,17 @@ async def initialize_rag(): +### Data Isolation Between LightRAG Instances + +The `workspace` parameter ensures data isolation between different LightRAG instances. Once initialized, the `workspace` is immutable and cannot be changed.Here is how workspaces are implemented for different types of storage: + +- **For local file-based databases, data isolation is achieved through workspace subdirectories:** `JsonKVStorage`, `JsonDocStatusStorage`, `NetworkXStorage`, `NanoVectorDBStorage`, `FaissVectorDBStorage`. +- **For databases that store data in collections, it's done by adding a workspace prefix to the collection name:** `RedisKVStorage`, `RedisDocStatusStorage`, `MilvusVectorDBStorage`, `QdrantVectorDBStorage`, `MongoKVStorage`, `MongoDocStatusStorage`, `MongoVectorDBStorage`, `MongoGraphStorage`, `PGGraphStorage`. +- **For relational databases, data isolation is achieved by adding a `workspace` field to the tables for logical data separation:** `PGKVStorage`, `PGVectorStorage`, `PGDocStatusStorage`. +- **For the Neo4j graph database, logical data isolation is achieved through labels:** `Neo4JStorage` + +To maintain compatibility with legacy data, the default workspace for PostgreSQL is `default` and for Neo4j is `base` when no workspace is configured. For all external storages, the system provides dedicated workspace environment variables to override the common `WORKSPACE` environment variable configuration. These storage-specific workspace environment variables are: `REDIS_WORKSPACE`, `MILVUS_WORKSPACE`, `QDRANT_WORKSPACE`, `MONGODB_WORKSPACE`, `POSTGRES_WORKSPACE`, `NEO4J_WORKSPACE`. + ## Edit Entities and Relations LightRAG now supports comprehensive knowledge graph management capabilities, allowing you to create, edit, and delete entities and relationships within your knowledge graph. @@ -1219,17 +1231,17 @@ LightRAG now seamlessly integrates with [RAG-Anything](https://github.com/HKUDS/ from lightrag.llm.openai import openai_complete_if_cache, openai_embed from lightrag.utils import EmbeddingFunc import os - + async def load_existing_lightrag(): # First, create or load an existing LightRAG instance lightrag_working_dir = "./existing_lightrag_storage" - + # Check if previous LightRAG instance exists if os.path.exists(lightrag_working_dir) and os.listdir(lightrag_working_dir): print("✅ Found existing LightRAG instance, loading...") else: print("❌ No existing LightRAG instance found, will create new one") - + # Create/Load LightRAG instance with your configurations lightrag_instance = LightRAG( working_dir=lightrag_working_dir, @@ -1252,10 +1264,10 @@ LightRAG now seamlessly integrates with [RAG-Anything](https://github.com/HKUDS/ ), ) ) - + # Initialize storage (this will load existing data if available) await lightrag_instance.initialize_storages() - + # Now initialize RAGAnything with the existing LightRAG instance rag = RAGAnything( lightrag=lightrag_instance, # Pass the existing LightRAG instance @@ -1284,20 +1296,20 @@ LightRAG now seamlessly integrates with [RAG-Anything](https://github.com/HKUDS/ ) # Note: working_dir, llm_model_func, embedding_func, etc. are inherited from lightrag_instance ) - + # Query the existing knowledge base result = await rag.query_with_multimodal( "What data has been processed in this LightRAG instance?", mode="hybrid" ) print("Query result:", result) - + # Add new multimodal documents to the existing LightRAG instance await rag.process_document_complete( file_path="path/to/new/multimodal_document.pdf", output_dir="./output" ) - + if __name__ == "__main__": asyncio.run(load_existing_lightrag()) ``` From e457374224e8d105a0e698adf15cd804a2598519 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 9 Jul 2025 15:33:05 +0800 Subject: [PATCH 29/30] Fix linting --- README-zh.md | 18 +++++++++--------- README.md | 18 +++++++++--------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/README-zh.md b/README-zh.md index 678d727b..b6cc0a8c 100644 --- a/README-zh.md +++ b/README-zh.md @@ -824,7 +824,7 @@ rag = LightRAG( create INDEX CONCURRENTLY entity_idx_node_id ON dickens."Entity" (ag_catalog.agtype_access_operator(properties, '"node_id"'::agtype)); CREATE INDEX CONCURRENTLY entity_node_id_gin_idx ON dickens."Entity" using gin(properties); ALTER TABLE dickens."DIRECTED" CLUSTER ON directed_sid_idx; - + -- 如有必要可以删除 drop INDEX entity_p_idx; drop INDEX vertex_p_idx; @@ -1182,17 +1182,17 @@ LightRAG 现已与 [RAG-Anything](https://github.com/HKUDS/RAG-Anything) 实现 from lightrag.llm.openai import openai_complete_if_cache, openai_embed from lightrag.utils import EmbeddingFunc import os - + async def load_existing_lightrag(): # 首先,创建或加载现有的 LightRAG 实例 lightrag_working_dir = "./existing_lightrag_storage" - + # 检查是否存在之前的 LightRAG 实例 if os.path.exists(lightrag_working_dir) and os.listdir(lightrag_working_dir): print("✅ Found existing LightRAG instance, loading...") else: print("❌ No existing LightRAG instance found, will create new one") - + # 使用您的配置创建/加载 LightRAG 实例 lightrag_instance = LightRAG( working_dir=lightrag_working_dir, @@ -1215,10 +1215,10 @@ LightRAG 现已与 [RAG-Anything](https://github.com/HKUDS/RAG-Anything) 实现 ), ) ) - + # 初始化存储(如果有现有数据,这将加载现有数据) await lightrag_instance.initialize_storages() - + # 现在使用现有的 LightRAG 实例初始化 RAGAnything rag = RAGAnything( lightrag=lightrag_instance, # 传递现有的 LightRAG 实例 @@ -1247,20 +1247,20 @@ LightRAG 现已与 [RAG-Anything](https://github.com/HKUDS/RAG-Anything) 实现 ) # 注意:working_dir、llm_model_func、embedding_func 等都从 lightrag_instance 继承 ) - + # 查询现有的知识库 result = await rag.query_with_multimodal( "What data has been processed in this LightRAG instance?", mode="hybrid" ) print("Query result:", result) - + # 向现有的 LightRAG 实例添加新的多模态文档 await rag.process_document_complete( file_path="path/to/new/multimodal_document.pdf", output_dir="./output" ) - + if __name__ == "__main__": asyncio.run(load_existing_lightrag()) ``` diff --git a/README.md b/README.md index 6650ada8..1cb5e6f7 100644 --- a/README.md +++ b/README.md @@ -797,7 +797,7 @@ For production level scenarios you will most likely want to leverage an enterpri create INDEX CONCURRENTLY entity_idx_node_id ON dickens."Entity" (ag_catalog.agtype_access_operator(properties, '"node_id"'::agtype)); CREATE INDEX CONCURRENTLY entity_node_id_gin_idx ON dickens."Entity" using gin(properties); ALTER TABLE dickens."DIRECTED" CLUSTER ON directed_sid_idx; - + -- drop if necessary drop INDEX entity_p_idx; drop INDEX vertex_p_idx; @@ -1231,17 +1231,17 @@ LightRAG now seamlessly integrates with [RAG-Anything](https://github.com/HKUDS/ from lightrag.llm.openai import openai_complete_if_cache, openai_embed from lightrag.utils import EmbeddingFunc import os - + async def load_existing_lightrag(): # First, create or load an existing LightRAG instance lightrag_working_dir = "./existing_lightrag_storage" - + # Check if previous LightRAG instance exists if os.path.exists(lightrag_working_dir) and os.listdir(lightrag_working_dir): print("✅ Found existing LightRAG instance, loading...") else: print("❌ No existing LightRAG instance found, will create new one") - + # Create/Load LightRAG instance with your configurations lightrag_instance = LightRAG( working_dir=lightrag_working_dir, @@ -1264,10 +1264,10 @@ LightRAG now seamlessly integrates with [RAG-Anything](https://github.com/HKUDS/ ), ) ) - + # Initialize storage (this will load existing data if available) await lightrag_instance.initialize_storages() - + # Now initialize RAGAnything with the existing LightRAG instance rag = RAGAnything( lightrag=lightrag_instance, # Pass the existing LightRAG instance @@ -1296,20 +1296,20 @@ LightRAG now seamlessly integrates with [RAG-Anything](https://github.com/HKUDS/ ) # Note: working_dir, llm_model_func, embedding_func, etc. are inherited from lightrag_instance ) - + # Query the existing knowledge base result = await rag.query_with_multimodal( "What data has been processed in this LightRAG instance?", mode="hybrid" ) print("Query result:", result) - + # Add new multimodal documents to the existing LightRAG instance await rag.process_document_complete( file_path="path/to/new/multimodal_document.pdf", output_dir="./output" ) - + if __name__ == "__main__": asyncio.run(load_existing_lightrag()) ``` From b0479c078a0edb125891dd62906a3be45ffcb546 Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Wed, 9 Jul 2025 15:55:38 +0800 Subject: [PATCH 30/30] fix process_chunks_unified() --- lightrag/operate.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/lightrag/operate.py b/lightrag/operate.py index a27e19f4..be4499ab 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -2879,15 +2879,7 @@ async def process_chunks_unified( f"Deduplication: {len(unique_chunks)} chunks (original: {len(chunks)})" ) - # 2. Apply chunk_top_k limiting if specified - if query_param.chunk_top_k is not None and query_param.chunk_top_k > 0: - if len(unique_chunks) > query_param.chunk_top_k: - unique_chunks = unique_chunks[: query_param.chunk_top_k] - logger.debug( - f"Chunk top-k limiting: kept {len(unique_chunks)} chunks (chunk_top_k={query_param.chunk_top_k})" - ) - - # 3. Apply reranking if enabled and query is provided + # 2. Apply reranking if enabled and query is provided if global_config.get("enable_rerank", False) and query and unique_chunks: rerank_top_k = query_param.chunk_rerank_top_k or len(unique_chunks) unique_chunks = await apply_rerank_if_enabled( @@ -2898,6 +2890,14 @@ async def process_chunks_unified( ) logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})") + # 3. Apply chunk_top_k limiting if specified + if query_param.chunk_top_k is not None and query_param.chunk_top_k > 0: + if len(unique_chunks) > query_param.chunk_top_k: + unique_chunks = unique_chunks[: query_param.chunk_top_k] + logger.debug( + f"Chunk top-k limiting: kept {len(unique_chunks)} chunks (chunk_top_k={query_param.chunk_top_k})" + ) + # 4. Token-based final truncation tokenizer = global_config.get("tokenizer") if tokenizer and unique_chunks: