diff --git a/docs/rerank_integration.md b/docs/rerank_integration.md index fdaebfa5..4e4d433f 100644 --- a/docs/rerank_integration.md +++ b/docs/rerank_integration.md @@ -1,36 +1,24 @@ -# Rerank Integration in LightRAG +# Rerank Integration Guide -This document explains how to configure and use the rerank functionality in LightRAG to improve retrieval quality. +LightRAG supports reranking functionality to improve retrieval quality by re-ordering documents based on their relevance to the query. Reranking is now controlled per query via the `enable_rerank` parameter (default: True). -## Overview - -Reranking is an optional feature that improves the quality of retrieved documents by re-ordering them based on their relevance to the query. This is particularly useful when you want higher precision in document retrieval across all query modes (naive, local, global, hybrid, mix). - -## Architecture - -The rerank integration follows a simplified design pattern: - -- **Single Function Configuration**: All rerank settings (model, API keys, top_k, etc.) are contained within the rerank function -- **Async Processing**: Non-blocking rerank operations -- **Error Handling**: Graceful fallback to original results -- **Optional Feature**: Can be enabled/disabled via configuration -- **Code Reuse**: Single generic implementation for Jina/Cohere compatible APIs - -## Configuration +## Quick Start ### Environment Variables -Set this variable in your `.env` file or environment: +Set these variables in your `.env` file or environment for rerank model configuration: ```bash -# Enable/disable reranking -ENABLE_RERANK=True +# Rerank model configuration (required when enable_rerank=True in queries) +RERANK_MODEL=BAAI/bge-reranker-v2-m3 +RERANK_BINDING_HOST=https://api.your-provider.com/v1/rerank +RERANK_BINDING_API_KEY=your_api_key_here ``` ### Programmatic Configuration ```python -from lightrag import LightRAG +from lightrag import LightRAG, QueryParam from lightrag.rerank import custom_rerank, RerankModel # Method 1: Using a custom rerank function with all settings included @@ -49,8 +37,19 @@ rag = LightRAG( working_dir="./rag_storage", llm_model_func=your_llm_func, embedding_func=your_embedding_func, - enable_rerank=True, - rerank_model_func=my_rerank_func, + rerank_model_func=my_rerank_func, # Configure rerank function +) + +# Query with rerank enabled (default) +result = await rag.aquery( + "your query", + param=QueryParam(enable_rerank=True) # Control rerank per query +) + +# Query with rerank disabled +result = await rag.aquery( + "your query", + param=QueryParam(enable_rerank=False) ) # Method 2: Using RerankModel wrapper @@ -67,9 +66,17 @@ rag = LightRAG( working_dir="./rag_storage", llm_model_func=your_llm_func, embedding_func=your_embedding_func, - enable_rerank=True, rerank_model_func=rerank_model.rerank, ) + +# Control rerank per query +result = await rag.aquery( + "your query", + param=QueryParam( + enable_rerank=True, # Enable rerank for this query + chunk_top_k=5 # Number of chunks to keep after reranking + ) +) ``` ## Supported Providers @@ -164,7 +171,6 @@ async def main(): working_dir="./rag_storage", llm_model_func=gpt_4o_mini_complete, embedding_func=openai_embedding, - enable_rerank=True, rerank_model_func=my_rerank_func, ) @@ -180,7 +186,7 @@ async def main(): # Query with rerank (automatically applied) result = await rag.aquery( "Your question here", - param=QueryParam(mode="hybrid", top_k=5) # This top_k is passed to rerank function + param=QueryParam(enable_rerank=True) # This top_k is passed to rerank function ) print(result) diff --git a/examples/rerank_example.py b/examples/rerank_example.py index e0e361a5..42b4dd38 100644 --- a/examples/rerank_example.py +++ b/examples/rerank_example.py @@ -9,7 +9,11 @@ Configuration Required: 2. Set your embedding API key and base URL in embedding_func() 3. Set your rerank API key and base URL in the rerank configuration 4. Or use environment variables (.env file): - - ENABLE_RERANK=True + - RERANK_MODEL=your_rerank_model + - RERANK_BINDING_HOST=your_rerank_endpoint + - RERANK_BINDING_API_KEY=your_rerank_api_key + +Note: Rerank is now controlled per query via the 'enable_rerank' parameter (default: True) """ import asyncio @@ -83,8 +87,7 @@ async def create_rag_with_rerank(): max_token_size=8192, func=embedding_func, ), - # Simplified Rerank Configuration - enable_rerank=True, + # Rerank Configuration - provide the rerank function rerank_model_func=my_rerank_func, ) @@ -120,7 +123,6 @@ async def create_rag_with_rerank_model(): max_token_size=8192, func=embedding_func, ), - enable_rerank=True, rerank_model_func=rerank_model.rerank, ) @@ -130,9 +132,9 @@ async def create_rag_with_rerank_model(): return rag -async def test_rerank_with_different_topk(): +async def test_rerank_with_different_settings(): """ - Test rerank functionality with different top_k settings + Test rerank functionality with different enable_rerank settings """ print("šŸš€ Setting up LightRAG with Rerank functionality...") @@ -154,16 +156,41 @@ async def test_rerank_with_different_topk(): print(f"\nšŸ” Testing query: '{query}'") print("=" * 80) - # Test different top_k values to show parameter priority - top_k_values = [2, 5, 10] + # Test with rerank enabled (default) + print("\nšŸ“Š Testing with enable_rerank=True (default):") + result_with_rerank = await rag.aquery( + query, + param=QueryParam( + mode="naive", + top_k=10, + chunk_top_k=5, + enable_rerank=True, # Explicitly enable rerank + ), + ) + print(f" Result length: {len(result_with_rerank)} characters") + print(f" Preview: {result_with_rerank[:100]}...") - for top_k in top_k_values: - print(f"\nšŸ“Š Testing with QueryParam(top_k={top_k}):") + # Test with rerank disabled + print("\nšŸ“Š Testing with enable_rerank=False:") + result_without_rerank = await rag.aquery( + query, + param=QueryParam( + mode="naive", + top_k=10, + chunk_top_k=5, + enable_rerank=False, # Disable rerank + ), + ) + print(f" Result length: {len(result_without_rerank)} characters") + print(f" Preview: {result_without_rerank[:100]}...") - # Test naive mode with specific top_k - result = await rag.aquery(query, param=QueryParam(mode="naive", top_k=top_k)) - print(f" Result length: {len(result)} characters") - print(f" Preview: {result[:100]}...") + # Test with default settings (enable_rerank defaults to True) + print("\nšŸ“Š Testing with default settings (enable_rerank defaults to True):") + result_default = await rag.aquery( + query, param=QueryParam(mode="naive", top_k=10, chunk_top_k=5) + ) + print(f" Result length: {len(result_default)} characters") + print(f" Preview: {result_default[:100]}...") async def test_direct_rerank(): @@ -209,17 +236,21 @@ async def main(): print("=" * 60) try: - # Test rerank with different top_k values - await test_rerank_with_different_topk() + # Test rerank with different enable_rerank settings + await test_rerank_with_different_settings() # Test direct rerank await test_direct_rerank() print("\nāœ… Example completed successfully!") print("\nšŸ’” Key Points:") - print(" āœ“ All rerank configurations are contained within rerank_model_func") - print(" āœ“ Rerank improves document relevance ordering") - print(" āœ“ Configure API keys within your rerank function") + print(" āœ“ Rerank is now controlled per query via 'enable_rerank' parameter") + print(" āœ“ Default value for enable_rerank is True") + print(" āœ“ Rerank function is configured at LightRAG initialization") + print(" āœ“ Per-query enable_rerank setting overrides default behavior") + print( + " āœ“ If enable_rerank=True but no rerank model is configured, a warning is issued" + ) print(" āœ“ Monitor API usage and costs when using rerank services") except Exception as e: diff --git a/lightrag/rerank.py b/lightrag/rerank.py index 59719bc9..297fa053 100644 --- a/lightrag/rerank.py +++ b/lightrag/rerank.py @@ -10,55 +10,58 @@ from .utils import logger class RerankModel(BaseModel): """ - Pydantic model class for defining a custom rerank model. - - This class provides a convenient wrapper for rerank functions, allowing you to - encapsulate all rerank configurations (API keys, model settings, etc.) in one place. - - Attributes: - rerank_func (Callable[[Any], List[Dict]]): A callable function that reranks documents. - The function should take query and documents as input and return reranked results. - kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function. - This should include all necessary configurations such as model name, API key, base_url, etc. + Wrapper for rerank functions that can be used with LightRAG. Example usage: - Rerank model example with Jina: - ```python - rerank_model = RerankModel( - rerank_func=jina_rerank, - kwargs={ - "model": "BAAI/bge-reranker-v2-m3", - "api_key": "your_api_key_here", - "base_url": "https://api.jina.ai/v1/rerank" - } + ```python + from lightrag.rerank import RerankModel, jina_rerank + + # Create rerank model + rerank_model = RerankModel( + rerank_func=jina_rerank, + kwargs={ + "model": "BAAI/bge-reranker-v2-m3", + "api_key": "your_api_key_here", + "base_url": "https://api.jina.ai/v1/rerank" + } + ) + + # Use in LightRAG + rag = LightRAG( + rerank_model_func=rerank_model.rerank, + # ... other configurations + ) + + # Query with rerank enabled (default) + result = await rag.aquery( + "your query", + param=QueryParam(enable_rerank=True) + ) + ``` + + Or define a custom function directly: + ```python + async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs): + return await jina_rerank( + query=query, + documents=documents, + model="BAAI/bge-reranker-v2-m3", + api_key="your_api_key_here", + top_k=top_k or 10, + **kwargs ) - # Use in LightRAG - rag = LightRAG( - enable_rerank=True, - rerank_model_func=rerank_model.rerank, - # ... other configurations - ) - ``` + rag = LightRAG( + rerank_model_func=my_rerank_func, + # ... other configurations + ) - Or define a custom function directly: - ```python - async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs): - return await jina_rerank( - query=query, - documents=documents, - model="BAAI/bge-reranker-v2-m3", - api_key="your_api_key_here", - top_k=top_k or 10, - **kwargs - ) - - rag = LightRAG( - enable_rerank=True, - rerank_model_func=my_rerank_func, - # ... other configurations - ) - ``` + # Control rerank per query + result = await rag.aquery( + "your query", + param=QueryParam(enable_rerank=True) # Enable rerank for this query + ) + ``` """ rerank_func: Callable[[Any], List[Dict]]