- Deleted .env.example file as it is no longer needed. - Added .gitignore to manage ignored files and directories. - Introduced CLAUDE.md for AI provider integration documentation. - Created dev.sh for development setup and scripts. - Updated Dockerfile and Dockerfile.production for improved build processes. - Added multiple test files and directories for comprehensive testing. - Introduced new utility and service files for enhanced functionality. - Organized codebase with new directories and files for better maintainability.
734 lines
27 KiB
Python
734 lines
27 KiB
Python
"""
|
|
Integration tests for NVIDIA NeMo Audio Processing Pipeline.
|
|
|
|
Tests the end-to-end integration of NeMo speaker diarization with the Discord bot's
|
|
audio processing pipeline, including recording, transcription, and quote analysis.
|
|
"""
|
|
|
|
import asyncio
|
|
import tempfile
|
|
import wave
|
|
from datetime import datetime, timedelta
|
|
from pathlib import Path
|
|
from typing import List
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
from core.consent_manager import ConsentManager
|
|
from core.database import DatabaseManager
|
|
from services.audio.audio_recorder import AudioRecorderService
|
|
from services.audio.speaker_diarization import (DiarizationResult,
|
|
SpeakerDiarizationService,
|
|
SpeakerSegment)
|
|
from services.audio.transcription_service import TranscriptionService
|
|
from services.quotes.quote_analyzer import QuoteAnalyzer
|
|
|
|
|
|
class TestNeMoAudioPipeline:
|
|
"""Integration test suite for NeMo-based audio processing pipeline."""
|
|
|
|
@pytest.fixture
|
|
def mock_database_manager(self):
|
|
"""Create mock database manager with realistic responses."""
|
|
db_manager = AsyncMock(spec=DatabaseManager)
|
|
|
|
# Mock user consent data
|
|
db_manager.execute_query.return_value = [
|
|
{"user_id": 111, "consent_given": True, "username": "Alice"},
|
|
{"user_id": 222, "consent_given": True, "username": "Bob"},
|
|
{"user_id": 333, "consent_given": False, "username": "Charlie"},
|
|
]
|
|
|
|
return db_manager
|
|
|
|
@pytest.fixture
|
|
def mock_consent_manager(self, mock_database_manager):
|
|
"""Create consent manager with database integration."""
|
|
consent_manager = ConsentManager(mock_database_manager)
|
|
consent_manager.has_recording_consent = AsyncMock(return_value=True)
|
|
consent_manager.get_consented_users = AsyncMock(return_value=[111, 222])
|
|
return consent_manager
|
|
|
|
@pytest.fixture
|
|
def mock_audio_processor(self):
|
|
"""Create audio processor for format conversions."""
|
|
processor = MagicMock()
|
|
processor.tensor_to_bytes = AsyncMock(return_value=b"processed_audio_bytes")
|
|
processor.bytes_to_tensor = AsyncMock(return_value=torch.randn(1, 16000))
|
|
return processor
|
|
|
|
@pytest.fixture
|
|
async def diarization_service(
|
|
self, mock_database_manager, mock_consent_manager, mock_audio_processor
|
|
):
|
|
"""Create initialized speaker diarization service."""
|
|
service = SpeakerDiarizationService(
|
|
db_manager=mock_database_manager,
|
|
consent_manager=mock_consent_manager,
|
|
audio_processor=mock_audio_processor,
|
|
)
|
|
|
|
# Mock successful initialization
|
|
with patch.object(service, "_load_nemo_models") as mock_load:
|
|
mock_load.return_value = True
|
|
await service.initialize()
|
|
|
|
return service
|
|
|
|
@pytest.fixture
|
|
async def transcription_service(self):
|
|
"""Create transcription service."""
|
|
service = AsyncMock(spec=TranscriptionService)
|
|
service.transcribe_audio.return_value = {
|
|
"segments": [
|
|
{
|
|
"start": 0.0,
|
|
"end": 2.5,
|
|
"text": "This is a funny quote",
|
|
"confidence": 0.95,
|
|
},
|
|
{
|
|
"start": 3.0,
|
|
"end": 5.5,
|
|
"text": "Another interesting statement",
|
|
"confidence": 0.88,
|
|
},
|
|
],
|
|
"full_text": "This is a funny quote. Another interesting statement.",
|
|
}
|
|
return service
|
|
|
|
@pytest.fixture
|
|
async def quote_analyzer(self):
|
|
"""Create quote analyzer service."""
|
|
analyzer = AsyncMock(spec=QuoteAnalyzer)
|
|
analyzer.analyze_quote.return_value = {
|
|
"funny_score": 8.5,
|
|
"dark_score": 2.1,
|
|
"silly_score": 7.3,
|
|
"suspicious_score": 1.8,
|
|
"asinine_score": 3.2,
|
|
"overall_score": 7.2,
|
|
}
|
|
return analyzer
|
|
|
|
@pytest.fixture
|
|
async def audio_recorder(self):
|
|
"""Create audio recorder service."""
|
|
recorder = AsyncMock(spec=AudioRecorderService)
|
|
recorder.get_active_recordings.return_value = {
|
|
67890: {
|
|
"guild_id": 12345,
|
|
"participants": [111, 222],
|
|
"start_time": datetime.utcnow() - timedelta(seconds=30),
|
|
"buffer": MagicMock(),
|
|
}
|
|
}
|
|
return recorder
|
|
|
|
@pytest.fixture
|
|
def sample_discord_audio(self):
|
|
"""Create sample Discord-compatible audio data."""
|
|
# Generate 10 seconds of mock audio with two speakers
|
|
sample_rate = 48000 # Discord's sample rate
|
|
duration = 10
|
|
samples = int(sample_rate * duration)
|
|
|
|
# Create stereo audio with different patterns for each channel
|
|
left_channel = np.sin(
|
|
2 * np.pi * 440 * np.linspace(0, duration, samples)
|
|
) # 440 Hz
|
|
right_channel = np.sin(
|
|
2 * np.pi * 880 * np.linspace(0, duration, samples)
|
|
) # 880 Hz
|
|
|
|
# Combine channels
|
|
stereo_audio = np.array([left_channel, right_channel])
|
|
return torch.from_numpy(stereo_audio.astype(np.float32))
|
|
|
|
@pytest.fixture
|
|
def create_test_wav_file(self):
|
|
"""Create a temporary WAV file with test audio."""
|
|
|
|
def _create_wav(duration_seconds=10, sample_rate=16000, num_channels=1):
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
|
# Generate sine wave audio
|
|
samples = int(duration_seconds * sample_rate)
|
|
audio_data = np.sin(
|
|
2 * np.pi * 440 * np.linspace(0, duration_seconds, samples)
|
|
)
|
|
audio_data = (audio_data * 32767).astype(np.int16)
|
|
|
|
# Write WAV file
|
|
with wave.open(f.name, "wb") as wav_file:
|
|
wav_file.setnchannels(num_channels)
|
|
wav_file.setsampwidth(2) # 16-bit
|
|
wav_file.setframerate(sample_rate)
|
|
wav_file.writeframes(audio_data.tobytes())
|
|
|
|
return f.name
|
|
|
|
return _create_wav
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_end_to_end_pipeline(
|
|
self,
|
|
diarization_service,
|
|
transcription_service,
|
|
quote_analyzer,
|
|
create_test_wav_file,
|
|
):
|
|
"""Test complete end-to-end audio processing pipeline."""
|
|
# Create test audio file
|
|
audio_file = create_test_wav_file(duration_seconds=10)
|
|
|
|
try:
|
|
# Mock NeMo diarization output
|
|
with patch.object(
|
|
diarization_service, "_run_nemo_diarization"
|
|
) as mock_diar:
|
|
mock_diar.return_value = [
|
|
{
|
|
"start_time": 0.0,
|
|
"end_time": 2.5,
|
|
"speaker_label": "SPEAKER_01",
|
|
"confidence": 0.95,
|
|
},
|
|
{
|
|
"start_time": 3.0,
|
|
"end_time": 5.5,
|
|
"speaker_label": "SPEAKER_02",
|
|
"confidence": 0.88,
|
|
},
|
|
]
|
|
|
|
# Step 1: Perform speaker diarization
|
|
diarization_result = await diarization_service.process_audio_clip(
|
|
audio_file_path=audio_file,
|
|
guild_id=12345,
|
|
channel_id=67890,
|
|
participants=[111, 222],
|
|
)
|
|
|
|
assert diarization_result is not None
|
|
assert len(diarization_result.speaker_segments) == 2
|
|
|
|
# Step 2: Transcribe audio with speaker segments
|
|
transcription_result = await transcription_service.transcribe_audio(
|
|
audio_file
|
|
)
|
|
assert "segments" in transcription_result
|
|
|
|
# Step 3: Combine diarization and transcription
|
|
combined_segments = await self._combine_diarization_and_transcription(
|
|
diarization_result.speaker_segments, transcription_result["segments"]
|
|
)
|
|
|
|
assert len(combined_segments) > 0
|
|
assert all(
|
|
"speaker_label" in seg and "text" in seg for seg in combined_segments
|
|
)
|
|
|
|
# Step 4: Analyze quotes for each speaker segment
|
|
for segment in combined_segments:
|
|
if segment["text"].strip():
|
|
analysis = await quote_analyzer.analyze_quote(
|
|
text=segment["text"],
|
|
speaker_id=segment.get("user_id"),
|
|
context={"duration": segment["end"] - segment["start"]},
|
|
)
|
|
|
|
assert "overall_score" in analysis
|
|
assert 0.0 <= analysis["overall_score"] <= 10.0
|
|
|
|
finally:
|
|
# Cleanup
|
|
Path(audio_file).unlink(missing_ok=True)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_discord_voice_integration(
|
|
self, diarization_service, audio_recorder, sample_discord_audio
|
|
):
|
|
"""Test integration with Discord voice recording system."""
|
|
channel_id = 67890
|
|
guild_id = 12345
|
|
participants = [111, 222, 333]
|
|
|
|
# Mock Discord voice client
|
|
mock_voice_client = MagicMock()
|
|
mock_voice_client.is_connected.return_value = True
|
|
mock_voice_client.channel.id = channel_id
|
|
|
|
# Start recording
|
|
with patch.object(audio_recorder, "start_recording") as mock_start:
|
|
mock_start.return_value = True
|
|
success = await audio_recorder.start_recording(
|
|
voice_client=mock_voice_client, channel_id=channel_id, guild_id=guild_id
|
|
)
|
|
|
|
assert success
|
|
|
|
# Simulate audio processing
|
|
with patch.object(diarization_service, "process_audio_clip") as mock_process:
|
|
mock_result = DiarizationResult(
|
|
audio_file_path="/temp/discord_audio.wav",
|
|
total_duration=10.0,
|
|
speaker_segments=[
|
|
SpeakerSegment(0.0, 5.0, "SPEAKER_01", 0.9, user_id=111),
|
|
SpeakerSegment(5.0, 10.0, "SPEAKER_02", 0.8, user_id=222),
|
|
],
|
|
unique_speakers=["SPEAKER_01", "SPEAKER_02"],
|
|
processing_time=2.1,
|
|
timestamp=datetime.utcnow(),
|
|
)
|
|
mock_process.return_value = mock_result
|
|
|
|
result = await diarization_service.process_audio_clip(
|
|
audio_file_path="/temp/discord_audio.wav",
|
|
guild_id=guild_id,
|
|
channel_id=channel_id,
|
|
participants=participants,
|
|
)
|
|
|
|
assert result.unique_speakers == 2
|
|
assert (
|
|
len([seg for seg in result.speaker_segments if seg.user_id is not None])
|
|
== 2
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_multi_language_support(
|
|
self, diarization_service, create_test_wav_file
|
|
):
|
|
"""Test pipeline support for multiple languages."""
|
|
languages = ["en", "es", "fr", "de", "zh"]
|
|
|
|
for language in languages:
|
|
audio_file = create_test_wav_file()
|
|
|
|
try:
|
|
with patch.object(
|
|
diarization_service, "_detect_language"
|
|
) as mock_detect:
|
|
mock_detect.return_value = language
|
|
|
|
with patch.object(
|
|
diarization_service, "_run_nemo_diarization"
|
|
) as mock_diar:
|
|
mock_diar.return_value = [
|
|
{
|
|
"start_time": 0.0,
|
|
"end_time": 5.0,
|
|
"speaker_label": "SPEAKER_01",
|
|
"confidence": 0.9,
|
|
}
|
|
]
|
|
|
|
result = await diarization_service.process_audio_clip(
|
|
audio_file_path=audio_file,
|
|
guild_id=12345,
|
|
channel_id=67890,
|
|
participants=[111],
|
|
)
|
|
|
|
assert result is not None
|
|
assert len(result.speaker_segments) == 1
|
|
|
|
finally:
|
|
Path(audio_file).unlink(missing_ok=True)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_real_time_processing(self, diarization_service, audio_recorder):
|
|
"""Test real-time audio processing capabilities."""
|
|
# Simulate streaming audio chunks
|
|
chunk_duration = 2.0 # 2-second chunks
|
|
total_duration = 10.0
|
|
sample_rate = 16000
|
|
|
|
chunks = []
|
|
for i in range(int(total_duration / chunk_duration)):
|
|
chunk_samples = int(chunk_duration * sample_rate)
|
|
chunk = torch.randn(1, chunk_samples)
|
|
chunks.append(chunk)
|
|
|
|
# Process chunks in real-time
|
|
accumulated_results = []
|
|
|
|
for i, chunk in enumerate(chunks):
|
|
with patch.object(
|
|
diarization_service, "_process_audio_chunk"
|
|
) as mock_chunk:
|
|
mock_chunk.return_value = [
|
|
SpeakerSegment(
|
|
start_time=i * chunk_duration,
|
|
end_time=(i + 1) * chunk_duration,
|
|
speaker_label=f"SPEAKER_{i % 2:02d}",
|
|
confidence=0.85,
|
|
)
|
|
]
|
|
|
|
chunk_result = await diarization_service._process_audio_chunk(
|
|
chunk, sample_rate, chunk_index=i
|
|
)
|
|
|
|
accumulated_results.extend(chunk_result)
|
|
|
|
assert len(accumulated_results) == len(chunks)
|
|
|
|
# Verify temporal continuity
|
|
for i in range(1, len(accumulated_results)):
|
|
assert (
|
|
accumulated_results[i].start_time >= accumulated_results[i - 1].end_time
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_concurrent_channel_processing(
|
|
self, diarization_service, create_test_wav_file
|
|
):
|
|
"""Test processing multiple Discord channels simultaneously."""
|
|
channels = [
|
|
{"id": 67890, "guild_id": 12345, "participants": [111, 222]},
|
|
{"id": 67891, "guild_id": 12345, "participants": [333, 444]},
|
|
{"id": 67892, "guild_id": 12346, "participants": [555, 666]},
|
|
]
|
|
|
|
# Create test audio files for each channel
|
|
audio_files = [create_test_wav_file() for _ in channels]
|
|
|
|
try:
|
|
# Process all channels concurrently
|
|
tasks = []
|
|
for i, channel in enumerate(channels):
|
|
task = diarization_service.process_audio_clip(
|
|
audio_file_path=audio_files[i],
|
|
guild_id=channel["guild_id"],
|
|
channel_id=channel["id"],
|
|
participants=channel["participants"],
|
|
)
|
|
tasks.append(task)
|
|
|
|
results = await asyncio.gather(*tasks)
|
|
|
|
# Verify all channels processed successfully
|
|
assert len(results) == len(channels)
|
|
assert all(result is not None for result in results)
|
|
|
|
# Verify channel isolation
|
|
for i, result in enumerate(results):
|
|
assert (
|
|
str(channels[i]["id"]) in result.audio_file_path
|
|
or result.audio_file_path == audio_files[i]
|
|
)
|
|
|
|
finally:
|
|
# Cleanup
|
|
for audio_file in audio_files:
|
|
Path(audio_file).unlink(missing_ok=True)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_error_recovery_and_fallbacks(
|
|
self, diarization_service, create_test_wav_file
|
|
):
|
|
"""Test error recovery mechanisms and fallback strategies."""
|
|
audio_file = create_test_wav_file()
|
|
|
|
try:
|
|
# Test NeMo model failure with fallback
|
|
with patch.object(
|
|
diarization_service, "_run_nemo_diarization"
|
|
) as mock_nemo:
|
|
mock_nemo.side_effect = Exception("NeMo model failed")
|
|
|
|
with patch.object(
|
|
diarization_service, "_fallback_basic_vad"
|
|
) as mock_fallback:
|
|
mock_fallback.return_value = [
|
|
SpeakerSegment(0.0, 10.0, "SPEAKER_00", 0.6, needs_tagging=True)
|
|
]
|
|
|
|
result = await diarization_service.process_audio_clip(
|
|
audio_file_path=audio_file,
|
|
guild_id=12345,
|
|
channel_id=67890,
|
|
participants=[111, 222],
|
|
)
|
|
|
|
assert result is not None
|
|
assert len(result.speaker_segments) == 1
|
|
assert result.speaker_segments[
|
|
0
|
|
].needs_tagging # Indicates fallback was used
|
|
mock_fallback.assert_called_once()
|
|
|
|
finally:
|
|
Path(audio_file).unlink(missing_ok=True)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_memory_management(self, diarization_service, create_test_wav_file):
|
|
"""Test memory management during intensive processing."""
|
|
# Create multiple large audio files
|
|
large_audio_files = [
|
|
create_test_wav_file(duration_seconds=120) # 2-minute files
|
|
for _ in range(5)
|
|
]
|
|
|
|
try:
|
|
# Track memory usage
|
|
initial_memory = (
|
|
torch.cuda.memory_allocated() if torch.cuda.is_available() else 0
|
|
)
|
|
|
|
# Process files sequentially with memory monitoring
|
|
for audio_file in large_audio_files:
|
|
await diarization_service.process_audio_clip(
|
|
audio_file_path=audio_file,
|
|
guild_id=12345,
|
|
channel_id=67890,
|
|
participants=[111, 222],
|
|
)
|
|
|
|
# Force garbage collection
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
|
|
current_memory = (
|
|
torch.cuda.memory_allocated() if torch.cuda.is_available() else 0
|
|
)
|
|
memory_increase = current_memory - initial_memory
|
|
|
|
# Memory should not grow excessively
|
|
assert memory_increase < 1024 * 1024 * 1024 # Less than 1GB increase
|
|
|
|
finally:
|
|
# Cleanup
|
|
for audio_file in large_audio_files:
|
|
Path(audio_file).unlink(missing_ok=True)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_performance_benchmarks(
|
|
self, diarization_service, create_test_wav_file
|
|
):
|
|
"""Test performance benchmarks for different scenarios."""
|
|
scenarios = [
|
|
{"duration": 10, "expected_max_time": 5.0, "description": "Short audio"},
|
|
{"duration": 60, "expected_max_time": 15.0, "description": "Medium audio"},
|
|
{"duration": 120, "expected_max_time": 30.0, "description": "Long audio"},
|
|
]
|
|
|
|
for scenario in scenarios:
|
|
audio_file = create_test_wav_file(duration_seconds=scenario["duration"])
|
|
|
|
try:
|
|
start_time = datetime.utcnow()
|
|
|
|
result = await diarization_service.process_audio_clip(
|
|
audio_file_path=audio_file,
|
|
guild_id=12345,
|
|
channel_id=67890,
|
|
participants=[111, 222],
|
|
)
|
|
|
|
processing_time = (datetime.utcnow() - start_time).total_seconds()
|
|
|
|
assert result is not None
|
|
assert processing_time <= scenario["expected_max_time"], (
|
|
f"{scenario['description']}: Processing took {processing_time:.2f}s, "
|
|
f"expected <= {scenario['expected_max_time']}s"
|
|
)
|
|
|
|
finally:
|
|
Path(audio_file).unlink(missing_ok=True)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_data_consistency(
|
|
self, diarization_service, mock_database_manager, create_test_wav_file
|
|
):
|
|
"""Test data consistency between diarization results and database storage."""
|
|
audio_file = create_test_wav_file()
|
|
|
|
try:
|
|
# Mock database storage
|
|
stored_segments = []
|
|
|
|
async def mock_store_segment(*args):
|
|
stored_segments.append(args)
|
|
return {"id": len(stored_segments)}
|
|
|
|
mock_database_manager.execute_query.side_effect = mock_store_segment
|
|
|
|
result = await diarization_service.process_audio_clip(
|
|
audio_file_path=audio_file,
|
|
guild_id=12345,
|
|
channel_id=67890,
|
|
participants=[111, 222],
|
|
)
|
|
|
|
# Verify data consistency
|
|
assert result is not None
|
|
assert len(stored_segments) == len(result.speaker_segments)
|
|
|
|
# Verify timestamp consistency
|
|
for segment in result.speaker_segments:
|
|
assert segment.start_time < segment.end_time
|
|
assert segment.end_time <= result.total_duration
|
|
|
|
finally:
|
|
Path(audio_file).unlink(missing_ok=True)
|
|
|
|
async def _combine_diarization_and_transcription(
|
|
self, diar_segments: List[SpeakerSegment], transcription_segments: List[dict]
|
|
) -> List[dict]:
|
|
"""Combine diarization and transcription results."""
|
|
combined = []
|
|
|
|
for trans_seg in transcription_segments:
|
|
# Find overlapping speaker segment
|
|
best_overlap = 0
|
|
best_speaker = None
|
|
|
|
for diar_seg in diar_segments:
|
|
# Calculate overlap
|
|
overlap_start = max(trans_seg["start"], diar_seg.start_time)
|
|
overlap_end = min(trans_seg["end"], diar_seg.end_time)
|
|
overlap = max(0, overlap_end - overlap_start)
|
|
|
|
if overlap > best_overlap:
|
|
best_overlap = overlap
|
|
best_speaker = diar_seg
|
|
|
|
combined_segment = {
|
|
"start": trans_seg["start"],
|
|
"end": trans_seg["end"],
|
|
"text": trans_seg["text"],
|
|
"confidence": trans_seg["confidence"],
|
|
"speaker_label": (
|
|
best_speaker.speaker_label if best_speaker else "UNKNOWN"
|
|
),
|
|
"user_id": best_speaker.user_id if best_speaker else None,
|
|
}
|
|
combined.append(combined_segment)
|
|
|
|
return combined
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_speaker_continuity(self, diarization_service, create_test_wav_file):
|
|
"""Test speaker label continuity across segments."""
|
|
audio_file = create_test_wav_file(duration_seconds=30)
|
|
|
|
try:
|
|
with patch.object(
|
|
diarization_service, "_run_nemo_diarization"
|
|
) as mock_diar:
|
|
# Simulate alternating speakers
|
|
mock_diar.return_value = [
|
|
{
|
|
"start_time": 0.0,
|
|
"end_time": 5.0,
|
|
"speaker_label": "SPEAKER_01",
|
|
"confidence": 0.9,
|
|
},
|
|
{
|
|
"start_time": 5.0,
|
|
"end_time": 10.0,
|
|
"speaker_label": "SPEAKER_02",
|
|
"confidence": 0.85,
|
|
},
|
|
{
|
|
"start_time": 10.0,
|
|
"end_time": 15.0,
|
|
"speaker_label": "SPEAKER_01",
|
|
"confidence": 0.88,
|
|
},
|
|
{
|
|
"start_time": 15.0,
|
|
"end_time": 20.0,
|
|
"speaker_label": "SPEAKER_02",
|
|
"confidence": 0.92,
|
|
},
|
|
{
|
|
"start_time": 20.0,
|
|
"end_time": 25.0,
|
|
"speaker_label": "SPEAKER_01",
|
|
"confidence": 0.87,
|
|
},
|
|
]
|
|
|
|
result = await diarization_service.process_audio_clip(
|
|
audio_file_path=audio_file,
|
|
guild_id=12345,
|
|
channel_id=67890,
|
|
participants=[111, 222],
|
|
)
|
|
|
|
# Verify speaker continuity
|
|
speaker_01_segments = [
|
|
seg
|
|
for seg in result.speaker_segments
|
|
if seg.speaker_label == "SPEAKER_01"
|
|
]
|
|
speaker_02_segments = [
|
|
seg
|
|
for seg in result.speaker_segments
|
|
if seg.speaker_label == "SPEAKER_02"
|
|
]
|
|
|
|
assert len(speaker_01_segments) == 3
|
|
assert len(speaker_02_segments) == 2
|
|
|
|
# Verify temporal ordering
|
|
for segments in [speaker_01_segments, speaker_02_segments]:
|
|
for i in range(1, len(segments)):
|
|
assert segments[i].start_time > segments[i - 1].end_time
|
|
|
|
finally:
|
|
Path(audio_file).unlink(missing_ok=True)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_quote_scoring_integration(
|
|
self, diarization_service, quote_analyzer, create_test_wav_file
|
|
):
|
|
"""Test integration between diarization and quote scoring."""
|
|
audio_file = create_test_wav_file()
|
|
|
|
try:
|
|
# Mock diarization with speaker identification
|
|
with patch.object(diarization_service, "process_audio_clip") as mock_diar:
|
|
mock_result = DiarizationResult(
|
|
audio_file_path=audio_file,
|
|
total_duration=10.0,
|
|
speaker_segments=[
|
|
SpeakerSegment(0.0, 5.0, "Alice", 0.9, user_id=111),
|
|
SpeakerSegment(5.0, 10.0, "Bob", 0.85, user_id=222),
|
|
],
|
|
unique_speakers=["Alice", "Bob"],
|
|
processing_time=2.0,
|
|
timestamp=datetime.utcnow(),
|
|
)
|
|
mock_diar.return_value = mock_result
|
|
|
|
diar_result = await mock_diar(audio_file, 12345, 67890, [111, 222])
|
|
|
|
# Test quote scoring for each speaker
|
|
for segment in diar_result.speaker_segments:
|
|
if segment.user_id:
|
|
# Mock transcription for this segment
|
|
segment_text = f"This is a quote from {segment.speaker_label}"
|
|
|
|
analysis = await quote_analyzer.analyze_quote(
|
|
text=segment_text,
|
|
speaker_id=segment.user_id,
|
|
context={
|
|
"speaker_confidence": segment.confidence,
|
|
"duration": segment.end_time - segment.start_time,
|
|
},
|
|
)
|
|
|
|
assert "overall_score" in analysis
|
|
assert analysis["overall_score"] > 0
|
|
|
|
finally:
|
|
Path(audio_file).unlink(missing_ok=True)
|