From 7aaa51cda9a9b2bb4810e89e9b03cd42cb7eda85 Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 24 Nov 2025 22:28:15 +0800 Subject: [PATCH] Add retry decorators to Neo4j read operations for resilience --- lightrag/kg/neo4j_impl.py | 140 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 140 insertions(+) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 256656d8..d3d6c4eb 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -352,6 +352,20 @@ class Neo4JStorage(BaseGraphStorage): # Neo4J handles persistence automatically pass + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.SessionExpired, + ConnectionResetError, + OSError, + AttributeError, + ) + ), + ) async def has_node(self, node_id: str) -> bool: """ Check if a node with the given label exists in the database @@ -385,6 +399,20 @@ class Neo4JStorage(BaseGraphStorage): await result.consume() # Ensure results are consumed even on error raise + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.SessionExpired, + ConnectionResetError, + OSError, + AttributeError, + ) + ), + ) async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: """ Check if an edge exists between two nodes @@ -426,6 +454,20 @@ class Neo4JStorage(BaseGraphStorage): await result.consume() # Ensure results are consumed even on error raise + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.SessionExpired, + ConnectionResetError, + OSError, + AttributeError, + ) + ), + ) async def get_node(self, node_id: str) -> dict[str, str] | None: """Get node by its label identifier, return only node properties @@ -479,6 +521,20 @@ class Neo4JStorage(BaseGraphStorage): ) raise + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.SessionExpired, + ConnectionResetError, + OSError, + AttributeError, + ) + ), + ) async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: """ Retrieve multiple nodes in one query using UNWIND. @@ -515,6 +571,20 @@ class Neo4JStorage(BaseGraphStorage): await result.consume() # Make sure to consume the result fully return nodes + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.SessionExpired, + ConnectionResetError, + OSError, + AttributeError, + ) + ), + ) async def node_degree(self, node_id: str) -> int: """Get the degree (number of relationships) of a node with the given label. If multiple nodes have the same label, returns the degree of the first node. @@ -563,6 +633,20 @@ class Neo4JStorage(BaseGraphStorage): ) raise + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.SessionExpired, + ConnectionResetError, + OSError, + AttributeError, + ) + ), + ) async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: """ Retrieve the degree for multiple nodes in a single query using UNWIND. @@ -647,6 +731,20 @@ class Neo4JStorage(BaseGraphStorage): edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0) return edge_degrees + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.SessionExpired, + ConnectionResetError, + OSError, + AttributeError, + ) + ), + ) async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: @@ -734,6 +832,20 @@ class Neo4JStorage(BaseGraphStorage): ) raise + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.SessionExpired, + ConnectionResetError, + OSError, + AttributeError, + ) + ), + ) async def get_edges_batch( self, pairs: list[dict[str, str]] ) -> dict[tuple[str, str], dict]: @@ -784,6 +896,20 @@ class Neo4JStorage(BaseGraphStorage): await result.consume() return edges_dict + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.SessionExpired, + ConnectionResetError, + OSError, + AttributeError, + ) + ), + ) async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """Retrieves all edges (relationships) for a particular node identified by its label. @@ -851,6 +977,20 @@ class Neo4JStorage(BaseGraphStorage): ) raise + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.SessionExpired, + ConnectionResetError, + OSError, + AttributeError, + ) + ), + ) async def get_nodes_edges_batch( self, node_ids: list[str] ) -> dict[str, list[tuple[str, str]]]: