Files
disbord/tests/integration/test_nemo_audio_pipeline.py
Travis Vasceannie 3acb779569 chore: remove .env.example and add new files for project structure
- Deleted .env.example file as it is no longer needed.
- Added .gitignore to manage ignored files and directories.
- Introduced CLAUDE.md for AI provider integration documentation.
- Created dev.sh for development setup and scripts.
- Updated Dockerfile and Dockerfile.production for improved build processes.
- Added multiple test files and directories for comprehensive testing.
- Introduced new utility and service files for enhanced functionality.
- Organized codebase with new directories and files for better maintainability.
2025-08-27 23:00:19 -04:00

734 lines
27 KiB
Python

"""
Integration tests for NVIDIA NeMo Audio Processing Pipeline.
Tests the end-to-end integration of NeMo speaker diarization with the Discord bot's
audio processing pipeline, including recording, transcription, and quote analysis.
"""
import asyncio
import tempfile
import wave
from datetime import datetime, timedelta
from pathlib import Path
from typing import List
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
import torch
from core.consent_manager import ConsentManager
from core.database import DatabaseManager
from services.audio.audio_recorder import AudioRecorderService
from services.audio.speaker_diarization import (DiarizationResult,
SpeakerDiarizationService,
SpeakerSegment)
from services.audio.transcription_service import TranscriptionService
from services.quotes.quote_analyzer import QuoteAnalyzer
class TestNeMoAudioPipeline:
"""Integration test suite for NeMo-based audio processing pipeline."""
@pytest.fixture
def mock_database_manager(self):
"""Create mock database manager with realistic responses."""
db_manager = AsyncMock(spec=DatabaseManager)
# Mock user consent data
db_manager.execute_query.return_value = [
{"user_id": 111, "consent_given": True, "username": "Alice"},
{"user_id": 222, "consent_given": True, "username": "Bob"},
{"user_id": 333, "consent_given": False, "username": "Charlie"},
]
return db_manager
@pytest.fixture
def mock_consent_manager(self, mock_database_manager):
"""Create consent manager with database integration."""
consent_manager = ConsentManager(mock_database_manager)
consent_manager.has_recording_consent = AsyncMock(return_value=True)
consent_manager.get_consented_users = AsyncMock(return_value=[111, 222])
return consent_manager
@pytest.fixture
def mock_audio_processor(self):
"""Create audio processor for format conversions."""
processor = MagicMock()
processor.tensor_to_bytes = AsyncMock(return_value=b"processed_audio_bytes")
processor.bytes_to_tensor = AsyncMock(return_value=torch.randn(1, 16000))
return processor
@pytest.fixture
async def diarization_service(
self, mock_database_manager, mock_consent_manager, mock_audio_processor
):
"""Create initialized speaker diarization service."""
service = SpeakerDiarizationService(
db_manager=mock_database_manager,
consent_manager=mock_consent_manager,
audio_processor=mock_audio_processor,
)
# Mock successful initialization
with patch.object(service, "_load_nemo_models") as mock_load:
mock_load.return_value = True
await service.initialize()
return service
@pytest.fixture
async def transcription_service(self):
"""Create transcription service."""
service = AsyncMock(spec=TranscriptionService)
service.transcribe_audio.return_value = {
"segments": [
{
"start": 0.0,
"end": 2.5,
"text": "This is a funny quote",
"confidence": 0.95,
},
{
"start": 3.0,
"end": 5.5,
"text": "Another interesting statement",
"confidence": 0.88,
},
],
"full_text": "This is a funny quote. Another interesting statement.",
}
return service
@pytest.fixture
async def quote_analyzer(self):
"""Create quote analyzer service."""
analyzer = AsyncMock(spec=QuoteAnalyzer)
analyzer.analyze_quote.return_value = {
"funny_score": 8.5,
"dark_score": 2.1,
"silly_score": 7.3,
"suspicious_score": 1.8,
"asinine_score": 3.2,
"overall_score": 7.2,
}
return analyzer
@pytest.fixture
async def audio_recorder(self):
"""Create audio recorder service."""
recorder = AsyncMock(spec=AudioRecorderService)
recorder.get_active_recordings.return_value = {
67890: {
"guild_id": 12345,
"participants": [111, 222],
"start_time": datetime.utcnow() - timedelta(seconds=30),
"buffer": MagicMock(),
}
}
return recorder
@pytest.fixture
def sample_discord_audio(self):
"""Create sample Discord-compatible audio data."""
# Generate 10 seconds of mock audio with two speakers
sample_rate = 48000 # Discord's sample rate
duration = 10
samples = int(sample_rate * duration)
# Create stereo audio with different patterns for each channel
left_channel = np.sin(
2 * np.pi * 440 * np.linspace(0, duration, samples)
) # 440 Hz
right_channel = np.sin(
2 * np.pi * 880 * np.linspace(0, duration, samples)
) # 880 Hz
# Combine channels
stereo_audio = np.array([left_channel, right_channel])
return torch.from_numpy(stereo_audio.astype(np.float32))
@pytest.fixture
def create_test_wav_file(self):
"""Create a temporary WAV file with test audio."""
def _create_wav(duration_seconds=10, sample_rate=16000, num_channels=1):
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
# Generate sine wave audio
samples = int(duration_seconds * sample_rate)
audio_data = np.sin(
2 * np.pi * 440 * np.linspace(0, duration_seconds, samples)
)
audio_data = (audio_data * 32767).astype(np.int16)
# Write WAV file
with wave.open(f.name, "wb") as wav_file:
wav_file.setnchannels(num_channels)
wav_file.setsampwidth(2) # 16-bit
wav_file.setframerate(sample_rate)
wav_file.writeframes(audio_data.tobytes())
return f.name
return _create_wav
@pytest.mark.asyncio
async def test_end_to_end_pipeline(
self,
diarization_service,
transcription_service,
quote_analyzer,
create_test_wav_file,
):
"""Test complete end-to-end audio processing pipeline."""
# Create test audio file
audio_file = create_test_wav_file(duration_seconds=10)
try:
# Mock NeMo diarization output
with patch.object(
diarization_service, "_run_nemo_diarization"
) as mock_diar:
mock_diar.return_value = [
{
"start_time": 0.0,
"end_time": 2.5,
"speaker_label": "SPEAKER_01",
"confidence": 0.95,
},
{
"start_time": 3.0,
"end_time": 5.5,
"speaker_label": "SPEAKER_02",
"confidence": 0.88,
},
]
# Step 1: Perform speaker diarization
diarization_result = await diarization_service.process_audio_clip(
audio_file_path=audio_file,
guild_id=12345,
channel_id=67890,
participants=[111, 222],
)
assert diarization_result is not None
assert len(diarization_result.speaker_segments) == 2
# Step 2: Transcribe audio with speaker segments
transcription_result = await transcription_service.transcribe_audio(
audio_file
)
assert "segments" in transcription_result
# Step 3: Combine diarization and transcription
combined_segments = await self._combine_diarization_and_transcription(
diarization_result.speaker_segments, transcription_result["segments"]
)
assert len(combined_segments) > 0
assert all(
"speaker_label" in seg and "text" in seg for seg in combined_segments
)
# Step 4: Analyze quotes for each speaker segment
for segment in combined_segments:
if segment["text"].strip():
analysis = await quote_analyzer.analyze_quote(
text=segment["text"],
speaker_id=segment.get("user_id"),
context={"duration": segment["end"] - segment["start"]},
)
assert "overall_score" in analysis
assert 0.0 <= analysis["overall_score"] <= 10.0
finally:
# Cleanup
Path(audio_file).unlink(missing_ok=True)
@pytest.mark.asyncio
async def test_discord_voice_integration(
self, diarization_service, audio_recorder, sample_discord_audio
):
"""Test integration with Discord voice recording system."""
channel_id = 67890
guild_id = 12345
participants = [111, 222, 333]
# Mock Discord voice client
mock_voice_client = MagicMock()
mock_voice_client.is_connected.return_value = True
mock_voice_client.channel.id = channel_id
# Start recording
with patch.object(audio_recorder, "start_recording") as mock_start:
mock_start.return_value = True
success = await audio_recorder.start_recording(
voice_client=mock_voice_client, channel_id=channel_id, guild_id=guild_id
)
assert success
# Simulate audio processing
with patch.object(diarization_service, "process_audio_clip") as mock_process:
mock_result = DiarizationResult(
audio_file_path="/temp/discord_audio.wav",
total_duration=10.0,
speaker_segments=[
SpeakerSegment(0.0, 5.0, "SPEAKER_01", 0.9, user_id=111),
SpeakerSegment(5.0, 10.0, "SPEAKER_02", 0.8, user_id=222),
],
unique_speakers=["SPEAKER_01", "SPEAKER_02"],
processing_time=2.1,
timestamp=datetime.utcnow(),
)
mock_process.return_value = mock_result
result = await diarization_service.process_audio_clip(
audio_file_path="/temp/discord_audio.wav",
guild_id=guild_id,
channel_id=channel_id,
participants=participants,
)
assert result.unique_speakers == 2
assert (
len([seg for seg in result.speaker_segments if seg.user_id is not None])
== 2
)
@pytest.mark.asyncio
async def test_multi_language_support(
self, diarization_service, create_test_wav_file
):
"""Test pipeline support for multiple languages."""
languages = ["en", "es", "fr", "de", "zh"]
for language in languages:
audio_file = create_test_wav_file()
try:
with patch.object(
diarization_service, "_detect_language"
) as mock_detect:
mock_detect.return_value = language
with patch.object(
diarization_service, "_run_nemo_diarization"
) as mock_diar:
mock_diar.return_value = [
{
"start_time": 0.0,
"end_time": 5.0,
"speaker_label": "SPEAKER_01",
"confidence": 0.9,
}
]
result = await diarization_service.process_audio_clip(
audio_file_path=audio_file,
guild_id=12345,
channel_id=67890,
participants=[111],
)
assert result is not None
assert len(result.speaker_segments) == 1
finally:
Path(audio_file).unlink(missing_ok=True)
@pytest.mark.asyncio
async def test_real_time_processing(self, diarization_service, audio_recorder):
"""Test real-time audio processing capabilities."""
# Simulate streaming audio chunks
chunk_duration = 2.0 # 2-second chunks
total_duration = 10.0
sample_rate = 16000
chunks = []
for i in range(int(total_duration / chunk_duration)):
chunk_samples = int(chunk_duration * sample_rate)
chunk = torch.randn(1, chunk_samples)
chunks.append(chunk)
# Process chunks in real-time
accumulated_results = []
for i, chunk in enumerate(chunks):
with patch.object(
diarization_service, "_process_audio_chunk"
) as mock_chunk:
mock_chunk.return_value = [
SpeakerSegment(
start_time=i * chunk_duration,
end_time=(i + 1) * chunk_duration,
speaker_label=f"SPEAKER_{i % 2:02d}",
confidence=0.85,
)
]
chunk_result = await diarization_service._process_audio_chunk(
chunk, sample_rate, chunk_index=i
)
accumulated_results.extend(chunk_result)
assert len(accumulated_results) == len(chunks)
# Verify temporal continuity
for i in range(1, len(accumulated_results)):
assert (
accumulated_results[i].start_time >= accumulated_results[i - 1].end_time
)
@pytest.mark.asyncio
async def test_concurrent_channel_processing(
self, diarization_service, create_test_wav_file
):
"""Test processing multiple Discord channels simultaneously."""
channels = [
{"id": 67890, "guild_id": 12345, "participants": [111, 222]},
{"id": 67891, "guild_id": 12345, "participants": [333, 444]},
{"id": 67892, "guild_id": 12346, "participants": [555, 666]},
]
# Create test audio files for each channel
audio_files = [create_test_wav_file() for _ in channels]
try:
# Process all channels concurrently
tasks = []
for i, channel in enumerate(channels):
task = diarization_service.process_audio_clip(
audio_file_path=audio_files[i],
guild_id=channel["guild_id"],
channel_id=channel["id"],
participants=channel["participants"],
)
tasks.append(task)
results = await asyncio.gather(*tasks)
# Verify all channels processed successfully
assert len(results) == len(channels)
assert all(result is not None for result in results)
# Verify channel isolation
for i, result in enumerate(results):
assert (
str(channels[i]["id"]) in result.audio_file_path
or result.audio_file_path == audio_files[i]
)
finally:
# Cleanup
for audio_file in audio_files:
Path(audio_file).unlink(missing_ok=True)
@pytest.mark.asyncio
async def test_error_recovery_and_fallbacks(
self, diarization_service, create_test_wav_file
):
"""Test error recovery mechanisms and fallback strategies."""
audio_file = create_test_wav_file()
try:
# Test NeMo model failure with fallback
with patch.object(
diarization_service, "_run_nemo_diarization"
) as mock_nemo:
mock_nemo.side_effect = Exception("NeMo model failed")
with patch.object(
diarization_service, "_fallback_basic_vad"
) as mock_fallback:
mock_fallback.return_value = [
SpeakerSegment(0.0, 10.0, "SPEAKER_00", 0.6, needs_tagging=True)
]
result = await diarization_service.process_audio_clip(
audio_file_path=audio_file,
guild_id=12345,
channel_id=67890,
participants=[111, 222],
)
assert result is not None
assert len(result.speaker_segments) == 1
assert result.speaker_segments[
0
].needs_tagging # Indicates fallback was used
mock_fallback.assert_called_once()
finally:
Path(audio_file).unlink(missing_ok=True)
@pytest.mark.asyncio
async def test_memory_management(self, diarization_service, create_test_wav_file):
"""Test memory management during intensive processing."""
# Create multiple large audio files
large_audio_files = [
create_test_wav_file(duration_seconds=120) # 2-minute files
for _ in range(5)
]
try:
# Track memory usage
initial_memory = (
torch.cuda.memory_allocated() if torch.cuda.is_available() else 0
)
# Process files sequentially with memory monitoring
for audio_file in large_audio_files:
await diarization_service.process_audio_clip(
audio_file_path=audio_file,
guild_id=12345,
channel_id=67890,
participants=[111, 222],
)
# Force garbage collection
if torch.cuda.is_available():
torch.cuda.empty_cache()
current_memory = (
torch.cuda.memory_allocated() if torch.cuda.is_available() else 0
)
memory_increase = current_memory - initial_memory
# Memory should not grow excessively
assert memory_increase < 1024 * 1024 * 1024 # Less than 1GB increase
finally:
# Cleanup
for audio_file in large_audio_files:
Path(audio_file).unlink(missing_ok=True)
@pytest.mark.asyncio
async def test_performance_benchmarks(
self, diarization_service, create_test_wav_file
):
"""Test performance benchmarks for different scenarios."""
scenarios = [
{"duration": 10, "expected_max_time": 5.0, "description": "Short audio"},
{"duration": 60, "expected_max_time": 15.0, "description": "Medium audio"},
{"duration": 120, "expected_max_time": 30.0, "description": "Long audio"},
]
for scenario in scenarios:
audio_file = create_test_wav_file(duration_seconds=scenario["duration"])
try:
start_time = datetime.utcnow()
result = await diarization_service.process_audio_clip(
audio_file_path=audio_file,
guild_id=12345,
channel_id=67890,
participants=[111, 222],
)
processing_time = (datetime.utcnow() - start_time).total_seconds()
assert result is not None
assert processing_time <= scenario["expected_max_time"], (
f"{scenario['description']}: Processing took {processing_time:.2f}s, "
f"expected <= {scenario['expected_max_time']}s"
)
finally:
Path(audio_file).unlink(missing_ok=True)
@pytest.mark.asyncio
async def test_data_consistency(
self, diarization_service, mock_database_manager, create_test_wav_file
):
"""Test data consistency between diarization results and database storage."""
audio_file = create_test_wav_file()
try:
# Mock database storage
stored_segments = []
async def mock_store_segment(*args):
stored_segments.append(args)
return {"id": len(stored_segments)}
mock_database_manager.execute_query.side_effect = mock_store_segment
result = await diarization_service.process_audio_clip(
audio_file_path=audio_file,
guild_id=12345,
channel_id=67890,
participants=[111, 222],
)
# Verify data consistency
assert result is not None
assert len(stored_segments) == len(result.speaker_segments)
# Verify timestamp consistency
for segment in result.speaker_segments:
assert segment.start_time < segment.end_time
assert segment.end_time <= result.total_duration
finally:
Path(audio_file).unlink(missing_ok=True)
async def _combine_diarization_and_transcription(
self, diar_segments: List[SpeakerSegment], transcription_segments: List[dict]
) -> List[dict]:
"""Combine diarization and transcription results."""
combined = []
for trans_seg in transcription_segments:
# Find overlapping speaker segment
best_overlap = 0
best_speaker = None
for diar_seg in diar_segments:
# Calculate overlap
overlap_start = max(trans_seg["start"], diar_seg.start_time)
overlap_end = min(trans_seg["end"], diar_seg.end_time)
overlap = max(0, overlap_end - overlap_start)
if overlap > best_overlap:
best_overlap = overlap
best_speaker = diar_seg
combined_segment = {
"start": trans_seg["start"],
"end": trans_seg["end"],
"text": trans_seg["text"],
"confidence": trans_seg["confidence"],
"speaker_label": (
best_speaker.speaker_label if best_speaker else "UNKNOWN"
),
"user_id": best_speaker.user_id if best_speaker else None,
}
combined.append(combined_segment)
return combined
@pytest.mark.asyncio
async def test_speaker_continuity(self, diarization_service, create_test_wav_file):
"""Test speaker label continuity across segments."""
audio_file = create_test_wav_file(duration_seconds=30)
try:
with patch.object(
diarization_service, "_run_nemo_diarization"
) as mock_diar:
# Simulate alternating speakers
mock_diar.return_value = [
{
"start_time": 0.0,
"end_time": 5.0,
"speaker_label": "SPEAKER_01",
"confidence": 0.9,
},
{
"start_time": 5.0,
"end_time": 10.0,
"speaker_label": "SPEAKER_02",
"confidence": 0.85,
},
{
"start_time": 10.0,
"end_time": 15.0,
"speaker_label": "SPEAKER_01",
"confidence": 0.88,
},
{
"start_time": 15.0,
"end_time": 20.0,
"speaker_label": "SPEAKER_02",
"confidence": 0.92,
},
{
"start_time": 20.0,
"end_time": 25.0,
"speaker_label": "SPEAKER_01",
"confidence": 0.87,
},
]
result = await diarization_service.process_audio_clip(
audio_file_path=audio_file,
guild_id=12345,
channel_id=67890,
participants=[111, 222],
)
# Verify speaker continuity
speaker_01_segments = [
seg
for seg in result.speaker_segments
if seg.speaker_label == "SPEAKER_01"
]
speaker_02_segments = [
seg
for seg in result.speaker_segments
if seg.speaker_label == "SPEAKER_02"
]
assert len(speaker_01_segments) == 3
assert len(speaker_02_segments) == 2
# Verify temporal ordering
for segments in [speaker_01_segments, speaker_02_segments]:
for i in range(1, len(segments)):
assert segments[i].start_time > segments[i - 1].end_time
finally:
Path(audio_file).unlink(missing_ok=True)
@pytest.mark.asyncio
async def test_quote_scoring_integration(
self, diarization_service, quote_analyzer, create_test_wav_file
):
"""Test integration between diarization and quote scoring."""
audio_file = create_test_wav_file()
try:
# Mock diarization with speaker identification
with patch.object(diarization_service, "process_audio_clip") as mock_diar:
mock_result = DiarizationResult(
audio_file_path=audio_file,
total_duration=10.0,
speaker_segments=[
SpeakerSegment(0.0, 5.0, "Alice", 0.9, user_id=111),
SpeakerSegment(5.0, 10.0, "Bob", 0.85, user_id=222),
],
unique_speakers=["Alice", "Bob"],
processing_time=2.0,
timestamp=datetime.utcnow(),
)
mock_diar.return_value = mock_result
diar_result = await mock_diar(audio_file, 12345, 67890, [111, 222])
# Test quote scoring for each speaker
for segment in diar_result.speaker_segments:
if segment.user_id:
# Mock transcription for this segment
segment_text = f"This is a quote from {segment.speaker_label}"
analysis = await quote_analyzer.analyze_quote(
text=segment_text,
speaker_id=segment.user_id,
context={
"speaker_confidence": segment.confidence,
"duration": segment.end_time - segment.start_time,
},
)
assert "overall_score" in analysis
assert analysis["overall_score"] > 0
finally:
Path(audio_file).unlink(missing_ok=True)