Files
noteflow/scripts/ab_streaming_harness.py
Travis Vasceannie 1ce24cdf7b feat: reorganize Claude hooks and add RAG documentation structure with error handling policies
- Moved all hookify configuration files from `.claude/` to `.claude/hooks/` subdirectory for better organization
- Added four new blocking hooks to prevent common error handling anti-patterns:
  - `block-broad-exception-handler`: Prevents catching generic `Exception` with only logging
  - `block-datetime-now-fallback`: Blocks returning `datetime.now()` as fallback on parse failures to prevent data corruption
  - `block-default
2026-01-15 15:58:06 +00:00

452 lines
14 KiB
Python

#!/usr/bin/env python3
"""A/B harness for streaming configuration latency and WER comparisons."""
from __future__ import annotations
import argparse
import json
import re
import threading
import time
import wave
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Iterable, Protocol, cast
import numpy as np
from numpy.typing import NDArray
from noteflow.config.settings import get_settings
from noteflow.grpc.client import NoteFlowClient
from noteflow.grpc.proto import noteflow_pb2
from noteflow.infrastructure.audio.reader import MeetingAudioReader
from noteflow.infrastructure.logging import LoggingConfig, configure_logging
from noteflow.infrastructure.security.crypto import AesGcmCryptoBox
from noteflow.infrastructure.security.keystore import KeyringKeyStore
if TYPE_CHECKING:
from noteflow.grpc.types import TranscriptSegment
PRESETS: dict[str, dict[str, float]] = {
"responsive": {
"partial_cadence_seconds": 0.8,
"min_partial_audio_seconds": 0.3,
"max_segment_duration_seconds": 15.0,
"min_speech_duration_seconds": 0.2,
"trailing_silence_seconds": 0.3,
"leading_buffer_seconds": 0.1,
},
"balanced": {
"partial_cadence_seconds": 1.5,
"min_partial_audio_seconds": 0.5,
"max_segment_duration_seconds": 30.0,
"min_speech_duration_seconds": 0.3,
"trailing_silence_seconds": 0.5,
"leading_buffer_seconds": 0.2,
},
"accurate": {
"partial_cadence_seconds": 2.5,
"min_partial_audio_seconds": 0.8,
"max_segment_duration_seconds": 45.0,
"min_speech_duration_seconds": 0.4,
"trailing_silence_seconds": 0.8,
"leading_buffer_seconds": 0.25,
},
}
@dataclass(frozen=True)
class AudioCase:
label: str
audio: NDArray[np.float32]
sample_rate: int
reference: str | None
@dataclass
class StreamingStats:
first_partial_at: float | None = None
first_final_at: float | None = None
partial_count: int = 0
final_count: int = 0
@dataclass
class RunResult:
label: str
meeting_id: str
audio_duration_s: float
wall_time_s: float
first_partial_latency_s: float | None
first_final_latency_s: float | None
transcript: str
wer: float | None
segments: int
class StreamingConfigStub(Protocol):
def GetStreamingConfiguration(
self,
request: noteflow_pb2.GetStreamingConfigurationRequest,
) -> noteflow_pb2.GetStreamingConfigurationResponse: ...
def UpdateStreamingConfiguration(
self,
request: noteflow_pb2.UpdateStreamingConfigurationRequest,
) -> noteflow_pb2.UpdateStreamingConfigurationResponse: ...
def _normalize_text(text: str) -> list[str]:
cleaned = re.sub(r"[^a-z0-9']+", " ", text.lower())
return [token for token in cleaned.split() if token]
def _word_error_rate(reference: str, hypothesis: str) -> float:
ref_tokens = _normalize_text(reference)
hyp_tokens = _normalize_text(hypothesis)
if not ref_tokens:
return 0.0 if not hyp_tokens else 1.0
rows = len(ref_tokens) + 1
cols = len(hyp_tokens) + 1
dp = [[0] * cols for _ in range(rows)]
for i in range(rows):
dp[i][0] = i
for j in range(cols):
dp[0][j] = j
for i in range(1, rows):
for j in range(1, cols):
cost = 0 if ref_tokens[i - 1] == hyp_tokens[j - 1] else 1
dp[i][j] = min(
dp[i - 1][j] + 1,
dp[i][j - 1] + 1,
dp[i - 1][j - 1] + cost,
)
return dp[-1][-1] / len(ref_tokens)
def _load_wav(path: Path) -> tuple[NDArray[np.float32], int]:
with wave.open(str(path), "rb") as wav_file:
channels = wav_file.getnchannels()
sample_rate = wav_file.getframerate()
sample_width = wav_file.getsampwidth()
frame_count = wav_file.getnframes()
raw = wav_file.readframes(frame_count)
if sample_width != 2:
raise ValueError("Only 16-bit PCM WAV files are supported")
pcm16 = np.frombuffer(raw, dtype=np.int16)
if channels > 1:
pcm16 = pcm16.reshape(-1, channels).mean(axis=1).astype(np.int16)
audio = pcm16.astype(np.float32) / 32767.0
return audio, sample_rate
def _load_meeting_audio(
meeting_id: str,
asset_path: str | None,
meetings_dir: Path,
) -> tuple[NDArray[np.float32], int]:
crypto = AesGcmCryptoBox(KeyringKeyStore())
reader = MeetingAudioReader(crypto, meetings_dir)
chunks = reader.load_meeting_audio(meeting_id, asset_path)
if not chunks:
return np.array([], dtype=np.float32), reader.sample_rate
audio = np.concatenate([chunk.frames for chunk in chunks]).astype(np.float32)
return audio, reader.sample_rate
def _chunk_audio(
audio: NDArray[np.float32],
sample_rate: int,
chunk_ms: int,
) -> Iterable[NDArray[np.float32]]:
chunk_size = max(1, int(sample_rate * (chunk_ms / 1000)))
for start in range(0, audio.shape[0], chunk_size):
yield audio[start : start + chunk_size]
def _get_streaming_config(stub: StreamingConfigStub) -> dict[str, float]:
response = stub.GetStreamingConfiguration(noteflow_pb2.GetStreamingConfigurationRequest())
config = response.configuration
return {
"partial_cadence_seconds": config.partial_cadence_seconds,
"min_partial_audio_seconds": config.min_partial_audio_seconds,
"max_segment_duration_seconds": config.max_segment_duration_seconds,
"min_speech_duration_seconds": config.min_speech_duration_seconds,
"trailing_silence_seconds": config.trailing_silence_seconds,
"leading_buffer_seconds": config.leading_buffer_seconds,
}
def _apply_streaming_config(
stub: StreamingConfigStub,
config: dict[str, float],
) -> None:
request = noteflow_pb2.UpdateStreamingConfigurationRequest(**config)
stub.UpdateStreamingConfiguration(request)
def _read_reference(reference_path: str | None) -> str | None:
if not reference_path:
return None
path = Path(reference_path)
if not path.exists():
raise FileNotFoundError(f"Reference file not found: {path}")
return path.read_text(encoding="utf-8")
def _make_case(
label: str,
meeting_id: str | None,
asset_path: str | None,
wav_path: str | None,
reference_path: str | None,
) -> AudioCase:
settings = get_settings()
reference = _read_reference(reference_path)
if meeting_id:
audio, sample_rate = _load_meeting_audio(
meeting_id,
asset_path,
Path(settings.meetings_dir),
)
elif wav_path:
audio, sample_rate = _load_wav(Path(wav_path))
else:
raise ValueError("Either meeting_id or wav_path must be provided.")
return AudioCase(
label=label,
audio=audio,
sample_rate=sample_rate,
reference=reference,
)
def _run_streaming_case(
client: NoteFlowClient,
case: AudioCase,
config_label: str,
config: dict[str, float],
chunk_ms: int,
realtime: bool,
final_wait_seconds: float,
) -> RunResult:
stub = client.require_connection()
_apply_streaming_config(stub, config)
meeting_title = f"AB {case.label} [{config_label}]"
meeting = client.create_meeting(meeting_title)
if meeting is None:
raise RuntimeError("Failed to create meeting")
stats = StreamingStats()
lock = threading.Lock()
def on_transcript(segment: TranscriptSegment) -> None:
now = time.time()
with lock:
if segment.is_final:
stats.final_count += 1
if stats.first_final_at is None:
stats.first_final_at = now
else:
stats.partial_count += 1
if stats.first_partial_at is None:
stats.first_partial_at = now
client.on_transcript = on_transcript
if not client.start_streaming(meeting.id):
raise RuntimeError("Failed to start streaming")
start_time = time.time()
sent_samples = 0
try:
for chunk in _chunk_audio(case.audio, case.sample_rate, chunk_ms):
if realtime:
target_time = start_time + (sent_samples / case.sample_rate)
sleep_for = target_time - time.time()
if sleep_for > 0:
time.sleep(sleep_for)
while not client.send_audio(chunk, timestamp=time.time()):
time.sleep(0.01)
sent_samples += chunk.shape[0]
finally:
time.sleep(final_wait_seconds)
client.stop_streaming()
client.stop_meeting(meeting.id)
segments = client.get_meeting_segments(meeting.id)
transcript = " ".join(seg.text.strip() for seg in segments if seg.text.strip())
end_time = time.time()
audio_duration = case.audio.shape[0] / case.sample_rate if case.sample_rate else 0.0
first_partial_latency = (
(stats.first_partial_at - start_time) if stats.first_partial_at else None
)
first_final_latency = (
(stats.first_final_at - start_time) if stats.first_final_at else None
)
wer = _word_error_rate(case.reference, transcript) if case.reference else None
return RunResult(
label=f"{case.label}:{config_label}",
meeting_id=meeting.id,
audio_duration_s=audio_duration,
wall_time_s=end_time - start_time,
first_partial_latency_s=first_partial_latency,
first_final_latency_s=first_final_latency,
transcript=transcript,
wer=wer,
segments=len(segments),
)
def _load_cases_from_json(path: str) -> list[AudioCase]:
raw = json.loads(Path(path).read_text(encoding="utf-8"))
if not isinstance(raw, list):
raise ValueError("Cases JSON must be a list of case objects.")
payload = cast(list[object], raw)
entries: list[dict[str, object]] = []
for item in payload:
if not isinstance(item, dict):
raise ValueError("Each case must be an object.")
entries.append(cast(dict[str, object], item))
cases: list[AudioCase] = []
for entry_dict in entries:
cases.append(
_make_case(
label=str(entry_dict.get("label", "case")),
meeting_id=cast(str | None, entry_dict.get("meeting_id")),
asset_path=cast(str | None, entry_dict.get("asset_path")),
wav_path=cast(str | None, entry_dict.get("wav_path")),
reference_path=cast(str | None, entry_dict.get("reference_path")),
)
)
return cases
def _format_latency(value: float | None) -> str:
if value is None:
return "n/a"
return f"{value:.2f}s"
def _print_results(results: list[RunResult]) -> None:
for result in results:
print("")
print(f"Case {result.label}")
print(f" meeting_id: {result.meeting_id}")
print(f" audio_duration_s: {result.audio_duration_s:.2f}")
print(f" wall_time_s: {result.wall_time_s:.2f}")
print(f" first_partial_latency: {_format_latency(result.first_partial_latency_s)}")
print(f" first_final_latency: {_format_latency(result.first_final_latency_s)}")
print(f" segments: {result.segments}")
if result.wer is not None:
print(f" WER: {result.wer:.3f}")
def _build_config(value: str | None, label: str) -> dict[str, float]:
if value is None:
raise ValueError(f"Missing config for {label}")
if value in PRESETS:
return PRESETS[value]
path = Path(value)
if path.exists():
data = json.loads(path.read_text(encoding="utf-8"))
if not isinstance(data, dict):
raise ValueError(f"Config file must be an object: {path}")
payload = cast(dict[str, object], data)
config: dict[str, float] = {}
for key, raw_value in payload.items():
if isinstance(raw_value, (int, float)):
config[str(key)] = float(raw_value)
if not config:
raise ValueError(f"Config file has no numeric values: {path}")
return config
raise ValueError(f"Unknown preset or config path: {value}")
def main() -> None:
parser = argparse.ArgumentParser(description="A/B harness for streaming config.")
parser.add_argument("--server", default="localhost:50051")
parser.add_argument("--meeting-id", help="Meeting ID to replay audio from.")
parser.add_argument("--asset-path", help="Override meeting asset path.")
parser.add_argument("--wav", help="WAV file to stream instead of a meeting.")
parser.add_argument("--reference", help="Reference transcript text file.")
parser.add_argument(
"--cases",
help="JSON file describing multiple cases (label, meeting_id/wav_path, reference_path).",
)
parser.add_argument("--preset-a", default="responsive")
parser.add_argument("--preset-b", default="balanced")
parser.add_argument("--chunk-ms", type=int, default=200)
parser.add_argument("--realtime", action="store_true")
parser.add_argument("--final-wait", type=float, default=2.0)
args = parser.parse_args()
configure_logging(LoggingConfig(level="INFO"))
if args.cases:
cases = _load_cases_from_json(args.cases)
else:
cases = [
_make_case(
label="sample",
meeting_id=args.meeting_id,
asset_path=args.asset_path,
wav_path=args.wav,
reference_path=args.reference,
)
]
config_a = _build_config(args.preset_a, "A")
config_b = _build_config(args.preset_b, "B")
client = NoteFlowClient(server_address=args.server)
if not client.connect():
raise RuntimeError(f"Unable to connect to server at {args.server}")
stub = cast(StreamingConfigStub, client.require_connection())
original_config = _get_streaming_config(stub)
results: list[RunResult] = []
try:
for case in cases:
results.append(
_run_streaming_case(
client,
case,
"A",
config_a,
args.chunk_ms,
args.realtime,
args.final_wait,
)
)
results.append(
_run_streaming_case(
client,
case,
"B",
config_b,
args.chunk_ms,
args.realtime,
args.final_wait,
)
)
finally:
_apply_streaming_config(stub, original_config)
client.disconnect()
_print_results(results)
if __name__ == "__main__":
main()