113 lines
3.9 KiB
Python
113 lines
3.9 KiB
Python
"""Tests for domain/ai/ports.py - AI assistant protocol definitions."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, Final
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
|
|
from noteflow.domain.ai.ports import AssistantPort, AssistantRequest
|
|
|
|
if TYPE_CHECKING:
|
|
from noteflow.domain.ai.state import AssistantOutputState
|
|
|
|
FROZEN_ASSIGNMENT_MESSAGE: Final[str] = "cannot assign to field"
|
|
DEFAULT_TOP_K: Final[int] = 8
|
|
SAMPLE_QUESTION: Final[str] = "What was discussed in the meeting?"
|
|
|
|
|
|
class TestAssistantRequest:
|
|
def test_request_with_required_fields(self) -> None:
|
|
user_id = uuid4()
|
|
request = AssistantRequest(
|
|
question=SAMPLE_QUESTION,
|
|
user_id=user_id,
|
|
)
|
|
assert request.question == SAMPLE_QUESTION, "question should match input"
|
|
assert request.user_id == user_id, "user_id should match input"
|
|
|
|
@pytest.mark.parametrize(
|
|
("attribute", "expected"),
|
|
[
|
|
pytest.param("meeting_id", None, id="meeting_id"),
|
|
pytest.param("thread_id", None, id="thread_id"),
|
|
pytest.param("allow_web", False, id="allow_web"),
|
|
pytest.param("top_k", DEFAULT_TOP_K, id="top_k"),
|
|
],
|
|
)
|
|
def test_request_defaults(self, attribute: str, expected: object) -> None:
|
|
request = AssistantRequest(
|
|
question=SAMPLE_QUESTION,
|
|
user_id=uuid4(),
|
|
)
|
|
assert getattr(request, attribute) == expected
|
|
|
|
def test_request_with_all_fields(self) -> None:
|
|
user_id = uuid4()
|
|
meeting_id = uuid4()
|
|
thread_id = "thread-123"
|
|
custom_top_k = 10
|
|
|
|
request = AssistantRequest(
|
|
question=SAMPLE_QUESTION,
|
|
user_id=user_id,
|
|
meeting_id=meeting_id,
|
|
thread_id=thread_id,
|
|
allow_web=True,
|
|
top_k=custom_top_k,
|
|
)
|
|
|
|
assert request.question == SAMPLE_QUESTION, "question should match input"
|
|
assert request.user_id == user_id, "user_id should match input"
|
|
assert request.meeting_id == meeting_id, "meeting_id should match input"
|
|
assert request.thread_id == thread_id, "thread_id should match input"
|
|
assert request.allow_web is True, "allow_web should be enabled"
|
|
assert request.top_k == custom_top_k, "top_k should use custom value"
|
|
|
|
def test_request_is_frozen(self) -> None:
|
|
request = AssistantRequest(
|
|
question=SAMPLE_QUESTION,
|
|
user_id=uuid4(),
|
|
)
|
|
with pytest.raises(AttributeError, match=FROZEN_ASSIGNMENT_MESSAGE):
|
|
request.question = "Modified question"
|
|
|
|
|
|
class TestAssistantPort:
|
|
def test_protocol_has_ask_method(self) -> None:
|
|
assert hasattr(AssistantPort, "ask")
|
|
|
|
def test_mock_implementation_satisfies_protocol(self) -> None:
|
|
class MockAssistant:
|
|
async def ask(self, request: AssistantRequest) -> AssistantOutputState:
|
|
return {
|
|
"answer": "Test answer",
|
|
"citations": [],
|
|
"suggested_annotations": [],
|
|
"thread_id": "thread-1",
|
|
}
|
|
|
|
assistant: AssistantPort = MockAssistant()
|
|
assert hasattr(assistant, "ask")
|
|
|
|
async def test_mock_implementation_can_be_called(self) -> None:
|
|
class MockAssistant:
|
|
async def ask(self, request: AssistantRequest) -> AssistantOutputState:
|
|
return {
|
|
"answer": f"Response to: {request.question}",
|
|
"citations": [],
|
|
"suggested_annotations": [],
|
|
"thread_id": "thread-1",
|
|
}
|
|
|
|
assistant = MockAssistant()
|
|
request = AssistantRequest(
|
|
question=SAMPLE_QUESTION,
|
|
user_id=uuid4(),
|
|
)
|
|
result = await assistant.ask(request)
|
|
|
|
assert "Response to:" in result["answer"], "answer should include response prefix"
|
|
assert result["citations"] == [], "citations should be empty by default"
|