diff --git a/tests/infrastructure/diarization/test_compat.py b/tests/infrastructure/diarization/test_compat.py index fb561f9..d097490 100644 --- a/tests/infrastructure/diarization/test_compat.py +++ b/tests/infrastructure/diarization/test_compat.py @@ -151,12 +151,10 @@ class TestPatchTorchaudio: with patch.dict(sys.modules, {"torchaudio": mock_torchaudio}): compat_module.apply_patches() - assert hasattr( - mock_torchaudio, "AudioMetaData" - ), "should add AudioMetaData" - assert ( - mock_torchaudio.AudioMetaData is compat_module.AudioMetaData - ), "should use our AudioMetaData class" + assert hasattr(mock_torchaudio, "AudioMetaData"), "should add AudioMetaData" + assert mock_torchaudio.AudioMetaData is compat_module.AudioMetaData, ( + "should use our AudioMetaData class" + ) def test_does_not_override_existing_audiometadata( self, compat_module: _CompatModule, mock_torch_minimal: MagicMock @@ -166,23 +164,19 @@ class TestPatchTorchaudio: existing_class = type("ExistingAudioMetaData", (), {}) mock.AudioMetaData = existing_class - with patch.dict( - sys.modules, {"torchaudio": mock, "torch": mock_torch_minimal} - ): + with patch.dict(sys.modules, {"torchaudio": mock, "torch": mock_torch_minimal}): compat_module.apply_patches() - assert ( - mock.AudioMetaData is existing_class - ), "should not override existing AudioMetaData" + assert mock.AudioMetaData is existing_class, ( + "should not override existing AudioMetaData" + ) def test_torchaudio_handles_import_error_gracefully( self, compat_module: _CompatModule, mock_torch_minimal: MagicMock ) -> None: """_patch_torchaudio doesn't raise when torchaudio not installed.""" # Remove torchaudio from modules if present, mock torch to prevent real import - with patch.dict( - sys.modules, {"torchaudio": None, "torch": mock_torch_minimal} - ): + with patch.dict(sys.modules, {"torchaudio": None, "torch": mock_torch_minimal}): # Should not raise compat_module.apply_patches() @@ -201,9 +195,14 @@ class TestPatchTorchLoad: """_patch_torch_load adds weights_only=False default for PyTorch 2.6+.""" original_load = mock_torch.load - with patch.dict(sys.modules, {"torch": mock_torch}), patch( - "packaging.version.Version" - ) as mock_version: + def mock_parse_version(version_str: str) -> str: + return version_str + + with ( + patch.dict(sys.modules, {"torch": mock_torch}), + patch("packaging.version.Version") as mock_version, + patch("packaging.version.parse", mock_parse_version), + ): mock_version.return_value = mock_version mock_version.__ge__ = MagicMock(return_value=True) @@ -218,9 +217,10 @@ class TestPatchTorchLoad: mock.__version__ = "2.5.0" original_load = mock.load - with patch.dict(sys.modules, {"torch": mock}), patch( - "packaging.version.Version" - ) as mock_version: + with ( + patch.dict(sys.modules, {"torch": mock}), + patch("packaging.version.Version") as mock_version, + ): mock_version.return_value = mock_version mock_version.__ge__ = MagicMock(return_value=False) @@ -229,9 +229,7 @@ class TestPatchTorchLoad: # load should not have been replaced assert mock.load is original_load, "should not patch older PyTorch" - def test_torch_load_handles_import_error_gracefully( - self, compat_module: _CompatModule - ) -> None: + def test_torch_load_handles_import_error_gracefully(self, compat_module: _CompatModule) -> None: """_patch_torch_load doesn't raise when torch not installed.""" with patch.dict(sys.modules, {"torch": None}): compat_module.apply_patches() @@ -272,9 +270,7 @@ class TestPatchHuggingfaceAuth: call_kwargs = original_download.call_args[1] assert "token" in call_kwargs, "should convert to token parameter" assert call_kwargs["token"] == "my_token", "should preserve token value" - assert ( - "use_auth_token" not in call_kwargs - ), "should remove use_auth_token" + assert "use_auth_token" not in call_kwargs, "should remove use_auth_token" def test_preserves_token_parameter( self, @@ -305,9 +301,7 @@ class TestPatchHuggingfaceAuth: self, compat_module: _CompatModule, mock_torch_minimal: MagicMock ) -> None: """_patch_huggingface_auth doesn't raise when huggingface_hub not installed.""" - with patch.dict( - sys.modules, {"huggingface_hub": None, "torch": mock_torch_minimal} - ): + with patch.dict(sys.modules, {"huggingface_hub": None, "torch": mock_torch_minimal}): compat_module.apply_patches() @@ -326,14 +320,10 @@ class TestPatchSpeechbrainBackend: mock_torch_minimal: MagicMock, ) -> None: """_patch_speechbrain_backend adds list_audio_backends when missing.""" - with patch.dict( - sys.modules, {"torchaudio": mock_torchaudio, "torch": mock_torch_minimal} - ): + with patch.dict(sys.modules, {"torchaudio": mock_torchaudio, "torch": mock_torch_minimal}): compat_module.apply_patches() - assert hasattr( - mock_torchaudio, "list_audio_backends" - ), "should add list_audio_backends" + assert hasattr(mock_torchaudio, "list_audio_backends"), "should add list_audio_backends" result = mock_torchaudio.list_audio_backends() assert isinstance(result, list), "should return list" @@ -344,14 +334,10 @@ class TestPatchSpeechbrainBackend: mock_torch_minimal: MagicMock, ) -> None: """_patch_speechbrain_backend adds get_audio_backend when missing.""" - with patch.dict( - sys.modules, {"torchaudio": mock_torchaudio, "torch": mock_torch_minimal} - ): + with patch.dict(sys.modules, {"torchaudio": mock_torchaudio, "torch": mock_torch_minimal}): compat_module.apply_patches() - assert hasattr( - mock_torchaudio, "get_audio_backend" - ), "should add get_audio_backend" + assert hasattr(mock_torchaudio, "get_audio_backend"), "should add get_audio_backend" result = mock_torchaudio.get_audio_backend() assert result is None, "should return None" @@ -362,14 +348,10 @@ class TestPatchSpeechbrainBackend: mock_torch_minimal: MagicMock, ) -> None: """_patch_speechbrain_backend adds set_audio_backend when missing.""" - with patch.dict( - sys.modules, {"torchaudio": mock_torchaudio, "torch": mock_torch_minimal} - ): + with patch.dict(sys.modules, {"torchaudio": mock_torchaudio, "torch": mock_torch_minimal}): compat_module.apply_patches() - assert hasattr( - mock_torchaudio, "set_audio_backend" - ), "should add set_audio_backend" + assert hasattr(mock_torchaudio, "set_audio_backend"), "should add set_audio_backend" # Should not raise mock_torchaudio.set_audio_backend("sox") @@ -381,14 +363,12 @@ class TestPatchSpeechbrainBackend: existing_list = MagicMock(return_value=["ffmpeg"]) mock.list_audio_backends = existing_list - with patch.dict( - sys.modules, {"torchaudio": mock, "torch": mock_torch_minimal} - ): + with patch.dict(sys.modules, {"torchaudio": mock, "torch": mock_torch_minimal}): compat_module.apply_patches() - assert ( - mock.list_audio_backends is existing_list - ), "should not override existing function" + assert mock.list_audio_backends is existing_list, ( + "should not override existing function" + ) # ============================================================================= @@ -399,17 +379,20 @@ class TestPatchSpeechbrainBackend: class TestApplyPatches: """Tests for the main apply_patches function.""" - def test_apply_patches_is_idempotent( - self, compat_module: _CompatModule - ) -> None: + def test_apply_patches_is_idempotent(self, compat_module: _CompatModule) -> None: """apply_patches only applies patches once.""" mock_torch = MagicMock() mock_torch.__version__ = "2.6.0" original_load = mock_torch.load - with patch.dict(sys.modules, {"torch": mock_torch}), patch( - "packaging.version.Version" - ) as mock_version: + def mock_parse_version(version_str: str) -> str: + return version_str + + with ( + patch.dict(sys.modules, {"torch": mock_torch}), + patch("packaging.version.Version") as mock_version, + patch("packaging.version.parse", mock_parse_version), + ): mock_version.return_value = mock_version mock_version.__ge__ = MagicMock(return_value=True) @@ -429,9 +412,7 @@ class TestApplyPatches: class TestEnsureCompatibility: """Tests for the ensure_compatibility entry point.""" - def test_ensure_compatibility_calls_apply_patches( - self, compat_module: _CompatModule - ) -> None: + def test_ensure_compatibility_calls_apply_patches(self, compat_module: _CompatModule) -> None: """ensure_compatibility delegates to apply_patches.""" with patch.object(compat_module, "apply_patches") as mock_apply: compat_module.ensure_compatibility()