feat: Add multiple rerank provider support to LightRAG Server by adding new env vars and cli params

- Add --enable-rerank CLI argument and ENABLE_RERANK env var
- Simplify rerank configuration logic to only check enable flag and binding
- Update health endpoint to show enable_rerank and rerank_configured status
- Improve logging messages for rerank enable/disable states
- Maintain backward compatibility with default value True
This commit is contained in:
yangdx
2025-08-22 19:29:45 +08:00
parent 0019a3adc6
commit 580cb7906c
6 changed files with 368 additions and 568 deletions

View File

@@ -1,281 +0,0 @@
# Rerank Integration Guide
LightRAG supports reranking functionality to improve retrieval quality by re-ordering documents based on their relevance to the query. Reranking is now controlled per query via the `enable_rerank` parameter (default: True).
## Quick Start
### Environment Variables
Set these variables in your `.env` file or environment for rerank model configuration:
```bash
# Rerank model configuration (required when enable_rerank=True in queries)
RERANK_MODEL=BAAI/bge-reranker-v2-m3
RERANK_BINDING_HOST=https://api.your-provider.com/v1/rerank
RERANK_BINDING_API_KEY=your_api_key_here
```
### Programmatic Configuration
```python
from lightrag import LightRAG, QueryParam
from lightrag.rerank import custom_rerank, RerankModel
# Method 1: Using a custom rerank function with all settings included
async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
return await custom_rerank(
query=query,
documents=documents,
model="BAAI/bge-reranker-v2-m3",
base_url="https://api.your-provider.com/v1/rerank",
api_key="your_api_key_here",
top_n=top_n or 10, # Handle top_n within the function
**kwargs
)
rag = LightRAG(
working_dir="./rag_storage",
llm_model_func=your_llm_func,
embedding_func=your_embedding_func,
rerank_model_func=my_rerank_func, # Configure rerank function
)
# Query with rerank enabled (default)
result = await rag.aquery(
"your query",
param=QueryParam(enable_rerank=True) # Control rerank per query
)
# Query with rerank disabled
result = await rag.aquery(
"your query",
param=QueryParam(enable_rerank=False)
)
# Method 2: Using RerankModel wrapper
rerank_model = RerankModel(
rerank_func=custom_rerank,
kwargs={
"model": "BAAI/bge-reranker-v2-m3",
"base_url": "https://api.your-provider.com/v1/rerank",
"api_key": "your_api_key_here",
}
)
rag = LightRAG(
working_dir="./rag_storage",
llm_model_func=your_llm_func,
embedding_func=your_embedding_func,
rerank_model_func=rerank_model.rerank,
)
# Control rerank per query
result = await rag.aquery(
"your query",
param=QueryParam(
enable_rerank=True, # Enable rerank for this query
chunk_top_k=5 # Number of chunks to keep after reranking
)
)
```
## Supported Providers
### 1. Custom/Generic API (Recommended)
For Jina/Cohere compatible APIs:
```python
from lightrag.rerank import custom_rerank
# Your custom API endpoint
result = await custom_rerank(
query="your query",
documents=documents,
model="BAAI/bge-reranker-v2-m3",
base_url="https://api.your-provider.com/v1/rerank",
api_key="your_api_key_here",
top_n=10
)
```
### 2. Jina AI
```python
from lightrag.rerank import jina_rerank
result = await jina_rerank(
query="your query",
documents=documents,
model="BAAI/bge-reranker-v2-m3",
api_key="your_jina_api_key",
top_n=10
)
```
### 3. Cohere
```python
from lightrag.rerank import cohere_rerank
result = await cohere_rerank(
query="your query",
documents=documents,
model="rerank-english-v2.0",
api_key="your_cohere_api_key",
top_n=10
)
```
## Integration Points
Reranking is automatically applied at these key retrieval stages:
1. **Naive Mode**: After vector similarity search in `_get_vector_context`
2. **Local Mode**: After entity retrieval in `_get_node_data`
3. **Global Mode**: After relationship retrieval in `_get_edge_data`
4. **Hybrid/Mix Modes**: Applied to all relevant components
## Configuration Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `enable_rerank` | bool | False | Enable/disable reranking |
| `rerank_model_func` | callable | None | Custom rerank function containing all configurations (model, API keys, top_n, etc.) |
## Example Usage
### Basic Usage
```python
import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import gpt_4o_mini_complete, openai_embedding
from lightrag.kg.shared_storage import initialize_pipeline_status
from lightrag.rerank import jina_rerank
async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
"""Custom rerank function with all settings included"""
return await jina_rerank(
query=query,
documents=documents,
model="BAAI/bge-reranker-v2-m3",
api_key="your_jina_api_key_here",
top_n=top_n or 10, # Default top_n if not provided
**kwargs
)
async def main():
# Initialize with rerank enabled
rag = LightRAG(
working_dir="./rag_storage",
llm_model_func=gpt_4o_mini_complete,
embedding_func=openai_embedding,
rerank_model_func=my_rerank_func,
)
await rag.initialize_storages()
await initialize_pipeline_status()
# Insert documents
await rag.ainsert([
"Document 1 content...",
"Document 2 content...",
])
# Query with rerank (automatically applied)
result = await rag.aquery(
"Your question here",
param=QueryParam(enable_rerank=True) # This top_n is passed to rerank function
)
print(result)
asyncio.run(main())
```
### Direct Rerank Usage
```python
from lightrag.rerank import custom_rerank
async def test_rerank():
documents = [
{"content": "Text about topic A"},
{"content": "Text about topic B"},
{"content": "Text about topic C"},
]
reranked = await custom_rerank(
query="Tell me about topic A",
documents=documents,
model="BAAI/bge-reranker-v2-m3",
base_url="https://api.your-provider.com/v1/rerank",
api_key="your_api_key_here",
top_n=2
)
for doc in reranked:
print(f"Score: {doc.get('rerank_score')}, Content: {doc.get('content')}")
```
## Best Practices
1. **Self-Contained Functions**: Include all necessary configurations (API keys, models, top_n handling) within your rerank function
2. **Performance**: Use reranking selectively for better performance vs. quality tradeoff
3. **API Limits**: Monitor API usage and implement rate limiting within your rerank function
4. **Fallback**: Always handle rerank failures gracefully (returns original results)
5. **Top-n Handling**: Handle top_n parameter appropriately within your rerank function
6. **Cost Management**: Consider rerank API costs in your budget planning
## Troubleshooting
### Common Issues
1. **API Key Missing**: Ensure API keys are properly configured within your rerank function
2. **Network Issues**: Check API endpoints and network connectivity
3. **Model Errors**: Verify the rerank model name is supported by your API
4. **Document Format**: Ensure documents have `content` or `text` fields
### Debug Mode
Enable debug logging to see rerank operations:
```python
import logging
logging.getLogger("lightrag.rerank").setLevel(logging.DEBUG)
```
### Error Handling
The rerank integration includes automatic fallback:
```python
# If rerank fails, original documents are returned
# No exceptions are raised to the user
# Errors are logged for debugging
```
## API Compatibility
The generic rerank API expects this response format:
```json
{
"results": [
{
"index": 0,
"relevance_score": 0.95
},
{
"index": 2,
"relevance_score": 0.87
}
]
}
```
This is compatible with:
- Jina AI Rerank API
- Cohere Rerank API
- Custom APIs following the same format

View File

@@ -85,16 +85,36 @@ ENABLE_LLM_CACHE=true
### If reranking is enabled, the impact of chunk selection strategies will be diminished.
# KG_CHUNK_PICK_METHOD=VECTOR
#########################################################
### Reranking configuration
### Reranker Set ENABLE_RERANK to true in reranking model is configed
# ENABLE_RERANK=True
### Minimum rerank score for document chunk exclusion (set to 0.0 to keep all chunks, 0.6 or above if LLM is not strong enought)
### RERANK_BINDING type: cohere, jina, aliyun
### For rerank model deployed by vLLM use cohere binding
#########################################################
ENABLE_RERANK=False
RERANK_BINDING=cohere
### rerank score chunk filter(set to 0.0 to keep all chunks, 0.6 or above if LLM is not strong enought)
# MIN_RERANK_SCORE=0.0
### Rerank model configuration (required when ENABLE_RERANK=True)
# RERANK_MODEL=jina-reranker-v2-base-multilingual
### For local deployment
# RERANK_MODEL=BAAI/bge-reranker-v2-m3
# RERANK_BINDING_HOST=http://localhost:8000
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
### Default value for Cohere AI
# RERANK_MODEL=rerank-v3.5
# RERANK_BINDING_HOST=https://ai.znipower.com:5017/rerank
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
### Default value for Jina AI
# RERANK_MODELjina-reranker-v2-base-multilingual
# RERANK_BINDING_HOST=https://api.jina.ai/v1/rerank
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
### Default value for Aliyun
# RERANK_MODEL=gte-rerank-v2
# RERANK_BINDING_HOST=https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
########################################
### Document processing configuration
########################################

View File

@@ -225,6 +225,19 @@ def parse_args() -> argparse.Namespace:
choices=["lollms", "ollama", "openai", "azure_openai", "aws_bedrock", "jina"],
help="Embedding binding type (default: from env or ollama)",
)
parser.add_argument(
"--rerank-binding",
type=str,
default=get_env_value("RERANK_BINDING", "cohere"),
choices=["cohere", "jina", "aliyun"],
help="Rerank binding type (default: from env or cohere)",
)
parser.add_argument(
"--enable-rerank",
action="store_true",
default=get_env_value("ENABLE_RERANK", True, bool),
help="Enable rerank functionality (default: from env or True)",
)
# Conditionally add binding options defined in binding_options module
# This will add command line arguments for all binding options (e.g., --ollama-embedding-num_ctx)
@@ -340,6 +353,7 @@ def parse_args() -> argparse.Namespace:
args.rerank_model = get_env_value("RERANK_MODEL", "BAAI/bge-reranker-v2-m3")
args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None)
args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None)
# Note: rerank_binding is already set by argparse, no need to override from env
# Min rerank score configuration
args.min_rerank_score = get_env_value(

View File

@@ -390,33 +390,44 @@ def create_app(args):
),
)
# Configure rerank function if model and API are configured
# Configure rerank function based on enable_rerank parameter
rerank_model_func = None
if args.rerank_binding_api_key and args.rerank_binding_host:
from lightrag.rerank import custom_rerank
if args.enable_rerank and args.rerank_binding:
from lightrag.rerank import cohere_rerank, jina_rerank, ali_rerank
# Map rerank binding to corresponding function
rerank_functions = {
"cohere": cohere_rerank,
"jina": jina_rerank,
"aliyun": ali_rerank,
}
# Select the appropriate rerank function based on binding
selected_rerank_func = rerank_functions.get(args.rerank_binding)
if not selected_rerank_func:
logger.error(f"Unsupported rerank binding: {args.rerank_binding}")
raise ValueError(f"Unsupported rerank binding: {args.rerank_binding}")
async def server_rerank_func(
query: str, documents: list, top_n: int = None, **kwargs
query: str, documents: list, top_n: int = None, extra_body: dict = None
):
"""Server rerank function with configuration from environment variables"""
return await custom_rerank(
return await selected_rerank_func(
query=query,
documents=documents,
model=args.rerank_model,
base_url=args.rerank_binding_host,
api_key=args.rerank_binding_api_key,
top_n=top_n,
**kwargs,
extra_body=extra_body,
)
rerank_model_func = server_rerank_func
logger.info(
f"Rerank model configured: {args.rerank_model} (can be enabled per query)"
f"Rerank enabled: {args.rerank_model} using {args.rerank_binding} provider"
)
else:
logger.info(
"Rerank model not configured. Set RERANK_BINDING_API_KEY and RERANK_BINDING_HOST to enable reranking."
)
logger.info("Rerank disabled")
# Create ollama_server_infos from command line arguments
from lightrag.api.config import OllamaServerInfos
@@ -622,13 +633,15 @@ def create_app(args):
"enable_llm_cache": args.enable_llm_cache,
"workspace": args.workspace,
"max_graph_nodes": args.max_graph_nodes,
# Rerank configuration (based on whether rerank model is configured)
"enable_rerank": rerank_model_func is not None,
"rerank_model": args.rerank_model
if rerank_model_func is not None
# Rerank configuration
"enable_rerank": args.enable_rerank,
"rerank_configured": rerank_model_func is not None,
"rerank_binding": args.rerank_binding
if args.enable_rerank
else None,
"rerank_model": args.rerank_model if args.enable_rerank else None,
"rerank_binding_host": args.rerank_binding_host
if rerank_model_func is not None
if args.enable_rerank
else None,
# Environment variable status (requested configuration)
"summary_language": args.summary_language,

View File

@@ -2,270 +2,194 @@ from __future__ import annotations
import os
import aiohttp
from typing import Callable, Any, List, Dict, Optional
from pydantic import BaseModel, Field
from typing import Any, List, Dict, Optional
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from .utils import logger
from dotenv import load_dotenv
class RerankModel(BaseModel):
"""
Wrapper for rerank functions that can be used with LightRAG.
Example usage:
```python
from lightrag.rerank import RerankModel, jina_rerank
# Create rerank model
rerank_model = RerankModel(
rerank_func=jina_rerank,
kwargs={
"model": "BAAI/bge-reranker-v2-m3",
"api_key": "your_api_key_here",
"base_url": "https://api.jina.ai/v1/rerank"
}
)
# Use in LightRAG
rag = LightRAG(
rerank_model_func=rerank_model.rerank,
# ... other configurations
)
# Query with rerank enabled (default)
result = await rag.aquery(
"your query",
param=QueryParam(enable_rerank=True)
)
```
Or define a custom function directly:
```python
async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
return await jina_rerank(
query=query,
documents=documents,
model="BAAI/bge-reranker-v2-m3",
api_key="your_api_key_here",
top_n=top_n or 10,
**kwargs
)
rag = LightRAG(
rerank_model_func=my_rerank_func,
# ... other configurations
)
# Control rerank per query
result = await rag.aquery(
"your query",
param=QueryParam(enable_rerank=True) # Enable rerank for this query
)
```
"""
rerank_func: Callable[[Any], List[Dict]]
kwargs: Dict[str, Any] = Field(default_factory=dict)
async def rerank(
self,
query: str,
documents: List[Dict[str, Any]],
top_n: Optional[int] = None,
**extra_kwargs,
) -> List[Dict[str, Any]]:
"""Rerank documents using the configured model function."""
# Merge extra kwargs with model kwargs
kwargs = {**self.kwargs, **extra_kwargs}
return await self.rerank_func(
query=query, documents=documents, top_n=top_n, **kwargs
)
class MultiRerankModel(BaseModel):
"""Multiple rerank models for different modes/scenarios."""
# Primary rerank model (used if mode-specific models are not defined)
rerank_model: Optional[RerankModel] = None
# Mode-specific rerank models
entity_rerank_model: Optional[RerankModel] = None
relation_rerank_model: Optional[RerankModel] = None
chunk_rerank_model: Optional[RerankModel] = None
async def rerank(
self,
query: str,
documents: List[Dict[str, Any]],
mode: str = "default",
top_n: Optional[int] = None,
**kwargs,
) -> List[Dict[str, Any]]:
"""Rerank using the appropriate model based on mode."""
# Select model based on mode
if mode == "entity" and self.entity_rerank_model:
model = self.entity_rerank_model
elif mode == "relation" and self.relation_rerank_model:
model = self.relation_rerank_model
elif mode == "chunk" and self.chunk_rerank_model:
model = self.chunk_rerank_model
elif self.rerank_model:
model = self.rerank_model
else:
logger.warning(f"No rerank model available for mode: {mode}")
return documents
return await model.rerank(query, documents, top_n, **kwargs)
# use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance
# the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=(
retry_if_exception_type(aiohttp.ClientError)
| retry_if_exception_type(aiohttp.ClientResponseError)
),
)
async def generic_rerank_api(
query: str,
documents: List[Dict[str, Any]],
documents: List[str],
model: str,
base_url: str,
api_key: str,
top_n: Optional[int] = None,
**kwargs,
return_documents: Optional[bool] = None,
extra_body: Optional[Dict[str, Any]] = None,
response_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
request_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
) -> List[Dict[str, Any]]:
"""
Generic rerank function that works with Jina/Cohere compatible APIs.
Generic rerank API call for Jina/Cohere/Aliyun models.
Args:
query: The search query
documents: List of documents to rerank
model: Model identifier
documents: List of strings to rerank
model: Model name to use
base_url: API endpoint URL
api_key: API authentication key
api_key: API key for authentication
top_n: Number of top results to return
**kwargs: Additional API-specific parameters
return_documents: Whether to return document text (Jina only)
extra_body: Additional body parameters
response_format: Response format type ("standard" for Jina/Cohere, "aliyun" for Aliyun)
Returns:
List of reranked documents with relevance scores
List of dictionary of ["index": int, "relevance_score": float]
"""
if not api_key:
logger.warning("No API key provided for rerank service")
return documents
raise ValueError("API key is required")
if not documents:
return documents
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}
# Prepare documents for reranking - handle both text and dict formats
prepared_docs = []
for doc in documents:
if isinstance(doc, dict):
# Use 'content' field if available, otherwise use 'text' or convert to string
text = doc.get("content") or doc.get("text") or str(doc)
else:
text = str(doc)
prepared_docs.append(text)
# Build request payload based on request format
if request_format == "aliyun":
# Aliyun format: nested input/parameters structure
payload = {
"model": model,
"input": {
"query": query,
"documents": documents,
},
"parameters": {},
}
# Prepare request
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
# Add optional parameters to parameters object
if top_n is not None:
payload["parameters"]["top_n"] = top_n
data = {"model": model, "query": query, "documents": prepared_docs, **kwargs}
if return_documents is not None:
payload["parameters"]["return_documents"] = return_documents
if top_n is not None:
data["top_n"] = min(top_n, len(prepared_docs))
# Add extra parameters to parameters object
if extra_body:
payload["parameters"].update(extra_body)
else:
# Standard format for Jina/Cohere
payload = {
"model": model,
"query": query,
"documents": documents,
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(base_url, headers=headers, json=data) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"Rerank API error {response.status}: {error_text}")
return documents
# Add optional parameters
if top_n is not None:
payload["top_n"] = top_n
result = await response.json()
# Only Jina API supports return_documents parameter
if return_documents is not None:
payload["return_documents"] = return_documents
# Extract reranked results
if "results" in result:
# Standard format: results contain index and relevance_score
reranked_docs = []
for item in result["results"]:
if "index" in item:
doc_idx = item["index"]
if 0 <= doc_idx < len(documents):
reranked_doc = documents[doc_idx].copy()
if "relevance_score" in item:
reranked_doc["rerank_score"] = item[
"relevance_score"
]
reranked_docs.append(reranked_doc)
return reranked_docs
else:
logger.warning("Unexpected rerank API response format")
return documents
# Add extra parameters
if extra_body:
payload.update(extra_body)
except Exception as e:
logger.error(f"Error during reranking: {e}")
return documents
async def jina_rerank(
query: str,
documents: List[Dict[str, Any]],
model: str = "BAAI/bge-reranker-v2-m3",
top_n: Optional[int] = None,
base_url: str = "https://api.jina.ai/v1/rerank",
api_key: Optional[str] = None,
**kwargs,
) -> List[Dict[str, Any]]:
"""
Rerank documents using Jina AI API.
Args:
query: The search query
documents: List of documents to rerank
model: Jina rerank model name
top_n: Number of top results to return
base_url: Jina API endpoint
api_key: Jina API key
**kwargs: Additional parameters
Returns:
List of reranked documents with relevance scores
"""
if api_key is None:
api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_API_KEY")
return await generic_rerank_api(
query=query,
documents=documents,
model=model,
base_url=base_url,
api_key=api_key,
top_n=top_n,
**kwargs,
logger.debug(
f"Rerank request: {len(documents)} documents, model: {model}, format: {response_format}"
)
async with aiohttp.ClientSession() as session:
async with session.post(base_url, headers=headers, json=payload) as response:
if response.status != 200:
error_text = await response.text()
content_type = response.headers.get("content-type", "").lower()
is_html_error = (
error_text.strip().startswith("<!DOCTYPE html>")
or "text/html" in content_type
)
if is_html_error:
if response.status == 502:
clean_error = "Bad Gateway (502) - Rerank service temporarily unavailable. Please try again in a few minutes."
elif response.status == 503:
clean_error = "Service Unavailable (503) - Rerank service is temporarily overloaded. Please try again later."
elif response.status == 504:
clean_error = "Gateway Timeout (504) - Rerank service request timed out. Please try again."
else:
clean_error = f"HTTP {response.status} - Rerank service error. Please try again later."
else:
clean_error = error_text
logger.error(f"Rerank API error {response.status}: {clean_error}")
raise aiohttp.ClientResponseError(
request_info=response.request_info,
history=response.history,
status=response.status,
message=f"Rerank API error: {clean_error}",
)
response_json = await response.json()
# Handle different response formats
if response_format == "aliyun":
# Aliyun format: {"output": {"results": [...]}}
output = response_json.get("output", {})
results = output.get("results", [])
elif response_format == "standard":
# Standard format: {"results": [...]}
results = response_json.get("results", [])
else:
raise ValueError(f"Unsupported response format: {response_format}")
if not results:
logger.warning("Rerank API returned empty results")
return []
# Standardize return format
return [
{"index": result["index"], "relevance_score": result["relevance_score"]}
for result in results
]
async def cohere_rerank(
query: str,
documents: List[Dict[str, Any]],
model: str = "rerank-english-v2.0",
documents: List[str],
top_n: Optional[int] = None,
base_url: str = "https://api.cohere.ai/v1/rerank",
api_key: Optional[str] = None,
**kwargs,
model: str = "rerank-v3.5",
base_url: str = "https://ai.znipower.com:5017/rerank",
extra_body: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
"""
Rerank documents using Cohere API.
Args:
query: The search query
documents: List of documents to rerank
model: Cohere rerank model name
documents: List of strings to rerank
top_n: Number of top results to return
base_url: Cohere API endpoint
api_key: Cohere API key
**kwargs: Additional parameters
api_key: API key
model: rerank model name
base_url: API endpoint
extra_body: Additional body for http request(reserved for extra params)
Returns:
List of reranked documents with relevance scores
List of dictionary of ["index": int, "relevance_score": float]
"""
if api_key is None:
api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_API_KEY")
api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
return await generic_rerank_api(
query=query,
@@ -274,24 +198,39 @@ async def cohere_rerank(
base_url=base_url,
api_key=api_key,
top_n=top_n,
**kwargs,
return_documents=None, # Cohere doesn't support this parameter
extra_body=extra_body,
response_format="standard",
)
# Convenience function for custom API endpoints
async def custom_rerank(
async def jina_rerank(
query: str,
documents: List[Dict[str, Any]],
model: str,
base_url: str,
api_key: str,
documents: List[str],
top_n: Optional[int] = None,
**kwargs,
api_key: Optional[str] = None,
model: str = "jina-reranker-v2-base-multilingual",
base_url: str = "https://api.jina.ai/v1/rerank",
extra_body: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
"""
Rerank documents using a custom API endpoint.
This is useful for self-hosted or custom rerank services.
Rerank documents using Jina AI API.
Args:
query: The search query
documents: List of strings to rerank
top_n: Number of top results to return
api_key: API key
model: rerank model name
base_url: API endpoint
extra_body: Additional body for http request(reserved for extra params)
Returns:
List of dictionary of ["index": int, "relevance_score": float]
"""
if api_key is None:
api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
return await generic_rerank_api(
query=query,
documents=documents,
@@ -299,26 +238,112 @@ async def custom_rerank(
base_url=base_url,
api_key=api_key,
top_n=top_n,
**kwargs,
return_documents=False,
extra_body=extra_body,
response_format="standard",
)
async def ali_rerank(
query: str,
documents: List[str],
top_n: Optional[int] = None,
api_key: Optional[str] = None,
model: str = "gte-rerank-v2",
base_url: str = "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank",
extra_body: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
"""
Rerank documents using Aliyun DashScope API.
Args:
query: The search query
documents: List of strings to rerank
top_n: Number of top results to return
api_key: Aliyun API key
model: rerank model name
base_url: API endpoint
extra_body: Additional body for http request(reserved for extra params)
Returns:
List of dictionary of ["index": int, "relevance_score": float]
"""
if api_key is None:
api_key = os.getenv("DASHSCOPE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
return await generic_rerank_api(
query=query,
documents=documents,
model=model,
base_url=base_url,
api_key=api_key,
top_n=top_n,
return_documents=False, # Aliyun doesn't need this parameter
extra_body=extra_body,
response_format="aliyun",
request_format="aliyun",
)
"""Please run this test as a module:
python -m lightrag.rerank
"""
if __name__ == "__main__":
import asyncio
async def main():
# Example usage
# Example usage - documents should be strings, not dictionaries
docs = [
{"content": "The capital of France is Paris."},
{"content": "Tokyo is the capital of Japan."},
{"content": "London is the capital of England."},
"The capital of France is Paris.",
"Tokyo is the capital of Japan.",
"London is the capital of England.",
]
query = "What is the capital of France?"
result = await jina_rerank(
query=query, documents=docs, top_n=2, api_key="your-api-key-here"
)
print(result)
# Test Jina rerank
try:
print("=== Jina Rerank ===")
result = await jina_rerank(
query=query,
documents=docs,
top_n=2,
)
print("Results:")
for item in result:
print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}")
print(f"Document: {docs[item['index']]}")
except Exception as e:
print(f"Jina Error: {e}")
# Test Cohere rerank
try:
print("\n=== Cohere Rerank ===")
result = await cohere_rerank(
query=query,
documents=docs,
top_n=2,
)
print("Results:")
for item in result:
print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}")
print(f"Document: {docs[item['index']]}")
except Exception as e:
print(f"Cohere Error: {e}")
# Test Aliyun rerank
try:
print("\n=== Aliyun Rerank ===")
result = await ali_rerank(
query=query,
documents=docs,
top_n=2,
)
print("Results:")
for item in result:
print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}")
print(f"Document: {docs[item['index']]}")
except Exception as e:
print(f"Aliyun Error: {e}")
asyncio.run(main())

View File

@@ -1978,17 +1978,50 @@ async def apply_rerank_if_enabled(
return retrieved_docs
try:
# Apply reranking - let rerank_model_func handle top_k internally
reranked_docs = await rerank_func(
# Extract document content for reranking
document_texts = []
for doc in retrieved_docs:
# Try multiple possible content fields
content = (
doc.get("content")
or doc.get("text")
or doc.get("chunk_content")
or doc.get("document")
or str(doc)
)
document_texts.append(content)
# Call the new rerank function that returns index-based results
rerank_results = await rerank_func(
query=query,
documents=retrieved_docs,
top_n=top_n,
documents=document_texts,
top_n=top_n or len(retrieved_docs),
)
if reranked_docs and len(reranked_docs) > 0:
if len(reranked_docs) > top_n:
reranked_docs = reranked_docs[:top_n]
logger.info(f"Successfully reranked: {len(retrieved_docs)} chunks")
return reranked_docs
# Process rerank results based on return format
if rerank_results and len(rerank_results) > 0:
# Check if results are in the new index-based format
if isinstance(rerank_results[0], dict) and "index" in rerank_results[0]:
# New format: [{"index": 0, "relevance_score": 0.85}, ...]
reranked_docs = []
for result in rerank_results:
index = result["index"]
relevance_score = result["relevance_score"]
# Get original document and add rerank score
if 0 <= index < len(retrieved_docs):
doc = retrieved_docs[index].copy()
doc["rerank_score"] = relevance_score
reranked_docs.append(doc)
logger.info(
f"Successfully reranked: {len(reranked_docs)} chunks from {len(retrieved_docs)} original chunks"
)
return reranked_docs
else:
# Legacy format: assume it's already reranked documents
logger.info(f"Using legacy rerank format: {len(rerank_results)} chunks")
return rerank_results[:top_n] if top_n else rerank_results
else:
logger.warning("Rerank returned empty results, using original chunks")
return retrieved_docs
@@ -2027,13 +2060,6 @@ async def process_chunks_unified(
# 1. Apply reranking if enabled and query is provided
if query_param.enable_rerank and query and unique_chunks:
# 保存 chunk_id 字段,因为 rerank 可能会丢失这个字段
chunk_ids = {}
for chunk in unique_chunks:
chunk_id = chunk.get("chunk_id")
if chunk_id:
chunk_ids[id(chunk)] = chunk_id
rerank_top_k = query_param.chunk_top_k or len(unique_chunks)
unique_chunks = await apply_rerank_if_enabled(
query=query,
@@ -2043,11 +2069,6 @@ async def process_chunks_unified(
top_n=rerank_top_k,
)
# 恢复 chunk_id 字段
for chunk in unique_chunks:
if id(chunk) in chunk_ids:
chunk["chunk_id"] = chunk_ids[id(chunk)]
# 2. Filter by minimum rerank score if reranking is enabled
if query_param.enable_rerank and unique_chunks:
min_rerank_score = global_config.get("min_rerank_score", 0.5)
@@ -2095,13 +2116,6 @@ async def process_chunks_unified(
original_count = len(unique_chunks)
# Keep chunk_id field, cause truncate_list_by_token_size will lose it
chunk_ids_map = {}
for i, chunk in enumerate(unique_chunks):
chunk_id = chunk.get("chunk_id")
if chunk_id:
chunk_ids_map[i] = chunk_id
unique_chunks = truncate_list_by_token_size(
unique_chunks,
key=lambda x: json.dumps(x, ensure_ascii=False),
@@ -2109,11 +2123,6 @@ async def process_chunks_unified(
tokenizer=tokenizer,
)
# restore chunk_id feiled
for i, chunk in enumerate(unique_chunks):
if i in chunk_ids_map:
chunk["chunk_id"] = chunk_ids_map[i]
logger.debug(
f"Token truncation: {len(unique_chunks)} chunks from {original_count} "
f"(chunk available tokens: {chunk_token_limit}, source: {source_type})"