Files
noteflow/tests/infrastructure/asr/test_factory.py
Travis Vasceannie d8090a98e8
Some checks failed
CI / test-typescript (push) Has been cancelled
CI / test-rust (push) Has been cancelled
CI / test-python (push) Has been cancelled
ci/cd fixes
2026-01-26 00:28:15 +00:00

214 lines
7.3 KiB
Python

"""Tests for ASR engine factory."""
from __future__ import annotations
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch
import pytest
from noteflow.domain.ports.gpu import GpuBackend
from noteflow.infrastructure.asr.factory import (
EngineCreationError,
create_asr_engine,
)
if TYPE_CHECKING:
pass
class TestCreateAsrEngine:
"""Test ASR engine factory."""
def test_cpu_engine_creation(self) -> None:
"""Test CPU engine is created for cpu device."""
with patch(
"noteflow.infrastructure.asr.factory.resolve_device",
return_value="cpu",
):
engine = create_asr_engine(device="cpu", compute_type="int8")
assert engine.device == "cpu"
def test_cpu_forces_float32_for_float16(self) -> None:
"""Test CPU engine converts float16 to float32."""
with patch(
"noteflow.infrastructure.asr.factory.resolve_device",
return_value="cpu",
):
engine = create_asr_engine(device="cpu", compute_type="float16")
# CPU doesn't support float16, should use float32
assert engine.compute_type in ("float32", "float16")
def test_auto_device_resolution(self) -> None:
"""Test auto device resolution."""
# Mock GPU detection to return CUDA
with (
patch(
"noteflow.infrastructure.asr.factory.detect_gpu_backend",
return_value=GpuBackend.CUDA,
),
patch(
"noteflow.infrastructure.asr.factory._create_cuda_engine",
) as mock_cuda,
):
mock_cuda.return_value = MagicMock()
create_asr_engine(device="auto", compute_type="float16")
mock_cuda.assert_called_once()
def test_unsupported_device_raises(self) -> None:
"""Test unsupported device raises EngineCreationError."""
with patch(
"noteflow.infrastructure.asr.factory.resolve_device",
return_value="invalid_device",
), pytest.raises(EngineCreationError, match="Unsupported device"):
create_asr_engine(device="invalid_device")
class TestDeviceResolution:
"""Test device resolution logic."""
@pytest.mark.parametrize(
"device",
["cpu", "cuda", "rocm"],
)
def test_explicit_device_not_changed(self, device: str) -> None:
"""Test explicit device string is not changed."""
from noteflow.infrastructure.asr.factory import resolve_device
assert resolve_device(device) == device, f"Device {device} should remain unchanged"
def test_auto_with_cuda(self) -> None:
"""Test auto resolves to cuda when CUDA is available."""
from noteflow.infrastructure.asr.factory import resolve_device
with patch(
"noteflow.infrastructure.asr.factory.detect_gpu_backend",
return_value=GpuBackend.CUDA,
):
assert resolve_device("auto") == "cuda"
def test_auto_with_rocm_supported(self) -> None:
"""Test auto resolves to rocm when ROCm is available and supported."""
from noteflow.infrastructure.asr.factory import resolve_device
mock_gpu_info = MagicMock()
mock_gpu_info.architecture = "gfx1100"
with (
patch(
"noteflow.infrastructure.asr.factory.detect_gpu_backend",
return_value=GpuBackend.ROCM,
),
patch(
"noteflow.infrastructure.asr.factory.get_gpu_info",
return_value=mock_gpu_info,
),
patch(
"noteflow.infrastructure.asr.factory.is_rocm_architecture_supported",
return_value=True,
),
):
assert resolve_device("auto") == "rocm"
def test_auto_with_rocm_unsupported_falls_to_cpu(self) -> None:
"""Test auto falls back to CPU when ROCm arch is unsupported."""
from noteflow.infrastructure.asr.factory import resolve_device
mock_gpu_info = MagicMock()
mock_gpu_info.architecture = "gfx803"
with (
patch(
"noteflow.infrastructure.asr.factory.detect_gpu_backend",
return_value=GpuBackend.ROCM,
),
patch(
"noteflow.infrastructure.asr.factory.get_gpu_info",
return_value=mock_gpu_info,
),
patch(
"noteflow.infrastructure.asr.factory.is_rocm_architecture_supported",
return_value=False,
),
):
assert resolve_device("auto") == "cpu"
def test_auto_with_mps_falls_to_cpu(self) -> None:
"""Test auto falls back to CPU for MPS (not supported for ASR)."""
from noteflow.infrastructure.asr.factory import resolve_device
with patch(
"noteflow.infrastructure.asr.factory.detect_gpu_backend",
return_value=GpuBackend.MPS,
):
assert resolve_device("auto") == "cpu"
class TestRocmEngineCreation:
"""Test ROCm engine creation."""
def test_rocm_with_ctranslate2_available(self) -> None:
"""Test ROCm uses CTranslate2 when available."""
mock_engine = MagicMock()
with (
patch(
"noteflow.infrastructure.asr.factory.resolve_device",
return_value="rocm",
),
patch(
"noteflow.infrastructure.asr.factory.is_ctranslate2_rocm_available",
return_value=True,
),
patch(
"noteflow.infrastructure.asr.rocm_engine.FasterWhisperRocmEngine",
return_value=mock_engine,
),
):
engine = create_asr_engine(device="rocm", compute_type="float16")
assert engine == mock_engine
def test_rocm_falls_back_to_pytorch(self) -> None:
"""Test ROCm falls back to PyTorch Whisper when CTranslate2 unavailable."""
mock_engine = MagicMock()
with (
patch(
"noteflow.infrastructure.asr.factory.resolve_device",
return_value="rocm",
),
patch(
"noteflow.infrastructure.asr.factory.is_ctranslate2_rocm_available",
return_value=False,
),
patch(
"noteflow.infrastructure.asr.pytorch_engine.WhisperPyTorchEngine",
return_value=mock_engine,
),
):
engine = create_asr_engine(device="rocm", compute_type="float16")
assert engine == mock_engine
class TestPytorchEngineFallback:
"""Test PyTorch engine fallback."""
def test_pytorch_engine_import_error(self) -> None:
"""Test import error raises EngineCreationError."""
with (
patch(
"noteflow.infrastructure.asr.factory.resolve_device",
return_value="rocm",
),
patch(
"noteflow.infrastructure.asr.factory.is_ctranslate2_rocm_available",
return_value=False,
),
patch(
"noteflow.infrastructure.asr.pytorch_engine.WhisperPyTorchEngine",
side_effect=ImportError("No module"),
),
pytest.raises(EngineCreationError, match="Neither CTranslate2"),
):
create_asr_engine(device="rocm", compute_type="float16")