591 lines
17 KiB
Python
591 lines
17 KiB
Python
"""Shared fixtures and utilities for pipeline evaluation tests."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Sequence
|
|
from dataclasses import dataclass, field
|
|
from typing import Final, Protocol
|
|
from uuid import UUID, uuid4
|
|
|
|
import pytest
|
|
|
|
from noteflow.domain.entities.named_entity import EntityCategory, NamedEntity
|
|
from noteflow.domain.entities.segment import Segment
|
|
from noteflow.domain.value_objects import MeetingId
|
|
|
|
DEFAULT_EMBEDDING_DIM: Final[int] = 1536
|
|
DEFAULT_SAMPLE_RATE: Final[int] = 16000
|
|
DEFAULT_TOP_K: Final[int] = 8
|
|
MIN_CONFIDENCE: Final[float] = 0.0
|
|
MAX_CONFIDENCE: Final[float] = 1.0
|
|
PERFECT_SCORE: Final[float] = 1.0
|
|
ACCEPTABLE_PRECISION: Final[float] = 0.7
|
|
ACCEPTABLE_RECALL: Final[float] = 0.7
|
|
ACCEPTABLE_F1: Final[float] = 0.7
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class TranscriptSegment:
|
|
segment_id: int
|
|
text: str
|
|
start_time: float
|
|
end_time: float
|
|
speaker_id: str | None = None
|
|
|
|
|
|
SAMPLE_MEETING_TRANSCRIPT: Final[tuple[TranscriptSegment, ...]] = (
|
|
TranscriptSegment(
|
|
segment_id=1,
|
|
text="Good morning everyone. Let's start the product review meeting.",
|
|
start_time=0.0,
|
|
end_time=4.5,
|
|
speaker_id="speaker_0",
|
|
),
|
|
TranscriptSegment(
|
|
segment_id=2,
|
|
text="I'm John from engineering. We've completed the API refactoring last week.",
|
|
start_time=5.0,
|
|
end_time=10.2,
|
|
speaker_id="speaker_1",
|
|
),
|
|
TranscriptSegment(
|
|
segment_id=3,
|
|
text="Sarah here from design. The new UI mockups are ready for review.",
|
|
start_time=11.0,
|
|
end_time=15.8,
|
|
speaker_id="speaker_2",
|
|
),
|
|
TranscriptSegment(
|
|
segment_id=4,
|
|
text="We need to discuss the deadline. The launch is scheduled for March 15th.",
|
|
start_time=16.5,
|
|
end_time=22.0,
|
|
speaker_id="speaker_0",
|
|
),
|
|
TranscriptSegment(
|
|
segment_id=5,
|
|
text="Acme Corporation has requested an early demo next Tuesday.",
|
|
start_time=23.0,
|
|
end_time=28.5,
|
|
speaker_id="speaker_1",
|
|
),
|
|
TranscriptSegment(
|
|
segment_id=6,
|
|
text="I'll prepare the presentation slides. Can we meet at 2 PM tomorrow?",
|
|
start_time=29.0,
|
|
end_time=34.2,
|
|
speaker_id="speaker_2",
|
|
),
|
|
TranscriptSegment(
|
|
segment_id=7,
|
|
text="Action item: John to send the technical documentation to the client by Friday.",
|
|
start_time=35.0,
|
|
end_time=41.0,
|
|
speaker_id="speaker_0",
|
|
),
|
|
TranscriptSegment(
|
|
segment_id=8,
|
|
text="The budget is $50,000 for this quarter. We're currently at 60% utilization.",
|
|
start_time=42.0,
|
|
end_time=48.5,
|
|
speaker_id="speaker_1",
|
|
),
|
|
)
|
|
|
|
EXPECTED_ENTITIES_IN_SAMPLE: Final[tuple[tuple[str, EntityCategory], ...]] = (
|
|
("John", EntityCategory.PERSON),
|
|
("Sarah", EntityCategory.PERSON),
|
|
("Acme Corporation", EntityCategory.COMPANY),
|
|
("API", EntityCategory.TECHNICAL),
|
|
("March 15th", EntityCategory.DATE),
|
|
("next Tuesday", EntityCategory.DATE),
|
|
("Friday", EntityCategory.DATE),
|
|
("2 PM tomorrow", EntityCategory.DATE),
|
|
)
|
|
|
|
EXPECTED_SPEAKERS_IN_SAMPLE: Final[tuple[str, ...]] = (
|
|
"speaker_0",
|
|
"speaker_1",
|
|
"speaker_2",
|
|
)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class PrecisionRecallMetrics:
|
|
precision: float
|
|
recall: float
|
|
f1_score: float
|
|
true_positives: int
|
|
false_positives: int
|
|
false_negatives: int
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class RetrievalMetrics:
|
|
mean_reciprocal_rank: float
|
|
precision_at_k: float
|
|
recall_at_k: float
|
|
ndcg: float
|
|
hit_rate: float
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class AnswerQualityMetrics:
|
|
citation_accuracy: float
|
|
answer_relevance: float
|
|
factual_grounding: float
|
|
citation_count: int
|
|
invalid_citations: int
|
|
|
|
|
|
@dataclass
|
|
class EvaluationResult:
|
|
test_name: str
|
|
passed: bool
|
|
metrics: dict[str, float] = field(default_factory=dict)
|
|
details: str = ""
|
|
|
|
|
|
class MockSegmentLike(Protocol):
|
|
segment_id: int
|
|
meeting_id: MeetingId | None
|
|
text: str
|
|
start_time: float
|
|
end_time: float
|
|
|
|
|
|
@dataclass
|
|
class MockSegment:
|
|
segment_id: int
|
|
meeting_id: MeetingId | None
|
|
text: str
|
|
start_time: float
|
|
end_time: float
|
|
embedding: list[float] | None = None
|
|
|
|
|
|
class ConfigurableEmbedder:
|
|
def __init__(
|
|
self,
|
|
embeddings: dict[str, list[float]] | None = None,
|
|
default_embedding: list[float] | None = None,
|
|
) -> None:
|
|
self._embeddings = embeddings or {}
|
|
self._default = default_embedding or [0.1] * DEFAULT_EMBEDDING_DIM
|
|
self.embed_calls: list[str] = []
|
|
|
|
async def embed(self, text: str) -> list[float]:
|
|
self.embed_calls.append(text)
|
|
return self._embeddings.get(text, self._default.copy())
|
|
|
|
async def embed_batch(self, texts: Sequence[str]) -> list[list[float]]:
|
|
return [await self.embed(t) for t in texts]
|
|
|
|
|
|
class ConfigurableLLM:
|
|
def __init__(
|
|
self,
|
|
responses: dict[str, str] | None = None,
|
|
default_response: str = "Based on the transcript [1], the answer is relevant.",
|
|
) -> None:
|
|
self._responses = responses or {}
|
|
self._default = default_response
|
|
self.complete_calls: list[str] = []
|
|
|
|
async def complete(self, prompt: str) -> str:
|
|
self.complete_calls.append(prompt)
|
|
for key, response in self._responses.items():
|
|
if key in prompt:
|
|
return response
|
|
return self._default
|
|
|
|
|
|
class ConfigurableSegmentRepository:
|
|
def __init__(
|
|
self,
|
|
segments: list[tuple[MockSegment, float]] | None = None,
|
|
) -> None:
|
|
self._segments = segments or []
|
|
self.search_calls: list[tuple[list[float], int, MeetingId | None]] = []
|
|
|
|
async def search_semantic(
|
|
self,
|
|
query_embedding: list[float],
|
|
limit: int,
|
|
meeting_id: MeetingId | None,
|
|
) -> Sequence[tuple[MockSegment, float]]:
|
|
self.search_calls.append((query_embedding, limit, meeting_id))
|
|
filtered = [
|
|
(seg, score)
|
|
for seg, score in self._segments
|
|
if meeting_id is None or seg.meeting_id == meeting_id
|
|
]
|
|
return filtered[:limit]
|
|
|
|
|
|
class ConfigurableWorkspaceSegmentRepository:
|
|
def __init__(
|
|
self,
|
|
segments: list[tuple[MockSegment, float]] | None = None,
|
|
) -> None:
|
|
self._segments = segments or []
|
|
self.search_calls: list[tuple[list[float], UUID, UUID | None, int]] = []
|
|
|
|
async def search_semantic_workspace(
|
|
self,
|
|
query_embedding: list[float],
|
|
workspace_id: UUID,
|
|
project_id: UUID | None,
|
|
limit: int,
|
|
) -> Sequence[tuple[MockSegment, float]]:
|
|
self.search_calls.append((query_embedding, workspace_id, project_id, limit))
|
|
return self._segments[:limit]
|
|
|
|
|
|
def calculate_precision_recall(
|
|
predicted: set[str],
|
|
expected: set[str],
|
|
) -> PrecisionRecallMetrics:
|
|
true_positives = len(predicted & expected)
|
|
false_positives = len(predicted - expected)
|
|
false_negatives = len(expected - predicted)
|
|
|
|
precision = true_positives / (true_positives + false_positives) if predicted else 0.0
|
|
recall = true_positives / (true_positives + false_negatives) if expected else 0.0
|
|
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
|
|
|
|
return PrecisionRecallMetrics(
|
|
precision=precision,
|
|
recall=recall,
|
|
f1_score=f1,
|
|
true_positives=true_positives,
|
|
false_positives=false_positives,
|
|
false_negatives=false_negatives,
|
|
)
|
|
|
|
|
|
def calculate_retrieval_metrics(
|
|
retrieved_ids: list[int],
|
|
relevant_ids: set[int],
|
|
k: int,
|
|
) -> RetrievalMetrics:
|
|
if not relevant_ids:
|
|
return RetrievalMetrics(
|
|
mean_reciprocal_rank=0.0,
|
|
precision_at_k=0.0,
|
|
recall_at_k=0.0,
|
|
ndcg=0.0,
|
|
hit_rate=0.0,
|
|
)
|
|
|
|
mrr = 0.0
|
|
for i, rid in enumerate(retrieved_ids[:k], start=1):
|
|
if rid in relevant_ids:
|
|
mrr = 1.0 / i
|
|
break
|
|
|
|
hits_at_k = sum(1 for rid in retrieved_ids[:k] if rid in relevant_ids)
|
|
precision_at_k = hits_at_k / k if k > 0 else 0.0
|
|
recall_at_k = hits_at_k / len(relevant_ids)
|
|
hit_rate = 1.0 if hits_at_k > 0 else 0.0
|
|
|
|
dcg = sum(
|
|
(1.0 if rid in relevant_ids else 0.0) / _log2(i + 1)
|
|
for i, rid in enumerate(retrieved_ids[:k], start=1)
|
|
)
|
|
ideal_dcg = sum(1.0 / _log2(i + 1) for i in range(1, min(k, len(relevant_ids)) + 1))
|
|
ndcg = dcg / ideal_dcg if ideal_dcg > 0 else 0.0
|
|
|
|
return RetrievalMetrics(
|
|
mean_reciprocal_rank=mrr,
|
|
precision_at_k=precision_at_k,
|
|
recall_at_k=recall_at_k,
|
|
ndcg=ndcg,
|
|
hit_rate=hit_rate,
|
|
)
|
|
|
|
|
|
def _log2(x: float) -> float:
|
|
import math
|
|
|
|
return math.log2(x) if x > 0 else 0.0
|
|
|
|
|
|
def calculate_citation_accuracy(
|
|
cited_ids: list[int],
|
|
valid_ids: set[int],
|
|
) -> AnswerQualityMetrics:
|
|
valid_citations = [cid for cid in cited_ids if cid in valid_ids]
|
|
invalid_citations = len(cited_ids) - len(valid_citations)
|
|
accuracy = len(valid_citations) / len(cited_ids) if cited_ids else 1.0
|
|
|
|
return AnswerQualityMetrics(
|
|
citation_accuracy=accuracy,
|
|
answer_relevance=0.0,
|
|
factual_grounding=0.0,
|
|
citation_count=len(cited_ids),
|
|
invalid_citations=invalid_citations,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_meeting_id() -> MeetingId:
|
|
return MeetingId(uuid4())
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_workspace_id() -> UUID:
|
|
return uuid4()
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_project_id() -> UUID:
|
|
return uuid4()
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_transcript() -> tuple[TranscriptSegment, ...]:
|
|
return SAMPLE_MEETING_TRANSCRIPT
|
|
|
|
|
|
@pytest.fixture
|
|
def expected_entities() -> tuple[tuple[str, EntityCategory], ...]:
|
|
return EXPECTED_ENTITIES_IN_SAMPLE
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_segments(sample_meeting_id: MeetingId) -> list[MockSegment]:
|
|
return [
|
|
MockSegment(
|
|
segment_id=seg.segment_id,
|
|
meeting_id=sample_meeting_id,
|
|
text=seg.text,
|
|
start_time=seg.start_time,
|
|
end_time=seg.end_time,
|
|
)
|
|
for seg in SAMPLE_MEETING_TRANSCRIPT
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def configurable_embedder() -> ConfigurableEmbedder:
|
|
return ConfigurableEmbedder()
|
|
|
|
|
|
@pytest.fixture
|
|
def configurable_llm() -> ConfigurableLLM:
|
|
return ConfigurableLLM()
|
|
|
|
|
|
@pytest.fixture
|
|
def segment_repo_with_data(
|
|
sample_segments: list[MockSegment],
|
|
) -> ConfigurableSegmentRepository:
|
|
scored_segments = [(seg, 0.9 - i * 0.05) for i, seg in enumerate(sample_segments)]
|
|
return ConfigurableSegmentRepository(scored_segments)
|
|
|
|
|
|
@pytest.fixture
|
|
def workspace_repo_with_data(
|
|
sample_segments: list[MockSegment],
|
|
) -> ConfigurableWorkspaceSegmentRepository:
|
|
scored_segments = [(seg, 0.9 - i * 0.05) for i, seg in enumerate(sample_segments)]
|
|
return ConfigurableWorkspaceSegmentRepository(scored_segments)
|
|
|
|
|
|
def create_domain_segment(
|
|
segment_id: int,
|
|
meeting_id: MeetingId,
|
|
text: str,
|
|
start_time: float,
|
|
end_time: float,
|
|
speaker_id: str | None = None,
|
|
) -> Segment:
|
|
return Segment(
|
|
segment_id=segment_id,
|
|
db_id=segment_id,
|
|
meeting_id=meeting_id,
|
|
text=text,
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
speaker_id=speaker_id,
|
|
)
|
|
|
|
|
|
def create_named_entity(
|
|
text: str,
|
|
category: EntityCategory,
|
|
segment_ids: list[int],
|
|
confidence: float = 0.9,
|
|
) -> NamedEntity:
|
|
return NamedEntity(
|
|
text=text,
|
|
category=category,
|
|
segment_ids=segment_ids,
|
|
confidence=confidence,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helper functions for test assertions (avoids loops in test bodies)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def extract_entity_texts(entities: Sequence[NamedEntity]) -> set[str]:
|
|
"""Extract text from entities as a set."""
|
|
return {e.text for e in entities}
|
|
|
|
|
|
def extract_entity_categories(entities: Sequence[NamedEntity]) -> dict[str, EntityCategory]:
|
|
"""Map entity text to category."""
|
|
return {e.text: e.category for e in entities}
|
|
|
|
|
|
def extract_entity_confidences(entities: Sequence[NamedEntity]) -> dict[str, float]:
|
|
"""Map entity text to confidence."""
|
|
return {e.text: e.confidence for e in entities}
|
|
|
|
|
|
def get_expected_entity_texts() -> set[str]:
|
|
"""Get expected entity texts from sample data."""
|
|
return {text for text, _ in EXPECTED_ENTITIES_IN_SAMPLE}
|
|
|
|
|
|
def get_expected_entity_categories() -> dict[str, EntityCategory]:
|
|
"""Get expected entity categories from sample data."""
|
|
return dict(EXPECTED_ENTITIES_IN_SAMPLE)
|
|
|
|
|
|
def build_segment_tuples_for_ner(
|
|
transcript: tuple[TranscriptSegment, ...],
|
|
) -> list[tuple[int, str]]:
|
|
"""Build (segment_id, text) tuples for NER engine."""
|
|
return [(seg.segment_id, seg.text) for seg in transcript]
|
|
|
|
|
|
def extract_speaker_ids(segments: Sequence[TranscriptSegment]) -> set[str]:
|
|
"""Extract unique speaker IDs from segments."""
|
|
return {seg.speaker_id for seg in segments if seg.speaker_id is not None}
|
|
|
|
|
|
def build_time_ranges(
|
|
segments: Sequence[TranscriptSegment],
|
|
) -> list[tuple[float, float]]:
|
|
"""Build (start, end) time ranges from segments."""
|
|
return [(seg.start_time, seg.end_time) for seg in segments]
|
|
|
|
|
|
def count_category_occurrences(
|
|
entities: Sequence[NamedEntity],
|
|
category: EntityCategory,
|
|
) -> int:
|
|
"""Count entities of a specific category."""
|
|
return sum(1 for e in entities if e.category == category)
|
|
|
|
|
|
def filter_entities_by_category(
|
|
entities: Sequence[NamedEntity],
|
|
category: EntityCategory,
|
|
) -> list[NamedEntity]:
|
|
"""Filter entities by category."""
|
|
return [e for e in entities if e.category == category]
|
|
|
|
|
|
def extract_segment_ids_from_retrieved(
|
|
retrieved: Sequence[tuple[MockSegment, float]],
|
|
) -> list[int]:
|
|
"""Extract segment IDs from retrieval results."""
|
|
return [seg.segment_id for seg, _ in retrieved]
|
|
|
|
|
|
def map_speaker_to_segments(
|
|
segments: Sequence[TranscriptSegment],
|
|
) -> dict[str, list[int]]:
|
|
"""Map speaker_id to list of segment_ids."""
|
|
result: dict[str, list[int]] = {}
|
|
for seg in segments:
|
|
if seg.speaker_id is not None:
|
|
if seg.speaker_id not in result:
|
|
result[seg.speaker_id] = []
|
|
result[seg.speaker_id].append(seg.segment_id)
|
|
return result
|
|
|
|
|
|
def get_segments_for_speaker(
|
|
segments: Sequence[TranscriptSegment],
|
|
speaker_id: str,
|
|
) -> list[TranscriptSegment]:
|
|
"""Get all segments for a specific speaker."""
|
|
return [seg for seg in segments if seg.speaker_id == speaker_id]
|
|
|
|
|
|
def calculate_speaker_coverage(
|
|
assigned_speakers: Sequence[str | None],
|
|
expected_speakers: set[str],
|
|
) -> float:
|
|
"""Calculate what fraction of expected speakers were assigned."""
|
|
assigned_set = {s for s in assigned_speakers if s is not None}
|
|
if not expected_speakers:
|
|
return 1.0
|
|
return len(assigned_set & expected_speakers) / len(expected_speakers)
|
|
|
|
|
|
def count_speaker_matches(
|
|
segments: Sequence[TranscriptSegment],
|
|
assigned_speakers: Sequence[str | None],
|
|
) -> int:
|
|
"""Count how many assigned speakers match expected."""
|
|
matches = 0
|
|
for i, seg in enumerate(segments):
|
|
if i < len(assigned_speakers) and seg.speaker_id == assigned_speakers[i]:
|
|
matches += 1
|
|
return matches
|
|
|
|
|
|
def extract_expected_speakers_from_segments(
|
|
segments: Sequence[TranscriptSegment],
|
|
) -> list[str | None]:
|
|
"""Extract speaker_id from each segment in order."""
|
|
return [seg.speaker_id for seg in segments]
|
|
|
|
|
|
def extract_citations_from_answer(answer: str) -> list[int]:
|
|
"""Extract citation numbers [N] from an answer string."""
|
|
import re
|
|
|
|
return [int(m) for m in re.findall(r"\[(\d+)\]", answer)]
|
|
|
|
|
|
def build_context_from_segments(
|
|
segments: Sequence[MockSegment],
|
|
limit: int,
|
|
) -> str:
|
|
"""Build context string from segments for synthesis."""
|
|
lines: list[str] = []
|
|
for i, seg in enumerate(segments[:limit], start=1):
|
|
lines.append(f"[{i}] {seg.text}")
|
|
return "\n".join(lines)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Pre-computed test data constants
|
|
# ---------------------------------------------------------------------------
|
|
|
|
EXPECTED_ENTITY_TEXTS: Final[set[str]] = get_expected_entity_texts()
|
|
EXPECTED_ENTITY_CATEGORIES: Final[dict[str, EntityCategory]] = get_expected_entity_categories()
|
|
SAMPLE_SEGMENT_TUPLES: Final[list[tuple[int, str]]] = build_segment_tuples_for_ner(
|
|
SAMPLE_MEETING_TRANSCRIPT
|
|
)
|
|
SAMPLE_TIME_RANGES: Final[list[tuple[float, float]]] = build_time_ranges(SAMPLE_MEETING_TRANSCRIPT)
|
|
SAMPLE_SPEAKER_TO_SEGMENTS: Final[dict[str, list[int]]] = map_speaker_to_segments(
|
|
SAMPLE_MEETING_TRANSCRIPT
|
|
)
|
|
|
|
|
|
def get_all_transcript_text() -> str:
|
|
"""Concatenate all transcript segment texts."""
|
|
return " ".join(seg.text for seg in SAMPLE_MEETING_TRANSCRIPT)
|
|
|
|
|
|
SAMPLE_ALL_TEXT: Final[str] = get_all_transcript_text()
|