Files
noteflow/tests/evaluation/conftest.py
Travis Vasceannie d8090a98e8
Some checks failed
CI / test-typescript (push) Has been cancelled
CI / test-rust (push) Has been cancelled
CI / test-python (push) Has been cancelled
ci/cd fixes
2026-01-26 00:28:15 +00:00

591 lines
17 KiB
Python

"""Shared fixtures and utilities for pipeline evaluation tests."""
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Final, Protocol
from uuid import UUID, uuid4
import pytest
from noteflow.domain.entities.named_entity import EntityCategory, NamedEntity
from noteflow.domain.entities.segment import Segment
from noteflow.domain.value_objects import MeetingId
DEFAULT_EMBEDDING_DIM: Final[int] = 1536
DEFAULT_SAMPLE_RATE: Final[int] = 16000
DEFAULT_TOP_K: Final[int] = 8
MIN_CONFIDENCE: Final[float] = 0.0
MAX_CONFIDENCE: Final[float] = 1.0
PERFECT_SCORE: Final[float] = 1.0
ACCEPTABLE_PRECISION: Final[float] = 0.7
ACCEPTABLE_RECALL: Final[float] = 0.7
ACCEPTABLE_F1: Final[float] = 0.7
@dataclass(frozen=True)
class TranscriptSegment:
segment_id: int
text: str
start_time: float
end_time: float
speaker_id: str | None = None
SAMPLE_MEETING_TRANSCRIPT: Final[tuple[TranscriptSegment, ...]] = (
TranscriptSegment(
segment_id=1,
text="Good morning everyone. Let's start the product review meeting.",
start_time=0.0,
end_time=4.5,
speaker_id="speaker_0",
),
TranscriptSegment(
segment_id=2,
text="I'm John from engineering. We've completed the API refactoring last week.",
start_time=5.0,
end_time=10.2,
speaker_id="speaker_1",
),
TranscriptSegment(
segment_id=3,
text="Sarah here from design. The new UI mockups are ready for review.",
start_time=11.0,
end_time=15.8,
speaker_id="speaker_2",
),
TranscriptSegment(
segment_id=4,
text="We need to discuss the deadline. The launch is scheduled for March 15th.",
start_time=16.5,
end_time=22.0,
speaker_id="speaker_0",
),
TranscriptSegment(
segment_id=5,
text="Acme Corporation has requested an early demo next Tuesday.",
start_time=23.0,
end_time=28.5,
speaker_id="speaker_1",
),
TranscriptSegment(
segment_id=6,
text="I'll prepare the presentation slides. Can we meet at 2 PM tomorrow?",
start_time=29.0,
end_time=34.2,
speaker_id="speaker_2",
),
TranscriptSegment(
segment_id=7,
text="Action item: John to send the technical documentation to the client by Friday.",
start_time=35.0,
end_time=41.0,
speaker_id="speaker_0",
),
TranscriptSegment(
segment_id=8,
text="The budget is $50,000 for this quarter. We're currently at 60% utilization.",
start_time=42.0,
end_time=48.5,
speaker_id="speaker_1",
),
)
EXPECTED_ENTITIES_IN_SAMPLE: Final[tuple[tuple[str, EntityCategory], ...]] = (
("John", EntityCategory.PERSON),
("Sarah", EntityCategory.PERSON),
("Acme Corporation", EntityCategory.COMPANY),
("API", EntityCategory.TECHNICAL),
("March 15th", EntityCategory.DATE),
("next Tuesday", EntityCategory.DATE),
("Friday", EntityCategory.DATE),
("2 PM tomorrow", EntityCategory.DATE),
)
EXPECTED_SPEAKERS_IN_SAMPLE: Final[tuple[str, ...]] = (
"speaker_0",
"speaker_1",
"speaker_2",
)
@dataclass(frozen=True)
class PrecisionRecallMetrics:
precision: float
recall: float
f1_score: float
true_positives: int
false_positives: int
false_negatives: int
@dataclass(frozen=True)
class RetrievalMetrics:
mean_reciprocal_rank: float
precision_at_k: float
recall_at_k: float
ndcg: float
hit_rate: float
@dataclass(frozen=True)
class AnswerQualityMetrics:
citation_accuracy: float
answer_relevance: float
factual_grounding: float
citation_count: int
invalid_citations: int
@dataclass
class EvaluationResult:
test_name: str
passed: bool
metrics: dict[str, float] = field(default_factory=dict)
details: str = ""
class MockSegmentLike(Protocol):
segment_id: int
meeting_id: MeetingId | None
text: str
start_time: float
end_time: float
@dataclass
class MockSegment:
segment_id: int
meeting_id: MeetingId | None
text: str
start_time: float
end_time: float
embedding: list[float] | None = None
class ConfigurableEmbedder:
def __init__(
self,
embeddings: dict[str, list[float]] | None = None,
default_embedding: list[float] | None = None,
) -> None:
self._embeddings = embeddings or {}
self._default = default_embedding or [0.1] * DEFAULT_EMBEDDING_DIM
self.embed_calls: list[str] = []
async def embed(self, text: str) -> list[float]:
self.embed_calls.append(text)
return self._embeddings.get(text, self._default.copy())
async def embed_batch(self, texts: Sequence[str]) -> list[list[float]]:
return [await self.embed(t) for t in texts]
class ConfigurableLLM:
def __init__(
self,
responses: dict[str, str] | None = None,
default_response: str = "Based on the transcript [1], the answer is relevant.",
) -> None:
self._responses = responses or {}
self._default = default_response
self.complete_calls: list[str] = []
async def complete(self, prompt: str) -> str:
self.complete_calls.append(prompt)
for key, response in self._responses.items():
if key in prompt:
return response
return self._default
class ConfigurableSegmentRepository:
def __init__(
self,
segments: list[tuple[MockSegment, float]] | None = None,
) -> None:
self._segments = segments or []
self.search_calls: list[tuple[list[float], int, MeetingId | None]] = []
async def search_semantic(
self,
query_embedding: list[float],
limit: int,
meeting_id: MeetingId | None,
) -> Sequence[tuple[MockSegment, float]]:
self.search_calls.append((query_embedding, limit, meeting_id))
filtered = [
(seg, score)
for seg, score in self._segments
if meeting_id is None or seg.meeting_id == meeting_id
]
return filtered[:limit]
class ConfigurableWorkspaceSegmentRepository:
def __init__(
self,
segments: list[tuple[MockSegment, float]] | None = None,
) -> None:
self._segments = segments or []
self.search_calls: list[tuple[list[float], UUID, UUID | None, int]] = []
async def search_semantic_workspace(
self,
query_embedding: list[float],
workspace_id: UUID,
project_id: UUID | None,
limit: int,
) -> Sequence[tuple[MockSegment, float]]:
self.search_calls.append((query_embedding, workspace_id, project_id, limit))
return self._segments[:limit]
def calculate_precision_recall(
predicted: set[str],
expected: set[str],
) -> PrecisionRecallMetrics:
true_positives = len(predicted & expected)
false_positives = len(predicted - expected)
false_negatives = len(expected - predicted)
precision = true_positives / (true_positives + false_positives) if predicted else 0.0
recall = true_positives / (true_positives + false_negatives) if expected else 0.0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
return PrecisionRecallMetrics(
precision=precision,
recall=recall,
f1_score=f1,
true_positives=true_positives,
false_positives=false_positives,
false_negatives=false_negatives,
)
def calculate_retrieval_metrics(
retrieved_ids: list[int],
relevant_ids: set[int],
k: int,
) -> RetrievalMetrics:
if not relevant_ids:
return RetrievalMetrics(
mean_reciprocal_rank=0.0,
precision_at_k=0.0,
recall_at_k=0.0,
ndcg=0.0,
hit_rate=0.0,
)
mrr = 0.0
for i, rid in enumerate(retrieved_ids[:k], start=1):
if rid in relevant_ids:
mrr = 1.0 / i
break
hits_at_k = sum(1 for rid in retrieved_ids[:k] if rid in relevant_ids)
precision_at_k = hits_at_k / k if k > 0 else 0.0
recall_at_k = hits_at_k / len(relevant_ids)
hit_rate = 1.0 if hits_at_k > 0 else 0.0
dcg = sum(
(1.0 if rid in relevant_ids else 0.0) / _log2(i + 1)
for i, rid in enumerate(retrieved_ids[:k], start=1)
)
ideal_dcg = sum(1.0 / _log2(i + 1) for i in range(1, min(k, len(relevant_ids)) + 1))
ndcg = dcg / ideal_dcg if ideal_dcg > 0 else 0.0
return RetrievalMetrics(
mean_reciprocal_rank=mrr,
precision_at_k=precision_at_k,
recall_at_k=recall_at_k,
ndcg=ndcg,
hit_rate=hit_rate,
)
def _log2(x: float) -> float:
import math
return math.log2(x) if x > 0 else 0.0
def calculate_citation_accuracy(
cited_ids: list[int],
valid_ids: set[int],
) -> AnswerQualityMetrics:
valid_citations = [cid for cid in cited_ids if cid in valid_ids]
invalid_citations = len(cited_ids) - len(valid_citations)
accuracy = len(valid_citations) / len(cited_ids) if cited_ids else 1.0
return AnswerQualityMetrics(
citation_accuracy=accuracy,
answer_relevance=0.0,
factual_grounding=0.0,
citation_count=len(cited_ids),
invalid_citations=invalid_citations,
)
@pytest.fixture
def sample_meeting_id() -> MeetingId:
return MeetingId(uuid4())
@pytest.fixture
def sample_workspace_id() -> UUID:
return uuid4()
@pytest.fixture
def sample_project_id() -> UUID:
return uuid4()
@pytest.fixture
def sample_transcript() -> tuple[TranscriptSegment, ...]:
return SAMPLE_MEETING_TRANSCRIPT
@pytest.fixture
def expected_entities() -> tuple[tuple[str, EntityCategory], ...]:
return EXPECTED_ENTITIES_IN_SAMPLE
@pytest.fixture
def sample_segments(sample_meeting_id: MeetingId) -> list[MockSegment]:
return [
MockSegment(
segment_id=seg.segment_id,
meeting_id=sample_meeting_id,
text=seg.text,
start_time=seg.start_time,
end_time=seg.end_time,
)
for seg in SAMPLE_MEETING_TRANSCRIPT
]
@pytest.fixture
def configurable_embedder() -> ConfigurableEmbedder:
return ConfigurableEmbedder()
@pytest.fixture
def configurable_llm() -> ConfigurableLLM:
return ConfigurableLLM()
@pytest.fixture
def segment_repo_with_data(
sample_segments: list[MockSegment],
) -> ConfigurableSegmentRepository:
scored_segments = [(seg, 0.9 - i * 0.05) for i, seg in enumerate(sample_segments)]
return ConfigurableSegmentRepository(scored_segments)
@pytest.fixture
def workspace_repo_with_data(
sample_segments: list[MockSegment],
) -> ConfigurableWorkspaceSegmentRepository:
scored_segments = [(seg, 0.9 - i * 0.05) for i, seg in enumerate(sample_segments)]
return ConfigurableWorkspaceSegmentRepository(scored_segments)
def create_domain_segment(
segment_id: int,
meeting_id: MeetingId,
text: str,
start_time: float,
end_time: float,
speaker_id: str | None = None,
) -> Segment:
return Segment(
segment_id=segment_id,
db_id=segment_id,
meeting_id=meeting_id,
text=text,
start_time=start_time,
end_time=end_time,
speaker_id=speaker_id,
)
def create_named_entity(
text: str,
category: EntityCategory,
segment_ids: list[int],
confidence: float = 0.9,
) -> NamedEntity:
return NamedEntity(
text=text,
category=category,
segment_ids=segment_ids,
confidence=confidence,
)
# ---------------------------------------------------------------------------
# Helper functions for test assertions (avoids loops in test bodies)
# ---------------------------------------------------------------------------
def extract_entity_texts(entities: Sequence[NamedEntity]) -> set[str]:
"""Extract text from entities as a set."""
return {e.text for e in entities}
def extract_entity_categories(entities: Sequence[NamedEntity]) -> dict[str, EntityCategory]:
"""Map entity text to category."""
return {e.text: e.category for e in entities}
def extract_entity_confidences(entities: Sequence[NamedEntity]) -> dict[str, float]:
"""Map entity text to confidence."""
return {e.text: e.confidence for e in entities}
def get_expected_entity_texts() -> set[str]:
"""Get expected entity texts from sample data."""
return {text for text, _ in EXPECTED_ENTITIES_IN_SAMPLE}
def get_expected_entity_categories() -> dict[str, EntityCategory]:
"""Get expected entity categories from sample data."""
return dict(EXPECTED_ENTITIES_IN_SAMPLE)
def build_segment_tuples_for_ner(
transcript: tuple[TranscriptSegment, ...],
) -> list[tuple[int, str]]:
"""Build (segment_id, text) tuples for NER engine."""
return [(seg.segment_id, seg.text) for seg in transcript]
def extract_speaker_ids(segments: Sequence[TranscriptSegment]) -> set[str]:
"""Extract unique speaker IDs from segments."""
return {seg.speaker_id for seg in segments if seg.speaker_id is not None}
def build_time_ranges(
segments: Sequence[TranscriptSegment],
) -> list[tuple[float, float]]:
"""Build (start, end) time ranges from segments."""
return [(seg.start_time, seg.end_time) for seg in segments]
def count_category_occurrences(
entities: Sequence[NamedEntity],
category: EntityCategory,
) -> int:
"""Count entities of a specific category."""
return sum(1 for e in entities if e.category == category)
def filter_entities_by_category(
entities: Sequence[NamedEntity],
category: EntityCategory,
) -> list[NamedEntity]:
"""Filter entities by category."""
return [e for e in entities if e.category == category]
def extract_segment_ids_from_retrieved(
retrieved: Sequence[tuple[MockSegment, float]],
) -> list[int]:
"""Extract segment IDs from retrieval results."""
return [seg.segment_id for seg, _ in retrieved]
def map_speaker_to_segments(
segments: Sequence[TranscriptSegment],
) -> dict[str, list[int]]:
"""Map speaker_id to list of segment_ids."""
result: dict[str, list[int]] = {}
for seg in segments:
if seg.speaker_id is not None:
if seg.speaker_id not in result:
result[seg.speaker_id] = []
result[seg.speaker_id].append(seg.segment_id)
return result
def get_segments_for_speaker(
segments: Sequence[TranscriptSegment],
speaker_id: str,
) -> list[TranscriptSegment]:
"""Get all segments for a specific speaker."""
return [seg for seg in segments if seg.speaker_id == speaker_id]
def calculate_speaker_coverage(
assigned_speakers: Sequence[str | None],
expected_speakers: set[str],
) -> float:
"""Calculate what fraction of expected speakers were assigned."""
assigned_set = {s for s in assigned_speakers if s is not None}
if not expected_speakers:
return 1.0
return len(assigned_set & expected_speakers) / len(expected_speakers)
def count_speaker_matches(
segments: Sequence[TranscriptSegment],
assigned_speakers: Sequence[str | None],
) -> int:
"""Count how many assigned speakers match expected."""
matches = 0
for i, seg in enumerate(segments):
if i < len(assigned_speakers) and seg.speaker_id == assigned_speakers[i]:
matches += 1
return matches
def extract_expected_speakers_from_segments(
segments: Sequence[TranscriptSegment],
) -> list[str | None]:
"""Extract speaker_id from each segment in order."""
return [seg.speaker_id for seg in segments]
def extract_citations_from_answer(answer: str) -> list[int]:
"""Extract citation numbers [N] from an answer string."""
import re
return [int(m) for m in re.findall(r"\[(\d+)\]", answer)]
def build_context_from_segments(
segments: Sequence[MockSegment],
limit: int,
) -> str:
"""Build context string from segments for synthesis."""
lines: list[str] = []
for i, seg in enumerate(segments[:limit], start=1):
lines.append(f"[{i}] {seg.text}")
return "\n".join(lines)
# ---------------------------------------------------------------------------
# Pre-computed test data constants
# ---------------------------------------------------------------------------
EXPECTED_ENTITY_TEXTS: Final[set[str]] = get_expected_entity_texts()
EXPECTED_ENTITY_CATEGORIES: Final[dict[str, EntityCategory]] = get_expected_entity_categories()
SAMPLE_SEGMENT_TUPLES: Final[list[tuple[int, str]]] = build_segment_tuples_for_ner(
SAMPLE_MEETING_TRANSCRIPT
)
SAMPLE_TIME_RANGES: Final[list[tuple[float, float]]] = build_time_ranges(SAMPLE_MEETING_TRANSCRIPT)
SAMPLE_SPEAKER_TO_SEGMENTS: Final[dict[str, list[int]]] = map_speaker_to_segments(
SAMPLE_MEETING_TRANSCRIPT
)
def get_all_transcript_text() -> str:
"""Concatenate all transcript segment texts."""
return " ".join(seg.text for seg in SAMPLE_MEETING_TRANSCRIPT)
SAMPLE_ALL_TEXT: Final[str] = get_all_transcript_text()