242 lines
9.8 KiB
Python
242 lines
9.8 KiB
Python
"""Tests for AskAssistant gRPC mixin."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, Final
|
|
from uuid import uuid4
|
|
|
|
import grpc
|
|
import pytest
|
|
|
|
from tests.conftest import approx_float
|
|
|
|
from noteflow.application.services.assistant import (
|
|
AssistantResponse,
|
|
AssistantService,
|
|
AssistantServiceSettings,
|
|
)
|
|
from noteflow.domain.ai.citations import SegmentCitation
|
|
from noteflow.domain.ai.ports import AssistantRequest
|
|
from noteflow.grpc.config.config import ServicesConfig
|
|
from noteflow.grpc.proto import noteflow_pb2
|
|
from noteflow.grpc.service import NoteFlowServicer
|
|
|
|
if TYPE_CHECKING:
|
|
from noteflow.infrastructure.ai.nodes.annotation_suggester import SuggestedAnnotation
|
|
|
|
EMPTY_ANSWER_MESSAGE: Final[str] = "AI assistant is not currently available."
|
|
DEFAULT_TOP_K: Final[int] = 8
|
|
CUSTOM_TOP_K: Final[int] = 15
|
|
SEGMENT_ID_ONE: Final[int] = 1
|
|
SEGMENT_ID_TWO: Final[int] = 2
|
|
START_TIME_FIRST: Final[float] = 10.0
|
|
END_TIME_FIRST: Final[float] = 15.0
|
|
START_TIME_SECOND: Final[float] = 30.0
|
|
END_TIME_SECOND: Final[float] = 35.0
|
|
SCORE_HIGH: Final[float] = 0.95
|
|
SCORE_MEDIUM: Final[float] = 0.87
|
|
THREAD_PREFIX: Final[str] = "test-thread"
|
|
HEX_SLICE_LEN: Final[int] = 8
|
|
CITATION_COUNT_TWO: Final[int] = 2
|
|
|
|
|
|
class _DummyContext:
|
|
"""Minimal gRPC context that records abort calls."""
|
|
|
|
abort_called: bool = False
|
|
abort_code: grpc.StatusCode | None = None
|
|
abort_details: str = ""
|
|
|
|
async def abort(self, code: grpc.StatusCode, details: str) -> None:
|
|
self.abort_called = True
|
|
self.abort_code = code
|
|
self.abort_details = details
|
|
raise AssertionError("Unreachable")
|
|
|
|
def invocation_metadata(self) -> list[tuple[str, str]]:
|
|
return []
|
|
|
|
|
|
class _MockAssistantService(AssistantService):
|
|
"""Assistant service that returns fixed responses."""
|
|
|
|
def __init__(
|
|
self,
|
|
answer: str = "Test answer",
|
|
citations: list[SegmentCitation] | None = None,
|
|
suggested_annotations: list[SuggestedAnnotation] | None = None,
|
|
) -> None:
|
|
super().__init__(settings=AssistantServiceSettings())
|
|
self._answer = answer
|
|
self._citations = citations or []
|
|
self._suggested_annotations = suggested_annotations or []
|
|
self.last_request: AssistantRequest | None = None
|
|
|
|
async def ask(self, request: AssistantRequest) -> AssistantResponse:
|
|
self.last_request = request
|
|
thread_id = request.thread_id or f"{THREAD_PREFIX}-{uuid4().hex[:HEX_SLICE_LEN]}"
|
|
return AssistantResponse(
|
|
answer=self._answer,
|
|
citations=self._citations,
|
|
suggested_annotations=self._suggested_annotations,
|
|
thread_id=thread_id,
|
|
)
|
|
|
|
|
|
class TestAskAssistantUnavailable:
|
|
"""Tests for when assistant service is not configured."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_unavailable_message(self) -> None:
|
|
"""When service is not configured, return unavailable message."""
|
|
servicer = NoteFlowServicer()
|
|
response = await servicer.AskAssistant(
|
|
noteflow_pb2.AskAssistantRequest(question="What was discussed?"),
|
|
_DummyContext(),
|
|
)
|
|
assert response.answer == EMPTY_ANSWER_MESSAGE, "Should return unavailable message"
|
|
assert response.citations == [], "Should return empty citations"
|
|
assert response.suggested_annotations == [], "Should return empty annotations"
|
|
|
|
|
|
class TestAskAssistantValidation:
|
|
"""Tests for input validation."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rejects_empty_question(self) -> None:
|
|
"""Empty question should trigger INVALID_ARGUMENT abort."""
|
|
mock_service = _MockAssistantService()
|
|
servicer = NoteFlowServicer(services=ServicesConfig(assistant_service=mock_service))
|
|
context = _DummyContext()
|
|
with pytest.raises(AssertionError, match="Unreachable"):
|
|
await servicer.AskAssistant(noteflow_pb2.AskAssistantRequest(question=""), context)
|
|
assert context.abort_called, "Should abort on empty question"
|
|
assert context.abort_code == grpc.StatusCode.INVALID_ARGUMENT, "Should be INVALID_ARGUMENT"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rejects_whitespace_only_question(self) -> None:
|
|
"""Whitespace-only question should trigger abort."""
|
|
mock_service = _MockAssistantService()
|
|
servicer = NoteFlowServicer(services=ServicesConfig(assistant_service=mock_service))
|
|
context = _DummyContext()
|
|
with pytest.raises(AssertionError, match="Unreachable"):
|
|
await servicer.AskAssistant(noteflow_pb2.AskAssistantRequest(question=" "), context)
|
|
assert context.abort_called
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rejects_invalid_meeting_id(self) -> None:
|
|
"""Invalid meeting ID format should trigger abort."""
|
|
mock_service = _MockAssistantService()
|
|
servicer = NoteFlowServicer(services=ServicesConfig(assistant_service=mock_service))
|
|
context = _DummyContext()
|
|
with pytest.raises(AssertionError, match="Unreachable"):
|
|
await servicer.AskAssistant(
|
|
noteflow_pb2.AskAssistantRequest(question="What?", meeting_id="not-a-uuid"),
|
|
context,
|
|
)
|
|
assert context.abort_called, "Should abort on invalid meeting ID"
|
|
assert context.abort_code == grpc.StatusCode.INVALID_ARGUMENT, "Should be INVALID_ARGUMENT"
|
|
|
|
|
|
class TestAskAssistantRequestPassing:
|
|
"""Tests for passing request parameters to service."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_passes_question(self) -> None:
|
|
"""Question should be passed to assistant service."""
|
|
mock_service = _MockAssistantService()
|
|
servicer = NoteFlowServicer(services=ServicesConfig(assistant_service=mock_service))
|
|
await servicer.AskAssistant(
|
|
noteflow_pb2.AskAssistantRequest(question="What were the action items?"),
|
|
_DummyContext(),
|
|
)
|
|
assert mock_service.last_request is not None, "Request should be captured"
|
|
assert mock_service.last_request.question == "What were the action items?", (
|
|
"Question should match"
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_passes_meeting_id(self) -> None:
|
|
"""Meeting ID should be parsed and passed when provided."""
|
|
mock_service = _MockAssistantService()
|
|
servicer = NoteFlowServicer(services=ServicesConfig(assistant_service=mock_service))
|
|
test_meeting_id = uuid4()
|
|
await servicer.AskAssistant(
|
|
noteflow_pb2.AskAssistantRequest(
|
|
question="What happened?",
|
|
meeting_id=str(test_meeting_id),
|
|
),
|
|
_DummyContext(),
|
|
)
|
|
assert mock_service.last_request is not None, "Request should be captured"
|
|
assert mock_service.last_request.meeting_id == test_meeting_id, "Meeting ID should match"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_uses_default_top_k(self) -> None:
|
|
"""When top_k is 0, default value should be used."""
|
|
mock_service = _MockAssistantService()
|
|
servicer = NoteFlowServicer(services=ServicesConfig(assistant_service=mock_service))
|
|
await servicer.AskAssistant(
|
|
noteflow_pb2.AskAssistantRequest(question="Test", top_k=0),
|
|
_DummyContext(),
|
|
)
|
|
assert mock_service.last_request is not None, "Request should be captured"
|
|
assert mock_service.last_request.top_k == DEFAULT_TOP_K, "Should use default top_k"
|
|
|
|
|
|
class TestAskAssistantResponse:
|
|
"""Tests for response conversion."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_answer(self) -> None:
|
|
"""Answer from service should be returned."""
|
|
expected_answer = "The meeting discussed project timelines."
|
|
mock_service = _MockAssistantService(answer=expected_answer)
|
|
servicer = NoteFlowServicer(services=ServicesConfig(assistant_service=mock_service))
|
|
response = await servicer.AskAssistant(
|
|
noteflow_pb2.AskAssistantRequest(question="What was discussed?"),
|
|
_DummyContext(),
|
|
)
|
|
assert response.answer == expected_answer
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_citations(self) -> None:
|
|
"""Citations from service should be converted and returned."""
|
|
test_meeting_id = uuid4()
|
|
citations = [
|
|
SegmentCitation(
|
|
meeting_id=test_meeting_id,
|
|
segment_id=SEGMENT_ID_ONE,
|
|
start_time=START_TIME_FIRST,
|
|
end_time=END_TIME_FIRST,
|
|
text="Project deadline is next week",
|
|
score=SCORE_HIGH,
|
|
),
|
|
]
|
|
mock_service = _MockAssistantService(citations=citations)
|
|
servicer = NoteFlowServicer(services=ServicesConfig(assistant_service=mock_service))
|
|
response = await servicer.AskAssistant(
|
|
noteflow_pb2.AskAssistantRequest(question="What was discussed?"),
|
|
_DummyContext(),
|
|
)
|
|
assert len(response.citations) == SEGMENT_ID_ONE, "Should return one citation"
|
|
assert response.citations[0].segment_id == SEGMENT_ID_ONE, "Segment ID should match"
|
|
assert response.citations[0].text == "Project deadline is next week", (
|
|
"Citation text should match"
|
|
)
|
|
assert response.citations[0].score == approx_float(SCORE_HIGH), (
|
|
"Citation score should match"
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_thread_id(self) -> None:
|
|
"""Thread ID should be included in response."""
|
|
mock_service = _MockAssistantService()
|
|
servicer = NoteFlowServicer(services=ServicesConfig(assistant_service=mock_service))
|
|
response = await servicer.AskAssistant(
|
|
noteflow_pb2.AskAssistantRequest(question="What was discussed?"),
|
|
_DummyContext(),
|
|
)
|
|
assert response.thread_id != "", "Thread ID should not be empty"
|
|
assert THREAD_PREFIX in response.thread_id, "Thread ID should contain expected prefix"
|