Files
noteflow/tests/infrastructure/ai/test_retrieval.py
Travis Vasceannie d8090a98e8
Some checks failed
CI / test-typescript (push) Has been cancelled
CI / test-rust (push) Has been cancelled
CI / test-python (push) Has been cancelled
ci/cd fixes
2026-01-26 00:28:15 +00:00

303 lines
9.3 KiB
Python

from collections.abc import Awaitable, Callable, Sequence
from dataclasses import dataclass
from typing import cast
from unittest.mock import AsyncMock
from uuid import uuid4
import pytest
from noteflow.domain.value_objects import MeetingId
from noteflow.infrastructure.ai.tools.retrieval import (
BatchEmbedderProtocol,
MeetingBatchRetrievalRequest,
MeetingRetrievalDependencies,
MeetingRetrievalRequest,
RetrievalResult,
retrieve_segments,
retrieve_segments_batch,
)
@dataclass
class MockSegment:
segment_id: int
meeting_id: object
text: str
start_time: float
end_time: float
class TestRetrieveSegments:
@pytest.fixture
def mock_embedder(self) -> AsyncMock:
embedder = AsyncMock()
embedder.embed.return_value = [0.1, 0.2, 0.3]
return embedder
@pytest.fixture
def mock_segment_repo(self) -> AsyncMock:
return AsyncMock()
@pytest.fixture
def sample_meeting_id(self) -> MeetingId:
return MeetingId(uuid4())
@pytest.fixture
def deps(
self,
mock_embedder: AsyncMock,
mock_segment_repo: AsyncMock,
) -> MeetingRetrievalDependencies:
return MeetingRetrievalDependencies(
embedder=mock_embedder,
segment_repo=mock_segment_repo,
)
async def test_retrieve_segments_success(
self,
deps: MeetingRetrievalDependencies,
sample_meeting_id: MeetingId,
) -> None:
segment = MockSegment(
segment_id=1,
meeting_id=sample_meeting_id,
text="Test segment",
start_time=0.0,
end_time=5.0,
)
deps.segment_repo.search_semantic.return_value = [(segment, 0.95)]
results = await retrieve_segments(
MeetingRetrievalRequest(
query="test query",
meeting_id=sample_meeting_id,
top_k=5,
),
deps,
)
assert len(results) == 1, "Expected one retrieval result"
assert results[0].segment_id == 1, "Segment ID should match input"
assert results[0].text == "Test segment", "Segment text should match input"
assert results[0].score == 0.95, "Score should preserve search score"
async def test_retrieve_segments_calls_embedder_with_query(
self,
deps: MeetingRetrievalDependencies,
) -> None:
deps.segment_repo.search_semantic.return_value = []
await retrieve_segments(
MeetingRetrievalRequest(query="what happened in the meeting"),
deps,
)
embed_call = cast(AsyncMock, deps.embedder.embed)
embed_call.assert_called_once_with("what happened in the meeting")
async def test_retrieve_segments_passes_embedding_to_repo(
self,
deps: MeetingRetrievalDependencies,
) -> None:
deps.embedder.embed.return_value = [1.0, 2.0, 3.0]
deps.segment_repo.search_semantic.return_value = []
await retrieve_segments(
MeetingRetrievalRequest(query="test", top_k=10),
deps,
)
search_call = cast(AsyncMock, deps.segment_repo.search_semantic)
search_call.assert_called_once_with(
query_embedding=[1.0, 2.0, 3.0],
meeting_id=None,
limit=10,
)
async def test_retrieve_segments_empty_result(
self,
deps: MeetingRetrievalDependencies,
) -> None:
deps.segment_repo.search_semantic.return_value = []
results = await retrieve_segments(
MeetingRetrievalRequest(query="test"),
deps,
)
assert results == [], "Expected no results for empty search response"
async def test_retrieval_result_is_frozen(self) -> None:
result = RetrievalResult(
segment_id=1,
meeting_id=uuid4(),
text="Test",
start_time=0.0,
end_time=5.0,
score=0.9,
)
with pytest.raises(AttributeError, match="cannot assign to field"):
result.text = "Modified"
class MockBatchEmbedder:
def __init__(self, embedding: list[float]) -> None:
self._embedding = embedding
self.embed_calls: list[str] = []
self.embed_batch_calls: list[Sequence[str]] = []
async def embed(self, text: str) -> list[float]:
self.embed_calls.append(text)
return self._embedding
async def embed_batch(self, texts: Sequence[str]) -> list[list[float]]:
self.embed_batch_calls.append(texts)
return [self._embedding for _ in texts]
def _ordered_search_side_effect(
first: MockSegment,
second: MockSegment,
) -> Callable[[list[float], int, object], Awaitable[list[tuple[MockSegment, float]]]]:
call_count = 0
async def side_effect(
query_embedding: list[float],
limit: int,
meeting_id: object,
) -> list[tuple[MockSegment, float]]:
nonlocal call_count
call_count += 1
return [(first, 0.9)] if call_count == 1 else [(second, 0.8)]
return side_effect
class TestRetrieveSegmentsBatch:
@pytest.fixture
def mock_embedder(self) -> AsyncMock:
embedder = AsyncMock()
embedder.embed.return_value = [0.1, 0.2, 0.3]
return embedder
@pytest.fixture
def batch_embedder(self) -> MockBatchEmbedder:
return MockBatchEmbedder([0.1, 0.2, 0.3])
@pytest.fixture
def mock_segment_repo(self) -> AsyncMock:
return AsyncMock()
@pytest.fixture
def sample_meeting_id(self) -> MeetingId:
return MeetingId(uuid4())
async def test_batch_returns_empty_for_no_queries(
self,
mock_embedder: AsyncMock,
mock_segment_repo: AsyncMock,
) -> None:
results = await retrieve_segments_batch(
MeetingBatchRetrievalRequest(queries=[]),
MeetingRetrievalDependencies(
embedder=mock_embedder,
segment_repo=mock_segment_repo,
),
)
assert results == [], "Expected empty results for empty query list"
mock_embedder.embed.assert_not_called()
async def test_batch_uses_embed_batch_when_available(
self,
batch_embedder: MockBatchEmbedder,
mock_segment_repo: AsyncMock,
sample_meeting_id: MeetingId,
) -> None:
segment = MockSegment(
segment_id=1,
meeting_id=sample_meeting_id,
text="Test",
start_time=0.0,
end_time=5.0,
)
mock_segment_repo.search_semantic.return_value = [(segment, 0.9)]
assert isinstance(batch_embedder, BatchEmbedderProtocol)
results = await retrieve_segments_batch(
MeetingBatchRetrievalRequest(
queries=["query1", "query2"],
meeting_id=sample_meeting_id,
),
MeetingRetrievalDependencies(
embedder=batch_embedder,
segment_repo=mock_segment_repo,
),
)
assert len(results) == 2, "Expected one result list per query"
assert len(batch_embedder.embed_batch_calls) == 1, "Expected batch embedding call"
assert list(batch_embedder.embed_batch_calls[0]) == ["query1", "query2"], (
"Batch embedder should receive queries in order"
)
assert batch_embedder.embed_calls == [], "Single embed should not be used"
async def test_batch_falls_back_to_parallel_embed(
self,
mock_embedder: AsyncMock,
mock_segment_repo: AsyncMock,
sample_meeting_id: MeetingId,
) -> None:
segment = MockSegment(
segment_id=1,
meeting_id=sample_meeting_id,
text="Test",
start_time=0.0,
end_time=5.0,
)
mock_segment_repo.search_semantic.return_value = [(segment, 0.9)]
results = await retrieve_segments_batch(
MeetingBatchRetrievalRequest(
queries=["query1", "query2"],
meeting_id=sample_meeting_id,
),
MeetingRetrievalDependencies(
embedder=mock_embedder,
segment_repo=mock_segment_repo,
),
)
assert len(results) == 2, "Expected one result list per query"
assert mock_embedder.embed.call_count == 2, "Expected parallel embed fallback"
async def test_batch_preserves_query_order(
self,
mock_segment_repo: AsyncMock,
sample_meeting_id: MeetingId,
) -> None:
segment1 = MockSegment(1, sample_meeting_id, "First", 0.0, 5.0)
segment2 = MockSegment(2, sample_meeting_id, "Second", 5.0, 10.0)
mock_segment_repo.search_semantic.side_effect = _ordered_search_side_effect(
segment1,
segment2,
)
embedder = MockBatchEmbedder([0.1, 0.2])
results = await retrieve_segments_batch(
MeetingBatchRetrievalRequest(
queries=["first", "second"],
meeting_id=sample_meeting_id,
),
MeetingRetrievalDependencies(
embedder=embedder,
segment_repo=mock_segment_repo,
),
)
assert len(results) == 2, "Expected results for two queries"
assert results[0][0].text == "First", "First query should map to first result"
assert results[1][0].text == "Second", "Second query should map to second result"