157 lines
5.7 KiB
Python
157 lines
5.7 KiB
Python
"""Tests for platform detection utilities."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
from unittest.mock import MagicMock, mock_open, patch
|
|
|
|
from noteflow.infrastructure.platform import (
|
|
configure_pytorch_for_platform,
|
|
has_avx2_support,
|
|
read_linux_cpuinfo,
|
|
read_sysctl_features,
|
|
)
|
|
|
|
|
|
class TestReadLinuxCpuinfo:
|
|
def test_returns_none_when_file_not_exists(self) -> None:
|
|
with patch("os.path.exists", return_value=False):
|
|
result = read_linux_cpuinfo()
|
|
assert result is None, "Should return None when /proc/cpuinfo doesn't exist"
|
|
|
|
def test_returns_content_when_file_exists(self) -> None:
|
|
cpuinfo_content = "processor : 0\nflags : avx avx2 sse\n"
|
|
with (
|
|
patch("os.path.exists", return_value=True),
|
|
patch("builtins.open", mock_open(read_data=cpuinfo_content)),
|
|
):
|
|
result = read_linux_cpuinfo()
|
|
assert result == cpuinfo_content, "Should return file content"
|
|
|
|
def test_returns_none_on_oserror_linux(self) -> None:
|
|
with (
|
|
patch("os.path.exists", return_value=True),
|
|
patch("builtins.open", side_effect=OSError("Permission denied")),
|
|
):
|
|
result = read_linux_cpuinfo()
|
|
assert result is None, "Should return None on OSError"
|
|
|
|
|
|
class TestReadSysctlFeatures:
|
|
def test_returns_features_on_success(self) -> None:
|
|
mock_result = MagicMock()
|
|
mock_result.returncode = 0
|
|
mock_result.stdout = "FPU VME AVX2 SSE"
|
|
with patch("subprocess.run", return_value=mock_result):
|
|
result = read_sysctl_features()
|
|
assert result == "FPU VME AVX2 SSE", "Should return sysctl output"
|
|
|
|
def test_returns_none_on_nonzero_returncode(self) -> None:
|
|
mock_result = MagicMock()
|
|
mock_result.returncode = 1
|
|
mock_result.stdout = ""
|
|
with patch("subprocess.run", return_value=mock_result):
|
|
result = read_sysctl_features()
|
|
assert result is None, "Should return None on non-zero return code"
|
|
|
|
def test_returns_none_on_oserror_sysctl(self) -> None:
|
|
with patch("subprocess.run", side_effect=OSError("Command not found")):
|
|
result = read_sysctl_features()
|
|
assert result is None, "Should return None on OSError"
|
|
|
|
|
|
class TestHasAvx2Support:
|
|
def test_returns_true_when_linux_cpuinfo_has_avx2(self) -> None:
|
|
has_avx2_support.cache_clear()
|
|
cpuinfo = "processor : 0\nflags : fpu avx2 sse4_2\n"
|
|
with (
|
|
patch(
|
|
"noteflow.infrastructure.platform.read_linux_cpuinfo",
|
|
return_value=cpuinfo,
|
|
),
|
|
):
|
|
result = has_avx2_support()
|
|
assert result is True, "Should detect avx2 in cpuinfo"
|
|
|
|
def test_returns_false_when_linux_cpuinfo_no_avx2(self) -> None:
|
|
has_avx2_support.cache_clear()
|
|
cpuinfo = "processor : 0\nflags : fpu sse4_2\n"
|
|
with (
|
|
patch(
|
|
"noteflow.infrastructure.platform.read_linux_cpuinfo",
|
|
return_value=cpuinfo,
|
|
),
|
|
):
|
|
result = has_avx2_support()
|
|
assert result is False, "Should not detect avx2 when missing"
|
|
|
|
def test_falls_back_to_sysctl_when_no_cpuinfo(self) -> None:
|
|
has_avx2_support.cache_clear()
|
|
with (
|
|
patch(
|
|
"noteflow.infrastructure.platform.read_linux_cpuinfo",
|
|
return_value=None,
|
|
),
|
|
patch(
|
|
"noteflow.infrastructure.platform.read_sysctl_features",
|
|
return_value="FPU AVX2 SSE",
|
|
),
|
|
):
|
|
result = has_avx2_support()
|
|
assert result is True, "Should detect avx2 via sysctl fallback"
|
|
|
|
def test_returns_false_when_both_sources_unavailable(self) -> None:
|
|
has_avx2_support.cache_clear()
|
|
with (
|
|
patch(
|
|
"noteflow.infrastructure.platform.read_linux_cpuinfo",
|
|
return_value=None,
|
|
),
|
|
patch(
|
|
"noteflow.infrastructure.platform.read_sysctl_features",
|
|
return_value=None,
|
|
),
|
|
):
|
|
result = has_avx2_support()
|
|
assert result is False, "Should return False when no source available"
|
|
|
|
|
|
class TestConfigurePytorchForPlatform:
|
|
def test_sets_nnpack_disabled_when_no_avx2(self) -> None:
|
|
has_avx2_support.cache_clear()
|
|
with (
|
|
patch(
|
|
"noteflow.infrastructure.platform.has_avx2_support",
|
|
return_value=False,
|
|
),
|
|
patch.dict(os.environ, {}, clear=False),
|
|
):
|
|
os.environ.pop("PYTORCH_DISABLE_NNPACK", None)
|
|
configure_pytorch_for_platform()
|
|
assert os.environ.get("PYTORCH_DISABLE_NNPACK") == "1", "Should disable NNPACK"
|
|
|
|
def test_does_not_override_existing_nnpack_setting(self) -> None:
|
|
has_avx2_support.cache_clear()
|
|
with (
|
|
patch(
|
|
"noteflow.infrastructure.platform.has_avx2_support",
|
|
return_value=False,
|
|
),
|
|
patch.dict(os.environ, {"PYTORCH_DISABLE_NNPACK": "0"}, clear=False),
|
|
):
|
|
configure_pytorch_for_platform()
|
|
assert os.environ.get("PYTORCH_DISABLE_NNPACK") == "0", "Should not override"
|
|
|
|
def test_does_not_set_nnpack_when_avx2_supported(self) -> None:
|
|
has_avx2_support.cache_clear()
|
|
with (
|
|
patch(
|
|
"noteflow.infrastructure.platform.has_avx2_support",
|
|
return_value=True,
|
|
),
|
|
patch.dict(os.environ, {}, clear=False),
|
|
):
|
|
os.environ.pop("PYTORCH_DISABLE_NNPACK", None)
|
|
configure_pytorch_for_platform()
|
|
assert "PYTORCH_DISABLE_NNPACK" not in os.environ, "Should not set when avx2 supported"
|