150 lines
4.3 KiB
Python
150 lines
4.3 KiB
Python
from unittest.mock import AsyncMock
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
|
|
from noteflow.infrastructure.ai.tools.retrieval import RetrievalResult
|
|
from noteflow.infrastructure.ai.tools.synthesis import (
|
|
SynthesisResult,
|
|
extract_cited_ids,
|
|
synthesize_answer,
|
|
)
|
|
|
|
|
|
class TestSynthesizeAnswer:
|
|
@pytest.fixture
|
|
def mock_llm(self) -> AsyncMock:
|
|
return AsyncMock()
|
|
|
|
@pytest.fixture
|
|
def sample_segments(self) -> list[RetrievalResult]:
|
|
meeting_id = uuid4()
|
|
return [
|
|
RetrievalResult(
|
|
segment_id=1,
|
|
meeting_id=meeting_id,
|
|
text="John discussed the project timeline",
|
|
start_time=0.0,
|
|
end_time=5.0,
|
|
score=0.95,
|
|
),
|
|
RetrievalResult(
|
|
segment_id=3,
|
|
meeting_id=meeting_id,
|
|
text="The deadline is next Friday",
|
|
start_time=10.0,
|
|
end_time=15.0,
|
|
score=0.85,
|
|
),
|
|
]
|
|
|
|
async def test_synthesize_answer_returns_result(
|
|
self,
|
|
mock_llm: AsyncMock,
|
|
sample_segments: list[RetrievalResult],
|
|
) -> None:
|
|
mock_llm.complete.return_value = "The project deadline is next Friday [3]."
|
|
|
|
result = await synthesize_answer(
|
|
question="What is the deadline?",
|
|
segments=sample_segments,
|
|
llm=mock_llm,
|
|
)
|
|
|
|
assert isinstance(result, SynthesisResult)
|
|
assert "deadline" in result.answer.lower()
|
|
|
|
async def test_synthesize_answer_extracts_citations(
|
|
self,
|
|
mock_llm: AsyncMock,
|
|
sample_segments: list[RetrievalResult],
|
|
) -> None:
|
|
mock_llm.complete.return_value = "John discussed timelines [1] and the deadline [3]."
|
|
|
|
result = await synthesize_answer(
|
|
question="What happened?",
|
|
segments=sample_segments,
|
|
llm=mock_llm,
|
|
)
|
|
|
|
assert result.cited_segment_ids == [1, 3]
|
|
|
|
async def test_synthesize_answer_filters_invalid_citations(
|
|
self,
|
|
mock_llm: AsyncMock,
|
|
sample_segments: list[RetrievalResult],
|
|
) -> None:
|
|
mock_llm.complete.return_value = "Found [1], [99], and [3]."
|
|
|
|
result = await synthesize_answer(
|
|
question="What happened?",
|
|
segments=sample_segments,
|
|
llm=mock_llm,
|
|
)
|
|
|
|
assert 99 not in result.cited_segment_ids
|
|
assert result.cited_segment_ids == [1, 3]
|
|
|
|
async def test_synthesize_answer_builds_prompt_with_segments(
|
|
self,
|
|
mock_llm: AsyncMock,
|
|
sample_segments: list[RetrievalResult],
|
|
) -> None:
|
|
mock_llm.complete.return_value = "Answer."
|
|
|
|
await synthesize_answer(
|
|
question="What is happening?",
|
|
segments=sample_segments,
|
|
llm=mock_llm,
|
|
)
|
|
|
|
call_args = mock_llm.complete.call_args
|
|
prompt = call_args[0][0]
|
|
assert "What is happening?" in prompt
|
|
assert "[1]" in prompt
|
|
assert "[3]" in prompt
|
|
assert "John discussed" in prompt
|
|
|
|
|
|
class TestExtractCitedIds:
|
|
def test_extracts_single_citation(self) -> None:
|
|
result = extract_cited_ids("The answer is here [5].", {1, 3, 5})
|
|
|
|
assert result == [5]
|
|
|
|
def test_extracts_multiple_citations(self) -> None:
|
|
result = extract_cited_ids("See [1] and [3] for details.", {1, 3, 5})
|
|
|
|
assert result == [1, 3]
|
|
|
|
def test_filters_invalid_ids(self) -> None:
|
|
result = extract_cited_ids("See [1] and [99].", {1, 3, 5})
|
|
|
|
assert result == [1]
|
|
|
|
def test_deduplicates_citations(self) -> None:
|
|
result = extract_cited_ids("See [1] and then [1] again.", {1, 3})
|
|
|
|
assert result == [1]
|
|
|
|
def test_preserves_order(self) -> None:
|
|
result = extract_cited_ids("[3] comes first, then [1].", {1, 3})
|
|
|
|
assert result == [3, 1]
|
|
|
|
def test_empty_for_no_citations(self) -> None:
|
|
result = extract_cited_ids("No citations here.", {1, 3})
|
|
|
|
assert result == []
|
|
|
|
|
|
class TestSynthesisResult:
|
|
def test_is_frozen(self) -> None:
|
|
result = SynthesisResult(
|
|
answer="Test answer",
|
|
cited_segment_ids=[1, 2],
|
|
)
|
|
|
|
with pytest.raises(AttributeError):
|
|
result.answer = "Modified" # type: ignore[misc]
|