208 lines
7.4 KiB
Python
208 lines
7.4 KiB
Python
"""RAG pipeline evaluation tests."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from typing import Final
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
|
|
from noteflow.domain.value_objects import MeetingId
|
|
from tests.conftest import approx_float
|
|
|
|
from .conftest import (
|
|
DEFAULT_EMBEDDING_DIM,
|
|
ConfigurableEmbedder,
|
|
ConfigurableSegmentRepository,
|
|
MockSegment,
|
|
calculate_retrieval_metrics,
|
|
extract_segment_ids_from_retrieved,
|
|
)
|
|
|
|
HIGH_SCORE: Final[float] = 0.95
|
|
MEDIUM_SCORE: Final[float] = 0.85
|
|
LOW_SCORE: Final[float] = 0.75
|
|
SEARCH_LIMIT: Final[int] = 5
|
|
TOP_K: Final[int] = 3
|
|
|
|
|
|
class TestConfigurableEmbedder:
|
|
@pytest.mark.asyncio
|
|
async def test_embed_returns_default_embedding(self) -> None:
|
|
embedder = ConfigurableEmbedder()
|
|
result = await embedder.embed("test text")
|
|
assert len(result) == DEFAULT_EMBEDDING_DIM, "Default embedding dimension"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_embed_returns_configured_embedding(self) -> None:
|
|
custom = [1.0, 2.0, 3.0]
|
|
embedder = ConfigurableEmbedder(embeddings={"specific": custom})
|
|
result = await embedder.embed("specific")
|
|
assert result == custom, "Should return configured embedding"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_embed_tracks_calls(self) -> None:
|
|
embedder = ConfigurableEmbedder()
|
|
await embedder.embed("first")
|
|
await embedder.embed("second")
|
|
assert embedder.embed_calls == ["first", "second"], "Should track all calls"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_embed_batch_processes_multiple(self) -> None:
|
|
embedder = ConfigurableEmbedder()
|
|
results = await embedder.embed_batch(["a", "b", "c"])
|
|
assert len(results) == 3, "Should return 3 embeddings"
|
|
assert len(embedder.embed_calls) == 3, "Should track 3 calls"
|
|
|
|
|
|
class TestConfigurableSegmentRepository:
|
|
@pytest.mark.asyncio
|
|
async def test_search_returns_configured_segments(self) -> None:
|
|
meeting_id = MeetingId(uuid4())
|
|
seg = MockSegment(
|
|
segment_id=1, meeting_id=meeting_id, text="t", start_time=0.0, end_time=1.0
|
|
)
|
|
repo = ConfigurableSegmentRepository([(seg, HIGH_SCORE)])
|
|
|
|
results = await repo.search_semantic([0.1], SEARCH_LIMIT, meeting_id)
|
|
|
|
assert len(results) == 1, "Should return configured segment"
|
|
assert results[0][0].segment_id == 1, "Segment ID matches"
|
|
assert results[0][1] == HIGH_SCORE, "Score matches"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_filters_by_meeting_id(self) -> None:
|
|
meeting1 = MeetingId(uuid4())
|
|
meeting2 = MeetingId(uuid4())
|
|
seg1 = MockSegment(
|
|
segment_id=1, meeting_id=meeting1, text="a", start_time=0.0, end_time=1.0
|
|
)
|
|
seg2 = MockSegment(
|
|
segment_id=2, meeting_id=meeting2, text="b", start_time=0.0, end_time=1.0
|
|
)
|
|
repo = ConfigurableSegmentRepository([(seg1, HIGH_SCORE), (seg2, MEDIUM_SCORE)])
|
|
|
|
results = await repo.search_semantic([0.1], SEARCH_LIMIT, meeting1)
|
|
|
|
assert len(results) == 1, "Should filter to meeting1 only"
|
|
assert results[0][0].segment_id == 1, "Only seg1 matches"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_respects_limit(self) -> None:
|
|
meeting_id = MeetingId(uuid4())
|
|
seg1 = MockSegment(
|
|
segment_id=1, meeting_id=meeting_id, text="a", start_time=0.0, end_time=1.0
|
|
)
|
|
seg2 = MockSegment(
|
|
segment_id=2, meeting_id=meeting_id, text="b", start_time=1.0, end_time=2.0
|
|
)
|
|
seg3 = MockSegment(
|
|
segment_id=3, meeting_id=meeting_id, text="c", start_time=2.0, end_time=3.0
|
|
)
|
|
repo = ConfigurableSegmentRepository(
|
|
[(seg1, HIGH_SCORE), (seg2, MEDIUM_SCORE), (seg3, LOW_SCORE)]
|
|
)
|
|
|
|
results = await repo.search_semantic([0.1], 2, meeting_id)
|
|
|
|
assert len(results) == 2, "Should respect limit of 2"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_tracks_calls(self) -> None:
|
|
meeting_id = MeetingId(uuid4())
|
|
repo = ConfigurableSegmentRepository([])
|
|
|
|
await repo.search_semantic([0.1, 0.2], SEARCH_LIMIT, meeting_id)
|
|
|
|
assert len(repo.search_calls) == 1, "Should track call"
|
|
embedding, limit, mid = repo.search_calls[0]
|
|
assert embedding == [0.1, 0.2], "Embedding recorded"
|
|
assert limit == SEARCH_LIMIT, "Limit recorded"
|
|
assert mid == meeting_id, "Meeting ID recorded"
|
|
|
|
|
|
class TestRetrievalQuality:
|
|
def test_perfect_retrieval_all_relevant_first(self) -> None:
|
|
retrieved_ids = [1, 2, 3, 4, 5]
|
|
relevant_ids = {1, 2, 3}
|
|
metrics = calculate_retrieval_metrics(retrieved_ids, relevant_ids, k=TOP_K)
|
|
|
|
assert metrics.precision_at_k == 1.0, "All top-k are relevant"
|
|
assert metrics.recall_at_k == 1.0, "All relevant in top-k"
|
|
assert metrics.ndcg == approx_float(1.0), "Perfect ranking"
|
|
|
|
def test_poor_retrieval_relevant_at_end(self) -> None:
|
|
retrieved_ids = [10, 20, 1, 2, 3]
|
|
relevant_ids = {1, 2, 3}
|
|
metrics = calculate_retrieval_metrics(retrieved_ids, relevant_ids, k=TOP_K)
|
|
|
|
assert metrics.precision_at_k == approx_float(1 / 3), "Only 1 relevant in top-3"
|
|
assert metrics.recall_at_k == approx_float(1 / 3), "Only 1 relevant found"
|
|
|
|
def test_no_relevant_in_results(self) -> None:
|
|
retrieved_ids = [10, 20, 30]
|
|
relevant_ids = {1, 2, 3}
|
|
metrics = calculate_retrieval_metrics(retrieved_ids, relevant_ids, k=TOP_K)
|
|
|
|
assert metrics.precision_at_k == 0.0, "No relevant in results"
|
|
assert metrics.recall_at_k == 0.0, "No relevant found"
|
|
assert metrics.hit_rate == 0.0, "No hits"
|
|
|
|
|
|
class TestEmbeddingQuality:
|
|
@pytest.mark.asyncio
|
|
async def test_same_text_same_embedding(self) -> None:
|
|
embedder = ConfigurableEmbedder()
|
|
e1, e2 = await asyncio.gather(
|
|
embedder.embed("hello world"),
|
|
embedder.embed("hello world"),
|
|
)
|
|
|
|
assert e1 == e2, "Same text should yield same embedding"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_custom_embeddings_differ(self) -> None:
|
|
embedder = ConfigurableEmbedder(
|
|
embeddings={
|
|
"apple": [1.0, 0.0, 0.0],
|
|
"orange": [0.0, 1.0, 0.0],
|
|
}
|
|
)
|
|
|
|
e1 = await embedder.embed("apple")
|
|
e2 = await embedder.embed("orange")
|
|
|
|
assert e1 != e2, "Different texts should have different embeddings"
|
|
|
|
|
|
class TestSegmentIdExtraction:
|
|
def test_extraction_preserves_order(self) -> None:
|
|
seg1 = MockSegment(segment_id=5, meeting_id=None, text="a", start_time=0.0, end_time=1.0)
|
|
seg2 = MockSegment(segment_id=2, meeting_id=None, text="b", start_time=1.0, end_time=2.0)
|
|
seg3 = MockSegment(segment_id=8, meeting_id=None, text="c", start_time=2.0, end_time=3.0)
|
|
retrieved = [(seg1, HIGH_SCORE), (seg2, MEDIUM_SCORE), (seg3, LOW_SCORE)]
|
|
|
|
ids = extract_segment_ids_from_retrieved(retrieved)
|
|
|
|
assert ids == [5, 2, 8], "Order must match retrieval ranking"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("retrieved", "relevant", "expected_mrr"),
|
|
[
|
|
([1], {1}, 1.0),
|
|
([2, 1], {1}, 0.5),
|
|
([3, 2, 1], {1}, 1 / 3),
|
|
([4, 5, 6], {1}, 0.0),
|
|
],
|
|
ids=["pos1", "pos2", "pos3", "not-found"],
|
|
)
|
|
def test_mrr_by_position(
|
|
retrieved: list[int],
|
|
relevant: set[int],
|
|
expected_mrr: float,
|
|
) -> None:
|
|
metrics = calculate_retrieval_metrics(retrieved, relevant, k=TOP_K)
|
|
assert metrics.mean_reciprocal_rank == approx_float(expected_mrr), "MRR mismatch"
|