Fix trailing whitespace and update test mocking for rerank module

• Remove trailing whitespace
• Fix TiktokenTokenizer import patch
• Add async context manager mocks
• Update aiohttp.ClientSession patch
• Improve test reliability
This commit is contained in:
yangdx
2025-12-03 12:40:48 +08:00
parent 8e50eef58b
commit 561ba4e4b5
3 changed files with 23 additions and 19 deletions

View File

@@ -50,7 +50,7 @@ def chunk_documents_for_rerank(
f"overlap_tokens ({original_overlap}) must be less than max_tokens ({max_tokens}). " f"overlap_tokens ({original_overlap}) must be less than max_tokens ({max_tokens}). "
f"Clamping to {overlap_tokens} to prevent infinite loop." f"Clamping to {overlap_tokens} to prevent infinite loop."
) )
try: try:
from .utils import TiktokenTokenizer from .utils import TiktokenTokenizer

View File

@@ -14,12 +14,12 @@ class TestOverlapValidation:
def test_overlap_greater_than_max_tokens(self): def test_overlap_greater_than_max_tokens(self):
"""Test that overlap_tokens > max_tokens is clamped and doesn't hang""" """Test that overlap_tokens > max_tokens is clamped and doesn't hang"""
documents = [" ".join([f"word{i}" for i in range(100)])] documents = [" ".join([f"word{i}" for i in range(100)])]
# This should clamp overlap_tokens to 29 (max_tokens - 1) # This should clamp overlap_tokens to 29 (max_tokens - 1)
chunked_docs, doc_indices = chunk_documents_for_rerank( chunked_docs, doc_indices = chunk_documents_for_rerank(
documents, max_tokens=30, overlap_tokens=32 documents, max_tokens=30, overlap_tokens=32
) )
# Should complete without hanging # Should complete without hanging
assert len(chunked_docs) > 0 assert len(chunked_docs) > 0
assert all(idx == 0 for idx in doc_indices) assert all(idx == 0 for idx in doc_indices)
@@ -27,12 +27,12 @@ class TestOverlapValidation:
def test_overlap_equal_to_max_tokens(self): def test_overlap_equal_to_max_tokens(self):
"""Test that overlap_tokens == max_tokens is clamped and doesn't hang""" """Test that overlap_tokens == max_tokens is clamped and doesn't hang"""
documents = [" ".join([f"word{i}" for i in range(100)])] documents = [" ".join([f"word{i}" for i in range(100)])]
# This should clamp overlap_tokens to 29 (max_tokens - 1) # This should clamp overlap_tokens to 29 (max_tokens - 1)
chunked_docs, doc_indices = chunk_documents_for_rerank( chunked_docs, doc_indices = chunk_documents_for_rerank(
documents, max_tokens=30, overlap_tokens=30 documents, max_tokens=30, overlap_tokens=30
) )
# Should complete without hanging # Should complete without hanging
assert len(chunked_docs) > 0 assert len(chunked_docs) > 0
assert all(idx == 0 for idx in doc_indices) assert all(idx == 0 for idx in doc_indices)
@@ -40,12 +40,12 @@ class TestOverlapValidation:
def test_overlap_slightly_less_than_max_tokens(self): def test_overlap_slightly_less_than_max_tokens(self):
"""Test that overlap_tokens < max_tokens works normally""" """Test that overlap_tokens < max_tokens works normally"""
documents = [" ".join([f"word{i}" for i in range(100)])] documents = [" ".join([f"word{i}" for i in range(100)])]
# This should work without clamping # This should work without clamping
chunked_docs, doc_indices = chunk_documents_for_rerank( chunked_docs, doc_indices = chunk_documents_for_rerank(
documents, max_tokens=30, overlap_tokens=29 documents, max_tokens=30, overlap_tokens=29
) )
# Should complete successfully # Should complete successfully
assert len(chunked_docs) > 0 assert len(chunked_docs) > 0
assert all(idx == 0 for idx in doc_indices) assert all(idx == 0 for idx in doc_indices)
@@ -53,12 +53,12 @@ class TestOverlapValidation:
def test_small_max_tokens_with_large_overlap(self): def test_small_max_tokens_with_large_overlap(self):
"""Test edge case with very small max_tokens""" """Test edge case with very small max_tokens"""
documents = [" ".join([f"word{i}" for i in range(50)])] documents = [" ".join([f"word{i}" for i in range(50)])]
# max_tokens=5, overlap_tokens=10 should clamp to 4 # max_tokens=5, overlap_tokens=10 should clamp to 4
chunked_docs, doc_indices = chunk_documents_for_rerank( chunked_docs, doc_indices = chunk_documents_for_rerank(
documents, max_tokens=5, overlap_tokens=10 documents, max_tokens=5, overlap_tokens=10
) )
# Should complete without hanging # Should complete without hanging
assert len(chunked_docs) > 0 assert len(chunked_docs) > 0
assert all(idx == 0 for idx in doc_indices) assert all(idx == 0 for idx in doc_indices)
@@ -70,12 +70,12 @@ class TestOverlapValidation:
"short document", "short document",
" ".join([f"word{i}" for i in range(75)]), " ".join([f"word{i}" for i in range(75)]),
] ]
# overlap_tokens > max_tokens # overlap_tokens > max_tokens
chunked_docs, doc_indices = chunk_documents_for_rerank( chunked_docs, doc_indices = chunk_documents_for_rerank(
documents, max_tokens=25, overlap_tokens=30 documents, max_tokens=25, overlap_tokens=30
) )
# Should complete successfully and chunk the long documents # Should complete successfully and chunk the long documents
assert len(chunked_docs) >= len(documents) assert len(chunked_docs) >= len(documents)
# Short document should not be chunked # Short document should not be chunked
@@ -87,12 +87,12 @@ class TestOverlapValidation:
" ".join([f"word{i}" for i in range(100)]), " ".join([f"word{i}" for i in range(100)]),
"short doc", "short doc",
] ]
# Normal case: overlap_tokens (10) < max_tokens (50) # Normal case: overlap_tokens (10) < max_tokens (50)
chunked_docs, doc_indices = chunk_documents_for_rerank( chunked_docs, doc_indices = chunk_documents_for_rerank(
documents, max_tokens=50, overlap_tokens=10 documents, max_tokens=50, overlap_tokens=10
) )
# Long document should be chunked, short one should not # Long document should be chunked, short one should not
assert len(chunked_docs) > 2 # At least 3 chunks (2 from long doc + 1 short) assert len(chunked_docs) > 2 # At least 3 chunks (2 from long doc + 1 short)
assert "short doc" in chunked_docs assert "short doc" in chunked_docs
@@ -102,12 +102,12 @@ class TestOverlapValidation:
def test_edge_case_max_tokens_one(self): def test_edge_case_max_tokens_one(self):
"""Test edge case where max_tokens=1""" """Test edge case where max_tokens=1"""
documents = [" ".join([f"word{i}" for i in range(20)])] documents = [" ".join([f"word{i}" for i in range(20)])]
# max_tokens=1, overlap_tokens=5 should clamp to 0 # max_tokens=1, overlap_tokens=5 should clamp to 0
chunked_docs, doc_indices = chunk_documents_for_rerank( chunked_docs, doc_indices = chunk_documents_for_rerank(
documents, max_tokens=1, overlap_tokens=5 documents, max_tokens=1, overlap_tokens=5
) )
# Should complete without hanging # Should complete without hanging
assert len(chunked_docs) > 0 assert len(chunked_docs) > 0
assert all(idx == 0 for idx in doc_indices) assert all(idx == 0 for idx in doc_indices)

View File

@@ -40,7 +40,7 @@ class TestChunkDocumentsForRerank:
long_doc = "a" * 2000 # 2000 characters long_doc = "a" * 2000 # 2000 characters
documents = [long_doc, "short doc"] documents = [long_doc, "short doc"]
with patch("lightrag.rerank.TiktokenTokenizer", side_effect=ImportError): with patch("lightrag.utils.TiktokenTokenizer", side_effect=ImportError):
chunked_docs, doc_indices = chunk_documents_for_rerank( chunked_docs, doc_indices = chunk_documents_for_rerank(
documents, documents,
max_tokens=100, # 100 tokens = ~400 chars max_tokens=100, # 100 tokens = ~400 chars
@@ -360,13 +360,17 @@ class TestEndToEndChunking:
mock_response.request_info = None mock_response.request_info = None
mock_response.history = None mock_response.history = None
mock_response.headers = {} mock_response.headers = {}
# Make mock_response an async context manager (for `async with session.post() as response`)
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
mock_response.__aexit__ = AsyncMock(return_value=None)
mock_session = Mock() mock_session = Mock()
mock_session.post = AsyncMock(return_value=mock_response) # session.post() returns an async context manager, so return mock_response which is now one
mock_session.post = Mock(return_value=mock_response)
mock_session.__aenter__ = AsyncMock(return_value=mock_session) mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock() mock_session.__aexit__ = AsyncMock(return_value=None)
with patch("aiohttp.ClientSession", return_value=mock_session): with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session):
result = await cohere_rerank( result = await cohere_rerank(
query=query, query=query,
documents=documents, documents=documents,