303 lines
9.3 KiB
Python
303 lines
9.3 KiB
Python
from collections.abc import Awaitable, Callable, Sequence
|
|
from dataclasses import dataclass
|
|
from typing import cast
|
|
from unittest.mock import AsyncMock
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
|
|
from noteflow.domain.value_objects import MeetingId
|
|
from noteflow.infrastructure.ai.tools.retrieval import (
|
|
BatchEmbedderProtocol,
|
|
MeetingBatchRetrievalRequest,
|
|
MeetingRetrievalDependencies,
|
|
MeetingRetrievalRequest,
|
|
RetrievalResult,
|
|
retrieve_segments,
|
|
retrieve_segments_batch,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class MockSegment:
|
|
segment_id: int
|
|
meeting_id: object
|
|
text: str
|
|
start_time: float
|
|
end_time: float
|
|
|
|
|
|
class TestRetrieveSegments:
|
|
@pytest.fixture
|
|
def mock_embedder(self) -> AsyncMock:
|
|
embedder = AsyncMock()
|
|
embedder.embed.return_value = [0.1, 0.2, 0.3]
|
|
return embedder
|
|
|
|
@pytest.fixture
|
|
def mock_segment_repo(self) -> AsyncMock:
|
|
return AsyncMock()
|
|
|
|
@pytest.fixture
|
|
def sample_meeting_id(self) -> MeetingId:
|
|
return MeetingId(uuid4())
|
|
|
|
@pytest.fixture
|
|
def deps(
|
|
self,
|
|
mock_embedder: AsyncMock,
|
|
mock_segment_repo: AsyncMock,
|
|
) -> MeetingRetrievalDependencies:
|
|
return MeetingRetrievalDependencies(
|
|
embedder=mock_embedder,
|
|
segment_repo=mock_segment_repo,
|
|
)
|
|
|
|
async def test_retrieve_segments_success(
|
|
self,
|
|
deps: MeetingRetrievalDependencies,
|
|
sample_meeting_id: MeetingId,
|
|
) -> None:
|
|
segment = MockSegment(
|
|
segment_id=1,
|
|
meeting_id=sample_meeting_id,
|
|
text="Test segment",
|
|
start_time=0.0,
|
|
end_time=5.0,
|
|
)
|
|
deps.segment_repo.search_semantic.return_value = [(segment, 0.95)]
|
|
|
|
results = await retrieve_segments(
|
|
MeetingRetrievalRequest(
|
|
query="test query",
|
|
meeting_id=sample_meeting_id,
|
|
top_k=5,
|
|
),
|
|
deps,
|
|
)
|
|
|
|
assert len(results) == 1, "Expected one retrieval result"
|
|
assert results[0].segment_id == 1, "Segment ID should match input"
|
|
assert results[0].text == "Test segment", "Segment text should match input"
|
|
assert results[0].score == 0.95, "Score should preserve search score"
|
|
|
|
async def test_retrieve_segments_calls_embedder_with_query(
|
|
self,
|
|
deps: MeetingRetrievalDependencies,
|
|
) -> None:
|
|
deps.segment_repo.search_semantic.return_value = []
|
|
|
|
await retrieve_segments(
|
|
MeetingRetrievalRequest(query="what happened in the meeting"),
|
|
deps,
|
|
)
|
|
|
|
embed_call = cast(AsyncMock, deps.embedder.embed)
|
|
embed_call.assert_called_once_with("what happened in the meeting")
|
|
|
|
async def test_retrieve_segments_passes_embedding_to_repo(
|
|
self,
|
|
deps: MeetingRetrievalDependencies,
|
|
) -> None:
|
|
deps.embedder.embed.return_value = [1.0, 2.0, 3.0]
|
|
deps.segment_repo.search_semantic.return_value = []
|
|
|
|
await retrieve_segments(
|
|
MeetingRetrievalRequest(query="test", top_k=10),
|
|
deps,
|
|
)
|
|
|
|
search_call = cast(AsyncMock, deps.segment_repo.search_semantic)
|
|
search_call.assert_called_once_with(
|
|
query_embedding=[1.0, 2.0, 3.0],
|
|
meeting_id=None,
|
|
limit=10,
|
|
)
|
|
|
|
async def test_retrieve_segments_empty_result(
|
|
self,
|
|
deps: MeetingRetrievalDependencies,
|
|
) -> None:
|
|
deps.segment_repo.search_semantic.return_value = []
|
|
|
|
results = await retrieve_segments(
|
|
MeetingRetrievalRequest(query="test"),
|
|
deps,
|
|
)
|
|
|
|
assert results == [], "Expected no results for empty search response"
|
|
|
|
async def test_retrieval_result_is_frozen(self) -> None:
|
|
result = RetrievalResult(
|
|
segment_id=1,
|
|
meeting_id=uuid4(),
|
|
text="Test",
|
|
start_time=0.0,
|
|
end_time=5.0,
|
|
score=0.9,
|
|
)
|
|
|
|
with pytest.raises(AttributeError, match="cannot assign to field"):
|
|
result.text = "Modified"
|
|
|
|
|
|
class MockBatchEmbedder:
|
|
def __init__(self, embedding: list[float]) -> None:
|
|
self._embedding = embedding
|
|
self.embed_calls: list[str] = []
|
|
self.embed_batch_calls: list[Sequence[str]] = []
|
|
|
|
async def embed(self, text: str) -> list[float]:
|
|
self.embed_calls.append(text)
|
|
return self._embedding
|
|
|
|
async def embed_batch(self, texts: Sequence[str]) -> list[list[float]]:
|
|
self.embed_batch_calls.append(texts)
|
|
return [self._embedding for _ in texts]
|
|
|
|
|
|
def _ordered_search_side_effect(
|
|
first: MockSegment,
|
|
second: MockSegment,
|
|
) -> Callable[[list[float], int, object], Awaitable[list[tuple[MockSegment, float]]]]:
|
|
call_count = 0
|
|
|
|
async def side_effect(
|
|
query_embedding: list[float],
|
|
limit: int,
|
|
meeting_id: object,
|
|
) -> list[tuple[MockSegment, float]]:
|
|
nonlocal call_count
|
|
call_count += 1
|
|
return [(first, 0.9)] if call_count == 1 else [(second, 0.8)]
|
|
|
|
return side_effect
|
|
|
|
|
|
class TestRetrieveSegmentsBatch:
|
|
@pytest.fixture
|
|
def mock_embedder(self) -> AsyncMock:
|
|
embedder = AsyncMock()
|
|
embedder.embed.return_value = [0.1, 0.2, 0.3]
|
|
return embedder
|
|
|
|
@pytest.fixture
|
|
def batch_embedder(self) -> MockBatchEmbedder:
|
|
return MockBatchEmbedder([0.1, 0.2, 0.3])
|
|
|
|
@pytest.fixture
|
|
def mock_segment_repo(self) -> AsyncMock:
|
|
return AsyncMock()
|
|
|
|
@pytest.fixture
|
|
def sample_meeting_id(self) -> MeetingId:
|
|
return MeetingId(uuid4())
|
|
|
|
async def test_batch_returns_empty_for_no_queries(
|
|
self,
|
|
mock_embedder: AsyncMock,
|
|
mock_segment_repo: AsyncMock,
|
|
) -> None:
|
|
results = await retrieve_segments_batch(
|
|
MeetingBatchRetrievalRequest(queries=[]),
|
|
MeetingRetrievalDependencies(
|
|
embedder=mock_embedder,
|
|
segment_repo=mock_segment_repo,
|
|
),
|
|
)
|
|
|
|
assert results == [], "Expected empty results for empty query list"
|
|
mock_embedder.embed.assert_not_called()
|
|
|
|
async def test_batch_uses_embed_batch_when_available(
|
|
self,
|
|
batch_embedder: MockBatchEmbedder,
|
|
mock_segment_repo: AsyncMock,
|
|
sample_meeting_id: MeetingId,
|
|
) -> None:
|
|
segment = MockSegment(
|
|
segment_id=1,
|
|
meeting_id=sample_meeting_id,
|
|
text="Test",
|
|
start_time=0.0,
|
|
end_time=5.0,
|
|
)
|
|
mock_segment_repo.search_semantic.return_value = [(segment, 0.9)]
|
|
|
|
assert isinstance(batch_embedder, BatchEmbedderProtocol)
|
|
|
|
results = await retrieve_segments_batch(
|
|
MeetingBatchRetrievalRequest(
|
|
queries=["query1", "query2"],
|
|
meeting_id=sample_meeting_id,
|
|
),
|
|
MeetingRetrievalDependencies(
|
|
embedder=batch_embedder,
|
|
segment_repo=mock_segment_repo,
|
|
),
|
|
)
|
|
|
|
assert len(results) == 2, "Expected one result list per query"
|
|
assert len(batch_embedder.embed_batch_calls) == 1, "Expected batch embedding call"
|
|
assert list(batch_embedder.embed_batch_calls[0]) == ["query1", "query2"], (
|
|
"Batch embedder should receive queries in order"
|
|
)
|
|
assert batch_embedder.embed_calls == [], "Single embed should not be used"
|
|
|
|
async def test_batch_falls_back_to_parallel_embed(
|
|
self,
|
|
mock_embedder: AsyncMock,
|
|
mock_segment_repo: AsyncMock,
|
|
sample_meeting_id: MeetingId,
|
|
) -> None:
|
|
segment = MockSegment(
|
|
segment_id=1,
|
|
meeting_id=sample_meeting_id,
|
|
text="Test",
|
|
start_time=0.0,
|
|
end_time=5.0,
|
|
)
|
|
mock_segment_repo.search_semantic.return_value = [(segment, 0.9)]
|
|
|
|
results = await retrieve_segments_batch(
|
|
MeetingBatchRetrievalRequest(
|
|
queries=["query1", "query2"],
|
|
meeting_id=sample_meeting_id,
|
|
),
|
|
MeetingRetrievalDependencies(
|
|
embedder=mock_embedder,
|
|
segment_repo=mock_segment_repo,
|
|
),
|
|
)
|
|
|
|
assert len(results) == 2, "Expected one result list per query"
|
|
assert mock_embedder.embed.call_count == 2, "Expected parallel embed fallback"
|
|
|
|
async def test_batch_preserves_query_order(
|
|
self,
|
|
mock_segment_repo: AsyncMock,
|
|
sample_meeting_id: MeetingId,
|
|
) -> None:
|
|
segment1 = MockSegment(1, sample_meeting_id, "First", 0.0, 5.0)
|
|
segment2 = MockSegment(2, sample_meeting_id, "Second", 5.0, 10.0)
|
|
mock_segment_repo.search_semantic.side_effect = _ordered_search_side_effect(
|
|
segment1,
|
|
segment2,
|
|
)
|
|
|
|
embedder = MockBatchEmbedder([0.1, 0.2])
|
|
results = await retrieve_segments_batch(
|
|
MeetingBatchRetrievalRequest(
|
|
queries=["first", "second"],
|
|
meeting_id=sample_meeting_id,
|
|
),
|
|
MeetingRetrievalDependencies(
|
|
embedder=embedder,
|
|
segment_repo=mock_segment_repo,
|
|
),
|
|
)
|
|
|
|
assert len(results) == 2, "Expected results for two queries"
|
|
assert results[0][0].text == "First", "First query should map to first result"
|
|
assert results[1][0].text == "Second", "Second query should map to second result"
|