From cb3bf3291c1e2d41de0d596849584c04290a6f41 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 20 Jul 2025 00:26:27 +0800 Subject: [PATCH] Fix: rename rerank parameter from top_k to top_n The change aligns with the API parameter naming used by Jina and Cohere rerank services, ensuring consistency and clarity. --- examples/rerank_example.py | 6 +++--- lightrag/api/lightrag_server.py | 4 ++-- lightrag/operate.py | 14 ++++++------ lightrag/rerank.py | 38 ++++++++++++++++----------------- 4 files changed, 31 insertions(+), 31 deletions(-) diff --git a/examples/rerank_example.py b/examples/rerank_example.py index 42b4dd38..5754a750 100644 --- a/examples/rerank_example.py +++ b/examples/rerank_example.py @@ -57,7 +57,7 @@ async def embedding_func(texts: list[str]) -> np.ndarray: ) -async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs): +async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs): """Custom rerank function with all settings included""" return await custom_rerank( query=query, @@ -65,7 +65,7 @@ async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwarg model="BAAI/bge-reranker-v2-m3", base_url="https://api.your-rerank-provider.com/v1/rerank", api_key="your_rerank_api_key_here", - top_k=top_k or 10, # Default top_k if not provided + top_n=top_n or 10, **kwargs, ) @@ -217,7 +217,7 @@ async def test_direct_rerank(): model="BAAI/bge-reranker-v2-m3", base_url="https://api.your-rerank-provider.com/v1/rerank", api_key="your_rerank_api_key_here", - top_k=3, + top_n=3, ) print("\n✅ Rerank Results:") diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 573455e5..17bbaaed 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -298,7 +298,7 @@ def create_app(args): from lightrag.rerank import custom_rerank async def server_rerank_func( - query: str, documents: list, top_k: int = None, **kwargs + query: str, documents: list, top_n: int = None, **kwargs ): """Server rerank function with configuration from environment variables""" return await custom_rerank( @@ -307,7 +307,7 @@ def create_app(args): model=args.rerank_model, base_url=args.rerank_binding_host, api_key=args.rerank_binding_api_key, - top_k=top_k, + top_n=top_n, **kwargs, ) diff --git a/lightrag/operate.py b/lightrag/operate.py index a0418174..5228383a 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -3165,7 +3165,7 @@ async def apply_rerank_if_enabled( retrieved_docs: list[dict], global_config: dict, enable_rerank: bool = True, - top_k: int = None, + top_n: int = None, ) -> list[dict]: """ Apply reranking to retrieved documents if rerank is enabled. @@ -3175,7 +3175,7 @@ async def apply_rerank_if_enabled( retrieved_docs: List of retrieved documents global_config: Global configuration containing rerank settings enable_rerank: Whether to enable reranking from query parameter - top_k: Number of top documents to return after reranking + top_n: Number of top documents to return after reranking Returns: Reranked documents if rerank is enabled, otherwise original documents @@ -3192,18 +3192,18 @@ async def apply_rerank_if_enabled( try: logger.debug( - f"Applying rerank to {len(retrieved_docs)} documents, returning top {top_k}" + f"Applying rerank to {len(retrieved_docs)} documents, returning top {top_n}" ) # Apply reranking - let rerank_model_func handle top_k internally reranked_docs = await rerank_func( query=query, documents=retrieved_docs, - top_k=top_k, + top_n=top_n, ) if reranked_docs and len(reranked_docs) > 0: - if len(reranked_docs) > top_k: - reranked_docs = reranked_docs[:top_k] + if len(reranked_docs) > top_n: + reranked_docs = reranked_docs[:top_n] logger.info( f"Successfully reranked {len(retrieved_docs)} documents to {len(reranked_docs)}" ) @@ -3263,7 +3263,7 @@ async def process_chunks_unified( retrieved_docs=unique_chunks, global_config=global_config, enable_rerank=query_param.enable_rerank, - top_k=rerank_top_k, + top_n=rerank_top_k, ) logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})") diff --git a/lightrag/rerank.py b/lightrag/rerank.py index 297fa053..5ed1ca68 100644 --- a/lightrag/rerank.py +++ b/lightrag/rerank.py @@ -41,13 +41,13 @@ class RerankModel(BaseModel): Or define a custom function directly: ```python - async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs): + async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs): return await jina_rerank( query=query, documents=documents, model="BAAI/bge-reranker-v2-m3", api_key="your_api_key_here", - top_k=top_k or 10, + top_n=top_n or 10, **kwargs ) @@ -71,14 +71,14 @@ class RerankModel(BaseModel): self, query: str, documents: List[Dict[str, Any]], - top_k: Optional[int] = None, + top_n: Optional[int] = None, **extra_kwargs, ) -> List[Dict[str, Any]]: """Rerank documents using the configured model function.""" # Merge extra kwargs with model kwargs kwargs = {**self.kwargs, **extra_kwargs} return await self.rerank_func( - query=query, documents=documents, top_k=top_k, **kwargs + query=query, documents=documents, top_n=top_n, **kwargs ) @@ -98,7 +98,7 @@ class MultiRerankModel(BaseModel): query: str, documents: List[Dict[str, Any]], mode: str = "default", - top_k: Optional[int] = None, + top_n: Optional[int] = None, **kwargs, ) -> List[Dict[str, Any]]: """Rerank using the appropriate model based on mode.""" @@ -116,7 +116,7 @@ class MultiRerankModel(BaseModel): logger.warning(f"No rerank model available for mode: {mode}") return documents - return await model.rerank(query, documents, top_k, **kwargs) + return await model.rerank(query, documents, top_n, **kwargs) async def generic_rerank_api( @@ -125,7 +125,7 @@ async def generic_rerank_api( model: str, base_url: str, api_key: str, - top_k: Optional[int] = None, + top_n: Optional[int] = None, **kwargs, ) -> List[Dict[str, Any]]: """ @@ -137,7 +137,7 @@ async def generic_rerank_api( model: Model identifier base_url: API endpoint URL api_key: API authentication key - top_k: Number of top results to return + top_n: Number of top results to return **kwargs: Additional API-specific parameters Returns: @@ -165,8 +165,8 @@ async def generic_rerank_api( data = {"model": model, "query": query, "documents": prepared_docs, **kwargs} - if top_k is not None: - data["top_k"] = min(top_k, len(prepared_docs)) + if top_n is not None: + data["top_n"] = min(top_n, len(prepared_docs)) try: async with aiohttp.ClientSession() as session: @@ -206,7 +206,7 @@ async def jina_rerank( query: str, documents: List[Dict[str, Any]], model: str = "BAAI/bge-reranker-v2-m3", - top_k: Optional[int] = None, + top_n: Optional[int] = None, base_url: str = "https://api.jina.ai/v1/rerank", api_key: Optional[str] = None, **kwargs, @@ -218,7 +218,7 @@ async def jina_rerank( query: The search query documents: List of documents to rerank model: Jina rerank model name - top_k: Number of top results to return + top_n: Number of top results to return base_url: Jina API endpoint api_key: Jina API key **kwargs: Additional parameters @@ -235,7 +235,7 @@ async def jina_rerank( model=model, base_url=base_url, api_key=api_key, - top_k=top_k, + top_n=top_n, **kwargs, ) @@ -244,7 +244,7 @@ async def cohere_rerank( query: str, documents: List[Dict[str, Any]], model: str = "rerank-english-v2.0", - top_k: Optional[int] = None, + top_n: Optional[int] = None, base_url: str = "https://api.cohere.ai/v1/rerank", api_key: Optional[str] = None, **kwargs, @@ -256,7 +256,7 @@ async def cohere_rerank( query: The search query documents: List of documents to rerank model: Cohere rerank model name - top_k: Number of top results to return + top_n: Number of top results to return base_url: Cohere API endpoint api_key: Cohere API key **kwargs: Additional parameters @@ -273,7 +273,7 @@ async def cohere_rerank( model=model, base_url=base_url, api_key=api_key, - top_k=top_k, + top_n=top_n, **kwargs, ) @@ -285,7 +285,7 @@ async def custom_rerank( model: str, base_url: str, api_key: str, - top_k: Optional[int] = None, + top_n: Optional[int] = None, **kwargs, ) -> List[Dict[str, Any]]: """ @@ -298,7 +298,7 @@ async def custom_rerank( model=model, base_url=base_url, api_key=api_key, - top_k=top_k, + top_n=top_n, **kwargs, ) @@ -317,7 +317,7 @@ if __name__ == "__main__": query = "What is the capital of France?" result = await jina_rerank( - query=query, documents=docs, top_k=2, api_key="your-api-key-here" + query=query, documents=docs, top_n=2, api_key="your-api-key-here" ) print(result)