Files
disbord/services/audio/speaker_diarization.py
Travis Vasceannie 3acb779569 chore: remove .env.example and add new files for project structure
- Deleted .env.example file as it is no longer needed.
- Added .gitignore to manage ignored files and directories.
- Introduced CLAUDE.md for AI provider integration documentation.
- Created dev.sh for development setup and scripts.
- Updated Dockerfile and Dockerfile.production for improved build processes.
- Added multiple test files and directories for comprehensive testing.
- Introduced new utility and service files for enhanced functionality.
- Organized codebase with new directories and files for better maintainability.
2025-08-27 23:00:19 -04:00

1114 lines
41 KiB
Python

"""
Speaker Diarization Service for Discord Voice Chat Quote Bot
Integrates NVIDIA NeMo for automatic speaker separation and labeling.
Provides speaker segments that can be mapped to Discord users.
"""
import asyncio
import json
import logging
import tempfile
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Optional
import librosa
import numpy as np
import soundfile as sf
import torch
from omegaconf import DictConfig, OmegaConf
from core.consent_manager import ConsentManager
from core.database import DatabaseManager
from utils.audio_processor import AudioProcessor
# Set up logger first
logger = logging.getLogger(__name__)
# NeMo imports (with fallback)
try:
from nemo.collections.asr.models import ClusteringDiarizer, NeuralDiarizer
from nemo.collections.asr.models.label_models import \
EncDecSpeakerLabelModel as EncDecDiarLabelModel
from nemo.utils import logging as nemo_logging
NEMO_AVAILABLE = True
logger.info("NVIDIA NeMo is available for speaker diarization")
except (ImportError, AttributeError) as e:
logger.warning(f"NeMo not available: {e}. Using fallback implementation.")
ClusteringDiarizer = None
NeuralDiarizer = None
EncDecDiarLabelModel = None
nemo_logging = None
NEMO_AVAILABLE = False
@dataclass
class SpeakerSegment:
"""Data structure for speaker segments."""
start_time: float
end_time: float
speaker_label: str
confidence: float
audio_data: Optional[bytes] = None
user_id: Optional[int] = None
needs_tagging: bool = True
@dataclass
class DiarizationResult:
"""Complete diarization result for an audio clip."""
audio_file_path: str
total_duration: float
speaker_segments: list[SpeakerSegment]
unique_speakers: list[str]
processing_time: float
timestamp: datetime
class NeMoDiarizationConfig:
"""Configuration manager for NeMo diarization models."""
def __init__(self, device: str = "cuda"):
self.device = device
self.vad_model = "vad_multilingual_marblenet"
self.speaker_model = "titanet_large"
self.neural_diarizer_model = "diar_msdd_telephonic"
# Processing parameters
self.sample_rate = 16000
self.min_segment_duration = 1.0
self.max_speakers = 8
self.window_length = 1.5
self.shift_length = 0.75
# VAD parameters
self.vad_onset = 0.8
self.vad_offset = 0.6
self.vad_pad_offset = -0.05
# Clustering parameters
self.oracle_num_speakers = False
self.max_num_speakers = 8
self.enhanced_count_thres = 80
self.sparse_search_volume = 30
def get_clustering_config(self) -> DictConfig:
"""Get configuration for clustering diarizer."""
config = OmegaConf.create(
{
"diarizer": {
"manifest_filepath": None,
"out_dir": None,
"oracle_vad": False,
"collar": 0.25,
"ignore_overlap": True,
"vad": {
"model_path": self.vad_model,
"parameters": {
"onset": self.vad_onset,
"offset": self.vad_offset,
"pad_offset": self.vad_pad_offset,
"min_duration_on": 0.1,
"min_duration_off": 0.1,
},
},
"speaker_embeddings": {
"model_path": self.speaker_model,
"parameters": {
"window_length_in_sec": self.window_length,
"shift_length_in_sec": self.shift_length,
"multiscale_weights": None,
"save_embeddings": False,
},
},
"clustering": {
"parameters": {
"oracle_num_speakers": self.oracle_num_speakers,
"max_num_speakers": self.max_num_speakers,
"enhanced_count_thres": self.enhanced_count_thres,
"sparse_search_volume": self.sparse_search_volume,
}
},
}
}
)
return config
def get_neural_config(self) -> DictConfig:
"""Get configuration for neural diarizer."""
config = self.get_clustering_config()
config.diarizer.msdd_model = {
"model_path": self.neural_diarizer_model,
"parameters": {"sigmoid_threshold": [0.7, 1.0]},
}
return config
class SpeakerDiarizationService:
"""
Speaker diarization service using NVIDIA NeMo.
Features:
- Automatic speaker separation using clustering and neural approaches
- Speaker labeling and tracking
- Integration with consent management
- Support for user-assisted tagging
- Caching of diarization results
- GPU acceleration support
"""
def __init__(
self,
db_manager: DatabaseManager,
consent_manager: ConsentManager,
audio_processor: AudioProcessor,
):
self.db_manager = db_manager
self.consent_manager = consent_manager
self.audio_processor = audio_processor
# Device configuration
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Initializing NeMo diarization service on {self.device}")
# Configuration
self.config = NeMoDiarizationConfig(device=self.device)
# Service parameters (for test compatibility)
self.min_speakers = 1
self.max_speakers = self.config.max_speakers
self.min_segment_duration = self.config.min_segment_duration
# Models
self.clustering_model: Optional[ClusteringDiarizer] = None
self.neural_model: Optional[NeuralDiarizer] = None
# Processing queues and caching
self.processing_queue: asyncio.Queue[dict[str, object]] = asyncio.Queue()
self.result_cache: dict[str, DiarizationResult] = {}
self.cache_expiry = timedelta(hours=2)
# State management
self._initialized = False
self._processing_task: Optional[asyncio.Task[None]] = None
# Temporary directory for audio files
self.temp_dir = Path(tempfile.mkdtemp(prefix="nemo_diarization_"))
async def initialize(self) -> None:
"""Initialize the diarization models and workers."""
if self._initialized:
return
try:
logger.info("Initializing NeMo speaker diarization service...")
# Suppress NeMo logging noise if available
if nemo_logging:
nemo_logging.setLevel(logging.WARNING)
# Load diarization models
await self._load_diarization_models()
# Start processing worker
self._processing_task = asyncio.create_task(self._processing_worker())
# Start cache cleanup task
asyncio.create_task(self._cache_cleanup_worker())
self._initialized = True
logger.info(
f"NeMo diarization service initialized successfully on {self.device}"
)
except (ImportError, AttributeError) as e:
logger.warning(
f"Failed to initialize NeMo diarization service due to dependency issue: {e}"
)
logger.info("Continuing with fallback audio processing capabilities")
self._initialized = True # Still initialize service with fallback
except Exception as e:
logger.error(f"Failed to initialize NeMo diarization service: {e}")
raise
async def _load_nemo_models(self) -> bool:
"""Load NeMo models (compatibility method for tests)."""
try:
await self._load_diarization_models()
return True
except Exception as e:
logger.warning(f"Failed to load NeMo models: {e}")
return False
async def _load_diarization_models(self) -> None:
"""Load NeMo diarization models or use fallback."""
if not NEMO_AVAILABLE:
logger.warning("NeMo not available. Using basic audio processing fallback.")
self.clustering_model = None
self.neural_model = None
return
try:
# Load models in thread pool to avoid blocking
def load_models() -> (
tuple[Optional[ClusteringDiarizer], Optional[NeuralDiarizer]]
):
try:
# Create clustering diarizer
clustering_config = self.config.get_clustering_config()
clustering_model = ClusteringDiarizer(
cfg=clustering_config.diarizer
)
logger.info("Clustering diarizer loaded successfully")
except (ImportError, AttributeError, Exception) as e:
logger.warning(f"Failed to load clustering diarizer: {e}")
clustering_model = None
try:
# Create neural diarizer for advanced scenarios
neural_config = self.config.get_neural_config()
neural_model = NeuralDiarizer(cfg=neural_config.diarizer)
logger.info("Neural diarizer loaded successfully")
except (ImportError, AttributeError, Exception) as e:
logger.warning(f"Failed to load neural diarizer: {e}")
neural_model = None
return clustering_model, neural_model
self.clustering_model, self.neural_model = (
await asyncio.get_event_loop().run_in_executor(None, load_models)
)
if not self.clustering_model and not self.neural_model:
logger.warning("No NeMo models loaded. Using fallback implementation.")
except Exception as e:
logger.error(f"Failed to load NeMo models: {e}")
logger.warning("Falling back to basic audio processing.")
async def process_audio_clip(
self,
audio_file_path: str,
guild_id: int,
channel_id: int,
participants: list[int],
) -> Optional[DiarizationResult]:
"""
Process audio clip for speaker diarization.
Args:
audio_file_path: Path to audio file
guild_id: Discord guild ID
channel_id: Discord channel ID
participants: List of user IDs who were in the channel
Returns:
DiarizationResult or None if processing failed
"""
try:
if not self._initialized:
await self.initialize()
# Check cache first
cache_key = f"{audio_file_path}_{hash(tuple(participants))}"
if cache_key in self.result_cache:
cached_result = self.result_cache[cache_key]
if (
datetime.now(timezone.utc) - cached_result.timestamp
< self.cache_expiry
):
logger.debug(
f"Using cached diarization result for {audio_file_path}"
)
return cached_result
# Validate consent for all participants
consented_users = []
for user_id in participants:
has_consent = await self.consent_manager.has_recording_consent(
user_id, guild_id
)
if has_consent:
consented_users.append(user_id)
if not consented_users:
logger.info(
f"No consented users in channel {channel_id}, skipping diarization"
)
return None
# Queue for processing
result_future: asyncio.Future[Optional[DiarizationResult]] = (
asyncio.Future()
)
await self.processing_queue.put(
{
"audio_file_path": audio_file_path,
"guild_id": guild_id,
"channel_id": channel_id,
"participants": consented_users,
"result_future": result_future,
}
)
# Wait for processing result
result = await result_future
# Cache result
if result:
self.result_cache[cache_key] = result
return result
except Exception as e:
logger.error(f"Failed to process audio clip for diarization: {e}")
return None
async def _processing_worker(self) -> None:
"""Background worker for processing diarization requests."""
while True:
try:
# Get next processing request
request = await self.processing_queue.get()
try:
result = await self._perform_diarization(
request["audio_file_path"],
request["guild_id"],
request["channel_id"],
request["participants"],
)
request["result_future"].set_result(result)
except Exception as e:
logger.error(f"Error processing diarization request: {e}")
request["result_future"].set_exception(e)
finally:
self.processing_queue.task_done()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in diarization processing worker: {e}")
await asyncio.sleep(1)
async def _perform_diarization(
self,
audio_file_path: str,
guild_id: int,
channel_id: int,
participants: list[int],
) -> Optional[DiarizationResult]:
"""Perform actual speaker diarization using NeMo models."""
try:
start_time = datetime.now(timezone.utc)
# Prepare audio file for NeMo processing
prepared_audio_path = await self._prepare_audio_file(audio_file_path)
if not prepared_audio_path:
return None
# Create manifest file for NeMo
manifest_path = await self._create_manifest_file(prepared_audio_path)
# Perform diarization using available model
diarization_output = await self._run_nemo_diarization(
manifest_path, len(participants)
)
if not diarization_output:
logger.warning(f"No speakers detected in {audio_file_path}")
return None
# Convert NeMo output to our format
speaker_segments = await self._convert_nemo_output_to_segments(
diarization_output, audio_file_path, prepared_audio_path
)
# Attempt speaker identification
speaker_segments = await self._identify_speakers(
speaker_segments, guild_id, participants
)
# Get unique speakers
unique_speakers = list(
set(segment.speaker_label for segment in speaker_segments)
)
# Calculate total duration
total_duration = await self._get_audio_duration(prepared_audio_path)
processing_time = (datetime.now(timezone.utc) - start_time).total_seconds()
result = DiarizationResult(
audio_file_path=audio_file_path,
total_duration=total_duration,
speaker_segments=speaker_segments,
unique_speakers=unique_speakers,
processing_time=processing_time,
timestamp=datetime.now(timezone.utc),
)
# Store result in database
await self._store_diarization_result(result, guild_id, channel_id)
# Cleanup temporary files
await self._cleanup_temp_files([prepared_audio_path, manifest_path])
logger.info(
f"Diarization complete: {len(unique_speakers)} speakers, "
f"{len(speaker_segments)} segments, {processing_time:.2f}s"
)
return result
except Exception as e:
logger.error(f"Failed to perform NeMo diarization: {e}")
return None
async def _prepare_audio_file(self, audio_file_path: str) -> Optional[str]:
"""Prepare audio file for NeMo processing (convert to 16kHz mono WAV)."""
try:
def convert_audio() -> str:
# Load audio file
audio_data, sample_rate = librosa.load(audio_file_path, sr=None)
# Convert to mono if needed
if len(audio_data.shape) > 1:
audio_data = librosa.to_mono(audio_data)
# Resample to 16kHz for NeMo
if sample_rate != self.config.sample_rate:
audio_data = librosa.resample(
audio_data,
orig_sr=sample_rate,
target_sr=self.config.sample_rate,
)
# Save as WAV file in temp directory
output_path = (
self.temp_dir / f"processed_{Path(audio_file_path).stem}.wav"
)
sf.write(str(output_path), audio_data, self.config.sample_rate)
return str(output_path)
return await asyncio.get_event_loop().run_in_executor(None, convert_audio)
except Exception as e:
logger.error(f"Failed to prepare audio file {audio_file_path}: {e}")
return None
async def _create_manifest_file(self, audio_path: str) -> str:
"""Create manifest file required by NeMo."""
try:
manifest_data = {
"audio_filepath": audio_path,
"offset": 0,
"duration": None,
"label": "infer",
"text": "-",
"num_speakers": None,
"rttm_filepath": None,
"uem_filepath": None,
}
manifest_path = self.temp_dir / f"manifest_{Path(audio_path).stem}.json"
with open(manifest_path, "w", encoding="utf-8") as f:
json.dump(manifest_data, f)
f.write("\n") # NeMo expects newline-separated JSON
return str(manifest_path)
except Exception as e:
logger.error(f"Failed to create manifest file: {e}")
raise
async def _run_nemo_diarization(
self, manifest_path: str, expected_speakers: int
) -> Optional[dict[str, object]]:
"""Run NeMo diarization or fallback implementation."""
# If NeMo is not available, use basic voice activity detection
if not NEMO_AVAILABLE or (not self.clustering_model and not self.neural_model):
return await self._run_fallback_diarization(
manifest_path, expected_speakers
)
try:
def run_diarization() -> Optional[dict[str, object]]:
try:
# Choose model based on availability and expected speakers
if self.neural_model and expected_speakers >= 3:
model = self.neural_model
logger.debug("Using neural diarizer for complex scenario")
elif self.clustering_model:
model = self.clustering_model
logger.debug("Using clustering diarizer")
else:
return None
# Update model configuration with manifest path
model.cfg.manifest_filepath = manifest_path
model.cfg.out_dir = str(self.temp_dir)
# Run diarization
model.diarize()
# Load results from RTTM file
manifest_stem = Path(manifest_path).stem
rttm_path = (
self.temp_dir
/ "pred_rttms"
/ f"{manifest_stem.replace('manifest_', '')}.rttm"
)
if not rttm_path.exists():
logger.error(f"RTTM file not found: {rttm_path}")
return None
# Parse RTTM file
segments = []
with open(rttm_path, "r", encoding="utf-8") as f:
for line in f:
if line.strip():
parts = line.strip().split()
if len(parts) >= 8:
start_time = float(parts[3])
duration = float(parts[4])
speaker_id = parts[7]
segments.append(
{
"start": start_time,
"end": start_time + duration,
"speaker": speaker_id,
}
)
return {"segments": segments}
except Exception as e:
logger.error(f"NeMo diarization failed: {e}")
return None
result = await asyncio.get_event_loop().run_in_executor(
None, run_diarization
)
# If NeMo failed, fall back to basic implementation
if result is None:
logger.warning("NeMo diarization failed, using fallback")
return await self._run_fallback_diarization(
manifest_path, expected_speakers
)
return result
except Exception as e:
logger.error(f"Failed to run NeMo diarization: {e}")
return await self._run_fallback_diarization(
manifest_path, expected_speakers
)
async def _run_fallback_diarization(
self, manifest_path: str, expected_speakers: int
) -> Optional[dict[str, object]]:
"""Run basic voice activity detection as fallback when NeMo is unavailable."""
try:
# Load manifest to get audio path
with open(manifest_path, "r", encoding="utf-8") as f:
manifest_data = json.load(f)
audio_path = manifest_data["audio_filepath"]
def basic_diarization() -> Optional[dict[str, object]]:
try:
# Load audio
audio_data, sample_rate = librosa.load(
audio_path, sr=16000, mono=True
)
duration = len(audio_data) / sample_rate
# Simple voice activity detection using energy
frame_length = int(0.025 * sample_rate) # 25ms frames
hop_length = int(0.010 * sample_rate) # 10ms hop
# Calculate energy in each frame
energy = []
for i in range(0, len(audio_data) - frame_length, hop_length):
frame = audio_data[i : i + frame_length]
frame_energy = np.sum(frame**2)
energy.append(frame_energy)
energy = np.array(energy)
threshold = np.mean(energy) + 1.5 * np.std(energy)
voice_frames = energy > threshold
# Create segments for voice activity
segments = []
in_speech = False
start_time = 0
current_speaker = 0
speaker_change_interval = (
10.0 # Change speaker every 10 seconds (basic heuristic)
)
next_speaker_change = speaker_change_interval
frame_duration = hop_length / sample_rate
for i, is_voice in enumerate(voice_frames):
current_time = i * frame_duration
# Simple speaker change heuristic
if current_time > next_speaker_change:
current_speaker = (current_speaker + 1) % max(
2, min(expected_speakers, 4)
)
next_speaker_change += speaker_change_interval
if is_voice and not in_speech:
# Start of speech
start_time = current_time
in_speech = True
elif not is_voice and in_speech:
# End of speech
if (
current_time - start_time
>= self.config.min_segment_duration
):
segments.append(
{
"start": start_time,
"end": current_time,
"speaker": f"SPEAKER_{current_speaker:02d}",
}
)
in_speech = False
# Handle case where speech continues to end
if in_speech:
if duration - start_time >= self.config.min_segment_duration:
segments.append(
{
"start": start_time,
"end": duration,
"speaker": f"SPEAKER_{current_speaker:02d}",
}
)
logger.info(
f"Fallback diarization found {len(segments)} speech segments"
)
return {"segments": segments} if segments else None
except Exception as e:
logger.error(f"Fallback diarization failed: {e}")
return None
return await asyncio.get_event_loop().run_in_executor(
None, basic_diarization
)
except Exception as e:
logger.error(f"Failed to run fallback diarization: {e}")
return None
async def _convert_nemo_output_to_segments(
self, nemo_output: dict[str, object], original_path: str, prepared_path: str
) -> list[SpeakerSegment]:
"""Convert NeMo diarization output to SpeakerSegment objects."""
try:
segments = []
nemo_segments = nemo_output.get("segments", [])
# Load audio for segment extraction
audio_data, sample_rate = await self._load_audio_async(prepared_path)
for seg in nemo_segments:
start_time = float(seg["start"])
end_time = float(seg["end"])
speaker_label = seg["speaker"]
# Filter segments that are too short
if end_time - start_time < self.config.min_segment_duration:
continue
# Extract audio data for this segment
start_sample = int(start_time * sample_rate)
end_sample = int(end_time * sample_rate)
if start_sample < len(audio_data) and end_sample > start_sample:
segment_audio = audio_data[start_sample:end_sample]
# Convert to bytes for storage
audio_bytes = await self.audio_processor.numpy_to_bytes(
segment_audio, sample_rate
)
speaker_segment = SpeakerSegment(
start_time=start_time,
end_time=end_time,
speaker_label=speaker_label,
confidence=1.0, # NeMo doesn't provide confidence scores directly
audio_data=audio_bytes,
user_id=None, # Will be filled by speaker identification
needs_tagging=True,
)
segments.append(speaker_segment)
return segments
except Exception as e:
logger.error(f"Failed to convert NeMo output to segments: {e}")
return []
async def _load_audio_async(self, audio_path: str) -> tuple[np.ndarray, int]:
"""Load audio file asynchronously."""
def load_audio() -> tuple[np.ndarray, int]:
audio_data, sample_rate = sf.read(audio_path)
return audio_data, sample_rate
return await asyncio.get_event_loop().run_in_executor(None, load_audio)
async def _get_audio_duration(self, audio_path: str) -> float:
"""Get audio file duration."""
try:
def get_duration() -> float:
with sf.SoundFile(audio_path) as f:
return len(f) / f.samplerate
return await asyncio.get_event_loop().run_in_executor(None, get_duration)
except Exception as e:
logger.error(f"Failed to get audio duration: {e}")
return 0.0
async def _identify_speakers(
self, segments: list[SpeakerSegment], guild_id: int, participants: list[int]
) -> list[SpeakerSegment]:
"""Attempt to identify speakers using stored voice profiles."""
try:
# Get known speaker profiles for participants
speaker_profiles = await self.db_manager.execute_query(
"""
SELECT user_id, voice_embedding, username
FROM speaker_profiles
WHERE user_id = ANY($1) AND guild_id = $2 AND voice_embedding IS NOT NULL
""",
participants,
guild_id,
fetch_all=True,
)
if not speaker_profiles:
logger.debug("No speaker profiles available for identification")
return segments
# For each segment, try to match with known profiles
for segment in segments:
if segment.audio_data:
# Generate embedding for segment
segment_embedding = await self._generate_voice_embedding(
segment.audio_data
)
if segment_embedding:
# Find best match among known profiles
best_match = await self._find_best_speaker_match(
segment_embedding, speaker_profiles
)
if best_match:
segment.user_id = best_match["user_id"]
segment.speaker_label = (
best_match["username"]
or f"User_{best_match['user_id']}"
)
segment.confidence = best_match["confidence"]
segment.needs_tagging = False
return segments
except Exception as e:
logger.error(f"Failed to identify speakers: {e}")
return segments
async def _generate_voice_embedding(
self, audio_data: bytes
) -> Optional[np.ndarray]:
"""Generate voice embedding for speaker identification."""
try:
# This would integrate with a speaker recognition model
# For now, return None to indicate no embedding available
logger.debug("Voice embedding generation not yet implemented")
return None
except Exception as e:
logger.error(f"Failed to generate voice embedding: {e}")
return None
async def _find_best_speaker_match(
self, segment_embedding: np.ndarray, speaker_profiles: list[dict[str, object]]
) -> Optional[dict[str, object]]:
"""Find best matching speaker profile."""
try:
# This would implement speaker matching logic
logger.debug("Speaker matching not yet implemented")
return None
except Exception as e:
logger.error(f"Failed to find speaker match: {e}")
return None
async def _store_diarization_result(
self, result: DiarizationResult, guild_id: int, channel_id: int
) -> None:
"""Store diarization result in database."""
try:
# Store main diarization record
diarization_id = await self.db_manager.execute_query(
"""
INSERT INTO speaker_diarizations
(guild_id, channel_id, audio_file_path, total_duration, unique_speakers, processing_time)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id
""",
guild_id,
channel_id,
result.audio_file_path,
result.total_duration,
len(result.unique_speakers),
result.processing_time,
fetch_one=True,
)
diarization_id = diarization_id["id"]
# Store speaker segments
for segment in result.speaker_segments:
await self.db_manager.execute_query(
"""
INSERT INTO speaker_segments
(diarization_id, start_time, end_time, speaker_label, confidence, user_id, needs_tagging)
VALUES ($1, $2, $3, $4, $5, $6, $7)
""",
diarization_id,
segment.start_time,
segment.end_time,
segment.speaker_label,
segment.confidence,
segment.user_id,
segment.needs_tagging,
)
logger.debug(
f"Stored diarization result {diarization_id} with {len(result.speaker_segments)} segments"
)
except Exception as e:
logger.error(f"Failed to store diarization result: {e}")
async def _cleanup_temp_files(self, file_paths: list[str]) -> None:
"""Clean up temporary files."""
try:
def cleanup():
for file_path in file_paths:
try:
Path(file_path).unlink(missing_ok=True)
except Exception as e:
logger.warning(f"Failed to cleanup temp file {file_path}: {e}")
await asyncio.get_event_loop().run_in_executor(None, cleanup)
except Exception as e:
logger.error(f"Error during temp file cleanup: {e}")
async def _cache_cleanup_worker(self) -> None:
"""Background worker to clean up expired cache entries."""
while True:
try:
current_time = datetime.now(timezone.utc)
expired_keys = []
for key, result in self.result_cache.items():
if current_time - result.timestamp > self.cache_expiry:
expired_keys.append(key)
for key in expired_keys:
del self.result_cache[key]
if expired_keys:
logger.debug(
f"Cleaned up {len(expired_keys)} expired cache entries"
)
# Sleep for 30 minutes
await asyncio.sleep(1800)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in cache cleanup worker: {e}")
await asyncio.sleep(1800)
# Public API methods (maintaining compatibility)
async def get_speaker_segments(
self, audio_file_path: str
) -> Optional[list[SpeakerSegment]]:
"""Get speaker segments for an audio file."""
try:
results = await self.db_manager.execute_query(
"""
SELECT ss.start_time, ss.end_time, ss.speaker_label, ss.confidence,
ss.user_id, ss.needs_tagging
FROM speaker_segments ss
JOIN speaker_diarizations sd ON ss.diarization_id = sd.id
WHERE sd.audio_file_path = $1
ORDER BY ss.start_time
""",
audio_file_path,
fetch_all=True,
)
segments = []
for result in results:
segment = SpeakerSegment(
start_time=float(result["start_time"]),
end_time=float(result["end_time"]),
speaker_label=result["speaker_label"],
confidence=float(result["confidence"]),
user_id=result["user_id"],
needs_tagging=result["needs_tagging"],
)
segments.append(segment)
return segments
except Exception as e:
logger.error(f"Failed to get speaker segments: {e}")
return None
async def tag_speaker_segment(
self, segment_id: int, user_id: int, username: str
) -> None:
"""Tag a speaker segment with user identification (user-assisted)."""
try:
await self.db_manager.execute_query(
"""
UPDATE speaker_segments
SET user_id = $1, speaker_label = $2, needs_tagging = FALSE
WHERE id = $3
""",
user_id,
username,
segment_id,
)
logger.info(
f"Tagged speaker segment {segment_id} as user {user_id} ({username})"
)
except Exception as e:
logger.error(f"Failed to tag speaker segment: {e}")
raise
async def get_untagged_segments(
self, guild_id: int, limit: int = 10
) -> list[dict[str, object]]:
"""Get untagged speaker segments for user assistance."""
try:
results = await self.db_manager.execute_query(
"""
SELECT ss.id, ss.start_time, ss.end_time, ss.speaker_label,
sd.audio_file_path, sd.guild_id, sd.channel_id
FROM speaker_segments ss
JOIN speaker_diarizations sd ON ss.diarization_id = sd.id
WHERE sd.guild_id = $1 AND ss.needs_tagging = TRUE
ORDER BY sd.timestamp DESC, ss.start_time ASC
LIMIT $2
""",
guild_id,
limit,
fetch_all=True,
)
return [dict(result) for result in results]
except Exception as e:
logger.error(f"Failed to get untagged segments: {e}")
return []
async def check_health(self) -> dict[str, object]:
"""Check health of diarization service."""
try:
health_status: dict[str, object] = {
"initialized": self._initialized,
"nemo_available": NEMO_AVAILABLE,
"clustering_model_loaded": self.clustering_model is not None,
"neural_model_loaded": self.neural_model is not None,
"device": self.device,
"queue_size": self.processing_queue.qsize(),
"cache_size": len(self.result_cache),
"framework": (
"NVIDIA NeMo (with fallback)"
if NEMO_AVAILABLE
else "Basic Audio Processing"
),
}
if NEMO_AVAILABLE and (self.clustering_model or self.neural_model):
health_status.update(
{
"vad_model": self.config.vad_model,
"speaker_model": self.config.speaker_model,
"neural_diarizer_model": self.config.neural_diarizer_model,
}
)
else:
health_status["fallback_mode"] = True
return health_status
except Exception as e:
return {"error": str(e), "healthy": False}
async def close(self) -> None:
"""Close diarization service and cleanup resources."""
try:
if self._processing_task:
self._processing_task.cancel()
try:
await self._processing_task
except asyncio.CancelledError:
pass
# Clear cache
self.result_cache.clear()
# Cleanup temp directory
if self.temp_dir.exists():
def cleanup_temp_dir():
import shutil
shutil.rmtree(self.temp_dir, ignore_errors=True)
await asyncio.get_event_loop().run_in_executor(None, cleanup_temp_dir)
logger.info("NeMo speaker diarization service closed")
except Exception as e:
logger.error(f"Error closing diarization service: {e}")