348 lines
12 KiB
Python
348 lines
12 KiB
Python
"""Tests for GPU backend detection."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, Final
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from noteflow.domain.ports.gpu import GpuBackend, GpuInfo
|
|
from noteflow.infrastructure.gpu.detection import (
|
|
SUPPORTED_AMD_ARCHITECTURES,
|
|
GpuDetectionError,
|
|
detect_gpu_backend,
|
|
get_gpu_info,
|
|
get_rocm_environment_info,
|
|
is_ctranslate2_rocm_available,
|
|
is_rocm_architecture_supported,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
pass
|
|
|
|
# Test constants
|
|
VRAM_24GB_BYTES: Final[int] = 24 * 1024 * 1024 * 1024
|
|
VRAM_24GB_MB: Final[int] = 24 * 1024
|
|
VRAM_RX7900_MB: Final[int] = 24576
|
|
|
|
|
|
class TestDetectGpuBackend:
|
|
"""Test GPU backend detection."""
|
|
|
|
def test_no_pytorch_returns_none(self) -> None:
|
|
"""Test that missing PyTorch returns NONE backend."""
|
|
# Clear cache first
|
|
detect_gpu_backend.cache_clear()
|
|
|
|
# Create a mock that only raises ImportError for 'torch'
|
|
original_import = __builtins__["__import__"]
|
|
|
|
def mock_import(name: str, *args: object, **kwargs: object) -> object:
|
|
if name == "torch":
|
|
raise ImportError("No module named 'torch'")
|
|
return original_import(name, *args, **kwargs)
|
|
|
|
with (
|
|
patch.dict("sys.modules", {"torch": None}),
|
|
patch("builtins.__import__", side_effect=mock_import),
|
|
):
|
|
result = detect_gpu_backend()
|
|
assert result == GpuBackend.NONE, "Missing PyTorch should return NONE"
|
|
|
|
# Clear cache after test
|
|
detect_gpu_backend.cache_clear()
|
|
|
|
def test_cuda_detected(self) -> None:
|
|
"""Test CUDA backend detection."""
|
|
detect_gpu_backend.cache_clear()
|
|
|
|
mock_torch = MagicMock()
|
|
mock_torch.cuda.is_available.return_value = True
|
|
mock_torch.version.hip = None
|
|
mock_torch.version.cuda = "12.1"
|
|
|
|
with patch.dict("sys.modules", {"torch": mock_torch}):
|
|
# Need to reimport to use mocked torch
|
|
from noteflow.infrastructure.gpu import detection
|
|
|
|
# Clear the function's cache
|
|
detection.detect_gpu_backend.cache_clear()
|
|
|
|
result = detection.detect_gpu_backend()
|
|
assert result == GpuBackend.CUDA, "CUDA should be detected when available"
|
|
|
|
detect_gpu_backend.cache_clear()
|
|
|
|
def test_rocm_detected(self) -> None:
|
|
"""Test ROCm backend detection via HIP."""
|
|
detect_gpu_backend.cache_clear()
|
|
|
|
mock_torch = MagicMock()
|
|
mock_torch.cuda.is_available.return_value = True
|
|
mock_torch.version.hip = "6.0"
|
|
mock_torch.version.cuda = None
|
|
|
|
with patch.dict("sys.modules", {"torch": mock_torch}):
|
|
from noteflow.infrastructure.gpu import detection
|
|
|
|
detection.detect_gpu_backend.cache_clear()
|
|
|
|
result = detection.detect_gpu_backend()
|
|
assert result == GpuBackend.ROCM, "ROCm should be detected when HIP available"
|
|
|
|
detect_gpu_backend.cache_clear()
|
|
|
|
def test_mps_detected(self) -> None:
|
|
"""Test MPS backend detection."""
|
|
detect_gpu_backend.cache_clear()
|
|
|
|
mock_torch = MagicMock()
|
|
mock_torch.cuda.is_available.return_value = False
|
|
mock_torch.backends.mps.is_available.return_value = True
|
|
|
|
with patch.dict("sys.modules", {"torch": mock_torch}):
|
|
from noteflow.infrastructure.gpu import detection
|
|
|
|
detection.detect_gpu_backend.cache_clear()
|
|
|
|
result = detection.detect_gpu_backend()
|
|
assert result == GpuBackend.MPS, "MPS should be detected on Apple Silicon"
|
|
|
|
detect_gpu_backend.cache_clear()
|
|
|
|
|
|
class TestSupportedArchitectures:
|
|
"""Test supported AMD architecture list."""
|
|
|
|
@pytest.mark.parametrize(
|
|
"architecture",
|
|
["gfx906", "gfx908", "gfx90a", "gfx942"],
|
|
ids=["MI50", "MI100", "MI210", "MI300X"],
|
|
)
|
|
def test_cdna_architectures_included(self, architecture: str) -> None:
|
|
"""Test that CDNA datacenter architectures are supported."""
|
|
assert architecture in SUPPORTED_AMD_ARCHITECTURES, f"{architecture} should be supported"
|
|
|
|
@pytest.mark.parametrize(
|
|
"architecture",
|
|
["gfx1030", "gfx1031", "gfx1032"],
|
|
ids=["RX6800", "RX6700XT", "RX6600"],
|
|
)
|
|
def test_rdna2_architectures_included(self, architecture: str) -> None:
|
|
"""Test that RDNA 2 consumer architectures are supported."""
|
|
assert architecture in SUPPORTED_AMD_ARCHITECTURES, f"{architecture} should be supported"
|
|
|
|
@pytest.mark.parametrize(
|
|
"architecture",
|
|
["gfx1100", "gfx1101", "gfx1102"],
|
|
ids=["RX7900XTX", "RX7800XT", "RX7600"],
|
|
)
|
|
def test_rdna3_architectures_included(self, architecture: str) -> None:
|
|
"""Test that RDNA 3 consumer architectures are supported."""
|
|
assert architecture in SUPPORTED_AMD_ARCHITECTURES, f"{architecture} should be supported"
|
|
|
|
|
|
class TestIsRocmArchitectureSupported:
|
|
"""Test ROCm architecture support checking."""
|
|
|
|
@pytest.mark.parametrize(
|
|
"architecture",
|
|
["gfx1100", "gfx1030", "gfx90a", "gfx942"],
|
|
)
|
|
def test_supported_architectures(self, architecture: str) -> None:
|
|
"""Test officially supported architectures."""
|
|
assert is_rocm_architecture_supported(architecture) is True
|
|
|
|
@pytest.mark.parametrize(
|
|
"architecture",
|
|
["gfx803", "gfx900", "gfx1010", "unknown"],
|
|
)
|
|
def test_unsupported_architectures(self, architecture: str) -> None:
|
|
"""Test unsupported architectures."""
|
|
assert is_rocm_architecture_supported(architecture) is False
|
|
|
|
def test_none_architecture(self) -> None:
|
|
"""Test None architecture returns False."""
|
|
assert is_rocm_architecture_supported(None) is False
|
|
|
|
def test_override_env_var(self) -> None:
|
|
"""Test HSA_OVERRIDE_GFX_VERSION allows any architecture."""
|
|
with patch.dict("os.environ", {"HSA_OVERRIDE_GFX_VERSION": "10.3.0"}):
|
|
# Even unsupported architecture should work
|
|
assert is_rocm_architecture_supported("gfx803") is True
|
|
|
|
|
|
class TestGetGpuInfo:
|
|
"""Test GPU info retrieval."""
|
|
|
|
def test_no_gpu_returns_none(self) -> None:
|
|
"""Test no GPU returns None."""
|
|
detect_gpu_backend.cache_clear()
|
|
|
|
with patch(
|
|
"noteflow.infrastructure.gpu.detection.detect_gpu_backend",
|
|
return_value=GpuBackend.NONE,
|
|
):
|
|
result = get_gpu_info()
|
|
assert result is None, "No GPU should return None"
|
|
|
|
def test_cuda_gpu_info(self) -> None:
|
|
"""Test CUDA GPU info retrieval."""
|
|
mock_props = MagicMock()
|
|
mock_props.name = "NVIDIA GeForce RTX 4090"
|
|
mock_props.total_memory = VRAM_24GB_BYTES
|
|
mock_props.major = 8
|
|
mock_props.minor = 9
|
|
|
|
mock_torch = MagicMock()
|
|
mock_torch.cuda.get_device_properties.return_value = mock_props
|
|
mock_torch.version.cuda = "12.1"
|
|
mock_torch.version.hip = None
|
|
|
|
with (
|
|
patch(
|
|
"noteflow.infrastructure.gpu.detection.detect_gpu_backend",
|
|
return_value=GpuBackend.CUDA,
|
|
),
|
|
patch.dict("sys.modules", {"torch": mock_torch}),
|
|
):
|
|
result = get_gpu_info()
|
|
assert result is not None, "CUDA GPU info should not be None"
|
|
assert result.backend == GpuBackend.CUDA, "Backend should be CUDA"
|
|
assert result.device_name == "NVIDIA GeForce RTX 4090", "Device name mismatch"
|
|
assert result.vram_total_mb == VRAM_24GB_MB, "VRAM should be 24GB in MB"
|
|
assert result.architecture == "sm_89", "Architecture should be sm_89"
|
|
|
|
def test_mps_gpu_info(self) -> None:
|
|
"""Test MPS GPU info retrieval."""
|
|
with patch(
|
|
"noteflow.infrastructure.gpu.detection.detect_gpu_backend",
|
|
return_value=GpuBackend.MPS,
|
|
):
|
|
result = get_gpu_info()
|
|
assert result is not None, "MPS GPU info should not be None"
|
|
assert result.backend == GpuBackend.MPS, "Backend should be MPS"
|
|
assert result.device_name == "Apple Metal", "Device should be Apple Metal"
|
|
# MPS doesn't expose VRAM
|
|
assert result.vram_total_mb == 0, "MPS doesn't expose VRAM"
|
|
|
|
def test_gpu_properties_error_raises(self) -> None:
|
|
"""Test GPU properties retrieval error raises GpuDetectionError."""
|
|
mock_torch = MagicMock()
|
|
mock_torch.cuda.get_device_properties.side_effect = RuntimeError("Device not found")
|
|
|
|
with (
|
|
patch(
|
|
"noteflow.infrastructure.gpu.detection.detect_gpu_backend",
|
|
return_value=GpuBackend.CUDA,
|
|
),
|
|
patch.dict("sys.modules", {"torch": mock_torch}),
|
|
pytest.raises(GpuDetectionError, match="Failed to get GPU properties"),
|
|
):
|
|
get_gpu_info()
|
|
|
|
|
|
class TestIsCtranslate2RocmAvailable:
|
|
"""Test CTranslate2-ROCm availability checking."""
|
|
|
|
def test_not_rocm_returns_false(self) -> None:
|
|
"""Test non-ROCm backend returns False."""
|
|
detect_gpu_backend.cache_clear()
|
|
|
|
with patch(
|
|
"noteflow.infrastructure.gpu.detection.detect_gpu_backend",
|
|
return_value=GpuBackend.CUDA,
|
|
):
|
|
assert is_ctranslate2_rocm_available() is False
|
|
|
|
def test_no_ctranslate2_returns_false(self) -> None:
|
|
"""Test missing CTranslate2 returns False."""
|
|
with (
|
|
patch(
|
|
"noteflow.infrastructure.gpu.detection.detect_gpu_backend",
|
|
return_value=GpuBackend.ROCM,
|
|
),
|
|
patch("builtins.__import__", side_effect=ImportError),
|
|
):
|
|
assert is_ctranslate2_rocm_available() is False
|
|
|
|
|
|
class TestGetRocmEnvironmentInfo:
|
|
"""Test ROCm environment info retrieval."""
|
|
|
|
def test_empty_env(self) -> None:
|
|
"""Test empty environment returns empty dict."""
|
|
with patch.dict("os.environ", {}, clear=True):
|
|
result = get_rocm_environment_info()
|
|
assert result == {}, "Empty env should return empty dict"
|
|
|
|
def test_rocm_vars_captured(self) -> None:
|
|
"""Test ROCm environment variables are captured."""
|
|
env_vars = {
|
|
"HSA_OVERRIDE_GFX_VERSION": "10.3.0",
|
|
"HIP_VISIBLE_DEVICES": "0,1",
|
|
"ROCM_PATH": "/opt/rocm",
|
|
}
|
|
with patch.dict("os.environ", env_vars, clear=True):
|
|
result = get_rocm_environment_info()
|
|
assert result == env_vars, "ROCm env vars should be captured"
|
|
|
|
|
|
class TestGpuBackendEnum:
|
|
"""Test GpuBackend enum."""
|
|
|
|
@pytest.mark.parametrize(
|
|
("backend", "expected_value"),
|
|
[
|
|
(GpuBackend.NONE, "none"),
|
|
(GpuBackend.CUDA, "cuda"),
|
|
(GpuBackend.ROCM, "rocm"),
|
|
(GpuBackend.MPS, "mps"),
|
|
],
|
|
)
|
|
def test_enum_values(self, backend: GpuBackend, expected_value: str) -> None:
|
|
"""Test GpuBackend enum has expected values."""
|
|
assert backend.value == expected_value, f"{backend} should have value {expected_value}"
|
|
|
|
@pytest.mark.parametrize(
|
|
("backend", "string_value"),
|
|
[
|
|
(GpuBackend.CUDA, "cuda"),
|
|
(GpuBackend.ROCM, "rocm"),
|
|
],
|
|
)
|
|
def test_string_comparison(self, backend: GpuBackend, string_value: str) -> None:
|
|
"""Test GpuBackend can be compared as string."""
|
|
assert backend == string_value, f"{backend} should equal {string_value}"
|
|
|
|
|
|
class TestGpuInfo:
|
|
"""Test GpuInfo dataclass."""
|
|
|
|
def test_creation(self) -> None:
|
|
"""Test GpuInfo creation."""
|
|
info = GpuInfo(
|
|
backend=GpuBackend.ROCM,
|
|
device_name="AMD Radeon RX 7900 XTX",
|
|
vram_total_mb=VRAM_RX7900_MB,
|
|
driver_version="6.0",
|
|
architecture="gfx1100",
|
|
)
|
|
assert info.backend == GpuBackend.ROCM, "Backend should be ROCM"
|
|
assert info.device_name == "AMD Radeon RX 7900 XTX", "Device name mismatch"
|
|
assert info.vram_total_mb == VRAM_RX7900_MB, "VRAM mismatch"
|
|
assert info.driver_version == "6.0", "Driver version mismatch"
|
|
assert info.architecture == "gfx1100", "Architecture mismatch"
|
|
|
|
def test_frozen(self) -> None:
|
|
"""Test GpuInfo is immutable."""
|
|
info = GpuInfo(
|
|
backend=GpuBackend.CUDA,
|
|
device_name="GPU",
|
|
vram_total_mb=1024,
|
|
driver_version="12.0",
|
|
)
|
|
with pytest.raises(AttributeError, match="cannot assign"):
|
|
info.device_name = "New Name" # type: ignore[misc]
|