441 lines
17 KiB
Python
441 lines
17 KiB
Python
"""Tests for audio processing helper functions in grpc/mixins/_audio_processing.py.
|
|
|
|
Tests cover:
|
|
- resample_audio: Linear interpolation resampling
|
|
- decode_audio_chunk: Bytes to numpy array conversion
|
|
- convert_audio_format: Downmixing and resampling pipeline
|
|
- validate_stream_format: Format validation and mid-stream checks
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Final
|
|
|
|
import numpy as np
|
|
import pytest
|
|
from numpy.typing import NDArray
|
|
|
|
from noteflow.grpc.mixins._audio_processing import (
|
|
StreamFormatValidation,
|
|
convert_audio_format,
|
|
decode_audio_chunk,
|
|
resample_audio,
|
|
validate_stream_format,
|
|
)
|
|
|
|
# Audio test constants
|
|
SAMPLE_RATE_8K: Final[int] = 8000
|
|
SAMPLE_RATE_16K: Final[int] = 16000
|
|
SAMPLE_RATE_44K: Final[int] = 44100
|
|
SAMPLE_RATE_48K: Final[int] = 48000
|
|
|
|
MONO_CHANNELS: Final[int] = 1
|
|
STEREO_CHANNELS: Final[int] = 2
|
|
|
|
SUPPORTED_RATES: Final[frozenset[int]] = frozenset({SAMPLE_RATE_16K, SAMPLE_RATE_44K, SAMPLE_RATE_48K})
|
|
DEFAULT_SAMPLE_RATE: Final[int] = SAMPLE_RATE_16K
|
|
|
|
|
|
def generate_sine_wave(
|
|
frequency_hz: float,
|
|
duration_seconds: float,
|
|
sample_rate: int,
|
|
) -> NDArray[np.float32]:
|
|
"""Generate a sine wave test signal.
|
|
|
|
Args:
|
|
frequency_hz: Frequency of the sine wave.
|
|
duration_seconds: Duration in seconds.
|
|
sample_rate: Sample rate in Hz.
|
|
|
|
Returns:
|
|
Float32 numpy array containing the sine wave.
|
|
"""
|
|
num_samples = int(duration_seconds * sample_rate)
|
|
t = np.arange(num_samples) / sample_rate
|
|
return np.sin(2 * np.pi * frequency_hz * t).astype(np.float32)
|
|
|
|
|
|
class TestResampleAudio:
|
|
"""Tests for resample_audio function."""
|
|
|
|
def test_upsample_8k_to_16k_preserves_duration(self) -> None:
|
|
"""Upsampling from 8kHz to 16kHz preserves audio duration."""
|
|
duration_seconds = 0.1
|
|
original = generate_sine_wave(440.0, duration_seconds, SAMPLE_RATE_8K)
|
|
expected_length = int(duration_seconds * SAMPLE_RATE_16K)
|
|
|
|
resampled = resample_audio(original, SAMPLE_RATE_8K, SAMPLE_RATE_16K)
|
|
|
|
assert resampled.shape[0] == expected_length, (
|
|
f"Expected {expected_length} samples, got {resampled.shape[0]}"
|
|
)
|
|
|
|
def test_downsample_48k_to_16k_preserves_duration(self) -> None:
|
|
"""Downsampling from 48kHz to 16kHz preserves audio duration."""
|
|
duration_seconds = 0.1
|
|
original = generate_sine_wave(440.0, duration_seconds, SAMPLE_RATE_48K)
|
|
expected_length = int(duration_seconds * SAMPLE_RATE_16K)
|
|
|
|
resampled = resample_audio(original, SAMPLE_RATE_48K, SAMPLE_RATE_16K)
|
|
|
|
assert resampled.shape[0] == expected_length, (
|
|
f"Expected {expected_length} samples, got {resampled.shape[0]}"
|
|
)
|
|
|
|
def test_same_rate_returns_original_unchanged(self) -> None:
|
|
"""Resampling with same source and destination rate returns original."""
|
|
original = generate_sine_wave(440.0, 0.1, SAMPLE_RATE_16K)
|
|
|
|
resampled = resample_audio(original, SAMPLE_RATE_16K, SAMPLE_RATE_16K)
|
|
|
|
assert resampled is original, "Same rate should return original array"
|
|
|
|
def test_empty_audio_returns_empty_array(self) -> None:
|
|
"""Resampling empty audio returns empty array."""
|
|
empty_audio = np.array([], dtype=np.float32)
|
|
|
|
resampled = resample_audio(empty_audio, SAMPLE_RATE_8K, SAMPLE_RATE_16K)
|
|
|
|
assert resampled is empty_audio, "Empty audio should return original empty array"
|
|
|
|
def test_resampled_output_is_float32(self) -> None:
|
|
"""Resampled audio maintains float32 dtype."""
|
|
original = generate_sine_wave(440.0, 0.1, SAMPLE_RATE_8K)
|
|
|
|
resampled = resample_audio(original, SAMPLE_RATE_8K, SAMPLE_RATE_16K)
|
|
|
|
assert resampled.dtype == np.float32, (
|
|
f"Expected float32 dtype, got {resampled.dtype}"
|
|
)
|
|
|
|
@pytest.mark.parametrize(
|
|
("src_rate", "dst_rate", "expected_ratio"),
|
|
[
|
|
pytest.param(SAMPLE_RATE_8K, SAMPLE_RATE_16K, 2.0, id="upsample-2x"),
|
|
pytest.param(SAMPLE_RATE_48K, SAMPLE_RATE_16K, 1 / 3, id="downsample-3x"),
|
|
pytest.param(SAMPLE_RATE_44K, SAMPLE_RATE_16K, 16000 / 44100, id="downsample-44k-to-16k"),
|
|
],
|
|
)
|
|
def test_resample_length_matches_ratio(
|
|
self,
|
|
src_rate: int,
|
|
dst_rate: int,
|
|
expected_ratio: float,
|
|
) -> None:
|
|
"""Resampled length matches the rate ratio."""
|
|
num_samples = 1000
|
|
original = np.random.rand(num_samples).astype(np.float32)
|
|
expected_length = round(num_samples * expected_ratio)
|
|
|
|
resampled = resample_audio(original, src_rate, dst_rate)
|
|
|
|
assert resampled.shape[0] == expected_length, (
|
|
f"Expected {expected_length} samples for ratio {expected_ratio}, got {resampled.shape[0]}"
|
|
)
|
|
|
|
|
|
class TestDecodeAudioChunk:
|
|
"""Tests for decode_audio_chunk function."""
|
|
|
|
def test_float32_roundtrip_preserves_values(self) -> None:
|
|
"""Encoding to bytes and decoding preserves float32 values."""
|
|
original = np.array([0.5, -0.5, 1.0, -1.0, 0.0], dtype=np.float32)
|
|
audio_bytes = original.tobytes()
|
|
|
|
decoded = decode_audio_chunk(audio_bytes)
|
|
|
|
assert decoded is not None, "Decoded audio should not be None"
|
|
np.testing.assert_array_equal(decoded, original, err_msg="Roundtrip should preserve values")
|
|
|
|
def test_empty_bytes_returns_none(self) -> None:
|
|
"""Decoding empty bytes returns None."""
|
|
empty_bytes = b""
|
|
|
|
result = decode_audio_chunk(empty_bytes)
|
|
|
|
assert result is None, "Empty bytes should return None"
|
|
|
|
def test_decoded_dtype_is_float32(self) -> None:
|
|
"""Decoded array has float32 dtype."""
|
|
original = np.array([0.1, 0.2, 0.3], dtype=np.float32)
|
|
audio_bytes = original.tobytes()
|
|
|
|
decoded = decode_audio_chunk(audio_bytes)
|
|
|
|
assert decoded is not None, "Decoded audio should not be None"
|
|
assert decoded.dtype == np.float32, f"Expected float32, got {decoded.dtype}"
|
|
|
|
def test_large_chunk_decode(self) -> None:
|
|
"""Decode large audio chunk successfully."""
|
|
num_samples = 16000 # 1 second at 16kHz
|
|
original = np.random.rand(num_samples).astype(np.float32)
|
|
audio_bytes = original.tobytes()
|
|
|
|
decoded = decode_audio_chunk(audio_bytes)
|
|
|
|
assert decoded is not None, "Decoded audio should not be None"
|
|
assert decoded.shape[0] == num_samples, (
|
|
f"Expected {num_samples} samples, got {decoded.shape[0]}"
|
|
)
|
|
|
|
|
|
class TestConvertAudioFormat:
|
|
"""Tests for convert_audio_format function."""
|
|
|
|
def test_stereo_to_mono_averages_channels(self) -> None:
|
|
"""Stereo to mono conversion averages left and right channels."""
|
|
# Create interleaved stereo: [L0, R0, L1, R1, ...]
|
|
# Left channel: 1.0, Right channel: 0.0 -> Average: 0.5
|
|
left_samples = np.ones(100, dtype=np.float32)
|
|
right_samples = np.zeros(100, dtype=np.float32)
|
|
stereo = np.empty(200, dtype=np.float32)
|
|
stereo[0::2] = left_samples
|
|
stereo[1::2] = right_samples
|
|
|
|
mono = convert_audio_format(stereo, SAMPLE_RATE_16K, STEREO_CHANNELS, SAMPLE_RATE_16K)
|
|
|
|
expected_value = 0.5
|
|
np.testing.assert_allclose(mono, expected_value, rtol=1e-5, err_msg="Stereo should average to 0.5")
|
|
|
|
def test_mono_unchanged_when_single_channel(self) -> None:
|
|
"""Mono audio passes through without modification when channels=1."""
|
|
original = generate_sine_wave(440.0, 0.1, SAMPLE_RATE_16K)
|
|
|
|
result = convert_audio_format(original, SAMPLE_RATE_16K, MONO_CHANNELS, SAMPLE_RATE_16K)
|
|
|
|
np.testing.assert_array_equal(result, original, err_msg="Mono should pass through unchanged")
|
|
|
|
def test_resample_during_format_conversion(self) -> None:
|
|
"""Format conversion performs resampling when rates differ."""
|
|
original = generate_sine_wave(440.0, 0.1, SAMPLE_RATE_48K)
|
|
expected_length = int(0.1 * SAMPLE_RATE_16K)
|
|
|
|
result = convert_audio_format(original, SAMPLE_RATE_48K, MONO_CHANNELS, SAMPLE_RATE_16K)
|
|
|
|
assert result.shape[0] == expected_length, (
|
|
f"Expected {expected_length} samples after resampling, got {result.shape[0]}"
|
|
)
|
|
|
|
def test_stereo_downmix_then_resample(self) -> None:
|
|
"""Format conversion downmixes stereo then resamples."""
|
|
duration_seconds = 0.1
|
|
# Stereo at 48kHz
|
|
stereo_samples = int(duration_seconds * SAMPLE_RATE_48K * STEREO_CHANNELS)
|
|
stereo = np.random.rand(stereo_samples).astype(np.float32)
|
|
expected_mono_length = int(duration_seconds * SAMPLE_RATE_16K)
|
|
|
|
result = convert_audio_format(stereo, SAMPLE_RATE_48K, STEREO_CHANNELS, SAMPLE_RATE_16K)
|
|
|
|
assert result.shape[0] == expected_mono_length, (
|
|
f"Expected {expected_mono_length} samples, got {result.shape[0]}"
|
|
)
|
|
|
|
def test_raises_on_buffer_not_divisible_by_channels(self) -> None:
|
|
"""Raises ValueError when buffer size not divisible by channel count."""
|
|
odd_buffer = np.array([1.0, 2.0, 3.0], dtype=np.float32) # 3 samples, 2 channels
|
|
|
|
with pytest.raises(ValueError, match="not divisible by channel count"):
|
|
convert_audio_format(odd_buffer, SAMPLE_RATE_16K, STEREO_CHANNELS, SAMPLE_RATE_16K)
|
|
|
|
@pytest.mark.parametrize(
|
|
("channels",),
|
|
[
|
|
pytest.param(2, id="stereo"),
|
|
pytest.param(4, id="quad"),
|
|
pytest.param(6, id="5.1-surround"),
|
|
],
|
|
)
|
|
def test_multichannel_downmix(self, channels: int) -> None:
|
|
"""Multi-channel audio downmixes correctly to mono."""
|
|
num_frames = 100
|
|
# All channels have value 1.0, so average should be 1.0
|
|
multichannel = np.ones(num_frames * channels, dtype=np.float32)
|
|
|
|
mono = convert_audio_format(multichannel, SAMPLE_RATE_16K, channels, SAMPLE_RATE_16K)
|
|
|
|
assert mono.shape[0] == num_frames, f"Expected {num_frames} mono samples"
|
|
np.testing.assert_allclose(mono, 1.0, rtol=1e-5, err_msg="Mono average should be 1.0")
|
|
|
|
|
|
class TestValidateStreamFormat:
|
|
"""Tests for validate_stream_format function."""
|
|
|
|
def test_valid_format_returns_normalized_values(self) -> None:
|
|
"""Valid format request returns normalized rate and channels."""
|
|
request = StreamFormatValidation(
|
|
sample_rate=SAMPLE_RATE_16K,
|
|
channels=MONO_CHANNELS,
|
|
default_sample_rate=DEFAULT_SAMPLE_RATE,
|
|
supported_sample_rates=SUPPORTED_RATES,
|
|
existing_format=None,
|
|
)
|
|
|
|
rate, channels = validate_stream_format(request)
|
|
|
|
assert rate == SAMPLE_RATE_16K, f"Expected rate {SAMPLE_RATE_16K}, got {rate}"
|
|
assert channels == MONO_CHANNELS, f"Expected channels {MONO_CHANNELS}, got {channels}"
|
|
|
|
def test_zero_sample_rate_uses_default(self) -> None:
|
|
"""Zero sample rate falls back to default sample rate."""
|
|
request = StreamFormatValidation(
|
|
sample_rate=0,
|
|
channels=MONO_CHANNELS,
|
|
default_sample_rate=DEFAULT_SAMPLE_RATE,
|
|
supported_sample_rates=SUPPORTED_RATES,
|
|
existing_format=None,
|
|
)
|
|
|
|
rate, _ = validate_stream_format(request)
|
|
|
|
assert rate == DEFAULT_SAMPLE_RATE, (
|
|
f"Expected default rate {DEFAULT_SAMPLE_RATE}, got {rate}"
|
|
)
|
|
|
|
def test_zero_channels_defaults_to_mono(self) -> None:
|
|
"""Zero channels defaults to mono (1 channel)."""
|
|
request = StreamFormatValidation(
|
|
sample_rate=SAMPLE_RATE_16K,
|
|
channels=0,
|
|
default_sample_rate=DEFAULT_SAMPLE_RATE,
|
|
supported_sample_rates=SUPPORTED_RATES,
|
|
existing_format=None,
|
|
)
|
|
|
|
_, channels = validate_stream_format(request)
|
|
|
|
assert channels == MONO_CHANNELS, f"Expected mono, got {channels} channels"
|
|
|
|
def test_raises_on_unsupported_sample_rate(self) -> None:
|
|
"""Raises ValueError for unsupported sample rate."""
|
|
unsupported_rate = 22050
|
|
request = StreamFormatValidation(
|
|
sample_rate=unsupported_rate,
|
|
channels=MONO_CHANNELS,
|
|
default_sample_rate=DEFAULT_SAMPLE_RATE,
|
|
supported_sample_rates=SUPPORTED_RATES,
|
|
existing_format=None,
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="Unsupported sample_rate"):
|
|
validate_stream_format(request)
|
|
|
|
def test_raises_on_negative_channels(self) -> None:
|
|
"""Raises ValueError for negative channel count."""
|
|
request = StreamFormatValidation(
|
|
sample_rate=SAMPLE_RATE_16K,
|
|
channels=-1,
|
|
default_sample_rate=DEFAULT_SAMPLE_RATE,
|
|
supported_sample_rates=SUPPORTED_RATES,
|
|
existing_format=None,
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="channels must be >= 1"):
|
|
validate_stream_format(request)
|
|
|
|
def test_raises_on_mid_stream_rate_change(self) -> None:
|
|
"""Raises ValueError when sample rate changes mid-stream."""
|
|
existing_rate = SAMPLE_RATE_44K
|
|
new_rate = SAMPLE_RATE_16K
|
|
request = StreamFormatValidation(
|
|
sample_rate=new_rate,
|
|
channels=MONO_CHANNELS,
|
|
default_sample_rate=DEFAULT_SAMPLE_RATE,
|
|
supported_sample_rates=SUPPORTED_RATES,
|
|
existing_format=(existing_rate, MONO_CHANNELS),
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="cannot change mid-stream"):
|
|
validate_stream_format(request)
|
|
|
|
def test_raises_on_mid_stream_channel_change(self) -> None:
|
|
"""Raises ValueError when channel count changes mid-stream."""
|
|
request = StreamFormatValidation(
|
|
sample_rate=SAMPLE_RATE_16K,
|
|
channels=STEREO_CHANNELS,
|
|
default_sample_rate=DEFAULT_SAMPLE_RATE,
|
|
supported_sample_rates=SUPPORTED_RATES,
|
|
existing_format=(SAMPLE_RATE_16K, MONO_CHANNELS),
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="cannot change mid-stream"):
|
|
validate_stream_format(request)
|
|
|
|
def test_accepts_matching_existing_format(self) -> None:
|
|
"""Accepts format when it matches existing stream format."""
|
|
request = StreamFormatValidation(
|
|
sample_rate=SAMPLE_RATE_16K,
|
|
channels=MONO_CHANNELS,
|
|
default_sample_rate=DEFAULT_SAMPLE_RATE,
|
|
supported_sample_rates=SUPPORTED_RATES,
|
|
existing_format=(SAMPLE_RATE_16K, MONO_CHANNELS),
|
|
)
|
|
|
|
rate, channels = validate_stream_format(request)
|
|
|
|
assert rate == SAMPLE_RATE_16K, "Rate should match existing"
|
|
assert channels == MONO_CHANNELS, "Channels should match existing"
|
|
|
|
@pytest.mark.parametrize(
|
|
("sample_rate",),
|
|
[
|
|
pytest.param(SAMPLE_RATE_16K, id="16kHz"),
|
|
pytest.param(SAMPLE_RATE_44K, id="44.1kHz"),
|
|
pytest.param(SAMPLE_RATE_48K, id="48kHz"),
|
|
],
|
|
)
|
|
def test_accepts_all_supported_rates(self, sample_rate: int) -> None:
|
|
"""All rates in supported_sample_rates are accepted."""
|
|
request = StreamFormatValidation(
|
|
sample_rate=sample_rate,
|
|
channels=MONO_CHANNELS,
|
|
default_sample_rate=DEFAULT_SAMPLE_RATE,
|
|
supported_sample_rates=SUPPORTED_RATES,
|
|
existing_format=None,
|
|
)
|
|
|
|
rate, _ = validate_stream_format(request)
|
|
|
|
assert rate == sample_rate, f"Expected rate {sample_rate} to be accepted"
|
|
|
|
@pytest.mark.parametrize(
|
|
("channels",),
|
|
[
|
|
pytest.param(1, id="mono"),
|
|
pytest.param(2, id="stereo"),
|
|
pytest.param(6, id="5.1-surround"),
|
|
],
|
|
)
|
|
def test_accepts_positive_channel_counts(self, channels: int) -> None:
|
|
"""Positive channel counts are accepted."""
|
|
request = StreamFormatValidation(
|
|
sample_rate=SAMPLE_RATE_16K,
|
|
channels=channels,
|
|
default_sample_rate=DEFAULT_SAMPLE_RATE,
|
|
supported_sample_rates=SUPPORTED_RATES,
|
|
existing_format=None,
|
|
)
|
|
|
|
_, result_channels = validate_stream_format(request)
|
|
|
|
assert result_channels == channels, f"Expected {channels} channels"
|
|
|
|
def test_defaults_with_existing_format_keep_existing(self) -> None:
|
|
"""When existing_format is set, defaulted values (0) keep existing format."""
|
|
existing_format = (SAMPLE_RATE_16K, MONO_CHANNELS)
|
|
|
|
request = StreamFormatValidation(
|
|
sample_rate=0, # defaulted
|
|
channels=0, # defaulted
|
|
default_sample_rate=DEFAULT_SAMPLE_RATE,
|
|
supported_sample_rates=SUPPORTED_RATES,
|
|
existing_format=existing_format,
|
|
)
|
|
|
|
rate, channels = validate_stream_format(request)
|
|
|
|
# 0 in a mid-stream request with existing format normalizes to default,
|
|
# which should match existing format if defaults match the existing format
|
|
assert rate == DEFAULT_SAMPLE_RATE, "Rate should normalize to default"
|
|
assert channels == MONO_CHANNELS, "Channels should normalize to mono"
|