Files
noteflow/tests/evaluation/test_prompt_tuning.py
Travis Vasceannie d8090a98e8
Some checks failed
CI / test-typescript (push) Has been cancelled
CI / test-rust (push) Has been cancelled
CI / test-python (push) Has been cancelled
ci/cd fixes
2026-01-26 00:28:15 +00:00

150 lines
5.6 KiB
Python

"""Prompt tuning harness tests."""
from __future__ import annotations
from typing import Final
import pytest
from tests.conftest import approx_float
from .conftest import (
ConfigurableLLM,
MockSegment,
build_context_from_segments,
calculate_citation_accuracy,
extract_citations_from_answer,
)
VALID_SEGMENT_IDS: Final[set[int]] = {1, 2, 3, 4, 5}
CONTEXT_LIMIT_TWO: Final[int] = 2
CONTEXT_LIMIT_FIVE: Final[int] = 5
TWO_THIRDS: Final[float] = 2 / 3
ONE_HALF: Final[float] = 0.5
class TestConfigurableLLM:
@pytest.mark.asyncio
async def test_returns_default_response(self) -> None:
llm = ConfigurableLLM()
result = await llm.complete("any prompt")
assert "[1]" in result, "Default response should have citation"
@pytest.mark.asyncio
async def test_returns_configured_response(self) -> None:
llm = ConfigurableLLM(responses={"budget": "The budget is $50,000 [8]."})
result = await llm.complete("What is the budget?")
assert "50,000" in result, "Should return configured response"
assert "[8]" in result, "Should have citation [8]"
@pytest.mark.asyncio
async def test_tracks_complete_calls(self) -> None:
llm = ConfigurableLLM()
await llm.complete("first prompt")
await llm.complete("second prompt")
assert len(llm.complete_calls) == 2, "Should track 2 calls"
assert llm.complete_calls[0] == "first prompt", "First call recorded"
assert llm.complete_calls[1] == "second prompt", "Second call recorded"
@pytest.mark.asyncio
async def test_matches_partial_key(self) -> None:
llm = ConfigurableLLM(responses={"deadline": "March 15th [4]."})
result = await llm.complete("When is the deadline scheduled?")
assert "March 15th" in result, "Should match on partial key"
class TestContextBuilding:
def test_builds_numbered_context(self) -> None:
seg1 = MockSegment(
segment_id=1, meeting_id=None, text="Hello", start_time=0.0, end_time=1.0
)
seg2 = MockSegment(
segment_id=2, meeting_id=None, text="World", start_time=1.0, end_time=2.0
)
context = build_context_from_segments([seg1, seg2], limit=CONTEXT_LIMIT_TWO)
assert "[1] Hello" in context, "First segment numbered"
assert "[2] World" in context, "Second segment numbered"
def test_respects_limit(self) -> None:
seg1 = MockSegment(segment_id=1, meeting_id=None, text="A", start_time=0.0, end_time=1.0)
seg2 = MockSegment(segment_id=2, meeting_id=None, text="B", start_time=1.0, end_time=2.0)
seg3 = MockSegment(segment_id=3, meeting_id=None, text="C", start_time=2.0, end_time=3.0)
context = build_context_from_segments([seg1, seg2, seg3], limit=CONTEXT_LIMIT_TWO)
assert "[1] A" in context, "First included"
assert "[2] B" in context, "Second included"
assert "[3]" not in context, "Third excluded by limit"
def test_empty_segments_empty_context(self) -> None:
context = build_context_from_segments([], limit=CONTEXT_LIMIT_FIVE)
assert context == "", "Empty input yields empty context"
class TestAnswerCitationQuality:
def test_answer_all_citations_valid(self) -> None:
answer = "The budget [1] and deadline [2] were discussed [3]."
citations = extract_citations_from_answer(answer)
metrics = calculate_citation_accuracy(citations, VALID_SEGMENT_IDS)
assert metrics.citation_accuracy == 1.0, "All citations valid"
assert metrics.citation_count == 3, "3 citations"
def test_mixed_valid_invalid(self) -> None:
answer = "See [1] and [99] and [2]."
citations = extract_citations_from_answer(answer)
metrics = calculate_citation_accuracy(citations, VALID_SEGMENT_IDS)
assert metrics.citation_accuracy == approx_float(TWO_THIRDS), "2 valid citations"
assert metrics.invalid_citations == 1, "1 invalid"
def test_no_citations_perfect_accuracy(self) -> None:
answer = "This answer has no citations."
citations = extract_citations_from_answer(answer)
metrics = calculate_citation_accuracy(citations, VALID_SEGMENT_IDS)
assert metrics.citation_accuracy == 1.0, "Vacuously true"
assert metrics.citation_count == 0, "0 citations"
class TestPromptVariations:
@pytest.mark.asyncio
async def test_different_prompts_same_key(self) -> None:
llm = ConfigurableLLM(responses={"action": "John sends docs [7]."})
r1 = await llm.complete("What are the action items?")
r2 = await llm.complete("List action items")
assert r1 == r2, "Same key matches both prompts"
@pytest.mark.asyncio
async def test_no_match_uses_default(self) -> None:
llm = ConfigurableLLM(
responses={"budget": "Budget response"},
default_response="Default [1]",
)
result = await llm.complete("unrelated question")
assert result == "Default [1]", "Falls back to default"
@pytest.mark.parametrize(
("answer", "valid_ids", "expected_accuracy"),
[
("Text [1] [2]", {1, 2, 3}, 1.0),
("Text [1] [99]", {1, 2, 3}, ONE_HALF),
("Text [99] [100]", {1, 2, 3}, 0.0),
("No citations", {1, 2, 3}, 1.0),
],
ids=["all-valid", "half-valid", "none-valid", "no-citations"],
)
def test_citation_accuracy_parametrized(
answer: str,
valid_ids: set[int],
expected_accuracy: float,
) -> None:
citations = extract_citations_from_answer(answer)
metrics = calculate_citation_accuracy(citations, valid_ids)
assert metrics.citation_accuracy == approx_float(expected_accuracy), "Accuracy mismatch"