Files
noteflow/spikes/spike_03_asr_latency/engine_impl.py
Travis Vasceannie af1285b181 Add initial project structure and files
- Introduced .python-version for Python version management.
- Added AGENTS.md for documentation on agent usage and best practices.
- Created alembic.ini for database migration configurations.
- Implemented main.py as the entry point for the application.
- Established pyproject.toml for project dependencies and configurations.
- Initialized README.md for project overview.
- Generated uv.lock for dependency locking.
- Documented milestones and specifications in docs/milestones.md and docs/spec.md.
- Created logs/status_line.json for logging status information.
- Added initial spike implementations for UI tray hotkeys, audio capture, ASR latency, and encryption validation.
- Set up NoteFlow core structure in src/noteflow with necessary modules and services.
- Developed test suite in tests directory for application, domain, infrastructure, and integration testing.
- Included initial migration scripts in infrastructure/persistence/migrations for database setup.
- Established security protocols in infrastructure/security for key management and encryption.
- Implemented audio infrastructure for capturing and processing audio data.
- Created converters for ASR and ORM in infrastructure/converters.
- Added export functionality for different formats in infrastructure/export.
- Ensured all new files are included in the repository for future development.
2025-12-17 18:28:59 +00:00

179 lines
4.9 KiB
Python

"""ASR engine implementation using faster-whisper.
Provides Whisper-based transcription with word-level timestamps.
"""
from __future__ import annotations
import logging
from collections.abc import Iterator
from typing import TYPE_CHECKING, Final
if TYPE_CHECKING:
import numpy as np
from numpy.typing import NDArray
from .dto import AsrResult, WordTiming
logger = logging.getLogger(__name__)
# Available model sizes
VALID_MODEL_SIZES: Final[tuple[str, ...]] = (
"tiny",
"tiny.en",
"base",
"base.en",
"small",
"small.en",
"medium",
"medium.en",
"large-v1",
"large-v2",
"large-v3",
)
class FasterWhisperEngine:
"""faster-whisper based ASR engine.
Uses CTranslate2 for efficient Whisper inference on CPU or GPU.
"""
def __init__(
self,
compute_type: str = "int8",
device: str = "cpu",
num_workers: int = 1,
) -> None:
"""Initialize the engine.
Args:
compute_type: Computation type ("int8", "float16", "float32").
device: Device to use ("cpu" or "cuda").
num_workers: Number of worker threads.
"""
self._compute_type = compute_type
self._device = device
self._num_workers = num_workers
self._model = None
self._model_size: str | None = None
def load_model(self, model_size: str = "base") -> None:
"""Load the ASR model.
Args:
model_size: Model size (e.g., "tiny", "base", "small").
Raises:
ValueError: If model_size is invalid.
RuntimeError: If model loading fails.
"""
from faster_whisper import WhisperModel
if model_size not in VALID_MODEL_SIZES:
raise ValueError(
f"Invalid model size: {model_size}. "
f"Valid sizes: {', '.join(VALID_MODEL_SIZES)}"
)
logger.info(
"Loading Whisper model '%s' on %s with %s compute...",
model_size,
self._device,
self._compute_type,
)
try:
self._model = WhisperModel(
model_size,
device=self._device,
compute_type=self._compute_type,
num_workers=self._num_workers,
)
self._model_size = model_size
logger.info("Model loaded successfully")
except Exception as e:
raise RuntimeError(f"Failed to load model: {e}") from e
def transcribe(
self,
audio: "NDArray[np.float32]",
language: str | None = None,
) -> Iterator[AsrResult]:
"""Transcribe audio and yield results.
Args:
audio: Audio samples as float32 array (16kHz mono, normalized).
language: Optional language code (e.g., "en").
Yields:
AsrResult segments with word-level timestamps.
"""
if self._model is None:
raise RuntimeError("Model not loaded. Call load_model() first.")
# Transcribe with word timestamps
segments, info = self._model.transcribe(
audio,
language=language,
word_timestamps=True,
beam_size=5,
vad_filter=True, # Filter out non-speech
)
logger.debug(
"Detected language: %s (prob: %.2f)",
info.language,
info.language_probability,
)
for segment in segments:
# Convert word info to WordTiming objects
words: list[WordTiming] = []
if segment.words:
words.extend(
WordTiming(
word=word.word,
start=word.start,
end=word.end,
probability=word.probability,
)
for word in segment.words
)
yield AsrResult(
text=segment.text.strip(),
start=segment.start,
end=segment.end,
words=tuple(words),
language=info.language,
language_probability=info.language_probability,
avg_logprob=segment.avg_logprob,
no_speech_prob=segment.no_speech_prob,
)
@property
def is_loaded(self) -> bool:
"""Return True if model is loaded."""
return self._model is not None
@property
def model_size(self) -> str | None:
"""Return the loaded model size, or None if not loaded."""
return self._model_size
def unload(self) -> None:
"""Unload the model to free memory."""
self._model = None
self._model_size = None
logger.info("Model unloaded")
@property
def compute_type(self) -> str:
"""Return the compute type."""
return self._compute_type
@property
def device(self) -> str:
"""Return the device."""
return self._device