Files
noteflow/tests/infrastructure/gpu/test_detection.py
Travis Vasceannie d8090a98e8
Some checks failed
CI / test-typescript (push) Has been cancelled
CI / test-rust (push) Has been cancelled
CI / test-python (push) Has been cancelled
ci/cd fixes
2026-01-26 00:28:15 +00:00

348 lines
12 KiB
Python

"""Tests for GPU backend detection."""
from __future__ import annotations
from typing import TYPE_CHECKING, Final
from unittest.mock import MagicMock, patch
import pytest
from noteflow.domain.ports.gpu import GpuBackend, GpuInfo
from noteflow.infrastructure.gpu.detection import (
SUPPORTED_AMD_ARCHITECTURES,
GpuDetectionError,
detect_gpu_backend,
get_gpu_info,
get_rocm_environment_info,
is_ctranslate2_rocm_available,
is_rocm_architecture_supported,
)
if TYPE_CHECKING:
pass
# Test constants
VRAM_24GB_BYTES: Final[int] = 24 * 1024 * 1024 * 1024
VRAM_24GB_MB: Final[int] = 24 * 1024
VRAM_RX7900_MB: Final[int] = 24576
class TestDetectGpuBackend:
"""Test GPU backend detection."""
def test_no_pytorch_returns_none(self) -> None:
"""Test that missing PyTorch returns NONE backend."""
# Clear cache first
detect_gpu_backend.cache_clear()
# Create a mock that only raises ImportError for 'torch'
original_import = __builtins__["__import__"]
def mock_import(name: str, *args: object, **kwargs: object) -> object:
if name == "torch":
raise ImportError("No module named 'torch'")
return original_import(name, *args, **kwargs)
with (
patch.dict("sys.modules", {"torch": None}),
patch("builtins.__import__", side_effect=mock_import),
):
result = detect_gpu_backend()
assert result == GpuBackend.NONE, "Missing PyTorch should return NONE"
# Clear cache after test
detect_gpu_backend.cache_clear()
def test_cuda_detected(self) -> None:
"""Test CUDA backend detection."""
detect_gpu_backend.cache_clear()
mock_torch = MagicMock()
mock_torch.cuda.is_available.return_value = True
mock_torch.version.hip = None
mock_torch.version.cuda = "12.1"
with patch.dict("sys.modules", {"torch": mock_torch}):
# Need to reimport to use mocked torch
from noteflow.infrastructure.gpu import detection
# Clear the function's cache
detection.detect_gpu_backend.cache_clear()
result = detection.detect_gpu_backend()
assert result == GpuBackend.CUDA, "CUDA should be detected when available"
detect_gpu_backend.cache_clear()
def test_rocm_detected(self) -> None:
"""Test ROCm backend detection via HIP."""
detect_gpu_backend.cache_clear()
mock_torch = MagicMock()
mock_torch.cuda.is_available.return_value = True
mock_torch.version.hip = "6.0"
mock_torch.version.cuda = None
with patch.dict("sys.modules", {"torch": mock_torch}):
from noteflow.infrastructure.gpu import detection
detection.detect_gpu_backend.cache_clear()
result = detection.detect_gpu_backend()
assert result == GpuBackend.ROCM, "ROCm should be detected when HIP available"
detect_gpu_backend.cache_clear()
def test_mps_detected(self) -> None:
"""Test MPS backend detection."""
detect_gpu_backend.cache_clear()
mock_torch = MagicMock()
mock_torch.cuda.is_available.return_value = False
mock_torch.backends.mps.is_available.return_value = True
with patch.dict("sys.modules", {"torch": mock_torch}):
from noteflow.infrastructure.gpu import detection
detection.detect_gpu_backend.cache_clear()
result = detection.detect_gpu_backend()
assert result == GpuBackend.MPS, "MPS should be detected on Apple Silicon"
detect_gpu_backend.cache_clear()
class TestSupportedArchitectures:
"""Test supported AMD architecture list."""
@pytest.mark.parametrize(
"architecture",
["gfx906", "gfx908", "gfx90a", "gfx942"],
ids=["MI50", "MI100", "MI210", "MI300X"],
)
def test_cdna_architectures_included(self, architecture: str) -> None:
"""Test that CDNA datacenter architectures are supported."""
assert architecture in SUPPORTED_AMD_ARCHITECTURES, f"{architecture} should be supported"
@pytest.mark.parametrize(
"architecture",
["gfx1030", "gfx1031", "gfx1032"],
ids=["RX6800", "RX6700XT", "RX6600"],
)
def test_rdna2_architectures_included(self, architecture: str) -> None:
"""Test that RDNA 2 consumer architectures are supported."""
assert architecture in SUPPORTED_AMD_ARCHITECTURES, f"{architecture} should be supported"
@pytest.mark.parametrize(
"architecture",
["gfx1100", "gfx1101", "gfx1102"],
ids=["RX7900XTX", "RX7800XT", "RX7600"],
)
def test_rdna3_architectures_included(self, architecture: str) -> None:
"""Test that RDNA 3 consumer architectures are supported."""
assert architecture in SUPPORTED_AMD_ARCHITECTURES, f"{architecture} should be supported"
class TestIsRocmArchitectureSupported:
"""Test ROCm architecture support checking."""
@pytest.mark.parametrize(
"architecture",
["gfx1100", "gfx1030", "gfx90a", "gfx942"],
)
def test_supported_architectures(self, architecture: str) -> None:
"""Test officially supported architectures."""
assert is_rocm_architecture_supported(architecture) is True
@pytest.mark.parametrize(
"architecture",
["gfx803", "gfx900", "gfx1010", "unknown"],
)
def test_unsupported_architectures(self, architecture: str) -> None:
"""Test unsupported architectures."""
assert is_rocm_architecture_supported(architecture) is False
def test_none_architecture(self) -> None:
"""Test None architecture returns False."""
assert is_rocm_architecture_supported(None) is False
def test_override_env_var(self) -> None:
"""Test HSA_OVERRIDE_GFX_VERSION allows any architecture."""
with patch.dict("os.environ", {"HSA_OVERRIDE_GFX_VERSION": "10.3.0"}):
# Even unsupported architecture should work
assert is_rocm_architecture_supported("gfx803") is True
class TestGetGpuInfo:
"""Test GPU info retrieval."""
def test_no_gpu_returns_none(self) -> None:
"""Test no GPU returns None."""
detect_gpu_backend.cache_clear()
with patch(
"noteflow.infrastructure.gpu.detection.detect_gpu_backend",
return_value=GpuBackend.NONE,
):
result = get_gpu_info()
assert result is None, "No GPU should return None"
def test_cuda_gpu_info(self) -> None:
"""Test CUDA GPU info retrieval."""
mock_props = MagicMock()
mock_props.name = "NVIDIA GeForce RTX 4090"
mock_props.total_memory = VRAM_24GB_BYTES
mock_props.major = 8
mock_props.minor = 9
mock_torch = MagicMock()
mock_torch.cuda.get_device_properties.return_value = mock_props
mock_torch.version.cuda = "12.1"
mock_torch.version.hip = None
with (
patch(
"noteflow.infrastructure.gpu.detection.detect_gpu_backend",
return_value=GpuBackend.CUDA,
),
patch.dict("sys.modules", {"torch": mock_torch}),
):
result = get_gpu_info()
assert result is not None, "CUDA GPU info should not be None"
assert result.backend == GpuBackend.CUDA, "Backend should be CUDA"
assert result.device_name == "NVIDIA GeForce RTX 4090", "Device name mismatch"
assert result.vram_total_mb == VRAM_24GB_MB, "VRAM should be 24GB in MB"
assert result.architecture == "sm_89", "Architecture should be sm_89"
def test_mps_gpu_info(self) -> None:
"""Test MPS GPU info retrieval."""
with patch(
"noteflow.infrastructure.gpu.detection.detect_gpu_backend",
return_value=GpuBackend.MPS,
):
result = get_gpu_info()
assert result is not None, "MPS GPU info should not be None"
assert result.backend == GpuBackend.MPS, "Backend should be MPS"
assert result.device_name == "Apple Metal", "Device should be Apple Metal"
# MPS doesn't expose VRAM
assert result.vram_total_mb == 0, "MPS doesn't expose VRAM"
def test_gpu_properties_error_raises(self) -> None:
"""Test GPU properties retrieval error raises GpuDetectionError."""
mock_torch = MagicMock()
mock_torch.cuda.get_device_properties.side_effect = RuntimeError("Device not found")
with (
patch(
"noteflow.infrastructure.gpu.detection.detect_gpu_backend",
return_value=GpuBackend.CUDA,
),
patch.dict("sys.modules", {"torch": mock_torch}),
pytest.raises(GpuDetectionError, match="Failed to get GPU properties"),
):
get_gpu_info()
class TestIsCtranslate2RocmAvailable:
"""Test CTranslate2-ROCm availability checking."""
def test_not_rocm_returns_false(self) -> None:
"""Test non-ROCm backend returns False."""
detect_gpu_backend.cache_clear()
with patch(
"noteflow.infrastructure.gpu.detection.detect_gpu_backend",
return_value=GpuBackend.CUDA,
):
assert is_ctranslate2_rocm_available() is False
def test_no_ctranslate2_returns_false(self) -> None:
"""Test missing CTranslate2 returns False."""
with (
patch(
"noteflow.infrastructure.gpu.detection.detect_gpu_backend",
return_value=GpuBackend.ROCM,
),
patch("builtins.__import__", side_effect=ImportError),
):
assert is_ctranslate2_rocm_available() is False
class TestGetRocmEnvironmentInfo:
"""Test ROCm environment info retrieval."""
def test_empty_env(self) -> None:
"""Test empty environment returns empty dict."""
with patch.dict("os.environ", {}, clear=True):
result = get_rocm_environment_info()
assert result == {}, "Empty env should return empty dict"
def test_rocm_vars_captured(self) -> None:
"""Test ROCm environment variables are captured."""
env_vars = {
"HSA_OVERRIDE_GFX_VERSION": "10.3.0",
"HIP_VISIBLE_DEVICES": "0,1",
"ROCM_PATH": "/opt/rocm",
}
with patch.dict("os.environ", env_vars, clear=True):
result = get_rocm_environment_info()
assert result == env_vars, "ROCm env vars should be captured"
class TestGpuBackendEnum:
"""Test GpuBackend enum."""
@pytest.mark.parametrize(
("backend", "expected_value"),
[
(GpuBackend.NONE, "none"),
(GpuBackend.CUDA, "cuda"),
(GpuBackend.ROCM, "rocm"),
(GpuBackend.MPS, "mps"),
],
)
def test_enum_values(self, backend: GpuBackend, expected_value: str) -> None:
"""Test GpuBackend enum has expected values."""
assert backend.value == expected_value, f"{backend} should have value {expected_value}"
@pytest.mark.parametrize(
("backend", "string_value"),
[
(GpuBackend.CUDA, "cuda"),
(GpuBackend.ROCM, "rocm"),
],
)
def test_string_comparison(self, backend: GpuBackend, string_value: str) -> None:
"""Test GpuBackend can be compared as string."""
assert backend == string_value, f"{backend} should equal {string_value}"
class TestGpuInfo:
"""Test GpuInfo dataclass."""
def test_creation(self) -> None:
"""Test GpuInfo creation."""
info = GpuInfo(
backend=GpuBackend.ROCM,
device_name="AMD Radeon RX 7900 XTX",
vram_total_mb=VRAM_RX7900_MB,
driver_version="6.0",
architecture="gfx1100",
)
assert info.backend == GpuBackend.ROCM, "Backend should be ROCM"
assert info.device_name == "AMD Radeon RX 7900 XTX", "Device name mismatch"
assert info.vram_total_mb == VRAM_RX7900_MB, "VRAM mismatch"
assert info.driver_version == "6.0", "Driver version mismatch"
assert info.architecture == "gfx1100", "Architecture mismatch"
def test_frozen(self) -> None:
"""Test GpuInfo is immutable."""
info = GpuInfo(
backend=GpuBackend.CUDA,
device_name="GPU",
vram_total_mb=1024,
driver_version="12.0",
)
with pytest.raises(AttributeError, match="cannot assign"):
info.device_name = "New Name" # type: ignore[misc]