607 lines
22 KiB
Python
607 lines
22 KiB
Python
"""Tests for Ollama summarization provider.
|
|
|
|
Uses shared helpers from tests/infrastructure/summarization/conftest.py:
|
|
- create_test_segment: Creates Segment instances for testing
|
|
- build_valid_json_response: Creates mock LLM JSON responses
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import sys
|
|
import types
|
|
from collections.abc import Callable, Mapping, Sequence
|
|
from typing import Protocol
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
|
|
from noteflow.domain.summarization import (
|
|
InvalidResponseError,
|
|
ProviderUnavailableError,
|
|
SummarizationRequest,
|
|
)
|
|
from noteflow.domain.value_objects import MeetingId
|
|
from noteflow.infrastructure.summarization import OllamaSummarizer, ollama_provider
|
|
|
|
from .conftest import build_valid_json_response, create_test_segment
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Type Definitions
|
|
# -----------------------------------------------------------------------------
|
|
|
|
# Callable type for mock chat functions
|
|
ChatFn = Callable[
|
|
...,
|
|
"_MockChatResponse",
|
|
]
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Typed Mock Infrastructure
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
class _MockMessage:
|
|
"""Mock message for chat responses."""
|
|
|
|
def __init__(self, content: str) -> None:
|
|
self.content = content
|
|
|
|
|
|
class _MockChatResponse:
|
|
"""Mock chat response matching ollama.ChatResponse structure."""
|
|
|
|
def __init__(self, content: str) -> None:
|
|
self.message = _MockMessage(content)
|
|
|
|
|
|
class _MockListResponse:
|
|
"""Mock list response matching ollama.ListResponse structure."""
|
|
|
|
def __init__(self, models: list[dict[str, str]]) -> None:
|
|
self.models = models
|
|
|
|
|
|
class _MockOllamaClient(Protocol):
|
|
"""Protocol for mock Ollama client."""
|
|
|
|
def list(self) -> _MockListResponse: ...
|
|
|
|
def chat(
|
|
self,
|
|
*,
|
|
model: str,
|
|
messages: Sequence[Mapping[str, str]],
|
|
options: Mapping[str, object],
|
|
format: str,
|
|
) -> _MockChatResponse: ...
|
|
|
|
|
|
class _MockOllamaModule(types.ModuleType):
|
|
"""Mock module type for ollama."""
|
|
|
|
Client: type[_MockOllamaClient] | object
|
|
|
|
|
|
def _create_mock_client_class(
|
|
list_response: _MockListResponse,
|
|
chat_response: _MockChatResponse | None = None,
|
|
chat_fn: ChatFn | None = None,
|
|
) -> type[object]:
|
|
"""Create a mock client class with specified responses.
|
|
|
|
Args:
|
|
list_response: Response to return from list().
|
|
chat_response: Optional static response for chat().
|
|
chat_fn: Optional callable for dynamic chat behavior.
|
|
|
|
Returns:
|
|
Mock client class.
|
|
"""
|
|
|
|
class _Client:
|
|
def __init__(self, host: str | None = None) -> None:
|
|
self._host = host
|
|
|
|
def list(self) -> _MockListResponse:
|
|
return list_response
|
|
|
|
def chat(
|
|
self,
|
|
*,
|
|
model: str = "",
|
|
messages: Sequence[Mapping[str, str]] | None = None,
|
|
options: Mapping[str, object] | None = None,
|
|
format: str = "",
|
|
) -> _MockChatResponse:
|
|
if chat_fn is not None:
|
|
# Dynamic behavior - call provided function
|
|
return chat_fn(model=model, messages=messages, options=options, format=format)
|
|
if chat_response is not None:
|
|
return chat_response
|
|
return _MockChatResponse(build_valid_json_response())
|
|
|
|
return _Client
|
|
|
|
|
|
def _create_mock_module(client_class: type[object]) -> _MockOllamaModule:
|
|
"""Create mock ollama module with the given client class.
|
|
|
|
Args:
|
|
client_class: Client class to use.
|
|
|
|
Returns:
|
|
Mock module.
|
|
"""
|
|
mock_module = _MockOllamaModule("ollama")
|
|
mock_module.Client = client_class
|
|
return mock_module
|
|
|
|
|
|
def _setup_import_failure(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""Set up monkeypatch to make ollama import fail.
|
|
|
|
Args:
|
|
monkeypatch: Pytest monkeypatch fixture.
|
|
"""
|
|
import builtins
|
|
|
|
monkeypatch.delitem(sys.modules, "ollama", raising=False)
|
|
original_import = builtins.__import__
|
|
|
|
def mock_import(
|
|
name: str,
|
|
globals: Mapping[str, object] | None = None,
|
|
locals: Mapping[str, object] | None = None,
|
|
fromlist: Sequence[str] | None = None,
|
|
level: int = 0,
|
|
) -> object:
|
|
if name == "ollama":
|
|
raise ImportError("No module named 'ollama'")
|
|
return original_import(name, globals, locals, fromlist, level)
|
|
|
|
monkeypatch.setattr(builtins, "__import__", mock_import)
|
|
|
|
|
|
def _create_fresh_summarizer_without_client() -> ollama_provider.OllamaSummarizer:
|
|
"""Create a fresh OllamaSummarizer with client reset to force re-import.
|
|
|
|
Returns:
|
|
OllamaSummarizer instance with client attribute reset.
|
|
"""
|
|
from noteflow.infrastructure.summarization import ollama_provider
|
|
|
|
summarizer = ollama_provider.OllamaSummarizer()
|
|
object.__setattr__(summarizer, "_client", None)
|
|
return summarizer
|
|
|
|
|
|
def _setup_mock_ollama(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
chat_response: _MockChatResponse | None = None,
|
|
chat_fn: ChatFn | None = None,
|
|
) -> None:
|
|
"""Set up mock ollama module in sys.modules.
|
|
|
|
Args:
|
|
monkeypatch: Pytest monkeypatch fixture.
|
|
chat_response: Optional static chat response.
|
|
chat_fn: Optional dynamic chat function.
|
|
"""
|
|
list_resp = _MockListResponse(models=[])
|
|
client_class = _create_mock_client_class(list_resp, chat_response, chat_fn)
|
|
mock_module = _create_mock_module(client_class)
|
|
monkeypatch.setitem(sys.modules, "ollama", mock_module)
|
|
|
|
|
|
def _create_summarizer_and_request(
|
|
meeting_id: MeetingId, segment_count: int = 1
|
|
) -> tuple[OllamaSummarizer, SummarizationRequest]:
|
|
"""Create an OllamaSummarizer and SummarizationRequest.
|
|
|
|
Args:
|
|
meeting_id: Meeting ID for the request.
|
|
segment_count: Number of test segments to create.
|
|
|
|
Returns:
|
|
Tuple of (summarizer, request).
|
|
"""
|
|
from noteflow.infrastructure.summarization import OllamaSummarizer
|
|
|
|
summarizer = OllamaSummarizer()
|
|
segments = [create_test_segment(i, f"Test segment {i}") for i in range(segment_count)]
|
|
request = SummarizationRequest(meeting_id=meeting_id, segments=segments)
|
|
return summarizer, request
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Test Classes
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
class TestOllamaSummarizerProperties:
|
|
"""Tests for OllamaSummarizer properties."""
|
|
|
|
@pytest.fixture
|
|
def mock_ollama_module(self, monkeypatch: pytest.MonkeyPatch) -> types.ModuleType:
|
|
"""Mock ollama module."""
|
|
list_resp = _MockListResponse(models=[])
|
|
chat_resp = _MockChatResponse(build_valid_json_response())
|
|
client_class = _create_mock_client_class(list_resp, chat_resp)
|
|
mock_module = _create_mock_module(client_class)
|
|
monkeypatch.setitem(sys.modules, "ollama", mock_module)
|
|
return mock_module
|
|
|
|
def test_ollama_provider_name(self) -> None:
|
|
"""Provider name should be 'ollama'."""
|
|
from noteflow.infrastructure.summarization import OllamaSummarizer
|
|
|
|
summarizer = OllamaSummarizer()
|
|
assert summarizer.provider_name == "ollama", "provider_name should be 'ollama'"
|
|
|
|
def test_requires_cloud_consent_false(self) -> None:
|
|
"""Ollama should not require cloud consent (local processing)."""
|
|
from noteflow.infrastructure.summarization import OllamaSummarizer
|
|
|
|
summarizer = OllamaSummarizer()
|
|
assert summarizer.requires_cloud_consent is False, (
|
|
"local provider should not require cloud consent"
|
|
)
|
|
|
|
@pytest.mark.usefixtures("mock_ollama_module")
|
|
def test_is_available_when_server_responds(self) -> None:
|
|
"""is_available should be True when server responds."""
|
|
from noteflow.infrastructure.summarization import OllamaSummarizer
|
|
|
|
summarizer = OllamaSummarizer()
|
|
assert summarizer.is_available is True, "is_available should be True when server responds"
|
|
|
|
def test_is_available_false_when_connection_fails(
|
|
self, monkeypatch: pytest.MonkeyPatch
|
|
) -> None:
|
|
"""is_available should be False when server unreachable."""
|
|
|
|
class _FailingClient:
|
|
def __init__(self, host: str | None = None) -> None:
|
|
self._host = host
|
|
|
|
def list(self) -> _MockListResponse:
|
|
raise ConnectionError("Connection refused")
|
|
|
|
def chat(
|
|
self,
|
|
*,
|
|
model: str = "",
|
|
messages: Sequence[Mapping[str, str]] | None = None,
|
|
options: Mapping[str, object] | None = None,
|
|
format: str = "",
|
|
) -> _MockChatResponse:
|
|
raise ConnectionError("Connection refused")
|
|
|
|
mock_module = _create_mock_module(_FailingClient)
|
|
monkeypatch.setitem(sys.modules, "ollama", mock_module)
|
|
|
|
from noteflow.infrastructure.summarization import OllamaSummarizer
|
|
|
|
summarizer = OllamaSummarizer()
|
|
assert summarizer.is_available is False, (
|
|
"is_available should be False when connection fails"
|
|
)
|
|
|
|
|
|
class TestOllamaSummarizerSummarize:
|
|
"""Tests for OllamaSummarizer.summarize method."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ollama_summarize_empty_segments(
|
|
self, meeting_id: MeetingId, monkeypatch: pytest.MonkeyPatch
|
|
) -> None:
|
|
"""Empty segments should return empty summary without calling LLM."""
|
|
call_count = 0
|
|
|
|
def counting_chat(
|
|
*,
|
|
model: str = "",
|
|
messages: Sequence[Mapping[str, str]] | None = None,
|
|
options: Mapping[str, object] | None = None,
|
|
format: str = "",
|
|
) -> _MockChatResponse:
|
|
nonlocal call_count
|
|
call_count += 1
|
|
return _MockChatResponse(build_valid_json_response())
|
|
|
|
list_resp = _MockListResponse(models=[])
|
|
client_class = _create_mock_client_class(list_resp, chat_fn=counting_chat)
|
|
mock_module = _create_mock_module(client_class)
|
|
monkeypatch.setitem(sys.modules, "ollama", mock_module)
|
|
|
|
from noteflow.infrastructure.summarization import OllamaSummarizer
|
|
|
|
summarizer = OllamaSummarizer()
|
|
request = SummarizationRequest(meeting_id=meeting_id, segments=[])
|
|
|
|
result = await summarizer.summarize(request)
|
|
|
|
assert result.summary.key_points == [], "empty segments should produce empty key_points"
|
|
assert result.summary.action_items == [], "empty segments should produce empty action_items"
|
|
assert call_count == 0, "LLM should not be called for empty segments"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ollama_summarize_returns_result(
|
|
self, meeting_id: MeetingId, monkeypatch: pytest.MonkeyPatch
|
|
) -> None:
|
|
"""Summarize should return SummarizationResult."""
|
|
response = build_valid_json_response(
|
|
summary="Meeting discussed project updates.",
|
|
key_points=[{"text": "Project on track", "segment_ids": [0]}],
|
|
action_items=[
|
|
{"text": "Review code", "assignee": "Alice", "priority": 2, "segment_ids": [1]}
|
|
],
|
|
)
|
|
_setup_mock_ollama(monkeypatch, chat_response=_MockChatResponse(response))
|
|
summarizer, request = _create_summarizer_and_request(meeting_id, segment_count=2)
|
|
|
|
result = await summarizer.summarize(request)
|
|
|
|
assert result.provider_name == "ollama", "result should report 'ollama' as provider_name"
|
|
assert result.summary.meeting_id == meeting_id, "summary should have matching meeting_id"
|
|
assert result.summary.executive_summary == "Meeting discussed project updates.", (
|
|
"executive_summary should match"
|
|
)
|
|
assert len(result.summary.key_points) == 1, "should have exactly one key_point"
|
|
assert result.summary.key_points[0].segment_ids == [0], (
|
|
"key_point should reference segment 0"
|
|
)
|
|
assert len(result.summary.action_items) == 1, "should have exactly one action_item"
|
|
assert result.summary.action_items[0].assignee == "Alice", (
|
|
"action_item assignee should be 'Alice'"
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_summarize_filters_invalid_segment_ids(
|
|
self, meeting_id: MeetingId, monkeypatch: pytest.MonkeyPatch
|
|
) -> None:
|
|
"""Invalid segment_ids in response should be filtered out."""
|
|
response = build_valid_json_response(
|
|
summary="Test",
|
|
key_points=[{"text": "Point", "segment_ids": [0, 99, 100]}], # 99, 100 invalid
|
|
)
|
|
|
|
list_resp = _MockListResponse(models=[])
|
|
chat_resp = _MockChatResponse(response)
|
|
client_class = _create_mock_client_class(list_resp, chat_resp)
|
|
mock_module = _create_mock_module(client_class)
|
|
monkeypatch.setitem(sys.modules, "ollama", mock_module)
|
|
|
|
from noteflow.infrastructure.summarization import OllamaSummarizer
|
|
|
|
summarizer = OllamaSummarizer()
|
|
segments = [create_test_segment(0, "Only segment")]
|
|
request = SummarizationRequest(meeting_id=meeting_id, segments=segments)
|
|
|
|
result = await summarizer.summarize(request)
|
|
|
|
assert result.summary.key_points[0].segment_ids == [0], (
|
|
"invalid segment_ids (99, 100) should be filtered out"
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_summarize_respects_max_limits(
|
|
self, meeting_id: MeetingId, monkeypatch: pytest.MonkeyPatch
|
|
) -> None:
|
|
"""Response items exceeding max limits should be truncated."""
|
|
response = build_valid_json_response(
|
|
summary="Test",
|
|
key_points=[{"text": f"Point {i}", "segment_ids": [0]} for i in range(10)],
|
|
action_items=[{"text": f"Action {i}", "segment_ids": [0]} for i in range(10)],
|
|
)
|
|
|
|
list_resp = _MockListResponse(models=[])
|
|
chat_resp = _MockChatResponse(response)
|
|
client_class = _create_mock_client_class(list_resp, chat_resp)
|
|
mock_module = _create_mock_module(client_class)
|
|
monkeypatch.setitem(sys.modules, "ollama", mock_module)
|
|
|
|
from noteflow.infrastructure.summarization import OllamaSummarizer
|
|
|
|
summarizer = OllamaSummarizer()
|
|
segments = [create_test_segment(0, "Test segment")]
|
|
request = SummarizationRequest(
|
|
meeting_id=meeting_id,
|
|
segments=segments,
|
|
max_key_points=3,
|
|
max_action_items=2,
|
|
)
|
|
|
|
result = await summarizer.summarize(request)
|
|
|
|
assert len(result.summary.key_points) == 3, (
|
|
"key_points should be truncated to max_key_points=3"
|
|
)
|
|
assert len(result.summary.action_items) == 2, (
|
|
"action_items should be truncated to max_action_items=2"
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_summarize_handles_markdown_fenced_json(
|
|
self, meeting_id: MeetingId, monkeypatch: pytest.MonkeyPatch
|
|
) -> None:
|
|
"""Markdown code fences around JSON should be stripped."""
|
|
json_content = build_valid_json_response(summary="Fenced response")
|
|
response = f"```json\n{json_content}\n```"
|
|
|
|
list_resp = _MockListResponse(models=[])
|
|
chat_resp = _MockChatResponse(response)
|
|
client_class = _create_mock_client_class(list_resp, chat_resp)
|
|
mock_module = _create_mock_module(client_class)
|
|
monkeypatch.setitem(sys.modules, "ollama", mock_module)
|
|
|
|
from noteflow.infrastructure.summarization import OllamaSummarizer
|
|
|
|
summarizer = OllamaSummarizer()
|
|
segments = [create_test_segment(0, "Test")]
|
|
request = SummarizationRequest(meeting_id=meeting_id, segments=segments)
|
|
|
|
result = await summarizer.summarize(request)
|
|
|
|
assert result.summary.executive_summary == "Fenced response", (
|
|
"markdown code fences should be stripped from JSON response"
|
|
)
|
|
|
|
|
|
class TestOllamaSummarizerErrors:
|
|
"""Tests for OllamaSummarizer error handling."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ollama_raises_unavailable_when_package_missing(
|
|
self, meeting_id: MeetingId, monkeypatch: pytest.MonkeyPatch
|
|
) -> None:
|
|
"""Should raise ProviderUnavailableError when ollama not installed."""
|
|
_setup_import_failure(monkeypatch)
|
|
summarizer = _create_fresh_summarizer_without_client()
|
|
request = SummarizationRequest(
|
|
meeting_id=meeting_id, segments=[create_test_segment(0, "Test")]
|
|
)
|
|
|
|
with pytest.raises(ProviderUnavailableError, match="ollama package not installed"):
|
|
await summarizer.summarize(request)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_raises_unavailable_on_connection_error(
|
|
self, meeting_id: MeetingId, monkeypatch: pytest.MonkeyPatch
|
|
) -> None:
|
|
"""Should raise ProviderUnavailableError on connection failure."""
|
|
|
|
def raise_connection_error(
|
|
*,
|
|
model: str = "",
|
|
messages: Sequence[Mapping[str, str]] | None = None,
|
|
options: Mapping[str, object] | None = None,
|
|
format: str = "",
|
|
) -> _MockChatResponse:
|
|
raise ConnectionRefusedError("Connection refused")
|
|
|
|
list_resp = _MockListResponse(models=[])
|
|
client_class = _create_mock_client_class(list_resp, chat_fn=raise_connection_error)
|
|
mock_module = _create_mock_module(client_class)
|
|
monkeypatch.setitem(sys.modules, "ollama", mock_module)
|
|
|
|
from noteflow.infrastructure.summarization import OllamaSummarizer
|
|
|
|
summarizer = OllamaSummarizer()
|
|
segments = [create_test_segment(0, "Test")]
|
|
request = SummarizationRequest(meeting_id=meeting_id, segments=segments)
|
|
|
|
with pytest.raises(ProviderUnavailableError, match="Cannot connect"):
|
|
await summarizer.summarize(request)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_raises_invalid_response_on_bad_json(
|
|
self, meeting_id: MeetingId, monkeypatch: pytest.MonkeyPatch
|
|
) -> None:
|
|
"""Should raise InvalidResponseError on malformed JSON."""
|
|
list_resp = _MockListResponse(models=[])
|
|
chat_resp = _MockChatResponse("not valid json {{{")
|
|
client_class = _create_mock_client_class(list_resp, chat_resp)
|
|
mock_module = _create_mock_module(client_class)
|
|
monkeypatch.setitem(sys.modules, "ollama", mock_module)
|
|
|
|
from noteflow.infrastructure.summarization import OllamaSummarizer
|
|
|
|
summarizer = OllamaSummarizer()
|
|
segments = [create_test_segment(0, "Test")]
|
|
request = SummarizationRequest(meeting_id=meeting_id, segments=segments)
|
|
|
|
with pytest.raises(InvalidResponseError, match="Invalid JSON"):
|
|
await summarizer.summarize(request)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ollama_raises_invalid_response_on_empty_content(
|
|
self, meeting_id: MeetingId, monkeypatch: pytest.MonkeyPatch
|
|
) -> None:
|
|
"""Should raise InvalidResponseError on empty response."""
|
|
list_resp = _MockListResponse(models=[])
|
|
chat_resp = _MockChatResponse("")
|
|
client_class = _create_mock_client_class(list_resp, chat_resp)
|
|
mock_module = _create_mock_module(client_class)
|
|
monkeypatch.setitem(sys.modules, "ollama", mock_module)
|
|
|
|
from noteflow.infrastructure.summarization import OllamaSummarizer
|
|
|
|
summarizer = OllamaSummarizer()
|
|
segments = [create_test_segment(0, "Test")]
|
|
request = SummarizationRequest(meeting_id=meeting_id, segments=segments)
|
|
|
|
with pytest.raises(InvalidResponseError, match="Empty response"):
|
|
await summarizer.summarize(request)
|
|
|
|
|
|
class TestOllamaSummarizerConfiguration:
|
|
"""Tests for OllamaSummarizer configuration."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_custom_model_name(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""Custom model name should be used."""
|
|
captured_model: str | None = None
|
|
|
|
def capture_chat(
|
|
*,
|
|
model: str = "",
|
|
messages: Sequence[Mapping[str, str]] | None = None,
|
|
options: Mapping[str, object] | None = None,
|
|
format: str = "",
|
|
) -> _MockChatResponse:
|
|
nonlocal captured_model
|
|
captured_model = model
|
|
return _MockChatResponse(build_valid_json_response())
|
|
|
|
list_resp = _MockListResponse(models=[])
|
|
client_class = _create_mock_client_class(list_resp, chat_fn=capture_chat)
|
|
mock_module = _create_mock_module(client_class)
|
|
monkeypatch.setitem(sys.modules, "ollama", mock_module)
|
|
|
|
from noteflow.infrastructure.summarization import OllamaSummarizer
|
|
|
|
summarizer = OllamaSummarizer(model="mistral")
|
|
meeting_id = MeetingId(uuid4())
|
|
segments = [create_test_segment(0, "Test")]
|
|
request = SummarizationRequest(meeting_id=meeting_id, segments=segments)
|
|
|
|
await summarizer.summarize(request)
|
|
|
|
assert captured_model == "mistral", "custom model name should be passed to ollama client"
|
|
|
|
def test_custom_host(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""Custom host should be passed to client."""
|
|
captured_host: str | None = None
|
|
|
|
class _CapturingClient:
|
|
def __init__(self, host: str | None = None) -> None:
|
|
nonlocal captured_host
|
|
captured_host = host
|
|
|
|
def list(self) -> _MockListResponse:
|
|
return _MockListResponse(models=[])
|
|
|
|
def chat(
|
|
self,
|
|
*,
|
|
model: str = "",
|
|
messages: Sequence[Mapping[str, str]] | None = None,
|
|
options: Mapping[str, object] | None = None,
|
|
format: str = "",
|
|
) -> _MockChatResponse:
|
|
return _MockChatResponse(build_valid_json_response())
|
|
|
|
mock_module = _create_mock_module(_CapturingClient)
|
|
monkeypatch.setitem(sys.modules, "ollama", mock_module)
|
|
|
|
from noteflow.infrastructure.summarization import OllamaSummarizer
|
|
|
|
summarizer = OllamaSummarizer(host="http://custom:8080")
|
|
_ = summarizer.is_available
|
|
|
|
assert captured_host == "http://custom:8080", (
|
|
"custom host should be passed to ollama Client"
|
|
)
|