From 580cb7906cb2189b994dcb2decb9fa6286a3d86d Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 22 Aug 2025 19:29:45 +0800 Subject: [PATCH] feat: Add multiple rerank provider support to LightRAG Server by adding new env vars and cli params - Add --enable-rerank CLI argument and ENABLE_RERANK env var - Simplify rerank configuration logic to only check enable flag and binding - Update health endpoint to show enable_rerank and rerank_configured status - Improve logging messages for rerank enable/disable states - Maintain backward compatibility with default value True --- docs/rerank_integration.md | 281 ------------------ env.example | 30 +- lightrag/api/config.py | 14 + lightrag/api/lightrag_server.py | 43 ++- lightrag/rerank.py | 493 +++++++++++++++++--------------- lightrag/utils.py | 75 ++--- 6 files changed, 368 insertions(+), 568 deletions(-) delete mode 100644 docs/rerank_integration.md diff --git a/docs/rerank_integration.md b/docs/rerank_integration.md deleted file mode 100644 index 0e6c5169..00000000 --- a/docs/rerank_integration.md +++ /dev/null @@ -1,281 +0,0 @@ -# Rerank Integration Guide - -LightRAG supports reranking functionality to improve retrieval quality by re-ordering documents based on their relevance to the query. Reranking is now controlled per query via the `enable_rerank` parameter (default: True). - -## Quick Start - -### Environment Variables - -Set these variables in your `.env` file or environment for rerank model configuration: - -```bash -# Rerank model configuration (required when enable_rerank=True in queries) -RERANK_MODEL=BAAI/bge-reranker-v2-m3 -RERANK_BINDING_HOST=https://api.your-provider.com/v1/rerank -RERANK_BINDING_API_KEY=your_api_key_here -``` - -### Programmatic Configuration - -```python -from lightrag import LightRAG, QueryParam -from lightrag.rerank import custom_rerank, RerankModel - -# Method 1: Using a custom rerank function with all settings included -async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs): - return await custom_rerank( - query=query, - documents=documents, - model="BAAI/bge-reranker-v2-m3", - base_url="https://api.your-provider.com/v1/rerank", - api_key="your_api_key_here", - top_n=top_n or 10, # Handle top_n within the function - **kwargs - ) - -rag = LightRAG( - working_dir="./rag_storage", - llm_model_func=your_llm_func, - embedding_func=your_embedding_func, - rerank_model_func=my_rerank_func, # Configure rerank function -) - -# Query with rerank enabled (default) -result = await rag.aquery( - "your query", - param=QueryParam(enable_rerank=True) # Control rerank per query -) - -# Query with rerank disabled -result = await rag.aquery( - "your query", - param=QueryParam(enable_rerank=False) -) - -# Method 2: Using RerankModel wrapper -rerank_model = RerankModel( - rerank_func=custom_rerank, - kwargs={ - "model": "BAAI/bge-reranker-v2-m3", - "base_url": "https://api.your-provider.com/v1/rerank", - "api_key": "your_api_key_here", - } -) - -rag = LightRAG( - working_dir="./rag_storage", - llm_model_func=your_llm_func, - embedding_func=your_embedding_func, - rerank_model_func=rerank_model.rerank, -) - -# Control rerank per query -result = await rag.aquery( - "your query", - param=QueryParam( - enable_rerank=True, # Enable rerank for this query - chunk_top_k=5 # Number of chunks to keep after reranking - ) -) -``` - -## Supported Providers - -### 1. Custom/Generic API (Recommended) - -For Jina/Cohere compatible APIs: - -```python -from lightrag.rerank import custom_rerank - -# Your custom API endpoint -result = await custom_rerank( - query="your query", - documents=documents, - model="BAAI/bge-reranker-v2-m3", - base_url="https://api.your-provider.com/v1/rerank", - api_key="your_api_key_here", - top_n=10 -) -``` - -### 2. Jina AI - -```python -from lightrag.rerank import jina_rerank - -result = await jina_rerank( - query="your query", - documents=documents, - model="BAAI/bge-reranker-v2-m3", - api_key="your_jina_api_key", - top_n=10 -) -``` - -### 3. Cohere - -```python -from lightrag.rerank import cohere_rerank - -result = await cohere_rerank( - query="your query", - documents=documents, - model="rerank-english-v2.0", - api_key="your_cohere_api_key", - top_n=10 -) -``` - -## Integration Points - -Reranking is automatically applied at these key retrieval stages: - -1. **Naive Mode**: After vector similarity search in `_get_vector_context` -2. **Local Mode**: After entity retrieval in `_get_node_data` -3. **Global Mode**: After relationship retrieval in `_get_edge_data` -4. **Hybrid/Mix Modes**: Applied to all relevant components - -## Configuration Parameters - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `enable_rerank` | bool | False | Enable/disable reranking | -| `rerank_model_func` | callable | None | Custom rerank function containing all configurations (model, API keys, top_n, etc.) | - -## Example Usage - -### Basic Usage - -```python -import asyncio -from lightrag import LightRAG, QueryParam -from lightrag.llm.openai import gpt_4o_mini_complete, openai_embedding -from lightrag.kg.shared_storage import initialize_pipeline_status -from lightrag.rerank import jina_rerank - -async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs): - """Custom rerank function with all settings included""" - return await jina_rerank( - query=query, - documents=documents, - model="BAAI/bge-reranker-v2-m3", - api_key="your_jina_api_key_here", - top_n=top_n or 10, # Default top_n if not provided - **kwargs - ) - -async def main(): - # Initialize with rerank enabled - rag = LightRAG( - working_dir="./rag_storage", - llm_model_func=gpt_4o_mini_complete, - embedding_func=openai_embedding, - rerank_model_func=my_rerank_func, - ) - - await rag.initialize_storages() - await initialize_pipeline_status() - - # Insert documents - await rag.ainsert([ - "Document 1 content...", - "Document 2 content...", - ]) - - # Query with rerank (automatically applied) - result = await rag.aquery( - "Your question here", - param=QueryParam(enable_rerank=True) # This top_n is passed to rerank function - ) - - print(result) - -asyncio.run(main()) -``` - -### Direct Rerank Usage - -```python -from lightrag.rerank import custom_rerank - -async def test_rerank(): - documents = [ - {"content": "Text about topic A"}, - {"content": "Text about topic B"}, - {"content": "Text about topic C"}, - ] - - reranked = await custom_rerank( - query="Tell me about topic A", - documents=documents, - model="BAAI/bge-reranker-v2-m3", - base_url="https://api.your-provider.com/v1/rerank", - api_key="your_api_key_here", - top_n=2 - ) - - for doc in reranked: - print(f"Score: {doc.get('rerank_score')}, Content: {doc.get('content')}") -``` - -## Best Practices - -1. **Self-Contained Functions**: Include all necessary configurations (API keys, models, top_n handling) within your rerank function -2. **Performance**: Use reranking selectively for better performance vs. quality tradeoff -3. **API Limits**: Monitor API usage and implement rate limiting within your rerank function -4. **Fallback**: Always handle rerank failures gracefully (returns original results) -5. **Top-n Handling**: Handle top_n parameter appropriately within your rerank function -6. **Cost Management**: Consider rerank API costs in your budget planning - -## Troubleshooting - -### Common Issues - -1. **API Key Missing**: Ensure API keys are properly configured within your rerank function -2. **Network Issues**: Check API endpoints and network connectivity -3. **Model Errors**: Verify the rerank model name is supported by your API -4. **Document Format**: Ensure documents have `content` or `text` fields - -### Debug Mode - -Enable debug logging to see rerank operations: - -```python -import logging -logging.getLogger("lightrag.rerank").setLevel(logging.DEBUG) -``` - -### Error Handling - -The rerank integration includes automatic fallback: - -```python -# If rerank fails, original documents are returned -# No exceptions are raised to the user -# Errors are logged for debugging -``` - -## API Compatibility - -The generic rerank API expects this response format: - -```json -{ - "results": [ - { - "index": 0, - "relevance_score": 0.95 - }, - { - "index": 2, - "relevance_score": 0.87 - } - ] -} -``` - -This is compatible with: -- Jina AI Rerank API -- Cohere Rerank API -- Custom APIs following the same format diff --git a/env.example b/env.example index 47a2ff60..c39faa5f 100644 --- a/env.example +++ b/env.example @@ -85,16 +85,36 @@ ENABLE_LLM_CACHE=true ### If reranking is enabled, the impact of chunk selection strategies will be diminished. # KG_CHUNK_PICK_METHOD=VECTOR +######################################################### ### Reranking configuration -### Reranker Set ENABLE_RERANK to true in reranking model is configed -# ENABLE_RERANK=True -### Minimum rerank score for document chunk exclusion (set to 0.0 to keep all chunks, 0.6 or above if LLM is not strong enought) +### RERANK_BINDING type: cohere, jina, aliyun +### For rerank model deployed by vLLM use cohere binding +######################################################### +ENABLE_RERANK=False +RERANK_BINDING=cohere +### rerank score chunk filter(set to 0.0 to keep all chunks, 0.6 or above if LLM is not strong enought) # MIN_RERANK_SCORE=0.0 -### Rerank model configuration (required when ENABLE_RERANK=True) -# RERANK_MODEL=jina-reranker-v2-base-multilingual + +### For local deployment +# RERANK_MODEL=BAAI/bge-reranker-v2-m3 +# RERANK_BINDING_HOST=http://localhost:8000 +# RERANK_BINDING_API_KEY=your_rerank_api_key_here + +### Default value for Cohere AI +# RERANK_MODEL=rerank-v3.5 +# RERANK_BINDING_HOST=https://ai.znipower.com:5017/rerank +# RERANK_BINDING_API_KEY=your_rerank_api_key_here + +### Default value for Jina AI +# RERANK_MODELjina-reranker-v2-base-multilingual # RERANK_BINDING_HOST=https://api.jina.ai/v1/rerank # RERANK_BINDING_API_KEY=your_rerank_api_key_here +### Default value for Aliyun +# RERANK_MODEL=gte-rerank-v2 +# RERANK_BINDING_HOST=https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank +# RERANK_BINDING_API_KEY=your_rerank_api_key_here + ######################################## ### Document processing configuration ######################################## diff --git a/lightrag/api/config.py b/lightrag/api/config.py index 83a56f5a..bc24cd70 100644 --- a/lightrag/api/config.py +++ b/lightrag/api/config.py @@ -225,6 +225,19 @@ def parse_args() -> argparse.Namespace: choices=["lollms", "ollama", "openai", "azure_openai", "aws_bedrock", "jina"], help="Embedding binding type (default: from env or ollama)", ) + parser.add_argument( + "--rerank-binding", + type=str, + default=get_env_value("RERANK_BINDING", "cohere"), + choices=["cohere", "jina", "aliyun"], + help="Rerank binding type (default: from env or cohere)", + ) + parser.add_argument( + "--enable-rerank", + action="store_true", + default=get_env_value("ENABLE_RERANK", True, bool), + help="Enable rerank functionality (default: from env or True)", + ) # Conditionally add binding options defined in binding_options module # This will add command line arguments for all binding options (e.g., --ollama-embedding-num_ctx) @@ -340,6 +353,7 @@ def parse_args() -> argparse.Namespace: args.rerank_model = get_env_value("RERANK_MODEL", "BAAI/bge-reranker-v2-m3") args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None) args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None) + # Note: rerank_binding is already set by argparse, no need to override from env # Min rerank score configuration args.min_rerank_score = get_env_value( diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 708fedd2..8e3f9af1 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -390,33 +390,44 @@ def create_app(args): ), ) - # Configure rerank function if model and API are configured + # Configure rerank function based on enable_rerank parameter rerank_model_func = None - if args.rerank_binding_api_key and args.rerank_binding_host: - from lightrag.rerank import custom_rerank + if args.enable_rerank and args.rerank_binding: + from lightrag.rerank import cohere_rerank, jina_rerank, ali_rerank + + # Map rerank binding to corresponding function + rerank_functions = { + "cohere": cohere_rerank, + "jina": jina_rerank, + "aliyun": ali_rerank, + } + + # Select the appropriate rerank function based on binding + selected_rerank_func = rerank_functions.get(args.rerank_binding) + if not selected_rerank_func: + logger.error(f"Unsupported rerank binding: {args.rerank_binding}") + raise ValueError(f"Unsupported rerank binding: {args.rerank_binding}") async def server_rerank_func( - query: str, documents: list, top_n: int = None, **kwargs + query: str, documents: list, top_n: int = None, extra_body: dict = None ): """Server rerank function with configuration from environment variables""" - return await custom_rerank( + return await selected_rerank_func( query=query, documents=documents, model=args.rerank_model, base_url=args.rerank_binding_host, api_key=args.rerank_binding_api_key, top_n=top_n, - **kwargs, + extra_body=extra_body, ) rerank_model_func = server_rerank_func logger.info( - f"Rerank model configured: {args.rerank_model} (can be enabled per query)" + f"Rerank enabled: {args.rerank_model} using {args.rerank_binding} provider" ) else: - logger.info( - "Rerank model not configured. Set RERANK_BINDING_API_KEY and RERANK_BINDING_HOST to enable reranking." - ) + logger.info("Rerank disabled") # Create ollama_server_infos from command line arguments from lightrag.api.config import OllamaServerInfos @@ -622,13 +633,15 @@ def create_app(args): "enable_llm_cache": args.enable_llm_cache, "workspace": args.workspace, "max_graph_nodes": args.max_graph_nodes, - # Rerank configuration (based on whether rerank model is configured) - "enable_rerank": rerank_model_func is not None, - "rerank_model": args.rerank_model - if rerank_model_func is not None + # Rerank configuration + "enable_rerank": args.enable_rerank, + "rerank_configured": rerank_model_func is not None, + "rerank_binding": args.rerank_binding + if args.enable_rerank else None, + "rerank_model": args.rerank_model if args.enable_rerank else None, "rerank_binding_host": args.rerank_binding_host - if rerank_model_func is not None + if args.enable_rerank else None, # Environment variable status (requested configuration) "summary_language": args.summary_language, diff --git a/lightrag/rerank.py b/lightrag/rerank.py index 5ed1ca68..dbac1098 100644 --- a/lightrag/rerank.py +++ b/lightrag/rerank.py @@ -2,270 +2,194 @@ from __future__ import annotations import os import aiohttp -from typing import Callable, Any, List, Dict, Optional -from pydantic import BaseModel, Field - +from typing import Any, List, Dict, Optional +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + retry_if_exception_type, +) from .utils import logger +from dotenv import load_dotenv -class RerankModel(BaseModel): - """ - Wrapper for rerank functions that can be used with LightRAG. - - Example usage: - ```python - from lightrag.rerank import RerankModel, jina_rerank - - # Create rerank model - rerank_model = RerankModel( - rerank_func=jina_rerank, - kwargs={ - "model": "BAAI/bge-reranker-v2-m3", - "api_key": "your_api_key_here", - "base_url": "https://api.jina.ai/v1/rerank" - } - ) - - # Use in LightRAG - rag = LightRAG( - rerank_model_func=rerank_model.rerank, - # ... other configurations - ) - - # Query with rerank enabled (default) - result = await rag.aquery( - "your query", - param=QueryParam(enable_rerank=True) - ) - ``` - - Or define a custom function directly: - ```python - async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs): - return await jina_rerank( - query=query, - documents=documents, - model="BAAI/bge-reranker-v2-m3", - api_key="your_api_key_here", - top_n=top_n or 10, - **kwargs - ) - - rag = LightRAG( - rerank_model_func=my_rerank_func, - # ... other configurations - ) - - # Control rerank per query - result = await rag.aquery( - "your query", - param=QueryParam(enable_rerank=True) # Enable rerank for this query - ) - ``` - """ - - rerank_func: Callable[[Any], List[Dict]] - kwargs: Dict[str, Any] = Field(default_factory=dict) - - async def rerank( - self, - query: str, - documents: List[Dict[str, Any]], - top_n: Optional[int] = None, - **extra_kwargs, - ) -> List[Dict[str, Any]]: - """Rerank documents using the configured model function.""" - # Merge extra kwargs with model kwargs - kwargs = {**self.kwargs, **extra_kwargs} - return await self.rerank_func( - query=query, documents=documents, top_n=top_n, **kwargs - ) - - -class MultiRerankModel(BaseModel): - """Multiple rerank models for different modes/scenarios.""" - - # Primary rerank model (used if mode-specific models are not defined) - rerank_model: Optional[RerankModel] = None - - # Mode-specific rerank models - entity_rerank_model: Optional[RerankModel] = None - relation_rerank_model: Optional[RerankModel] = None - chunk_rerank_model: Optional[RerankModel] = None - - async def rerank( - self, - query: str, - documents: List[Dict[str, Any]], - mode: str = "default", - top_n: Optional[int] = None, - **kwargs, - ) -> List[Dict[str, Any]]: - """Rerank using the appropriate model based on mode.""" - - # Select model based on mode - if mode == "entity" and self.entity_rerank_model: - model = self.entity_rerank_model - elif mode == "relation" and self.relation_rerank_model: - model = self.relation_rerank_model - elif mode == "chunk" and self.chunk_rerank_model: - model = self.chunk_rerank_model - elif self.rerank_model: - model = self.rerank_model - else: - logger.warning(f"No rerank model available for mode: {mode}") - return documents - - return await model.rerank(query, documents, top_n, **kwargs) +# use the .env that is inside the current folder +# allows to use different .env file for each lightrag instance +# the OS environment variables take precedence over the .env file +load_dotenv(dotenv_path=".env", override=False) +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=60), + retry=( + retry_if_exception_type(aiohttp.ClientError) + | retry_if_exception_type(aiohttp.ClientResponseError) + ), +) async def generic_rerank_api( query: str, - documents: List[Dict[str, Any]], + documents: List[str], model: str, base_url: str, api_key: str, top_n: Optional[int] = None, - **kwargs, + return_documents: Optional[bool] = None, + extra_body: Optional[Dict[str, Any]] = None, + response_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun" + request_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun" ) -> List[Dict[str, Any]]: """ - Generic rerank function that works with Jina/Cohere compatible APIs. + Generic rerank API call for Jina/Cohere/Aliyun models. Args: query: The search query - documents: List of documents to rerank - model: Model identifier + documents: List of strings to rerank + model: Model name to use base_url: API endpoint URL - api_key: API authentication key + api_key: API key for authentication top_n: Number of top results to return - **kwargs: Additional API-specific parameters + return_documents: Whether to return document text (Jina only) + extra_body: Additional body parameters + response_format: Response format type ("standard" for Jina/Cohere, "aliyun" for Aliyun) Returns: - List of reranked documents with relevance scores + List of dictionary of ["index": int, "relevance_score": float] """ if not api_key: - logger.warning("No API key provided for rerank service") - return documents + raise ValueError("API key is required") - if not documents: - return documents + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + } - # Prepare documents for reranking - handle both text and dict formats - prepared_docs = [] - for doc in documents: - if isinstance(doc, dict): - # Use 'content' field if available, otherwise use 'text' or convert to string - text = doc.get("content") or doc.get("text") or str(doc) - else: - text = str(doc) - prepared_docs.append(text) + # Build request payload based on request format + if request_format == "aliyun": + # Aliyun format: nested input/parameters structure + payload = { + "model": model, + "input": { + "query": query, + "documents": documents, + }, + "parameters": {}, + } - # Prepare request - headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} + # Add optional parameters to parameters object + if top_n is not None: + payload["parameters"]["top_n"] = top_n - data = {"model": model, "query": query, "documents": prepared_docs, **kwargs} + if return_documents is not None: + payload["parameters"]["return_documents"] = return_documents - if top_n is not None: - data["top_n"] = min(top_n, len(prepared_docs)) + # Add extra parameters to parameters object + if extra_body: + payload["parameters"].update(extra_body) + else: + # Standard format for Jina/Cohere + payload = { + "model": model, + "query": query, + "documents": documents, + } - try: - async with aiohttp.ClientSession() as session: - async with session.post(base_url, headers=headers, json=data) as response: - if response.status != 200: - error_text = await response.text() - logger.error(f"Rerank API error {response.status}: {error_text}") - return documents + # Add optional parameters + if top_n is not None: + payload["top_n"] = top_n - result = await response.json() + # Only Jina API supports return_documents parameter + if return_documents is not None: + payload["return_documents"] = return_documents - # Extract reranked results - if "results" in result: - # Standard format: results contain index and relevance_score - reranked_docs = [] - for item in result["results"]: - if "index" in item: - doc_idx = item["index"] - if 0 <= doc_idx < len(documents): - reranked_doc = documents[doc_idx].copy() - if "relevance_score" in item: - reranked_doc["rerank_score"] = item[ - "relevance_score" - ] - reranked_docs.append(reranked_doc) - return reranked_docs - else: - logger.warning("Unexpected rerank API response format") - return documents + # Add extra parameters + if extra_body: + payload.update(extra_body) - except Exception as e: - logger.error(f"Error during reranking: {e}") - return documents - - -async def jina_rerank( - query: str, - documents: List[Dict[str, Any]], - model: str = "BAAI/bge-reranker-v2-m3", - top_n: Optional[int] = None, - base_url: str = "https://api.jina.ai/v1/rerank", - api_key: Optional[str] = None, - **kwargs, -) -> List[Dict[str, Any]]: - """ - Rerank documents using Jina AI API. - - Args: - query: The search query - documents: List of documents to rerank - model: Jina rerank model name - top_n: Number of top results to return - base_url: Jina API endpoint - api_key: Jina API key - **kwargs: Additional parameters - - Returns: - List of reranked documents with relevance scores - """ - if api_key is None: - api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_API_KEY") - - return await generic_rerank_api( - query=query, - documents=documents, - model=model, - base_url=base_url, - api_key=api_key, - top_n=top_n, - **kwargs, + logger.debug( + f"Rerank request: {len(documents)} documents, model: {model}, format: {response_format}" ) + async with aiohttp.ClientSession() as session: + async with session.post(base_url, headers=headers, json=payload) as response: + if response.status != 200: + error_text = await response.text() + content_type = response.headers.get("content-type", "").lower() + is_html_error = ( + error_text.strip().startswith("") + or "text/html" in content_type + ) + + if is_html_error: + if response.status == 502: + clean_error = "Bad Gateway (502) - Rerank service temporarily unavailable. Please try again in a few minutes." + elif response.status == 503: + clean_error = "Service Unavailable (503) - Rerank service is temporarily overloaded. Please try again later." + elif response.status == 504: + clean_error = "Gateway Timeout (504) - Rerank service request timed out. Please try again." + else: + clean_error = f"HTTP {response.status} - Rerank service error. Please try again later." + else: + clean_error = error_text + + logger.error(f"Rerank API error {response.status}: {clean_error}") + raise aiohttp.ClientResponseError( + request_info=response.request_info, + history=response.history, + status=response.status, + message=f"Rerank API error: {clean_error}", + ) + + response_json = await response.json() + + # Handle different response formats + if response_format == "aliyun": + # Aliyun format: {"output": {"results": [...]}} + output = response_json.get("output", {}) + results = output.get("results", []) + elif response_format == "standard": + # Standard format: {"results": [...]} + results = response_json.get("results", []) + else: + raise ValueError(f"Unsupported response format: {response_format}") + + if not results: + logger.warning("Rerank API returned empty results") + return [] + + # Standardize return format + return [ + {"index": result["index"], "relevance_score": result["relevance_score"]} + for result in results + ] + async def cohere_rerank( query: str, - documents: List[Dict[str, Any]], - model: str = "rerank-english-v2.0", + documents: List[str], top_n: Optional[int] = None, - base_url: str = "https://api.cohere.ai/v1/rerank", api_key: Optional[str] = None, - **kwargs, + model: str = "rerank-v3.5", + base_url: str = "https://ai.znipower.com:5017/rerank", + extra_body: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, Any]]: """ Rerank documents using Cohere API. Args: query: The search query - documents: List of documents to rerank - model: Cohere rerank model name + documents: List of strings to rerank top_n: Number of top results to return - base_url: Cohere API endpoint - api_key: Cohere API key - **kwargs: Additional parameters + api_key: API key + model: rerank model name + base_url: API endpoint + extra_body: Additional body for http request(reserved for extra params) Returns: - List of reranked documents with relevance scores + List of dictionary of ["index": int, "relevance_score": float] """ if api_key is None: - api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_API_KEY") + api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY") return await generic_rerank_api( query=query, @@ -274,24 +198,39 @@ async def cohere_rerank( base_url=base_url, api_key=api_key, top_n=top_n, - **kwargs, + return_documents=None, # Cohere doesn't support this parameter + extra_body=extra_body, + response_format="standard", ) -# Convenience function for custom API endpoints -async def custom_rerank( +async def jina_rerank( query: str, - documents: List[Dict[str, Any]], - model: str, - base_url: str, - api_key: str, + documents: List[str], top_n: Optional[int] = None, - **kwargs, + api_key: Optional[str] = None, + model: str = "jina-reranker-v2-base-multilingual", + base_url: str = "https://api.jina.ai/v1/rerank", + extra_body: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, Any]]: """ - Rerank documents using a custom API endpoint. - This is useful for self-hosted or custom rerank services. + Rerank documents using Jina AI API. + + Args: + query: The search query + documents: List of strings to rerank + top_n: Number of top results to return + api_key: API key + model: rerank model name + base_url: API endpoint + extra_body: Additional body for http request(reserved for extra params) + + Returns: + List of dictionary of ["index": int, "relevance_score": float] """ + if api_key is None: + api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_BINDING_API_KEY") + return await generic_rerank_api( query=query, documents=documents, @@ -299,26 +238,112 @@ async def custom_rerank( base_url=base_url, api_key=api_key, top_n=top_n, - **kwargs, + return_documents=False, + extra_body=extra_body, + response_format="standard", ) +async def ali_rerank( + query: str, + documents: List[str], + top_n: Optional[int] = None, + api_key: Optional[str] = None, + model: str = "gte-rerank-v2", + base_url: str = "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank", + extra_body: Optional[Dict[str, Any]] = None, +) -> List[Dict[str, Any]]: + """ + Rerank documents using Aliyun DashScope API. + + Args: + query: The search query + documents: List of strings to rerank + top_n: Number of top results to return + api_key: Aliyun API key + model: rerank model name + base_url: API endpoint + extra_body: Additional body for http request(reserved for extra params) + + Returns: + List of dictionary of ["index": int, "relevance_score": float] + """ + if api_key is None: + api_key = os.getenv("DASHSCOPE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY") + + return await generic_rerank_api( + query=query, + documents=documents, + model=model, + base_url=base_url, + api_key=api_key, + top_n=top_n, + return_documents=False, # Aliyun doesn't need this parameter + extra_body=extra_body, + response_format="aliyun", + request_format="aliyun", + ) + + +"""Please run this test as a module: +python -m lightrag.rerank +""" if __name__ == "__main__": import asyncio async def main(): - # Example usage + # Example usage - documents should be strings, not dictionaries docs = [ - {"content": "The capital of France is Paris."}, - {"content": "Tokyo is the capital of Japan."}, - {"content": "London is the capital of England."}, + "The capital of France is Paris.", + "Tokyo is the capital of Japan.", + "London is the capital of England.", ] query = "What is the capital of France?" - result = await jina_rerank( - query=query, documents=docs, top_n=2, api_key="your-api-key-here" - ) - print(result) + # Test Jina rerank + try: + print("=== Jina Rerank ===") + result = await jina_rerank( + query=query, + documents=docs, + top_n=2, + ) + print("Results:") + for item in result: + print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}") + print(f"Document: {docs[item['index']]}") + except Exception as e: + print(f"Jina Error: {e}") + + # Test Cohere rerank + try: + print("\n=== Cohere Rerank ===") + result = await cohere_rerank( + query=query, + documents=docs, + top_n=2, + ) + print("Results:") + for item in result: + print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}") + print(f"Document: {docs[item['index']]}") + except Exception as e: + print(f"Cohere Error: {e}") + + # Test Aliyun rerank + try: + print("\n=== Aliyun Rerank ===") + result = await ali_rerank( + query=query, + documents=docs, + top_n=2, + ) + print("Results:") + for item in result: + print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}") + print(f"Document: {docs[item['index']]}") + except Exception as e: + print(f"Aliyun Error: {e}") asyncio.run(main()) diff --git a/lightrag/utils.py b/lightrag/utils.py index bec45f5f..65e22d1a 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -1978,17 +1978,50 @@ async def apply_rerank_if_enabled( return retrieved_docs try: - # Apply reranking - let rerank_model_func handle top_k internally - reranked_docs = await rerank_func( + # Extract document content for reranking + document_texts = [] + for doc in retrieved_docs: + # Try multiple possible content fields + content = ( + doc.get("content") + or doc.get("text") + or doc.get("chunk_content") + or doc.get("document") + or str(doc) + ) + document_texts.append(content) + + # Call the new rerank function that returns index-based results + rerank_results = await rerank_func( query=query, - documents=retrieved_docs, - top_n=top_n, + documents=document_texts, + top_n=top_n or len(retrieved_docs), ) - if reranked_docs and len(reranked_docs) > 0: - if len(reranked_docs) > top_n: - reranked_docs = reranked_docs[:top_n] - logger.info(f"Successfully reranked: {len(retrieved_docs)} chunks") - return reranked_docs + + # Process rerank results based on return format + if rerank_results and len(rerank_results) > 0: + # Check if results are in the new index-based format + if isinstance(rerank_results[0], dict) and "index" in rerank_results[0]: + # New format: [{"index": 0, "relevance_score": 0.85}, ...] + reranked_docs = [] + for result in rerank_results: + index = result["index"] + relevance_score = result["relevance_score"] + + # Get original document and add rerank score + if 0 <= index < len(retrieved_docs): + doc = retrieved_docs[index].copy() + doc["rerank_score"] = relevance_score + reranked_docs.append(doc) + + logger.info( + f"Successfully reranked: {len(reranked_docs)} chunks from {len(retrieved_docs)} original chunks" + ) + return reranked_docs + else: + # Legacy format: assume it's already reranked documents + logger.info(f"Using legacy rerank format: {len(rerank_results)} chunks") + return rerank_results[:top_n] if top_n else rerank_results else: logger.warning("Rerank returned empty results, using original chunks") return retrieved_docs @@ -2027,13 +2060,6 @@ async def process_chunks_unified( # 1. Apply reranking if enabled and query is provided if query_param.enable_rerank and query and unique_chunks: - # 保存 chunk_id 字段,因为 rerank 可能会丢失这个字段 - chunk_ids = {} - for chunk in unique_chunks: - chunk_id = chunk.get("chunk_id") - if chunk_id: - chunk_ids[id(chunk)] = chunk_id - rerank_top_k = query_param.chunk_top_k or len(unique_chunks) unique_chunks = await apply_rerank_if_enabled( query=query, @@ -2043,11 +2069,6 @@ async def process_chunks_unified( top_n=rerank_top_k, ) - # 恢复 chunk_id 字段 - for chunk in unique_chunks: - if id(chunk) in chunk_ids: - chunk["chunk_id"] = chunk_ids[id(chunk)] - # 2. Filter by minimum rerank score if reranking is enabled if query_param.enable_rerank and unique_chunks: min_rerank_score = global_config.get("min_rerank_score", 0.5) @@ -2095,13 +2116,6 @@ async def process_chunks_unified( original_count = len(unique_chunks) - # Keep chunk_id field, cause truncate_list_by_token_size will lose it - chunk_ids_map = {} - for i, chunk in enumerate(unique_chunks): - chunk_id = chunk.get("chunk_id") - if chunk_id: - chunk_ids_map[i] = chunk_id - unique_chunks = truncate_list_by_token_size( unique_chunks, key=lambda x: json.dumps(x, ensure_ascii=False), @@ -2109,11 +2123,6 @@ async def process_chunks_unified( tokenizer=tokenizer, ) - # restore chunk_id feiled - for i, chunk in enumerate(unique_chunks): - if i in chunk_ids_map: - chunk["chunk_id"] = chunk_ids_map[i] - logger.debug( f"Token truncation: {len(unique_chunks)} chunks from {original_count} " f"(chunk available tokens: {chunk_token_limit}, source: {source_type})"