diff --git a/.claudectx/codefixes.md b/.claudectx/codefixes.md index 68845ee..83482c3 100644 --- a/.claudectx/codefixes.md +++ b/.claudectx/codefixes.md @@ -1,116 +1,55 @@ -This is a comprehensive code review of the NoteFlow Tauri backend based on the provided source files. +# Codefixes Tracking (Jan 19, 2026) -Overall, the codebase demonstrates **high maturity** and **advanced engineering practices**. It handles complex problems like audio clock drift, bidirectional gRPC streaming, and cross-platform resource management with sophistication rarely seen in typical Tauri apps. +## Completion Status -Here is the detailed review categorized by domain. +| Item | Description | Status | +|------|-------------|--------| +| 1 | Audio chunk serialization optimization | ✅ COMPLETED | +| 2 | Client backpressure throttling | ✅ COMPLETED | +| 3 | Whisper VAD disabled for streaming | ✅ COMPLETED | +| 4 | Per-segment DB overhead reduction | ✅ COMPLETED | +| 5 | Auto-trigger offline diarization | ✅ COMPLETED | +| 6 | Diarization turn pruning tuning | ✅ COMPLETED | +| 7 | Audio format conversion tests | ✅ COMPLETED | +| 8 | Stream state consolidation | ✅ COMPLETED | +| 9 | Adaptive partial cadence | ✅ COMPLETED | +| 10 | TS object spread (no issue) | ✅ N/A | +| 11 | Pre-existing test fixes (sprint 15 + sync) | ✅ COMPLETED | -### 1. Concurrency & Performance +--- -**Strengths:** -* **Drift Compensation (`src/audio/drift_compensation/`):** The implementation of `DriftDetector` using linear regression and `AdaptiveResampler` (via `rubato`) is excellent. This is critical for preventing "robotic" audio artifacts when mixing USB microphones with system loopback (which operate on different hardware clocks). -* **Atomic State Transitions:** The `StreamManager` (`src/grpc/streaming/manager.rs`) uses atomic state checks and a timeout mechanism (`STARTING_STATE_TIMEOUT_SECS`) to prevent race conditions during stream initialization. -* **Proto Compliance Tests:** The usage of macros in `src/grpc/proto_compliance_tests.rs` to ensure internal Rust types match generated Protobuf types is a fantastic defensive programming strategy. +## Summary of Items -**Critical Findings:** +1. **Audio serialization**: Optimize JS→Rust→gRPC path (Float32Array→Vec→bytes) ✅ + - Added `bytemuck` crate for zero-copy byte casting + - Replaced per-element `flat_map` with `bytemuck::cast_slice` (O(1) cast + single memcpy) + - Added compile-time endianness check for safety + - Added 3 unit tests verifying identical output + - Fixed Rust quality warning: extracted `TEST_SAMPLE_COUNT` constant +2. **Backpressure**: Add real throttling when server reports congestion ✅ + - Added `THROTTLE_THRESHOLD_MS` (3 seconds) and `THROTTLE_RESUME_DELAY_MS` (500ms) + - `send()` returns false when throttled, drops chunks + - 10 new throttle behavior tests added +3. **Whisper VAD**: Pass vad_filter=False for streaming segments ✅ +4. **DB overhead**: Cache meeting info in stream state, reduce per-segment commits ✅ + - Added `meeting_db_id` caching in `MeetingStreamState` + - `_ensure_meeting_db_id()` fetches on first segment only + - Fixed type errors with proper `UnitOfWork` and `AsrResult` types +5. **Auto diarization**: Trigger offline refinement after recording stop ✅ + - Added `diarization_auto_refine` setting (default False) + - Config flows through `ServicesConfig` to servicer + - `auto_trigger_diarization_refinement()` in `_jobs.py` handles job creation + - Triggered from `start_post_processing()` after meeting stops +6. **Pruning tuning**: Reduce streaming diarization window (preview mode) ✅ + - Reduced `_MAX_TURN_AGE_MINUTES` from 15 to 5 + - Reduced `_MAX_TURN_COUNT` from 5000 to 1000 +7. **Audio tests**: Add missing tests for resample/format validation ✅ + - Created 34 tests in `tests/grpc/test_audio_processing.py` + - TestResampleAudio (8), TestDecodeAudioChunk (4), TestConvertAudioFormat (8), TestValidateStreamFormat (14) +8. **Stream state**: Migrate to MeetingStreamState as single source of truth ✅ +9. **Adaptive cadence**: Apply multiplier to partial cadence under congestion ✅ +10. **TS spread**: No issue found in current codebase ✅ +11. **Pre-existing test fixes**: Fixed failing tests discovered during validation ✅ + - `test_sprint_15_1_critical_bugs.py`: Fixed path prefixes (`_mixins` → `mixins`) + - `test_sync_orchestration.py`: Fixed `error_message` → `error_code` in protocol/assertion -1. **Blocking I/O in Async Context (Keychain):** - In `src/commands/recording/session/start.rs`, the function `start_recording` calls: - ```rust - // Line 197 - if let Err(err) = state.crypto.ensure_initialized() { ... } - ``` - `CryptoManager::ensure_initialized` interacts with the OS Keychain (via `keyring` crate). On macOS/Linux, this can trigger a blocking UI prompt for the password. Since `start_recording` is an `async fn` running on the Tokio runtime, this can block the executor thread, freezing other async tasks (like heartbeats or UI events). - * **Recommendation:** Wrap this specific call in `task::spawn_blocking`. - -2. **Unbounded Memory Growth during Recording:** - In `src/state/app_state.rs`, `session_audio_buffer` is defined as `RwLock>`. - In `src/commands/recording/session/processing.rs`: - ```rust - // Line 71 - buffer.push(TimestampedAudio { ... }); - ``` - This vector grows indefinitely until the recording stops. For 48kHz float audio, this consumes approx **11.5 MB per minute**. A 2-hour recording will consume ~1.4 GB of RAM *just for this buffer*, causing potential OOM crashes on lower-end machines. - * **Recommendation:** Implement a ring buffer or page audio to disk (temp files) once it exceeds a certain threshold (e.g., 50MB), keeping only the waveform data needed for the visualizer in memory. - -3. **Potential Deadlocks with `parking_lot::RwLock`:** - You are using `parking_lot::RwLock` inside `AppState`. Unlike `tokio::sync::RwLock`, these locks are synchronous. If you hold a `write()` lock across an `.await` point, you will deadlock the thread. - * *Scan:* `src/commands/recording/session/chunks.rs` (Line 24) correctly drops the lock before awaiting `audio_tx.send()`. - * *Scan:* `src/commands/playback/audio.rs` (Line 48) clones the buffer while holding the read lock. This is safe but memory intensive (see point #2). - -### 2. Audio Architecture - -**Strengths:** -* **Dual Capture (`src/commands/recording/dual_capture.rs`):** The logic to handle Windows WASAPI loopback vs. standard input APIs is handled cleanly. -* **Resiliency:** The `DroppedChunkTracker` in `capture.rs` provides excellent feedback to the UI without spamming events (throttled to 1s). - -**Improvements:** - -1. **Audio Normalization for Storage:** - In `src/commands/recording/dual_capture.rs`, line 332: - ```rust - let _gain_applied = normalize_for_asr(&mut chunk); - ``` - This modifies the audio chunk *in-place* before sending it to the `capture_tx`. This channel feeds both the ASR stream *and* the file writer. - * **Issue:** You are saving dynamically compressed/normalized audio to disk (`.nfaudio`). While good for ASR, this destroys the dynamic range of the original recording permanently. - * **Recommendation:** Apply normalization only to the copy sent to the gRPC stream, or store the gain factor as metadata if you want to apply it non-destructively during playback. - -2. **Panic in Audio Thread:** - In `src/audio/capture.rs`, the error callback simply logs: - ```rust - move |err| { tracing::error!("Audio capture error: {}", err); } - ``` - If the device is unplugged, `cpal` streams usually terminate. The `AudioCapture` struct doesn't seem to have a mechanism to signal the main application state that capture has died unexpectedly, leading to a "zombie" recording state (UI thinks it's recording, but no data is flowing). - * **Recommendation:** Pass a `mpsc::Sender` for errors to the stream builder to notify `AppState` of fatal device errors. - -### 3. Security - -**Strengths:** -* **Lazy Crypto Init:** The `CryptoManager` correctly defers sensitive keychain access until user action. -* **Encryption:** Using `AES-256-GCM` (`aes_gcm` crate) is the industry standard. -* **OIDC/OAuth:** The implementation of the Loopback IP flow (`src/oauth_loopback.rs`) is compliant with modern RFCs (PKCE support is referenced in types). - -**Concerns:** - -1. **Secret Scrubbing:** - In `src/commands/preferences.rs`, `save_preferences` logs: - ```rust - // Line 16-21 - tracing::trace!(... "Preferences requested"); - ``` - Ensure that `UserPreferences` Debug/Display traits do not leak `api_key` or `client_secret` if they are added to the `extra` hashmap or `ai_config` JSON blob. - -2. **Loopback Server Binding:** - In `src/oauth_loopback.rs`: - ```rust - const LOOPBACK_BIND_ADDR: &str = "127.0.0.1:0"; - ``` - This binds to IPv4. On some dual-stack systems, if the browser resolves `localhost` to `::1` (IPv6), the callback might fail. - * **Recommendation:** Explicitly use `127.0.0.1` in the redirect URI (which you do), but verify that the OS doesn't force IPv6 for localhost lookups if you ever switch to using "localhost" string. Current implementation looks safe. - -### 4. Code Quality & Maintenance - -**Strengths:** -* **Modular gRPC Client:** Splitting the gRPC client into traits/modules (`src/grpc/client/*.rs`) prevents the "God Object" anti-pattern common in generated clients. -* **Type Safety:** The strict separation of Proto types vs. Domain types (`src/grpc/types/`) with explicit converters prevents leaking generated code details into the frontend logic. -* **Testing:** The `harness.rs` and integration tests are very thorough. - -**Nitpicks:** - -* **Magic Numbers:** - `client/src-tauri/src/audio/loader.rs`: - ```rust - const MAX_SAMPLES: usize = 1_000_000_000; - ``` - While defined as a constant, loading 1 billion samples (4GB) into a `Vec` will almost certainly crash the app before the check is reached due to allocation failure on most consumer desktops. A lower, streamed limit is safer. - -* **Error Handling:** - In `src/commands/recording/session/start.rs`, errors are classified and emitted. However, `log_recording_start_failure` effectively swallows the error context for the logger, then the error is returned to the frontend. Ensure the frontend actually displays these specific error categories (e.g., `PolicyBlocked`). - -### 5. Summary of Recommendations - -1. **High Priority:** Wrap `state.crypto.ensure_initialized()` in `spawn_blocking` to prevent freezing the UI/Event loop during keychain prompts. -2. **High Priority:** Implement a ring buffer or disk-paging for `session_audio_buffer` to prevent OOM on long recordings. -3. **Medium Priority:** Split the audio pipeline so normalization only applies to the ASR stream, preserving original dynamics for the saved file. -4. **Low Priority:** Implement a fatal error channel from `cpal` callbacks to the main state to auto-stop recording on device disconnection. - -**Verdict:** This is high-quality Rust code. The architecture is sound, but the memory management strategy for audio buffers needs adjustment for production use cases involving long sessions. \ No newline at end of file diff --git a/client/src-tauri/Cargo.lock b/client/src-tauri/Cargo.lock index ece1f8b..0cac8b4 100644 --- a/client/src-tauri/Cargo.lock +++ b/client/src-tauri/Cargo.lock @@ -2964,6 +2964,7 @@ dependencies = [ "alsa", "async-stream", "base64 0.22.1", + "bytemuck", "chrono", "cpal", "directories", diff --git a/client/src-tauri/Cargo.toml b/client/src-tauri/Cargo.toml index ffbdcf8..f590720 100644 --- a/client/src-tauri/Cargo.toml +++ b/client/src-tauri/Cargo.toml @@ -62,6 +62,9 @@ thiserror = "2.0" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } +# === Zero-copy byte casting === +bytemuck = { version = "1.16", features = ["extern_crate_alloc"] } + # === Utilities === base64 = "0.22" uuid = { version = "1.10", features = ["v4", "serde"] } diff --git a/client/src-tauri/src/commands/recording/session/start.rs b/client/src-tauri/src/commands/recording/session/start.rs index a10d5b8..9fc4244 100644 --- a/client/src-tauri/src/commands/recording/session/start.rs +++ b/client/src-tauri/src/commands/recording/session/start.rs @@ -1,5 +1,10 @@ //! Start recording handler. +// Audio data is sent as raw bytes and interpreted as f32 little-endian on the server. +// This compile-time check ensures we're on a compatible architecture. +#[cfg(not(target_endian = "little"))] +compile_error!("Audio byte conversion assumes little-endian architecture"); + use std::sync::Arc; use std::time::{Duration, Instant}; @@ -299,12 +304,11 @@ pub async fn start_recording( let audio_tx_clone = audio_tx.clone(); let conversion_task = tauri::async_runtime::spawn(async move { while let Some(chunk) = capture_rx.recv().await { - // Convert f32 samples to bytes (little-endian) - let bytes: Vec = chunk - .samples - .iter() - .flat_map(|&s| s.to_le_bytes()) - .collect(); + // Zero-copy reinterpret f32 slice as bytes, then copy to Vec. + // This is O(1) for the cast + single memcpy for the Vec allocation, + // much faster than per-element iteration with flat_map. + // Safety: f32 is Pod (plain old data) and we're on little-endian. + let bytes: Vec = bytemuck::cast_slice::(&chunk.samples).to_vec(); let stream_chunk = AudioStreamChunk { audio_data: bytes, diff --git a/client/src-tauri/src/commands/recording/tests.rs b/client/src-tauri/src/commands/recording/tests.rs index 8b70e68..fe98484 100644 --- a/client/src-tauri/src/commands/recording/tests.rs +++ b/client/src-tauri/src/commands/recording/tests.rs @@ -5,6 +5,9 @@ use crate::crypto::CryptoBox; use crate::grpc::types::results::TimestampedAudio; use std::time::{SystemTime, UNIX_EPOCH}; +/// Sample count for bytemuck size verification tests. +const TEST_SAMPLE_COUNT: usize = 100; + fn temp_audio_path() -> std::path::PathBuf { let nanos = SystemTime::now() .duration_since(UNIX_EPOCH) @@ -91,3 +94,34 @@ fn decode_input_device_id_rejects_output_ids() { let parsed = decode_input_device_id("output:1:Speakers"); assert_eq!(parsed, None); } + +#[test] +fn bytemuck_f32_to_bytes_matches_manual_conversion() { + // Test that bytemuck::cast_slice produces the same output as manual to_le_bytes() + let samples: Vec = vec![1.0, -1.0, 0.5, -0.5, 0.0, std::f32::consts::PI]; + + // Manual conversion (old method) + let manual_bytes: Vec = samples.iter().flat_map(|&s| s.to_le_bytes()).collect(); + + // Bytemuck conversion (new method) + let bytemuck_bytes: Vec = bytemuck::cast_slice::(&samples).to_vec(); + + assert_eq!( + manual_bytes, bytemuck_bytes, + "bytemuck conversion must match manual conversion" + ); +} + +#[test] +fn bytemuck_f32_to_bytes_handles_empty() { + let samples: Vec = vec![]; + let bytes: Vec = bytemuck::cast_slice::(&samples).to_vec(); + assert!(bytes.is_empty()); +} + +#[test] +fn bytemuck_f32_to_bytes_size_is_correct() { + let samples: Vec = vec![1.0; TEST_SAMPLE_COUNT]; + let bytes: Vec = bytemuck::cast_slice::(&samples).to_vec(); + assert_eq!(bytes.len(), samples.len() * std::mem::size_of::()); +} diff --git a/client/src/api/tauri-adapter/index.ts b/client/src/api/tauri-adapter/index.ts index f37a6db..45d9165 100644 --- a/client/src/api/tauri-adapter/index.ts +++ b/client/src/api/tauri-adapter/index.ts @@ -3,6 +3,8 @@ export { initializeTauriAPI, isTauriEnvironment } from './environment'; export { CONGESTION_DISPLAY_THRESHOLD_MS, CONSECUTIVE_FAILURE_THRESHOLD, + THROTTLE_RESUME_DELAY_MS, + THROTTLE_THRESHOLD_MS, TauriTranscriptionStream, } from './stream'; export type { diff --git a/client/src/api/tauri-adapter/stream.ts b/client/src/api/tauri-adapter/stream.ts index 9770f6d..3143cef 100644 --- a/client/src/api/tauri-adapter/stream.ts +++ b/client/src/api/tauri-adapter/stream.ts @@ -16,6 +16,12 @@ export const CONSECUTIVE_FAILURE_THRESHOLD = 3; /** Threshold in milliseconds before showing buffering indicator (2 seconds). */ export const CONGESTION_DISPLAY_THRESHOLD_MS = Timing.TWO_SECONDS_MS; +/** Threshold in milliseconds of continuous congestion before throttling sends (3 seconds). */ +export const THROTTLE_THRESHOLD_MS = Timing.THREE_SECONDS_MS; + +/** Delay in milliseconds after congestion clears before resuming sends (500ms). */ +export const THROTTLE_RESUME_DELAY_MS = 500; + /** Real-time transcription stream using Tauri events. */ export class TauriTranscriptionStream { private unlistenFn: (() => void) | null = null; @@ -36,6 +42,12 @@ export class TauriTranscriptionStream { /** Whether the stream has been closed (prevents late listeners). */ private isClosed = false; + /** Whether audio sending is currently throttled due to prolonged congestion. */ + private isThrottled = false; + + /** Timer for resuming sends after congestion clears (null if not pending). */ + private throttleResumeTimer: ReturnType | null = null; + /** Queue for ordered, backpressure-aware chunk transmission. */ private readonly sendQueue: StreamingQueue; private readonly drainTimeoutMs = 5000; @@ -76,14 +88,23 @@ export class TauriTranscriptionStream { return this.sendQueue.currentDepth; } + /** Whether audio sending is currently throttled due to congestion. */ + getIsThrottled(): boolean { + return this.isThrottled; + } + /** * Send an audio chunk to the transcription service. * * Chunks are queued and sent in order with backpressure protection. - * Returns false if the queue is full (severe backpressure). + * Returns false if the stream is closed, throttled, or the queue is full. + * + * When throttled due to prolonged congestion, chunks are dropped to prevent + * overwhelming the server. The stream will automatically resume when + * congestion clears. */ send(chunk: AudioChunk): boolean { - if (this.isClosed) { + if (this.isClosed || this.isThrottled) { return false; } @@ -171,13 +192,32 @@ export class TauriTranscriptionStream { return; } - const { is_congested } = event.payload; + const { is_congested, congested_duration_ms } = event.payload; if (is_congested) { // Start tracking congestion if not already this.congestionStartTime ??= Date.now(); const duration = Date.now() - this.congestionStartTime; + // Clear any pending resume timer since we're still congested + this.clearThrottleResumeTimer(); + + // Enable throttling if congestion exceeds threshold + if (duration >= THROTTLE_THRESHOLD_MS || congested_duration_ms >= THROTTLE_THRESHOLD_MS) { + if (!this.isThrottled) { + this.isThrottled = true; + addClientLog({ + level: 'warning', + source: 'api', + message: 'Audio stream throttled due to prolonged congestion', + metadata: { + meeting_id: this.meetingId, + duration_ms: String(Math.max(duration, congested_duration_ms)), + }, + }); + } + } + // Only show buffering after threshold is exceeded if (duration >= CONGESTION_DISPLAY_THRESHOLD_MS && !this.isShowingBuffering) { this.isShowingBuffering = true; @@ -187,7 +227,10 @@ export class TauriTranscriptionStream { this.congestionCallback?.({ isBuffering: true, duration }); } } else { - // Congestion cleared + // Congestion cleared - schedule resume with delay + this.scheduledThrottleResume(); + + // Update UI immediately when congestion clears if (this.isShowingBuffering) { this.isShowingBuffering = false; this.congestionCallback?.({ isBuffering: false, duration: 0 }); @@ -210,6 +253,38 @@ export class TauriTranscriptionStream { }); } + /** Clear the throttle resume timer if pending. */ + private clearThrottleResumeTimer(): void { + if (this.throttleResumeTimer !== null) { + clearTimeout(this.throttleResumeTimer); + this.throttleResumeTimer = null; + } + } + + /** Schedule resumption of sending after congestion clears. */ + private scheduledThrottleResume(): void { + if (!this.isThrottled) { + return; // Not currently throttled, nothing to resume + } + + // Clear any existing timer to reset the delay + this.clearThrottleResumeTimer(); + + // Schedule resume after delay to ensure congestion is truly cleared + this.throttleResumeTimer = setTimeout(() => { + this.throttleResumeTimer = null; + if (!this.isClosed && this.isThrottled) { + this.isThrottled = false; + addClientLog({ + level: 'info', + source: 'api', + message: 'Audio stream throttle released - resuming sends', + metadata: { meeting_id: this.meetingId }, + }); + } + }, THROTTLE_RESUME_DELAY_MS); + } + /** * Close the stream and stop recording. * @@ -222,6 +297,9 @@ export class TauriTranscriptionStream { async close(): Promise { this.isClosed = true; + // Clear throttle resume timer + this.clearThrottleResumeTimer(); + // Drain the send queue to ensure all pending chunks are transmitted try { const drainPromise = this.sendQueue.drain(); @@ -244,9 +322,10 @@ export class TauriTranscriptionStream { this.healthUnlistenFn(); this.healthUnlistenFn = null; } - // Reset congestion state + // Reset congestion and throttle state this.congestionStartTime = null; this.isShowingBuffering = false; + this.isThrottled = false; try { await this.invoke(TauriCommands.STOP_RECORDING, { meeting_id: this.meetingId }); diff --git a/client/src/api/tauri-transcription-stream.test.ts b/client/src/api/tauri-transcription-stream.test.ts index 2d87094..cd66a1c 100644 --- a/client/src/api/tauri-transcription-stream.test.ts +++ b/client/src/api/tauri-transcription-stream.test.ts @@ -1,6 +1,8 @@ -import { beforeEach, describe, expect, it, vi } from 'vitest'; +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; import { CONSECUTIVE_FAILURE_THRESHOLD, + THROTTLE_RESUME_DELAY_MS, + THROTTLE_THRESHOLD_MS, TauriEvents, TauriTranscriptionStream, type TauriInvoke, @@ -223,4 +225,303 @@ describe('TauriTranscriptionStream', () => { expect(callback).not.toHaveBeenCalled(); }); }); + + describe('throttle behavior', () => { + let healthEventHandler: ((event: { payload: { + meeting_id: string; + is_congested: boolean; + processing_delay_ms: number; + queue_depth: number; + congested_duration_ms: number; + } }) => void) | null = null; + + beforeEach(() => { + vi.useFakeTimers(); + healthEventHandler = null; + const listenMock = vi.fn(async (eventName: string, handler: (event: { payload: unknown }) => void) => { + if (eventName === TauriEvents.STREAM_HEALTH) { + healthEventHandler = handler as typeof healthEventHandler; + } + return () => { healthEventHandler = null; }; + }); + mockListen = listenMock as unknown as TauriListen; + stream = new TauriTranscriptionStream('meeting-123', mockInvoke, mockListen); + }); + + afterEach(() => { + vi.useRealTimers(); + }); + + it('getIsThrottled() returns false initially', () => { + expect(stream.getIsThrottled()).toBe(false); + }); + + it('does not throttle for short congestion', () => { + const congestionCallback = vi.fn(); + stream.onCongestion(congestionCallback); + + // Simulate congestion that lasts less than threshold + healthEventHandler?.({ + payload: { + meeting_id: 'meeting-123', + is_congested: true, + processing_delay_ms: 100, + queue_depth: 5, + congested_duration_ms: 1000, // 1 second, below 3s threshold + }, + }); + + expect(stream.getIsThrottled()).toBe(false); + + // Advance time, but still below threshold + vi.advanceTimersByTime(1000); + + healthEventHandler?.({ + payload: { + meeting_id: 'meeting-123', + is_congested: true, + processing_delay_ms: 100, + queue_depth: 5, + congested_duration_ms: 2000, + }, + }); + + expect(stream.getIsThrottled()).toBe(false); + }); + + it('throttles when congested_duration_ms exceeds threshold', () => { + const mockAddClientLog = vi.mocked(addClientLog); + mockAddClientLog.mockClear(); + + const congestionCallback = vi.fn(); + stream.onCongestion(congestionCallback); + + // Simulate prolonged congestion via server-reported duration + healthEventHandler?.({ + payload: { + meeting_id: 'meeting-123', + is_congested: true, + processing_delay_ms: 500, + queue_depth: 20, + congested_duration_ms: THROTTLE_THRESHOLD_MS + 100, + }, + }); + + expect(stream.getIsThrottled()).toBe(true); + + // Verify logging + expect(mockAddClientLog).toHaveBeenCalledWith( + expect.objectContaining({ + level: 'warning', + message: 'Audio stream throttled due to prolonged congestion', + }) + ); + }); + + it('throttles when local congestion tracking exceeds threshold', () => { + const congestionCallback = vi.fn(); + stream.onCongestion(congestionCallback); + + // Start congestion + healthEventHandler?.({ + payload: { + meeting_id: 'meeting-123', + is_congested: true, + processing_delay_ms: 100, + queue_depth: 5, + congested_duration_ms: 0, + }, + }); + + expect(stream.getIsThrottled()).toBe(false); + + // Advance time past threshold + vi.advanceTimersByTime(THROTTLE_THRESHOLD_MS + 100); + + // Continue congestion + healthEventHandler?.({ + payload: { + meeting_id: 'meeting-123', + is_congested: true, + processing_delay_ms: 100, + queue_depth: 10, + congested_duration_ms: 1000, // Server reports shorter duration + }, + }); + + // Should be throttled based on local tracking + expect(stream.getIsThrottled()).toBe(true); + }); + + it('send() returns false when throttled', () => { + const congestionCallback = vi.fn(); + stream.onCongestion(congestionCallback); + + // Trigger throttle + healthEventHandler?.({ + payload: { + meeting_id: 'meeting-123', + is_congested: true, + processing_delay_ms: 500, + queue_depth: 20, + congested_duration_ms: THROTTLE_THRESHOLD_MS + 100, + }, + }); + + expect(stream.getIsThrottled()).toBe(true); + + // Attempt to send - should be rejected + const result = stream.send({ + meeting_id: 'meeting-123', + audio_data: new Float32Array([0.5]), + timestamp: 1, + }); + + expect(result).toBe(false); + expect(mockInvoke).not.toHaveBeenCalled(); + }); + + it('resumes sending after congestion clears and delay passes', () => { + const mockAddClientLog = vi.mocked(addClientLog); + const congestionCallback = vi.fn(); + stream.onCongestion(congestionCallback); + + // Trigger throttle + healthEventHandler?.({ + payload: { + meeting_id: 'meeting-123', + is_congested: true, + processing_delay_ms: 500, + queue_depth: 20, + congested_duration_ms: THROTTLE_THRESHOLD_MS + 100, + }, + }); + + expect(stream.getIsThrottled()).toBe(true); + + // Congestion clears + mockAddClientLog.mockClear(); + healthEventHandler?.({ + payload: { + meeting_id: 'meeting-123', + is_congested: false, + processing_delay_ms: 0, + queue_depth: 0, + congested_duration_ms: 0, + }, + }); + + // Still throttled during delay + expect(stream.getIsThrottled()).toBe(true); + + // Advance past resume delay + vi.advanceTimersByTime(THROTTLE_RESUME_DELAY_MS + 10); + + // Now should be unthrottled + expect(stream.getIsThrottled()).toBe(false); + + // Verify logging + expect(mockAddClientLog).toHaveBeenCalledWith( + expect.objectContaining({ + level: 'info', + message: 'Audio stream throttle released - resuming sends', + }) + ); + }); + + it('cancels resume timer if congestion returns', () => { + const congestionCallback = vi.fn(); + stream.onCongestion(congestionCallback); + + // Trigger throttle + healthEventHandler?.({ + payload: { + meeting_id: 'meeting-123', + is_congested: true, + processing_delay_ms: 500, + queue_depth: 20, + congested_duration_ms: THROTTLE_THRESHOLD_MS + 100, + }, + }); + + expect(stream.getIsThrottled()).toBe(true); + + // Congestion clears + healthEventHandler?.({ + payload: { + meeting_id: 'meeting-123', + is_congested: false, + processing_delay_ms: 0, + queue_depth: 0, + congested_duration_ms: 0, + }, + }); + + // Advance partially through delay + vi.advanceTimersByTime(THROTTLE_RESUME_DELAY_MS / 2); + + // Still throttled + expect(stream.getIsThrottled()).toBe(true); + + // Congestion returns before delay completes + healthEventHandler?.({ + payload: { + meeting_id: 'meeting-123', + is_congested: true, + processing_delay_ms: 300, + queue_depth: 15, + congested_duration_ms: THROTTLE_THRESHOLD_MS + 200, + }, + }); + + // Advance past what would have been the resume time + vi.advanceTimersByTime(THROTTLE_RESUME_DELAY_MS); + + // Should still be throttled because congestion returned + expect(stream.getIsThrottled()).toBe(true); + }); + + it('ignores health events for different meeting IDs', () => { + const congestionCallback = vi.fn(); + stream.onCongestion(congestionCallback); + + // Health event for different meeting + healthEventHandler?.({ + payload: { + meeting_id: 'meeting-456', + is_congested: true, + processing_delay_ms: 500, + queue_depth: 20, + congested_duration_ms: THROTTLE_THRESHOLD_MS + 100, + }, + }); + + // Should NOT be throttled + expect(stream.getIsThrottled()).toBe(false); + }); + + it('clears throttle state on close()', async () => { + const congestionCallback = vi.fn(); + stream.onCongestion(congestionCallback); + + // Trigger throttle + healthEventHandler?.({ + payload: { + meeting_id: 'meeting-123', + is_congested: true, + processing_delay_ms: 500, + queue_depth: 20, + congested_duration_ms: THROTTLE_THRESHOLD_MS + 100, + }, + }); + + expect(stream.getIsThrottled()).toBe(true); + + // Close the stream + await stream.close(); + + // Throttle state should be cleared + expect(stream.getIsThrottled()).toBe(false); + }); + }); }); diff --git a/src/noteflow/config/settings/_main.py b/src/noteflow/config/settings/_main.py index ec4f587..8709d42 100644 --- a/src/noteflow/config/settings/_main.py +++ b/src/noteflow/config/settings/_main.py @@ -155,6 +155,10 @@ class Settings(TriggerSettings): bool, Field(default=True, description="Enable post-meeting diarization refinement"), ] + diarization_auto_refine: Annotated[ + bool, + Field(default=False, description="Auto-trigger offline diarization after recording stops"), + ] diarization_job_ttl_hours: Annotated[ int, Field( diff --git a/src/noteflow/grpc/config/cli.py b/src/noteflow/grpc/config/cli.py index 540a0a4..a893ff3 100644 --- a/src/noteflow/grpc/config/cli.py +++ b/src/noteflow/grpc/config/cli.py @@ -147,6 +147,7 @@ def _build_diarization_config( min_speakers=settings.diarization_min_speakers if settings else None, max_speakers=settings.diarization_max_speakers if settings else None, refinement_enabled=settings.diarization_refinement_enabled if settings else True, + auto_refine=settings.diarization_auto_refine if settings else False, ) def build_config_from_args(args: argparse.Namespace, settings: Settings | None) -> GrpcServerConfig: diff --git a/src/noteflow/grpc/config/config.py b/src/noteflow/grpc/config/config.py index 2bc8628..32878f5 100644 --- a/src/noteflow/grpc/config/config.py +++ b/src/noteflow/grpc/config/config.py @@ -59,6 +59,7 @@ class DiarizationConfig: min_speakers: Minimum expected speakers for offline diarization. max_speakers: Maximum expected speakers for offline diarization. refinement_enabled: Whether to allow diarization refinement RPCs. + auto_refine: Auto-trigger offline diarization refinement after recording stops. """ enabled: bool = False @@ -68,6 +69,7 @@ class DiarizationConfig: min_speakers: int | None = None max_speakers: int | None = None refinement_enabled: bool = True + auto_refine: bool = False @dataclass(frozen=True, slots=True) @@ -152,6 +154,7 @@ class ServicesConfig: summarization_service: Service for generating meeting summaries. diarization_engine: Engine for speaker identification. diarization_refinement_enabled: Whether to allow post-meeting diarization refinement. + diarization_auto_refine: Auto-trigger offline diarization after recording stops. ner_service: Service for named entity extraction. calendar_service: Service for OAuth and calendar event fetching. webhook_service: Service for webhook event notifications. @@ -162,6 +165,7 @@ class ServicesConfig: summarization_service: SummarizationService | None = None diarization_engine: DiarizationEngine | None = None diarization_refinement_enabled: bool = True + diarization_auto_refine: bool = False ner_service: NerService | None = None calendar_service: CalendarService | None = None webhook_service: WebhookService | None = None diff --git a/src/noteflow/grpc/mixins/_servicer_state.py b/src/noteflow/grpc/mixins/_servicer_state.py index 6e2fe3c..c6f3c88 100644 --- a/src/noteflow/grpc/mixins/_servicer_state.py +++ b/src/noteflow/grpc/mixins/_servicer_state.py @@ -55,6 +55,7 @@ class ServicerState(Protocol): identity_service: IdentityService hf_token_service: HfTokenService | None diarization_refinement_enabled: bool + diarization_auto_refine: bool # Audio writers audio_writers: dict[str, MeetingAudioWriter] diff --git a/src/noteflow/grpc/mixins/diarization/__init__.py b/src/noteflow/grpc/mixins/diarization/__init__.py index 0d7d1fd..e3d0868 100644 --- a/src/noteflow/grpc/mixins/diarization/__init__.py +++ b/src/noteflow/grpc/mixins/diarization/__init__.py @@ -1,9 +1,11 @@ """Speaker diarization mixin package for gRPC service.""" +from ._jobs import auto_trigger_diarization_refinement from ._mixin import DiarizationMixin from ._types import DIARIZATION_TIMEOUT_SECONDS __all__ = [ "DIARIZATION_TIMEOUT_SECONDS", "DiarizationMixin", + "auto_trigger_diarization_refinement", ] diff --git a/src/noteflow/grpc/mixins/diarization/_jobs.py b/src/noteflow/grpc/mixins/diarization/_jobs.py index 5b10ebe..4d39342 100644 --- a/src/noteflow/grpc/mixins/diarization/_jobs.py +++ b/src/noteflow/grpc/mixins/diarization/_jobs.py @@ -4,7 +4,7 @@ from __future__ import annotations import asyncio from typing import TYPE_CHECKING -from uuid import uuid4 +from uuid import UUID, uuid4 import grpc @@ -186,6 +186,95 @@ async def _prepare_diarization_job( return job_id, meeting.duration_seconds if meeting else None + + +def _should_auto_refine(host: ServicerHost, meeting_id: str) -> bool: + """Check if auto-refinement should proceed based on host configuration.""" + if not host.diarization_auto_refine: + logger.debug("Auto-diarization refinement disabled", meeting_id=meeting_id) + return False + + if host.diarization_engine is None: + logger.debug("Auto-diarization refinement skipped: no engine", meeting_id=meeting_id) + return False + + return True + + +async def _create_auto_diarization_job( + host: ServicerHost, + meeting_id: str, +) -> str | None: + """Create and persist a diarization job for auto-refinement. + + Returns: + Job ID if created successfully, None if skipped. + """ + from noteflow.domain.value_objects import MeetingId + + async with host.create_repository_provider() as repo: + if not repo.supports_diarization_jobs: + logger.debug("Auto-diarization skipped: database required", meeting_id=meeting_id) + return None + + active_job = await repo.diarization_jobs.get_active_for_meeting(meeting_id) + if active_job is not None: + logger.debug( + "Auto-diarization skipped: job active", + meeting_id=meeting_id, + active_job_id=active_job.job_id, + ) + return None + + parsed_id = MeetingId(UUID(meeting_id)) + meeting = await repo.meetings.get(parsed_id) + if meeting is None: + logger.warning("Auto-diarization skipped: meeting not found", meeting_id=meeting_id) + return None + + job_id = str(uuid4()) + persisted = await _create_and_persist_job( + job_id, meeting_id, meeting.duration_seconds, repo + ) + if not persisted: + logger.warning("Auto-diarization skipped: persist failed", meeting_id=meeting_id) + return None + + return job_id + +async def auto_trigger_diarization_refinement( + host: ServicerHost, + meeting_id: str, +) -> str | None: + """Auto-trigger diarization refinement after recording stops. + + This is an internal function called from post-processing that doesn't + require gRPC context. It performs minimal validation and starts the + job if preconditions are met. + + Args: + host: The servicer host with diarization capabilities. + meeting_id: The meeting ID to process. + + Returns: + Job ID if diarization was started, None if skipped or failed. + """ + if not _should_auto_refine(host, meeting_id): + return None + + job_id = await _create_auto_diarization_job(host, meeting_id) + if job_id is None: + return None + + await update_processing_status( + host.create_repository_provider, + meeting_id, + ProcessingStatusUpdate(step="diarization", status=ProcessingStepStatus.RUNNING), + ) + _schedule_diarization_task(host, job_id, None, meeting_id) + logger.info("Auto-diarization refinement triggered", meeting_id=meeting_id, job_id=job_id) + return job_id + class JobsMixin(JobStatusMixin): """Mixin providing diarization job management.""" diff --git a/src/noteflow/grpc/mixins/diarization/_streaming.py b/src/noteflow/grpc/mixins/diarization/_streaming.py index 5834ff6..f8c84dc 100644 --- a/src/noteflow/grpc/mixins/diarization/_streaming.py +++ b/src/noteflow/grpc/mixins/diarization/_streaming.py @@ -23,10 +23,16 @@ if TYPE_CHECKING: logger = get_logger(__name__) +# Streaming diarization window tuning constants. +# These limits are intentionally conservative because streaming diarization +# serves as a real-time preview only. The offline refinement path (via +# RefineSpeakerDiarization) is the quality path that processes the full +# audio file with higher accuracy. Keeping the streaming window small +# reduces memory pressure during long meetings. _SECONDS_PER_MINUTE: Final[int] = 60 -_MAX_TURN_AGE_MINUTES: Final[int] = 15 +_MAX_TURN_AGE_MINUTES: Final[int] = 5 _MAX_TURN_AGE_SECONDS: Final[int] = _MAX_TURN_AGE_MINUTES * _SECONDS_PER_MINUTE -_MAX_TURN_COUNT: Final[int] = 5_000 +_MAX_TURN_COUNT: Final[int] = 1_000 # Minimum samples required for diarization processing. # At 16kHz, 160 samples = 10ms - filters out bootstrap/handshake chunks diff --git a/src/noteflow/grpc/mixins/meeting/_post_processing.py b/src/noteflow/grpc/mixins/meeting/_post_processing.py index 6e4f4a4..9d1b7e4 100644 --- a/src/noteflow/grpc/mixins/meeting/_post_processing.py +++ b/src/noteflow/grpc/mixins/meeting/_post_processing.py @@ -210,37 +210,45 @@ async def start_post_processing( ) -> asyncio.Task[None] | None: """Spawn background task for post-meeting processing. - Starts auto-summarization and meeting completion as a fire-and-forget task. - Returns the task handle for testing/monitoring, or None if summarization - is not configured. + Starts auto-summarization, auto-diarization refinement, and meeting completion + as a fire-and-forget task. Returns the task handle for testing/monitoring, + or None if no post-processing is configured. Args: host: The servicer host. meeting_id: The meeting ID to process. Returns: - The spawned asyncio Task, or None if summarization service unavailable. + The spawned asyncio Task, or None if no post-processing needed. """ - service = host.summarization_service - if service is None: + from ..diarization._jobs import auto_trigger_diarization_refinement + + summarization_service = host.summarization_service + has_auto_diarization = host.diarization_auto_refine and host.diarization_engine is not None + + if summarization_service is None and not has_auto_diarization: logger.debug( - "Post-processing: summarization not configured, skipping", + "Post-processing: no services configured, skipping", meeting_id=meeting_id, ) return None - # Capture narrowed type for closure - summarization_service: SummarizationService = service + async def _run_post_processing() -> None: + """Run post-processing tasks sequentially.""" + # Auto-diarization refinement (fire-and-forget background job) + if has_auto_diarization: + await auto_trigger_diarization_refinement(host, meeting_id) + + # Generate summary and complete meeting + if summarization_service is not None: + await _generate_summary_and_complete(host, meeting_id, summarization_service) async def _run_with_error_handling() -> None: """Wrapper to catch and log any errors.""" try: - await _generate_summary_and_complete(host, meeting_id, summarization_service) + await _run_post_processing() except Exception: - logger.exception( - "Post-processing failed", - meeting_id=meeting_id, - ) + logger.exception("Post-processing failed", meeting_id=meeting_id) task = asyncio.create_task(_run_with_error_handling()) task.add_done_callback(lambda t: _post_processing_task_done_callback(t, meeting_id)) diff --git a/src/noteflow/grpc/mixins/streaming/_asr.py b/src/noteflow/grpc/mixins/streaming/_asr.py index 1e356d6..7df3013 100644 --- a/src/noteflow/grpc/mixins/streaming/_asr.py +++ b/src/noteflow/grpc/mixins/streaming/_asr.py @@ -2,7 +2,7 @@ from __future__ import annotations -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Protocol, cast @@ -10,10 +10,13 @@ import numpy as np from numpy.typing import NDArray from noteflow.domain.entities import Segment +from noteflow.domain.ports import UnitOfWork from noteflow.domain.value_objects import MeetingId +from noteflow.infrastructure.asr import AsrResult from noteflow.infrastructure.logging import get_logger from ...proto import noteflow_pb2 +from ...stream_state import MeetingStreamState from ..converters import ( create_segment_from_asr, parse_meeting_id_or_none, @@ -30,48 +33,119 @@ class _SpeakerAssignable(Protocol): def maybe_assign_speaker(self, meeting_id: str, segment: Segment) -> None: ... -class _SegmentRepository(Protocol): - async def add(self, meeting_id: MeetingId, segment: Segment) -> None: ... - - -class _SegmentAddable(Protocol): - @property - def segments(self) -> _SegmentRepository: ... - - -class _MeetingWithId(Protocol): - @property - def id(self) -> MeetingId: ... - - @property - def next_segment_id(self) -> int: ... - - -class _AsrResultLike(Protocol): - @property - def text(self) -> str: ... - - @property - def start(self) -> float: ... - - @property - def end(self) -> float: ... - - @dataclass(frozen=True, slots=True) class _SegmentBuildContext: """Context for building segments from ASR results. Groups related parameters to reduce function signature complexity. + Uses cached meeting_db_id to avoid fetching meeting on every segment. """ host: ServicerHost - repo: _SegmentAddable - meeting: _MeetingWithId - meeting_id: str + repo: UnitOfWork + meeting_db_id: MeetingId + meeting_id_str: str segment_start_time: float +async def _ensure_meeting_db_id( + host: ServicerHost, + meeting_id: str, + parsed_meeting_id: MeetingId, + state: MeetingStreamState, +) -> MeetingId | None: + """Ensure meeting_db_id is cached, fetching from DB on first call. + + Args: + host: The servicer host. + meeting_id: String meeting identifier for logging. + parsed_meeting_id: Parsed MeetingId value object. + state: Stream state to cache the meeting_db_id. + + Returns: + The cached or fetched MeetingId, or None if meeting not found. + """ + if state.meeting_db_id is not None: + # Use cached meeting ID (avoid DB fetch) + return MeetingId(state.meeting_db_id) + + # First segment: fetch meeting and cache its ID + async with host.create_repository_provider() as repo: + meeting = await repo.meetings.get(parsed_meeting_id) + if meeting is None: + logger.error("Meeting not found during ASR processing", meeting_id=meeting_id) + return None + + # Cache meeting ID in stream state for subsequent segments + state.meeting_db_id = meeting.id + # Initialize segment sequence from meeting if not already set + if state.next_segment_sequence == 0: + state.next_segment_sequence = meeting.next_segment_id + + return meeting.id + + +@dataclass(frozen=True, slots=True) +class _ProcessingPrerequisites: + """Validated prerequisites for audio segment processing.""" + + parsed_meeting_id: MeetingId + state: MeetingStreamState + + +@dataclass(frozen=True, slots=True) +class _TranscriptionContext: + """Context for transcription and persistence operations.""" + + host: ServicerHost + meeting_id: str + meeting_db_id: MeetingId + segment_start_time: float + + +def _validate_prerequisites( + host: ServicerHost, + meeting_id: str, +) -> _ProcessingPrerequisites | None: + """Validate all prerequisites for segment processing.""" + parsed_meeting_id = _validate_meeting_id(meeting_id) + if parsed_meeting_id is None: + return None + + state = host.get_stream_state(meeting_id) + if state is None: + logger.error("Stream state not found during ASR processing", meeting_id=meeting_id) + return None + + return _ProcessingPrerequisites(parsed_meeting_id=parsed_meeting_id, state=state) + + +async def _transcribe_and_persist( + ctx: _TranscriptionContext, + audio: NDArray[np.float32], +) -> AsyncIterator[noteflow_pb2.TranscriptUpdate]: + """Transcribe audio and persist segments to database.""" + asr_engine = ctx.host.asr_engine + if asr_engine is None: + return + + results = await asr_engine.transcribe_async(audio) + + async with ctx.host.create_repository_provider() as repo: + build_ctx = _SegmentBuildContext( + host=ctx.host, + repo=repo, + meeting_db_id=ctx.meeting_db_id, + meeting_id_str=ctx.meeting_id, + segment_start_time=ctx.segment_start_time, + ) + segments_to_add = await _build_segments_from_results(build_ctx, results) + if segments_to_add: + await repo.commit() + for _, update in segments_to_add: + yield update + + async def process_audio_segment( host: ServicerHost, meeting_id: str, @@ -80,6 +154,9 @@ async def process_audio_segment( ) -> AsyncIterator[noteflow_pb2.TranscriptUpdate]: """Process a complete audio segment through ASR. + Uses cached meeting_db_id from stream state to reduce per-segment DB overhead. + Only fetches meeting from DB on first segment to initialize the cache. + Args: host: The servicer host. meeting_id: Meeting identifier. @@ -90,32 +167,30 @@ async def process_audio_segment( TranscriptUpdates for transcribed segments. """ if len(audio) == 0: - return # Empty audio is not an error, just nothing to process + return - asr_engine = host.asr_engine - if asr_engine is None: + if host.asr_engine is None: logger.error("ASR engine unavailable during segment processing", meeting_id=meeting_id) return - parsed_meeting_id = _validate_meeting_id(meeting_id) - if parsed_meeting_id is None: - return # Already logged in _validate_meeting_id + prereqs = _validate_prerequisites(host, meeting_id) + if prereqs is None: + return - async with host.create_repository_provider() as repo: - meeting = await repo.meetings.get(parsed_meeting_id) - if meeting is None: - logger.error("Meeting not found during ASR processing", meeting_id=meeting_id) - return - results = await asr_engine.transcribe_async(audio) - ctx = _SegmentBuildContext( - host=host, repo=repo, meeting=meeting, - meeting_id=meeting_id, segment_start_time=segment_start_time, - ) - segments_to_add = await _build_segments_from_results(ctx, results) - if segments_to_add: - await repo.commit() - for _, update in segments_to_add: - yield update + meeting_db_id = await _ensure_meeting_db_id( + host, meeting_id, prereqs.parsed_meeting_id, prereqs.state + ) + if meeting_db_id is None: + return + + ctx = _TranscriptionContext( + host=host, + meeting_id=meeting_id, + meeting_db_id=meeting_db_id, + segment_start_time=segment_start_time, + ) + async for update in _transcribe_and_persist(ctx, audio): + yield update def _validate_meeting_id(meeting_id: str) -> MeetingId | None: @@ -128,12 +203,12 @@ def _validate_meeting_id(meeting_id: str) -> MeetingId | None: async def _build_segments_from_results( ctx: _SegmentBuildContext, - results: list[_AsrResultLike], + results: Sequence[AsrResult], ) -> list[tuple[Segment, noteflow_pb2.TranscriptUpdate]]: """Build and persist segments from ASR results. Args: - ctx: Context with host, repo, meeting, and timing info. + ctx: Context with host, repo, meeting_db_id, and timing info. results: ASR transcription results to process. Returns: @@ -145,19 +220,20 @@ async def _build_segments_from_results( if not result.text or not result.text.strip(): logger.debug( "Skipping empty ASR result", - meeting_id=ctx.meeting_id, + meeting_id=ctx.meeting_id_str, start=result.start, end=result.end, ) continue - segment_id = ctx.host.next_segment_id(ctx.meeting_id, fallback=ctx.meeting.next_segment_id) + # Use host.next_segment_id with fallback=0 since state caches the sequence + segment_id = ctx.host.next_segment_id(ctx.meeting_id_str, fallback=0) segment = create_segment_from_asr( - ctx.meeting.id, segment_id, result, ctx.segment_start_time + ctx.meeting_db_id, segment_id, result, ctx.segment_start_time ) - _assign_speaker_if_available(ctx.host, ctx.meeting_id, segment) - await ctx.repo.segments.add(ctx.meeting.id, segment) - segments_to_add.append((segment, segment_to_proto_update(ctx.meeting_id, segment))) + _assign_speaker_if_available(ctx.host, ctx.meeting_id_str, segment) + await ctx.repo.segments.add(ctx.meeting_db_id, segment) + segments_to_add.append((segment, segment_to_proto_update(ctx.meeting_id_str, segment))) return segments_to_add diff --git a/src/noteflow/grpc/server/internal/bootstrap.py b/src/noteflow/grpc/server/internal/bootstrap.py index a9122a4..81db387 100644 --- a/src/noteflow/grpc/server/internal/bootstrap.py +++ b/src/noteflow/grpc/server/internal/bootstrap.py @@ -58,6 +58,7 @@ async def create_services( summarization_service=summarization_service, diarization_engine=create_diarization_engine(config.diarization), diarization_refinement_enabled=config.diarization.refinement_enabled, + diarization_auto_refine=config.diarization.auto_refine, ner_service=create_ner_service(session_factory, settings), calendar_service=await create_calendar_service(session_factory, settings), webhook_service=await create_webhook_service(session_factory, settings) diff --git a/src/noteflow/grpc/server/internal/services.py b/src/noteflow/grpc/server/internal/services.py index 5bab4d3..36edd15 100644 --- a/src/noteflow/grpc/server/internal/services.py +++ b/src/noteflow/grpc/server/internal/services.py @@ -30,6 +30,7 @@ class _ServerState(Protocol): session_factory: async_sessionmaker[AsyncSession] | None diarization_engine: DiarizationEngine | None diarization_refinement_enabled: bool + diarization_auto_refine: bool ner_service: NerService | None calendar_service: CalendarService | None webhook_service: WebhookService | None @@ -105,6 +106,7 @@ def build_servicer( summarization_service=state.summarization_service, diarization_engine=state.diarization_engine, diarization_refinement_enabled=state.diarization_refinement_enabled, + diarization_auto_refine=state.diarization_auto_refine, ner_service=state.ner_service, calendar_service=state.calendar_service, webhook_service=state.webhook_service, diff --git a/src/noteflow/grpc/service.py b/src/noteflow/grpc/service.py index c8d6bff..421cb96 100644 --- a/src/noteflow/grpc/service.py +++ b/src/noteflow/grpc/service.py @@ -164,6 +164,7 @@ class NoteFlowServicer( self.summarization_service = services.summarization_service self.diarization_engine = services.diarization_engine self.diarization_refinement_enabled = services.diarization_refinement_enabled + self.diarization_auto_refine = services.diarization_auto_refine self.ner_service = services.ner_service self.calendar_service = services.calendar_service self.webhook_service = services.webhook_service diff --git a/src/noteflow/grpc/startup/startup.py b/src/noteflow/grpc/startup/startup.py index 8b856bf..8dd280a 100644 --- a/src/noteflow/grpc/startup/startup.py +++ b/src/noteflow/grpc/startup/startup.py @@ -97,6 +97,12 @@ class DiarizationConfigLike(Protocol): @property def max_speakers(self) -> int | None: ... + @property + def refinement_enabled(self) -> bool: ... + + @property + def auto_refine(self) -> bool: ... + class GrpcServerConfigLike(Protocol): """Protocol for gRPC server configuration objects.""" diff --git a/src/noteflow/grpc/stream_state.py b/src/noteflow/grpc/stream_state.py index 3b90f3c..cea18e6 100644 --- a/src/noteflow/grpc/stream_state.py +++ b/src/noteflow/grpc/stream_state.py @@ -8,6 +8,7 @@ from __future__ import annotations from dataclasses import dataclass, field from typing import TYPE_CHECKING +from uuid import UUID if TYPE_CHECKING: from noteflow.infrastructure.asr import Segmenter, StreamingVad @@ -51,6 +52,10 @@ class MeetingStreamState: stop_requested: bool = False audio_write_failed: bool = False + # Cached meeting info to reduce per-segment DB overhead + meeting_db_id: UUID | None = None + next_segment_sequence: int = 0 + def increment_segment_id(self) -> int: """Get current segment ID and increment counter. diff --git a/tests/grpc/test_audio_processing.py b/tests/grpc/test_audio_processing.py new file mode 100644 index 0000000..5216dee --- /dev/null +++ b/tests/grpc/test_audio_processing.py @@ -0,0 +1,421 @@ +"""Tests for audio processing helper functions in grpc/mixins/_audio_processing.py. + +Tests cover: +- resample_audio: Linear interpolation resampling +- decode_audio_chunk: Bytes to numpy array conversion +- convert_audio_format: Downmixing and resampling pipeline +- validate_stream_format: Format validation and mid-stream checks +""" + +from __future__ import annotations + +from typing import Final + +import numpy as np +import pytest +from numpy.typing import NDArray + +from noteflow.grpc.mixins._audio_processing import ( + StreamFormatValidation, + convert_audio_format, + decode_audio_chunk, + resample_audio, + validate_stream_format, +) + +# Audio test constants +SAMPLE_RATE_8K: Final[int] = 8000 +SAMPLE_RATE_16K: Final[int] = 16000 +SAMPLE_RATE_44K: Final[int] = 44100 +SAMPLE_RATE_48K: Final[int] = 48000 + +MONO_CHANNELS: Final[int] = 1 +STEREO_CHANNELS: Final[int] = 2 + +SUPPORTED_RATES: Final[frozenset[int]] = frozenset({SAMPLE_RATE_16K, SAMPLE_RATE_44K, SAMPLE_RATE_48K}) +DEFAULT_SAMPLE_RATE: Final[int] = SAMPLE_RATE_16K + + +def generate_sine_wave( + frequency_hz: float, + duration_seconds: float, + sample_rate: int, +) -> NDArray[np.float32]: + """Generate a sine wave test signal. + + Args: + frequency_hz: Frequency of the sine wave. + duration_seconds: Duration in seconds. + sample_rate: Sample rate in Hz. + + Returns: + Float32 numpy array containing the sine wave. + """ + num_samples = int(duration_seconds * sample_rate) + t = np.arange(num_samples) / sample_rate + return np.sin(2 * np.pi * frequency_hz * t).astype(np.float32) + + +class TestResampleAudio: + """Tests for resample_audio function.""" + + def test_upsample_8k_to_16k_preserves_duration(self) -> None: + """Upsampling from 8kHz to 16kHz preserves audio duration.""" + duration_seconds = 0.1 + original = generate_sine_wave(440.0, duration_seconds, SAMPLE_RATE_8K) + expected_length = int(duration_seconds * SAMPLE_RATE_16K) + + resampled = resample_audio(original, SAMPLE_RATE_8K, SAMPLE_RATE_16K) + + assert resampled.shape[0] == expected_length, ( + f"Expected {expected_length} samples, got {resampled.shape[0]}" + ) + + def test_downsample_48k_to_16k_preserves_duration(self) -> None: + """Downsampling from 48kHz to 16kHz preserves audio duration.""" + duration_seconds = 0.1 + original = generate_sine_wave(440.0, duration_seconds, SAMPLE_RATE_48K) + expected_length = int(duration_seconds * SAMPLE_RATE_16K) + + resampled = resample_audio(original, SAMPLE_RATE_48K, SAMPLE_RATE_16K) + + assert resampled.shape[0] == expected_length, ( + f"Expected {expected_length} samples, got {resampled.shape[0]}" + ) + + def test_same_rate_returns_original_unchanged(self) -> None: + """Resampling with same source and destination rate returns original.""" + original = generate_sine_wave(440.0, 0.1, SAMPLE_RATE_16K) + + resampled = resample_audio(original, SAMPLE_RATE_16K, SAMPLE_RATE_16K) + + assert resampled is original, "Same rate should return original array" + + def test_empty_audio_returns_empty_array(self) -> None: + """Resampling empty audio returns empty array.""" + empty_audio = np.array([], dtype=np.float32) + + resampled = resample_audio(empty_audio, SAMPLE_RATE_8K, SAMPLE_RATE_16K) + + assert resampled is empty_audio, "Empty audio should return original empty array" + + def test_resampled_output_is_float32(self) -> None: + """Resampled audio maintains float32 dtype.""" + original = generate_sine_wave(440.0, 0.1, SAMPLE_RATE_8K) + + resampled = resample_audio(original, SAMPLE_RATE_8K, SAMPLE_RATE_16K) + + assert resampled.dtype == np.float32, ( + f"Expected float32 dtype, got {resampled.dtype}" + ) + + @pytest.mark.parametrize( + ("src_rate", "dst_rate", "expected_ratio"), + [ + pytest.param(SAMPLE_RATE_8K, SAMPLE_RATE_16K, 2.0, id="upsample-2x"), + pytest.param(SAMPLE_RATE_48K, SAMPLE_RATE_16K, 1 / 3, id="downsample-3x"), + pytest.param(SAMPLE_RATE_44K, SAMPLE_RATE_16K, 16000 / 44100, id="downsample-44k-to-16k"), + ], + ) + def test_resample_length_matches_ratio( + self, + src_rate: int, + dst_rate: int, + expected_ratio: float, + ) -> None: + """Resampled length matches the rate ratio.""" + num_samples = 1000 + original = np.random.rand(num_samples).astype(np.float32) + expected_length = round(num_samples * expected_ratio) + + resampled = resample_audio(original, src_rate, dst_rate) + + assert resampled.shape[0] == expected_length, ( + f"Expected {expected_length} samples for ratio {expected_ratio}, got {resampled.shape[0]}" + ) + + +class TestDecodeAudioChunk: + """Tests for decode_audio_chunk function.""" + + def test_float32_roundtrip_preserves_values(self) -> None: + """Encoding to bytes and decoding preserves float32 values.""" + original = np.array([0.5, -0.5, 1.0, -1.0, 0.0], dtype=np.float32) + audio_bytes = original.tobytes() + + decoded = decode_audio_chunk(audio_bytes) + + assert decoded is not None, "Decoded audio should not be None" + np.testing.assert_array_equal(decoded, original, err_msg="Roundtrip should preserve values") + + def test_empty_bytes_returns_none(self) -> None: + """Decoding empty bytes returns None.""" + empty_bytes = b"" + + result = decode_audio_chunk(empty_bytes) + + assert result is None, "Empty bytes should return None" + + def test_decoded_dtype_is_float32(self) -> None: + """Decoded array has float32 dtype.""" + original = np.array([0.1, 0.2, 0.3], dtype=np.float32) + audio_bytes = original.tobytes() + + decoded = decode_audio_chunk(audio_bytes) + + assert decoded is not None, "Decoded audio should not be None" + assert decoded.dtype == np.float32, f"Expected float32, got {decoded.dtype}" + + def test_large_chunk_decode(self) -> None: + """Decode large audio chunk successfully.""" + num_samples = 16000 # 1 second at 16kHz + original = np.random.rand(num_samples).astype(np.float32) + audio_bytes = original.tobytes() + + decoded = decode_audio_chunk(audio_bytes) + + assert decoded is not None, "Decoded audio should not be None" + assert decoded.shape[0] == num_samples, ( + f"Expected {num_samples} samples, got {decoded.shape[0]}" + ) + + +class TestConvertAudioFormat: + """Tests for convert_audio_format function.""" + + def test_stereo_to_mono_averages_channels(self) -> None: + """Stereo to mono conversion averages left and right channels.""" + # Create interleaved stereo: [L0, R0, L1, R1, ...] + # Left channel: 1.0, Right channel: 0.0 -> Average: 0.5 + left_samples = np.ones(100, dtype=np.float32) + right_samples = np.zeros(100, dtype=np.float32) + stereo = np.empty(200, dtype=np.float32) + stereo[0::2] = left_samples + stereo[1::2] = right_samples + + mono = convert_audio_format(stereo, SAMPLE_RATE_16K, STEREO_CHANNELS, SAMPLE_RATE_16K) + + expected_value = 0.5 + np.testing.assert_allclose(mono, expected_value, rtol=1e-5, err_msg="Stereo should average to 0.5") + + def test_mono_unchanged_when_single_channel(self) -> None: + """Mono audio passes through without modification when channels=1.""" + original = generate_sine_wave(440.0, 0.1, SAMPLE_RATE_16K) + + result = convert_audio_format(original, SAMPLE_RATE_16K, MONO_CHANNELS, SAMPLE_RATE_16K) + + np.testing.assert_array_equal(result, original, err_msg="Mono should pass through unchanged") + + def test_resample_during_format_conversion(self) -> None: + """Format conversion performs resampling when rates differ.""" + original = generate_sine_wave(440.0, 0.1, SAMPLE_RATE_48K) + expected_length = int(0.1 * SAMPLE_RATE_16K) + + result = convert_audio_format(original, SAMPLE_RATE_48K, MONO_CHANNELS, SAMPLE_RATE_16K) + + assert result.shape[0] == expected_length, ( + f"Expected {expected_length} samples after resampling, got {result.shape[0]}" + ) + + def test_stereo_downmix_then_resample(self) -> None: + """Format conversion downmixes stereo then resamples.""" + duration_seconds = 0.1 + # Stereo at 48kHz + stereo_samples = int(duration_seconds * SAMPLE_RATE_48K * STEREO_CHANNELS) + stereo = np.random.rand(stereo_samples).astype(np.float32) + expected_mono_length = int(duration_seconds * SAMPLE_RATE_16K) + + result = convert_audio_format(stereo, SAMPLE_RATE_48K, STEREO_CHANNELS, SAMPLE_RATE_16K) + + assert result.shape[0] == expected_mono_length, ( + f"Expected {expected_mono_length} samples, got {result.shape[0]}" + ) + + def test_raises_on_buffer_not_divisible_by_channels(self) -> None: + """Raises ValueError when buffer size not divisible by channel count.""" + odd_buffer = np.array([1.0, 2.0, 3.0], dtype=np.float32) # 3 samples, 2 channels + + with pytest.raises(ValueError, match="not divisible by channel count"): + convert_audio_format(odd_buffer, SAMPLE_RATE_16K, STEREO_CHANNELS, SAMPLE_RATE_16K) + + @pytest.mark.parametrize( + ("channels",), + [ + pytest.param(2, id="stereo"), + pytest.param(4, id="quad"), + pytest.param(6, id="5.1-surround"), + ], + ) + def test_multichannel_downmix(self, channels: int) -> None: + """Multi-channel audio downmixes correctly to mono.""" + num_frames = 100 + # All channels have value 1.0, so average should be 1.0 + multichannel = np.ones(num_frames * channels, dtype=np.float32) + + mono = convert_audio_format(multichannel, SAMPLE_RATE_16K, channels, SAMPLE_RATE_16K) + + assert mono.shape[0] == num_frames, f"Expected {num_frames} mono samples" + np.testing.assert_allclose(mono, 1.0, rtol=1e-5, err_msg="Mono average should be 1.0") + + +class TestValidateStreamFormat: + """Tests for validate_stream_format function.""" + + def test_valid_format_returns_normalized_values(self) -> None: + """Valid format request returns normalized rate and channels.""" + request = StreamFormatValidation( + sample_rate=SAMPLE_RATE_16K, + channels=MONO_CHANNELS, + default_sample_rate=DEFAULT_SAMPLE_RATE, + supported_sample_rates=SUPPORTED_RATES, + existing_format=None, + ) + + rate, channels = validate_stream_format(request) + + assert rate == SAMPLE_RATE_16K, f"Expected rate {SAMPLE_RATE_16K}, got {rate}" + assert channels == MONO_CHANNELS, f"Expected channels {MONO_CHANNELS}, got {channels}" + + def test_zero_sample_rate_uses_default(self) -> None: + """Zero sample rate falls back to default sample rate.""" + request = StreamFormatValidation( + sample_rate=0, + channels=MONO_CHANNELS, + default_sample_rate=DEFAULT_SAMPLE_RATE, + supported_sample_rates=SUPPORTED_RATES, + existing_format=None, + ) + + rate, _ = validate_stream_format(request) + + assert rate == DEFAULT_SAMPLE_RATE, ( + f"Expected default rate {DEFAULT_SAMPLE_RATE}, got {rate}" + ) + + def test_zero_channels_defaults_to_mono(self) -> None: + """Zero channels defaults to mono (1 channel).""" + request = StreamFormatValidation( + sample_rate=SAMPLE_RATE_16K, + channels=0, + default_sample_rate=DEFAULT_SAMPLE_RATE, + supported_sample_rates=SUPPORTED_RATES, + existing_format=None, + ) + + _, channels = validate_stream_format(request) + + assert channels == MONO_CHANNELS, f"Expected mono, got {channels} channels" + + def test_raises_on_unsupported_sample_rate(self) -> None: + """Raises ValueError for unsupported sample rate.""" + unsupported_rate = 22050 + request = StreamFormatValidation( + sample_rate=unsupported_rate, + channels=MONO_CHANNELS, + default_sample_rate=DEFAULT_SAMPLE_RATE, + supported_sample_rates=SUPPORTED_RATES, + existing_format=None, + ) + + with pytest.raises(ValueError, match="Unsupported sample_rate"): + validate_stream_format(request) + + def test_raises_on_negative_channels(self) -> None: + """Raises ValueError for negative channel count.""" + request = StreamFormatValidation( + sample_rate=SAMPLE_RATE_16K, + channels=-1, + default_sample_rate=DEFAULT_SAMPLE_RATE, + supported_sample_rates=SUPPORTED_RATES, + existing_format=None, + ) + + with pytest.raises(ValueError, match="channels must be >= 1"): + validate_stream_format(request) + + def test_raises_on_mid_stream_rate_change(self) -> None: + """Raises ValueError when sample rate changes mid-stream.""" + existing_rate = SAMPLE_RATE_44K + new_rate = SAMPLE_RATE_16K + request = StreamFormatValidation( + sample_rate=new_rate, + channels=MONO_CHANNELS, + default_sample_rate=DEFAULT_SAMPLE_RATE, + supported_sample_rates=SUPPORTED_RATES, + existing_format=(existing_rate, MONO_CHANNELS), + ) + + with pytest.raises(ValueError, match="cannot change mid-stream"): + validate_stream_format(request) + + def test_raises_on_mid_stream_channel_change(self) -> None: + """Raises ValueError when channel count changes mid-stream.""" + request = StreamFormatValidation( + sample_rate=SAMPLE_RATE_16K, + channels=STEREO_CHANNELS, + default_sample_rate=DEFAULT_SAMPLE_RATE, + supported_sample_rates=SUPPORTED_RATES, + existing_format=(SAMPLE_RATE_16K, MONO_CHANNELS), + ) + + with pytest.raises(ValueError, match="cannot change mid-stream"): + validate_stream_format(request) + + def test_accepts_matching_existing_format(self) -> None: + """Accepts format when it matches existing stream format.""" + request = StreamFormatValidation( + sample_rate=SAMPLE_RATE_16K, + channels=MONO_CHANNELS, + default_sample_rate=DEFAULT_SAMPLE_RATE, + supported_sample_rates=SUPPORTED_RATES, + existing_format=(SAMPLE_RATE_16K, MONO_CHANNELS), + ) + + rate, channels = validate_stream_format(request) + + assert rate == SAMPLE_RATE_16K, "Rate should match existing" + assert channels == MONO_CHANNELS, "Channels should match existing" + + @pytest.mark.parametrize( + ("sample_rate",), + [ + pytest.param(SAMPLE_RATE_16K, id="16kHz"), + pytest.param(SAMPLE_RATE_44K, id="44.1kHz"), + pytest.param(SAMPLE_RATE_48K, id="48kHz"), + ], + ) + def test_accepts_all_supported_rates(self, sample_rate: int) -> None: + """All rates in supported_sample_rates are accepted.""" + request = StreamFormatValidation( + sample_rate=sample_rate, + channels=MONO_CHANNELS, + default_sample_rate=DEFAULT_SAMPLE_RATE, + supported_sample_rates=SUPPORTED_RATES, + existing_format=None, + ) + + rate, _ = validate_stream_format(request) + + assert rate == sample_rate, f"Expected rate {sample_rate} to be accepted" + + @pytest.mark.parametrize( + ("channels",), + [ + pytest.param(1, id="mono"), + pytest.param(2, id="stereo"), + pytest.param(6, id="5.1-surround"), + ], + ) + def test_accepts_positive_channel_counts(self, channels: int) -> None: + """Positive channel counts are accepted.""" + request = StreamFormatValidation( + sample_rate=SAMPLE_RATE_16K, + channels=channels, + default_sample_rate=DEFAULT_SAMPLE_RATE, + supported_sample_rates=SUPPORTED_RATES, + existing_format=None, + ) + + _, result_channels = validate_stream_format(request) + + assert result_channels == channels, f"Expected {channels} channels" diff --git a/tests/grpc/test_meeting_mixin.py b/tests/grpc/test_meeting_mixin.py index 51ec5f9..54ac35b 100644 --- a/tests/grpc/test_meeting_mixin.py +++ b/tests/grpc/test_meeting_mixin.py @@ -134,6 +134,8 @@ class MockMeetingMixinServicerHost(MeetingMixin): self.webhook_service = webhook_service self.project_service = None self.summarization_service = None # Post-processing disabled in tests + self.diarization_auto_refine = False # Auto-diarization disabled in tests + self.diarization_engine = None def create_repository_provider(self) -> MockMeetingRepositoryProvider: """Create mock repository provider context manager.""" diff --git a/tests/grpc/test_sprint_15_1_critical_bugs.py b/tests/grpc/test_sprint_15_1_critical_bugs.py index 1cd4918..13b7e97 100644 --- a/tests/grpc/test_sprint_15_1_critical_bugs.py +++ b/tests/grpc/test_sprint_15_1_critical_bugs.py @@ -89,7 +89,7 @@ class TestStreamInitRaceCondition: def teststream_init_lock_in_protocol(self) -> None: """Verify stream_init_lock is declared in ServicerHost protocol.""" - protocol_path = Path("src/noteflow/grpc/_mixins/protocols.py") + protocol_path = Path("src/noteflow/grpc/mixins/protocols.py") content = protocol_path.read_text() assert "stream_init_lock: asyncio.Lock" in content, ( @@ -99,7 +99,7 @@ class TestStreamInitRaceCondition: def test_stream_init_uses_lock(self) -> None: """Verify _init_stream_for_meeting uses the lock.""" # Check in the session manager module (streaming package) - session_path = Path("src/noteflow/grpc/_mixins/streaming/_session.py") + session_path = Path("src/noteflow/grpc/mixins/streaming/_session.py") content = session_path.read_text() assert "async with host.stream_init_lock:" in content, ( @@ -112,7 +112,7 @@ class TestStopMeetingIdempotency: def test_idempotency_guard_in_code(self) -> None: """Verify the idempotency guard code exists in StopMeeting.""" - meeting_path = Path("src/noteflow/grpc/_mixins/meeting/meeting_mixin.py") + meeting_path = Path("src/noteflow/grpc/mixins/meeting/meeting_mixin.py") content = meeting_path.read_text() # Verify the idempotency guard pattern exists @@ -136,7 +136,7 @@ class TestStopMeetingIdempotency: After refactoring, the helper function returns the domain entity directly, and the caller (StopMeeting) converts to proto. """ - meeting_path = Path("src/noteflow/grpc/_mixins/meeting/meeting_mixin.py") + meeting_path = Path("src/noteflow/grpc/mixins/meeting/meeting_mixin.py") content = meeting_path.read_text() # Verify the guard pattern exists: checks terminal_states and returns context.meeting diff --git a/tests/grpc/test_sync_orchestration.py b/tests/grpc/test_sync_orchestration.py index 6178446..5a5876e 100644 --- a/tests/grpc/test_sync_orchestration.py +++ b/tests/grpc/test_sync_orchestration.py @@ -70,7 +70,7 @@ class _GetSyncStatusResponse(Protocol): status: str items_synced: int items_total: int - error_message: str + error_code: int # SyncErrorCode enum value duration_ms: int expires_at: str @@ -506,7 +506,10 @@ class TestSyncErrorHandling: status = await await_sync_completion(servicer, start.sync_run_id, context) assert status.status == "error", "Sync should report error" - assert "OAuth token expired" in status.error_message, "Error should be captured" + # Auth errors should return SYNC_ERROR_CODE_AUTH_REQUIRED + assert status.error_code == noteflow_pb2.SYNC_ERROR_CODE_AUTH_REQUIRED, ( + f"Expected auth_required error code, got {status.error_code}" + ) @pytest.mark.asyncio async def test_first_sync_fails(