- Deleted .env.example file as it is no longer needed. - Added .gitignore to manage ignored files and directories. - Introduced CLAUDE.md for AI provider integration documentation. - Created dev.sh for development setup and scripts. - Updated Dockerfile and Dockerfile.production for improved build processes. - Added multiple test files and directories for comprehensive testing. - Introduced new utility and service files for enhanced functionality. - Organized codebase with new directories and files for better maintainability.
1114 lines
41 KiB
Python
1114 lines
41 KiB
Python
"""
|
|
Speaker Diarization Service for Discord Voice Chat Quote Bot
|
|
|
|
Integrates NVIDIA NeMo for automatic speaker separation and labeling.
|
|
Provides speaker segments that can be mapped to Discord users.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import tempfile
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timedelta, timezone
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import librosa
|
|
import numpy as np
|
|
import soundfile as sf
|
|
import torch
|
|
from omegaconf import DictConfig, OmegaConf
|
|
|
|
from core.consent_manager import ConsentManager
|
|
from core.database import DatabaseManager
|
|
from utils.audio_processor import AudioProcessor
|
|
|
|
# Set up logger first
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# NeMo imports (with fallback)
|
|
try:
|
|
from nemo.collections.asr.models import ClusteringDiarizer, NeuralDiarizer
|
|
from nemo.collections.asr.models.label_models import \
|
|
EncDecSpeakerLabelModel as EncDecDiarLabelModel
|
|
from nemo.utils import logging as nemo_logging
|
|
|
|
NEMO_AVAILABLE = True
|
|
logger.info("NVIDIA NeMo is available for speaker diarization")
|
|
except (ImportError, AttributeError) as e:
|
|
logger.warning(f"NeMo not available: {e}. Using fallback implementation.")
|
|
ClusteringDiarizer = None
|
|
NeuralDiarizer = None
|
|
EncDecDiarLabelModel = None
|
|
nemo_logging = None
|
|
NEMO_AVAILABLE = False
|
|
|
|
|
|
@dataclass
|
|
class SpeakerSegment:
|
|
"""Data structure for speaker segments."""
|
|
|
|
start_time: float
|
|
end_time: float
|
|
speaker_label: str
|
|
confidence: float
|
|
audio_data: Optional[bytes] = None
|
|
user_id: Optional[int] = None
|
|
needs_tagging: bool = True
|
|
|
|
|
|
@dataclass
|
|
class DiarizationResult:
|
|
"""Complete diarization result for an audio clip."""
|
|
|
|
audio_file_path: str
|
|
total_duration: float
|
|
speaker_segments: list[SpeakerSegment]
|
|
unique_speakers: list[str]
|
|
processing_time: float
|
|
timestamp: datetime
|
|
|
|
|
|
class NeMoDiarizationConfig:
|
|
"""Configuration manager for NeMo diarization models."""
|
|
|
|
def __init__(self, device: str = "cuda"):
|
|
self.device = device
|
|
self.vad_model = "vad_multilingual_marblenet"
|
|
self.speaker_model = "titanet_large"
|
|
self.neural_diarizer_model = "diar_msdd_telephonic"
|
|
|
|
# Processing parameters
|
|
self.sample_rate = 16000
|
|
self.min_segment_duration = 1.0
|
|
self.max_speakers = 8
|
|
self.window_length = 1.5
|
|
self.shift_length = 0.75
|
|
|
|
# VAD parameters
|
|
self.vad_onset = 0.8
|
|
self.vad_offset = 0.6
|
|
self.vad_pad_offset = -0.05
|
|
|
|
# Clustering parameters
|
|
self.oracle_num_speakers = False
|
|
self.max_num_speakers = 8
|
|
self.enhanced_count_thres = 80
|
|
self.sparse_search_volume = 30
|
|
|
|
def get_clustering_config(self) -> DictConfig:
|
|
"""Get configuration for clustering diarizer."""
|
|
config = OmegaConf.create(
|
|
{
|
|
"diarizer": {
|
|
"manifest_filepath": None,
|
|
"out_dir": None,
|
|
"oracle_vad": False,
|
|
"collar": 0.25,
|
|
"ignore_overlap": True,
|
|
"vad": {
|
|
"model_path": self.vad_model,
|
|
"parameters": {
|
|
"onset": self.vad_onset,
|
|
"offset": self.vad_offset,
|
|
"pad_offset": self.vad_pad_offset,
|
|
"min_duration_on": 0.1,
|
|
"min_duration_off": 0.1,
|
|
},
|
|
},
|
|
"speaker_embeddings": {
|
|
"model_path": self.speaker_model,
|
|
"parameters": {
|
|
"window_length_in_sec": self.window_length,
|
|
"shift_length_in_sec": self.shift_length,
|
|
"multiscale_weights": None,
|
|
"save_embeddings": False,
|
|
},
|
|
},
|
|
"clustering": {
|
|
"parameters": {
|
|
"oracle_num_speakers": self.oracle_num_speakers,
|
|
"max_num_speakers": self.max_num_speakers,
|
|
"enhanced_count_thres": self.enhanced_count_thres,
|
|
"sparse_search_volume": self.sparse_search_volume,
|
|
}
|
|
},
|
|
}
|
|
}
|
|
)
|
|
return config
|
|
|
|
def get_neural_config(self) -> DictConfig:
|
|
"""Get configuration for neural diarizer."""
|
|
config = self.get_clustering_config()
|
|
config.diarizer.msdd_model = {
|
|
"model_path": self.neural_diarizer_model,
|
|
"parameters": {"sigmoid_threshold": [0.7, 1.0]},
|
|
}
|
|
return config
|
|
|
|
|
|
class SpeakerDiarizationService:
|
|
"""
|
|
Speaker diarization service using NVIDIA NeMo.
|
|
|
|
Features:
|
|
- Automatic speaker separation using clustering and neural approaches
|
|
- Speaker labeling and tracking
|
|
- Integration with consent management
|
|
- Support for user-assisted tagging
|
|
- Caching of diarization results
|
|
- GPU acceleration support
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
db_manager: DatabaseManager,
|
|
consent_manager: ConsentManager,
|
|
audio_processor: AudioProcessor,
|
|
):
|
|
self.db_manager = db_manager
|
|
self.consent_manager = consent_manager
|
|
self.audio_processor = audio_processor
|
|
|
|
# Device configuration
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
logger.info(f"Initializing NeMo diarization service on {self.device}")
|
|
|
|
# Configuration
|
|
self.config = NeMoDiarizationConfig(device=self.device)
|
|
|
|
# Service parameters (for test compatibility)
|
|
self.min_speakers = 1
|
|
self.max_speakers = self.config.max_speakers
|
|
self.min_segment_duration = self.config.min_segment_duration
|
|
|
|
# Models
|
|
self.clustering_model: Optional[ClusteringDiarizer] = None
|
|
self.neural_model: Optional[NeuralDiarizer] = None
|
|
|
|
# Processing queues and caching
|
|
self.processing_queue: asyncio.Queue[dict[str, object]] = asyncio.Queue()
|
|
self.result_cache: dict[str, DiarizationResult] = {}
|
|
self.cache_expiry = timedelta(hours=2)
|
|
|
|
# State management
|
|
self._initialized = False
|
|
self._processing_task: Optional[asyncio.Task[None]] = None
|
|
|
|
# Temporary directory for audio files
|
|
self.temp_dir = Path(tempfile.mkdtemp(prefix="nemo_diarization_"))
|
|
|
|
async def initialize(self) -> None:
|
|
"""Initialize the diarization models and workers."""
|
|
if self._initialized:
|
|
return
|
|
|
|
try:
|
|
logger.info("Initializing NeMo speaker diarization service...")
|
|
|
|
# Suppress NeMo logging noise if available
|
|
if nemo_logging:
|
|
nemo_logging.setLevel(logging.WARNING)
|
|
|
|
# Load diarization models
|
|
await self._load_diarization_models()
|
|
|
|
# Start processing worker
|
|
self._processing_task = asyncio.create_task(self._processing_worker())
|
|
|
|
# Start cache cleanup task
|
|
asyncio.create_task(self._cache_cleanup_worker())
|
|
|
|
self._initialized = True
|
|
logger.info(
|
|
f"NeMo diarization service initialized successfully on {self.device}"
|
|
)
|
|
|
|
except (ImportError, AttributeError) as e:
|
|
logger.warning(
|
|
f"Failed to initialize NeMo diarization service due to dependency issue: {e}"
|
|
)
|
|
logger.info("Continuing with fallback audio processing capabilities")
|
|
self._initialized = True # Still initialize service with fallback
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize NeMo diarization service: {e}")
|
|
raise
|
|
|
|
async def _load_nemo_models(self) -> bool:
|
|
"""Load NeMo models (compatibility method for tests)."""
|
|
try:
|
|
await self._load_diarization_models()
|
|
return True
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load NeMo models: {e}")
|
|
return False
|
|
|
|
async def _load_diarization_models(self) -> None:
|
|
"""Load NeMo diarization models or use fallback."""
|
|
if not NEMO_AVAILABLE:
|
|
logger.warning("NeMo not available. Using basic audio processing fallback.")
|
|
self.clustering_model = None
|
|
self.neural_model = None
|
|
return
|
|
|
|
try:
|
|
# Load models in thread pool to avoid blocking
|
|
def load_models() -> (
|
|
tuple[Optional[ClusteringDiarizer], Optional[NeuralDiarizer]]
|
|
):
|
|
try:
|
|
# Create clustering diarizer
|
|
clustering_config = self.config.get_clustering_config()
|
|
clustering_model = ClusteringDiarizer(
|
|
cfg=clustering_config.diarizer
|
|
)
|
|
logger.info("Clustering diarizer loaded successfully")
|
|
except (ImportError, AttributeError, Exception) as e:
|
|
logger.warning(f"Failed to load clustering diarizer: {e}")
|
|
clustering_model = None
|
|
|
|
try:
|
|
# Create neural diarizer for advanced scenarios
|
|
neural_config = self.config.get_neural_config()
|
|
neural_model = NeuralDiarizer(cfg=neural_config.diarizer)
|
|
logger.info("Neural diarizer loaded successfully")
|
|
except (ImportError, AttributeError, Exception) as e:
|
|
logger.warning(f"Failed to load neural diarizer: {e}")
|
|
neural_model = None
|
|
|
|
return clustering_model, neural_model
|
|
|
|
self.clustering_model, self.neural_model = (
|
|
await asyncio.get_event_loop().run_in_executor(None, load_models)
|
|
)
|
|
|
|
if not self.clustering_model and not self.neural_model:
|
|
logger.warning("No NeMo models loaded. Using fallback implementation.")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load NeMo models: {e}")
|
|
logger.warning("Falling back to basic audio processing.")
|
|
|
|
async def process_audio_clip(
|
|
self,
|
|
audio_file_path: str,
|
|
guild_id: int,
|
|
channel_id: int,
|
|
participants: list[int],
|
|
) -> Optional[DiarizationResult]:
|
|
"""
|
|
Process audio clip for speaker diarization.
|
|
|
|
Args:
|
|
audio_file_path: Path to audio file
|
|
guild_id: Discord guild ID
|
|
channel_id: Discord channel ID
|
|
participants: List of user IDs who were in the channel
|
|
|
|
Returns:
|
|
DiarizationResult or None if processing failed
|
|
"""
|
|
try:
|
|
if not self._initialized:
|
|
await self.initialize()
|
|
|
|
# Check cache first
|
|
cache_key = f"{audio_file_path}_{hash(tuple(participants))}"
|
|
if cache_key in self.result_cache:
|
|
cached_result = self.result_cache[cache_key]
|
|
if (
|
|
datetime.now(timezone.utc) - cached_result.timestamp
|
|
< self.cache_expiry
|
|
):
|
|
logger.debug(
|
|
f"Using cached diarization result for {audio_file_path}"
|
|
)
|
|
return cached_result
|
|
|
|
# Validate consent for all participants
|
|
consented_users = []
|
|
for user_id in participants:
|
|
has_consent = await self.consent_manager.has_recording_consent(
|
|
user_id, guild_id
|
|
)
|
|
if has_consent:
|
|
consented_users.append(user_id)
|
|
|
|
if not consented_users:
|
|
logger.info(
|
|
f"No consented users in channel {channel_id}, skipping diarization"
|
|
)
|
|
return None
|
|
|
|
# Queue for processing
|
|
result_future: asyncio.Future[Optional[DiarizationResult]] = (
|
|
asyncio.Future()
|
|
)
|
|
await self.processing_queue.put(
|
|
{
|
|
"audio_file_path": audio_file_path,
|
|
"guild_id": guild_id,
|
|
"channel_id": channel_id,
|
|
"participants": consented_users,
|
|
"result_future": result_future,
|
|
}
|
|
)
|
|
|
|
# Wait for processing result
|
|
result = await result_future
|
|
|
|
# Cache result
|
|
if result:
|
|
self.result_cache[cache_key] = result
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to process audio clip for diarization: {e}")
|
|
return None
|
|
|
|
async def _processing_worker(self) -> None:
|
|
"""Background worker for processing diarization requests."""
|
|
while True:
|
|
try:
|
|
# Get next processing request
|
|
request = await self.processing_queue.get()
|
|
|
|
try:
|
|
result = await self._perform_diarization(
|
|
request["audio_file_path"],
|
|
request["guild_id"],
|
|
request["channel_id"],
|
|
request["participants"],
|
|
)
|
|
request["result_future"].set_result(result)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing diarization request: {e}")
|
|
request["result_future"].set_exception(e)
|
|
|
|
finally:
|
|
self.processing_queue.task_done()
|
|
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Error in diarization processing worker: {e}")
|
|
await asyncio.sleep(1)
|
|
|
|
async def _perform_diarization(
|
|
self,
|
|
audio_file_path: str,
|
|
guild_id: int,
|
|
channel_id: int,
|
|
participants: list[int],
|
|
) -> Optional[DiarizationResult]:
|
|
"""Perform actual speaker diarization using NeMo models."""
|
|
try:
|
|
start_time = datetime.now(timezone.utc)
|
|
|
|
# Prepare audio file for NeMo processing
|
|
prepared_audio_path = await self._prepare_audio_file(audio_file_path)
|
|
if not prepared_audio_path:
|
|
return None
|
|
|
|
# Create manifest file for NeMo
|
|
manifest_path = await self._create_manifest_file(prepared_audio_path)
|
|
|
|
# Perform diarization using available model
|
|
diarization_output = await self._run_nemo_diarization(
|
|
manifest_path, len(participants)
|
|
)
|
|
|
|
if not diarization_output:
|
|
logger.warning(f"No speakers detected in {audio_file_path}")
|
|
return None
|
|
|
|
# Convert NeMo output to our format
|
|
speaker_segments = await self._convert_nemo_output_to_segments(
|
|
diarization_output, audio_file_path, prepared_audio_path
|
|
)
|
|
|
|
# Attempt speaker identification
|
|
speaker_segments = await self._identify_speakers(
|
|
speaker_segments, guild_id, participants
|
|
)
|
|
|
|
# Get unique speakers
|
|
unique_speakers = list(
|
|
set(segment.speaker_label for segment in speaker_segments)
|
|
)
|
|
|
|
# Calculate total duration
|
|
total_duration = await self._get_audio_duration(prepared_audio_path)
|
|
processing_time = (datetime.now(timezone.utc) - start_time).total_seconds()
|
|
|
|
result = DiarizationResult(
|
|
audio_file_path=audio_file_path,
|
|
total_duration=total_duration,
|
|
speaker_segments=speaker_segments,
|
|
unique_speakers=unique_speakers,
|
|
processing_time=processing_time,
|
|
timestamp=datetime.now(timezone.utc),
|
|
)
|
|
|
|
# Store result in database
|
|
await self._store_diarization_result(result, guild_id, channel_id)
|
|
|
|
# Cleanup temporary files
|
|
await self._cleanup_temp_files([prepared_audio_path, manifest_path])
|
|
|
|
logger.info(
|
|
f"Diarization complete: {len(unique_speakers)} speakers, "
|
|
f"{len(speaker_segments)} segments, {processing_time:.2f}s"
|
|
)
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to perform NeMo diarization: {e}")
|
|
return None
|
|
|
|
async def _prepare_audio_file(self, audio_file_path: str) -> Optional[str]:
|
|
"""Prepare audio file for NeMo processing (convert to 16kHz mono WAV)."""
|
|
try:
|
|
|
|
def convert_audio() -> str:
|
|
# Load audio file
|
|
audio_data, sample_rate = librosa.load(audio_file_path, sr=None)
|
|
|
|
# Convert to mono if needed
|
|
if len(audio_data.shape) > 1:
|
|
audio_data = librosa.to_mono(audio_data)
|
|
|
|
# Resample to 16kHz for NeMo
|
|
if sample_rate != self.config.sample_rate:
|
|
audio_data = librosa.resample(
|
|
audio_data,
|
|
orig_sr=sample_rate,
|
|
target_sr=self.config.sample_rate,
|
|
)
|
|
|
|
# Save as WAV file in temp directory
|
|
output_path = (
|
|
self.temp_dir / f"processed_{Path(audio_file_path).stem}.wav"
|
|
)
|
|
sf.write(str(output_path), audio_data, self.config.sample_rate)
|
|
|
|
return str(output_path)
|
|
|
|
return await asyncio.get_event_loop().run_in_executor(None, convert_audio)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to prepare audio file {audio_file_path}: {e}")
|
|
return None
|
|
|
|
async def _create_manifest_file(self, audio_path: str) -> str:
|
|
"""Create manifest file required by NeMo."""
|
|
try:
|
|
manifest_data = {
|
|
"audio_filepath": audio_path,
|
|
"offset": 0,
|
|
"duration": None,
|
|
"label": "infer",
|
|
"text": "-",
|
|
"num_speakers": None,
|
|
"rttm_filepath": None,
|
|
"uem_filepath": None,
|
|
}
|
|
|
|
manifest_path = self.temp_dir / f"manifest_{Path(audio_path).stem}.json"
|
|
|
|
with open(manifest_path, "w", encoding="utf-8") as f:
|
|
json.dump(manifest_data, f)
|
|
f.write("\n") # NeMo expects newline-separated JSON
|
|
|
|
return str(manifest_path)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to create manifest file: {e}")
|
|
raise
|
|
|
|
async def _run_nemo_diarization(
|
|
self, manifest_path: str, expected_speakers: int
|
|
) -> Optional[dict[str, object]]:
|
|
"""Run NeMo diarization or fallback implementation."""
|
|
# If NeMo is not available, use basic voice activity detection
|
|
if not NEMO_AVAILABLE or (not self.clustering_model and not self.neural_model):
|
|
return await self._run_fallback_diarization(
|
|
manifest_path, expected_speakers
|
|
)
|
|
|
|
try:
|
|
|
|
def run_diarization() -> Optional[dict[str, object]]:
|
|
try:
|
|
# Choose model based on availability and expected speakers
|
|
if self.neural_model and expected_speakers >= 3:
|
|
model = self.neural_model
|
|
logger.debug("Using neural diarizer for complex scenario")
|
|
elif self.clustering_model:
|
|
model = self.clustering_model
|
|
logger.debug("Using clustering diarizer")
|
|
else:
|
|
return None
|
|
|
|
# Update model configuration with manifest path
|
|
model.cfg.manifest_filepath = manifest_path
|
|
model.cfg.out_dir = str(self.temp_dir)
|
|
|
|
# Run diarization
|
|
model.diarize()
|
|
|
|
# Load results from RTTM file
|
|
manifest_stem = Path(manifest_path).stem
|
|
rttm_path = (
|
|
self.temp_dir
|
|
/ "pred_rttms"
|
|
/ f"{manifest_stem.replace('manifest_', '')}.rttm"
|
|
)
|
|
|
|
if not rttm_path.exists():
|
|
logger.error(f"RTTM file not found: {rttm_path}")
|
|
return None
|
|
|
|
# Parse RTTM file
|
|
segments = []
|
|
with open(rttm_path, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
if line.strip():
|
|
parts = line.strip().split()
|
|
if len(parts) >= 8:
|
|
start_time = float(parts[3])
|
|
duration = float(parts[4])
|
|
speaker_id = parts[7]
|
|
segments.append(
|
|
{
|
|
"start": start_time,
|
|
"end": start_time + duration,
|
|
"speaker": speaker_id,
|
|
}
|
|
)
|
|
|
|
return {"segments": segments}
|
|
|
|
except Exception as e:
|
|
logger.error(f"NeMo diarization failed: {e}")
|
|
return None
|
|
|
|
result = await asyncio.get_event_loop().run_in_executor(
|
|
None, run_diarization
|
|
)
|
|
|
|
# If NeMo failed, fall back to basic implementation
|
|
if result is None:
|
|
logger.warning("NeMo diarization failed, using fallback")
|
|
return await self._run_fallback_diarization(
|
|
manifest_path, expected_speakers
|
|
)
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to run NeMo diarization: {e}")
|
|
return await self._run_fallback_diarization(
|
|
manifest_path, expected_speakers
|
|
)
|
|
|
|
async def _run_fallback_diarization(
|
|
self, manifest_path: str, expected_speakers: int
|
|
) -> Optional[dict[str, object]]:
|
|
"""Run basic voice activity detection as fallback when NeMo is unavailable."""
|
|
try:
|
|
# Load manifest to get audio path
|
|
with open(manifest_path, "r", encoding="utf-8") as f:
|
|
manifest_data = json.load(f)
|
|
|
|
audio_path = manifest_data["audio_filepath"]
|
|
|
|
def basic_diarization() -> Optional[dict[str, object]]:
|
|
try:
|
|
# Load audio
|
|
audio_data, sample_rate = librosa.load(
|
|
audio_path, sr=16000, mono=True
|
|
)
|
|
duration = len(audio_data) / sample_rate
|
|
|
|
# Simple voice activity detection using energy
|
|
frame_length = int(0.025 * sample_rate) # 25ms frames
|
|
hop_length = int(0.010 * sample_rate) # 10ms hop
|
|
|
|
# Calculate energy in each frame
|
|
energy = []
|
|
for i in range(0, len(audio_data) - frame_length, hop_length):
|
|
frame = audio_data[i : i + frame_length]
|
|
frame_energy = np.sum(frame**2)
|
|
energy.append(frame_energy)
|
|
|
|
energy = np.array(energy)
|
|
threshold = np.mean(energy) + 1.5 * np.std(energy)
|
|
voice_frames = energy > threshold
|
|
|
|
# Create segments for voice activity
|
|
segments = []
|
|
in_speech = False
|
|
start_time = 0
|
|
current_speaker = 0
|
|
speaker_change_interval = (
|
|
10.0 # Change speaker every 10 seconds (basic heuristic)
|
|
)
|
|
next_speaker_change = speaker_change_interval
|
|
|
|
frame_duration = hop_length / sample_rate
|
|
|
|
for i, is_voice in enumerate(voice_frames):
|
|
current_time = i * frame_duration
|
|
|
|
# Simple speaker change heuristic
|
|
if current_time > next_speaker_change:
|
|
current_speaker = (current_speaker + 1) % max(
|
|
2, min(expected_speakers, 4)
|
|
)
|
|
next_speaker_change += speaker_change_interval
|
|
|
|
if is_voice and not in_speech:
|
|
# Start of speech
|
|
start_time = current_time
|
|
in_speech = True
|
|
elif not is_voice and in_speech:
|
|
# End of speech
|
|
if (
|
|
current_time - start_time
|
|
>= self.config.min_segment_duration
|
|
):
|
|
segments.append(
|
|
{
|
|
"start": start_time,
|
|
"end": current_time,
|
|
"speaker": f"SPEAKER_{current_speaker:02d}",
|
|
}
|
|
)
|
|
in_speech = False
|
|
|
|
# Handle case where speech continues to end
|
|
if in_speech:
|
|
if duration - start_time >= self.config.min_segment_duration:
|
|
segments.append(
|
|
{
|
|
"start": start_time,
|
|
"end": duration,
|
|
"speaker": f"SPEAKER_{current_speaker:02d}",
|
|
}
|
|
)
|
|
|
|
logger.info(
|
|
f"Fallback diarization found {len(segments)} speech segments"
|
|
)
|
|
return {"segments": segments} if segments else None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Fallback diarization failed: {e}")
|
|
return None
|
|
|
|
return await asyncio.get_event_loop().run_in_executor(
|
|
None, basic_diarization
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to run fallback diarization: {e}")
|
|
return None
|
|
|
|
async def _convert_nemo_output_to_segments(
|
|
self, nemo_output: dict[str, object], original_path: str, prepared_path: str
|
|
) -> list[SpeakerSegment]:
|
|
"""Convert NeMo diarization output to SpeakerSegment objects."""
|
|
try:
|
|
segments = []
|
|
nemo_segments = nemo_output.get("segments", [])
|
|
|
|
# Load audio for segment extraction
|
|
audio_data, sample_rate = await self._load_audio_async(prepared_path)
|
|
|
|
for seg in nemo_segments:
|
|
start_time = float(seg["start"])
|
|
end_time = float(seg["end"])
|
|
speaker_label = seg["speaker"]
|
|
|
|
# Filter segments that are too short
|
|
if end_time - start_time < self.config.min_segment_duration:
|
|
continue
|
|
|
|
# Extract audio data for this segment
|
|
start_sample = int(start_time * sample_rate)
|
|
end_sample = int(end_time * sample_rate)
|
|
|
|
if start_sample < len(audio_data) and end_sample > start_sample:
|
|
segment_audio = audio_data[start_sample:end_sample]
|
|
|
|
# Convert to bytes for storage
|
|
audio_bytes = await self.audio_processor.numpy_to_bytes(
|
|
segment_audio, sample_rate
|
|
)
|
|
|
|
speaker_segment = SpeakerSegment(
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
speaker_label=speaker_label,
|
|
confidence=1.0, # NeMo doesn't provide confidence scores directly
|
|
audio_data=audio_bytes,
|
|
user_id=None, # Will be filled by speaker identification
|
|
needs_tagging=True,
|
|
)
|
|
|
|
segments.append(speaker_segment)
|
|
|
|
return segments
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to convert NeMo output to segments: {e}")
|
|
return []
|
|
|
|
async def _load_audio_async(self, audio_path: str) -> tuple[np.ndarray, int]:
|
|
"""Load audio file asynchronously."""
|
|
|
|
def load_audio() -> tuple[np.ndarray, int]:
|
|
audio_data, sample_rate = sf.read(audio_path)
|
|
return audio_data, sample_rate
|
|
|
|
return await asyncio.get_event_loop().run_in_executor(None, load_audio)
|
|
|
|
async def _get_audio_duration(self, audio_path: str) -> float:
|
|
"""Get audio file duration."""
|
|
try:
|
|
|
|
def get_duration() -> float:
|
|
with sf.SoundFile(audio_path) as f:
|
|
return len(f) / f.samplerate
|
|
|
|
return await asyncio.get_event_loop().run_in_executor(None, get_duration)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get audio duration: {e}")
|
|
return 0.0
|
|
|
|
async def _identify_speakers(
|
|
self, segments: list[SpeakerSegment], guild_id: int, participants: list[int]
|
|
) -> list[SpeakerSegment]:
|
|
"""Attempt to identify speakers using stored voice profiles."""
|
|
try:
|
|
# Get known speaker profiles for participants
|
|
speaker_profiles = await self.db_manager.execute_query(
|
|
"""
|
|
SELECT user_id, voice_embedding, username
|
|
FROM speaker_profiles
|
|
WHERE user_id = ANY($1) AND guild_id = $2 AND voice_embedding IS NOT NULL
|
|
""",
|
|
participants,
|
|
guild_id,
|
|
fetch_all=True,
|
|
)
|
|
|
|
if not speaker_profiles:
|
|
logger.debug("No speaker profiles available for identification")
|
|
return segments
|
|
|
|
# For each segment, try to match with known profiles
|
|
for segment in segments:
|
|
if segment.audio_data:
|
|
# Generate embedding for segment
|
|
segment_embedding = await self._generate_voice_embedding(
|
|
segment.audio_data
|
|
)
|
|
|
|
if segment_embedding:
|
|
# Find best match among known profiles
|
|
best_match = await self._find_best_speaker_match(
|
|
segment_embedding, speaker_profiles
|
|
)
|
|
|
|
if best_match:
|
|
segment.user_id = best_match["user_id"]
|
|
segment.speaker_label = (
|
|
best_match["username"]
|
|
or f"User_{best_match['user_id']}"
|
|
)
|
|
segment.confidence = best_match["confidence"]
|
|
segment.needs_tagging = False
|
|
|
|
return segments
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to identify speakers: {e}")
|
|
return segments
|
|
|
|
async def _generate_voice_embedding(
|
|
self, audio_data: bytes
|
|
) -> Optional[np.ndarray]:
|
|
"""Generate voice embedding for speaker identification."""
|
|
try:
|
|
# This would integrate with a speaker recognition model
|
|
# For now, return None to indicate no embedding available
|
|
logger.debug("Voice embedding generation not yet implemented")
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to generate voice embedding: {e}")
|
|
return None
|
|
|
|
async def _find_best_speaker_match(
|
|
self, segment_embedding: np.ndarray, speaker_profiles: list[dict[str, object]]
|
|
) -> Optional[dict[str, object]]:
|
|
"""Find best matching speaker profile."""
|
|
try:
|
|
# This would implement speaker matching logic
|
|
logger.debug("Speaker matching not yet implemented")
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to find speaker match: {e}")
|
|
return None
|
|
|
|
async def _store_diarization_result(
|
|
self, result: DiarizationResult, guild_id: int, channel_id: int
|
|
) -> None:
|
|
"""Store diarization result in database."""
|
|
try:
|
|
# Store main diarization record
|
|
diarization_id = await self.db_manager.execute_query(
|
|
"""
|
|
INSERT INTO speaker_diarizations
|
|
(guild_id, channel_id, audio_file_path, total_duration, unique_speakers, processing_time)
|
|
VALUES ($1, $2, $3, $4, $5, $6)
|
|
RETURNING id
|
|
""",
|
|
guild_id,
|
|
channel_id,
|
|
result.audio_file_path,
|
|
result.total_duration,
|
|
len(result.unique_speakers),
|
|
result.processing_time,
|
|
fetch_one=True,
|
|
)
|
|
|
|
diarization_id = diarization_id["id"]
|
|
|
|
# Store speaker segments
|
|
for segment in result.speaker_segments:
|
|
await self.db_manager.execute_query(
|
|
"""
|
|
INSERT INTO speaker_segments
|
|
(diarization_id, start_time, end_time, speaker_label, confidence, user_id, needs_tagging)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
|
""",
|
|
diarization_id,
|
|
segment.start_time,
|
|
segment.end_time,
|
|
segment.speaker_label,
|
|
segment.confidence,
|
|
segment.user_id,
|
|
segment.needs_tagging,
|
|
)
|
|
|
|
logger.debug(
|
|
f"Stored diarization result {diarization_id} with {len(result.speaker_segments)} segments"
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to store diarization result: {e}")
|
|
|
|
async def _cleanup_temp_files(self, file_paths: list[str]) -> None:
|
|
"""Clean up temporary files."""
|
|
try:
|
|
|
|
def cleanup():
|
|
for file_path in file_paths:
|
|
try:
|
|
Path(file_path).unlink(missing_ok=True)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to cleanup temp file {file_path}: {e}")
|
|
|
|
await asyncio.get_event_loop().run_in_executor(None, cleanup)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during temp file cleanup: {e}")
|
|
|
|
async def _cache_cleanup_worker(self) -> None:
|
|
"""Background worker to clean up expired cache entries."""
|
|
while True:
|
|
try:
|
|
current_time = datetime.now(timezone.utc)
|
|
expired_keys = []
|
|
|
|
for key, result in self.result_cache.items():
|
|
if current_time - result.timestamp > self.cache_expiry:
|
|
expired_keys.append(key)
|
|
|
|
for key in expired_keys:
|
|
del self.result_cache[key]
|
|
|
|
if expired_keys:
|
|
logger.debug(
|
|
f"Cleaned up {len(expired_keys)} expired cache entries"
|
|
)
|
|
|
|
# Sleep for 30 minutes
|
|
await asyncio.sleep(1800)
|
|
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Error in cache cleanup worker: {e}")
|
|
await asyncio.sleep(1800)
|
|
|
|
# Public API methods (maintaining compatibility)
|
|
|
|
async def get_speaker_segments(
|
|
self, audio_file_path: str
|
|
) -> Optional[list[SpeakerSegment]]:
|
|
"""Get speaker segments for an audio file."""
|
|
try:
|
|
results = await self.db_manager.execute_query(
|
|
"""
|
|
SELECT ss.start_time, ss.end_time, ss.speaker_label, ss.confidence,
|
|
ss.user_id, ss.needs_tagging
|
|
FROM speaker_segments ss
|
|
JOIN speaker_diarizations sd ON ss.diarization_id = sd.id
|
|
WHERE sd.audio_file_path = $1
|
|
ORDER BY ss.start_time
|
|
""",
|
|
audio_file_path,
|
|
fetch_all=True,
|
|
)
|
|
|
|
segments = []
|
|
for result in results:
|
|
segment = SpeakerSegment(
|
|
start_time=float(result["start_time"]),
|
|
end_time=float(result["end_time"]),
|
|
speaker_label=result["speaker_label"],
|
|
confidence=float(result["confidence"]),
|
|
user_id=result["user_id"],
|
|
needs_tagging=result["needs_tagging"],
|
|
)
|
|
segments.append(segment)
|
|
|
|
return segments
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get speaker segments: {e}")
|
|
return None
|
|
|
|
async def tag_speaker_segment(
|
|
self, segment_id: int, user_id: int, username: str
|
|
) -> None:
|
|
"""Tag a speaker segment with user identification (user-assisted)."""
|
|
try:
|
|
await self.db_manager.execute_query(
|
|
"""
|
|
UPDATE speaker_segments
|
|
SET user_id = $1, speaker_label = $2, needs_tagging = FALSE
|
|
WHERE id = $3
|
|
""",
|
|
user_id,
|
|
username,
|
|
segment_id,
|
|
)
|
|
|
|
logger.info(
|
|
f"Tagged speaker segment {segment_id} as user {user_id} ({username})"
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to tag speaker segment: {e}")
|
|
raise
|
|
|
|
async def get_untagged_segments(
|
|
self, guild_id: int, limit: int = 10
|
|
) -> list[dict[str, object]]:
|
|
"""Get untagged speaker segments for user assistance."""
|
|
try:
|
|
results = await self.db_manager.execute_query(
|
|
"""
|
|
SELECT ss.id, ss.start_time, ss.end_time, ss.speaker_label,
|
|
sd.audio_file_path, sd.guild_id, sd.channel_id
|
|
FROM speaker_segments ss
|
|
JOIN speaker_diarizations sd ON ss.diarization_id = sd.id
|
|
WHERE sd.guild_id = $1 AND ss.needs_tagging = TRUE
|
|
ORDER BY sd.timestamp DESC, ss.start_time ASC
|
|
LIMIT $2
|
|
""",
|
|
guild_id,
|
|
limit,
|
|
fetch_all=True,
|
|
)
|
|
|
|
return [dict(result) for result in results]
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get untagged segments: {e}")
|
|
return []
|
|
|
|
async def check_health(self) -> dict[str, object]:
|
|
"""Check health of diarization service."""
|
|
try:
|
|
health_status: dict[str, object] = {
|
|
"initialized": self._initialized,
|
|
"nemo_available": NEMO_AVAILABLE,
|
|
"clustering_model_loaded": self.clustering_model is not None,
|
|
"neural_model_loaded": self.neural_model is not None,
|
|
"device": self.device,
|
|
"queue_size": self.processing_queue.qsize(),
|
|
"cache_size": len(self.result_cache),
|
|
"framework": (
|
|
"NVIDIA NeMo (with fallback)"
|
|
if NEMO_AVAILABLE
|
|
else "Basic Audio Processing"
|
|
),
|
|
}
|
|
|
|
if NEMO_AVAILABLE and (self.clustering_model or self.neural_model):
|
|
health_status.update(
|
|
{
|
|
"vad_model": self.config.vad_model,
|
|
"speaker_model": self.config.speaker_model,
|
|
"neural_diarizer_model": self.config.neural_diarizer_model,
|
|
}
|
|
)
|
|
else:
|
|
health_status["fallback_mode"] = True
|
|
|
|
return health_status
|
|
|
|
except Exception as e:
|
|
return {"error": str(e), "healthy": False}
|
|
|
|
async def close(self) -> None:
|
|
"""Close diarization service and cleanup resources."""
|
|
try:
|
|
if self._processing_task:
|
|
self._processing_task.cancel()
|
|
try:
|
|
await self._processing_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
# Clear cache
|
|
self.result_cache.clear()
|
|
|
|
# Cleanup temp directory
|
|
if self.temp_dir.exists():
|
|
|
|
def cleanup_temp_dir():
|
|
import shutil
|
|
|
|
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
|
|
|
await asyncio.get_event_loop().run_in_executor(None, cleanup_temp_dir)
|
|
|
|
logger.info("NeMo speaker diarization service closed")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error closing diarization service: {e}")
|