feat: Add multiple rerank provider support to LightRAG Server by adding new env vars and cli params
- Add --enable-rerank CLI argument and ENABLE_RERANK env var - Simplify rerank configuration logic to only check enable flag and binding - Update health endpoint to show enable_rerank and rerank_configured status - Improve logging messages for rerank enable/disable states - Maintain backward compatibility with default value True
This commit is contained in:
@@ -1,281 +0,0 @@
|
||||
# Rerank Integration Guide
|
||||
|
||||
LightRAG supports reranking functionality to improve retrieval quality by re-ordering documents based on their relevance to the query. Reranking is now controlled per query via the `enable_rerank` parameter (default: True).
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Environment Variables
|
||||
|
||||
Set these variables in your `.env` file or environment for rerank model configuration:
|
||||
|
||||
```bash
|
||||
# Rerank model configuration (required when enable_rerank=True in queries)
|
||||
RERANK_MODEL=BAAI/bge-reranker-v2-m3
|
||||
RERANK_BINDING_HOST=https://api.your-provider.com/v1/rerank
|
||||
RERANK_BINDING_API_KEY=your_api_key_here
|
||||
```
|
||||
|
||||
### Programmatic Configuration
|
||||
|
||||
```python
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.rerank import custom_rerank, RerankModel
|
||||
|
||||
# Method 1: Using a custom rerank function with all settings included
|
||||
async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
|
||||
return await custom_rerank(
|
||||
query=query,
|
||||
documents=documents,
|
||||
model="BAAI/bge-reranker-v2-m3",
|
||||
base_url="https://api.your-provider.com/v1/rerank",
|
||||
api_key="your_api_key_here",
|
||||
top_n=top_n or 10, # Handle top_n within the function
|
||||
**kwargs
|
||||
)
|
||||
|
||||
rag = LightRAG(
|
||||
working_dir="./rag_storage",
|
||||
llm_model_func=your_llm_func,
|
||||
embedding_func=your_embedding_func,
|
||||
rerank_model_func=my_rerank_func, # Configure rerank function
|
||||
)
|
||||
|
||||
# Query with rerank enabled (default)
|
||||
result = await rag.aquery(
|
||||
"your query",
|
||||
param=QueryParam(enable_rerank=True) # Control rerank per query
|
||||
)
|
||||
|
||||
# Query with rerank disabled
|
||||
result = await rag.aquery(
|
||||
"your query",
|
||||
param=QueryParam(enable_rerank=False)
|
||||
)
|
||||
|
||||
# Method 2: Using RerankModel wrapper
|
||||
rerank_model = RerankModel(
|
||||
rerank_func=custom_rerank,
|
||||
kwargs={
|
||||
"model": "BAAI/bge-reranker-v2-m3",
|
||||
"base_url": "https://api.your-provider.com/v1/rerank",
|
||||
"api_key": "your_api_key_here",
|
||||
}
|
||||
)
|
||||
|
||||
rag = LightRAG(
|
||||
working_dir="./rag_storage",
|
||||
llm_model_func=your_llm_func,
|
||||
embedding_func=your_embedding_func,
|
||||
rerank_model_func=rerank_model.rerank,
|
||||
)
|
||||
|
||||
# Control rerank per query
|
||||
result = await rag.aquery(
|
||||
"your query",
|
||||
param=QueryParam(
|
||||
enable_rerank=True, # Enable rerank for this query
|
||||
chunk_top_k=5 # Number of chunks to keep after reranking
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
## Supported Providers
|
||||
|
||||
### 1. Custom/Generic API (Recommended)
|
||||
|
||||
For Jina/Cohere compatible APIs:
|
||||
|
||||
```python
|
||||
from lightrag.rerank import custom_rerank
|
||||
|
||||
# Your custom API endpoint
|
||||
result = await custom_rerank(
|
||||
query="your query",
|
||||
documents=documents,
|
||||
model="BAAI/bge-reranker-v2-m3",
|
||||
base_url="https://api.your-provider.com/v1/rerank",
|
||||
api_key="your_api_key_here",
|
||||
top_n=10
|
||||
)
|
||||
```
|
||||
|
||||
### 2. Jina AI
|
||||
|
||||
```python
|
||||
from lightrag.rerank import jina_rerank
|
||||
|
||||
result = await jina_rerank(
|
||||
query="your query",
|
||||
documents=documents,
|
||||
model="BAAI/bge-reranker-v2-m3",
|
||||
api_key="your_jina_api_key",
|
||||
top_n=10
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Cohere
|
||||
|
||||
```python
|
||||
from lightrag.rerank import cohere_rerank
|
||||
|
||||
result = await cohere_rerank(
|
||||
query="your query",
|
||||
documents=documents,
|
||||
model="rerank-english-v2.0",
|
||||
api_key="your_cohere_api_key",
|
||||
top_n=10
|
||||
)
|
||||
```
|
||||
|
||||
## Integration Points
|
||||
|
||||
Reranking is automatically applied at these key retrieval stages:
|
||||
|
||||
1. **Naive Mode**: After vector similarity search in `_get_vector_context`
|
||||
2. **Local Mode**: After entity retrieval in `_get_node_data`
|
||||
3. **Global Mode**: After relationship retrieval in `_get_edge_data`
|
||||
4. **Hybrid/Mix Modes**: Applied to all relevant components
|
||||
|
||||
## Configuration Parameters
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `enable_rerank` | bool | False | Enable/disable reranking |
|
||||
| `rerank_model_func` | callable | None | Custom rerank function containing all configurations (model, API keys, top_n, etc.) |
|
||||
|
||||
## Example Usage
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.llm.openai import gpt_4o_mini_complete, openai_embedding
|
||||
from lightrag.kg.shared_storage import initialize_pipeline_status
|
||||
from lightrag.rerank import jina_rerank
|
||||
|
||||
async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
|
||||
"""Custom rerank function with all settings included"""
|
||||
return await jina_rerank(
|
||||
query=query,
|
||||
documents=documents,
|
||||
model="BAAI/bge-reranker-v2-m3",
|
||||
api_key="your_jina_api_key_here",
|
||||
top_n=top_n or 10, # Default top_n if not provided
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def main():
|
||||
# Initialize with rerank enabled
|
||||
rag = LightRAG(
|
||||
working_dir="./rag_storage",
|
||||
llm_model_func=gpt_4o_mini_complete,
|
||||
embedding_func=openai_embedding,
|
||||
rerank_model_func=my_rerank_func,
|
||||
)
|
||||
|
||||
await rag.initialize_storages()
|
||||
await initialize_pipeline_status()
|
||||
|
||||
# Insert documents
|
||||
await rag.ainsert([
|
||||
"Document 1 content...",
|
||||
"Document 2 content...",
|
||||
])
|
||||
|
||||
# Query with rerank (automatically applied)
|
||||
result = await rag.aquery(
|
||||
"Your question here",
|
||||
param=QueryParam(enable_rerank=True) # This top_n is passed to rerank function
|
||||
)
|
||||
|
||||
print(result)
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
### Direct Rerank Usage
|
||||
|
||||
```python
|
||||
from lightrag.rerank import custom_rerank
|
||||
|
||||
async def test_rerank():
|
||||
documents = [
|
||||
{"content": "Text about topic A"},
|
||||
{"content": "Text about topic B"},
|
||||
{"content": "Text about topic C"},
|
||||
]
|
||||
|
||||
reranked = await custom_rerank(
|
||||
query="Tell me about topic A",
|
||||
documents=documents,
|
||||
model="BAAI/bge-reranker-v2-m3",
|
||||
base_url="https://api.your-provider.com/v1/rerank",
|
||||
api_key="your_api_key_here",
|
||||
top_n=2
|
||||
)
|
||||
|
||||
for doc in reranked:
|
||||
print(f"Score: {doc.get('rerank_score')}, Content: {doc.get('content')}")
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Self-Contained Functions**: Include all necessary configurations (API keys, models, top_n handling) within your rerank function
|
||||
2. **Performance**: Use reranking selectively for better performance vs. quality tradeoff
|
||||
3. **API Limits**: Monitor API usage and implement rate limiting within your rerank function
|
||||
4. **Fallback**: Always handle rerank failures gracefully (returns original results)
|
||||
5. **Top-n Handling**: Handle top_n parameter appropriately within your rerank function
|
||||
6. **Cost Management**: Consider rerank API costs in your budget planning
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **API Key Missing**: Ensure API keys are properly configured within your rerank function
|
||||
2. **Network Issues**: Check API endpoints and network connectivity
|
||||
3. **Model Errors**: Verify the rerank model name is supported by your API
|
||||
4. **Document Format**: Ensure documents have `content` or `text` fields
|
||||
|
||||
### Debug Mode
|
||||
|
||||
Enable debug logging to see rerank operations:
|
||||
|
||||
```python
|
||||
import logging
|
||||
logging.getLogger("lightrag.rerank").setLevel(logging.DEBUG)
|
||||
```
|
||||
|
||||
### Error Handling
|
||||
|
||||
The rerank integration includes automatic fallback:
|
||||
|
||||
```python
|
||||
# If rerank fails, original documents are returned
|
||||
# No exceptions are raised to the user
|
||||
# Errors are logged for debugging
|
||||
```
|
||||
|
||||
## API Compatibility
|
||||
|
||||
The generic rerank API expects this response format:
|
||||
|
||||
```json
|
||||
{
|
||||
"results": [
|
||||
{
|
||||
"index": 0,
|
||||
"relevance_score": 0.95
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"relevance_score": 0.87
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
This is compatible with:
|
||||
- Jina AI Rerank API
|
||||
- Cohere Rerank API
|
||||
- Custom APIs following the same format
|
||||
30
env.example
30
env.example
@@ -85,16 +85,36 @@ ENABLE_LLM_CACHE=true
|
||||
### If reranking is enabled, the impact of chunk selection strategies will be diminished.
|
||||
# KG_CHUNK_PICK_METHOD=VECTOR
|
||||
|
||||
#########################################################
|
||||
### Reranking configuration
|
||||
### Reranker Set ENABLE_RERANK to true in reranking model is configed
|
||||
# ENABLE_RERANK=True
|
||||
### Minimum rerank score for document chunk exclusion (set to 0.0 to keep all chunks, 0.6 or above if LLM is not strong enought)
|
||||
### RERANK_BINDING type: cohere, jina, aliyun
|
||||
### For rerank model deployed by vLLM use cohere binding
|
||||
#########################################################
|
||||
ENABLE_RERANK=False
|
||||
RERANK_BINDING=cohere
|
||||
### rerank score chunk filter(set to 0.0 to keep all chunks, 0.6 or above if LLM is not strong enought)
|
||||
# MIN_RERANK_SCORE=0.0
|
||||
### Rerank model configuration (required when ENABLE_RERANK=True)
|
||||
# RERANK_MODEL=jina-reranker-v2-base-multilingual
|
||||
|
||||
### For local deployment
|
||||
# RERANK_MODEL=BAAI/bge-reranker-v2-m3
|
||||
# RERANK_BINDING_HOST=http://localhost:8000
|
||||
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
|
||||
|
||||
### Default value for Cohere AI
|
||||
# RERANK_MODEL=rerank-v3.5
|
||||
# RERANK_BINDING_HOST=https://ai.znipower.com:5017/rerank
|
||||
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
|
||||
|
||||
### Default value for Jina AI
|
||||
# RERANK_MODELjina-reranker-v2-base-multilingual
|
||||
# RERANK_BINDING_HOST=https://api.jina.ai/v1/rerank
|
||||
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
|
||||
|
||||
### Default value for Aliyun
|
||||
# RERANK_MODEL=gte-rerank-v2
|
||||
# RERANK_BINDING_HOST=https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank
|
||||
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
|
||||
|
||||
########################################
|
||||
### Document processing configuration
|
||||
########################################
|
||||
|
||||
@@ -225,6 +225,19 @@ def parse_args() -> argparse.Namespace:
|
||||
choices=["lollms", "ollama", "openai", "azure_openai", "aws_bedrock", "jina"],
|
||||
help="Embedding binding type (default: from env or ollama)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rerank-binding",
|
||||
type=str,
|
||||
default=get_env_value("RERANK_BINDING", "cohere"),
|
||||
choices=["cohere", "jina", "aliyun"],
|
||||
help="Rerank binding type (default: from env or cohere)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-rerank",
|
||||
action="store_true",
|
||||
default=get_env_value("ENABLE_RERANK", True, bool),
|
||||
help="Enable rerank functionality (default: from env or True)",
|
||||
)
|
||||
|
||||
# Conditionally add binding options defined in binding_options module
|
||||
# This will add command line arguments for all binding options (e.g., --ollama-embedding-num_ctx)
|
||||
@@ -340,6 +353,7 @@ def parse_args() -> argparse.Namespace:
|
||||
args.rerank_model = get_env_value("RERANK_MODEL", "BAAI/bge-reranker-v2-m3")
|
||||
args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None)
|
||||
args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None)
|
||||
# Note: rerank_binding is already set by argparse, no need to override from env
|
||||
|
||||
# Min rerank score configuration
|
||||
args.min_rerank_score = get_env_value(
|
||||
|
||||
@@ -390,33 +390,44 @@ def create_app(args):
|
||||
),
|
||||
)
|
||||
|
||||
# Configure rerank function if model and API are configured
|
||||
# Configure rerank function based on enable_rerank parameter
|
||||
rerank_model_func = None
|
||||
if args.rerank_binding_api_key and args.rerank_binding_host:
|
||||
from lightrag.rerank import custom_rerank
|
||||
if args.enable_rerank and args.rerank_binding:
|
||||
from lightrag.rerank import cohere_rerank, jina_rerank, ali_rerank
|
||||
|
||||
# Map rerank binding to corresponding function
|
||||
rerank_functions = {
|
||||
"cohere": cohere_rerank,
|
||||
"jina": jina_rerank,
|
||||
"aliyun": ali_rerank,
|
||||
}
|
||||
|
||||
# Select the appropriate rerank function based on binding
|
||||
selected_rerank_func = rerank_functions.get(args.rerank_binding)
|
||||
if not selected_rerank_func:
|
||||
logger.error(f"Unsupported rerank binding: {args.rerank_binding}")
|
||||
raise ValueError(f"Unsupported rerank binding: {args.rerank_binding}")
|
||||
|
||||
async def server_rerank_func(
|
||||
query: str, documents: list, top_n: int = None, **kwargs
|
||||
query: str, documents: list, top_n: int = None, extra_body: dict = None
|
||||
):
|
||||
"""Server rerank function with configuration from environment variables"""
|
||||
return await custom_rerank(
|
||||
return await selected_rerank_func(
|
||||
query=query,
|
||||
documents=documents,
|
||||
model=args.rerank_model,
|
||||
base_url=args.rerank_binding_host,
|
||||
api_key=args.rerank_binding_api_key,
|
||||
top_n=top_n,
|
||||
**kwargs,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
rerank_model_func = server_rerank_func
|
||||
logger.info(
|
||||
f"Rerank model configured: {args.rerank_model} (can be enabled per query)"
|
||||
f"Rerank enabled: {args.rerank_model} using {args.rerank_binding} provider"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Rerank model not configured. Set RERANK_BINDING_API_KEY and RERANK_BINDING_HOST to enable reranking."
|
||||
)
|
||||
logger.info("Rerank disabled")
|
||||
|
||||
# Create ollama_server_infos from command line arguments
|
||||
from lightrag.api.config import OllamaServerInfos
|
||||
@@ -622,13 +633,15 @@ def create_app(args):
|
||||
"enable_llm_cache": args.enable_llm_cache,
|
||||
"workspace": args.workspace,
|
||||
"max_graph_nodes": args.max_graph_nodes,
|
||||
# Rerank configuration (based on whether rerank model is configured)
|
||||
"enable_rerank": rerank_model_func is not None,
|
||||
"rerank_model": args.rerank_model
|
||||
if rerank_model_func is not None
|
||||
# Rerank configuration
|
||||
"enable_rerank": args.enable_rerank,
|
||||
"rerank_configured": rerank_model_func is not None,
|
||||
"rerank_binding": args.rerank_binding
|
||||
if args.enable_rerank
|
||||
else None,
|
||||
"rerank_model": args.rerank_model if args.enable_rerank else None,
|
||||
"rerank_binding_host": args.rerank_binding_host
|
||||
if rerank_model_func is not None
|
||||
if args.enable_rerank
|
||||
else None,
|
||||
# Environment variable status (requested configuration)
|
||||
"summary_language": args.summary_language,
|
||||
|
||||
@@ -2,270 +2,194 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import aiohttp
|
||||
from typing import Callable, Any, List, Dict, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from typing import Any, List, Dict, Optional
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
from .utils import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
class RerankModel(BaseModel):
|
||||
"""
|
||||
Wrapper for rerank functions that can be used with LightRAG.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
from lightrag.rerank import RerankModel, jina_rerank
|
||||
|
||||
# Create rerank model
|
||||
rerank_model = RerankModel(
|
||||
rerank_func=jina_rerank,
|
||||
kwargs={
|
||||
"model": "BAAI/bge-reranker-v2-m3",
|
||||
"api_key": "your_api_key_here",
|
||||
"base_url": "https://api.jina.ai/v1/rerank"
|
||||
}
|
||||
)
|
||||
|
||||
# Use in LightRAG
|
||||
rag = LightRAG(
|
||||
rerank_model_func=rerank_model.rerank,
|
||||
# ... other configurations
|
||||
)
|
||||
|
||||
# Query with rerank enabled (default)
|
||||
result = await rag.aquery(
|
||||
"your query",
|
||||
param=QueryParam(enable_rerank=True)
|
||||
)
|
||||
```
|
||||
|
||||
Or define a custom function directly:
|
||||
```python
|
||||
async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
|
||||
return await jina_rerank(
|
||||
query=query,
|
||||
documents=documents,
|
||||
model="BAAI/bge-reranker-v2-m3",
|
||||
api_key="your_api_key_here",
|
||||
top_n=top_n or 10,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
rag = LightRAG(
|
||||
rerank_model_func=my_rerank_func,
|
||||
# ... other configurations
|
||||
)
|
||||
|
||||
# Control rerank per query
|
||||
result = await rag.aquery(
|
||||
"your query",
|
||||
param=QueryParam(enable_rerank=True) # Enable rerank for this query
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
rerank_func: Callable[[Any], List[Dict]]
|
||||
kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
async def rerank(
|
||||
self,
|
||||
query: str,
|
||||
documents: List[Dict[str, Any]],
|
||||
top_n: Optional[int] = None,
|
||||
**extra_kwargs,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Rerank documents using the configured model function."""
|
||||
# Merge extra kwargs with model kwargs
|
||||
kwargs = {**self.kwargs, **extra_kwargs}
|
||||
return await self.rerank_func(
|
||||
query=query, documents=documents, top_n=top_n, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class MultiRerankModel(BaseModel):
|
||||
"""Multiple rerank models for different modes/scenarios."""
|
||||
|
||||
# Primary rerank model (used if mode-specific models are not defined)
|
||||
rerank_model: Optional[RerankModel] = None
|
||||
|
||||
# Mode-specific rerank models
|
||||
entity_rerank_model: Optional[RerankModel] = None
|
||||
relation_rerank_model: Optional[RerankModel] = None
|
||||
chunk_rerank_model: Optional[RerankModel] = None
|
||||
|
||||
async def rerank(
|
||||
self,
|
||||
query: str,
|
||||
documents: List[Dict[str, Any]],
|
||||
mode: str = "default",
|
||||
top_n: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Rerank using the appropriate model based on mode."""
|
||||
|
||||
# Select model based on mode
|
||||
if mode == "entity" and self.entity_rerank_model:
|
||||
model = self.entity_rerank_model
|
||||
elif mode == "relation" and self.relation_rerank_model:
|
||||
model = self.relation_rerank_model
|
||||
elif mode == "chunk" and self.chunk_rerank_model:
|
||||
model = self.chunk_rerank_model
|
||||
elif self.rerank_model:
|
||||
model = self.rerank_model
|
||||
else:
|
||||
logger.warning(f"No rerank model available for mode: {mode}")
|
||||
return documents
|
||||
|
||||
return await model.rerank(query, documents, top_n, **kwargs)
|
||||
# use the .env that is inside the current folder
|
||||
# allows to use different .env file for each lightrag instance
|
||||
# the OS environment variables take precedence over the .env file
|
||||
load_dotenv(dotenv_path=".env", override=False)
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=(
|
||||
retry_if_exception_type(aiohttp.ClientError)
|
||||
| retry_if_exception_type(aiohttp.ClientResponseError)
|
||||
),
|
||||
)
|
||||
async def generic_rerank_api(
|
||||
query: str,
|
||||
documents: List[Dict[str, Any]],
|
||||
documents: List[str],
|
||||
model: str,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
top_n: Optional[int] = None,
|
||||
**kwargs,
|
||||
return_documents: Optional[bool] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
response_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
|
||||
request_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Generic rerank function that works with Jina/Cohere compatible APIs.
|
||||
Generic rerank API call for Jina/Cohere/Aliyun models.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
documents: List of documents to rerank
|
||||
model: Model identifier
|
||||
documents: List of strings to rerank
|
||||
model: Model name to use
|
||||
base_url: API endpoint URL
|
||||
api_key: API authentication key
|
||||
api_key: API key for authentication
|
||||
top_n: Number of top results to return
|
||||
**kwargs: Additional API-specific parameters
|
||||
return_documents: Whether to return document text (Jina only)
|
||||
extra_body: Additional body parameters
|
||||
response_format: Response format type ("standard" for Jina/Cohere, "aliyun" for Aliyun)
|
||||
|
||||
Returns:
|
||||
List of reranked documents with relevance scores
|
||||
List of dictionary of ["index": int, "relevance_score": float]
|
||||
"""
|
||||
if not api_key:
|
||||
logger.warning("No API key provided for rerank service")
|
||||
return documents
|
||||
raise ValueError("API key is required")
|
||||
|
||||
if not documents:
|
||||
return documents
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
|
||||
# Prepare documents for reranking - handle both text and dict formats
|
||||
prepared_docs = []
|
||||
for doc in documents:
|
||||
if isinstance(doc, dict):
|
||||
# Use 'content' field if available, otherwise use 'text' or convert to string
|
||||
text = doc.get("content") or doc.get("text") or str(doc)
|
||||
else:
|
||||
text = str(doc)
|
||||
prepared_docs.append(text)
|
||||
# Build request payload based on request format
|
||||
if request_format == "aliyun":
|
||||
# Aliyun format: nested input/parameters structure
|
||||
payload = {
|
||||
"model": model,
|
||||
"input": {
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
},
|
||||
"parameters": {},
|
||||
}
|
||||
|
||||
# Prepare request
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
# Add optional parameters to parameters object
|
||||
if top_n is not None:
|
||||
payload["parameters"]["top_n"] = top_n
|
||||
|
||||
data = {"model": model, "query": query, "documents": prepared_docs, **kwargs}
|
||||
if return_documents is not None:
|
||||
payload["parameters"]["return_documents"] = return_documents
|
||||
|
||||
if top_n is not None:
|
||||
data["top_n"] = min(top_n, len(prepared_docs))
|
||||
# Add extra parameters to parameters object
|
||||
if extra_body:
|
||||
payload["parameters"].update(extra_body)
|
||||
else:
|
||||
# Standard format for Jina/Cohere
|
||||
payload = {
|
||||
"model": model,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(base_url, headers=headers, json=data) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Rerank API error {response.status}: {error_text}")
|
||||
return documents
|
||||
# Add optional parameters
|
||||
if top_n is not None:
|
||||
payload["top_n"] = top_n
|
||||
|
||||
result = await response.json()
|
||||
# Only Jina API supports return_documents parameter
|
||||
if return_documents is not None:
|
||||
payload["return_documents"] = return_documents
|
||||
|
||||
# Extract reranked results
|
||||
if "results" in result:
|
||||
# Standard format: results contain index and relevance_score
|
||||
reranked_docs = []
|
||||
for item in result["results"]:
|
||||
if "index" in item:
|
||||
doc_idx = item["index"]
|
||||
if 0 <= doc_idx < len(documents):
|
||||
reranked_doc = documents[doc_idx].copy()
|
||||
if "relevance_score" in item:
|
||||
reranked_doc["rerank_score"] = item[
|
||||
"relevance_score"
|
||||
]
|
||||
reranked_docs.append(reranked_doc)
|
||||
return reranked_docs
|
||||
else:
|
||||
logger.warning("Unexpected rerank API response format")
|
||||
return documents
|
||||
# Add extra parameters
|
||||
if extra_body:
|
||||
payload.update(extra_body)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during reranking: {e}")
|
||||
return documents
|
||||
|
||||
|
||||
async def jina_rerank(
|
||||
query: str,
|
||||
documents: List[Dict[str, Any]],
|
||||
model: str = "BAAI/bge-reranker-v2-m3",
|
||||
top_n: Optional[int] = None,
|
||||
base_url: str = "https://api.jina.ai/v1/rerank",
|
||||
api_key: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Rerank documents using Jina AI API.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
documents: List of documents to rerank
|
||||
model: Jina rerank model name
|
||||
top_n: Number of top results to return
|
||||
base_url: Jina API endpoint
|
||||
api_key: Jina API key
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
List of reranked documents with relevance scores
|
||||
"""
|
||||
if api_key is None:
|
||||
api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_API_KEY")
|
||||
|
||||
return await generic_rerank_api(
|
||||
query=query,
|
||||
documents=documents,
|
||||
model=model,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
top_n=top_n,
|
||||
**kwargs,
|
||||
logger.debug(
|
||||
f"Rerank request: {len(documents)} documents, model: {model}, format: {response_format}"
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(base_url, headers=headers, json=payload) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
content_type = response.headers.get("content-type", "").lower()
|
||||
is_html_error = (
|
||||
error_text.strip().startswith("<!DOCTYPE html>")
|
||||
or "text/html" in content_type
|
||||
)
|
||||
|
||||
if is_html_error:
|
||||
if response.status == 502:
|
||||
clean_error = "Bad Gateway (502) - Rerank service temporarily unavailable. Please try again in a few minutes."
|
||||
elif response.status == 503:
|
||||
clean_error = "Service Unavailable (503) - Rerank service is temporarily overloaded. Please try again later."
|
||||
elif response.status == 504:
|
||||
clean_error = "Gateway Timeout (504) - Rerank service request timed out. Please try again."
|
||||
else:
|
||||
clean_error = f"HTTP {response.status} - Rerank service error. Please try again later."
|
||||
else:
|
||||
clean_error = error_text
|
||||
|
||||
logger.error(f"Rerank API error {response.status}: {clean_error}")
|
||||
raise aiohttp.ClientResponseError(
|
||||
request_info=response.request_info,
|
||||
history=response.history,
|
||||
status=response.status,
|
||||
message=f"Rerank API error: {clean_error}",
|
||||
)
|
||||
|
||||
response_json = await response.json()
|
||||
|
||||
# Handle different response formats
|
||||
if response_format == "aliyun":
|
||||
# Aliyun format: {"output": {"results": [...]}}
|
||||
output = response_json.get("output", {})
|
||||
results = output.get("results", [])
|
||||
elif response_format == "standard":
|
||||
# Standard format: {"results": [...]}
|
||||
results = response_json.get("results", [])
|
||||
else:
|
||||
raise ValueError(f"Unsupported response format: {response_format}")
|
||||
|
||||
if not results:
|
||||
logger.warning("Rerank API returned empty results")
|
||||
return []
|
||||
|
||||
# Standardize return format
|
||||
return [
|
||||
{"index": result["index"], "relevance_score": result["relevance_score"]}
|
||||
for result in results
|
||||
]
|
||||
|
||||
|
||||
async def cohere_rerank(
|
||||
query: str,
|
||||
documents: List[Dict[str, Any]],
|
||||
model: str = "rerank-english-v2.0",
|
||||
documents: List[str],
|
||||
top_n: Optional[int] = None,
|
||||
base_url: str = "https://api.cohere.ai/v1/rerank",
|
||||
api_key: Optional[str] = None,
|
||||
**kwargs,
|
||||
model: str = "rerank-v3.5",
|
||||
base_url: str = "https://ai.znipower.com:5017/rerank",
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Rerank documents using Cohere API.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
documents: List of documents to rerank
|
||||
model: Cohere rerank model name
|
||||
documents: List of strings to rerank
|
||||
top_n: Number of top results to return
|
||||
base_url: Cohere API endpoint
|
||||
api_key: Cohere API key
|
||||
**kwargs: Additional parameters
|
||||
api_key: API key
|
||||
model: rerank model name
|
||||
base_url: API endpoint
|
||||
extra_body: Additional body for http request(reserved for extra params)
|
||||
|
||||
Returns:
|
||||
List of reranked documents with relevance scores
|
||||
List of dictionary of ["index": int, "relevance_score": float]
|
||||
"""
|
||||
if api_key is None:
|
||||
api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_API_KEY")
|
||||
api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
|
||||
|
||||
return await generic_rerank_api(
|
||||
query=query,
|
||||
@@ -274,24 +198,39 @@ async def cohere_rerank(
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
top_n=top_n,
|
||||
**kwargs,
|
||||
return_documents=None, # Cohere doesn't support this parameter
|
||||
extra_body=extra_body,
|
||||
response_format="standard",
|
||||
)
|
||||
|
||||
|
||||
# Convenience function for custom API endpoints
|
||||
async def custom_rerank(
|
||||
async def jina_rerank(
|
||||
query: str,
|
||||
documents: List[Dict[str, Any]],
|
||||
model: str,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
documents: List[str],
|
||||
top_n: Optional[int] = None,
|
||||
**kwargs,
|
||||
api_key: Optional[str] = None,
|
||||
model: str = "jina-reranker-v2-base-multilingual",
|
||||
base_url: str = "https://api.jina.ai/v1/rerank",
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Rerank documents using a custom API endpoint.
|
||||
This is useful for self-hosted or custom rerank services.
|
||||
Rerank documents using Jina AI API.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
documents: List of strings to rerank
|
||||
top_n: Number of top results to return
|
||||
api_key: API key
|
||||
model: rerank model name
|
||||
base_url: API endpoint
|
||||
extra_body: Additional body for http request(reserved for extra params)
|
||||
|
||||
Returns:
|
||||
List of dictionary of ["index": int, "relevance_score": float]
|
||||
"""
|
||||
if api_key is None:
|
||||
api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
|
||||
|
||||
return await generic_rerank_api(
|
||||
query=query,
|
||||
documents=documents,
|
||||
@@ -299,26 +238,112 @@ async def custom_rerank(
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
top_n=top_n,
|
||||
**kwargs,
|
||||
return_documents=False,
|
||||
extra_body=extra_body,
|
||||
response_format="standard",
|
||||
)
|
||||
|
||||
|
||||
async def ali_rerank(
|
||||
query: str,
|
||||
documents: List[str],
|
||||
top_n: Optional[int] = None,
|
||||
api_key: Optional[str] = None,
|
||||
model: str = "gte-rerank-v2",
|
||||
base_url: str = "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank",
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Rerank documents using Aliyun DashScope API.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
documents: List of strings to rerank
|
||||
top_n: Number of top results to return
|
||||
api_key: Aliyun API key
|
||||
model: rerank model name
|
||||
base_url: API endpoint
|
||||
extra_body: Additional body for http request(reserved for extra params)
|
||||
|
||||
Returns:
|
||||
List of dictionary of ["index": int, "relevance_score": float]
|
||||
"""
|
||||
if api_key is None:
|
||||
api_key = os.getenv("DASHSCOPE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
|
||||
|
||||
return await generic_rerank_api(
|
||||
query=query,
|
||||
documents=documents,
|
||||
model=model,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
top_n=top_n,
|
||||
return_documents=False, # Aliyun doesn't need this parameter
|
||||
extra_body=extra_body,
|
||||
response_format="aliyun",
|
||||
request_format="aliyun",
|
||||
)
|
||||
|
||||
|
||||
"""Please run this test as a module:
|
||||
python -m lightrag.rerank
|
||||
"""
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
async def main():
|
||||
# Example usage
|
||||
# Example usage - documents should be strings, not dictionaries
|
||||
docs = [
|
||||
{"content": "The capital of France is Paris."},
|
||||
{"content": "Tokyo is the capital of Japan."},
|
||||
{"content": "London is the capital of England."},
|
||||
"The capital of France is Paris.",
|
||||
"Tokyo is the capital of Japan.",
|
||||
"London is the capital of England.",
|
||||
]
|
||||
|
||||
query = "What is the capital of France?"
|
||||
|
||||
result = await jina_rerank(
|
||||
query=query, documents=docs, top_n=2, api_key="your-api-key-here"
|
||||
)
|
||||
print(result)
|
||||
# Test Jina rerank
|
||||
try:
|
||||
print("=== Jina Rerank ===")
|
||||
result = await jina_rerank(
|
||||
query=query,
|
||||
documents=docs,
|
||||
top_n=2,
|
||||
)
|
||||
print("Results:")
|
||||
for item in result:
|
||||
print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}")
|
||||
print(f"Document: {docs[item['index']]}")
|
||||
except Exception as e:
|
||||
print(f"Jina Error: {e}")
|
||||
|
||||
# Test Cohere rerank
|
||||
try:
|
||||
print("\n=== Cohere Rerank ===")
|
||||
result = await cohere_rerank(
|
||||
query=query,
|
||||
documents=docs,
|
||||
top_n=2,
|
||||
)
|
||||
print("Results:")
|
||||
for item in result:
|
||||
print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}")
|
||||
print(f"Document: {docs[item['index']]}")
|
||||
except Exception as e:
|
||||
print(f"Cohere Error: {e}")
|
||||
|
||||
# Test Aliyun rerank
|
||||
try:
|
||||
print("\n=== Aliyun Rerank ===")
|
||||
result = await ali_rerank(
|
||||
query=query,
|
||||
documents=docs,
|
||||
top_n=2,
|
||||
)
|
||||
print("Results:")
|
||||
for item in result:
|
||||
print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}")
|
||||
print(f"Document: {docs[item['index']]}")
|
||||
except Exception as e:
|
||||
print(f"Aliyun Error: {e}")
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -1978,17 +1978,50 @@ async def apply_rerank_if_enabled(
|
||||
return retrieved_docs
|
||||
|
||||
try:
|
||||
# Apply reranking - let rerank_model_func handle top_k internally
|
||||
reranked_docs = await rerank_func(
|
||||
# Extract document content for reranking
|
||||
document_texts = []
|
||||
for doc in retrieved_docs:
|
||||
# Try multiple possible content fields
|
||||
content = (
|
||||
doc.get("content")
|
||||
or doc.get("text")
|
||||
or doc.get("chunk_content")
|
||||
or doc.get("document")
|
||||
or str(doc)
|
||||
)
|
||||
document_texts.append(content)
|
||||
|
||||
# Call the new rerank function that returns index-based results
|
||||
rerank_results = await rerank_func(
|
||||
query=query,
|
||||
documents=retrieved_docs,
|
||||
top_n=top_n,
|
||||
documents=document_texts,
|
||||
top_n=top_n or len(retrieved_docs),
|
||||
)
|
||||
if reranked_docs and len(reranked_docs) > 0:
|
||||
if len(reranked_docs) > top_n:
|
||||
reranked_docs = reranked_docs[:top_n]
|
||||
logger.info(f"Successfully reranked: {len(retrieved_docs)} chunks")
|
||||
return reranked_docs
|
||||
|
||||
# Process rerank results based on return format
|
||||
if rerank_results and len(rerank_results) > 0:
|
||||
# Check if results are in the new index-based format
|
||||
if isinstance(rerank_results[0], dict) and "index" in rerank_results[0]:
|
||||
# New format: [{"index": 0, "relevance_score": 0.85}, ...]
|
||||
reranked_docs = []
|
||||
for result in rerank_results:
|
||||
index = result["index"]
|
||||
relevance_score = result["relevance_score"]
|
||||
|
||||
# Get original document and add rerank score
|
||||
if 0 <= index < len(retrieved_docs):
|
||||
doc = retrieved_docs[index].copy()
|
||||
doc["rerank_score"] = relevance_score
|
||||
reranked_docs.append(doc)
|
||||
|
||||
logger.info(
|
||||
f"Successfully reranked: {len(reranked_docs)} chunks from {len(retrieved_docs)} original chunks"
|
||||
)
|
||||
return reranked_docs
|
||||
else:
|
||||
# Legacy format: assume it's already reranked documents
|
||||
logger.info(f"Using legacy rerank format: {len(rerank_results)} chunks")
|
||||
return rerank_results[:top_n] if top_n else rerank_results
|
||||
else:
|
||||
logger.warning("Rerank returned empty results, using original chunks")
|
||||
return retrieved_docs
|
||||
@@ -2027,13 +2060,6 @@ async def process_chunks_unified(
|
||||
|
||||
# 1. Apply reranking if enabled and query is provided
|
||||
if query_param.enable_rerank and query and unique_chunks:
|
||||
# 保存 chunk_id 字段,因为 rerank 可能会丢失这个字段
|
||||
chunk_ids = {}
|
||||
for chunk in unique_chunks:
|
||||
chunk_id = chunk.get("chunk_id")
|
||||
if chunk_id:
|
||||
chunk_ids[id(chunk)] = chunk_id
|
||||
|
||||
rerank_top_k = query_param.chunk_top_k or len(unique_chunks)
|
||||
unique_chunks = await apply_rerank_if_enabled(
|
||||
query=query,
|
||||
@@ -2043,11 +2069,6 @@ async def process_chunks_unified(
|
||||
top_n=rerank_top_k,
|
||||
)
|
||||
|
||||
# 恢复 chunk_id 字段
|
||||
for chunk in unique_chunks:
|
||||
if id(chunk) in chunk_ids:
|
||||
chunk["chunk_id"] = chunk_ids[id(chunk)]
|
||||
|
||||
# 2. Filter by minimum rerank score if reranking is enabled
|
||||
if query_param.enable_rerank and unique_chunks:
|
||||
min_rerank_score = global_config.get("min_rerank_score", 0.5)
|
||||
@@ -2095,13 +2116,6 @@ async def process_chunks_unified(
|
||||
|
||||
original_count = len(unique_chunks)
|
||||
|
||||
# Keep chunk_id field, cause truncate_list_by_token_size will lose it
|
||||
chunk_ids_map = {}
|
||||
for i, chunk in enumerate(unique_chunks):
|
||||
chunk_id = chunk.get("chunk_id")
|
||||
if chunk_id:
|
||||
chunk_ids_map[i] = chunk_id
|
||||
|
||||
unique_chunks = truncate_list_by_token_size(
|
||||
unique_chunks,
|
||||
key=lambda x: json.dumps(x, ensure_ascii=False),
|
||||
@@ -2109,11 +2123,6 @@ async def process_chunks_unified(
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
# restore chunk_id feiled
|
||||
for i, chunk in enumerate(unique_chunks):
|
||||
if i in chunk_ids_map:
|
||||
chunk["chunk_id"] = chunk_ids_map[i]
|
||||
|
||||
logger.debug(
|
||||
f"Token truncation: {len(unique_chunks)} chunks from {original_count} "
|
||||
f"(chunk available tokens: {chunk_token_limit}, source: {source_type})"
|
||||
|
||||
Reference in New Issue
Block a user