add rerank model

This commit is contained in:
zrguo
2025-07-07 22:44:59 +08:00
parent cb14ce6ff3
commit 75dd4f3498
6 changed files with 912 additions and 0 deletions

271
docs/rerank_integration.md Normal file
View File

@@ -0,0 +1,271 @@
# Rerank Integration in LightRAG
This document explains how to configure and use the rerank functionality in LightRAG to improve retrieval quality.
## ⚠️ Important: Parameter Priority
**QueryParam.top_k has higher priority than rerank_top_k configuration:**
- When you set `QueryParam(top_k=5)`, it will override the `rerank_top_k=10` setting in LightRAG configuration
- This means the actual number of documents sent to rerank will be determined by QueryParam.top_k
- For optimal rerank performance, always consider the top_k value in your QueryParam calls
- Example: `rag.aquery(query, param=QueryParam(mode="naive", top_k=20))` will use 20, not rerank_top_k
## Overview
Reranking is an optional feature that improves the quality of retrieved documents by re-ordering them based on their relevance to the query. This is particularly useful when you want higher precision in document retrieval across all query modes (naive, local, global, hybrid, mix).
## Architecture
The rerank integration follows the same design pattern as the LLM integration:
- **Configurable Models**: Support for multiple rerank providers through a generic API
- **Async Processing**: Non-blocking rerank operations
- **Error Handling**: Graceful fallback to original results
- **Optional Feature**: Can be enabled/disabled via configuration
- **Code Reuse**: Single generic implementation for Jina/Cohere compatible APIs
## Configuration
### Environment Variables
Set these variables in your `.env` file or environment:
```bash
# Enable/disable reranking
ENABLE_RERANK=True
# Rerank model configuration
RERANK_MODEL=BAAI/bge-reranker-v2-m3
RERANK_MAX_ASYNC=4
RERANK_TOP_K=10
# API configuration
RERANK_API_KEY=your_rerank_api_key_here
RERANK_BASE_URL=https://api.your-provider.com/v1/rerank
# Provider-specific keys (optional alternatives)
JINA_API_KEY=your_jina_api_key_here
COHERE_API_KEY=your_cohere_api_key_here
```
### Programmatic Configuration
```python
from lightrag import LightRAG
from lightrag.rerank import custom_rerank, RerankModel
# Method 1: Using environment variables (recommended)
rag = LightRAG(
working_dir="./rag_storage",
llm_model_func=your_llm_func,
embedding_func=your_embedding_func,
# Rerank automatically configured from environment variables
)
# Method 2: Explicit configuration
rerank_model = RerankModel(
rerank_func=custom_rerank,
kwargs={
"model": "BAAI/bge-reranker-v2-m3",
"base_url": "https://api.your-provider.com/v1/rerank",
"api_key": "your_api_key_here",
}
)
rag = LightRAG(
working_dir="./rag_storage",
llm_model_func=your_llm_func,
embedding_func=your_embedding_func,
enable_rerank=True,
rerank_model_func=rerank_model.rerank,
rerank_top_k=10,
)
```
## Supported Providers
### 1. Custom/Generic API (Recommended)
For Jina/Cohere compatible APIs:
```python
from lightrag.rerank import custom_rerank
# Your custom API endpoint
result = await custom_rerank(
query="your query",
documents=documents,
model="BAAI/bge-reranker-v2-m3",
base_url="https://api.your-provider.com/v1/rerank",
api_key="your_api_key_here",
top_k=10
)
```
### 2. Jina AI
```python
from lightrag.rerank import jina_rerank
result = await jina_rerank(
query="your query",
documents=documents,
model="BAAI/bge-reranker-v2-m3",
api_key="your_jina_api_key"
)
```
### 3. Cohere
```python
from lightrag.rerank import cohere_rerank
result = await cohere_rerank(
query="your query",
documents=documents,
model="rerank-english-v2.0",
api_key="your_cohere_api_key"
)
```
## Integration Points
Reranking is automatically applied at these key retrieval stages:
1. **Naive Mode**: After vector similarity search in `_get_vector_context`
2. **Local Mode**: After entity retrieval in `_get_node_data`
3. **Global Mode**: After relationship retrieval in `_get_edge_data`
4. **Hybrid/Mix Modes**: Applied to all relevant components
## Configuration Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `enable_rerank` | bool | False | Enable/disable reranking |
| `rerank_model_name` | str | "BAAI/bge-reranker-v2-m3" | Model identifier |
| `rerank_model_max_async` | int | 4 | Max concurrent rerank calls |
| `rerank_top_k` | int | 10 | Number of top results to return ⚠️ **Overridden by QueryParam.top_k** |
| `rerank_model_func` | callable | None | Custom rerank function |
| `rerank_model_kwargs` | dict | {} | Additional rerank parameters |
## Example Usage
### Basic Usage
```python
import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import gpt_4o_mini_complete, openai_embedding
async def main():
# Initialize with rerank enabled
rag = LightRAG(
working_dir="./rag_storage",
llm_model_func=gpt_4o_mini_complete,
embedding_func=openai_embedding,
enable_rerank=True,
)
# Insert documents
await rag.ainsert([
"Document 1 content...",
"Document 2 content...",
])
# Query with rerank (automatically applied)
result = await rag.aquery(
"Your question here",
param=QueryParam(mode="hybrid", top_k=5) # ⚠️ This top_k=5 overrides rerank_top_k
)
print(result)
asyncio.run(main())
```
### Direct Rerank Usage
```python
from lightrag.rerank import custom_rerank
async def test_rerank():
documents = [
{"content": "Text about topic A"},
{"content": "Text about topic B"},
{"content": "Text about topic C"},
]
reranked = await custom_rerank(
query="Tell me about topic A",
documents=documents,
model="BAAI/bge-reranker-v2-m3",
base_url="https://api.your-provider.com/v1/rerank",
api_key="your_api_key_here",
top_k=2
)
for doc in reranked:
print(f"Score: {doc.get('rerank_score')}, Content: {doc.get('content')}")
```
## Best Practices
1. **Parameter Priority Awareness**: Remember that QueryParam.top_k always overrides rerank_top_k configuration
2. **Performance**: Use reranking selectively for better performance vs. quality tradeoff
3. **API Limits**: Monitor API usage and implement rate limiting if needed
4. **Fallback**: Always handle rerank failures gracefully (returns original results)
5. **Top-k Selection**: Choose appropriate `top_k` values in QueryParam based on your use case
6. **Cost Management**: Consider rerank API costs in your budget planning
## Troubleshooting
### Common Issues
1. **API Key Missing**: Ensure `RERANK_API_KEY` or provider-specific keys are set
2. **Network Issues**: Check `RERANK_BASE_URL` and network connectivity
3. **Model Errors**: Verify the rerank model name is supported by your API
4. **Document Format**: Ensure documents have `content` or `text` fields
### Debug Mode
Enable debug logging to see rerank operations:
```python
import logging
logging.getLogger("lightrag.rerank").setLevel(logging.DEBUG)
```
### Error Handling
The rerank integration includes automatic fallback:
```python
# If rerank fails, original documents are returned
# No exceptions are raised to the user
# Errors are logged for debugging
```
## API Compatibility
The generic rerank API expects this response format:
```json
{
"results": [
{
"index": 0,
"relevance_score": 0.95
},
{
"index": 2,
"relevance_score": 0.87
}
]
}
```
This is compatible with:
- Jina AI Rerank API
- Cohere Rerank API
- Custom APIs following the same format

View File

@@ -179,3 +179,14 @@ QDRANT_URL=http://localhost:6333
### Redis
REDIS_URI=redis://localhost:6379
# REDIS_WORKSPACE=forced_workspace_name
# Rerank Configuration
ENABLE_RERANK=False
RERANK_MODEL=BAAI/bge-reranker-v2-m3
RERANK_MAX_ASYNC=4
RERANK_TOP_K=10
# Note: QueryParam.top_k in your code will override RERANK_TOP_K setting
# Rerank API Configuration
RERANK_API_KEY=your_rerank_api_key_here
RERANK_BASE_URL=https://api.your-provider.com/v1/rerank

193
examples/rerank_example.py Normal file
View File

@@ -0,0 +1,193 @@
"""
LightRAG Rerank Integration Example
This example demonstrates how to use rerank functionality with LightRAG
to improve retrieval quality across different query modes.
IMPORTANT: Parameter Priority
- QueryParam(top_k=N) has higher priority than rerank_top_k in LightRAG configuration
- If you set QueryParam(top_k=5), it will override rerank_top_k setting
- For optimal rerank performance, use appropriate top_k values in QueryParam
Configuration Required:
1. Set your LLM API key and base URL in llm_model_func()
2. Set your embedding API key and base URL in embedding_func()
3. Set your rerank API key and base URL in the rerank configuration
4. Or use environment variables (.env file):
- RERANK_API_KEY=your_actual_rerank_api_key
- RERANK_BASE_URL=https://your-actual-rerank-endpoint/v1/rerank
- RERANK_MODEL=your_rerank_model_name
"""
import asyncio
import os
import numpy as np
from lightrag import LightRAG, QueryParam
from lightrag.rerank import custom_rerank, RerankModel
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc, setup_logger
# Set up your working directory
WORKING_DIR = "./test_rerank"
setup_logger("test_rerank")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await openai_complete_if_cache(
"gpt-4o-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key="your_llm_api_key_here",
base_url="https://api.your-llm-provider.com/v1",
**kwargs,
)
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embed(
texts,
model="text-embedding-3-large",
api_key="your_embedding_api_key_here",
base_url="https://api.your-embedding-provider.com/v1",
)
async def create_rag_with_rerank():
"""Create LightRAG instance with rerank configuration"""
# Get embedding dimension
test_embedding = await embedding_func(["test"])
embedding_dim = test_embedding.shape[1]
print(f"Detected embedding dimension: {embedding_dim}")
# Create rerank model
rerank_model = RerankModel(
rerank_func=custom_rerank,
kwargs={
"model": "BAAI/bge-reranker-v2-m3",
"base_url": "https://api.your-rerank-provider.com/v1/rerank",
"api_key": "your_rerank_api_key_here",
}
)
# Initialize LightRAG with rerank
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dim,
max_token_size=8192,
func=embedding_func,
),
# Rerank Configuration
enable_rerank=True,
rerank_model_func=rerank_model.rerank,
rerank_top_k=10, # Note: QueryParam.top_k will override this
)
return rag
async def test_rerank_with_different_topk():
"""
Test rerank functionality with different top_k settings to demonstrate parameter priority
"""
print("🚀 Setting up LightRAG with Rerank functionality...")
rag = await create_rag_with_rerank()
# Insert sample documents
sample_docs = [
"Reranking improves retrieval quality by re-ordering documents based on relevance.",
"LightRAG is a powerful retrieval-augmented generation system with multiple query modes.",
"Vector databases enable efficient similarity search in high-dimensional embedding spaces.",
"Natural language processing has evolved with large language models and transformers.",
"Machine learning algorithms can learn patterns from data without explicit programming."
]
print("📄 Inserting sample documents...")
await rag.ainsert(sample_docs)
query = "How does reranking improve retrieval quality?"
print(f"\n🔍 Testing query: '{query}'")
print("=" * 80)
# Test different top_k values to show parameter priority
top_k_values = [2, 5, 10]
for top_k in top_k_values:
print(f"\n📊 Testing with QueryParam(top_k={top_k}) - overrides rerank_top_k=10:")
# Test naive mode with specific top_k
result = await rag.aquery(
query,
param=QueryParam(mode="naive", top_k=top_k)
)
print(f" Result length: {len(result)} characters")
print(f" Preview: {result[:100]}...")
async def test_direct_rerank():
"""Test rerank function directly"""
print("\n🔧 Direct Rerank API Test")
print("=" * 40)
documents = [
{"content": "Reranking significantly improves retrieval quality"},
{"content": "LightRAG supports advanced reranking capabilities"},
{"content": "Vector search finds semantically similar documents"},
{"content": "Natural language processing with modern transformers"},
{"content": "The quick brown fox jumps over the lazy dog"}
]
query = "rerank improve quality"
print(f"Query: '{query}'")
print(f"Documents: {len(documents)}")
try:
reranked_docs = await custom_rerank(
query=query,
documents=documents,
model="BAAI/bge-reranker-v2-m3",
base_url="https://api.your-rerank-provider.com/v1/rerank",
api_key="your_rerank_api_key_here",
top_k=3
)
print("\n✅ Rerank Results:")
for i, doc in enumerate(reranked_docs):
score = doc.get("rerank_score", "N/A")
content = doc.get("content", "")[:60]
print(f" {i+1}. Score: {score:.4f} | {content}...")
except Exception as e:
print(f"❌ Rerank failed: {e}")
async def main():
"""Main example function"""
print("🎯 LightRAG Rerank Integration Example")
print("=" * 60)
try:
# Test rerank with different top_k values
await test_rerank_with_different_topk()
# Test direct rerank
await test_direct_rerank()
print("\n✅ Example completed successfully!")
print("\n💡 Key Points:")
print(" ✓ QueryParam.top_k has higher priority than rerank_top_k")
print(" ✓ Rerank improves document relevance ordering")
print(" ✓ Configure API keys in your .env file for production")
print(" ✓ Monitor API usage and costs when using rerank services")
except Exception as e:
print(f"\n❌ Example failed: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -240,6 +240,35 @@ class LightRAG:
llm_model_kwargs: dict[str, Any] = field(default_factory=dict)
"""Additional keyword arguments passed to the LLM model function."""
# Rerank Configuration
# ---
enable_rerank: bool = field(
default=bool(os.getenv("ENABLE_RERANK", "False").lower() == "true")
)
"""Enable reranking for improved retrieval quality. Defaults to False."""
rerank_model_func: Callable[..., object] | None = field(default=None)
"""Function for reranking retrieved documents. Optional."""
rerank_model_name: str = field(
default=os.getenv("RERANK_MODEL", "BAAI/bge-reranker-v2-m3")
)
"""Name of the rerank model used for reranking documents."""
rerank_model_max_async: int = field(default=int(os.getenv("RERANK_MAX_ASYNC", 4)))
"""Maximum number of concurrent rerank calls."""
rerank_model_kwargs: dict[str, Any] = field(default_factory=dict)
"""Additional keyword arguments passed to the rerank model function."""
rerank_top_k: int = field(default=int(os.getenv("RERANK_TOP_K", 10)))
"""Number of top documents to return after reranking.
Note: This value will be overridden by QueryParam.top_k in query calls.
Example: QueryParam(top_k=5) will override rerank_top_k=10 setting.
"""
# Storage
# ---
@@ -444,6 +473,22 @@ class LightRAG:
)
)
# Init Rerank
if self.enable_rerank and self.rerank_model_func:
self.rerank_model_func = priority_limit_async_func_call(
self.rerank_model_max_async
)(
partial(
self.rerank_model_func, # type: ignore
**self.rerank_model_kwargs,
)
)
logger.info("Rerank model initialized for improved retrieval quality")
elif self.enable_rerank and not self.rerank_model_func:
logger.warning(
"Rerank is enabled but no rerank_model_func provided. Reranking will be skipped."
)
self._storages_status = StoragesStatus.CREATED
if self.auto_manage_storages_states:

View File

@@ -1783,6 +1783,15 @@ async def _get_vector_context(
if not valid_chunks:
return [], [], []
# Apply reranking if enabled
global_config = chunks_vdb.global_config
valid_chunks = await apply_rerank_if_enabled(
query=query,
retrieved_docs=valid_chunks,
global_config=global_config,
top_k=query_param.top_k,
)
maybe_trun_chunks = truncate_list_by_token_size(
valid_chunks,
key=lambda x: x["content"],
@@ -1966,6 +1975,15 @@ async def _get_node_data(
if not len(results):
return "", "", ""
# Apply reranking if enabled for entity results
global_config = entities_vdb.global_config
results = await apply_rerank_if_enabled(
query=query,
retrieved_docs=results,
global_config=global_config,
top_k=query_param.top_k,
)
# Extract all entity IDs from your results list
node_ids = [r["entity_name"] for r in results]
@@ -2269,6 +2287,15 @@ async def _get_edge_data(
if not len(results):
return "", "", ""
# Apply reranking if enabled for relationship results
global_config = relationships_vdb.global_config
results = await apply_rerank_if_enabled(
query=keywords,
retrieved_docs=results,
global_config=global_config,
top_k=query_param.top_k,
)
# Prepare edge pairs in two forms:
# For the batch edge properties function, use dicts.
edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results]
@@ -2806,3 +2833,61 @@ async def query_with_keywords(
)
else:
raise ValueError(f"Unknown mode {param.mode}")
async def apply_rerank_if_enabled(
query: str,
retrieved_docs: list[dict],
global_config: dict,
top_k: int = None,
) -> list[dict]:
"""
Apply reranking to retrieved documents if rerank is enabled.
Args:
query: The search query
retrieved_docs: List of retrieved documents
global_config: Global configuration containing rerank settings
top_k: Number of top documents to return after reranking
Returns:
Reranked documents if rerank is enabled, otherwise original documents
"""
if not global_config.get("enable_rerank", False) or not retrieved_docs:
return retrieved_docs
rerank_func = global_config.get("rerank_model_func")
if not rerank_func:
logger.debug(
"Rerank is enabled but no rerank function provided, skipping rerank"
)
return retrieved_docs
try:
# Determine top_k for reranking
rerank_top_k = top_k or global_config.get("rerank_top_k", 10)
rerank_top_k = min(rerank_top_k, len(retrieved_docs))
logger.debug(
f"Applying rerank to {len(retrieved_docs)} documents, returning top {rerank_top_k}"
)
# Apply reranking
reranked_docs = await rerank_func(
query=query,
documents=retrieved_docs,
top_k=rerank_top_k,
)
if reranked_docs and len(reranked_docs) > 0:
logger.info(
f"Successfully reranked {len(retrieved_docs)} documents to {len(reranked_docs)}"
)
return reranked_docs
else:
logger.warning("Rerank returned empty results, using original documents")
return retrieved_docs[:rerank_top_k] if rerank_top_k else retrieved_docs
except Exception as e:
logger.error(f"Error during reranking: {e}, using original documents")
return retrieved_docs

307
lightrag/rerank.py Normal file
View File

@@ -0,0 +1,307 @@
from __future__ import annotations
import os
import json
import aiohttp
import numpy as np
from typing import Callable, Any, List, Dict, Optional
from pydantic import BaseModel, Field
from dataclasses import asdict
from .utils import logger
class RerankModel(BaseModel):
"""
Pydantic model class for defining a custom rerank model.
Attributes:
rerank_func (Callable[[Any], List[Dict]]): A callable function that reranks documents.
The function should take query and documents as input and return reranked results.
kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function.
This could include parameters such as the model name, API key, etc.
Example usage:
Rerank model example from jina:
```python
rerank_model = RerankModel(
rerank_func=jina_rerank,
kwargs={
"model": "BAAI/bge-reranker-v2-m3",
"api_key": "your_api_key_here",
"base_url": "https://api.jina.ai/v1/rerank"
}
)
```
"""
rerank_func: Callable[[Any], List[Dict]]
kwargs: Dict[str, Any] = Field(default_factory=dict)
async def rerank(
self,
query: str,
documents: List[Dict[str, Any]],
top_k: Optional[int] = None,
**extra_kwargs
) -> List[Dict[str, Any]]:
"""Rerank documents using the configured model function."""
# Merge extra kwargs with model kwargs
kwargs = {**self.kwargs, **extra_kwargs}
return await self.rerank_func(
query=query,
documents=documents,
top_k=top_k,
**kwargs
)
class MultiRerankModel(BaseModel):
"""Multiple rerank models for different modes/scenarios."""
# Primary rerank model (used if mode-specific models are not defined)
rerank_model: Optional[RerankModel] = None
# Mode-specific rerank models
entity_rerank_model: Optional[RerankModel] = None
relation_rerank_model: Optional[RerankModel] = None
chunk_rerank_model: Optional[RerankModel] = None
async def rerank(
self,
query: str,
documents: List[Dict[str, Any]],
mode: str = "default",
top_k: Optional[int] = None,
**kwargs
) -> List[Dict[str, Any]]:
"""Rerank using the appropriate model based on mode."""
# Select model based on mode
if mode == "entity" and self.entity_rerank_model:
model = self.entity_rerank_model
elif mode == "relation" and self.relation_rerank_model:
model = self.relation_rerank_model
elif mode == "chunk" and self.chunk_rerank_model:
model = self.chunk_rerank_model
elif self.rerank_model:
model = self.rerank_model
else:
logger.warning(f"No rerank model available for mode: {mode}")
return documents
return await model.rerank(query, documents, top_k, **kwargs)
async def generic_rerank_api(
query: str,
documents: List[Dict[str, Any]],
model: str,
base_url: str,
api_key: str,
top_k: Optional[int] = None,
**kwargs
) -> List[Dict[str, Any]]:
"""
Generic rerank function that works with Jina/Cohere compatible APIs.
Args:
query: The search query
documents: List of documents to rerank
model: Model identifier
base_url: API endpoint URL
api_key: API authentication key
top_k: Number of top results to return
**kwargs: Additional API-specific parameters
Returns:
List of reranked documents with relevance scores
"""
if not api_key:
logger.warning("No API key provided for rerank service")
return documents
if not documents:
return documents
# Prepare documents for reranking - handle both text and dict formats
prepared_docs = []
for doc in documents:
if isinstance(doc, dict):
# Use 'content' field if available, otherwise use 'text' or convert to string
text = doc.get('content') or doc.get('text') or str(doc)
else:
text = str(doc)
prepared_docs.append(text)
# Prepare request
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
data = {
"model": model,
"query": query,
"documents": prepared_docs,
**kwargs
}
if top_k is not None:
data["top_k"] = min(top_k, len(prepared_docs))
try:
async with aiohttp.ClientSession() as session:
async with session.post(base_url, headers=headers, json=data) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"Rerank API error {response.status}: {error_text}")
return documents
result = await response.json()
# Extract reranked results
if "results" in result:
# Standard format: results contain index and relevance_score
reranked_docs = []
for item in result["results"]:
if "index" in item:
doc_idx = item["index"]
if 0 <= doc_idx < len(documents):
reranked_doc = documents[doc_idx].copy()
if "relevance_score" in item:
reranked_doc["rerank_score"] = item["relevance_score"]
reranked_docs.append(reranked_doc)
return reranked_docs
else:
logger.warning("Unexpected rerank API response format")
return documents
except Exception as e:
logger.error(f"Error during reranking: {e}")
return documents
async def jina_rerank(
query: str,
documents: List[Dict[str, Any]],
model: str = "BAAI/bge-reranker-v2-m3",
top_k: Optional[int] = None,
base_url: str = "https://api.jina.ai/v1/rerank",
api_key: Optional[str] = None,
**kwargs
) -> List[Dict[str, Any]]:
"""
Rerank documents using Jina AI API.
Args:
query: The search query
documents: List of documents to rerank
model: Jina rerank model name
top_k: Number of top results to return
base_url: Jina API endpoint
api_key: Jina API key
**kwargs: Additional parameters
Returns:
List of reranked documents with relevance scores
"""
if api_key is None:
api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_API_KEY")
return await generic_rerank_api(
query=query,
documents=documents,
model=model,
base_url=base_url,
api_key=api_key,
top_k=top_k,
**kwargs
)
async def cohere_rerank(
query: str,
documents: List[Dict[str, Any]],
model: str = "rerank-english-v2.0",
top_k: Optional[int] = None,
base_url: str = "https://api.cohere.ai/v1/rerank",
api_key: Optional[str] = None,
**kwargs
) -> List[Dict[str, Any]]:
"""
Rerank documents using Cohere API.
Args:
query: The search query
documents: List of documents to rerank
model: Cohere rerank model name
top_k: Number of top results to return
base_url: Cohere API endpoint
api_key: Cohere API key
**kwargs: Additional parameters
Returns:
List of reranked documents with relevance scores
"""
if api_key is None:
api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_API_KEY")
return await generic_rerank_api(
query=query,
documents=documents,
model=model,
base_url=base_url,
api_key=api_key,
top_k=top_k,
**kwargs
)
# Convenience function for custom API endpoints
async def custom_rerank(
query: str,
documents: List[Dict[str, Any]],
model: str,
base_url: str,
api_key: str,
top_k: Optional[int] = None,
**kwargs
) -> List[Dict[str, Any]]:
"""
Rerank documents using a custom API endpoint.
This is useful for self-hosted or custom rerank services.
"""
return await generic_rerank_api(
query=query,
documents=documents,
model=model,
base_url=base_url,
api_key=api_key,
top_k=top_k,
**kwargs
)
if __name__ == "__main__":
import asyncio
async def main():
# Example usage
docs = [
{"content": "The capital of France is Paris."},
{"content": "Tokyo is the capital of Japan."},
{"content": "London is the capital of England."},
]
query = "What is the capital of France?"
result = await jina_rerank(
query=query,
documents=docs,
top_k=2,
api_key="your-api-key-here"
)
print(result)
asyncio.run(main())