412 lines
15 KiB
Python
412 lines
15 KiB
Python
"""Unit tests for AsrConfigService.
|
|
|
|
Tests cover:
|
|
- get_capabilities: Returns current ASR configuration and available options
|
|
- validate_configuration: Validates model size, device, and compute type
|
|
- start_reconfiguration: Starts background reconfiguration job
|
|
- get_job_status: Returns status of reconfiguration jobs
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import contextlib
|
|
from unittest.mock import MagicMock, patch
|
|
from uuid import UUID
|
|
|
|
import pytest
|
|
|
|
from noteflow.application.services.asr_config import (
|
|
AsrComputeType,
|
|
AsrConfigJob,
|
|
AsrConfigService,
|
|
AsrDevice,
|
|
)
|
|
from noteflow.domain.constants.fields import (
|
|
JOB_STATUS_COMPLETED,
|
|
JOB_STATUS_FAILED,
|
|
)
|
|
|
|
# =============================================================================
|
|
# Fixtures
|
|
# =============================================================================
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_asr_engine() -> MagicMock:
|
|
"""Create mock ASR engine for testing."""
|
|
engine = MagicMock()
|
|
engine.model_size = "base"
|
|
engine.device = "cpu"
|
|
engine.compute_type = "int8"
|
|
engine.is_loaded = True
|
|
engine.unload = MagicMock()
|
|
engine.load_model = MagicMock()
|
|
return engine
|
|
|
|
|
|
@pytest.fixture
|
|
def asr_config_service(mock_asr_engine: MagicMock) -> AsrConfigService:
|
|
"""Create AsrConfigService with mock engine."""
|
|
return AsrConfigService(asr_engine=mock_asr_engine)
|
|
|
|
|
|
@pytest.fixture
|
|
def asr_config_service_no_engine() -> AsrConfigService:
|
|
"""Create AsrConfigService without engine."""
|
|
return AsrConfigService(asr_engine=None)
|
|
|
|
|
|
# =============================================================================
|
|
# get_capabilities tests
|
|
# =============================================================================
|
|
|
|
|
|
def test_get_capabilities_returns_current_config(
|
|
asr_config_service: AsrConfigService,
|
|
) -> None:
|
|
"""get_capabilities returns current ASR configuration."""
|
|
with patch.object(
|
|
asr_config_service.engine_manager, "detect_cuda_available", return_value=False
|
|
):
|
|
caps = asr_config_service.get_capabilities()
|
|
|
|
assert caps.model_size == "base", "model_size should be 'base' from engine"
|
|
assert caps.device == AsrDevice.CPU, "device should be CPU"
|
|
assert caps.compute_type == AsrComputeType.INT8, "compute_type should be INT8"
|
|
assert caps.is_ready is True, "is_ready should be True when engine is loaded"
|
|
assert caps.cuda_available is False, "cuda_available should be False"
|
|
assert "base" in caps.available_model_sizes, "available_model_sizes should include 'base'"
|
|
assert AsrComputeType.INT8 in caps.available_compute_types, "INT8 should be available"
|
|
|
|
|
|
def test_get_capabilities_no_engine_returns_defaults(
|
|
asr_config_service_no_engine: AsrConfigService,
|
|
) -> None:
|
|
"""get_capabilities returns defaults when no engine."""
|
|
with patch.object(asr_config_service_no_engine, "detect_cuda_available", return_value=False):
|
|
caps = asr_config_service_no_engine.get_capabilities()
|
|
|
|
assert caps.model_size is None, "model_size should be None without engine"
|
|
assert caps.device == AsrDevice.CPU, "device should default to CPU"
|
|
assert caps.is_ready is False, "is_ready should be False without engine"
|
|
|
|
|
|
def test_get_capabilities_with_cuda_available(
|
|
asr_config_service: AsrConfigService,
|
|
mock_asr_engine: MagicMock,
|
|
) -> None:
|
|
"""get_capabilities includes CUDA compute types when available."""
|
|
from noteflow.domain.ports.gpu import GpuBackend
|
|
|
|
mock_asr_engine.device = "cuda"
|
|
with patch(
|
|
"noteflow.application.services.asr_config._engine_manager.detect_gpu_backend",
|
|
return_value=GpuBackend.CUDA,
|
|
):
|
|
caps = asr_config_service.get_capabilities()
|
|
|
|
assert caps.cuda_available is True, "cuda_available should be True when CUDA detected"
|
|
assert caps.device == AsrDevice.CUDA, "device should be CUDA"
|
|
assert AsrComputeType.FLOAT16 in caps.available_compute_types, (
|
|
"FLOAT16 should be available for CUDA"
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# validate_configuration tests
|
|
# =============================================================================
|
|
|
|
|
|
def test_validate_configuration_valid_cpu_config(
|
|
asr_config_service: AsrConfigService,
|
|
) -> None:
|
|
"""validate_configuration accepts valid CPU configuration."""
|
|
with patch.object(
|
|
asr_config_service.engine_manager, "detect_cuda_available", return_value=False
|
|
):
|
|
error = asr_config_service.validate_configuration(
|
|
model_size="small",
|
|
device=AsrDevice.CPU,
|
|
compute_type=AsrComputeType.INT8,
|
|
)
|
|
|
|
assert error is None, "valid CPU configuration should not return an error"
|
|
|
|
|
|
def test_validate_configuration_invalid_model_size(
|
|
asr_config_service: AsrConfigService,
|
|
) -> None:
|
|
"""validate_configuration rejects invalid model size."""
|
|
error = asr_config_service.validate_configuration(
|
|
model_size="invalid-model",
|
|
device=None,
|
|
compute_type=None,
|
|
)
|
|
|
|
assert error is not None, "error should be set for invalid model"
|
|
assert "Invalid model size" in error, "error should mention invalid model size"
|
|
|
|
|
|
def test_validate_configuration_cuda_unavailable(
|
|
asr_config_service: AsrConfigService,
|
|
) -> None:
|
|
"""validate_configuration rejects CUDA when unavailable."""
|
|
with patch.object(
|
|
asr_config_service.engine_manager, "detect_cuda_available", return_value=False
|
|
):
|
|
error = asr_config_service.validate_configuration(
|
|
model_size=None,
|
|
device=AsrDevice.CUDA,
|
|
compute_type=None,
|
|
)
|
|
|
|
assert error is not None, "error should be set when CUDA unavailable"
|
|
assert "CUDA" in error, "error should mention CUDA"
|
|
|
|
|
|
def test_validate_configuration_invalid_compute_for_device(
|
|
asr_config_service: AsrConfigService,
|
|
) -> None:
|
|
"""validate_configuration rejects invalid compute type for device."""
|
|
error = asr_config_service.validate_configuration(
|
|
model_size=None,
|
|
device=AsrDevice.CPU,
|
|
compute_type=AsrComputeType.FLOAT16, # FLOAT16 not available for CPU
|
|
)
|
|
|
|
assert error is not None, "error should be set for invalid compute type"
|
|
assert "not available" in error, "error should mention unavailability"
|
|
|
|
|
|
def test_validate_configuration_none_values_accepted(
|
|
asr_config_service: AsrConfigService,
|
|
) -> None:
|
|
"""validate_configuration accepts None values (keep current)."""
|
|
error = asr_config_service.validate_configuration(
|
|
model_size=None,
|
|
device=None,
|
|
compute_type=None,
|
|
)
|
|
|
|
assert error is None, "None values should be accepted for validation"
|
|
|
|
|
|
# =============================================================================
|
|
# start_reconfiguration tests
|
|
# =============================================================================
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_reconfiguration_returns_job_id(
|
|
asr_config_service: AsrConfigService,
|
|
) -> None:
|
|
"""start_reconfiguration returns job ID on success."""
|
|
with patch.object(
|
|
asr_config_service.engine_manager, "detect_cuda_available", return_value=False
|
|
):
|
|
job_id, error = await asr_config_service.start_reconfiguration(
|
|
model_size="small",
|
|
device=None,
|
|
compute_type=None,
|
|
has_active_recordings=False,
|
|
)
|
|
|
|
assert job_id is not None, "job_id should be returned on success"
|
|
assert isinstance(job_id, UUID), "job_id should be a UUID"
|
|
assert error is None, "error should be None on success"
|
|
|
|
# Clean up background task
|
|
job = asr_config_service.get_job_status(job_id)
|
|
assert job is not None, "job should exist after start_reconfiguration"
|
|
assert job.task is not None, "job task should be created"
|
|
job.task.cancel()
|
|
with contextlib.suppress(asyncio.CancelledError):
|
|
await job.task
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_reconfiguration_blocked_during_recording(
|
|
asr_config_service: AsrConfigService,
|
|
) -> None:
|
|
"""start_reconfiguration is blocked while recordings are active."""
|
|
job_id, error = await asr_config_service.start_reconfiguration(
|
|
model_size="small",
|
|
device=None,
|
|
compute_type=None,
|
|
has_active_recordings=True,
|
|
)
|
|
|
|
assert job_id is None, "job_id should be None when blocked"
|
|
assert error is not None, "error should be set when recordings active"
|
|
assert "recordings are active" in error, "error should explain why blocked"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_reconfiguration_no_engine(
|
|
asr_config_service_no_engine: AsrConfigService,
|
|
) -> None:
|
|
"""start_reconfiguration fails when no engine available."""
|
|
job_id, error = await asr_config_service_no_engine.start_reconfiguration(
|
|
model_size="small",
|
|
device=None,
|
|
compute_type=None,
|
|
has_active_recordings=False,
|
|
)
|
|
|
|
assert job_id is None, "job_id should be None without engine"
|
|
assert error is not None, "error should be set without engine"
|
|
assert "not available" in error, "error should explain unavailability"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_reconfiguration_validation_failure(
|
|
asr_config_service: AsrConfigService,
|
|
) -> None:
|
|
"""start_reconfiguration fails on invalid configuration."""
|
|
with patch.object(
|
|
asr_config_service.engine_manager, "detect_cuda_available", return_value=False
|
|
):
|
|
job_id, error = await asr_config_service.start_reconfiguration(
|
|
model_size="invalid-model",
|
|
device=None,
|
|
compute_type=None,
|
|
has_active_recordings=False,
|
|
)
|
|
|
|
assert job_id is None, "job_id should be None on validation failure"
|
|
assert error is not None, "error should be set on validation failure"
|
|
assert "Invalid model size" in error, "error should mention invalid model"
|
|
|
|
|
|
# =============================================================================
|
|
# get_job_status tests
|
|
# =============================================================================
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_job_status_returns_job(
|
|
asr_config_service: AsrConfigService,
|
|
) -> None:
|
|
"""get_job_status returns job info for valid ID."""
|
|
with patch.object(
|
|
asr_config_service.engine_manager, "detect_cuda_available", return_value=False
|
|
):
|
|
job_id, _ = await asr_config_service.start_reconfiguration(
|
|
model_size="small",
|
|
device=None,
|
|
compute_type=None,
|
|
has_active_recordings=False,
|
|
)
|
|
|
|
assert job_id is not None, "job_id should be returned"
|
|
job = asr_config_service.get_job_status(job_id)
|
|
|
|
assert job is not None, "job should be found for valid ID"
|
|
assert isinstance(job, AsrConfigJob), "job should be AsrConfigJob type"
|
|
assert job.target_model_size == "small", "target_model_size should match request"
|
|
|
|
# Clean up
|
|
assert job.task is not None, "job task should be created"
|
|
job.task.cancel()
|
|
with contextlib.suppress(asyncio.CancelledError):
|
|
await job.task
|
|
|
|
|
|
def test_get_job_status_returns_none_for_unknown_id(
|
|
asr_config_service: AsrConfigService,
|
|
) -> None:
|
|
"""get_job_status returns None for unknown job ID."""
|
|
from uuid import uuid4
|
|
|
|
job = asr_config_service.get_job_status(uuid4())
|
|
assert job is None, "job should be None for unknown ID"
|
|
|
|
|
|
# =============================================================================
|
|
# detect_cuda_available tests
|
|
# =============================================================================
|
|
|
|
|
|
def test_detect_cuda_available_with_cuda(
|
|
asr_config_service: AsrConfigService,
|
|
) -> None:
|
|
"""detect_cuda_available returns True when CUDA is available."""
|
|
from noteflow.domain.ports.gpu import GpuBackend
|
|
|
|
with patch(
|
|
"noteflow.application.services.asr_config._engine_manager.detect_gpu_backend",
|
|
return_value=GpuBackend.CUDA,
|
|
):
|
|
result = asr_config_service.detect_cuda_available()
|
|
|
|
assert result is True, "detect_cuda_available should return True when CUDA available"
|
|
|
|
|
|
def test_detect_cuda_available_no_cuda(
|
|
asr_config_service: AsrConfigService,
|
|
) -> None:
|
|
"""detect_cuda_available returns False when CUDA is not available."""
|
|
from noteflow.domain.ports.gpu import GpuBackend
|
|
|
|
with patch(
|
|
"noteflow.application.services.asr_config._engine_manager.detect_gpu_backend",
|
|
return_value=GpuBackend.NONE,
|
|
):
|
|
result = asr_config_service.detect_cuda_available()
|
|
|
|
assert result is False, "detect_cuda_available should return False when CUDA unavailable"
|
|
|
|
|
|
# =============================================================================
|
|
# reconfiguration behavior tests
|
|
# =============================================================================
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reconfiguration_failure_keeps_active_engine(mock_asr_engine: MagicMock) -> None:
|
|
"""Reconfiguration failure should not replace or unload the active engine."""
|
|
from noteflow.infrastructure.asr.engine import FasterWhisperEngine
|
|
|
|
updates: list[FasterWhisperEngine] = []
|
|
service = AsrConfigService(asr_engine=mock_asr_engine, on_engine_update=updates.append)
|
|
mgr = service.engine_manager
|
|
|
|
with (
|
|
patch.object(mgr, "build_engine_for_job", return_value=(MagicMock(), True)),
|
|
patch.object(mgr, "load_model", side_effect=RuntimeError("boom")),
|
|
):
|
|
job_id, _ = await service.start_reconfiguration("small", None, None, False)
|
|
assert job_id is not None, "job_id should not be None"
|
|
job = service.get_job_status(job_id)
|
|
assert job is not None and job.task is not None, "job should be created with task"
|
|
await job.task
|
|
|
|
assert job.status == JOB_STATUS_FAILED, "job marked failed on load error"
|
|
mock_asr_engine.unload.assert_not_called()
|
|
assert not updates, "no callback on failure"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reconfiguration_success_swaps_engine(mock_asr_engine: MagicMock) -> None:
|
|
"""Successful reconfiguration should swap engine and unload the old one."""
|
|
from noteflow.infrastructure.asr.engine import FasterWhisperEngine
|
|
|
|
updates: list[FasterWhisperEngine] = []
|
|
service = AsrConfigService(asr_engine=mock_asr_engine, on_engine_update=updates.append)
|
|
new_engine, mgr = MagicMock(), service.engine_manager
|
|
|
|
with (
|
|
patch.object(mgr, "build_engine_for_job", return_value=(new_engine, True)),
|
|
patch.object(mgr, "load_model", return_value=None),
|
|
):
|
|
job_id, _ = await service.start_reconfiguration("small", None, None, False)
|
|
assert job_id is not None, "job_id should not be None"
|
|
job = service.get_job_status(job_id)
|
|
assert job is not None and job.task is not None, "job should be created with task"
|
|
await job.task
|
|
|
|
assert job.status == JOB_STATUS_COMPLETED, "job completed successfully"
|
|
mock_asr_engine.unload.assert_called_once()
|
|
assert updates == [new_engine], "callback fired with new engine"
|