269 lines
8.4 KiB
Python
269 lines
8.4 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import cast
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
|
|
from noteflow.domain.value_objects import MeetingId
|
|
from noteflow.infrastructure.ai.graphs.meeting_qa import (
|
|
MeetingQADependencies,
|
|
MeetingQAInternalState,
|
|
build_meeting_qa_graph,
|
|
)
|
|
|
|
from .conftest import (
|
|
DEFAULT_TOP_K,
|
|
INVALID_SEGMENT_ID,
|
|
MIN_EXPECTED_CALLS,
|
|
SCORE_HIGH,
|
|
SCORE_MEDIUM,
|
|
SEGMENT_END_FIFTEEN,
|
|
SEGMENT_END_FIVE,
|
|
SEGMENT_ID_ONE,
|
|
SEGMENT_ID_TWO,
|
|
SEGMENT_START_TEN,
|
|
SEGMENT_START_ZERO,
|
|
MeetingQAInputFactory,
|
|
MockEmbedder,
|
|
MockLLM,
|
|
MockSegmentLike,
|
|
MockSegmentRepository,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def qa_sample_meeting_id() -> MeetingId:
|
|
return MeetingId(uuid4())
|
|
|
|
|
|
@pytest.fixture
|
|
def qa_sample_segments(qa_sample_meeting_id: MeetingId) -> list[tuple[MockSegmentLike, float]]:
|
|
return [
|
|
(
|
|
MockSegmentLike(
|
|
segment_id=SEGMENT_ID_ONE,
|
|
meeting_id=qa_sample_meeting_id,
|
|
text="John discussed the project timeline and milestones.",
|
|
start_time=SEGMENT_START_ZERO,
|
|
end_time=SEGMENT_END_FIVE,
|
|
),
|
|
SCORE_HIGH,
|
|
),
|
|
(
|
|
MockSegmentLike(
|
|
segment_id=SEGMENT_ID_TWO,
|
|
meeting_id=qa_sample_meeting_id,
|
|
text="The deadline is next Friday for the first deliverable.",
|
|
start_time=SEGMENT_START_TEN,
|
|
end_time=SEGMENT_END_FIFTEEN,
|
|
),
|
|
SCORE_MEDIUM,
|
|
),
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def qa_mock_segment_repo(
|
|
qa_sample_segments: list[tuple[MockSegmentLike, float]],
|
|
) -> MockSegmentRepository:
|
|
return MockSegmentRepository(qa_sample_segments)
|
|
|
|
|
|
@pytest.fixture
|
|
def qa_deps(
|
|
graph_mock_embedder: MockEmbedder,
|
|
qa_mock_segment_repo: MockSegmentRepository,
|
|
graph_mock_llm: MockLLM,
|
|
) -> MeetingQADependencies:
|
|
return MeetingQADependencies(
|
|
embedder=graph_mock_embedder,
|
|
segment_repo=qa_mock_segment_repo,
|
|
llm=graph_mock_llm,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def qa_deps_with_citations(
|
|
graph_mock_embedder: MockEmbedder,
|
|
qa_mock_segment_repo: MockSegmentRepository,
|
|
graph_mock_llm_with_citations: MockLLM,
|
|
) -> MeetingQADependencies:
|
|
return MeetingQADependencies(
|
|
embedder=graph_mock_embedder,
|
|
segment_repo=qa_mock_segment_repo,
|
|
llm=graph_mock_llm_with_citations,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def qa_deps_invalid_citation(
|
|
graph_mock_embedder: MockEmbedder,
|
|
qa_mock_segment_repo: MockSegmentRepository,
|
|
graph_mock_llm_with_invalid_citation: MockLLM,
|
|
) -> MeetingQADependencies:
|
|
return MeetingQADependencies(
|
|
embedder=graph_mock_embedder,
|
|
segment_repo=qa_mock_segment_repo,
|
|
llm=graph_mock_llm_with_invalid_citation,
|
|
)
|
|
|
|
|
|
async def _invoke_graph(
|
|
deps: MeetingQADependencies,
|
|
question: str,
|
|
meeting_id: MeetingId,
|
|
top_k: int = DEFAULT_TOP_K,
|
|
) -> dict[str, object]:
|
|
graph = build_meeting_qa_graph(deps)
|
|
# Cast needed: LangGraph accepts partial input state at runtime
|
|
input_state = cast(
|
|
MeetingQAInternalState,
|
|
MeetingQAInputFactory.create(question, meeting_id, top_k),
|
|
)
|
|
result = await graph.ainvoke(input_state)
|
|
return dict(result)
|
|
|
|
|
|
@pytest.mark.integration
|
|
class TestMeetingQAGraphExecution:
|
|
async def test_meeting_graph_produces_answer_key(
|
|
self,
|
|
qa_deps_with_citations: MeetingQADependencies,
|
|
qa_sample_meeting_id: MeetingId,
|
|
) -> None:
|
|
result = await _invoke_graph(
|
|
qa_deps_with_citations,
|
|
"What was discussed about the timeline?",
|
|
qa_sample_meeting_id,
|
|
)
|
|
|
|
assert "answer" in result, "Result should contain answer key"
|
|
|
|
async def test_meeting_graph_produces_citations_key(
|
|
self,
|
|
qa_deps_with_citations: MeetingQADependencies,
|
|
qa_sample_meeting_id: MeetingId,
|
|
) -> None:
|
|
result = await _invoke_graph(
|
|
qa_deps_with_citations,
|
|
"What was discussed about the timeline?",
|
|
qa_sample_meeting_id,
|
|
)
|
|
|
|
assert "citations" in result, "Result should contain citations key"
|
|
|
|
async def test_meeting_graph_answer_contains_expected_content(
|
|
self,
|
|
qa_deps_with_citations: MeetingQADependencies,
|
|
qa_sample_meeting_id: MeetingId,
|
|
) -> None:
|
|
result = await _invoke_graph(
|
|
qa_deps_with_citations,
|
|
"What was discussed about the timeline?",
|
|
qa_sample_meeting_id,
|
|
)
|
|
|
|
answer = str(result["answer"])
|
|
assert "John" in answer, "Answer should contain expected content"
|
|
|
|
async def test_meeting_graph_calls_embedder_at_least_once(
|
|
self,
|
|
qa_deps: MeetingQADependencies,
|
|
qa_sample_meeting_id: MeetingId,
|
|
graph_mock_embedder: MockEmbedder,
|
|
) -> None:
|
|
await _invoke_graph(qa_deps, "What is the project deadline?", qa_sample_meeting_id)
|
|
|
|
assert len(graph_mock_embedder.embed_calls) >= MIN_EXPECTED_CALLS, (
|
|
"Embedder should be called at least once"
|
|
)
|
|
|
|
async def test_meeting_graph_passes_question_to_embedder(
|
|
self,
|
|
qa_deps: MeetingQADependencies,
|
|
qa_sample_meeting_id: MeetingId,
|
|
graph_mock_embedder: MockEmbedder,
|
|
) -> None:
|
|
question = "What is the project deadline?"
|
|
await _invoke_graph(qa_deps, question, qa_sample_meeting_id)
|
|
|
|
assert question in graph_mock_embedder.embed_calls, "Embedder should receive the question"
|
|
|
|
async def test_meeting_graph_calls_segment_repo(
|
|
self,
|
|
qa_deps: MeetingQADependencies,
|
|
qa_sample_meeting_id: MeetingId,
|
|
qa_mock_segment_repo: MockSegmentRepository,
|
|
) -> None:
|
|
await _invoke_graph(qa_deps, "What was discussed?", qa_sample_meeting_id)
|
|
|
|
assert len(qa_mock_segment_repo.search_calls) >= MIN_EXPECTED_CALLS, (
|
|
"Segment repo should be called"
|
|
)
|
|
|
|
async def test_meeting_graph_passes_meeting_id_to_repo(
|
|
self,
|
|
qa_deps: MeetingQADependencies,
|
|
qa_sample_meeting_id: MeetingId,
|
|
qa_mock_segment_repo: MockSegmentRepository,
|
|
) -> None:
|
|
await _invoke_graph(qa_deps, "What was discussed?", qa_sample_meeting_id)
|
|
|
|
_, _, called_meeting_id = qa_mock_segment_repo.search_calls[0]
|
|
assert called_meeting_id == qa_sample_meeting_id, "Meeting ID should be passed to repo"
|
|
|
|
async def test_meeting_graph_calls_llm(
|
|
self,
|
|
qa_deps: MeetingQADependencies,
|
|
qa_sample_meeting_id: MeetingId,
|
|
graph_mock_llm: MockLLM,
|
|
) -> None:
|
|
await _invoke_graph(qa_deps, "What was discussed?", qa_sample_meeting_id)
|
|
|
|
assert len(graph_mock_llm.complete_calls) >= MIN_EXPECTED_CALLS, "LLM should be called"
|
|
|
|
async def test_meeting_graph_prompt_contains_segment_text(
|
|
self,
|
|
qa_deps: MeetingQADependencies,
|
|
qa_sample_meeting_id: MeetingId,
|
|
graph_mock_llm: MockLLM,
|
|
) -> None:
|
|
await _invoke_graph(qa_deps, "What was discussed?", qa_sample_meeting_id)
|
|
|
|
prompt = graph_mock_llm.complete_calls[0]
|
|
assert "John discussed" in prompt, "Prompt should contain segment text"
|
|
|
|
async def test_meeting_graph_prompt_contains_deadline_segment(
|
|
self,
|
|
qa_deps: MeetingQADependencies,
|
|
qa_sample_meeting_id: MeetingId,
|
|
graph_mock_llm: MockLLM,
|
|
) -> None:
|
|
await _invoke_graph(qa_deps, "What was discussed?", qa_sample_meeting_id)
|
|
|
|
prompt = graph_mock_llm.complete_calls[0]
|
|
assert "deadline" in prompt.lower(), "Prompt should contain deadline segment"
|
|
|
|
async def test_meeting_graph_produces_citations_without_invalid_id(
|
|
self,
|
|
qa_deps_invalid_citation: MeetingQADependencies,
|
|
qa_sample_meeting_id: MeetingId,
|
|
) -> None:
|
|
from noteflow.domain.ai.citations import SegmentCitation
|
|
|
|
result = await _invoke_graph(
|
|
qa_deps_invalid_citation,
|
|
"What was discussed?",
|
|
qa_sample_meeting_id,
|
|
)
|
|
|
|
raw_citations = result.get("citations")
|
|
assert raw_citations is not None, "Citations key should exist"
|
|
citations = cast(list[SegmentCitation], raw_citations)
|
|
assert len(citations) >= MIN_EXPECTED_CALLS, "Should have at least one citation"
|
|
first_citation = citations[0]
|
|
assert first_citation.segment_id != INVALID_SEGMENT_ID, (
|
|
"Citation should not have invalid segment ID"
|
|
)
|