From a05bbf105e3b7f996c3f60497b5146c7dbfb68f6 Mon Sep 17 00:00:00 2001 From: netbrah Date: Sat, 22 Nov 2025 16:43:13 -0500 Subject: [PATCH 1/6] Add Cohere reranker config, chunking, and tests --- env.example | 3 + examples/rerank_example.py | 13 +- lightrag/api/lightrag_server.py | 30 ++- lightrag/rerank.py | 208 ++++++++++++++++- tests/test_rerank_chunking.py | 386 ++++++++++++++++++++++++++++++++ 5 files changed, 620 insertions(+), 20 deletions(-) create mode 100644 tests/test_rerank_chunking.py diff --git a/env.example b/env.example index fea99953..c8419961 100644 --- a/env.example +++ b/env.example @@ -102,6 +102,9 @@ RERANK_BINDING=null # RERANK_MODEL=rerank-v3.5 # RERANK_BINDING_HOST=https://api.cohere.com/v2/rerank # RERANK_BINDING_API_KEY=your_rerank_api_key_here +### Cohere rerank chunking configuration (useful for models with token limits like ColBERT) +# RERANK_ENABLE_CHUNKING=true +# RERANK_MAX_TOKENS_PER_DOC=480 ### Default value for Jina AI # RERANK_MODEL=jina-reranker-v2-base-multilingual diff --git a/examples/rerank_example.py b/examples/rerank_example.py index da3d0efe..889cffe8 100644 --- a/examples/rerank_example.py +++ b/examples/rerank_example.py @@ -15,9 +15,12 @@ Configuration Required: EMBEDDING_BINDING_HOST EMBEDDING_BINDING_API_KEY 3. Set your vLLM deployed AI rerank model setting with env vars: - RERANK_MODEL - RERANK_BINDING_HOST + RERANK_BINDING=cohere + RERANK_MODEL (e.g., answerai-colbert-small-v1 or rerank-v3.5) + RERANK_BINDING_HOST (e.g., https://api.cohere.com/v2/rerank or LiteLLM proxy) RERANK_BINDING_API_KEY + RERANK_ENABLE_CHUNKING=true (optional, for models with token limits) + RERANK_MAX_TOKENS_PER_DOC=480 (optional, default 4096) Note: Rerank is controlled per query via the 'enable_rerank' parameter (default: True) """ @@ -66,9 +69,11 @@ async def embedding_func(texts: list[str]) -> np.ndarray: rerank_model_func = partial( cohere_rerank, - model=os.getenv("RERANK_MODEL"), + model=os.getenv("RERANK_MODEL", "rerank-v3.5"), api_key=os.getenv("RERANK_BINDING_API_KEY"), - base_url=os.getenv("RERANK_BINDING_HOST"), + base_url=os.getenv("RERANK_BINDING_HOST", "https://api.cohere.com/v2/rerank"), + enable_chunking=os.getenv("RERANK_ENABLE_CHUNKING", "false").lower() == "true", + max_tokens_per_doc=int(os.getenv("RERANK_MAX_TOKENS_PER_DOC", "4096")), ) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index b29e39b2..0be5d9de 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -967,15 +967,27 @@ def create_app(args): query: str, documents: list, top_n: int = None, extra_body: dict = None ): """Server rerank function with configuration from environment variables""" - return await selected_rerank_func( - query=query, - documents=documents, - top_n=top_n, - api_key=args.rerank_binding_api_key, - model=args.rerank_model, - base_url=args.rerank_binding_host, - extra_body=extra_body, - ) + # Prepare kwargs for rerank function + kwargs = { + "query": query, + "documents": documents, + "top_n": top_n, + "api_key": args.rerank_binding_api_key, + "model": args.rerank_model, + "base_url": args.rerank_binding_host, + } + + # Add Cohere-specific parameters if using cohere binding + if args.rerank_binding == "cohere": + # Enable chunking if configured (useful for models with token limits like ColBERT) + kwargs["enable_chunking"] = ( + os.getenv("RERANK_ENABLE_CHUNKING", "false").lower() == "true" + ) + kwargs["max_tokens_per_doc"] = int( + os.getenv("RERANK_MAX_TOKENS_PER_DOC", "4096") + ) + + return await selected_rerank_func(**kwargs, extra_body=extra_body) rerank_model_func = server_rerank_func logger.info( diff --git a/lightrag/rerank.py b/lightrag/rerank.py index 35551f5a..b3892d56 100644 --- a/lightrag/rerank.py +++ b/lightrag/rerank.py @@ -2,7 +2,7 @@ from __future__ import annotations import os import aiohttp -from typing import Any, List, Dict, Optional +from typing import Any, List, Dict, Optional, Tuple from tenacity import ( retry, stop_after_attempt, @@ -19,6 +19,146 @@ from dotenv import load_dotenv load_dotenv(dotenv_path=".env", override=False) +def chunk_documents_for_rerank( + documents: List[str], + max_tokens: int = 480, + overlap_tokens: int = 32, + tokenizer_model: str = "gpt-4o-mini", +) -> Tuple[List[str], List[int]]: + """ + Chunk documents that exceed token limit for reranking. + + Args: + documents: List of document strings to chunk + max_tokens: Maximum tokens per chunk (default 480 to leave margin for 512 limit) + overlap_tokens: Number of tokens to overlap between chunks + tokenizer_model: Model name for tiktoken tokenizer + + Returns: + Tuple of (chunked_documents, original_doc_indices) + - chunked_documents: List of document chunks (may be more than input) + - original_doc_indices: Maps each chunk back to its original document index + """ + try: + from .utils import TiktokenTokenizer + + tokenizer = TiktokenTokenizer(model_name=tokenizer_model) + except Exception as e: + logger.warning( + f"Failed to initialize tokenizer: {e}. Using character-based approximation." + ) + # Fallback: approximate 1 token ≈ 4 characters + max_chars = max_tokens * 4 + overlap_chars = overlap_tokens * 4 + + chunked_docs = [] + doc_indices = [] + + for idx, doc in enumerate(documents): + if len(doc) <= max_chars: + chunked_docs.append(doc) + doc_indices.append(idx) + else: + # Split into overlapping chunks + start = 0 + while start < len(doc): + end = min(start + max_chars, len(doc)) + chunk = doc[start:end] + chunked_docs.append(chunk) + doc_indices.append(idx) + + if end >= len(doc): + break + start = end - overlap_chars + + return chunked_docs, doc_indices + + # Use tokenizer for accurate chunking + chunked_docs = [] + doc_indices = [] + + for idx, doc in enumerate(documents): + tokens = tokenizer.encode(doc) + + if len(tokens) <= max_tokens: + # Document fits in one chunk + chunked_docs.append(doc) + doc_indices.append(idx) + else: + # Split into overlapping chunks + start = 0 + while start < len(tokens): + end = min(start + max_tokens, len(tokens)) + chunk_tokens = tokens[start:end] + chunk_text = tokenizer.decode(chunk_tokens) + chunked_docs.append(chunk_text) + doc_indices.append(idx) + + if end >= len(tokens): + break + start = end - overlap_tokens + + return chunked_docs, doc_indices + + +def aggregate_chunk_scores( + chunk_results: List[Dict[str, Any]], + doc_indices: List[int], + num_original_docs: int, + aggregation: str = "max", +) -> List[Dict[str, Any]]: + """ + Aggregate rerank scores from document chunks back to original documents. + + Args: + chunk_results: Rerank results for chunks [{"index": chunk_idx, "relevance_score": score}, ...] + doc_indices: Maps each chunk index to original document index + num_original_docs: Total number of original documents + aggregation: Strategy for aggregating scores ("max", "mean", "first") + + Returns: + List of results for original documents [{"index": doc_idx, "relevance_score": score}, ...] + """ + # Group scores by original document index + doc_scores: Dict[int, List[float]] = {i: [] for i in range(num_original_docs)} + + for result in chunk_results: + chunk_idx = result["index"] + score = result["relevance_score"] + + if 0 <= chunk_idx < len(doc_indices): + original_doc_idx = doc_indices[chunk_idx] + doc_scores[original_doc_idx].append(score) + + # Aggregate scores + aggregated_results = [] + for doc_idx, scores in doc_scores.items(): + if not scores: + continue + + if aggregation == "max": + final_score = max(scores) + elif aggregation == "mean": + final_score = sum(scores) / len(scores) + elif aggregation == "first": + final_score = scores[0] + else: + logger.warning(f"Unknown aggregation strategy: {aggregation}, using max") + final_score = max(scores) + + aggregated_results.append( + { + "index": doc_idx, + "relevance_score": final_score, + } + ) + + # Sort by relevance score (descending) + aggregated_results.sort(key=lambda x: x["relevance_score"], reverse=True) + + return aggregated_results + + @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=60), @@ -38,6 +178,8 @@ async def generic_rerank_api( extra_body: Optional[Dict[str, Any]] = None, response_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun" request_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun" + enable_chunking: bool = False, + max_tokens_per_doc: int = 480, ) -> List[Dict[str, Any]]: """ Generic rerank API call for Jina/Cohere/Aliyun models. @@ -52,6 +194,9 @@ async def generic_rerank_api( return_documents: Whether to return document text (Jina only) extra_body: Additional body parameters response_format: Response format type ("standard" for Jina/Cohere, "aliyun" for Aliyun) + request_format: Request format type + enable_chunking: Whether to chunk documents exceeding token limit + max_tokens_per_doc: Maximum tokens per document for chunking Returns: List of dictionary of ["index": int, "relevance_score": float] @@ -63,6 +208,17 @@ async def generic_rerank_api( if api_key is not None: headers["Authorization"] = f"Bearer {api_key}" + # Handle document chunking if enabled + original_documents = documents + doc_indices = None + if enable_chunking: + documents, doc_indices = chunk_documents_for_rerank( + documents, max_tokens=max_tokens_per_doc + ) + logger.debug( + f"Chunked {len(original_documents)} documents into {len(documents)} chunks" + ) + # Build request payload based on request format if request_format == "aliyun": # Aliyun format: nested input/parameters structure @@ -86,7 +242,7 @@ async def generic_rerank_api( if extra_body: payload["parameters"].update(extra_body) else: - # Standard format for Jina/Cohere + # Standard format for Jina/Cohere/OpenAI payload = { "model": model, "query": query, @@ -98,7 +254,7 @@ async def generic_rerank_api( payload["top_n"] = top_n # Only Jina API supports return_documents parameter - if return_documents is not None: + if return_documents is not None and response_format in ("standard",): payload["return_documents"] = return_documents # Add extra parameters @@ -147,7 +303,6 @@ async def generic_rerank_api( f"Expected 'output.results' to be list, got {type(results)}: {results}" ) results = [] - elif response_format == "standard": # Standard format: {"results": [...]} results = response_json.get("results", []) @@ -158,16 +313,28 @@ async def generic_rerank_api( results = [] else: raise ValueError(f"Unsupported response format: {response_format}") + if not results: logger.warning("Rerank API returned empty results") return [] # Standardize return format - return [ + standardized_results = [ {"index": result["index"], "relevance_score": result["relevance_score"]} for result in results ] + # Aggregate chunk scores back to original documents if chunking was enabled + if enable_chunking and doc_indices: + standardized_results = aggregate_chunk_scores( + standardized_results, + doc_indices, + len(original_documents), + aggregation="max", + ) + + return standardized_results + async def cohere_rerank( query: str, @@ -177,21 +344,46 @@ async def cohere_rerank( model: str = "rerank-v3.5", base_url: str = "https://api.cohere.com/v2/rerank", extra_body: Optional[Dict[str, Any]] = None, + enable_chunking: bool = False, + max_tokens_per_doc: int = 4096, ) -> List[Dict[str, Any]]: """ Rerank documents using Cohere API. + Supports both standard Cohere API and Cohere-compatible proxies + Args: query: The search query documents: List of strings to rerank top_n: Number of top results to return - api_key: API key - model: rerank model name + api_key: API key for authentication + model: rerank model name (default: rerank-v3.5) base_url: API endpoint extra_body: Additional body for http request(reserved for extra params) + enable_chunking: Whether to chunk documents exceeding max_tokens_per_doc + max_tokens_per_doc: Maximum tokens per document (default: 4096 for Cohere v3.5) Returns: List of dictionary of ["index": int, "relevance_score": float] + + Example: + >>> # Standard Cohere API + >>> results = await cohere_rerank( + ... query="What is the meaning of life?", + ... documents=["Doc1", "Doc2"], + ... api_key="your-cohere-key" + ... ) + + >>> # LiteLLM proxy with user authentication + >>> results = await cohere_rerank( + ... query="What is vector search?", + ... documents=["Doc1", "Doc2"], + ... model="answerai-colbert-small-v1", + ... base_url="https://llm-proxy.example.com/v2/rerank", + ... api_key="your-proxy-key", + ... enable_chunking=True, + ... max_tokens_per_doc=480 + ... ) """ if api_key is None: api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY") @@ -206,6 +398,8 @@ async def cohere_rerank( return_documents=None, # Cohere doesn't support this parameter extra_body=extra_body, response_format="standard", + enable_chunking=enable_chunking, + max_tokens_per_doc=max_tokens_per_doc, ) diff --git a/tests/test_rerank_chunking.py b/tests/test_rerank_chunking.py new file mode 100644 index 00000000..f31331d2 --- /dev/null +++ b/tests/test_rerank_chunking.py @@ -0,0 +1,386 @@ +""" +Unit tests for rerank document chunking functionality. + +Tests the chunk_documents_for_rerank and aggregate_chunk_scores functions +in lightrag/rerank.py to ensure proper document splitting and score aggregation. +""" + +import pytest +from unittest.mock import Mock, patch, AsyncMock +from lightrag.rerank import ( + chunk_documents_for_rerank, + aggregate_chunk_scores, + cohere_rerank, +) + + +class TestChunkDocumentsForRerank: + """Test suite for chunk_documents_for_rerank function""" + + def test_no_chunking_needed_for_short_docs(self): + """Documents shorter than max_tokens should not be chunked""" + documents = [ + "Short doc 1", + "Short doc 2", + "Short doc 3", + ] + + chunked_docs, doc_indices = chunk_documents_for_rerank( + documents, max_tokens=100, overlap_tokens=10 + ) + + # No chunking should occur + assert len(chunked_docs) == 3 + assert chunked_docs == documents + assert doc_indices == [0, 1, 2] + + def test_chunking_with_character_fallback(self): + """Test chunking falls back to character-based when tokenizer unavailable""" + # Create a very long document that exceeds character limit + long_doc = "a" * 2000 # 2000 characters + documents = [long_doc, "short doc"] + + with patch("lightrag.rerank.TiktokenTokenizer", side_effect=ImportError): + chunked_docs, doc_indices = chunk_documents_for_rerank( + documents, + max_tokens=100, # 100 tokens = ~400 chars + overlap_tokens=10, # 10 tokens = ~40 chars + ) + + # First doc should be split into chunks, second doc stays whole + assert len(chunked_docs) > 2 # At least one chunk from first doc + second doc + assert chunked_docs[-1] == "short doc" # Last chunk is the short doc + # Verify doc_indices maps chunks to correct original document + assert doc_indices[-1] == 1 # Last chunk maps to document 1 + + def test_chunking_with_tiktoken_tokenizer(self): + """Test chunking with actual tokenizer""" + # Create document with known token count + # Approximate: "word " = ~1 token, so 200 words ~ 200 tokens + long_doc = " ".join([f"word{i}" for i in range(200)]) + documents = [long_doc, "short"] + + chunked_docs, doc_indices = chunk_documents_for_rerank( + documents, max_tokens=50, overlap_tokens=10 + ) + + # Long doc should be split, short doc should remain + assert len(chunked_docs) > 2 + assert doc_indices[-1] == 1 # Last chunk is from second document + + # Verify overlapping chunks contain overlapping content + if len(chunked_docs) > 2: + # Check that consecutive chunks from same doc have some overlap + for i in range(len(doc_indices) - 1): + if doc_indices[i] == doc_indices[i + 1] == 0: + # Both chunks from first doc, should have overlap + chunk1_words = chunked_docs[i].split() + chunk2_words = chunked_docs[i + 1].split() + # At least one word should be common due to overlap + assert any(word in chunk2_words for word in chunk1_words[-5:]) + + def test_empty_documents(self): + """Test handling of empty document list""" + documents = [] + chunked_docs, doc_indices = chunk_documents_for_rerank(documents) + + assert chunked_docs == [] + assert doc_indices == [] + + def test_single_document_chunking(self): + """Test chunking of a single long document""" + # Create document with ~100 tokens + long_doc = " ".join([f"token{i}" for i in range(100)]) + documents = [long_doc] + + chunked_docs, doc_indices = chunk_documents_for_rerank( + documents, max_tokens=30, overlap_tokens=5 + ) + + # Should create multiple chunks + assert len(chunked_docs) > 1 + # All chunks should map to document 0 + assert all(idx == 0 for idx in doc_indices) + + +class TestAggregateChunkScores: + """Test suite for aggregate_chunk_scores function""" + + def test_no_chunking_simple_aggregation(self): + """Test aggregation when no chunking occurred (1:1 mapping)""" + chunk_results = [ + {"index": 0, "relevance_score": 0.9}, + {"index": 1, "relevance_score": 0.7}, + {"index": 2, "relevance_score": 0.5}, + ] + doc_indices = [0, 1, 2] # 1:1 mapping + num_original_docs = 3 + + aggregated = aggregate_chunk_scores( + chunk_results, doc_indices, num_original_docs, aggregation="max" + ) + + # Results should be sorted by score + assert len(aggregated) == 3 + assert aggregated[0]["index"] == 0 + assert aggregated[0]["relevance_score"] == 0.9 + assert aggregated[1]["index"] == 1 + assert aggregated[1]["relevance_score"] == 0.7 + assert aggregated[2]["index"] == 2 + assert aggregated[2]["relevance_score"] == 0.5 + + def test_max_aggregation_with_chunks(self): + """Test max aggregation strategy with multiple chunks per document""" + # 5 chunks: first 3 from doc 0, last 2 from doc 1 + chunk_results = [ + {"index": 0, "relevance_score": 0.5}, + {"index": 1, "relevance_score": 0.8}, + {"index": 2, "relevance_score": 0.6}, + {"index": 3, "relevance_score": 0.7}, + {"index": 4, "relevance_score": 0.4}, + ] + doc_indices = [0, 0, 0, 1, 1] + num_original_docs = 2 + + aggregated = aggregate_chunk_scores( + chunk_results, doc_indices, num_original_docs, aggregation="max" + ) + + # Should take max score for each document + assert len(aggregated) == 2 + assert aggregated[0]["index"] == 0 + assert aggregated[0]["relevance_score"] == 0.8 # max of 0.5, 0.8, 0.6 + assert aggregated[1]["index"] == 1 + assert aggregated[1]["relevance_score"] == 0.7 # max of 0.7, 0.4 + + def test_mean_aggregation_with_chunks(self): + """Test mean aggregation strategy""" + chunk_results = [ + {"index": 0, "relevance_score": 0.6}, + {"index": 1, "relevance_score": 0.8}, + {"index": 2, "relevance_score": 0.4}, + ] + doc_indices = [0, 0, 1] # First two chunks from doc 0, last from doc 1 + num_original_docs = 2 + + aggregated = aggregate_chunk_scores( + chunk_results, doc_indices, num_original_docs, aggregation="mean" + ) + + assert len(aggregated) == 2 + assert aggregated[0]["index"] == 0 + assert aggregated[0]["relevance_score"] == pytest.approx(0.7) # (0.6 + 0.8) / 2 + assert aggregated[1]["index"] == 1 + assert aggregated[1]["relevance_score"] == 0.4 + + def test_first_aggregation_with_chunks(self): + """Test first aggregation strategy""" + chunk_results = [ + {"index": 0, "relevance_score": 0.6}, + {"index": 1, "relevance_score": 0.8}, + {"index": 2, "relevance_score": 0.4}, + ] + doc_indices = [0, 0, 1] + num_original_docs = 2 + + aggregated = aggregate_chunk_scores( + chunk_results, doc_indices, num_original_docs, aggregation="first" + ) + + assert len(aggregated) == 2 + # First should use first score seen for each doc + assert aggregated[0]["index"] == 0 + assert aggregated[0]["relevance_score"] == 0.6 # First score for doc 0 + assert aggregated[1]["index"] == 1 + assert aggregated[1]["relevance_score"] == 0.4 + + def test_empty_chunk_results(self): + """Test handling of empty results""" + aggregated = aggregate_chunk_scores([], [], 3, aggregation="max") + assert aggregated == [] + + def test_documents_with_no_scores(self): + """Test when some documents have no chunks/scores""" + chunk_results = [ + {"index": 0, "relevance_score": 0.9}, + {"index": 1, "relevance_score": 0.7}, + ] + doc_indices = [0, 0] # Both chunks from document 0 + num_original_docs = 3 # But we have 3 documents total + + aggregated = aggregate_chunk_scores( + chunk_results, doc_indices, num_original_docs, aggregation="max" + ) + + # Only doc 0 should appear in results + assert len(aggregated) == 1 + assert aggregated[0]["index"] == 0 + + def test_unknown_aggregation_strategy(self): + """Test that unknown strategy falls back to max""" + chunk_results = [ + {"index": 0, "relevance_score": 0.6}, + {"index": 1, "relevance_score": 0.8}, + ] + doc_indices = [0, 0] + num_original_docs = 1 + + # Use invalid strategy + aggregated = aggregate_chunk_scores( + chunk_results, doc_indices, num_original_docs, aggregation="invalid" + ) + + # Should fall back to max + assert aggregated[0]["relevance_score"] == 0.8 + + +@pytest.mark.offline +class TestCohereRerankChunking: + """Integration tests for cohere_rerank with chunking enabled""" + + @pytest.mark.asyncio + async def test_cohere_rerank_with_chunking_disabled(self): + """Test that chunking can be disabled""" + documents = ["doc1", "doc2"] + query = "test query" + + # Mock the generic_rerank_api + with patch( + "lightrag.rerank.generic_rerank_api", new_callable=AsyncMock + ) as mock_api: + mock_api.return_value = [ + {"index": 0, "relevance_score": 0.9}, + {"index": 1, "relevance_score": 0.7}, + ] + + result = await cohere_rerank( + query=query, + documents=documents, + api_key="test-key", + enable_chunking=False, + max_tokens_per_doc=100, + ) + + # Verify generic_rerank_api was called with correct parameters + mock_api.assert_called_once() + call_kwargs = mock_api.call_args[1] + assert call_kwargs["enable_chunking"] is False + assert call_kwargs["max_tokens_per_doc"] == 100 + # Result should mirror mocked scores + assert len(result) == 2 + assert result[0]["index"] == 0 + assert result[0]["relevance_score"] == 0.9 + assert result[1]["index"] == 1 + assert result[1]["relevance_score"] == 0.7 + + @pytest.mark.asyncio + async def test_cohere_rerank_with_chunking_enabled(self): + """Test that chunking parameters are passed through""" + documents = ["doc1", "doc2"] + query = "test query" + + with patch( + "lightrag.rerank.generic_rerank_api", new_callable=AsyncMock + ) as mock_api: + mock_api.return_value = [ + {"index": 0, "relevance_score": 0.9}, + {"index": 1, "relevance_score": 0.7}, + ] + + result = await cohere_rerank( + query=query, + documents=documents, + api_key="test-key", + enable_chunking=True, + max_tokens_per_doc=480, + ) + + # Verify parameters were passed + call_kwargs = mock_api.call_args[1] + assert call_kwargs["enable_chunking"] is True + assert call_kwargs["max_tokens_per_doc"] == 480 + # Result should mirror mocked scores + assert len(result) == 2 + assert result[0]["index"] == 0 + assert result[0]["relevance_score"] == 0.9 + assert result[1]["index"] == 1 + assert result[1]["relevance_score"] == 0.7 + + @pytest.mark.asyncio + async def test_cohere_rerank_default_parameters(self): + """Test default parameter values for cohere_rerank""" + documents = ["doc1"] + query = "test" + + with patch( + "lightrag.rerank.generic_rerank_api", new_callable=AsyncMock + ) as mock_api: + mock_api.return_value = [{"index": 0, "relevance_score": 0.9}] + + result = await cohere_rerank( + query=query, documents=documents, api_key="test-key" + ) + + # Verify default values + call_kwargs = mock_api.call_args[1] + assert call_kwargs["enable_chunking"] is False + assert call_kwargs["max_tokens_per_doc"] == 4096 + assert call_kwargs["model"] == "rerank-v3.5" + # Result should mirror mocked scores + assert len(result) == 1 + assert result[0]["index"] == 0 + assert result[0]["relevance_score"] == 0.9 + + +@pytest.mark.offline +class TestEndToEndChunking: + """End-to-end tests for chunking workflow""" + + @pytest.mark.asyncio + async def test_end_to_end_chunking_workflow(self): + """Test complete chunking workflow from documents to aggregated results""" + # Create documents where first one needs chunking + long_doc = " ".join([f"word{i}" for i in range(100)]) + documents = [long_doc, "short doc"] + query = "test query" + + # Mock the HTTP call inside generic_rerank_api + mock_response = Mock() + mock_response.status = 200 + mock_response.json = AsyncMock( + return_value={ + "results": [ + {"index": 0, "relevance_score": 0.5}, # chunk 0 from doc 0 + {"index": 1, "relevance_score": 0.8}, # chunk 1 from doc 0 + {"index": 2, "relevance_score": 0.6}, # chunk 2 from doc 0 + {"index": 3, "relevance_score": 0.7}, # doc 1 (short) + ] + } + ) + mock_response.request_info = None + mock_response.history = None + mock_response.headers = {} + + mock_session = Mock() + mock_session.post = AsyncMock(return_value=mock_response) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock() + + with patch("aiohttp.ClientSession", return_value=mock_session): + result = await cohere_rerank( + query=query, + documents=documents, + api_key="test-key", + base_url="http://test.com/rerank", + enable_chunking=True, + max_tokens_per_doc=30, # Force chunking of long doc + ) + + # Should get 2 results (one per original document) + # The long doc's chunks should be aggregated + assert len(result) <= len(documents) + # Results should be sorted by score + assert all( + result[i]["relevance_score"] >= result[i + 1]["relevance_score"] + for i in range(len(result) - 1) + ) From e136da968bded7c2cc0b772ce0a383d891a3c19c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 24 Nov 2025 03:33:26 +0000 Subject: [PATCH 2/6] Initial plan From 1d6ea0c5f7dd48d6f2c1e9ea0cafeb54478c490d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 24 Nov 2025 03:40:58 +0000 Subject: [PATCH 3/6] Fix chunking infinite loop when overlap_tokens >= max_tokens Co-authored-by: netbrah <162479981+netbrah@users.noreply.github.com> --- lightrag/rerank.py | 10 ++++ tests/test_overlap_validation.py | 100 +++++++++++++++++++++++++++++++ 2 files changed, 110 insertions(+) create mode 100644 tests/test_overlap_validation.py diff --git a/lightrag/rerank.py b/lightrag/rerank.py index b3892d56..1b5d7612 100644 --- a/lightrag/rerank.py +++ b/lightrag/rerank.py @@ -39,6 +39,16 @@ def chunk_documents_for_rerank( - chunked_documents: List of document chunks (may be more than input) - original_doc_indices: Maps each chunk back to its original document index """ + # Clamp overlap_tokens to ensure the loop always advances + # If overlap_tokens >= max_tokens, the chunking loop would hang + if overlap_tokens >= max_tokens: + original_overlap = overlap_tokens + overlap_tokens = max(1, max_tokens - 1) + logger.warning( + f"overlap_tokens ({original_overlap}) must be less than max_tokens ({max_tokens}). " + f"Clamping to {overlap_tokens} to prevent infinite loop." + ) + try: from .utils import TiktokenTokenizer diff --git a/tests/test_overlap_validation.py b/tests/test_overlap_validation.py new file mode 100644 index 00000000..da364719 --- /dev/null +++ b/tests/test_overlap_validation.py @@ -0,0 +1,100 @@ +""" +Test for overlap_tokens validation to prevent infinite loop. + +This test validates the fix for the bug where overlap_tokens >= max_tokens +causes an infinite loop in the chunking function. +""" + +from lightrag.rerank import chunk_documents_for_rerank + + +class TestOverlapValidation: + """Test suite for overlap_tokens validation""" + + def test_overlap_greater_than_max_tokens(self): + """Test that overlap_tokens > max_tokens is clamped and doesn't hang""" + documents = [" ".join([f"word{i}" for i in range(100)])] + + # This should clamp overlap_tokens to 29 (max_tokens - 1) + chunked_docs, doc_indices = chunk_documents_for_rerank( + documents, max_tokens=30, overlap_tokens=32 + ) + + # Should complete without hanging + assert len(chunked_docs) > 0 + assert all(idx == 0 for idx in doc_indices) + + def test_overlap_equal_to_max_tokens(self): + """Test that overlap_tokens == max_tokens is clamped and doesn't hang""" + documents = [" ".join([f"word{i}" for i in range(100)])] + + # This should clamp overlap_tokens to 29 (max_tokens - 1) + chunked_docs, doc_indices = chunk_documents_for_rerank( + documents, max_tokens=30, overlap_tokens=30 + ) + + # Should complete without hanging + assert len(chunked_docs) > 0 + assert all(idx == 0 for idx in doc_indices) + + def test_overlap_slightly_less_than_max_tokens(self): + """Test that overlap_tokens < max_tokens works normally""" + documents = [" ".join([f"word{i}" for i in range(100)])] + + # This should work without clamping + chunked_docs, doc_indices = chunk_documents_for_rerank( + documents, max_tokens=30, overlap_tokens=29 + ) + + # Should complete successfully + assert len(chunked_docs) > 0 + assert all(idx == 0 for idx in doc_indices) + + def test_small_max_tokens_with_large_overlap(self): + """Test edge case with very small max_tokens""" + documents = [" ".join([f"word{i}" for i in range(50)])] + + # max_tokens=5, overlap_tokens=10 should clamp to 4 + chunked_docs, doc_indices = chunk_documents_for_rerank( + documents, max_tokens=5, overlap_tokens=10 + ) + + # Should complete without hanging + assert len(chunked_docs) > 0 + assert all(idx == 0 for idx in doc_indices) + + def test_multiple_documents_with_invalid_overlap(self): + """Test multiple documents with overlap_tokens >= max_tokens""" + documents = [ + " ".join([f"word{i}" for i in range(50)]), + "short document", + " ".join([f"word{i}" for i in range(75)]), + ] + + # overlap_tokens > max_tokens + chunked_docs, doc_indices = chunk_documents_for_rerank( + documents, max_tokens=25, overlap_tokens=30 + ) + + # Should complete successfully and chunk the long documents + assert len(chunked_docs) >= len(documents) + # Short document should not be chunked + assert "short document" in chunked_docs + + def test_normal_operation_unaffected(self): + """Test that normal cases continue to work correctly""" + documents = [ + " ".join([f"word{i}" for i in range(100)]), + "short doc", + ] + + # Normal case: overlap_tokens (10) < max_tokens (50) + chunked_docs, doc_indices = chunk_documents_for_rerank( + documents, max_tokens=50, overlap_tokens=10 + ) + + # Long document should be chunked, short one should not + assert len(chunked_docs) > 2 # At least 3 chunks (2 from long doc + 1 short) + assert "short doc" in chunked_docs + # Verify doc_indices maps correctly + assert doc_indices[-1] == 1 # Last chunk is from second document From 8835fc244a90017b8fc98f60017ae4e78e975c0b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 24 Nov 2025 03:43:05 +0000 Subject: [PATCH 4/6] Improve edge case handling for max_tokens=1 Co-authored-by: netbrah <162479981+netbrah@users.noreply.github.com> --- lightrag/rerank.py | 4 +++- tests/test_overlap_validation.py | 13 +++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/lightrag/rerank.py b/lightrag/rerank.py index 1b5d7612..81632b71 100644 --- a/lightrag/rerank.py +++ b/lightrag/rerank.py @@ -43,7 +43,9 @@ def chunk_documents_for_rerank( # If overlap_tokens >= max_tokens, the chunking loop would hang if overlap_tokens >= max_tokens: original_overlap = overlap_tokens - overlap_tokens = max(1, max_tokens - 1) + # Ensure overlap is at least 1 token less than max to guarantee progress + # For very small max_tokens (e.g., 1), set overlap to 0 + overlap_tokens = max(0, max_tokens - 1) logger.warning( f"overlap_tokens ({original_overlap}) must be less than max_tokens ({max_tokens}). " f"Clamping to {overlap_tokens} to prevent infinite loop." diff --git a/tests/test_overlap_validation.py b/tests/test_overlap_validation.py index da364719..7f84a3cf 100644 --- a/tests/test_overlap_validation.py +++ b/tests/test_overlap_validation.py @@ -98,3 +98,16 @@ class TestOverlapValidation: assert "short doc" in chunked_docs # Verify doc_indices maps correctly assert doc_indices[-1] == 1 # Last chunk is from second document + + def test_edge_case_max_tokens_one(self): + """Test edge case where max_tokens=1""" + documents = [" ".join([f"word{i}" for i in range(20)])] + + # max_tokens=1, overlap_tokens=5 should clamp to 0 + chunked_docs, doc_indices = chunk_documents_for_rerank( + documents, max_tokens=1, overlap_tokens=5 + ) + + # Should complete without hanging + assert len(chunked_docs) > 0 + assert all(idx == 0 for idx in doc_indices) From 561ba4e4b5d5d9b7939930be479f2fad353d1128 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 3 Dec 2025 12:40:48 +0800 Subject: [PATCH 5/6] Fix trailing whitespace and update test mocking for rerank module MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Remove trailing whitespace • Fix TiktokenTokenizer import patch • Add async context manager mocks • Update aiohttp.ClientSession patch • Improve test reliability --- lightrag/rerank.py | 2 +- tests/test_overlap_validation.py | 28 ++++++++++++++-------------- tests/test_rerank_chunking.py | 12 ++++++++---- 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/lightrag/rerank.py b/lightrag/rerank.py index 81632b71..2e22f19a 100644 --- a/lightrag/rerank.py +++ b/lightrag/rerank.py @@ -50,7 +50,7 @@ def chunk_documents_for_rerank( f"overlap_tokens ({original_overlap}) must be less than max_tokens ({max_tokens}). " f"Clamping to {overlap_tokens} to prevent infinite loop." ) - + try: from .utils import TiktokenTokenizer diff --git a/tests/test_overlap_validation.py b/tests/test_overlap_validation.py index 7f84a3cf..4e7c9cbd 100644 --- a/tests/test_overlap_validation.py +++ b/tests/test_overlap_validation.py @@ -14,12 +14,12 @@ class TestOverlapValidation: def test_overlap_greater_than_max_tokens(self): """Test that overlap_tokens > max_tokens is clamped and doesn't hang""" documents = [" ".join([f"word{i}" for i in range(100)])] - + # This should clamp overlap_tokens to 29 (max_tokens - 1) chunked_docs, doc_indices = chunk_documents_for_rerank( documents, max_tokens=30, overlap_tokens=32 ) - + # Should complete without hanging assert len(chunked_docs) > 0 assert all(idx == 0 for idx in doc_indices) @@ -27,12 +27,12 @@ class TestOverlapValidation: def test_overlap_equal_to_max_tokens(self): """Test that overlap_tokens == max_tokens is clamped and doesn't hang""" documents = [" ".join([f"word{i}" for i in range(100)])] - + # This should clamp overlap_tokens to 29 (max_tokens - 1) chunked_docs, doc_indices = chunk_documents_for_rerank( documents, max_tokens=30, overlap_tokens=30 ) - + # Should complete without hanging assert len(chunked_docs) > 0 assert all(idx == 0 for idx in doc_indices) @@ -40,12 +40,12 @@ class TestOverlapValidation: def test_overlap_slightly_less_than_max_tokens(self): """Test that overlap_tokens < max_tokens works normally""" documents = [" ".join([f"word{i}" for i in range(100)])] - + # This should work without clamping chunked_docs, doc_indices = chunk_documents_for_rerank( documents, max_tokens=30, overlap_tokens=29 ) - + # Should complete successfully assert len(chunked_docs) > 0 assert all(idx == 0 for idx in doc_indices) @@ -53,12 +53,12 @@ class TestOverlapValidation: def test_small_max_tokens_with_large_overlap(self): """Test edge case with very small max_tokens""" documents = [" ".join([f"word{i}" for i in range(50)])] - + # max_tokens=5, overlap_tokens=10 should clamp to 4 chunked_docs, doc_indices = chunk_documents_for_rerank( documents, max_tokens=5, overlap_tokens=10 ) - + # Should complete without hanging assert len(chunked_docs) > 0 assert all(idx == 0 for idx in doc_indices) @@ -70,12 +70,12 @@ class TestOverlapValidation: "short document", " ".join([f"word{i}" for i in range(75)]), ] - + # overlap_tokens > max_tokens chunked_docs, doc_indices = chunk_documents_for_rerank( documents, max_tokens=25, overlap_tokens=30 ) - + # Should complete successfully and chunk the long documents assert len(chunked_docs) >= len(documents) # Short document should not be chunked @@ -87,12 +87,12 @@ class TestOverlapValidation: " ".join([f"word{i}" for i in range(100)]), "short doc", ] - + # Normal case: overlap_tokens (10) < max_tokens (50) chunked_docs, doc_indices = chunk_documents_for_rerank( documents, max_tokens=50, overlap_tokens=10 ) - + # Long document should be chunked, short one should not assert len(chunked_docs) > 2 # At least 3 chunks (2 from long doc + 1 short) assert "short doc" in chunked_docs @@ -102,12 +102,12 @@ class TestOverlapValidation: def test_edge_case_max_tokens_one(self): """Test edge case where max_tokens=1""" documents = [" ".join([f"word{i}" for i in range(20)])] - + # max_tokens=1, overlap_tokens=5 should clamp to 0 chunked_docs, doc_indices = chunk_documents_for_rerank( documents, max_tokens=1, overlap_tokens=5 ) - + # Should complete without hanging assert len(chunked_docs) > 0 assert all(idx == 0 for idx in doc_indices) diff --git a/tests/test_rerank_chunking.py b/tests/test_rerank_chunking.py index f31331d2..1700988a 100644 --- a/tests/test_rerank_chunking.py +++ b/tests/test_rerank_chunking.py @@ -40,7 +40,7 @@ class TestChunkDocumentsForRerank: long_doc = "a" * 2000 # 2000 characters documents = [long_doc, "short doc"] - with patch("lightrag.rerank.TiktokenTokenizer", side_effect=ImportError): + with patch("lightrag.utils.TiktokenTokenizer", side_effect=ImportError): chunked_docs, doc_indices = chunk_documents_for_rerank( documents, max_tokens=100, # 100 tokens = ~400 chars @@ -360,13 +360,17 @@ class TestEndToEndChunking: mock_response.request_info = None mock_response.history = None mock_response.headers = {} + # Make mock_response an async context manager (for `async with session.post() as response`) + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) mock_session = Mock() - mock_session.post = AsyncMock(return_value=mock_response) + # session.post() returns an async context manager, so return mock_response which is now one + mock_session.post = Mock(return_value=mock_response) mock_session.__aenter__ = AsyncMock(return_value=mock_session) - mock_session.__aexit__ = AsyncMock() + mock_session.__aexit__ = AsyncMock(return_value=None) - with patch("aiohttp.ClientSession", return_value=mock_session): + with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session): result = await cohere_rerank( query=query, documents=documents, From 9009abed3ecd61605f3ec43dcda8ada1787bd3a6 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 3 Dec 2025 13:08:26 +0800 Subject: [PATCH 6/6] Fix top_n behavior with chunking to limit documents not chunks - Disable API-level top_n when chunking - Apply top_n to aggregated documents - Add comprehensive test coverage --- lightrag/rerank.py | 17 ++++ tests/test_rerank_chunking.py | 174 ++++++++++++++++++++++++++++++++++ 2 files changed, 191 insertions(+) diff --git a/lightrag/rerank.py b/lightrag/rerank.py index 2e22f19a..12950fe6 100644 --- a/lightrag/rerank.py +++ b/lightrag/rerank.py @@ -223,6 +223,8 @@ async def generic_rerank_api( # Handle document chunking if enabled original_documents = documents doc_indices = None + original_top_n = top_n # Save original top_n for post-aggregation limiting + if enable_chunking: documents, doc_indices = chunk_documents_for_rerank( documents, max_tokens=max_tokens_per_doc @@ -230,6 +232,14 @@ async def generic_rerank_api( logger.debug( f"Chunked {len(original_documents)} documents into {len(documents)} chunks" ) + # When chunking is enabled, disable top_n at API level to get all chunk scores + # This ensures proper document-level coverage after aggregation + # We'll apply top_n to aggregated document results instead + if top_n is not None: + logger.debug( + f"Chunking enabled: disabled API-level top_n={top_n} to ensure complete document coverage" + ) + top_n = None # Build request payload based on request format if request_format == "aliyun": @@ -344,6 +354,13 @@ async def generic_rerank_api( len(original_documents), aggregation="max", ) + # Apply original top_n limit at document level (post-aggregation) + # This preserves document-level semantics: top_n limits documents, not chunks + if ( + original_top_n is not None + and len(standardized_results) > original_top_n + ): + standardized_results = standardized_results[:original_top_n] return standardized_results diff --git a/tests/test_rerank_chunking.py b/tests/test_rerank_chunking.py index 1700988a..09f1816b 100644 --- a/tests/test_rerank_chunking.py +++ b/tests/test_rerank_chunking.py @@ -234,6 +234,180 @@ class TestAggregateChunkScores: assert aggregated[0]["relevance_score"] == 0.8 +@pytest.mark.offline +class TestTopNWithChunking: + """Tests for top_n behavior when chunking is enabled (Bug fix verification)""" + + @pytest.mark.asyncio + async def test_top_n_limits_documents_not_chunks(self): + """ + Test that top_n correctly limits documents (not chunks) when chunking is enabled. + + Bug scenario: 10 docs expand to 50 chunks. With old behavior, top_n=5 would + return scores for only 5 chunks (possibly all from 1-2 docs). After aggregation, + fewer than 5 documents would be returned. + + Fixed behavior: top_n=5 should return exactly 5 documents after aggregation. + """ + # Setup: 5 documents, each producing multiple chunks when chunked + # Using small max_tokens to force chunking + long_docs = [" ".join([f"doc{i}_word{j}" for j in range(50)]) for i in range(5)] + query = "test query" + + # First, determine how many chunks will be created by actual chunking + _, doc_indices = chunk_documents_for_rerank( + long_docs, max_tokens=50, overlap_tokens=10 + ) + num_chunks = len(doc_indices) + + # Mock API returns scores for ALL chunks (simulating disabled API-level top_n) + # Give different scores to ensure doc 0 gets highest, doc 1 second, etc. + # Assign scores based on original document index (lower doc index = higher score) + mock_chunk_scores = [] + for i in range(num_chunks): + original_doc = doc_indices[i] + # Higher score for lower doc index, with small variation per chunk + base_score = 0.9 - (original_doc * 0.1) + mock_chunk_scores.append({"index": i, "relevance_score": base_score}) + + mock_response = Mock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={"results": mock_chunk_scores}) + mock_response.request_info = None + mock_response.history = None + mock_response.headers = {} + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + + mock_session = Mock() + mock_session.post = Mock(return_value=mock_response) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + + with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session): + result = await cohere_rerank( + query=query, + documents=long_docs, + api_key="test-key", + base_url="http://test.com/rerank", + enable_chunking=True, + max_tokens_per_doc=50, # Match chunking above + top_n=3, # Request top 3 documents + ) + + # Verify: should get exactly 3 documents (not unlimited chunks) + assert len(result) == 3 + # All results should have valid document indices (0-4) + assert all(0 <= r["index"] < 5 for r in result) + # Results should be sorted by score (descending) + assert all( + result[i]["relevance_score"] >= result[i + 1]["relevance_score"] + for i in range(len(result) - 1) + ) + # The top 3 docs should be 0, 1, 2 (highest scores) + result_indices = [r["index"] for r in result] + assert set(result_indices) == {0, 1, 2} + + @pytest.mark.asyncio + async def test_api_receives_no_top_n_when_chunking_enabled(self): + """ + Test that the API request does NOT include top_n when chunking is enabled. + + This ensures all chunk scores are retrieved for proper aggregation. + """ + documents = [" ".join([f"word{i}" for i in range(100)]), "short doc"] + query = "test query" + + captured_payload = {} + + mock_response = Mock() + mock_response.status = 200 + mock_response.json = AsyncMock( + return_value={ + "results": [ + {"index": 0, "relevance_score": 0.9}, + {"index": 1, "relevance_score": 0.8}, + {"index": 2, "relevance_score": 0.7}, + ] + } + ) + mock_response.request_info = None + mock_response.history = None + mock_response.headers = {} + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + + def capture_post(*args, **kwargs): + captured_payload.update(kwargs.get("json", {})) + return mock_response + + mock_session = Mock() + mock_session.post = Mock(side_effect=capture_post) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + + with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session): + await cohere_rerank( + query=query, + documents=documents, + api_key="test-key", + base_url="http://test.com/rerank", + enable_chunking=True, + max_tokens_per_doc=30, + top_n=1, # User wants top 1 document + ) + + # Verify: API payload should NOT have top_n (disabled for chunking) + assert "top_n" not in captured_payload + + @pytest.mark.asyncio + async def test_top_n_not_modified_when_chunking_disabled(self): + """ + Test that top_n is passed through to API when chunking is disabled. + """ + documents = ["doc1", "doc2"] + query = "test query" + + captured_payload = {} + + mock_response = Mock() + mock_response.status = 200 + mock_response.json = AsyncMock( + return_value={ + "results": [ + {"index": 0, "relevance_score": 0.9}, + ] + } + ) + mock_response.request_info = None + mock_response.history = None + mock_response.headers = {} + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + + def capture_post(*args, **kwargs): + captured_payload.update(kwargs.get("json", {})) + return mock_response + + mock_session = Mock() + mock_session.post = Mock(side_effect=capture_post) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + + with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session): + await cohere_rerank( + query=query, + documents=documents, + api_key="test-key", + base_url="http://test.com/rerank", + enable_chunking=False, # Chunking disabled + top_n=1, + ) + + # Verify: API payload should have top_n when chunking is disabled + assert captured_payload.get("top_n") == 1 + + @pytest.mark.offline class TestCohereRerankChunking: """Integration tests for cohere_rerank with chunking enabled"""