Files
noteflow/tests/evaluation/test_rag_pipeline.py
Travis Vasceannie b11633192a
Some checks failed
CI / test-python (push) Failing after 22m14s
CI / test-rust (push) Has been cancelled
CI / test-typescript (push) Has been cancelled
deps
2026-01-24 21:31:58 +00:00

208 lines
7.4 KiB
Python

"""RAG pipeline evaluation tests."""
from __future__ import annotations
import asyncio
from typing import Final
from uuid import uuid4
import pytest
from noteflow.domain.value_objects import MeetingId
from tests.conftest import approx_float
from .conftest import (
DEFAULT_EMBEDDING_DIM,
ConfigurableEmbedder,
ConfigurableSegmentRepository,
MockSegment,
calculate_retrieval_metrics,
extract_segment_ids_from_retrieved,
)
HIGH_SCORE: Final[float] = 0.95
MEDIUM_SCORE: Final[float] = 0.85
LOW_SCORE: Final[float] = 0.75
SEARCH_LIMIT: Final[int] = 5
TOP_K: Final[int] = 3
class TestConfigurableEmbedder:
@pytest.mark.asyncio
async def test_embed_returns_default_embedding(self) -> None:
embedder = ConfigurableEmbedder()
result = await embedder.embed("test text")
assert len(result) == DEFAULT_EMBEDDING_DIM, "Default embedding dimension"
@pytest.mark.asyncio
async def test_embed_returns_configured_embedding(self) -> None:
custom = [1.0, 2.0, 3.0]
embedder = ConfigurableEmbedder(embeddings={"specific": custom})
result = await embedder.embed("specific")
assert result == custom, "Should return configured embedding"
@pytest.mark.asyncio
async def test_embed_tracks_calls(self) -> None:
embedder = ConfigurableEmbedder()
await embedder.embed("first")
await embedder.embed("second")
assert embedder.embed_calls == ["first", "second"], "Should track all calls"
@pytest.mark.asyncio
async def test_embed_batch_processes_multiple(self) -> None:
embedder = ConfigurableEmbedder()
results = await embedder.embed_batch(["a", "b", "c"])
assert len(results) == 3, "Should return 3 embeddings"
assert len(embedder.embed_calls) == 3, "Should track 3 calls"
class TestConfigurableSegmentRepository:
@pytest.mark.asyncio
async def test_search_returns_configured_segments(self) -> None:
meeting_id = MeetingId(uuid4())
seg = MockSegment(
segment_id=1, meeting_id=meeting_id, text="t", start_time=0.0, end_time=1.0
)
repo = ConfigurableSegmentRepository([(seg, HIGH_SCORE)])
results = await repo.search_semantic([0.1], SEARCH_LIMIT, meeting_id)
assert len(results) == 1, "Should return configured segment"
assert results[0][0].segment_id == 1, "Segment ID matches"
assert results[0][1] == HIGH_SCORE, "Score matches"
@pytest.mark.asyncio
async def test_search_filters_by_meeting_id(self) -> None:
meeting1 = MeetingId(uuid4())
meeting2 = MeetingId(uuid4())
seg1 = MockSegment(
segment_id=1, meeting_id=meeting1, text="a", start_time=0.0, end_time=1.0
)
seg2 = MockSegment(
segment_id=2, meeting_id=meeting2, text="b", start_time=0.0, end_time=1.0
)
repo = ConfigurableSegmentRepository([(seg1, HIGH_SCORE), (seg2, MEDIUM_SCORE)])
results = await repo.search_semantic([0.1], SEARCH_LIMIT, meeting1)
assert len(results) == 1, "Should filter to meeting1 only"
assert results[0][0].segment_id == 1, "Only seg1 matches"
@pytest.mark.asyncio
async def test_search_respects_limit(self) -> None:
meeting_id = MeetingId(uuid4())
seg1 = MockSegment(
segment_id=1, meeting_id=meeting_id, text="a", start_time=0.0, end_time=1.0
)
seg2 = MockSegment(
segment_id=2, meeting_id=meeting_id, text="b", start_time=1.0, end_time=2.0
)
seg3 = MockSegment(
segment_id=3, meeting_id=meeting_id, text="c", start_time=2.0, end_time=3.0
)
repo = ConfigurableSegmentRepository(
[(seg1, HIGH_SCORE), (seg2, MEDIUM_SCORE), (seg3, LOW_SCORE)]
)
results = await repo.search_semantic([0.1], 2, meeting_id)
assert len(results) == 2, "Should respect limit of 2"
@pytest.mark.asyncio
async def test_search_tracks_calls(self) -> None:
meeting_id = MeetingId(uuid4())
repo = ConfigurableSegmentRepository([])
await repo.search_semantic([0.1, 0.2], SEARCH_LIMIT, meeting_id)
assert len(repo.search_calls) == 1, "Should track call"
embedding, limit, mid = repo.search_calls[0]
assert embedding == [0.1, 0.2], "Embedding recorded"
assert limit == SEARCH_LIMIT, "Limit recorded"
assert mid == meeting_id, "Meeting ID recorded"
class TestRetrievalQuality:
def test_perfect_retrieval_all_relevant_first(self) -> None:
retrieved_ids = [1, 2, 3, 4, 5]
relevant_ids = {1, 2, 3}
metrics = calculate_retrieval_metrics(retrieved_ids, relevant_ids, k=TOP_K)
assert metrics.precision_at_k == 1.0, "All top-k are relevant"
assert metrics.recall_at_k == 1.0, "All relevant in top-k"
assert metrics.ndcg == approx_float(1.0), "Perfect ranking"
def test_poor_retrieval_relevant_at_end(self) -> None:
retrieved_ids = [10, 20, 1, 2, 3]
relevant_ids = {1, 2, 3}
metrics = calculate_retrieval_metrics(retrieved_ids, relevant_ids, k=TOP_K)
assert metrics.precision_at_k == approx_float(1 / 3), "Only 1 relevant in top-3"
assert metrics.recall_at_k == approx_float(1 / 3), "Only 1 relevant found"
def test_no_relevant_in_results(self) -> None:
retrieved_ids = [10, 20, 30]
relevant_ids = {1, 2, 3}
metrics = calculate_retrieval_metrics(retrieved_ids, relevant_ids, k=TOP_K)
assert metrics.precision_at_k == 0.0, "No relevant in results"
assert metrics.recall_at_k == 0.0, "No relevant found"
assert metrics.hit_rate == 0.0, "No hits"
class TestEmbeddingQuality:
@pytest.mark.asyncio
async def test_same_text_same_embedding(self) -> None:
embedder = ConfigurableEmbedder()
e1, e2 = await asyncio.gather(
embedder.embed("hello world"),
embedder.embed("hello world"),
)
assert e1 == e2, "Same text should yield same embedding"
@pytest.mark.asyncio
async def test_custom_embeddings_differ(self) -> None:
embedder = ConfigurableEmbedder(
embeddings={
"apple": [1.0, 0.0, 0.0],
"orange": [0.0, 1.0, 0.0],
}
)
e1 = await embedder.embed("apple")
e2 = await embedder.embed("orange")
assert e1 != e2, "Different texts should have different embeddings"
class TestSegmentIdExtraction:
def test_extraction_preserves_order(self) -> None:
seg1 = MockSegment(segment_id=5, meeting_id=None, text="a", start_time=0.0, end_time=1.0)
seg2 = MockSegment(segment_id=2, meeting_id=None, text="b", start_time=1.0, end_time=2.0)
seg3 = MockSegment(segment_id=8, meeting_id=None, text="c", start_time=2.0, end_time=3.0)
retrieved = [(seg1, HIGH_SCORE), (seg2, MEDIUM_SCORE), (seg3, LOW_SCORE)]
ids = extract_segment_ids_from_retrieved(retrieved)
assert ids == [5, 2, 8], "Order must match retrieval ranking"
@pytest.mark.parametrize(
("retrieved", "relevant", "expected_mrr"),
[
([1], {1}, 1.0),
([2, 1], {1}, 0.5),
([3, 2, 1], {1}, 1 / 3),
([4, 5, 6], {1}, 0.0),
],
ids=["pos1", "pos2", "pos3", "not-found"],
)
def test_mrr_by_position(
retrieved: list[int],
relevant: set[int],
expected_mrr: float,
) -> None:
metrics = calculate_retrieval_metrics(retrieved, relevant, k=TOP_K)
assert metrics.mean_reciprocal_rank == approx_float(expected_mrr), "MRR mismatch"