- Deleted .env.example file as it is no longer needed. - Added .gitignore to manage ignored files and directories. - Introduced CLAUDE.md for AI provider integration documentation. - Created dev.sh for development setup and scripts. - Updated Dockerfile and Dockerfile.production for improved build processes. - Added multiple test files and directories for comprehensive testing. - Introduced new utility and service files for enhanced functionality. - Organized codebase with new directories and files for better maintainability.
443 lines
15 KiB
Python
443 lines
15 KiB
Python
"""
|
|
Integration tests for the complete audio processing pipeline.
|
|
|
|
Tests the end-to-end flow from audio recording through quote analysis.
|
|
"""
|
|
|
|
import asyncio
|
|
import tempfile
|
|
import wave
|
|
from datetime import datetime, timedelta
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from main import QuoteBot
|
|
|
|
|
|
class TestAudioPipeline:
|
|
"""Integration tests for the complete audio pipeline."""
|
|
|
|
@pytest.fixture
|
|
async def test_bot(self, mock_discord_environment):
|
|
"""Create a test bot instance with mocked Discord environment."""
|
|
bot = QuoteBot()
|
|
bot.settings = self._create_test_settings()
|
|
|
|
# Mock Discord connection
|
|
bot.user = MagicMock()
|
|
bot.user.id = 999999
|
|
bot.guilds = [mock_discord_environment["guild"]]
|
|
|
|
await bot.setup_hook()
|
|
return bot
|
|
|
|
@pytest.fixture
|
|
def mock_discord_environment(self):
|
|
"""Create a complete mock Discord environment."""
|
|
guild = MagicMock()
|
|
guild.id = 123456789
|
|
guild.name = "Test Guild"
|
|
|
|
channel = MagicMock()
|
|
channel.id = 987654321
|
|
channel.name = "test-voice"
|
|
channel.guild = guild
|
|
|
|
members = []
|
|
for i in range(3):
|
|
member = MagicMock()
|
|
member.id = 100 + i
|
|
member.name = f"TestUser{i}"
|
|
member.voice = MagicMock()
|
|
member.voice.channel = channel
|
|
members.append(member)
|
|
|
|
channel.members = members
|
|
|
|
return {"guild": guild, "channel": channel, "members": members}
|
|
|
|
@pytest.fixture
|
|
def test_audio_data(self):
|
|
"""Generate test audio data with known characteristics."""
|
|
sample_rate = 48000
|
|
duration = 10 # seconds
|
|
|
|
# Generate multi-speaker audio simulation
|
|
|
|
# Speaker 1: 0-3 seconds (funny quote)
|
|
t1 = np.linspace(0, 3, sample_rate * 3)
|
|
speaker1_audio = np.sin(2 * np.pi * 440 * t1) * 0.5
|
|
|
|
# Speaker 2: 3-6 seconds (response with laughter)
|
|
t2 = np.linspace(0, 3, sample_rate * 3)
|
|
speaker2_audio = np.sin(2 * np.pi * 554 * t2) * 0.5
|
|
|
|
# Laughter: 6-7 seconds
|
|
np.linspace(0, 1, sample_rate)
|
|
laughter_audio = np.random.normal(0, 0.3, sample_rate)
|
|
|
|
# Speaker 1: 7-10 seconds (follow-up)
|
|
t4 = np.linspace(0, 3, sample_rate * 3)
|
|
speaker1_followup = np.sin(2 * np.pi * 440 * t4) * 0.5
|
|
|
|
# Combine segments
|
|
full_audio = np.concatenate(
|
|
[speaker1_audio, speaker2_audio, laughter_audio, speaker1_followup]
|
|
).astype(np.float32)
|
|
|
|
return {
|
|
"audio": full_audio,
|
|
"sample_rate": sample_rate,
|
|
"duration": duration,
|
|
"expected_segments": [
|
|
{
|
|
"start": 0,
|
|
"end": 3,
|
|
"speaker": "SPEAKER_01",
|
|
"text": "This is really funny",
|
|
},
|
|
{
|
|
"start": 3,
|
|
"end": 6,
|
|
"speaker": "SPEAKER_02",
|
|
"text": "That's hilarious",
|
|
},
|
|
{"start": 6, "end": 7, "type": "laughter"},
|
|
{
|
|
"start": 7,
|
|
"end": 10,
|
|
"speaker": "SPEAKER_01",
|
|
"text": "I know right",
|
|
},
|
|
],
|
|
}
|
|
|
|
def _create_test_settings(self):
|
|
"""Create test settings."""
|
|
settings = MagicMock()
|
|
settings.database_url = "sqlite:///:memory:"
|
|
settings.audio_buffer_duration = 120
|
|
settings.audio_sample_rate = 48000
|
|
settings.quote_min_length = 5
|
|
settings.quote_score_threshold = 5.0
|
|
settings.high_quality_threshold = 8.0
|
|
return settings
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_full_audio_pipeline(
|
|
self, test_bot, test_audio_data, mock_discord_environment
|
|
):
|
|
"""Test the complete audio processing pipeline."""
|
|
channel = mock_discord_environment["channel"]
|
|
|
|
# Step 1: Start recording
|
|
voice_client = MagicMock()
|
|
voice_client.is_connected.return_value = True
|
|
voice_client.channel = channel
|
|
|
|
recording_started = await test_bot.audio_recorder.start_recording(
|
|
voice_client, channel.id, channel.guild.id
|
|
)
|
|
assert recording_started is True
|
|
|
|
# Step 2: Simulate audio input
|
|
audio_clip = await self._simulate_audio_recording(
|
|
test_bot.audio_recorder,
|
|
channel.id,
|
|
test_audio_data["audio"],
|
|
test_audio_data["sample_rate"],
|
|
)
|
|
assert audio_clip is not None
|
|
|
|
# Step 3: Process through diarization
|
|
diarization_result = await test_bot.speaker_diarization.process_audio(
|
|
audio_clip.file_path, audio_clip.participants
|
|
)
|
|
assert len(diarization_result["segments"]) > 0
|
|
|
|
# Step 4: Transcribe with speaker mapping
|
|
transcription = await test_bot.transcription_service.transcribe_audio_clip(
|
|
audio_clip.file_path,
|
|
channel.guild.id,
|
|
channel.id,
|
|
diarization_result,
|
|
audio_clip.id,
|
|
)
|
|
assert transcription is not None
|
|
assert len(transcription.transcribed_segments) > 0
|
|
|
|
# Step 5: Detect laughter
|
|
laughter_analysis = await test_bot.laughter_detector.detect_laughter(
|
|
audio_clip.file_path, audio_clip.participants
|
|
)
|
|
assert laughter_analysis.total_laughter_duration > 0
|
|
|
|
# Step 6: Analyze quotes
|
|
quote_results = []
|
|
for segment in transcription.transcribed_segments:
|
|
if segment.is_quote_candidate:
|
|
quote_data = await test_bot.quote_analyzer.analyze_quote(
|
|
segment.text,
|
|
segment.speaker_label,
|
|
{
|
|
"user_id": segment.user_id,
|
|
"laughter_duration": self._get_overlapping_laughter(
|
|
segment, laughter_analysis
|
|
),
|
|
},
|
|
)
|
|
if quote_data:
|
|
quote_results.append(quote_data)
|
|
|
|
assert len(quote_results) > 0
|
|
assert any(q["overall_score"] > 5.0 for q in quote_results)
|
|
|
|
# Step 7: Schedule responses
|
|
for quote_data in quote_results:
|
|
await test_bot.response_scheduler.process_quote_score(quote_data)
|
|
|
|
# Verify pipeline metrics
|
|
assert test_bot.metrics.get_counter("audio_clips_processed") > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_multi_guild_concurrent_processing(self, test_bot, test_audio_data):
|
|
"""Test concurrent audio processing for multiple guilds."""
|
|
guilds = []
|
|
for i in range(3):
|
|
guild = MagicMock()
|
|
guild.id = 1000 + i
|
|
guild.name = f"Guild{i}"
|
|
|
|
channel = MagicMock()
|
|
channel.id = 2000 + i
|
|
channel.guild = guild
|
|
|
|
guilds.append({"guild": guild, "channel": channel})
|
|
|
|
# Start recordings concurrently
|
|
recording_tasks = []
|
|
for g in guilds:
|
|
voice_client = MagicMock()
|
|
voice_client.channel = g["channel"]
|
|
|
|
task = test_bot.audio_recorder.start_recording(
|
|
voice_client, g["channel"].id, g["guild"].id
|
|
)
|
|
recording_tasks.append(task)
|
|
|
|
results = await asyncio.gather(*recording_tasks)
|
|
assert all(results)
|
|
|
|
# Process audio concurrently
|
|
processing_tasks = []
|
|
for g in guilds:
|
|
audio_clip = await self._create_test_audio_clip(
|
|
g["channel"].id, g["guild"].id, test_audio_data
|
|
)
|
|
|
|
task = test_bot._process_audio_clip(audio_clip)
|
|
processing_tasks.append(task)
|
|
|
|
await asyncio.gather(*processing_tasks)
|
|
|
|
# Verify isolation between guilds
|
|
for g in guilds:
|
|
assert test_bot.audio_recorder.get_recording(g["channel"].id) is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pipeline_failure_recovery(self, test_bot, test_audio_data):
|
|
"""Test pipeline recovery from failures at various stages."""
|
|
channel_id = 123456
|
|
guild_id = 789012
|
|
|
|
audio_clip = await self._create_test_audio_clip(
|
|
channel_id, guild_id, test_audio_data
|
|
)
|
|
|
|
# Test transcription failure
|
|
with patch.object(
|
|
test_bot.transcription_service, "transcribe_audio_clip"
|
|
) as mock_transcribe:
|
|
mock_transcribe.side_effect = Exception("Transcription API error")
|
|
|
|
# Should not crash the pipeline
|
|
await test_bot._process_audio_clip(audio_clip)
|
|
|
|
# Should log error
|
|
assert test_bot.metrics.get_counter("audio_processing_errors") > 0
|
|
|
|
# Test quote analysis failure with fallback
|
|
with patch.object(test_bot.quote_analyzer, "analyze_quote") as mock_analyze:
|
|
mock_analyze.side_effect = [Exception("AI error"), {"overall_score": 5.0}]
|
|
|
|
# Should retry and succeed
|
|
await test_bot._process_audio_clip(audio_clip)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_voice_state_changes_during_recording(
|
|
self, test_bot, mock_discord_environment
|
|
):
|
|
"""Test handling voice state changes during active recording."""
|
|
channel = mock_discord_environment["channel"]
|
|
members = mock_discord_environment["members"]
|
|
|
|
# Start recording
|
|
voice_client = MagicMock()
|
|
voice_client.channel = channel
|
|
|
|
await test_bot.audio_recorder.start_recording(
|
|
voice_client, channel.id, channel.guild.id
|
|
)
|
|
|
|
# Simulate member join
|
|
new_member = MagicMock()
|
|
new_member.id = 200
|
|
new_member.name = "NewUser"
|
|
await test_bot.audio_recorder.on_member_join(channel.id, new_member)
|
|
|
|
# Simulate member leave
|
|
await test_bot.audio_recorder.on_member_leave(channel.id, members[0])
|
|
|
|
# Simulate member mute
|
|
members[1].voice.self_mute = True
|
|
await test_bot.audio_recorder.on_voice_state_update(
|
|
members[1], channel.id, channel.id
|
|
)
|
|
|
|
# Verify recording continues with updated participants
|
|
recording = test_bot.audio_recorder.get_recording(channel.id)
|
|
assert 200 in recording["participants"]
|
|
assert members[0].id not in recording["participants"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_quote_response_generation(self, test_bot):
|
|
"""Test the complete quote response generation flow."""
|
|
quote_data = {
|
|
"id": 1,
|
|
"quote": "This is the funniest thing ever said",
|
|
"user_id": 123456,
|
|
"guild_id": 789012,
|
|
"channel_id": 111222,
|
|
"funny_score": 9.5,
|
|
"overall_score": 9.0,
|
|
"is_high_quality": True,
|
|
"timestamp": datetime.utcnow(),
|
|
}
|
|
|
|
# Process high-quality quote
|
|
await test_bot.response_scheduler.process_quote_score(quote_data)
|
|
|
|
# Should schedule immediate response for high-quality quote
|
|
scheduled = test_bot.response_scheduler.get_scheduled_responses()
|
|
assert len(scheduled) > 0
|
|
assert scheduled[0]["quote_id"] == 1
|
|
assert scheduled[0]["response_type"] == "immediate"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_memory_context_integration(self, test_bot):
|
|
"""Test memory system integration with quote analysis."""
|
|
# Store previous conversation context
|
|
await test_bot.memory_manager.store_conversation(
|
|
{
|
|
"guild_id": 123456,
|
|
"content": "Remember that hilarious thing from yesterday?",
|
|
"timestamp": datetime.utcnow() - timedelta(hours=24),
|
|
}
|
|
)
|
|
|
|
# Analyze new quote that references context
|
|
quote = "Just like yesterday, this is golden"
|
|
|
|
with patch.object(test_bot.memory_manager, "retrieve_context") as mock_retrieve:
|
|
mock_retrieve.return_value = [
|
|
{"content": "Yesterday's hilarious moment", "relevance": 0.9}
|
|
]
|
|
|
|
result = await test_bot.quote_analyzer.analyze_quote(
|
|
quote, "SPEAKER_01", {"guild_id": 123456}
|
|
)
|
|
|
|
assert result["has_context"] is True
|
|
assert result["overall_score"] > 6.0 # Context should boost score
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_consent_flow_integration(self, test_bot, mock_discord_environment):
|
|
"""Test consent management integration with recording."""
|
|
channel = mock_discord_environment["channel"]
|
|
members = mock_discord_environment["members"]
|
|
|
|
# Set consent status
|
|
await test_bot.consent_manager.update_consent(members[0].id, True)
|
|
await test_bot.consent_manager.update_consent(members[1].id, False)
|
|
|
|
# Try to start recording
|
|
voice_client = MagicMock()
|
|
voice_client.channel = channel
|
|
|
|
# Should check consent before recording
|
|
with patch.object(
|
|
test_bot.consent_manager, "check_channel_consent"
|
|
) as mock_check:
|
|
mock_check.return_value = True # At least one consented user
|
|
|
|
success = await test_bot.audio_recorder.start_recording(
|
|
voice_client, channel.id, channel.guild.id
|
|
)
|
|
assert success is True
|
|
|
|
# Should only process audio from consented users
|
|
recording = test_bot.audio_recorder.get_recording(channel.id)
|
|
assert members[0].id in recording["consented_participants"]
|
|
assert members[1].id not in recording["consented_participants"]
|
|
|
|
async def _simulate_audio_recording(
|
|
self, recorder, channel_id, audio_data, sample_rate
|
|
):
|
|
"""Helper to simulate audio recording."""
|
|
# Create temporary audio file
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
|
with wave.open(f.name, "wb") as wav_file:
|
|
wav_file.setnchannels(1)
|
|
wav_file.setsampwidth(2)
|
|
wav_file.setframerate(sample_rate)
|
|
wav_file.writeframes((audio_data * 32767).astype(np.int16).tobytes())
|
|
|
|
audio_clip = MagicMock()
|
|
audio_clip.file_path = f.name
|
|
audio_clip.id = channel_id
|
|
audio_clip.channel_id = channel_id
|
|
audio_clip.participants = [100, 101, 102]
|
|
|
|
return audio_clip
|
|
|
|
async def _create_test_audio_clip(self, channel_id, guild_id, test_audio_data):
|
|
"""Helper to create test audio clip."""
|
|
audio_clip = MagicMock()
|
|
audio_clip.id = f"clip_{channel_id}"
|
|
audio_clip.channel_id = channel_id
|
|
audio_clip.guild_id = guild_id
|
|
audio_clip.file_path = "/tmp/test_audio.wav"
|
|
audio_clip.participants = [100, 101, 102]
|
|
audio_clip.duration = test_audio_data["duration"]
|
|
|
|
return audio_clip
|
|
|
|
def _get_overlapping_laughter(self, segment, laughter_analysis):
|
|
"""Helper to calculate overlapping laughter duration."""
|
|
if not laughter_analysis or not laughter_analysis.laughter_segments:
|
|
return 0
|
|
|
|
overlap = 0
|
|
for laugh in laughter_analysis.laughter_segments:
|
|
if (
|
|
laugh.start_time < segment.end_time
|
|
and laugh.end_time > segment.start_time
|
|
):
|
|
overlap += min(laugh.end_time, segment.end_time) - max(
|
|
laugh.start_time, segment.start_time
|
|
)
|
|
|
|
return overlap
|