421 lines
16 KiB
Python
421 lines
16 KiB
Python
"""Tests for diarization compatibility patches.
|
|
|
|
Tests cover:
|
|
- _patch_torchaudio: AudioMetaData class injection
|
|
- _patch_torch_load: weights_only=False default for PyTorch 2.6+
|
|
- _patch_huggingface_auth: use_auth_token → token parameter conversion
|
|
- _patch_speechbrain_backend: torchaudio backend API restoration
|
|
- apply_patches: Idempotency and warning suppression
|
|
- ensure_compatibility: Alias for apply_patches
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import importlib
|
|
import sys
|
|
from typing import TYPE_CHECKING, Protocol, cast
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from noteflow.infrastructure.diarization import _compat
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Generator
|
|
|
|
# =============================================================================
|
|
# Constants for test assertions
|
|
# =============================================================================
|
|
|
|
SAMPLE_RATE_16K = 16000
|
|
"""Sample rate in Hz for narrowband audio (telephony, ASR)."""
|
|
|
|
SAMPLE_RATE_48K = 48000
|
|
"""Sample rate in Hz for high-quality audio (CD-quality stereo)."""
|
|
|
|
BIT_DEPTH_16 = 16
|
|
"""Standard bit depth for PCM audio samples."""
|
|
|
|
|
|
class _CompatModule(Protocol):
|
|
"""Protocol for the diarization compatibility module."""
|
|
|
|
AudioMetaData: type
|
|
|
|
def apply_patches(self) -> None: ...
|
|
|
|
def ensure_compatibility(self) -> None: ...
|
|
|
|
|
|
# =============================================================================
|
|
# Fixtures
|
|
# =============================================================================
|
|
|
|
|
|
@pytest.fixture
|
|
def compat_module() -> Generator[_CompatModule, None, None]:
|
|
"""Reload compatibility module to reset internal patch state."""
|
|
module = importlib.reload(_compat)
|
|
yield cast(_CompatModule, module)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_torchaudio() -> MagicMock:
|
|
"""Create mock torchaudio module without AudioMetaData."""
|
|
return MagicMock(spec=[])
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_torch() -> MagicMock:
|
|
"""Create mock torch module."""
|
|
mock = MagicMock()
|
|
mock.__version__ = "2.6.0"
|
|
mock.load = MagicMock(return_value={"model": "weights"})
|
|
return mock
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_torch_minimal() -> MagicMock:
|
|
"""Create minimal mock torch module to prevent real torch import.
|
|
|
|
This is used by tests that don't specifically test torch behavior
|
|
but need to call apply_patches() without triggering real torch imports.
|
|
"""
|
|
mock = MagicMock()
|
|
mock.__version__ = "2.5.0" # Version < 2.6 to skip torch.load patching
|
|
return mock
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_huggingface_hub() -> MagicMock:
|
|
"""Create mock huggingface_hub module."""
|
|
mock = MagicMock()
|
|
mock.hf_hub_download = MagicMock(return_value="/path/to/file")
|
|
return mock
|
|
|
|
|
|
# =============================================================================
|
|
# Test: AudioMetaData Dataclass
|
|
# =============================================================================
|
|
|
|
|
|
class TestAudioMetaData:
|
|
"""Tests for the replacement AudioMetaData dataclass."""
|
|
|
|
def test_audiometadata_has_required_fields(self, compat_module: _CompatModule) -> None:
|
|
"""AudioMetaData has all fields expected by pyannote.audio."""
|
|
num_channels = 1
|
|
metadata = compat_module.AudioMetaData(
|
|
sample_rate=SAMPLE_RATE_16K,
|
|
num_frames=SAMPLE_RATE_48K,
|
|
num_channels=num_channels,
|
|
bits_per_sample=BIT_DEPTH_16,
|
|
encoding="PCM_S",
|
|
)
|
|
|
|
assert metadata.sample_rate == SAMPLE_RATE_16K, "should store sample_rate"
|
|
assert metadata.num_frames == SAMPLE_RATE_48K, "should store num_frames"
|
|
assert metadata.num_channels == num_channels, "should store num_channels"
|
|
assert metadata.bits_per_sample == BIT_DEPTH_16, "should store bits_per_sample"
|
|
assert metadata.encoding == "PCM_S", "should store encoding"
|
|
|
|
def test_audiometadata_is_immutable(self, compat_module: _CompatModule) -> None:
|
|
"""AudioMetaData fields cannot be modified after creation."""
|
|
num_channels = 1
|
|
sample_rate_44k = 44100
|
|
metadata = compat_module.AudioMetaData(
|
|
sample_rate=SAMPLE_RATE_16K,
|
|
num_frames=SAMPLE_RATE_48K,
|
|
num_channels=num_channels,
|
|
bits_per_sample=BIT_DEPTH_16,
|
|
encoding="PCM_S",
|
|
)
|
|
|
|
# Dataclass is not frozen, so this is documentation of expected behavior
|
|
# If it becomes frozen, this test validates that
|
|
metadata.sample_rate = sample_rate_44k # May or may not raise depending on frozen
|
|
|
|
|
|
# =============================================================================
|
|
# Test: _patch_torchaudio
|
|
# =============================================================================
|
|
|
|
|
|
class TestPatchTorchaudio:
|
|
"""Tests for torchaudio AudioMetaData patching."""
|
|
|
|
def test_patches_audiometadata_when_missing(
|
|
self, compat_module: _CompatModule, mock_torchaudio: MagicMock
|
|
) -> None:
|
|
"""_patch_torchaudio adds AudioMetaData when not present."""
|
|
with patch.dict(sys.modules, {"torchaudio": mock_torchaudio}):
|
|
compat_module.apply_patches()
|
|
|
|
assert hasattr(mock_torchaudio, "AudioMetaData"), "should add AudioMetaData"
|
|
assert mock_torchaudio.AudioMetaData is compat_module.AudioMetaData, (
|
|
"should use our AudioMetaData class"
|
|
)
|
|
|
|
def test_does_not_override_existing_audiometadata(
|
|
self, compat_module: _CompatModule, mock_torch_minimal: MagicMock
|
|
) -> None:
|
|
"""_patch_torchaudio preserves existing AudioMetaData if present."""
|
|
mock = MagicMock()
|
|
existing_class = type("ExistingAudioMetaData", (), {})
|
|
mock.AudioMetaData = existing_class
|
|
|
|
with patch.dict(sys.modules, {"torchaudio": mock, "torch": mock_torch_minimal}):
|
|
compat_module.apply_patches()
|
|
|
|
assert mock.AudioMetaData is existing_class, (
|
|
"should not override existing AudioMetaData"
|
|
)
|
|
|
|
def test_torchaudio_handles_import_error_gracefully(
|
|
self, compat_module: _CompatModule, mock_torch_minimal: MagicMock
|
|
) -> None:
|
|
"""_patch_torchaudio doesn't raise when torchaudio not installed."""
|
|
# Remove torchaudio from modules if present, mock torch to prevent real import
|
|
with patch.dict(sys.modules, {"torchaudio": None, "torch": mock_torch_minimal}):
|
|
# Should not raise
|
|
compat_module.apply_patches()
|
|
|
|
|
|
# =============================================================================
|
|
# Test: _patch_torch_load
|
|
# =============================================================================
|
|
|
|
|
|
class TestPatchTorchLoad:
|
|
"""Tests for torch.load weights_only patching."""
|
|
|
|
def test_patches_torch_load_for_pytorch_2_6_plus(
|
|
self, compat_module: _CompatModule, mock_torch: MagicMock
|
|
) -> None:
|
|
"""_patch_torch_load adds weights_only=False default for PyTorch 2.6+."""
|
|
original_load = mock_torch.load
|
|
|
|
def mock_parse_version(version_str: str) -> str:
|
|
return version_str
|
|
|
|
with (
|
|
patch.dict(sys.modules, {"torch": mock_torch}),
|
|
patch("packaging.version.Version") as mock_version,
|
|
patch("packaging.version.parse", mock_parse_version),
|
|
):
|
|
mock_version.return_value = mock_version
|
|
mock_version.__ge__ = MagicMock(return_value=True)
|
|
|
|
compat_module.apply_patches()
|
|
|
|
# Verify torch.load was replaced (not the same function)
|
|
assert mock_torch.load is not original_load, "load should be patched"
|
|
|
|
def test_does_not_patch_older_pytorch(self, compat_module: _CompatModule) -> None:
|
|
"""_patch_torch_load skips patching for PyTorch < 2.6."""
|
|
mock = MagicMock()
|
|
mock.__version__ = "2.5.0"
|
|
original_load = mock.load
|
|
|
|
with (
|
|
patch.dict(sys.modules, {"torch": mock}),
|
|
patch("packaging.version.Version") as mock_version,
|
|
):
|
|
mock_version.return_value = mock_version
|
|
mock_version.__ge__ = MagicMock(return_value=False)
|
|
|
|
compat_module.apply_patches()
|
|
|
|
# load should not have been replaced
|
|
assert mock.load is original_load, "should not patch older PyTorch"
|
|
|
|
def test_torch_load_handles_import_error_gracefully(self, compat_module: _CompatModule) -> None:
|
|
"""_patch_torch_load doesn't raise when torch not installed."""
|
|
with patch.dict(sys.modules, {"torch": None}):
|
|
compat_module.apply_patches()
|
|
|
|
|
|
# =============================================================================
|
|
# Test: _patch_huggingface_auth
|
|
# =============================================================================
|
|
|
|
|
|
class TestPatchHuggingfaceAuth:
|
|
"""Tests for huggingface_hub use_auth_token patching."""
|
|
|
|
def test_converts_use_auth_token_to_token(
|
|
self,
|
|
compat_module: _CompatModule,
|
|
mock_huggingface_hub: MagicMock,
|
|
mock_torch_minimal: MagicMock,
|
|
) -> None:
|
|
"""_patch_huggingface_auth converts use_auth_token to token parameter."""
|
|
original_download = mock_huggingface_hub.hf_hub_download
|
|
|
|
with patch.dict(
|
|
sys.modules,
|
|
{"huggingface_hub": mock_huggingface_hub, "torch": mock_torch_minimal},
|
|
):
|
|
compat_module.apply_patches()
|
|
|
|
# Call with legacy use_auth_token
|
|
mock_huggingface_hub.hf_hub_download(
|
|
repo_id="test/repo",
|
|
filename="model.bin",
|
|
use_auth_token="my_token",
|
|
)
|
|
|
|
# Verify original was called with token instead
|
|
original_download.assert_called_once()
|
|
call_kwargs = original_download.call_args[1]
|
|
assert "token" in call_kwargs, "should convert to token parameter"
|
|
assert call_kwargs["token"] == "my_token", "should preserve token value"
|
|
assert "use_auth_token" not in call_kwargs, "should remove use_auth_token"
|
|
|
|
def test_preserves_token_parameter(
|
|
self,
|
|
compat_module: _CompatModule,
|
|
mock_huggingface_hub: MagicMock,
|
|
mock_torch_minimal: MagicMock,
|
|
) -> None:
|
|
"""_patch_huggingface_auth preserves token if already using new API."""
|
|
original_download = mock_huggingface_hub.hf_hub_download
|
|
|
|
with patch.dict(
|
|
sys.modules,
|
|
{"huggingface_hub": mock_huggingface_hub, "torch": mock_torch_minimal},
|
|
):
|
|
compat_module.apply_patches()
|
|
|
|
mock_huggingface_hub.hf_hub_download(
|
|
repo_id="test/repo",
|
|
filename="model.bin",
|
|
token="my_token",
|
|
)
|
|
|
|
original_download.assert_called_once()
|
|
call_kwargs = original_download.call_args[1]
|
|
assert call_kwargs["token"] == "my_token", "should preserve token"
|
|
|
|
def test_huggingface_handles_import_error_gracefully(
|
|
self, compat_module: _CompatModule, mock_torch_minimal: MagicMock
|
|
) -> None:
|
|
"""_patch_huggingface_auth doesn't raise when huggingface_hub not installed."""
|
|
with patch.dict(sys.modules, {"huggingface_hub": None, "torch": mock_torch_minimal}):
|
|
compat_module.apply_patches()
|
|
|
|
|
|
# =============================================================================
|
|
# Test: _patch_speechbrain_backend
|
|
# =============================================================================
|
|
|
|
|
|
class TestPatchSpeechbrainBackend:
|
|
"""Tests for torchaudio backend API patching."""
|
|
|
|
def test_patches_list_audio_backends(
|
|
self,
|
|
compat_module: _CompatModule,
|
|
mock_torchaudio: MagicMock,
|
|
mock_torch_minimal: MagicMock,
|
|
) -> None:
|
|
"""_patch_speechbrain_backend adds list_audio_backends when missing."""
|
|
with patch.dict(sys.modules, {"torchaudio": mock_torchaudio, "torch": mock_torch_minimal}):
|
|
compat_module.apply_patches()
|
|
|
|
assert hasattr(mock_torchaudio, "list_audio_backends"), "should add list_audio_backends"
|
|
result = mock_torchaudio.list_audio_backends()
|
|
assert isinstance(result, list), "should return list"
|
|
|
|
def test_patches_get_audio_backend(
|
|
self,
|
|
compat_module: _CompatModule,
|
|
mock_torchaudio: MagicMock,
|
|
mock_torch_minimal: MagicMock,
|
|
) -> None:
|
|
"""_patch_speechbrain_backend adds get_audio_backend when missing."""
|
|
with patch.dict(sys.modules, {"torchaudio": mock_torchaudio, "torch": mock_torch_minimal}):
|
|
compat_module.apply_patches()
|
|
|
|
assert hasattr(mock_torchaudio, "get_audio_backend"), "should add get_audio_backend"
|
|
result = mock_torchaudio.get_audio_backend()
|
|
assert result is None, "should return None"
|
|
|
|
def test_patches_set_audio_backend(
|
|
self,
|
|
compat_module: _CompatModule,
|
|
mock_torchaudio: MagicMock,
|
|
mock_torch_minimal: MagicMock,
|
|
) -> None:
|
|
"""_patch_speechbrain_backend adds set_audio_backend when missing."""
|
|
with patch.dict(sys.modules, {"torchaudio": mock_torchaudio, "torch": mock_torch_minimal}):
|
|
compat_module.apply_patches()
|
|
|
|
assert hasattr(mock_torchaudio, "set_audio_backend"), "should add set_audio_backend"
|
|
# Should not raise
|
|
mock_torchaudio.set_audio_backend("sox")
|
|
|
|
def test_does_not_override_existing_functions(
|
|
self, compat_module: _CompatModule, mock_torch_minimal: MagicMock
|
|
) -> None:
|
|
"""_patch_speechbrain_backend preserves existing backend functions."""
|
|
mock = MagicMock()
|
|
existing_list = MagicMock(return_value=["ffmpeg"])
|
|
mock.list_audio_backends = existing_list
|
|
|
|
with patch.dict(sys.modules, {"torchaudio": mock, "torch": mock_torch_minimal}):
|
|
compat_module.apply_patches()
|
|
|
|
assert mock.list_audio_backends is existing_list, (
|
|
"should not override existing function"
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Test: apply_patches
|
|
# =============================================================================
|
|
|
|
|
|
class TestApplyPatches:
|
|
"""Tests for the main apply_patches function."""
|
|
|
|
def test_apply_patches_is_idempotent(self, compat_module: _CompatModule) -> None:
|
|
"""apply_patches only applies patches once."""
|
|
mock_torch = MagicMock()
|
|
mock_torch.__version__ = "2.6.0"
|
|
original_load = mock_torch.load
|
|
|
|
def mock_parse_version(version_str: str) -> str:
|
|
return version_str
|
|
|
|
with (
|
|
patch.dict(sys.modules, {"torch": mock_torch}),
|
|
patch("packaging.version.Version") as mock_version,
|
|
patch("packaging.version.parse", mock_parse_version),
|
|
):
|
|
mock_version.return_value = mock_version
|
|
mock_version.__ge__ = MagicMock(return_value=True)
|
|
|
|
compat_module.apply_patches()
|
|
first_load = mock_torch.load
|
|
compat_module.apply_patches()
|
|
|
|
assert first_load is not original_load, "initial call should patch torch.load"
|
|
assert mock_torch.load is first_load, "subsequent calls should be idempotent"
|
|
|
|
|
|
# =============================================================================
|
|
# Test: ensure_compatibility
|
|
# =============================================================================
|
|
|
|
|
|
class TestEnsureCompatibility:
|
|
"""Tests for the ensure_compatibility entry point."""
|
|
|
|
def test_ensure_compatibility_calls_apply_patches(self, compat_module: _CompatModule) -> None:
|
|
"""ensure_compatibility delegates to apply_patches."""
|
|
with patch.object(compat_module, "apply_patches") as mock_apply:
|
|
compat_module.ensure_compatibility()
|
|
|
|
mock_apply.assert_called_once()
|