150 lines
5.6 KiB
Python
150 lines
5.6 KiB
Python
"""Prompt tuning harness tests."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Final
|
|
|
|
import pytest
|
|
|
|
from tests.conftest import approx_float
|
|
|
|
from .conftest import (
|
|
ConfigurableLLM,
|
|
MockSegment,
|
|
build_context_from_segments,
|
|
calculate_citation_accuracy,
|
|
extract_citations_from_answer,
|
|
)
|
|
|
|
VALID_SEGMENT_IDS: Final[set[int]] = {1, 2, 3, 4, 5}
|
|
CONTEXT_LIMIT_TWO: Final[int] = 2
|
|
CONTEXT_LIMIT_FIVE: Final[int] = 5
|
|
TWO_THIRDS: Final[float] = 2 / 3
|
|
ONE_HALF: Final[float] = 0.5
|
|
|
|
|
|
class TestConfigurableLLM:
|
|
@pytest.mark.asyncio
|
|
async def test_returns_default_response(self) -> None:
|
|
llm = ConfigurableLLM()
|
|
result = await llm.complete("any prompt")
|
|
assert "[1]" in result, "Default response should have citation"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_configured_response(self) -> None:
|
|
llm = ConfigurableLLM(responses={"budget": "The budget is $50,000 [8]."})
|
|
result = await llm.complete("What is the budget?")
|
|
assert "50,000" in result, "Should return configured response"
|
|
assert "[8]" in result, "Should have citation [8]"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tracks_complete_calls(self) -> None:
|
|
llm = ConfigurableLLM()
|
|
await llm.complete("first prompt")
|
|
await llm.complete("second prompt")
|
|
assert len(llm.complete_calls) == 2, "Should track 2 calls"
|
|
assert llm.complete_calls[0] == "first prompt", "First call recorded"
|
|
assert llm.complete_calls[1] == "second prompt", "Second call recorded"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_matches_partial_key(self) -> None:
|
|
llm = ConfigurableLLM(responses={"deadline": "March 15th [4]."})
|
|
result = await llm.complete("When is the deadline scheduled?")
|
|
assert "March 15th" in result, "Should match on partial key"
|
|
|
|
|
|
class TestContextBuilding:
|
|
def test_builds_numbered_context(self) -> None:
|
|
seg1 = MockSegment(
|
|
segment_id=1, meeting_id=None, text="Hello", start_time=0.0, end_time=1.0
|
|
)
|
|
seg2 = MockSegment(
|
|
segment_id=2, meeting_id=None, text="World", start_time=1.0, end_time=2.0
|
|
)
|
|
context = build_context_from_segments([seg1, seg2], limit=CONTEXT_LIMIT_TWO)
|
|
|
|
assert "[1] Hello" in context, "First segment numbered"
|
|
assert "[2] World" in context, "Second segment numbered"
|
|
|
|
def test_respects_limit(self) -> None:
|
|
seg1 = MockSegment(segment_id=1, meeting_id=None, text="A", start_time=0.0, end_time=1.0)
|
|
seg2 = MockSegment(segment_id=2, meeting_id=None, text="B", start_time=1.0, end_time=2.0)
|
|
seg3 = MockSegment(segment_id=3, meeting_id=None, text="C", start_time=2.0, end_time=3.0)
|
|
context = build_context_from_segments([seg1, seg2, seg3], limit=CONTEXT_LIMIT_TWO)
|
|
|
|
assert "[1] A" in context, "First included"
|
|
assert "[2] B" in context, "Second included"
|
|
assert "[3]" not in context, "Third excluded by limit"
|
|
|
|
def test_empty_segments_empty_context(self) -> None:
|
|
context = build_context_from_segments([], limit=CONTEXT_LIMIT_FIVE)
|
|
assert context == "", "Empty input yields empty context"
|
|
|
|
|
|
class TestAnswerCitationQuality:
|
|
def test_answer_all_citations_valid(self) -> None:
|
|
answer = "The budget [1] and deadline [2] were discussed [3]."
|
|
citations = extract_citations_from_answer(answer)
|
|
metrics = calculate_citation_accuracy(citations, VALID_SEGMENT_IDS)
|
|
|
|
assert metrics.citation_accuracy == 1.0, "All citations valid"
|
|
assert metrics.citation_count == 3, "3 citations"
|
|
|
|
def test_mixed_valid_invalid(self) -> None:
|
|
answer = "See [1] and [99] and [2]."
|
|
citations = extract_citations_from_answer(answer)
|
|
metrics = calculate_citation_accuracy(citations, VALID_SEGMENT_IDS)
|
|
|
|
assert metrics.citation_accuracy == approx_float(TWO_THIRDS), "2 valid citations"
|
|
assert metrics.invalid_citations == 1, "1 invalid"
|
|
|
|
def test_no_citations_perfect_accuracy(self) -> None:
|
|
answer = "This answer has no citations."
|
|
citations = extract_citations_from_answer(answer)
|
|
metrics = calculate_citation_accuracy(citations, VALID_SEGMENT_IDS)
|
|
|
|
assert metrics.citation_accuracy == 1.0, "Vacuously true"
|
|
assert metrics.citation_count == 0, "0 citations"
|
|
|
|
|
|
class TestPromptVariations:
|
|
@pytest.mark.asyncio
|
|
async def test_different_prompts_same_key(self) -> None:
|
|
llm = ConfigurableLLM(responses={"action": "John sends docs [7]."})
|
|
|
|
r1 = await llm.complete("What are the action items?")
|
|
r2 = await llm.complete("List action items")
|
|
|
|
assert r1 == r2, "Same key matches both prompts"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_match_uses_default(self) -> None:
|
|
llm = ConfigurableLLM(
|
|
responses={"budget": "Budget response"},
|
|
default_response="Default [1]",
|
|
)
|
|
|
|
result = await llm.complete("unrelated question")
|
|
|
|
assert result == "Default [1]", "Falls back to default"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("answer", "valid_ids", "expected_accuracy"),
|
|
[
|
|
("Text [1] [2]", {1, 2, 3}, 1.0),
|
|
("Text [1] [99]", {1, 2, 3}, ONE_HALF),
|
|
("Text [99] [100]", {1, 2, 3}, 0.0),
|
|
("No citations", {1, 2, 3}, 1.0),
|
|
],
|
|
ids=["all-valid", "half-valid", "none-valid", "no-citations"],
|
|
)
|
|
def test_citation_accuracy_parametrized(
|
|
answer: str,
|
|
valid_ids: set[int],
|
|
expected_accuracy: float,
|
|
) -> None:
|
|
citations = extract_citations_from_answer(answer)
|
|
metrics = calculate_citation_accuracy(citations, valid_ids)
|
|
assert metrics.citation_accuracy == approx_float(expected_accuracy), "Accuracy mismatch"
|