220 lines
7.3 KiB
Python
220 lines
7.3 KiB
Python
"""Tests for assistant service."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import re
|
|
from typing import Final
|
|
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
|
|
from noteflow.application.services.assistant import (
|
|
AssistantResponse,
|
|
AssistantService,
|
|
AssistantServiceSettings,
|
|
)
|
|
from noteflow.domain.ai.citations import SegmentCitation
|
|
from noteflow.domain.ai.ports import AssistantRequest
|
|
from noteflow.domain.value_objects import MeetingId
|
|
|
|
SAMPLE_USER_ID: Final = uuid4()
|
|
SAMPLE_MEETING_ID: Final = uuid4()
|
|
SAMPLE_QUESTION: Final = "What are the key decisions?"
|
|
SAMPLE_ANSWER: Final = "The key decisions were about budget allocation."
|
|
THREAD_SESSION_ID_LENGTH: Final = 8
|
|
THREAD_SESSION_ID_PATTERN: Final = re.compile(r"^[0-9a-f]{8}$")
|
|
SAMPLE_THREAD_ID: Final = (
|
|
f"meeting:{SAMPLE_MEETING_ID}:user:{SAMPLE_USER_ID}:graph:meeting_qa:v1:abc12345"
|
|
)
|
|
|
|
|
|
def _create_citation(meeting_id: MeetingId, segment_id: int = 0) -> SegmentCitation:
|
|
return SegmentCitation(
|
|
meeting_id=meeting_id,
|
|
segment_id=segment_id,
|
|
start_time=0.0,
|
|
end_time=5.0,
|
|
text="Sample cited text",
|
|
score=0.95,
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ask_meeting_unavailable_returns_fallback_response() -> None:
|
|
service = AssistantService()
|
|
request = AssistantRequest(
|
|
question=SAMPLE_QUESTION,
|
|
user_id=SAMPLE_USER_ID,
|
|
meeting_id=SAMPLE_MEETING_ID,
|
|
)
|
|
|
|
response = await service.ask(request)
|
|
|
|
assert response.answer == "AI assistant is not currently available.", (
|
|
"Expected unavailable message"
|
|
)
|
|
assert response.citations == [], "Expected empty citations"
|
|
assert response.suggested_annotations == [], "Expected empty annotations"
|
|
assert response.thread_id is not None, "Expected thread_id to be generated"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ask_workspace_unavailable_returns_fallback_response() -> None:
|
|
service = AssistantService()
|
|
request = AssistantRequest(
|
|
question=SAMPLE_QUESTION,
|
|
user_id=SAMPLE_USER_ID,
|
|
meeting_id=None,
|
|
)
|
|
|
|
response = await service.ask(request)
|
|
|
|
assert response.answer == "AI assistant is not currently available.", (
|
|
"Expected unavailable message"
|
|
)
|
|
assert response.citations == [], "Expected empty citations"
|
|
assert response.thread_id is not None, "Expected thread_id to be generated"
|
|
|
|
|
|
def test_default_settings_values() -> None:
|
|
settings = AssistantServiceSettings()
|
|
|
|
assert settings.enable_web_search is False, "Expected web search disabled by default"
|
|
assert settings.require_web_approval is True, "Expected web approval required by default"
|
|
assert settings.require_annotation_approval is False, (
|
|
"Expected annotation approval not required by default"
|
|
)
|
|
assert settings.default_top_k == 8, "Expected default top_k of 8"
|
|
|
|
|
|
def test_custom_settings_values() -> None:
|
|
settings = AssistantServiceSettings(
|
|
enable_web_search=True,
|
|
require_web_approval=False,
|
|
require_annotation_approval=True,
|
|
default_top_k=5,
|
|
)
|
|
|
|
assert settings.enable_web_search is True, "Expected web search enabled"
|
|
assert settings.require_web_approval is False, "Expected web approval not required"
|
|
assert settings.require_annotation_approval is True, "Expected annotation approval required"
|
|
assert settings.default_top_k == 5, "Expected custom top_k of 5"
|
|
|
|
|
|
def test_response_has_expected_fields() -> None:
|
|
response = AssistantResponse(
|
|
answer=SAMPLE_ANSWER,
|
|
citations=[],
|
|
suggested_annotations=[],
|
|
thread_id=SAMPLE_THREAD_ID,
|
|
)
|
|
|
|
assert response.answer == SAMPLE_ANSWER, "Expected answer to match"
|
|
assert response.citations == [], "Expected empty citations"
|
|
assert response.suggested_annotations == [], "Expected empty annotations"
|
|
assert response.thread_id == SAMPLE_THREAD_ID, "Expected thread_id to match"
|
|
|
|
|
|
def test_response_with_single_citation() -> None:
|
|
meeting_id = MeetingId(SAMPLE_MEETING_ID)
|
|
citation = _create_citation(meeting_id, segment_id=0)
|
|
response = AssistantResponse(
|
|
answer=SAMPLE_ANSWER,
|
|
citations=[citation],
|
|
suggested_annotations=[],
|
|
thread_id=SAMPLE_THREAD_ID,
|
|
)
|
|
|
|
assert len(response.citations) == 1, "Expected one citation"
|
|
assert response.citations[0].meeting_id == meeting_id, "Expected citation meeting_id to match"
|
|
assert response.citations[0].segment_id == 0, "Expected citation segment_id to be 0"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generates_thread_id_with_user_id() -> None:
|
|
"""Unavailable fallback generates workspace-scoped thread ID even if meeting_id provided."""
|
|
service = AssistantService()
|
|
request = AssistantRequest(
|
|
question=SAMPLE_QUESTION,
|
|
user_id=SAMPLE_USER_ID,
|
|
meeting_id=SAMPLE_MEETING_ID,
|
|
)
|
|
|
|
response = await service.ask(request)
|
|
|
|
# Unavailable fallback always uses workspace scope for thread_id
|
|
expected_prefix = f"meeting:workspace:user:{SAMPLE_USER_ID}:graph:workspace_qa:v1:"
|
|
assert response.thread_id.startswith(expected_prefix), (
|
|
"Expected thread_id to include workspace, user, graph, and version segments"
|
|
)
|
|
session_id = response.thread_id.removeprefix(expected_prefix)
|
|
assert len(session_id) == THREAD_SESSION_ID_LENGTH, "Expected session id length"
|
|
assert THREAD_SESSION_ID_PATTERN.fullmatch(session_id), "Expected hex session id"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_thread_id_format_has_expected_structure() -> None:
|
|
service = AssistantService()
|
|
request = AssistantRequest(
|
|
question=SAMPLE_QUESTION,
|
|
user_id=SAMPLE_USER_ID,
|
|
meeting_id=None,
|
|
)
|
|
|
|
response = await service.ask(request)
|
|
|
|
assert response.thread_id.startswith("meeting:"), "Expected thread_id to start with 'meeting:'"
|
|
assert ":user:" in response.thread_id, "Expected ':user:' separator in thread"
|
|
assert ":graph:" in response.thread_id, "Expected ':graph:' separator in thread"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_preserves_existing_thread_id() -> None:
|
|
service = AssistantService()
|
|
existing_thread = "custom-thread-id-123"
|
|
request = AssistantRequest(
|
|
question=SAMPLE_QUESTION,
|
|
user_id=SAMPLE_USER_ID,
|
|
meeting_id=None,
|
|
thread_id=existing_thread,
|
|
)
|
|
|
|
response = await service.ask(request)
|
|
|
|
assert response.thread_id == existing_thread, "Expected thread_id to be preserved"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unavailable_response_generates_fallback_thread_id() -> None:
|
|
service = AssistantService()
|
|
request = AssistantRequest(
|
|
question=SAMPLE_QUESTION,
|
|
user_id=SAMPLE_USER_ID,
|
|
meeting_id=SAMPLE_MEETING_ID,
|
|
)
|
|
|
|
response = await service.ask(request)
|
|
|
|
assert response.thread_id is not None, "Expected thread_id to be generated"
|
|
assert "workspace" in response.thread_id, (
|
|
"Expected unavailable response to use workspace-style thread"
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unavailable_response_for_workspace_request() -> None:
|
|
service = AssistantService()
|
|
request = AssistantRequest(
|
|
question=SAMPLE_QUESTION,
|
|
user_id=SAMPLE_USER_ID,
|
|
meeting_id=None,
|
|
)
|
|
|
|
response = await service.ask(request)
|
|
|
|
expected_prefix = f"meeting:workspace:user:{SAMPLE_USER_ID}:graph:workspace_qa:v1:"
|
|
assert response.thread_id.startswith(expected_prefix), (
|
|
"Expected workspace-scoped thread_id format"
|
|
)
|