fix chunk_top_k limiting

This commit is contained in:
zrguo
2025-07-08 15:05:30 +08:00
parent 04a57445da
commit c295d355a0
2 changed files with 20 additions and 4 deletions

View File

@@ -20,6 +20,7 @@ from lightrag import LightRAG, QueryParam
from lightrag.rerank import custom_rerank, RerankModel
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc, setup_logger
from lightrag.kg.shared_storage import initialize_pipeline_status
# Set up your working directory
WORKING_DIR = "./test_rerank"
@@ -87,6 +88,9 @@ async def create_rag_with_rerank():
rerank_model_func=my_rerank_func,
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
@@ -120,6 +124,9 @@ async def create_rag_with_rerank_model():
rerank_model_func=rerank_model.rerank,
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag

View File

@@ -2823,8 +2823,9 @@ async def apply_rerank_if_enabled(
documents=retrieved_docs,
top_k=top_k,
)
if reranked_docs and len(reranked_docs) > 0:
if len(reranked_docs) > top_k:
reranked_docs = reranked_docs[:top_k]
logger.info(
f"Successfully reranked {len(retrieved_docs)} documents to {len(reranked_docs)}"
)
@@ -2846,7 +2847,7 @@ async def process_chunks_unified(
source_type: str = "mixed",
) -> list[dict]:
"""
Unified processing for text chunks: deduplication, reranking, and token truncation.
Unified processing for text chunks: deduplication, chunk_top_k limiting, reranking, and token truncation.
Args:
query: Search query for reranking
@@ -2874,7 +2875,15 @@ async def process_chunks_unified(
f"Deduplication: {len(unique_chunks)} chunks (original: {len(chunks)})"
)
# 2. Apply reranking if enabled and query is provided
# 2. Apply chunk_top_k limiting if specified
if query_param.chunk_top_k is not None and query_param.chunk_top_k > 0:
if len(unique_chunks) > query_param.chunk_top_k:
unique_chunks = unique_chunks[: query_param.chunk_top_k]
logger.debug(
f"Chunk top-k limiting: kept {len(unique_chunks)} chunks (chunk_top_k={query_param.chunk_top_k})"
)
# 3. Apply reranking if enabled and query is provided
if global_config.get("enable_rerank", False) and query and unique_chunks:
rerank_top_k = query_param.chunk_rerank_top_k or len(unique_chunks)
unique_chunks = await apply_rerank_if_enabled(
@@ -2885,7 +2894,7 @@ async def process_chunks_unified(
)
logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})")
# 3. Token-based final truncation
# 4. Token-based final truncation
tokenizer = global_config.get("tokenizer")
if tokenizer and unique_chunks:
original_count = len(unique_chunks)