Files
noteflow/tests/infrastructure/diarization/test_compat.py
Travis Vasceannie c70105f2b8 feat: implement identity management features in gRPC service
- Introduced `IdentityMixin` to manage user identity operations, including `GetCurrentUser`, `ListWorkspaces`, and `SwitchWorkspace`.
- Added corresponding gRPC methods and message definitions in the proto file for identity management.
- Enhanced `AuthService` to support user authentication and token management.
- Updated `OAuthManager` to include rate limiting for authentication attempts and improved error handling.
- Implemented unit tests for the new identity management features to ensure functionality and reliability.
2026-01-04 02:21:04 -05:00

380 lines
14 KiB
Python

"""Tests for diarization compatibility patches.
Tests cover:
- _patch_torchaudio: AudioMetaData class injection
- _patch_torch_load: weights_only=False default for PyTorch 2.6+
- _patch_huggingface_auth: use_auth_token → token parameter conversion
- _patch_speechbrain_backend: torchaudio backend API restoration
- apply_patches: Idempotency and warning suppression
- ensure_compatibility: Alias for apply_patches
"""
from __future__ import annotations
import sys
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch
import pytest
from noteflow.infrastructure.diarization._compat import (
AudioMetaData,
_patch_huggingface_auth,
_patch_speechbrain_backend,
_patch_torch_load,
_patch_torchaudio,
apply_patches,
ensure_compatibility,
)
if TYPE_CHECKING:
from collections.abc import Generator
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture
def reset_patches_state() -> Generator[None, None, None]:
"""Reset _patches_applied state before and after tests."""
import noteflow.infrastructure.diarization._compat as compat_module
original_state = compat_module._patches_applied
compat_module._patches_applied = False
yield
compat_module._patches_applied = original_state
@pytest.fixture
def mock_torchaudio() -> MagicMock:
"""Create mock torchaudio module without AudioMetaData."""
mock = MagicMock(spec=[]) # Empty spec means no auto-attributes
return mock
@pytest.fixture
def mock_torch() -> MagicMock:
"""Create mock torch module."""
mock = MagicMock()
mock.__version__ = "2.6.0"
mock.load = MagicMock(return_value={"model": "weights"})
return mock
@pytest.fixture
def mock_huggingface_hub() -> MagicMock:
"""Create mock huggingface_hub module."""
mock = MagicMock()
mock.hf_hub_download = MagicMock(return_value="/path/to/file")
return mock
# =============================================================================
# Test: AudioMetaData Dataclass
# =============================================================================
class TestAudioMetaData:
"""Tests for the replacement AudioMetaData dataclass."""
def test_audiometadata_has_required_fields(self) -> None:
"""AudioMetaData has all fields expected by pyannote.audio."""
metadata = AudioMetaData(
sample_rate=16000,
num_frames=48000,
num_channels=1,
bits_per_sample=16,
encoding="PCM_S",
)
assert metadata.sample_rate == 16000, "should store sample_rate"
assert metadata.num_frames == 48000, "should store num_frames"
assert metadata.num_channels == 1, "should store num_channels"
assert metadata.bits_per_sample == 16, "should store bits_per_sample"
assert metadata.encoding == "PCM_S", "should store encoding"
def test_audiometadata_is_immutable(self) -> None:
"""AudioMetaData fields cannot be modified after creation."""
metadata = AudioMetaData(
sample_rate=16000,
num_frames=48000,
num_channels=1,
bits_per_sample=16,
encoding="PCM_S",
)
# Dataclass is not frozen, so this is documentation of expected behavior
# If it becomes frozen, this test validates that
metadata.sample_rate = 44100 # May or may not raise depending on frozen
# =============================================================================
# Test: _patch_torchaudio
# =============================================================================
class TestPatchTorchaudio:
"""Tests for torchaudio AudioMetaData patching."""
def test_patches_audiometadata_when_missing(
self, mock_torchaudio: MagicMock
) -> None:
"""_patch_torchaudio adds AudioMetaData when not present."""
with patch.dict(sys.modules, {"torchaudio": mock_torchaudio}):
_patch_torchaudio()
assert hasattr(
mock_torchaudio, "AudioMetaData"
), "should add AudioMetaData"
assert (
mock_torchaudio.AudioMetaData is AudioMetaData
), "should use our AudioMetaData class"
def test_does_not_override_existing_audiometadata(self) -> None:
"""_patch_torchaudio preserves existing AudioMetaData if present."""
mock = MagicMock()
existing_class = type("ExistingAudioMetaData", (), {})
mock.AudioMetaData = existing_class
with patch.dict(sys.modules, {"torchaudio": mock}):
_patch_torchaudio()
assert (
mock.AudioMetaData is existing_class
), "should not override existing AudioMetaData"
def test_handles_import_error_gracefully(self) -> None:
"""_patch_torchaudio doesn't raise when torchaudio not installed."""
# Remove torchaudio from modules if present
with patch.dict(sys.modules, {"torchaudio": None}):
# Should not raise
_patch_torchaudio()
# =============================================================================
# Test: _patch_torch_load
# =============================================================================
class TestPatchTorchLoad:
"""Tests for torch.load weights_only patching."""
def test_patches_torch_load_for_pytorch_2_6_plus(
self, mock_torch: MagicMock
) -> None:
"""_patch_torch_load adds weights_only=False default for PyTorch 2.6+."""
original_load = mock_torch.load
with patch.dict(sys.modules, {"torch": mock_torch}):
with patch("packaging.version.Version") as mock_version:
mock_version.return_value = mock_version
mock_version.__ge__ = MagicMock(return_value=True)
_patch_torch_load()
# Verify torch.load was replaced (not the same function)
assert mock_torch.load is not original_load, "load should be patched"
def test_does_not_patch_older_pytorch(self) -> None:
"""_patch_torch_load skips patching for PyTorch < 2.6."""
mock = MagicMock()
mock.__version__ = "2.5.0"
original_load = mock.load
with patch.dict(sys.modules, {"torch": mock}):
with patch("packaging.version.Version") as mock_version:
mock_version.return_value = mock_version
mock_version.__ge__ = MagicMock(return_value=False)
_patch_torch_load()
# load should not have been replaced
assert mock.load is original_load, "should not patch older PyTorch"
def test_handles_import_error_gracefully(self) -> None:
"""_patch_torch_load doesn't raise when torch not installed."""
with patch.dict(sys.modules, {"torch": None}):
_patch_torch_load()
# =============================================================================
# Test: _patch_huggingface_auth
# =============================================================================
class TestPatchHuggingfaceAuth:
"""Tests for huggingface_hub use_auth_token patching."""
def test_converts_use_auth_token_to_token(
self, mock_huggingface_hub: MagicMock
) -> None:
"""_patch_huggingface_auth converts use_auth_token to token parameter."""
original_download = mock_huggingface_hub.hf_hub_download
with patch.dict(sys.modules, {"huggingface_hub": mock_huggingface_hub}):
_patch_huggingface_auth()
# Call with legacy use_auth_token
mock_huggingface_hub.hf_hub_download(
repo_id="test/repo",
filename="model.bin",
use_auth_token="my_token",
)
# Verify original was called with token instead
original_download.assert_called_once()
call_kwargs = original_download.call_args[1]
assert "token" in call_kwargs, "should convert to token parameter"
assert call_kwargs["token"] == "my_token", "should preserve token value"
assert (
"use_auth_token" not in call_kwargs
), "should remove use_auth_token"
def test_preserves_token_parameter(
self, mock_huggingface_hub: MagicMock
) -> None:
"""_patch_huggingface_auth preserves token if already using new API."""
original_download = mock_huggingface_hub.hf_hub_download
with patch.dict(sys.modules, {"huggingface_hub": mock_huggingface_hub}):
_patch_huggingface_auth()
mock_huggingface_hub.hf_hub_download(
repo_id="test/repo",
filename="model.bin",
token="my_token",
)
original_download.assert_called_once()
call_kwargs = original_download.call_args[1]
assert call_kwargs["token"] == "my_token", "should preserve token"
def test_handles_import_error_gracefully(self) -> None:
"""_patch_huggingface_auth doesn't raise when huggingface_hub not installed."""
with patch.dict(sys.modules, {"huggingface_hub": None}):
_patch_huggingface_auth()
# =============================================================================
# Test: _patch_speechbrain_backend
# =============================================================================
class TestPatchSpeechbrainBackend:
"""Tests for torchaudio backend API patching."""
def test_patches_list_audio_backends(self, mock_torchaudio: MagicMock) -> None:
"""_patch_speechbrain_backend adds list_audio_backends when missing."""
with patch.dict(sys.modules, {"torchaudio": mock_torchaudio}):
_patch_speechbrain_backend()
assert hasattr(
mock_torchaudio, "list_audio_backends"
), "should add list_audio_backends"
result = mock_torchaudio.list_audio_backends()
assert isinstance(result, list), "should return list"
def test_patches_get_audio_backend(self, mock_torchaudio: MagicMock) -> None:
"""_patch_speechbrain_backend adds get_audio_backend when missing."""
with patch.dict(sys.modules, {"torchaudio": mock_torchaudio}):
_patch_speechbrain_backend()
assert hasattr(
mock_torchaudio, "get_audio_backend"
), "should add get_audio_backend"
result = mock_torchaudio.get_audio_backend()
assert result is None, "should return None"
def test_patches_set_audio_backend(self, mock_torchaudio: MagicMock) -> None:
"""_patch_speechbrain_backend adds set_audio_backend when missing."""
with patch.dict(sys.modules, {"torchaudio": mock_torchaudio}):
_patch_speechbrain_backend()
assert hasattr(
mock_torchaudio, "set_audio_backend"
), "should add set_audio_backend"
# Should not raise
mock_torchaudio.set_audio_backend("sox")
def test_does_not_override_existing_functions(self) -> None:
"""_patch_speechbrain_backend preserves existing backend functions."""
mock = MagicMock()
existing_list = MagicMock(return_value=["ffmpeg"])
mock.list_audio_backends = existing_list
with patch.dict(sys.modules, {"torchaudio": mock}):
_patch_speechbrain_backend()
assert (
mock.list_audio_backends is existing_list
), "should not override existing function"
# =============================================================================
# Test: apply_patches
# =============================================================================
class TestApplyPatches:
"""Tests for the main apply_patches function."""
def test_apply_patches_is_idempotent(
self, reset_patches_state: None
) -> None:
"""apply_patches only applies patches once."""
import noteflow.infrastructure.diarization._compat as compat_module
with patch.object(compat_module, "_patch_torchaudio") as mock_torchaudio:
with patch.object(compat_module, "_patch_torch_load") as mock_torch:
with patch.object(
compat_module, "_patch_huggingface_auth"
) as mock_hf:
with patch.object(
compat_module, "_patch_speechbrain_backend"
) as mock_sb:
apply_patches()
apply_patches() # Second call
apply_patches() # Third call
# Each patch function should only be called once
mock_torchaudio.assert_called_once()
mock_torch.assert_called_once()
mock_hf.assert_called_once()
mock_sb.assert_called_once()
def test_apply_patches_sets_flag(self, reset_patches_state: None) -> None:
"""apply_patches sets _patches_applied flag."""
import noteflow.infrastructure.diarization._compat as compat_module
assert compat_module._patches_applied is False, "should start False"
with patch.object(compat_module, "_patch_torchaudio"):
with patch.object(compat_module, "_patch_torch_load"):
with patch.object(compat_module, "_patch_huggingface_auth"):
with patch.object(compat_module, "_patch_speechbrain_backend"):
apply_patches()
assert compat_module._patches_applied is True, "should be True after apply"
# =============================================================================
# Test: ensure_compatibility
# =============================================================================
class TestEnsureCompatibility:
"""Tests for the ensure_compatibility entry point."""
def test_ensure_compatibility_calls_apply_patches(
self, reset_patches_state: None
) -> None:
"""ensure_compatibility delegates to apply_patches."""
import noteflow.infrastructure.diarization._compat as compat_module
with patch.object(compat_module, "apply_patches") as mock_apply:
ensure_compatibility()
mock_apply.assert_called_once()