Files
noteflow/tests/evaluation/test_diarization.py
Travis Vasceannie d8090a98e8
Some checks failed
CI / test-typescript (push) Has been cancelled
CI / test-rust (push) Has been cancelled
CI / test-python (push) Has been cancelled
ci/cd fixes
2026-01-26 00:28:15 +00:00

244 lines
11 KiB
Python

"""Speaker diarization evaluation tests."""
from __future__ import annotations
from typing import Final
import pytest
from noteflow.infrastructure.diarization.assigner import assign_speaker, assign_speakers_batch
from noteflow.infrastructure.diarization.dto import SpeakerTurn
from tests.conftest import approx_float
from .conftest import (
SAMPLE_MEETING_TRANSCRIPT,
SAMPLE_TIME_RANGES,
calculate_speaker_coverage,
count_speaker_matches,
extract_expected_speakers_from_segments,
extract_speaker_ids,
get_segments_for_speaker,
)
SEGMENT_START_AFTER_TURN: Final[float] = 10.0
TURN_NO_OVERLAP_SEGMENT_END: Final[float] = 15.0
class TestSpeakerTurnOverlap:
def test_turn_overlaps_with_contained_segment(self) -> None:
turn = SpeakerTurn(speaker="SPEAKER_00", start=0.0, end=10.0)
assert turn.overlaps(2.0, 5.0), "Segment fully inside turn should overlap"
def test_turn_overlaps_with_containing_segment(self) -> None:
turn = SpeakerTurn(speaker="SPEAKER_00", start=2.0, end=5.0)
assert turn.overlaps(0.0, 10.0), "Segment containing turn should overlap"
def test_turn_overlaps_with_partial_start(self) -> None:
turn = SpeakerTurn(speaker="SPEAKER_00", start=5.0, end=10.0)
assert turn.overlaps(3.0, 7.0), "Partial overlap at start should count"
def test_turn_overlaps_with_partial_end(self) -> None:
turn = SpeakerTurn(speaker="SPEAKER_00", start=0.0, end=5.0)
assert turn.overlaps(3.0, 10.0), "Partial overlap at end should count"
def test_turn_no_overlap_before(self) -> None:
turn = SpeakerTurn(speaker="SPEAKER_00", start=10.0, end=15.0)
assert not turn.overlaps(0.0, 5.0), "No overlap when segment is before turn"
def test_turn_no_overlap_after(self) -> None:
turn = SpeakerTurn(speaker="SPEAKER_00", start=0.0, end=5.0)
assert not turn.overlaps(SEGMENT_START_AFTER_TURN, TURN_NO_OVERLAP_SEGMENT_END), (
"No overlap when segment is after turn"
)
def test_turn_no_overlap_adjacent(self) -> None:
turn = SpeakerTurn(speaker="SPEAKER_00", start=0.0, end=5.0)
assert not turn.overlaps(5.0, 10.0), "Adjacent segments should not overlap"
class TestSpeakerTurnOverlapDuration:
def test_full_overlap_returns_segment_duration(self) -> None:
turn = SpeakerTurn(speaker="SPEAKER_00", start=0.0, end=10.0)
duration = turn.overlap_duration(2.0, 5.0)
assert duration == approx_float(3.0), "Full overlap should be segment duration"
def test_partial_overlap_start_calculates_correctly(self) -> None:
turn = SpeakerTurn(speaker="SPEAKER_00", start=5.0, end=10.0)
duration = turn.overlap_duration(3.0, 7.0)
assert duration == approx_float(2.0), "Overlap should be 5.0 to 7.0 = 2.0s"
def test_partial_overlap_end_calculates_correctly(self) -> None:
turn = SpeakerTurn(speaker="SPEAKER_00", start=0.0, end=5.0)
duration = turn.overlap_duration(3.0, 10.0)
assert duration == approx_float(2.0), "Overlap should be 3.0 to 5.0 = 2.0s"
def test_no_overlap_returns_zero(self) -> None:
turn = SpeakerTurn(speaker="SPEAKER_00", start=0.0, end=5.0)
duration = turn.overlap_duration(10.0, 15.0)
assert duration == 0.0, "No overlap should return 0.0"
class TestAssignSpeaker:
def test_single_turn_exact_match(self) -> None:
turns = [SpeakerTurn(speaker="SPEAKER_00", start=0.0, end=5.0)]
speaker, confidence = assign_speaker(0.0, 5.0, turns)
assert speaker == "SPEAKER_00", "Should assign the matching speaker"
assert confidence == approx_float(1.0), "Exact match should have confidence 1.0"
def test_multiple_turns_selects_max_overlap(self) -> None:
# Segment: 0.0 to 5.0 (5 seconds)
# SPEAKER_00: 0.0 to 2.0 → overlap = 2.0 seconds
# SPEAKER_01: 1.0 to 10.0 → overlap = min(10.0, 5.0) - max(1.0, 0.0) = 4.0 seconds
turns = [
SpeakerTurn(speaker="SPEAKER_00", start=0.0, end=2.0),
SpeakerTurn(speaker="SPEAKER_01", start=1.0, end=10.0),
]
speaker, _ = assign_speaker(0.0, 5.0, turns)
assert speaker == "SPEAKER_01", "Should select speaker with most overlap"
def test_no_overlapping_turns_returns_none(self) -> None:
turns = [SpeakerTurn(speaker="SPEAKER_00", start=10.0, end=15.0)]
speaker, confidence = assign_speaker(0.0, 5.0, turns)
assert speaker is None, "No overlap should return None"
assert confidence == 0.0, "No overlap should have confidence 0.0"
def test_assign_speaker_empty_turns_returns_none(self) -> None:
turns: list[SpeakerTurn] = []
speaker, confidence = assign_speaker(0.0, 5.0, turns)
assert speaker is None, "Empty turns should return None"
assert confidence == 0.0, "Empty turns should have confidence 0.0"
def test_assign_speaker_zero_duration_segment_returns_none(self) -> None:
turns = [SpeakerTurn(speaker="SPEAKER_00", start=0.0, end=10.0)]
speaker, confidence = assign_speaker(5.0, 5.0, turns)
assert speaker is None, "Zero duration segment should return None"
assert confidence == 0.0, "Zero duration should have confidence 0.0"
class TestAssignSpeakersBatch:
def test_batch_assigns_all_segments(self) -> None:
turns = [
SpeakerTurn(speaker="SPEAKER_00", start=0.0, end=5.0),
SpeakerTurn(speaker="SPEAKER_01", start=5.0, end=10.0),
]
segments = [(0.0, 5.0), (5.0, 10.0)]
results = assign_speakers_batch(segments, turns)
assert len(results) == 2, "Should return result for each segment"
assert results[0][0] == "SPEAKER_00", "First segment should be SPEAKER_00"
assert results[1][0] == "SPEAKER_01", "Second segment should be SPEAKER_01"
def test_batch_handles_empty_segments(self) -> None:
turns = [SpeakerTurn(speaker="SPEAKER_00", start=0.0, end=10.0)]
segments: list[tuple[float, float]] = []
results = assign_speakers_batch(segments, turns)
assert results == [], "Empty segments should return empty list"
class TestSpeakerCoverageMetrics:
def test_full_coverage_returns_one(self) -> None:
assigned = ["speaker_0", "speaker_1", "speaker_2"]
expected = {"speaker_0", "speaker_1", "speaker_2"}
coverage = calculate_speaker_coverage(assigned, expected)
assert coverage == 1.0, "All speakers found should give 100% coverage"
def test_partial_coverage_calculates_correctly(self) -> None:
assigned = ["speaker_0", "speaker_1"]
expected = {"speaker_0", "speaker_1", "speaker_2"}
coverage = calculate_speaker_coverage(assigned, expected)
assert coverage == approx_float(2 / 3), "2 of 3 speakers should be 66.7%"
def test_no_coverage_returns_zero(self) -> None:
assigned = ["unknown"]
expected = {"speaker_0", "speaker_1"}
coverage = calculate_speaker_coverage(assigned, expected)
assert coverage == 0.0, "No matching speakers should be 0%"
def test_none_values_filtered_correctly(self) -> None:
assigned: list[str | None] = ["speaker_0", None, "speaker_1", None]
expected = {"speaker_0", "speaker_1"}
coverage = calculate_speaker_coverage(assigned, expected)
assert coverage == 1.0, "None values should be filtered, all speakers found"
class TestSpeakerMatchCounting:
def test_all_matches_counted(self) -> None:
segments = SAMPLE_MEETING_TRANSCRIPT[:3]
assigned = extract_expected_speakers_from_segments(segments)
matches = count_speaker_matches(segments, assigned)
assert matches == 3, "All 3 should match"
def test_no_matches_returns_zero(self) -> None:
segments = SAMPLE_MEETING_TRANSCRIPT[:3]
assigned = ["wrong_0", "wrong_1", "wrong_2"]
matches = count_speaker_matches(segments, assigned)
assert matches == 0, "No matches when all speakers wrong"
def test_partial_matches_counted(self) -> None:
segments = SAMPLE_MEETING_TRANSCRIPT[:3]
first_speaker = SAMPLE_MEETING_TRANSCRIPT[0].speaker_id
third_speaker = SAMPLE_MEETING_TRANSCRIPT[2].speaker_id
assigned = [first_speaker, "wrong", third_speaker]
matches = count_speaker_matches(segments, assigned)
assert matches == 2, "Should count 2 matches (first and third)"
class TestSampleDataSpeakers:
def test_extract_speaker_ids_from_sample(self) -> None:
speakers = extract_speaker_ids(SAMPLE_MEETING_TRANSCRIPT)
assert speakers == {"speaker_0", "speaker_1", "speaker_2"}, "Should have 3 speakers"
def test_sample_time_ranges_has_correct_count(self) -> None:
assert len(SAMPLE_TIME_RANGES) == 8, "Should have 8 time ranges"
def test_sample_time_ranges_are_sequential(self) -> None:
first_start, _ = SAMPLE_TIME_RANGES[0]
last_start, _ = SAMPLE_TIME_RANGES[-1]
assert first_start < last_start, "Time ranges should be sequential"
def test_get_segments_for_speaker_returns_correct_subset(self) -> None:
speaker_0_segments = get_segments_for_speaker(SAMPLE_MEETING_TRANSCRIPT, "speaker_0")
assert len(speaker_0_segments) == 3, "speaker_0 should have 3 segments"
def test_get_segments_for_unknown_speaker_returns_empty(self) -> None:
unknown_segments = get_segments_for_speaker(SAMPLE_MEETING_TRANSCRIPT, "unknown")
assert unknown_segments == [], "Unknown speaker should return empty list"
@pytest.mark.parametrize(
("turn_start", "turn_end", "seg_start", "seg_end", "expected_overlap"),
[
(0.0, 10.0, 0.0, 10.0, 10.0),
(0.0, 10.0, 2.0, 8.0, 6.0),
(5.0, 15.0, 0.0, 10.0, 5.0),
(0.0, 5.0, 5.0, 10.0, 0.0),
(0.0, 5.0, 10.0, 15.0, 0.0),
],
ids=["exact-match", "contained", "partial", "adjacent", "disjoint"],
)
def test_overlap_duration_parametrized(
turn_start: float,
turn_end: float,
seg_start: float,
seg_end: float,
expected_overlap: float,
) -> None:
turn = SpeakerTurn(speaker="SPEAKER_00", start=turn_start, end=turn_end)
duration = turn.overlap_duration(seg_start, seg_end)
assert duration == approx_float(expected_overlap), "Unexpected overlap"
@pytest.mark.parametrize(
("segment_idx", "expected_speaker"),
[
(0, "speaker_0"),
(1, "speaker_1"),
(2, "speaker_2"),
(3, "speaker_0"),
(4, "speaker_1"),
],
ids=["seg-1", "seg-2", "seg-3", "seg-4", "seg-5"],
)
def test_sample_transcript_speaker_pattern(segment_idx: int, expected_speaker: str) -> None:
segment = SAMPLE_MEETING_TRANSCRIPT[segment_idx]
assert segment.speaker_id == expected_speaker, "Speaker mismatch"