"""Audio capture protocols and data types for Spike 2. These protocols define the contracts for audio capture components that will be promoted to src/noteflow/audio/ after validation. """ from __future__ import annotations from collections.abc import Callable from dataclasses import dataclass from typing import Protocol import numpy as np from numpy.typing import NDArray @dataclass(frozen=True) class AudioDeviceInfo: """Information about an audio input device.""" device_id: int name: str channels: int sample_rate: int is_default: bool @dataclass class TimestampedAudio: """Audio frames with capture timestamp.""" frames: NDArray[np.float32] timestamp: float # Monotonic time when captured duration: float # Duration in seconds def __post_init__(self) -> None: """Validate audio data.""" if self.duration < 0: raise ValueError("Duration must be non-negative") if self.timestamp < 0: raise ValueError("Timestamp must be non-negative") # Type alias for audio frame callback AudioFrameCallback = Callable[[NDArray[np.float32], float], None] class AudioCapture(Protocol): """Protocol for audio input capture. Implementations should handle device enumeration, stream management, and device change detection. """ def list_devices(self) -> list[AudioDeviceInfo]: """List available audio input devices. Returns: List of AudioDeviceInfo for all available input devices. """ ... def start( self, device_id: int | None, on_frames: AudioFrameCallback, sample_rate: int = 16000, channels: int = 1, chunk_duration_ms: int = 100, ) -> None: """Start capturing audio from the specified device. Args: device_id: Device ID to capture from, or None for default device. on_frames: Callback receiving (frames, timestamp) for each chunk. sample_rate: Sample rate in Hz (default 16kHz for ASR). channels: Number of channels (default 1 for mono). chunk_duration_ms: Duration of each audio chunk in milliseconds. Raises: RuntimeError: If already capturing. ValueError: If device_id is invalid. """ ... def stop(self) -> None: """Stop audio capture. Safe to call even if not capturing. """ ... def is_capturing(self) -> bool: """Check if currently capturing audio. Returns: True if capture is active. """ ... class AudioLevelProvider(Protocol): """Protocol for computing audio levels (VU meter data).""" def get_rms(self, frames: NDArray[np.float32]) -> float: """Calculate RMS level from audio frames. Args: frames: Audio samples as float32 array (normalized -1.0 to 1.0). Returns: RMS level normalized to 0.0-1.0 range. """ ... def get_db(self, frames: NDArray[np.float32]) -> float: """Calculate dB level from audio frames. Args: frames: Audio samples as float32 array (normalized -1.0 to 1.0). Returns: Level in dB (typically -60 to 0 range). """ ... class RingBuffer(Protocol): """Protocol for timestamped audio ring buffer. Ring buffers store recent audio with timestamps for ASR processing and playback sync. """ def push(self, audio: TimestampedAudio) -> None: """Add audio to the buffer. Old audio is discarded if buffer exceeds max_duration. Args: audio: Timestamped audio chunk to add. """ ... def get_window(self, duration_seconds: float) -> list[TimestampedAudio]: """Get the last N seconds of audio. Args: duration_seconds: How many seconds of audio to retrieve. Returns: List of TimestampedAudio chunks, ordered oldest to newest. """ ... def clear(self) -> None: """Clear all audio from the buffer.""" ... @property def duration(self) -> float: """Total duration of buffered audio in seconds.""" ... @property def max_duration(self) -> float: """Maximum buffer duration in seconds.""" ...