Files
disbord/services/audio/transcription_service.py
Travis Vasceannie 3acb779569 chore: remove .env.example and add new files for project structure
- Deleted .env.example file as it is no longer needed.
- Added .gitignore to manage ignored files and directories.
- Introduced CLAUDE.md for AI provider integration documentation.
- Created dev.sh for development setup and scripts.
- Updated Dockerfile and Dockerfile.production for improved build processes.
- Added multiple test files and directories for comprehensive testing.
- Introduced new utility and service files for enhanced functionality.
- Organized codebase with new directories and files for better maintainability.
2025-08-27 23:00:19 -04:00

827 lines
29 KiB
Python

"""
Transcription Service for Discord Voice Chat Quote Bot
Converts audio clips to text with speaker segment mapping, integrating with
speaker diarization results and AI providers for accurate transcription.
"""
import asyncio
import logging
import os
import tempfile
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Optional
from core.ai_manager import AIProviderManager, TranscriptionResult
from core.database import DatabaseManager
from utils.audio_processor import AudioProcessor
# Temporary: Comment out due to ONNX/ml_dtypes compatibility issue
# from .speaker_diarization import SpeakerDiarizationService, SpeakerSegment, DiarizationResult
# Temporary stubs for speaker diarization classes
class SpeakerDiarizationService:
pass
class SpeakerSegment:
pass
class DiarizationResult:
pass
logger = logging.getLogger(__name__)
@dataclass
class TranscribedSegment:
"""A transcribed audio segment with speaker information"""
start_time: float
end_time: float
speaker_label: str
text: str
confidence: float
user_id: Optional[int] = None
language: str = "en"
word_count: int = 0
is_quote_candidate: bool = False
@dataclass
class TranscriptionSession:
"""Complete transcription session for an audio clip"""
clip_id: str
guild_id: int
channel_id: int
audio_file_path: str
total_duration: float
transcribed_segments: list[TranscribedSegment]
processing_time: float
ai_provider_used: str
ai_model_used: str
total_words: int
timestamp: datetime
diarization_result: Optional[DiarizationResult] = None
class TranscriptionService:
"""
Audio transcription service with speaker segment mapping
Features:
- Multi-provider audio-to-text conversion
- Integration with speaker diarization
- Speaker segment text mapping
- Quote candidate identification
- Language detection and confidence scoring
- Caching and optimization
"""
def __init__(
self,
ai_manager: AIProviderManager,
db_manager: DatabaseManager,
speaker_diarization: SpeakerDiarizationService,
audio_processor: AudioProcessor,
):
self.ai_manager = ai_manager
self.db_manager = db_manager
self.speaker_diarization = speaker_diarization
self.audio_processor = audio_processor
# Transcription configuration
self.min_segment_duration = 0.5 # Minimum segment length to transcribe
self.max_segment_duration = (
30.0 # Maximum segment length for single transcription
)
self.quote_min_words = 3 # Minimum words to consider as quote candidate
self.quote_max_words = 100 # Maximum words for a single quote
self.confidence_threshold = 0.7 # Minimum confidence for reliable transcription
# Processing queues and caches
self.processing_queue = asyncio.Queue()
self.transcription_cache: dict[str, TranscriptionSession] = {}
self.cache_expiry = timedelta(hours=1)
# Background tasks
self._processing_task = None
self._cache_cleanup_task = None
# Statistics
self.total_transcriptions = 0
self.total_processing_time = 0
self.provider_usage_stats: dict[str, int] = {}
self._initialized = False
async def initialize(self):
"""Initialize the transcription service"""
if self._initialized:
return
try:
logger.info("Initializing transcription service...")
# Start background processing task
self._processing_task = asyncio.create_task(self._transcription_worker())
# Start cache cleanup task
self._cache_cleanup_task = asyncio.create_task(self._cache_cleanup_worker())
self._initialized = True
logger.info("Transcription service initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize transcription service: {e}")
raise
async def transcribe_audio_clip(
self,
audio_file_path: str,
guild_id: int,
channel_id: int,
diarization_result: Optional[DiarizationResult] = None,
clip_id: Optional[str] = None,
) -> Optional[TranscriptionSession]:
"""
Transcribe an audio clip with speaker segment mapping
Args:
audio_file_path: Path to the audio file
guild_id: Discord guild ID
channel_id: Discord channel ID
diarization_result: Speaker diarization results
clip_id: Optional clip identifier
Returns:
TranscriptionSession: Complete transcription results
"""
try:
if not self._initialized:
await self.initialize()
# Generate clip ID if not provided
if not clip_id:
clip_id = f"{guild_id}_{channel_id}_{int(datetime.now(timezone.utc).timestamp())}"
# Check cache first
if clip_id in self.transcription_cache:
cached_session = self.transcription_cache[clip_id]
if (
datetime.now(timezone.utc) - cached_session.timestamp
< self.cache_expiry
):
logger.debug(f"Using cached transcription for {clip_id}")
return cached_session
# Validate audio file
if not os.path.exists(audio_file_path):
logger.error(f"Audio file not found: {audio_file_path}")
return None
# Get audio duration
audio_info = await self.audio_processor.get_audio_info(audio_file_path)
total_duration = audio_info.get("duration", 0.0)
if total_duration == 0:
logger.warning(f"Audio file has zero duration: {audio_file_path}")
return None
# Queue for processing
result_future = asyncio.Future()
await self.processing_queue.put(
{
"clip_id": clip_id,
"audio_file_path": audio_file_path,
"guild_id": guild_id,
"channel_id": channel_id,
"total_duration": total_duration,
"diarization_result": diarization_result,
"result_future": result_future,
}
)
# Wait for processing result
transcription_session = await result_future
# Cache result
if transcription_session:
self.transcription_cache[clip_id] = transcription_session
return transcription_session
except Exception as e:
logger.error(f"Failed to transcribe audio clip: {e}")
return None
async def _transcription_worker(self):
"""Background worker for processing transcription requests"""
logger.info("Transcription worker started")
while True:
try:
# Get next transcription request
request = await self.processing_queue.get()
if request is None: # Shutdown signal
break
try:
session = await self._perform_transcription(
request["clip_id"],
request["audio_file_path"],
request["guild_id"],
request["channel_id"],
request["total_duration"],
request["diarization_result"],
)
request["result_future"].set_result(session)
except Exception as e:
logger.error(f"Error processing transcription request: {e}")
request["result_future"].set_exception(e)
finally:
self.processing_queue.task_done()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in transcription worker: {e}")
await asyncio.sleep(1)
async def _perform_transcription(
self,
clip_id: str,
audio_file_path: str,
guild_id: int,
channel_id: int,
total_duration: float,
diarization_result: Optional[DiarizationResult],
) -> TranscriptionSession:
"""Perform the actual transcription process"""
try:
start_time = datetime.now(timezone.utc)
processing_start = start_time.timestamp()
logger.info(f"Transcribing audio clip: {clip_id}")
transcribed_segments = []
if diarization_result and diarization_result.speaker_segments:
# Transcribe each speaker segment individually
transcribed_segments = await self._transcribe_speaker_segments(
audio_file_path, diarization_result.speaker_segments
)
else:
# Transcribe entire audio file as single segment
transcribed_segments = await self._transcribe_full_audio(
audio_file_path, total_duration
)
# Calculate statistics
processing_time = datetime.now(timezone.utc).timestamp() - processing_start
total_words = sum(segment.word_count for segment in transcribed_segments)
# Identify quote candidates
await self._identify_quote_candidates(transcribed_segments)
# Determine AI provider used (from first successful transcription)
ai_provider_used = "unknown"
ai_model_used = "unknown"
if transcribed_segments:
# This would be set during transcription
ai_provider_used = getattr(
transcribed_segments[0], "provider", "unknown"
)
ai_model_used = getattr(transcribed_segments[0], "model", "unknown")
# Create transcription session
session = TranscriptionSession(
clip_id=clip_id,
guild_id=guild_id,
channel_id=channel_id,
audio_file_path=audio_file_path,
total_duration=total_duration,
transcribed_segments=transcribed_segments,
processing_time=processing_time,
ai_provider_used=ai_provider_used,
ai_model_used=ai_model_used,
total_words=total_words,
timestamp=start_time,
diarization_result=diarization_result,
)
# Store transcription session in database
await self._store_transcription_session(session)
# Update statistics
self.total_transcriptions += 1
self.total_processing_time += processing_time
self.provider_usage_stats[ai_provider_used] = (
self.provider_usage_stats.get(ai_provider_used, 0) + 1
)
logger.info(
f"Transcription completed: {clip_id}, {len(transcribed_segments)} segments, "
f"{total_words} words, {processing_time:.2f}s"
)
return session
except Exception as e:
logger.error(f"Failed to perform transcription: {e}")
raise
async def _transcribe_speaker_segments(
self, audio_file_path: str, speaker_segments: list[SpeakerSegment]
) -> list[TranscribedSegment]:
"""Transcribe individual speaker segments"""
try:
transcribed_segments = []
for segment in speaker_segments:
# Skip very short segments
segment_duration = segment.end_time - segment.start_time
if segment_duration < self.min_segment_duration:
continue
# Extract audio segment
segment_audio = await self._extract_audio_segment(
audio_file_path, segment.start_time, segment.end_time
)
if not segment_audio:
continue
# Transcribe segment
transcription_result = await self._transcribe_audio_data(segment_audio)
if transcription_result and transcription_result.text.strip():
# Create transcribed segment
transcribed_segment = TranscribedSegment(
start_time=segment.start_time,
end_time=segment.end_time,
speaker_label=segment.speaker_label,
text=transcription_result.text.strip(),
confidence=transcription_result.confidence,
user_id=segment.user_id,
language=transcription_result.language,
word_count=len(transcription_result.text.split()),
)
# Store provider info (for statistics)
transcribed_segment.provider = getattr(
transcription_result, "provider", "unknown"
)
transcribed_segment.model = getattr(
transcription_result, "model", "unknown"
)
transcribed_segments.append(transcribed_segment)
return transcribed_segments
except Exception as e:
logger.error(f"Failed to transcribe speaker segments: {e}")
return []
async def _transcribe_full_audio(
self, audio_file_path: str, duration: float
) -> list[TranscribedSegment]:
"""Transcribe entire audio file as single segment"""
try:
# Load full audio file
with open(audio_file_path, "rb") as f:
audio_data = f.read()
# Transcribe using AI provider
transcription_result = await self._transcribe_audio_data(audio_data)
if transcription_result and transcription_result.text.strip():
transcribed_segment = TranscribedSegment(
start_time=0.0,
end_time=duration,
speaker_label="SPEAKER_UNKNOWN",
text=transcription_result.text.strip(),
confidence=transcription_result.confidence,
user_id=None,
language=transcription_result.language,
word_count=len(transcription_result.text.split()),
)
# Store provider info
transcribed_segment.provider = getattr(
transcription_result, "provider", "unknown"
)
transcribed_segment.model = getattr(
transcription_result, "model", "unknown"
)
return [transcribed_segment]
return []
except Exception as e:
logger.error(f"Failed to transcribe full audio: {e}")
return []
async def _extract_audio_segment(
self, audio_file_path: str, start_time: float, end_time: float
) -> Optional[bytes]:
"""Extract a specific time segment from audio file"""
try:
# Use ffmpeg to extract segment
import subprocess
# Create temporary file for segment
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
temp_path = temp_file.name
# Extract segment using ffmpeg
cmd = [
"ffmpeg",
"-i",
audio_file_path,
"-ss",
str(start_time),
"-t",
str(end_time - start_time),
"-acodec",
"copy",
"-y",
temp_path,
]
result = await asyncio.get_event_loop().run_in_executor(
None, lambda: subprocess.run(cmd, capture_output=True, text=True)
)
if result.returncode == 0 and os.path.exists(temp_path):
# Read extracted audio
with open(temp_path, "rb") as f:
audio_data = f.read()
# Clean up temporary file
os.unlink(temp_path)
return audio_data
else:
logger.error(f"FFmpeg segment extraction failed: {result.stderr}")
if os.path.exists(temp_path):
os.unlink(temp_path)
return None
except Exception as e:
logger.error(f"Failed to extract audio segment: {e}")
return None
async def _transcribe_audio_data(
self, audio_data: bytes
) -> Optional[TranscriptionResult]:
"""Transcribe audio data using AI provider"""
try:
# Use AI manager to transcribe with fallback
transcription_result = await self.ai_manager.transcribe(audio_data)
return transcription_result
except Exception as e:
logger.error(f"Failed to transcribe audio data: {e}")
return None
async def _identify_quote_candidates(self, segments: list[TranscribedSegment]):
"""Identify segments that could be memorable quotes"""
try:
for segment in segments:
# Check word count criteria
if (
self.quote_min_words <= segment.word_count <= self.quote_max_words
and segment.confidence >= self.confidence_threshold
):
# Additional heuristics for quote detection
text = segment.text.lower()
# Check for conversational markers
conversational_markers = [
"!",
"?",
"haha",
"lol",
"omg",
"wow",
"really",
"seriously",
"actually",
"honestly",
"literally",
]
# Check for emotional indicators
emotional_indicators = [
"love",
"hate",
"amazing",
"terrible",
"awesome",
"stupid",
"crazy",
"weird",
"funny",
"hilarious",
]
has_markers = any(
marker in text for marker in conversational_markers
)
has_emotion = any(
indicator in text for indicator in emotional_indicators
)
# Mark as quote candidate if it meets criteria
if (
has_markers
or has_emotion
or "!" in segment.text
or "?" in segment.text
):
segment.is_quote_candidate = True
except Exception as e:
logger.error(f"Failed to identify quote candidates: {e}")
async def _store_transcription_session(self, session: TranscriptionSession):
"""Store transcription session in database"""
try:
# Store main transcription record
transcription_id = await self.db_manager.execute_query(
"""
INSERT INTO transcription_sessions
(clip_id, guild_id, channel_id, audio_file_path, total_duration,
processing_time, ai_provider_used, ai_model_used, total_words,
timestamp)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
RETURNING id
""",
session.clip_id,
session.guild_id,
session.channel_id,
session.audio_file_path,
session.total_duration,
session.processing_time,
session.ai_provider_used,
session.ai_model_used,
session.total_words,
session.timestamp,
fetch_one=True,
)
transcription_id = transcription_id["id"]
# Store transcribed segments
for segment in session.transcribed_segments:
await self.db_manager.execute_query(
"""
INSERT INTO transcribed_segments
(transcription_id, start_time, end_time, speaker_label, text,
confidence, user_id, language, word_count, is_quote_candidate)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
""",
transcription_id,
segment.start_time,
segment.end_time,
segment.speaker_label,
segment.text,
segment.confidence,
segment.user_id,
segment.language,
segment.word_count,
segment.is_quote_candidate,
)
logger.debug(
f"Stored transcription session {session.clip_id} with {len(session.transcribed_segments)} segments"
)
except Exception as e:
logger.error(f"Failed to store transcription session: {e}")
async def get_transcription_by_clip_id(
self, clip_id: str
) -> Optional[TranscriptionSession]:
"""Get stored transcription session by clip ID"""
try:
# Check cache first
if clip_id in self.transcription_cache:
return self.transcription_cache[clip_id]
# Query database
session_data = await self.db_manager.execute_query(
"""
SELECT * FROM transcription_sessions WHERE clip_id = $1
""",
clip_id,
fetch_one=True,
)
if not session_data:
return None
# Get transcribed segments
segments_data = await self.db_manager.execute_query(
"""
SELECT * FROM transcribed_segments
WHERE transcription_id = $1
ORDER BY start_time
""",
session_data["id"],
fetch_all=True,
)
# Reconstruct session
segments = []
for seg_data in segments_data:
segment = TranscribedSegment(
start_time=float(seg_data["start_time"]),
end_time=float(seg_data["end_time"]),
speaker_label=seg_data["speaker_label"],
text=seg_data["text"],
confidence=float(seg_data["confidence"]),
user_id=seg_data["user_id"],
language=seg_data["language"],
word_count=seg_data["word_count"],
is_quote_candidate=seg_data["is_quote_candidate"],
)
segments.append(segment)
session = TranscriptionSession(
clip_id=session_data["clip_id"],
guild_id=session_data["guild_id"],
channel_id=session_data["channel_id"],
audio_file_path=session_data["audio_file_path"],
total_duration=float(session_data["total_duration"]),
transcribed_segments=segments,
processing_time=float(session_data["processing_time"]),
ai_provider_used=session_data["ai_provider_used"],
ai_model_used=session_data["ai_model_used"],
total_words=session_data["total_words"],
timestamp=session_data["timestamp"],
)
# Cache the result
self.transcription_cache[clip_id] = session
return session
except Exception as e:
logger.error(f"Failed to get transcription by clip ID: {e}")
return None
async def get_quote_candidates(
self, guild_id: int, hours_back: int = 24
) -> list[TranscribedSegment]:
"""Get quote candidates from recent transcriptions"""
try:
since_time = datetime.now(timezone.utc) - timedelta(hours=hours_back)
results = await self.db_manager.execute_query(
"""
SELECT ts.*, tss.*
FROM transcribed_segments tss
JOIN transcription_sessions ts ON tss.transcription_id = ts.id
WHERE ts.guild_id = $1
AND ts.timestamp > $2
AND tss.is_quote_candidate = TRUE
ORDER BY ts.timestamp DESC, tss.start_time ASC
""",
guild_id,
since_time,
fetch_all=True,
)
candidates = []
for result in results:
segment = TranscribedSegment(
start_time=float(result["start_time"]),
end_time=float(result["end_time"]),
speaker_label=result["speaker_label"],
text=result["text"],
confidence=float(result["confidence"]),
user_id=result["user_id"],
language=result["language"],
word_count=result["word_count"],
is_quote_candidate=result["is_quote_candidate"],
)
candidates.append(segment)
return candidates
except Exception as e:
logger.error(f"Failed to get quote candidates: {e}")
return []
async def _cache_cleanup_worker(self):
"""Background worker to clean up expired cache entries"""
while True:
try:
current_time = datetime.now(timezone.utc)
expired_keys = []
for key, session in self.transcription_cache.items():
if current_time - session.timestamp > self.cache_expiry:
expired_keys.append(key)
for key in expired_keys:
del self.transcription_cache[key]
if expired_keys:
logger.debug(
f"Cleaned up {len(expired_keys)} expired transcription cache entries"
)
# Sleep for 30 minutes
await asyncio.sleep(1800)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in transcription cache cleanup worker: {e}")
await asyncio.sleep(1800)
async def get_transcription_stats(self) -> dict[str, object]:
"""Get transcription service statistics"""
try:
avg_processing_time = (
self.total_processing_time / self.total_transcriptions
if self.total_transcriptions > 0
else 0.0
)
return {
"total_transcriptions": self.total_transcriptions,
"total_processing_time": self.total_processing_time,
"average_processing_time": avg_processing_time,
"cache_size": len(self.transcription_cache),
"queue_size": self.processing_queue.qsize(),
"provider_usage": self.provider_usage_stats.copy(),
}
except Exception as e:
logger.error(f"Failed to get transcription stats: {e}")
return {}
async def check_health(self) -> dict[str, object]:
"""Check health of transcription service"""
try:
health_status = {
"initialized": self._initialized,
"processing_queue_size": self.processing_queue.qsize(),
"cache_size": len(self.transcription_cache),
"total_transcriptions": self.total_transcriptions,
}
# Check AI manager health
ai_health = await self.ai_manager.check_health()
health_status["ai_manager_healthy"] = ai_health.get("healthy", False)
return health_status
except Exception as e:
return {"error": str(e), "healthy": False}
async def close(self):
"""Close transcription service"""
try:
logger.info("Closing transcription service...")
# Stop background tasks
if self._processing_task:
await self.processing_queue.put(None) # Signal shutdown
self._processing_task.cancel()
if self._cache_cleanup_task:
self._cache_cleanup_task.cancel()
# Wait for tasks to complete
if self._processing_task or self._cache_cleanup_task:
await asyncio.gather(
self._processing_task,
self._cache_cleanup_task,
return_exceptions=True,
)
# Clear cache
self.transcription_cache.clear()
logger.info("Transcription service closed")
except Exception as e:
logger.error(f"Error closing transcription service: {e}")