214 lines
7.3 KiB
Python
214 lines
7.3 KiB
Python
"""Tests for ASR engine factory."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from noteflow.domain.ports.gpu import GpuBackend
|
|
from noteflow.infrastructure.asr.factory import (
|
|
EngineCreationError,
|
|
create_asr_engine,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
pass
|
|
|
|
|
|
class TestCreateAsrEngine:
|
|
"""Test ASR engine factory."""
|
|
|
|
def test_cpu_engine_creation(self) -> None:
|
|
"""Test CPU engine is created for cpu device."""
|
|
with patch(
|
|
"noteflow.infrastructure.asr.factory.resolve_device",
|
|
return_value="cpu",
|
|
):
|
|
engine = create_asr_engine(device="cpu", compute_type="int8")
|
|
assert engine.device == "cpu"
|
|
|
|
def test_cpu_forces_float32_for_float16(self) -> None:
|
|
"""Test CPU engine converts float16 to float32."""
|
|
with patch(
|
|
"noteflow.infrastructure.asr.factory.resolve_device",
|
|
return_value="cpu",
|
|
):
|
|
engine = create_asr_engine(device="cpu", compute_type="float16")
|
|
# CPU doesn't support float16, should use float32
|
|
assert engine.compute_type in ("float32", "float16")
|
|
|
|
def test_auto_device_resolution(self) -> None:
|
|
"""Test auto device resolution."""
|
|
# Mock GPU detection to return CUDA
|
|
with (
|
|
patch(
|
|
"noteflow.infrastructure.asr.factory.detect_gpu_backend",
|
|
return_value=GpuBackend.CUDA,
|
|
),
|
|
patch(
|
|
"noteflow.infrastructure.asr.factory._create_cuda_engine",
|
|
) as mock_cuda,
|
|
):
|
|
mock_cuda.return_value = MagicMock()
|
|
create_asr_engine(device="auto", compute_type="float16")
|
|
mock_cuda.assert_called_once()
|
|
|
|
def test_unsupported_device_raises(self) -> None:
|
|
"""Test unsupported device raises EngineCreationError."""
|
|
with patch(
|
|
"noteflow.infrastructure.asr.factory.resolve_device",
|
|
return_value="invalid_device",
|
|
), pytest.raises(EngineCreationError, match="Unsupported device"):
|
|
create_asr_engine(device="invalid_device")
|
|
|
|
|
|
class TestDeviceResolution:
|
|
"""Test device resolution logic."""
|
|
|
|
@pytest.mark.parametrize(
|
|
"device",
|
|
["cpu", "cuda", "rocm"],
|
|
)
|
|
def test_explicit_device_not_changed(self, device: str) -> None:
|
|
"""Test explicit device string is not changed."""
|
|
from noteflow.infrastructure.asr.factory import resolve_device
|
|
|
|
assert resolve_device(device) == device, f"Device {device} should remain unchanged"
|
|
|
|
def test_auto_with_cuda(self) -> None:
|
|
"""Test auto resolves to cuda when CUDA is available."""
|
|
from noteflow.infrastructure.asr.factory import resolve_device
|
|
|
|
with patch(
|
|
"noteflow.infrastructure.asr.factory.detect_gpu_backend",
|
|
return_value=GpuBackend.CUDA,
|
|
):
|
|
assert resolve_device("auto") == "cuda"
|
|
|
|
def test_auto_with_rocm_supported(self) -> None:
|
|
"""Test auto resolves to rocm when ROCm is available and supported."""
|
|
from noteflow.infrastructure.asr.factory import resolve_device
|
|
|
|
mock_gpu_info = MagicMock()
|
|
mock_gpu_info.architecture = "gfx1100"
|
|
|
|
with (
|
|
patch(
|
|
"noteflow.infrastructure.asr.factory.detect_gpu_backend",
|
|
return_value=GpuBackend.ROCM,
|
|
),
|
|
patch(
|
|
"noteflow.infrastructure.asr.factory.get_gpu_info",
|
|
return_value=mock_gpu_info,
|
|
),
|
|
patch(
|
|
"noteflow.infrastructure.asr.factory.is_rocm_architecture_supported",
|
|
return_value=True,
|
|
),
|
|
):
|
|
assert resolve_device("auto") == "rocm"
|
|
|
|
def test_auto_with_rocm_unsupported_falls_to_cpu(self) -> None:
|
|
"""Test auto falls back to CPU when ROCm arch is unsupported."""
|
|
from noteflow.infrastructure.asr.factory import resolve_device
|
|
|
|
mock_gpu_info = MagicMock()
|
|
mock_gpu_info.architecture = "gfx803"
|
|
|
|
with (
|
|
patch(
|
|
"noteflow.infrastructure.asr.factory.detect_gpu_backend",
|
|
return_value=GpuBackend.ROCM,
|
|
),
|
|
patch(
|
|
"noteflow.infrastructure.asr.factory.get_gpu_info",
|
|
return_value=mock_gpu_info,
|
|
),
|
|
patch(
|
|
"noteflow.infrastructure.asr.factory.is_rocm_architecture_supported",
|
|
return_value=False,
|
|
),
|
|
):
|
|
assert resolve_device("auto") == "cpu"
|
|
|
|
def test_auto_with_mps_falls_to_cpu(self) -> None:
|
|
"""Test auto falls back to CPU for MPS (not supported for ASR)."""
|
|
from noteflow.infrastructure.asr.factory import resolve_device
|
|
|
|
with patch(
|
|
"noteflow.infrastructure.asr.factory.detect_gpu_backend",
|
|
return_value=GpuBackend.MPS,
|
|
):
|
|
assert resolve_device("auto") == "cpu"
|
|
|
|
|
|
class TestRocmEngineCreation:
|
|
"""Test ROCm engine creation."""
|
|
|
|
def test_rocm_with_ctranslate2_available(self) -> None:
|
|
"""Test ROCm uses CTranslate2 when available."""
|
|
mock_engine = MagicMock()
|
|
|
|
with (
|
|
patch(
|
|
"noteflow.infrastructure.asr.factory.resolve_device",
|
|
return_value="rocm",
|
|
),
|
|
patch(
|
|
"noteflow.infrastructure.asr.factory.is_ctranslate2_rocm_available",
|
|
return_value=True,
|
|
),
|
|
patch(
|
|
"noteflow.infrastructure.asr.rocm_engine.FasterWhisperRocmEngine",
|
|
return_value=mock_engine,
|
|
),
|
|
):
|
|
engine = create_asr_engine(device="rocm", compute_type="float16")
|
|
assert engine == mock_engine
|
|
|
|
def test_rocm_falls_back_to_pytorch(self) -> None:
|
|
"""Test ROCm falls back to PyTorch Whisper when CTranslate2 unavailable."""
|
|
mock_engine = MagicMock()
|
|
|
|
with (
|
|
patch(
|
|
"noteflow.infrastructure.asr.factory.resolve_device",
|
|
return_value="rocm",
|
|
),
|
|
patch(
|
|
"noteflow.infrastructure.asr.factory.is_ctranslate2_rocm_available",
|
|
return_value=False,
|
|
),
|
|
patch(
|
|
"noteflow.infrastructure.asr.pytorch_engine.WhisperPyTorchEngine",
|
|
return_value=mock_engine,
|
|
),
|
|
):
|
|
engine = create_asr_engine(device="rocm", compute_type="float16")
|
|
assert engine == mock_engine
|
|
|
|
|
|
class TestPytorchEngineFallback:
|
|
"""Test PyTorch engine fallback."""
|
|
|
|
def test_pytorch_engine_import_error(self) -> None:
|
|
"""Test import error raises EngineCreationError."""
|
|
with (
|
|
patch(
|
|
"noteflow.infrastructure.asr.factory.resolve_device",
|
|
return_value="rocm",
|
|
),
|
|
patch(
|
|
"noteflow.infrastructure.asr.factory.is_ctranslate2_rocm_available",
|
|
return_value=False,
|
|
),
|
|
patch(
|
|
"noteflow.infrastructure.asr.pytorch_engine.WhisperPyTorchEngine",
|
|
side_effect=ImportError("No module"),
|
|
),
|
|
pytest.raises(EngineCreationError, match="Neither CTranslate2"),
|
|
):
|
|
create_asr_engine(device="rocm", compute_type="float16")
|