Files
noteflow/tests/grpc/test_audio_processing.py

441 lines
17 KiB
Python

"""Tests for audio processing helper functions in grpc/mixins/_audio_processing.py.
Tests cover:
- resample_audio: Linear interpolation resampling
- decode_audio_chunk: Bytes to numpy array conversion
- convert_audio_format: Downmixing and resampling pipeline
- validate_stream_format: Format validation and mid-stream checks
"""
from __future__ import annotations
from typing import Final
import numpy as np
import pytest
from numpy.typing import NDArray
from noteflow.grpc.mixins._audio_processing import (
StreamFormatValidation,
convert_audio_format,
decode_audio_chunk,
resample_audio,
validate_stream_format,
)
# Audio test constants
SAMPLE_RATE_8K: Final[int] = 8000
SAMPLE_RATE_16K: Final[int] = 16000
SAMPLE_RATE_44K: Final[int] = 44100
SAMPLE_RATE_48K: Final[int] = 48000
MONO_CHANNELS: Final[int] = 1
STEREO_CHANNELS: Final[int] = 2
SUPPORTED_RATES: Final[frozenset[int]] = frozenset({SAMPLE_RATE_16K, SAMPLE_RATE_44K, SAMPLE_RATE_48K})
DEFAULT_SAMPLE_RATE: Final[int] = SAMPLE_RATE_16K
def generate_sine_wave(
frequency_hz: float,
duration_seconds: float,
sample_rate: int,
) -> NDArray[np.float32]:
"""Generate a sine wave test signal.
Args:
frequency_hz: Frequency of the sine wave.
duration_seconds: Duration in seconds.
sample_rate: Sample rate in Hz.
Returns:
Float32 numpy array containing the sine wave.
"""
num_samples = int(duration_seconds * sample_rate)
t = np.arange(num_samples) / sample_rate
return np.sin(2 * np.pi * frequency_hz * t).astype(np.float32)
class TestResampleAudio:
"""Tests for resample_audio function."""
def test_upsample_8k_to_16k_preserves_duration(self) -> None:
"""Upsampling from 8kHz to 16kHz preserves audio duration."""
duration_seconds = 0.1
original = generate_sine_wave(440.0, duration_seconds, SAMPLE_RATE_8K)
expected_length = int(duration_seconds * SAMPLE_RATE_16K)
resampled = resample_audio(original, SAMPLE_RATE_8K, SAMPLE_RATE_16K)
assert resampled.shape[0] == expected_length, (
f"Expected {expected_length} samples, got {resampled.shape[0]}"
)
def test_downsample_48k_to_16k_preserves_duration(self) -> None:
"""Downsampling from 48kHz to 16kHz preserves audio duration."""
duration_seconds = 0.1
original = generate_sine_wave(440.0, duration_seconds, SAMPLE_RATE_48K)
expected_length = int(duration_seconds * SAMPLE_RATE_16K)
resampled = resample_audio(original, SAMPLE_RATE_48K, SAMPLE_RATE_16K)
assert resampled.shape[0] == expected_length, (
f"Expected {expected_length} samples, got {resampled.shape[0]}"
)
def test_same_rate_returns_original_unchanged(self) -> None:
"""Resampling with same source and destination rate returns original."""
original = generate_sine_wave(440.0, 0.1, SAMPLE_RATE_16K)
resampled = resample_audio(original, SAMPLE_RATE_16K, SAMPLE_RATE_16K)
assert resampled is original, "Same rate should return original array"
def test_empty_audio_returns_empty_array(self) -> None:
"""Resampling empty audio returns empty array."""
empty_audio = np.array([], dtype=np.float32)
resampled = resample_audio(empty_audio, SAMPLE_RATE_8K, SAMPLE_RATE_16K)
assert resampled is empty_audio, "Empty audio should return original empty array"
def test_resampled_output_is_float32(self) -> None:
"""Resampled audio maintains float32 dtype."""
original = generate_sine_wave(440.0, 0.1, SAMPLE_RATE_8K)
resampled = resample_audio(original, SAMPLE_RATE_8K, SAMPLE_RATE_16K)
assert resampled.dtype == np.float32, (
f"Expected float32 dtype, got {resampled.dtype}"
)
@pytest.mark.parametrize(
("src_rate", "dst_rate", "expected_ratio"),
[
pytest.param(SAMPLE_RATE_8K, SAMPLE_RATE_16K, 2.0, id="upsample-2x"),
pytest.param(SAMPLE_RATE_48K, SAMPLE_RATE_16K, 1 / 3, id="downsample-3x"),
pytest.param(SAMPLE_RATE_44K, SAMPLE_RATE_16K, 16000 / 44100, id="downsample-44k-to-16k"),
],
)
def test_resample_length_matches_ratio(
self,
src_rate: int,
dst_rate: int,
expected_ratio: float,
) -> None:
"""Resampled length matches the rate ratio."""
num_samples = 1000
original = np.random.rand(num_samples).astype(np.float32)
expected_length = round(num_samples * expected_ratio)
resampled = resample_audio(original, src_rate, dst_rate)
assert resampled.shape[0] == expected_length, (
f"Expected {expected_length} samples for ratio {expected_ratio}, got {resampled.shape[0]}"
)
class TestDecodeAudioChunk:
"""Tests for decode_audio_chunk function."""
def test_float32_roundtrip_preserves_values(self) -> None:
"""Encoding to bytes and decoding preserves float32 values."""
original = np.array([0.5, -0.5, 1.0, -1.0, 0.0], dtype=np.float32)
audio_bytes = original.tobytes()
decoded = decode_audio_chunk(audio_bytes)
assert decoded is not None, "Decoded audio should not be None"
np.testing.assert_array_equal(decoded, original, err_msg="Roundtrip should preserve values")
def test_empty_bytes_returns_none(self) -> None:
"""Decoding empty bytes returns None."""
empty_bytes = b""
result = decode_audio_chunk(empty_bytes)
assert result is None, "Empty bytes should return None"
def test_decoded_dtype_is_float32(self) -> None:
"""Decoded array has float32 dtype."""
original = np.array([0.1, 0.2, 0.3], dtype=np.float32)
audio_bytes = original.tobytes()
decoded = decode_audio_chunk(audio_bytes)
assert decoded is not None, "Decoded audio should not be None"
assert decoded.dtype == np.float32, f"Expected float32, got {decoded.dtype}"
def test_large_chunk_decode(self) -> None:
"""Decode large audio chunk successfully."""
num_samples = 16000 # 1 second at 16kHz
original = np.random.rand(num_samples).astype(np.float32)
audio_bytes = original.tobytes()
decoded = decode_audio_chunk(audio_bytes)
assert decoded is not None, "Decoded audio should not be None"
assert decoded.shape[0] == num_samples, (
f"Expected {num_samples} samples, got {decoded.shape[0]}"
)
class TestConvertAudioFormat:
"""Tests for convert_audio_format function."""
def test_stereo_to_mono_averages_channels(self) -> None:
"""Stereo to mono conversion averages left and right channels."""
# Create interleaved stereo: [L0, R0, L1, R1, ...]
# Left channel: 1.0, Right channel: 0.0 -> Average: 0.5
left_samples = np.ones(100, dtype=np.float32)
right_samples = np.zeros(100, dtype=np.float32)
stereo = np.empty(200, dtype=np.float32)
stereo[0::2] = left_samples
stereo[1::2] = right_samples
mono = convert_audio_format(stereo, SAMPLE_RATE_16K, STEREO_CHANNELS, SAMPLE_RATE_16K)
expected_value = 0.5
np.testing.assert_allclose(mono, expected_value, rtol=1e-5, err_msg="Stereo should average to 0.5")
def test_mono_unchanged_when_single_channel(self) -> None:
"""Mono audio passes through without modification when channels=1."""
original = generate_sine_wave(440.0, 0.1, SAMPLE_RATE_16K)
result = convert_audio_format(original, SAMPLE_RATE_16K, MONO_CHANNELS, SAMPLE_RATE_16K)
np.testing.assert_array_equal(result, original, err_msg="Mono should pass through unchanged")
def test_resample_during_format_conversion(self) -> None:
"""Format conversion performs resampling when rates differ."""
original = generate_sine_wave(440.0, 0.1, SAMPLE_RATE_48K)
expected_length = int(0.1 * SAMPLE_RATE_16K)
result = convert_audio_format(original, SAMPLE_RATE_48K, MONO_CHANNELS, SAMPLE_RATE_16K)
assert result.shape[0] == expected_length, (
f"Expected {expected_length} samples after resampling, got {result.shape[0]}"
)
def test_stereo_downmix_then_resample(self) -> None:
"""Format conversion downmixes stereo then resamples."""
duration_seconds = 0.1
# Stereo at 48kHz
stereo_samples = int(duration_seconds * SAMPLE_RATE_48K * STEREO_CHANNELS)
stereo = np.random.rand(stereo_samples).astype(np.float32)
expected_mono_length = int(duration_seconds * SAMPLE_RATE_16K)
result = convert_audio_format(stereo, SAMPLE_RATE_48K, STEREO_CHANNELS, SAMPLE_RATE_16K)
assert result.shape[0] == expected_mono_length, (
f"Expected {expected_mono_length} samples, got {result.shape[0]}"
)
def test_raises_on_buffer_not_divisible_by_channels(self) -> None:
"""Raises ValueError when buffer size not divisible by channel count."""
odd_buffer = np.array([1.0, 2.0, 3.0], dtype=np.float32) # 3 samples, 2 channels
with pytest.raises(ValueError, match="not divisible by channel count"):
convert_audio_format(odd_buffer, SAMPLE_RATE_16K, STEREO_CHANNELS, SAMPLE_RATE_16K)
@pytest.mark.parametrize(
("channels",),
[
pytest.param(2, id="stereo"),
pytest.param(4, id="quad"),
pytest.param(6, id="5.1-surround"),
],
)
def test_multichannel_downmix(self, channels: int) -> None:
"""Multi-channel audio downmixes correctly to mono."""
num_frames = 100
# All channels have value 1.0, so average should be 1.0
multichannel = np.ones(num_frames * channels, dtype=np.float32)
mono = convert_audio_format(multichannel, SAMPLE_RATE_16K, channels, SAMPLE_RATE_16K)
assert mono.shape[0] == num_frames, f"Expected {num_frames} mono samples"
np.testing.assert_allclose(mono, 1.0, rtol=1e-5, err_msg="Mono average should be 1.0")
class TestValidateStreamFormat:
"""Tests for validate_stream_format function."""
def test_valid_format_returns_normalized_values(self) -> None:
"""Valid format request returns normalized rate and channels."""
request = StreamFormatValidation(
sample_rate=SAMPLE_RATE_16K,
channels=MONO_CHANNELS,
default_sample_rate=DEFAULT_SAMPLE_RATE,
supported_sample_rates=SUPPORTED_RATES,
existing_format=None,
)
rate, channels = validate_stream_format(request)
assert rate == SAMPLE_RATE_16K, f"Expected rate {SAMPLE_RATE_16K}, got {rate}"
assert channels == MONO_CHANNELS, f"Expected channels {MONO_CHANNELS}, got {channels}"
def test_zero_sample_rate_uses_default(self) -> None:
"""Zero sample rate falls back to default sample rate."""
request = StreamFormatValidation(
sample_rate=0,
channels=MONO_CHANNELS,
default_sample_rate=DEFAULT_SAMPLE_RATE,
supported_sample_rates=SUPPORTED_RATES,
existing_format=None,
)
rate, _ = validate_stream_format(request)
assert rate == DEFAULT_SAMPLE_RATE, (
f"Expected default rate {DEFAULT_SAMPLE_RATE}, got {rate}"
)
def test_zero_channels_defaults_to_mono(self) -> None:
"""Zero channels defaults to mono (1 channel)."""
request = StreamFormatValidation(
sample_rate=SAMPLE_RATE_16K,
channels=0,
default_sample_rate=DEFAULT_SAMPLE_RATE,
supported_sample_rates=SUPPORTED_RATES,
existing_format=None,
)
_, channels = validate_stream_format(request)
assert channels == MONO_CHANNELS, f"Expected mono, got {channels} channels"
def test_raises_on_unsupported_sample_rate(self) -> None:
"""Raises ValueError for unsupported sample rate."""
unsupported_rate = 22050
request = StreamFormatValidation(
sample_rate=unsupported_rate,
channels=MONO_CHANNELS,
default_sample_rate=DEFAULT_SAMPLE_RATE,
supported_sample_rates=SUPPORTED_RATES,
existing_format=None,
)
with pytest.raises(ValueError, match="Unsupported sample_rate"):
validate_stream_format(request)
def test_raises_on_negative_channels(self) -> None:
"""Raises ValueError for negative channel count."""
request = StreamFormatValidation(
sample_rate=SAMPLE_RATE_16K,
channels=-1,
default_sample_rate=DEFAULT_SAMPLE_RATE,
supported_sample_rates=SUPPORTED_RATES,
existing_format=None,
)
with pytest.raises(ValueError, match="channels must be >= 1"):
validate_stream_format(request)
def test_raises_on_mid_stream_rate_change(self) -> None:
"""Raises ValueError when sample rate changes mid-stream."""
existing_rate = SAMPLE_RATE_44K
new_rate = SAMPLE_RATE_16K
request = StreamFormatValidation(
sample_rate=new_rate,
channels=MONO_CHANNELS,
default_sample_rate=DEFAULT_SAMPLE_RATE,
supported_sample_rates=SUPPORTED_RATES,
existing_format=(existing_rate, MONO_CHANNELS),
)
with pytest.raises(ValueError, match="cannot change mid-stream"):
validate_stream_format(request)
def test_raises_on_mid_stream_channel_change(self) -> None:
"""Raises ValueError when channel count changes mid-stream."""
request = StreamFormatValidation(
sample_rate=SAMPLE_RATE_16K,
channels=STEREO_CHANNELS,
default_sample_rate=DEFAULT_SAMPLE_RATE,
supported_sample_rates=SUPPORTED_RATES,
existing_format=(SAMPLE_RATE_16K, MONO_CHANNELS),
)
with pytest.raises(ValueError, match="cannot change mid-stream"):
validate_stream_format(request)
def test_accepts_matching_existing_format(self) -> None:
"""Accepts format when it matches existing stream format."""
request = StreamFormatValidation(
sample_rate=SAMPLE_RATE_16K,
channels=MONO_CHANNELS,
default_sample_rate=DEFAULT_SAMPLE_RATE,
supported_sample_rates=SUPPORTED_RATES,
existing_format=(SAMPLE_RATE_16K, MONO_CHANNELS),
)
rate, channels = validate_stream_format(request)
assert rate == SAMPLE_RATE_16K, "Rate should match existing"
assert channels == MONO_CHANNELS, "Channels should match existing"
@pytest.mark.parametrize(
("sample_rate",),
[
pytest.param(SAMPLE_RATE_16K, id="16kHz"),
pytest.param(SAMPLE_RATE_44K, id="44.1kHz"),
pytest.param(SAMPLE_RATE_48K, id="48kHz"),
],
)
def test_accepts_all_supported_rates(self, sample_rate: int) -> None:
"""All rates in supported_sample_rates are accepted."""
request = StreamFormatValidation(
sample_rate=sample_rate,
channels=MONO_CHANNELS,
default_sample_rate=DEFAULT_SAMPLE_RATE,
supported_sample_rates=SUPPORTED_RATES,
existing_format=None,
)
rate, _ = validate_stream_format(request)
assert rate == sample_rate, f"Expected rate {sample_rate} to be accepted"
@pytest.mark.parametrize(
("channels",),
[
pytest.param(1, id="mono"),
pytest.param(2, id="stereo"),
pytest.param(6, id="5.1-surround"),
],
)
def test_accepts_positive_channel_counts(self, channels: int) -> None:
"""Positive channel counts are accepted."""
request = StreamFormatValidation(
sample_rate=SAMPLE_RATE_16K,
channels=channels,
default_sample_rate=DEFAULT_SAMPLE_RATE,
supported_sample_rates=SUPPORTED_RATES,
existing_format=None,
)
_, result_channels = validate_stream_format(request)
assert result_channels == channels, f"Expected {channels} channels"
def test_defaults_with_existing_format_keep_existing(self) -> None:
"""When existing_format is set, defaulted values (0) keep existing format."""
existing_format = (SAMPLE_RATE_16K, MONO_CHANNELS)
request = StreamFormatValidation(
sample_rate=0, # defaulted
channels=0, # defaulted
default_sample_rate=DEFAULT_SAMPLE_RATE,
supported_sample_rates=SUPPORTED_RATES,
existing_format=existing_format,
)
rate, channels = validate_stream_format(request)
# 0 in a mid-stream request with existing format normalizes to default,
# which should match existing format if defaults match the existing format
assert rate == DEFAULT_SAMPLE_RATE, "Rate should normalize to default"
assert channels == MONO_CHANNELS, "Channels should normalize to mono"