refactor: enhance test readability by reformatting with statements and assertion messages, and refine version parsing mock.
This commit is contained in:
@@ -151,12 +151,10 @@ class TestPatchTorchaudio:
|
|||||||
with patch.dict(sys.modules, {"torchaudio": mock_torchaudio}):
|
with patch.dict(sys.modules, {"torchaudio": mock_torchaudio}):
|
||||||
compat_module.apply_patches()
|
compat_module.apply_patches()
|
||||||
|
|
||||||
assert hasattr(
|
assert hasattr(mock_torchaudio, "AudioMetaData"), "should add AudioMetaData"
|
||||||
mock_torchaudio, "AudioMetaData"
|
assert mock_torchaudio.AudioMetaData is compat_module.AudioMetaData, (
|
||||||
), "should add AudioMetaData"
|
"should use our AudioMetaData class"
|
||||||
assert (
|
)
|
||||||
mock_torchaudio.AudioMetaData is compat_module.AudioMetaData
|
|
||||||
), "should use our AudioMetaData class"
|
|
||||||
|
|
||||||
def test_does_not_override_existing_audiometadata(
|
def test_does_not_override_existing_audiometadata(
|
||||||
self, compat_module: _CompatModule, mock_torch_minimal: MagicMock
|
self, compat_module: _CompatModule, mock_torch_minimal: MagicMock
|
||||||
@@ -166,23 +164,19 @@ class TestPatchTorchaudio:
|
|||||||
existing_class = type("ExistingAudioMetaData", (), {})
|
existing_class = type("ExistingAudioMetaData", (), {})
|
||||||
mock.AudioMetaData = existing_class
|
mock.AudioMetaData = existing_class
|
||||||
|
|
||||||
with patch.dict(
|
with patch.dict(sys.modules, {"torchaudio": mock, "torch": mock_torch_minimal}):
|
||||||
sys.modules, {"torchaudio": mock, "torch": mock_torch_minimal}
|
|
||||||
):
|
|
||||||
compat_module.apply_patches()
|
compat_module.apply_patches()
|
||||||
|
|
||||||
assert (
|
assert mock.AudioMetaData is existing_class, (
|
||||||
mock.AudioMetaData is existing_class
|
"should not override existing AudioMetaData"
|
||||||
), "should not override existing AudioMetaData"
|
)
|
||||||
|
|
||||||
def test_torchaudio_handles_import_error_gracefully(
|
def test_torchaudio_handles_import_error_gracefully(
|
||||||
self, compat_module: _CompatModule, mock_torch_minimal: MagicMock
|
self, compat_module: _CompatModule, mock_torch_minimal: MagicMock
|
||||||
) -> None:
|
) -> None:
|
||||||
"""_patch_torchaudio doesn't raise when torchaudio not installed."""
|
"""_patch_torchaudio doesn't raise when torchaudio not installed."""
|
||||||
# Remove torchaudio from modules if present, mock torch to prevent real import
|
# Remove torchaudio from modules if present, mock torch to prevent real import
|
||||||
with patch.dict(
|
with patch.dict(sys.modules, {"torchaudio": None, "torch": mock_torch_minimal}):
|
||||||
sys.modules, {"torchaudio": None, "torch": mock_torch_minimal}
|
|
||||||
):
|
|
||||||
# Should not raise
|
# Should not raise
|
||||||
compat_module.apply_patches()
|
compat_module.apply_patches()
|
||||||
|
|
||||||
@@ -201,9 +195,14 @@ class TestPatchTorchLoad:
|
|||||||
"""_patch_torch_load adds weights_only=False default for PyTorch 2.6+."""
|
"""_patch_torch_load adds weights_only=False default for PyTorch 2.6+."""
|
||||||
original_load = mock_torch.load
|
original_load = mock_torch.load
|
||||||
|
|
||||||
with patch.dict(sys.modules, {"torch": mock_torch}), patch(
|
def mock_parse_version(version_str: str) -> str:
|
||||||
"packaging.version.Version"
|
return version_str
|
||||||
) as mock_version:
|
|
||||||
|
with (
|
||||||
|
patch.dict(sys.modules, {"torch": mock_torch}),
|
||||||
|
patch("packaging.version.Version") as mock_version,
|
||||||
|
patch("packaging.version.parse", mock_parse_version),
|
||||||
|
):
|
||||||
mock_version.return_value = mock_version
|
mock_version.return_value = mock_version
|
||||||
mock_version.__ge__ = MagicMock(return_value=True)
|
mock_version.__ge__ = MagicMock(return_value=True)
|
||||||
|
|
||||||
@@ -218,9 +217,10 @@ class TestPatchTorchLoad:
|
|||||||
mock.__version__ = "2.5.0"
|
mock.__version__ = "2.5.0"
|
||||||
original_load = mock.load
|
original_load = mock.load
|
||||||
|
|
||||||
with patch.dict(sys.modules, {"torch": mock}), patch(
|
with (
|
||||||
"packaging.version.Version"
|
patch.dict(sys.modules, {"torch": mock}),
|
||||||
) as mock_version:
|
patch("packaging.version.Version") as mock_version,
|
||||||
|
):
|
||||||
mock_version.return_value = mock_version
|
mock_version.return_value = mock_version
|
||||||
mock_version.__ge__ = MagicMock(return_value=False)
|
mock_version.__ge__ = MagicMock(return_value=False)
|
||||||
|
|
||||||
@@ -229,9 +229,7 @@ class TestPatchTorchLoad:
|
|||||||
# load should not have been replaced
|
# load should not have been replaced
|
||||||
assert mock.load is original_load, "should not patch older PyTorch"
|
assert mock.load is original_load, "should not patch older PyTorch"
|
||||||
|
|
||||||
def test_torch_load_handles_import_error_gracefully(
|
def test_torch_load_handles_import_error_gracefully(self, compat_module: _CompatModule) -> None:
|
||||||
self, compat_module: _CompatModule
|
|
||||||
) -> None:
|
|
||||||
"""_patch_torch_load doesn't raise when torch not installed."""
|
"""_patch_torch_load doesn't raise when torch not installed."""
|
||||||
with patch.dict(sys.modules, {"torch": None}):
|
with patch.dict(sys.modules, {"torch": None}):
|
||||||
compat_module.apply_patches()
|
compat_module.apply_patches()
|
||||||
@@ -272,9 +270,7 @@ class TestPatchHuggingfaceAuth:
|
|||||||
call_kwargs = original_download.call_args[1]
|
call_kwargs = original_download.call_args[1]
|
||||||
assert "token" in call_kwargs, "should convert to token parameter"
|
assert "token" in call_kwargs, "should convert to token parameter"
|
||||||
assert call_kwargs["token"] == "my_token", "should preserve token value"
|
assert call_kwargs["token"] == "my_token", "should preserve token value"
|
||||||
assert (
|
assert "use_auth_token" not in call_kwargs, "should remove use_auth_token"
|
||||||
"use_auth_token" not in call_kwargs
|
|
||||||
), "should remove use_auth_token"
|
|
||||||
|
|
||||||
def test_preserves_token_parameter(
|
def test_preserves_token_parameter(
|
||||||
self,
|
self,
|
||||||
@@ -305,9 +301,7 @@ class TestPatchHuggingfaceAuth:
|
|||||||
self, compat_module: _CompatModule, mock_torch_minimal: MagicMock
|
self, compat_module: _CompatModule, mock_torch_minimal: MagicMock
|
||||||
) -> None:
|
) -> None:
|
||||||
"""_patch_huggingface_auth doesn't raise when huggingface_hub not installed."""
|
"""_patch_huggingface_auth doesn't raise when huggingface_hub not installed."""
|
||||||
with patch.dict(
|
with patch.dict(sys.modules, {"huggingface_hub": None, "torch": mock_torch_minimal}):
|
||||||
sys.modules, {"huggingface_hub": None, "torch": mock_torch_minimal}
|
|
||||||
):
|
|
||||||
compat_module.apply_patches()
|
compat_module.apply_patches()
|
||||||
|
|
||||||
|
|
||||||
@@ -326,14 +320,10 @@ class TestPatchSpeechbrainBackend:
|
|||||||
mock_torch_minimal: MagicMock,
|
mock_torch_minimal: MagicMock,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""_patch_speechbrain_backend adds list_audio_backends when missing."""
|
"""_patch_speechbrain_backend adds list_audio_backends when missing."""
|
||||||
with patch.dict(
|
with patch.dict(sys.modules, {"torchaudio": mock_torchaudio, "torch": mock_torch_minimal}):
|
||||||
sys.modules, {"torchaudio": mock_torchaudio, "torch": mock_torch_minimal}
|
|
||||||
):
|
|
||||||
compat_module.apply_patches()
|
compat_module.apply_patches()
|
||||||
|
|
||||||
assert hasattr(
|
assert hasattr(mock_torchaudio, "list_audio_backends"), "should add list_audio_backends"
|
||||||
mock_torchaudio, "list_audio_backends"
|
|
||||||
), "should add list_audio_backends"
|
|
||||||
result = mock_torchaudio.list_audio_backends()
|
result = mock_torchaudio.list_audio_backends()
|
||||||
assert isinstance(result, list), "should return list"
|
assert isinstance(result, list), "should return list"
|
||||||
|
|
||||||
@@ -344,14 +334,10 @@ class TestPatchSpeechbrainBackend:
|
|||||||
mock_torch_minimal: MagicMock,
|
mock_torch_minimal: MagicMock,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""_patch_speechbrain_backend adds get_audio_backend when missing."""
|
"""_patch_speechbrain_backend adds get_audio_backend when missing."""
|
||||||
with patch.dict(
|
with patch.dict(sys.modules, {"torchaudio": mock_torchaudio, "torch": mock_torch_minimal}):
|
||||||
sys.modules, {"torchaudio": mock_torchaudio, "torch": mock_torch_minimal}
|
|
||||||
):
|
|
||||||
compat_module.apply_patches()
|
compat_module.apply_patches()
|
||||||
|
|
||||||
assert hasattr(
|
assert hasattr(mock_torchaudio, "get_audio_backend"), "should add get_audio_backend"
|
||||||
mock_torchaudio, "get_audio_backend"
|
|
||||||
), "should add get_audio_backend"
|
|
||||||
result = mock_torchaudio.get_audio_backend()
|
result = mock_torchaudio.get_audio_backend()
|
||||||
assert result is None, "should return None"
|
assert result is None, "should return None"
|
||||||
|
|
||||||
@@ -362,14 +348,10 @@ class TestPatchSpeechbrainBackend:
|
|||||||
mock_torch_minimal: MagicMock,
|
mock_torch_minimal: MagicMock,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""_patch_speechbrain_backend adds set_audio_backend when missing."""
|
"""_patch_speechbrain_backend adds set_audio_backend when missing."""
|
||||||
with patch.dict(
|
with patch.dict(sys.modules, {"torchaudio": mock_torchaudio, "torch": mock_torch_minimal}):
|
||||||
sys.modules, {"torchaudio": mock_torchaudio, "torch": mock_torch_minimal}
|
|
||||||
):
|
|
||||||
compat_module.apply_patches()
|
compat_module.apply_patches()
|
||||||
|
|
||||||
assert hasattr(
|
assert hasattr(mock_torchaudio, "set_audio_backend"), "should add set_audio_backend"
|
||||||
mock_torchaudio, "set_audio_backend"
|
|
||||||
), "should add set_audio_backend"
|
|
||||||
# Should not raise
|
# Should not raise
|
||||||
mock_torchaudio.set_audio_backend("sox")
|
mock_torchaudio.set_audio_backend("sox")
|
||||||
|
|
||||||
@@ -381,14 +363,12 @@ class TestPatchSpeechbrainBackend:
|
|||||||
existing_list = MagicMock(return_value=["ffmpeg"])
|
existing_list = MagicMock(return_value=["ffmpeg"])
|
||||||
mock.list_audio_backends = existing_list
|
mock.list_audio_backends = existing_list
|
||||||
|
|
||||||
with patch.dict(
|
with patch.dict(sys.modules, {"torchaudio": mock, "torch": mock_torch_minimal}):
|
||||||
sys.modules, {"torchaudio": mock, "torch": mock_torch_minimal}
|
|
||||||
):
|
|
||||||
compat_module.apply_patches()
|
compat_module.apply_patches()
|
||||||
|
|
||||||
assert (
|
assert mock.list_audio_backends is existing_list, (
|
||||||
mock.list_audio_backends is existing_list
|
"should not override existing function"
|
||||||
), "should not override existing function"
|
)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -399,17 +379,20 @@ class TestPatchSpeechbrainBackend:
|
|||||||
class TestApplyPatches:
|
class TestApplyPatches:
|
||||||
"""Tests for the main apply_patches function."""
|
"""Tests for the main apply_patches function."""
|
||||||
|
|
||||||
def test_apply_patches_is_idempotent(
|
def test_apply_patches_is_idempotent(self, compat_module: _CompatModule) -> None:
|
||||||
self, compat_module: _CompatModule
|
|
||||||
) -> None:
|
|
||||||
"""apply_patches only applies patches once."""
|
"""apply_patches only applies patches once."""
|
||||||
mock_torch = MagicMock()
|
mock_torch = MagicMock()
|
||||||
mock_torch.__version__ = "2.6.0"
|
mock_torch.__version__ = "2.6.0"
|
||||||
original_load = mock_torch.load
|
original_load = mock_torch.load
|
||||||
|
|
||||||
with patch.dict(sys.modules, {"torch": mock_torch}), patch(
|
def mock_parse_version(version_str: str) -> str:
|
||||||
"packaging.version.Version"
|
return version_str
|
||||||
) as mock_version:
|
|
||||||
|
with (
|
||||||
|
patch.dict(sys.modules, {"torch": mock_torch}),
|
||||||
|
patch("packaging.version.Version") as mock_version,
|
||||||
|
patch("packaging.version.parse", mock_parse_version),
|
||||||
|
):
|
||||||
mock_version.return_value = mock_version
|
mock_version.return_value = mock_version
|
||||||
mock_version.__ge__ = MagicMock(return_value=True)
|
mock_version.__ge__ = MagicMock(return_value=True)
|
||||||
|
|
||||||
@@ -429,9 +412,7 @@ class TestApplyPatches:
|
|||||||
class TestEnsureCompatibility:
|
class TestEnsureCompatibility:
|
||||||
"""Tests for the ensure_compatibility entry point."""
|
"""Tests for the ensure_compatibility entry point."""
|
||||||
|
|
||||||
def test_ensure_compatibility_calls_apply_patches(
|
def test_ensure_compatibility_calls_apply_patches(self, compat_module: _CompatModule) -> None:
|
||||||
self, compat_module: _CompatModule
|
|
||||||
) -> None:
|
|
||||||
"""ensure_compatibility delegates to apply_patches."""
|
"""ensure_compatibility delegates to apply_patches."""
|
||||||
with patch.object(compat_module, "apply_patches") as mock_apply:
|
with patch.object(compat_module, "apply_patches") as mock_apply:
|
||||||
compat_module.ensure_compatibility()
|
compat_module.ensure_compatibility()
|
||||||
|
|||||||
Reference in New Issue
Block a user