#!/usr/bin/env python3 """A/B harness for streaming configuration latency and WER comparisons.""" from __future__ import annotations import argparse import json import re import threading import time import wave from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Iterable, Protocol, cast from uuid import uuid4 import numpy as np from numpy.typing import NDArray from noteflow.config.settings import get_settings from noteflow.grpc.client import NoteFlowClient from noteflow.grpc.proto import noteflow_pb2 from noteflow.infrastructure.audio.reader import MeetingAudioReader from noteflow.infrastructure.logging import LoggingConfig, configure_logging from noteflow.infrastructure.security.crypto import AesGcmCryptoBox from noteflow.infrastructure.security.keystore import KeyringKeyStore if TYPE_CHECKING: from noteflow.grpc.types import TranscriptSegment PRESETS: dict[str, dict[str, float]] = { "responsive": { "partial_cadence_seconds": 0.8, "min_partial_audio_seconds": 0.3, "max_segment_duration_seconds": 15.0, "min_speech_duration_seconds": 0.2, "trailing_silence_seconds": 0.3, "leading_buffer_seconds": 0.1, }, "balanced": { "partial_cadence_seconds": 1.5, "min_partial_audio_seconds": 0.5, "max_segment_duration_seconds": 30.0, "min_speech_duration_seconds": 0.3, "trailing_silence_seconds": 0.5, "leading_buffer_seconds": 0.2, }, "accurate": { "partial_cadence_seconds": 2.5, "min_partial_audio_seconds": 0.8, "max_segment_duration_seconds": 45.0, "min_speech_duration_seconds": 0.4, "trailing_silence_seconds": 0.8, "leading_buffer_seconds": 0.25, }, } @dataclass(frozen=True) class AudioCase: label: str audio: NDArray[np.float32] sample_rate: int reference: str | None @dataclass class StreamingStats: first_partial_at: float | None = None first_final_at: float | None = None first_segment_persisted_at: float | None = None partial_count: int = 0 final_count: int = 0 @dataclass class RunResult: label: str meeting_id: str audio_duration_s: float wall_time_s: float first_partial_latency_s: float | None first_final_latency_s: float | None first_segment_persisted_latency_s: float | None transcript: str wer: float | None segments: int class _RequestIdStub: def __init__(self, stub: object, request_id: str) -> None: self._stub = stub self._metadata: tuple[tuple[str, str], ...] = (("x-request-id", request_id),) def __getattr__(self, name: str) -> object: attr = getattr(self._stub, name) if not callable(attr): return attr def _wrapped(*args: object, **kwargs: object) -> object: metadata = cast("tuple[tuple[str, str], ...] | None", kwargs.pop("metadata", None)) if metadata is not None: kwargs["metadata"] = metadata + self._metadata else: kwargs["metadata"] = self._metadata return attr(*args, **kwargs) return _wrapped def _attach_request_id(client: NoteFlowClient, request_id: str | None = None) -> str: rid = request_id or str(uuid4()) stub = client.require_connection() setattr(client, "_stub", _RequestIdStub(stub, rid)) return rid class StreamingConfigStub(Protocol): def GetStreamingConfiguration( self, request: noteflow_pb2.GetStreamingConfigurationRequest, ) -> noteflow_pb2.GetStreamingConfigurationResponse: ... def UpdateStreamingConfiguration( self, request: noteflow_pb2.UpdateStreamingConfigurationRequest, ) -> noteflow_pb2.UpdateStreamingConfigurationResponse: ... def _normalize_text(text: str) -> list[str]: cleaned = re.sub(r"[^a-z0-9']+", " ", text.lower()) return [token for token in cleaned.split() if token] def _word_error_rate(reference: str, hypothesis: str) -> float: ref_tokens = _normalize_text(reference) hyp_tokens = _normalize_text(hypothesis) if not ref_tokens: return 0.0 if not hyp_tokens else 1.0 rows = len(ref_tokens) + 1 cols = len(hyp_tokens) + 1 dp = [[0] * cols for _ in range(rows)] for i in range(rows): dp[i][0] = i for j in range(cols): dp[0][j] = j for i in range(1, rows): for j in range(1, cols): cost = 0 if ref_tokens[i - 1] == hyp_tokens[j - 1] else 1 dp[i][j] = min( dp[i - 1][j] + 1, dp[i][j - 1] + 1, dp[i - 1][j - 1] + cost, ) return dp[-1][-1] / len(ref_tokens) def _load_wav(path: Path) -> tuple[NDArray[np.float32], int]: with wave.open(str(path), "rb") as wav_file: channels = wav_file.getnchannels() sample_rate = wav_file.getframerate() sample_width = wav_file.getsampwidth() frame_count = wav_file.getnframes() raw = wav_file.readframes(frame_count) if sample_width != 2: raise ValueError("Only 16-bit PCM WAV files are supported") pcm16 = np.frombuffer(raw, dtype=np.int16) if channels > 1: pcm16 = pcm16.reshape(-1, channels).mean(axis=1).astype(np.int16) audio = pcm16.astype(np.float32) / 32767.0 return audio, sample_rate def _load_meeting_audio( meeting_id: str, asset_path: str | None, meetings_dir: Path, ) -> tuple[NDArray[np.float32], int]: crypto = AesGcmCryptoBox(KeyringKeyStore()) reader = MeetingAudioReader(crypto, meetings_dir) chunks = reader.load_meeting_audio(meeting_id, asset_path) if not chunks: return np.array([], dtype=np.float32), reader.sample_rate audio = np.concatenate([chunk.frames for chunk in chunks]).astype(np.float32) return audio, reader.sample_rate def _chunk_audio( audio: NDArray[np.float32], sample_rate: int, chunk_ms: int, ) -> Iterable[NDArray[np.float32]]: chunk_size = max(1, int(sample_rate * (chunk_ms / 1000))) for start in range(0, audio.shape[0], chunk_size): yield audio[start : start + chunk_size] def _poll_first_segment_persisted( server_address: str, meeting_id: str, interval_s: float, stats: StreamingStats, lock: threading.Lock, stop_event: threading.Event, ) -> None: if interval_s <= 0: return poll_client = NoteFlowClient(server_address=server_address) if not poll_client.connect(): return _attach_request_id(poll_client) try: while not stop_event.is_set(): try: segments = poll_client.get_meeting_segments(meeting_id) except Exception: time.sleep(interval_s) continue if segments: now = time.time() with lock: if stats.first_segment_persisted_at is None: stats.first_segment_persisted_at = now break time.sleep(interval_s) finally: poll_client.disconnect() def _get_streaming_config(stub: StreamingConfigStub) -> dict[str, float]: response = stub.GetStreamingConfiguration(noteflow_pb2.GetStreamingConfigurationRequest()) config = response.configuration return { "partial_cadence_seconds": config.partial_cadence_seconds, "min_partial_audio_seconds": config.min_partial_audio_seconds, "max_segment_duration_seconds": config.max_segment_duration_seconds, "min_speech_duration_seconds": config.min_speech_duration_seconds, "trailing_silence_seconds": config.trailing_silence_seconds, "leading_buffer_seconds": config.leading_buffer_seconds, } def _apply_streaming_config( stub: StreamingConfigStub, config: dict[str, float], ) -> None: request = noteflow_pb2.UpdateStreamingConfigurationRequest(**config) stub.UpdateStreamingConfiguration(request) def _read_reference(reference_path: str | None) -> str | None: if not reference_path: return None path = Path(reference_path) if not path.exists(): raise FileNotFoundError(f"Reference file not found: {path}") return path.read_text(encoding="utf-8") def _make_case( label: str, meeting_id: str | None, asset_path: str | None, wav_path: str | None, reference_path: str | None, ) -> AudioCase: settings = get_settings() reference = _read_reference(reference_path) if meeting_id: audio, sample_rate = _load_meeting_audio( meeting_id, asset_path, Path(settings.meetings_dir), ) elif wav_path: audio, sample_rate = _load_wav(Path(wav_path)) else: raise ValueError("Either meeting_id or wav_path must be provided.") return AudioCase( label=label, audio=audio, sample_rate=sample_rate, reference=reference, ) def _run_streaming_case( client: NoteFlowClient, case: AudioCase, config_label: str, config: dict[str, float], chunk_ms: int, realtime: bool, final_wait_seconds: float, segment_poll_ms: int, ) -> RunResult: stub = client.require_connection() _apply_streaming_config(stub, config) meeting_title = f"AB {case.label} [{config_label}]" meeting = client.create_meeting(meeting_title) if meeting is None: raise RuntimeError("Failed to create meeting") stats = StreamingStats() lock = threading.Lock() poll_stop = threading.Event() poll_thread: threading.Thread | None = None def on_transcript(segment: TranscriptSegment) -> None: now = time.time() with lock: if segment.is_final: stats.final_count += 1 if stats.first_final_at is None: stats.first_final_at = now else: stats.partial_count += 1 if stats.first_partial_at is None: stats.first_partial_at = now client.on_transcript = on_transcript if segment_poll_ms > 0: poll_thread = threading.Thread( target=_poll_first_segment_persisted, args=( client.server_address, meeting.id, segment_poll_ms / 1000.0, stats, lock, poll_stop, ), daemon=True, ) poll_thread.start() if not client.start_streaming(meeting.id): raise RuntimeError("Failed to start streaming") start_time = time.time() sent_samples = 0 try: for chunk in _chunk_audio(case.audio, case.sample_rate, chunk_ms): if realtime: target_time = start_time + (sent_samples / case.sample_rate) sleep_for = target_time - time.time() if sleep_for > 0: time.sleep(sleep_for) while not client.send_audio(chunk, timestamp=time.time()): time.sleep(0.01) sent_samples += chunk.shape[0] finally: time.sleep(final_wait_seconds) client.stop_streaming() client.stop_meeting(meeting.id) poll_stop.set() if poll_thread is not None: poll_thread.join(timeout=2.0) segments = client.get_meeting_segments(meeting.id) transcript = " ".join(seg.text.strip() for seg in segments if seg.text.strip()) end_time = time.time() audio_duration = case.audio.shape[0] / case.sample_rate if case.sample_rate else 0.0 first_partial_latency = ( (stats.first_partial_at - start_time) if stats.first_partial_at else None ) first_final_latency = ( (stats.first_final_at - start_time) if stats.first_final_at else None ) first_segment_latency = ( (stats.first_segment_persisted_at - start_time) if stats.first_segment_persisted_at else None ) wer = _word_error_rate(case.reference, transcript) if case.reference else None return RunResult( label=f"{case.label}:{config_label}", meeting_id=meeting.id, audio_duration_s=audio_duration, wall_time_s=end_time - start_time, first_partial_latency_s=first_partial_latency, first_final_latency_s=first_final_latency, first_segment_persisted_latency_s=first_segment_latency, transcript=transcript, wer=wer, segments=len(segments), ) def _load_cases_from_json(path: str) -> list[AudioCase]: raw = json.loads(Path(path).read_text(encoding="utf-8")) if not isinstance(raw, list): raise ValueError("Cases JSON must be a list of case objects.") payload = cast(list[object], raw) entries: list[dict[str, object]] = [] for item in payload: if not isinstance(item, dict): raise ValueError("Each case must be an object.") entries.append(cast(dict[str, object], item)) cases: list[AudioCase] = [] for entry_dict in entries: cases.append( _make_case( label=str(entry_dict.get("label", "case")), meeting_id=cast(str | None, entry_dict.get("meeting_id")), asset_path=cast(str | None, entry_dict.get("asset_path")), wav_path=cast(str | None, entry_dict.get("wav_path")), reference_path=cast(str | None, entry_dict.get("reference_path")), ) ) return cases def _format_latency(value: float | None) -> str: if value is None: return "n/a" return f"{value:.2f}s" def _print_results(results: list[RunResult]) -> None: for result in results: print("") print(f"Case {result.label}") print(f" meeting_id: {result.meeting_id}") print(f" audio_duration_s: {result.audio_duration_s:.2f}") print(f" wall_time_s: {result.wall_time_s:.2f}") print(f" first_partial_latency: {_format_latency(result.first_partial_latency_s)}") print(f" first_final_latency: {_format_latency(result.first_final_latency_s)}") print( " first_segment_persisted_latency: " f"{_format_latency(result.first_segment_persisted_latency_s)}" ) print(f" segments: {result.segments}") if result.wer is not None: print(f" WER: {result.wer:.3f}") def _build_config(value: str | None, label: str) -> dict[str, float]: if value is None: raise ValueError(f"Missing config for {label}") if value in PRESETS: return PRESETS[value] path = Path(value) if path.exists(): data = json.loads(path.read_text(encoding="utf-8")) if not isinstance(data, dict): raise ValueError(f"Config file must be an object: {path}") payload = cast(dict[str, object], data) config: dict[str, float] = {} for key, raw_value in payload.items(): if isinstance(raw_value, (int, float)): config[str(key)] = float(raw_value) if not config: raise ValueError(f"Config file has no numeric values: {path}") return config raise ValueError(f"Unknown preset or config path: {value}") def main() -> None: parser = argparse.ArgumentParser(description="A/B harness for streaming config.") parser.add_argument("--server", default="localhost:50051") parser.add_argument("--meeting-id", help="Meeting ID to replay audio from.") parser.add_argument("--asset-path", help="Override meeting asset path.") parser.add_argument("--wav", help="WAV file to stream instead of a meeting.") parser.add_argument("--reference", help="Reference transcript text file.") parser.add_argument( "--cases", help="JSON file describing multiple cases (label, meeting_id/wav_path, reference_path).", ) parser.add_argument("--preset-a", default="responsive") parser.add_argument("--preset-b", default="balanced") parser.add_argument("--chunk-ms", type=int, default=200) parser.add_argument("--realtime", action="store_true") parser.add_argument("--final-wait", type=float, default=2.0) parser.add_argument( "--segment-poll-ms", type=int, default=0, help="Poll for persisted segments to measure DB ingestion latency.", ) args = parser.parse_args() configure_logging(LoggingConfig(level="INFO")) if args.cases: cases = _load_cases_from_json(args.cases) else: cases = [ _make_case( label="sample", meeting_id=args.meeting_id, asset_path=args.asset_path, wav_path=args.wav, reference_path=args.reference, ) ] config_a = _build_config(args.preset_a, "A") config_b = _build_config(args.preset_b, "B") client = NoteFlowClient(server_address=args.server) if not client.connect(): raise RuntimeError(f"Unable to connect to server at {args.server}") _attach_request_id(client) stub = cast(StreamingConfigStub, client.require_connection()) original_config = _get_streaming_config(stub) results: list[RunResult] = [] try: for case in cases: results.append( _run_streaming_case( client, case, "A", config_a, args.chunk_ms, args.realtime, args.final_wait, args.segment_poll_ms, ) ) results.append( _run_streaming_case( client, case, "B", config_b, args.chunk_ms, args.realtime, args.final_wait, args.segment_poll_ms, ) ) finally: _apply_streaming_config(stub, original_config) client.disconnect() _print_results(results) if __name__ == "__main__": main()