From 75dd4f3498d06d754f9ddff62a6e650d639823e7 Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Mon, 7 Jul 2025 22:44:59 +0800 Subject: [PATCH] add rerank model --- docs/rerank_integration.md | 271 ++++++++++++++++++++++++++++++++ env.example | 11 ++ examples/rerank_example.py | 193 +++++++++++++++++++++++ lightrag/lightrag.py | 45 ++++++ lightrag/operate.py | 85 ++++++++++ lightrag/rerank.py | 307 +++++++++++++++++++++++++++++++++++++ 6 files changed, 912 insertions(+) create mode 100644 docs/rerank_integration.md create mode 100644 examples/rerank_example.py create mode 100644 lightrag/rerank.py diff --git a/docs/rerank_integration.md b/docs/rerank_integration.md new file mode 100644 index 00000000..647c0f91 --- /dev/null +++ b/docs/rerank_integration.md @@ -0,0 +1,271 @@ +# Rerank Integration in LightRAG + +This document explains how to configure and use the rerank functionality in LightRAG to improve retrieval quality. + +## ⚠️ Important: Parameter Priority + +**QueryParam.top_k has higher priority than rerank_top_k configuration:** + +- When you set `QueryParam(top_k=5)`, it will override the `rerank_top_k=10` setting in LightRAG configuration +- This means the actual number of documents sent to rerank will be determined by QueryParam.top_k +- For optimal rerank performance, always consider the top_k value in your QueryParam calls +- Example: `rag.aquery(query, param=QueryParam(mode="naive", top_k=20))` will use 20, not rerank_top_k + +## Overview + +Reranking is an optional feature that improves the quality of retrieved documents by re-ordering them based on their relevance to the query. This is particularly useful when you want higher precision in document retrieval across all query modes (naive, local, global, hybrid, mix). + +## Architecture + +The rerank integration follows the same design pattern as the LLM integration: + +- **Configurable Models**: Support for multiple rerank providers through a generic API +- **Async Processing**: Non-blocking rerank operations +- **Error Handling**: Graceful fallback to original results +- **Optional Feature**: Can be enabled/disabled via configuration +- **Code Reuse**: Single generic implementation for Jina/Cohere compatible APIs + +## Configuration + +### Environment Variables + +Set these variables in your `.env` file or environment: + +```bash +# Enable/disable reranking +ENABLE_RERANK=True + +# Rerank model configuration +RERANK_MODEL=BAAI/bge-reranker-v2-m3 +RERANK_MAX_ASYNC=4 +RERANK_TOP_K=10 + +# API configuration +RERANK_API_KEY=your_rerank_api_key_here +RERANK_BASE_URL=https://api.your-provider.com/v1/rerank + +# Provider-specific keys (optional alternatives) +JINA_API_KEY=your_jina_api_key_here +COHERE_API_KEY=your_cohere_api_key_here +``` + +### Programmatic Configuration + +```python +from lightrag import LightRAG +from lightrag.rerank import custom_rerank, RerankModel + +# Method 1: Using environment variables (recommended) +rag = LightRAG( + working_dir="./rag_storage", + llm_model_func=your_llm_func, + embedding_func=your_embedding_func, + # Rerank automatically configured from environment variables +) + +# Method 2: Explicit configuration +rerank_model = RerankModel( + rerank_func=custom_rerank, + kwargs={ + "model": "BAAI/bge-reranker-v2-m3", + "base_url": "https://api.your-provider.com/v1/rerank", + "api_key": "your_api_key_here", + } +) + +rag = LightRAG( + working_dir="./rag_storage", + llm_model_func=your_llm_func, + embedding_func=your_embedding_func, + enable_rerank=True, + rerank_model_func=rerank_model.rerank, + rerank_top_k=10, +) +``` + +## Supported Providers + +### 1. Custom/Generic API (Recommended) + +For Jina/Cohere compatible APIs: + +```python +from lightrag.rerank import custom_rerank + +# Your custom API endpoint +result = await custom_rerank( + query="your query", + documents=documents, + model="BAAI/bge-reranker-v2-m3", + base_url="https://api.your-provider.com/v1/rerank", + api_key="your_api_key_here", + top_k=10 +) +``` + +### 2. Jina AI + +```python +from lightrag.rerank import jina_rerank + +result = await jina_rerank( + query="your query", + documents=documents, + model="BAAI/bge-reranker-v2-m3", + api_key="your_jina_api_key" +) +``` + +### 3. Cohere + +```python +from lightrag.rerank import cohere_rerank + +result = await cohere_rerank( + query="your query", + documents=documents, + model="rerank-english-v2.0", + api_key="your_cohere_api_key" +) +``` + +## Integration Points + +Reranking is automatically applied at these key retrieval stages: + +1. **Naive Mode**: After vector similarity search in `_get_vector_context` +2. **Local Mode**: After entity retrieval in `_get_node_data` +3. **Global Mode**: After relationship retrieval in `_get_edge_data` +4. **Hybrid/Mix Modes**: Applied to all relevant components + +## Configuration Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `enable_rerank` | bool | False | Enable/disable reranking | +| `rerank_model_name` | str | "BAAI/bge-reranker-v2-m3" | Model identifier | +| `rerank_model_max_async` | int | 4 | Max concurrent rerank calls | +| `rerank_top_k` | int | 10 | Number of top results to return ⚠️ **Overridden by QueryParam.top_k** | +| `rerank_model_func` | callable | None | Custom rerank function | +| `rerank_model_kwargs` | dict | {} | Additional rerank parameters | + +## Example Usage + +### Basic Usage + +```python +import asyncio +from lightrag import LightRAG, QueryParam +from lightrag.llm.openai import gpt_4o_mini_complete, openai_embedding + +async def main(): + # Initialize with rerank enabled + rag = LightRAG( + working_dir="./rag_storage", + llm_model_func=gpt_4o_mini_complete, + embedding_func=openai_embedding, + enable_rerank=True, + ) + + # Insert documents + await rag.ainsert([ + "Document 1 content...", + "Document 2 content...", + ]) + + # Query with rerank (automatically applied) + result = await rag.aquery( + "Your question here", + param=QueryParam(mode="hybrid", top_k=5) # ⚠️ This top_k=5 overrides rerank_top_k + ) + + print(result) + +asyncio.run(main()) +``` + +### Direct Rerank Usage + +```python +from lightrag.rerank import custom_rerank + +async def test_rerank(): + documents = [ + {"content": "Text about topic A"}, + {"content": "Text about topic B"}, + {"content": "Text about topic C"}, + ] + + reranked = await custom_rerank( + query="Tell me about topic A", + documents=documents, + model="BAAI/bge-reranker-v2-m3", + base_url="https://api.your-provider.com/v1/rerank", + api_key="your_api_key_here", + top_k=2 + ) + + for doc in reranked: + print(f"Score: {doc.get('rerank_score')}, Content: {doc.get('content')}") +``` + +## Best Practices + +1. **Parameter Priority Awareness**: Remember that QueryParam.top_k always overrides rerank_top_k configuration +2. **Performance**: Use reranking selectively for better performance vs. quality tradeoff +3. **API Limits**: Monitor API usage and implement rate limiting if needed +4. **Fallback**: Always handle rerank failures gracefully (returns original results) +5. **Top-k Selection**: Choose appropriate `top_k` values in QueryParam based on your use case +6. **Cost Management**: Consider rerank API costs in your budget planning + +## Troubleshooting + +### Common Issues + +1. **API Key Missing**: Ensure `RERANK_API_KEY` or provider-specific keys are set +2. **Network Issues**: Check `RERANK_BASE_URL` and network connectivity +3. **Model Errors**: Verify the rerank model name is supported by your API +4. **Document Format**: Ensure documents have `content` or `text` fields + +### Debug Mode + +Enable debug logging to see rerank operations: + +```python +import logging +logging.getLogger("lightrag.rerank").setLevel(logging.DEBUG) +``` + +### Error Handling + +The rerank integration includes automatic fallback: + +```python +# If rerank fails, original documents are returned +# No exceptions are raised to the user +# Errors are logged for debugging +``` + +## API Compatibility + +The generic rerank API expects this response format: + +```json +{ + "results": [ + { + "index": 0, + "relevance_score": 0.95 + }, + { + "index": 2, + "relevance_score": 0.87 + } + ] +} +``` + +This is compatible with: +- Jina AI Rerank API +- Cohere Rerank API +- Custom APIs following the same format \ No newline at end of file diff --git a/env.example b/env.example index 1efe4830..49546343 100644 --- a/env.example +++ b/env.example @@ -179,3 +179,14 @@ QDRANT_URL=http://localhost:6333 ### Redis REDIS_URI=redis://localhost:6379 # REDIS_WORKSPACE=forced_workspace_name + +# Rerank Configuration +ENABLE_RERANK=False +RERANK_MODEL=BAAI/bge-reranker-v2-m3 +RERANK_MAX_ASYNC=4 +RERANK_TOP_K=10 +# Note: QueryParam.top_k in your code will override RERANK_TOP_K setting + +# Rerank API Configuration +RERANK_API_KEY=your_rerank_api_key_here +RERANK_BASE_URL=https://api.your-provider.com/v1/rerank diff --git a/examples/rerank_example.py b/examples/rerank_example.py new file mode 100644 index 00000000..30ad794d --- /dev/null +++ b/examples/rerank_example.py @@ -0,0 +1,193 @@ +""" +LightRAG Rerank Integration Example + +This example demonstrates how to use rerank functionality with LightRAG +to improve retrieval quality across different query modes. + +IMPORTANT: Parameter Priority +- QueryParam(top_k=N) has higher priority than rerank_top_k in LightRAG configuration +- If you set QueryParam(top_k=5), it will override rerank_top_k setting +- For optimal rerank performance, use appropriate top_k values in QueryParam + +Configuration Required: +1. Set your LLM API key and base URL in llm_model_func() +2. Set your embedding API key and base URL in embedding_func() +3. Set your rerank API key and base URL in the rerank configuration +4. Or use environment variables (.env file): + - RERANK_API_KEY=your_actual_rerank_api_key + - RERANK_BASE_URL=https://your-actual-rerank-endpoint/v1/rerank + - RERANK_MODEL=your_rerank_model_name +""" + +import asyncio +import os +import numpy as np + +from lightrag import LightRAG, QueryParam +from lightrag.rerank import custom_rerank, RerankModel +from lightrag.llm.openai import openai_complete_if_cache, openai_embed +from lightrag.utils import EmbeddingFunc, setup_logger + +# Set up your working directory +WORKING_DIR = "./test_rerank" +setup_logger("test_rerank") + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + return await openai_complete_if_cache( + "gpt-4o-mini", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key="your_llm_api_key_here", + base_url="https://api.your-llm-provider.com/v1", + **kwargs, + ) + +async def embedding_func(texts: list[str]) -> np.ndarray: + return await openai_embed( + texts, + model="text-embedding-3-large", + api_key="your_embedding_api_key_here", + base_url="https://api.your-embedding-provider.com/v1", + ) + +async def create_rag_with_rerank(): + """Create LightRAG instance with rerank configuration""" + + # Get embedding dimension + test_embedding = await embedding_func(["test"]) + embedding_dim = test_embedding.shape[1] + print(f"Detected embedding dimension: {embedding_dim}") + + # Create rerank model + rerank_model = RerankModel( + rerank_func=custom_rerank, + kwargs={ + "model": "BAAI/bge-reranker-v2-m3", + "base_url": "https://api.your-rerank-provider.com/v1/rerank", + "api_key": "your_rerank_api_key_here", + } + ) + + # Initialize LightRAG with rerank + rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=embedding_dim, + max_token_size=8192, + func=embedding_func, + ), + # Rerank Configuration + enable_rerank=True, + rerank_model_func=rerank_model.rerank, + rerank_top_k=10, # Note: QueryParam.top_k will override this + ) + + return rag + +async def test_rerank_with_different_topk(): + """ + Test rerank functionality with different top_k settings to demonstrate parameter priority + """ + print("🚀 Setting up LightRAG with Rerank functionality...") + + rag = await create_rag_with_rerank() + + # Insert sample documents + sample_docs = [ + "Reranking improves retrieval quality by re-ordering documents based on relevance.", + "LightRAG is a powerful retrieval-augmented generation system with multiple query modes.", + "Vector databases enable efficient similarity search in high-dimensional embedding spaces.", + "Natural language processing has evolved with large language models and transformers.", + "Machine learning algorithms can learn patterns from data without explicit programming." + ] + + print("📄 Inserting sample documents...") + await rag.ainsert(sample_docs) + + query = "How does reranking improve retrieval quality?" + print(f"\n🔍 Testing query: '{query}'") + print("=" * 80) + + # Test different top_k values to show parameter priority + top_k_values = [2, 5, 10] + + for top_k in top_k_values: + print(f"\n📊 Testing with QueryParam(top_k={top_k}) - overrides rerank_top_k=10:") + + # Test naive mode with specific top_k + result = await rag.aquery( + query, + param=QueryParam(mode="naive", top_k=top_k) + ) + print(f" Result length: {len(result)} characters") + print(f" Preview: {result[:100]}...") + +async def test_direct_rerank(): + """Test rerank function directly""" + print("\n🔧 Direct Rerank API Test") + print("=" * 40) + + documents = [ + {"content": "Reranking significantly improves retrieval quality"}, + {"content": "LightRAG supports advanced reranking capabilities"}, + {"content": "Vector search finds semantically similar documents"}, + {"content": "Natural language processing with modern transformers"}, + {"content": "The quick brown fox jumps over the lazy dog"} + ] + + query = "rerank improve quality" + print(f"Query: '{query}'") + print(f"Documents: {len(documents)}") + + try: + reranked_docs = await custom_rerank( + query=query, + documents=documents, + model="BAAI/bge-reranker-v2-m3", + base_url="https://api.your-rerank-provider.com/v1/rerank", + api_key="your_rerank_api_key_here", + top_k=3 + ) + + print("\n✅ Rerank Results:") + for i, doc in enumerate(reranked_docs): + score = doc.get("rerank_score", "N/A") + content = doc.get("content", "")[:60] + print(f" {i+1}. Score: {score:.4f} | {content}...") + + except Exception as e: + print(f"❌ Rerank failed: {e}") + +async def main(): + """Main example function""" + print("🎯 LightRAG Rerank Integration Example") + print("=" * 60) + + try: + # Test rerank with different top_k values + await test_rerank_with_different_topk() + + # Test direct rerank + await test_direct_rerank() + + print("\n✅ Example completed successfully!") + print("\n💡 Key Points:") + print(" ✓ QueryParam.top_k has higher priority than rerank_top_k") + print(" ✓ Rerank improves document relevance ordering") + print(" ✓ Configure API keys in your .env file for production") + print(" ✓ Monitor API usage and costs when using rerank services") + + except Exception as e: + print(f"\n❌ Example failed: {e}") + import traceback + traceback.print_exc() + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 5d96aeba..cee08373 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -240,6 +240,35 @@ class LightRAG: llm_model_kwargs: dict[str, Any] = field(default_factory=dict) """Additional keyword arguments passed to the LLM model function.""" + # Rerank Configuration + # --- + + enable_rerank: bool = field( + default=bool(os.getenv("ENABLE_RERANK", "False").lower() == "true") + ) + """Enable reranking for improved retrieval quality. Defaults to False.""" + + rerank_model_func: Callable[..., object] | None = field(default=None) + """Function for reranking retrieved documents. Optional.""" + + rerank_model_name: str = field( + default=os.getenv("RERANK_MODEL", "BAAI/bge-reranker-v2-m3") + ) + """Name of the rerank model used for reranking documents.""" + + rerank_model_max_async: int = field(default=int(os.getenv("RERANK_MAX_ASYNC", 4))) + """Maximum number of concurrent rerank calls.""" + + rerank_model_kwargs: dict[str, Any] = field(default_factory=dict) + """Additional keyword arguments passed to the rerank model function.""" + + rerank_top_k: int = field(default=int(os.getenv("RERANK_TOP_K", 10))) + """Number of top documents to return after reranking. + + Note: This value will be overridden by QueryParam.top_k in query calls. + Example: QueryParam(top_k=5) will override rerank_top_k=10 setting. + """ + # Storage # --- @@ -444,6 +473,22 @@ class LightRAG: ) ) + # Init Rerank + if self.enable_rerank and self.rerank_model_func: + self.rerank_model_func = priority_limit_async_func_call( + self.rerank_model_max_async + )( + partial( + self.rerank_model_func, # type: ignore + **self.rerank_model_kwargs, + ) + ) + logger.info("Rerank model initialized for improved retrieval quality") + elif self.enable_rerank and not self.rerank_model_func: + logger.warning( + "Rerank is enabled but no rerank_model_func provided. Reranking will be skipped." + ) + self._storages_status = StoragesStatus.CREATED if self.auto_manage_storages_states: diff --git a/lightrag/operate.py b/lightrag/operate.py index 88837435..b5d74c55 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1783,6 +1783,15 @@ async def _get_vector_context( if not valid_chunks: return [], [], [] + # Apply reranking if enabled + global_config = chunks_vdb.global_config + valid_chunks = await apply_rerank_if_enabled( + query=query, + retrieved_docs=valid_chunks, + global_config=global_config, + top_k=query_param.top_k, + ) + maybe_trun_chunks = truncate_list_by_token_size( valid_chunks, key=lambda x: x["content"], @@ -1966,6 +1975,15 @@ async def _get_node_data( if not len(results): return "", "", "" + # Apply reranking if enabled for entity results + global_config = entities_vdb.global_config + results = await apply_rerank_if_enabled( + query=query, + retrieved_docs=results, + global_config=global_config, + top_k=query_param.top_k, + ) + # Extract all entity IDs from your results list node_ids = [r["entity_name"] for r in results] @@ -2269,6 +2287,15 @@ async def _get_edge_data( if not len(results): return "", "", "" + # Apply reranking if enabled for relationship results + global_config = relationships_vdb.global_config + results = await apply_rerank_if_enabled( + query=keywords, + retrieved_docs=results, + global_config=global_config, + top_k=query_param.top_k, + ) + # Prepare edge pairs in two forms: # For the batch edge properties function, use dicts. edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results] @@ -2806,3 +2833,61 @@ async def query_with_keywords( ) else: raise ValueError(f"Unknown mode {param.mode}") + + +async def apply_rerank_if_enabled( + query: str, + retrieved_docs: list[dict], + global_config: dict, + top_k: int = None, +) -> list[dict]: + """ + Apply reranking to retrieved documents if rerank is enabled. + + Args: + query: The search query + retrieved_docs: List of retrieved documents + global_config: Global configuration containing rerank settings + top_k: Number of top documents to return after reranking + + Returns: + Reranked documents if rerank is enabled, otherwise original documents + """ + if not global_config.get("enable_rerank", False) or not retrieved_docs: + return retrieved_docs + + rerank_func = global_config.get("rerank_model_func") + if not rerank_func: + logger.debug( + "Rerank is enabled but no rerank function provided, skipping rerank" + ) + return retrieved_docs + + try: + # Determine top_k for reranking + rerank_top_k = top_k or global_config.get("rerank_top_k", 10) + rerank_top_k = min(rerank_top_k, len(retrieved_docs)) + + logger.debug( + f"Applying rerank to {len(retrieved_docs)} documents, returning top {rerank_top_k}" + ) + + # Apply reranking + reranked_docs = await rerank_func( + query=query, + documents=retrieved_docs, + top_k=rerank_top_k, + ) + + if reranked_docs and len(reranked_docs) > 0: + logger.info( + f"Successfully reranked {len(retrieved_docs)} documents to {len(reranked_docs)}" + ) + return reranked_docs + else: + logger.warning("Rerank returned empty results, using original documents") + return retrieved_docs[:rerank_top_k] if rerank_top_k else retrieved_docs + + except Exception as e: + logger.error(f"Error during reranking: {e}, using original documents") + return retrieved_docs diff --git a/lightrag/rerank.py b/lightrag/rerank.py new file mode 100644 index 00000000..d25a8485 --- /dev/null +++ b/lightrag/rerank.py @@ -0,0 +1,307 @@ +from __future__ import annotations + +import os +import json +import aiohttp +import numpy as np +from typing import Callable, Any, List, Dict, Optional +from pydantic import BaseModel, Field +from dataclasses import asdict + +from .utils import logger + + +class RerankModel(BaseModel): + """ + Pydantic model class for defining a custom rerank model. + + Attributes: + rerank_func (Callable[[Any], List[Dict]]): A callable function that reranks documents. + The function should take query and documents as input and return reranked results. + kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function. + This could include parameters such as the model name, API key, etc. + + Example usage: + Rerank model example from jina: + ```python + rerank_model = RerankModel( + rerank_func=jina_rerank, + kwargs={ + "model": "BAAI/bge-reranker-v2-m3", + "api_key": "your_api_key_here", + "base_url": "https://api.jina.ai/v1/rerank" + } + ) + ``` + """ + + rerank_func: Callable[[Any], List[Dict]] + kwargs: Dict[str, Any] = Field(default_factory=dict) + + async def rerank( + self, + query: str, + documents: List[Dict[str, Any]], + top_k: Optional[int] = None, + **extra_kwargs + ) -> List[Dict[str, Any]]: + """Rerank documents using the configured model function.""" + # Merge extra kwargs with model kwargs + kwargs = {**self.kwargs, **extra_kwargs} + return await self.rerank_func( + query=query, + documents=documents, + top_k=top_k, + **kwargs + ) + + +class MultiRerankModel(BaseModel): + """Multiple rerank models for different modes/scenarios.""" + + # Primary rerank model (used if mode-specific models are not defined) + rerank_model: Optional[RerankModel] = None + + # Mode-specific rerank models + entity_rerank_model: Optional[RerankModel] = None + relation_rerank_model: Optional[RerankModel] = None + chunk_rerank_model: Optional[RerankModel] = None + + async def rerank( + self, + query: str, + documents: List[Dict[str, Any]], + mode: str = "default", + top_k: Optional[int] = None, + **kwargs + ) -> List[Dict[str, Any]]: + """Rerank using the appropriate model based on mode.""" + + # Select model based on mode + if mode == "entity" and self.entity_rerank_model: + model = self.entity_rerank_model + elif mode == "relation" and self.relation_rerank_model: + model = self.relation_rerank_model + elif mode == "chunk" and self.chunk_rerank_model: + model = self.chunk_rerank_model + elif self.rerank_model: + model = self.rerank_model + else: + logger.warning(f"No rerank model available for mode: {mode}") + return documents + + return await model.rerank(query, documents, top_k, **kwargs) + + +async def generic_rerank_api( + query: str, + documents: List[Dict[str, Any]], + model: str, + base_url: str, + api_key: str, + top_k: Optional[int] = None, + **kwargs +) -> List[Dict[str, Any]]: + """ + Generic rerank function that works with Jina/Cohere compatible APIs. + + Args: + query: The search query + documents: List of documents to rerank + model: Model identifier + base_url: API endpoint URL + api_key: API authentication key + top_k: Number of top results to return + **kwargs: Additional API-specific parameters + + Returns: + List of reranked documents with relevance scores + """ + if not api_key: + logger.warning("No API key provided for rerank service") + return documents + + if not documents: + return documents + + # Prepare documents for reranking - handle both text and dict formats + prepared_docs = [] + for doc in documents: + if isinstance(doc, dict): + # Use 'content' field if available, otherwise use 'text' or convert to string + text = doc.get('content') or doc.get('text') or str(doc) + else: + text = str(doc) + prepared_docs.append(text) + + # Prepare request + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}" + } + + data = { + "model": model, + "query": query, + "documents": prepared_docs, + **kwargs + } + + if top_k is not None: + data["top_k"] = min(top_k, len(prepared_docs)) + + try: + async with aiohttp.ClientSession() as session: + async with session.post(base_url, headers=headers, json=data) as response: + if response.status != 200: + error_text = await response.text() + logger.error(f"Rerank API error {response.status}: {error_text}") + return documents + + result = await response.json() + + # Extract reranked results + if "results" in result: + # Standard format: results contain index and relevance_score + reranked_docs = [] + for item in result["results"]: + if "index" in item: + doc_idx = item["index"] + if 0 <= doc_idx < len(documents): + reranked_doc = documents[doc_idx].copy() + if "relevance_score" in item: + reranked_doc["rerank_score"] = item["relevance_score"] + reranked_docs.append(reranked_doc) + return reranked_docs + else: + logger.warning("Unexpected rerank API response format") + return documents + + except Exception as e: + logger.error(f"Error during reranking: {e}") + return documents + + +async def jina_rerank( + query: str, + documents: List[Dict[str, Any]], + model: str = "BAAI/bge-reranker-v2-m3", + top_k: Optional[int] = None, + base_url: str = "https://api.jina.ai/v1/rerank", + api_key: Optional[str] = None, + **kwargs +) -> List[Dict[str, Any]]: + """ + Rerank documents using Jina AI API. + + Args: + query: The search query + documents: List of documents to rerank + model: Jina rerank model name + top_k: Number of top results to return + base_url: Jina API endpoint + api_key: Jina API key + **kwargs: Additional parameters + + Returns: + List of reranked documents with relevance scores + """ + if api_key is None: + api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_API_KEY") + + return await generic_rerank_api( + query=query, + documents=documents, + model=model, + base_url=base_url, + api_key=api_key, + top_k=top_k, + **kwargs + ) + + +async def cohere_rerank( + query: str, + documents: List[Dict[str, Any]], + model: str = "rerank-english-v2.0", + top_k: Optional[int] = None, + base_url: str = "https://api.cohere.ai/v1/rerank", + api_key: Optional[str] = None, + **kwargs +) -> List[Dict[str, Any]]: + """ + Rerank documents using Cohere API. + + Args: + query: The search query + documents: List of documents to rerank + model: Cohere rerank model name + top_k: Number of top results to return + base_url: Cohere API endpoint + api_key: Cohere API key + **kwargs: Additional parameters + + Returns: + List of reranked documents with relevance scores + """ + if api_key is None: + api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_API_KEY") + + return await generic_rerank_api( + query=query, + documents=documents, + model=model, + base_url=base_url, + api_key=api_key, + top_k=top_k, + **kwargs + ) + + +# Convenience function for custom API endpoints +async def custom_rerank( + query: str, + documents: List[Dict[str, Any]], + model: str, + base_url: str, + api_key: str, + top_k: Optional[int] = None, + **kwargs +) -> List[Dict[str, Any]]: + """ + Rerank documents using a custom API endpoint. + This is useful for self-hosted or custom rerank services. + """ + return await generic_rerank_api( + query=query, + documents=documents, + model=model, + base_url=base_url, + api_key=api_key, + top_k=top_k, + **kwargs + ) + + +if __name__ == "__main__": + import asyncio + + async def main(): + # Example usage + docs = [ + {"content": "The capital of France is Paris."}, + {"content": "Tokyo is the capital of Japan."}, + {"content": "London is the capital of England."}, + ] + + query = "What is the capital of France?" + + result = await jina_rerank( + query=query, + documents=docs, + top_k=2, + api_key="your-api-key-here" + ) + print(result) + + asyncio.run(main()) \ No newline at end of file