Files
noteflow/scripts/ab_streaming_harness.py
Travis Vasceannie 100ca5596b
Some checks failed
CI / test-python (push) Failing after 16m26s
CI / test-rust (push) Has been cancelled
CI / test-typescript (push) Has been cancelled
mac
2026-01-24 12:47:35 -05:00

557 lines
18 KiB
Python

#!/usr/bin/env python3
"""A/B harness for streaming configuration latency and WER comparisons."""
from __future__ import annotations
import argparse
import json
import re
import threading
import time
import wave
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Iterable, Protocol, cast
from uuid import uuid4
import numpy as np
from numpy.typing import NDArray
from noteflow.config.settings import get_settings
from noteflow.grpc.client import NoteFlowClient
from noteflow.grpc.proto import noteflow_pb2
from noteflow.infrastructure.audio.reader import MeetingAudioReader
from noteflow.infrastructure.logging import LoggingConfig, configure_logging
from noteflow.infrastructure.security.crypto import AesGcmCryptoBox
from noteflow.infrastructure.security.keystore import KeyringKeyStore
if TYPE_CHECKING:
from noteflow.grpc.types import TranscriptSegment
PRESETS: dict[str, dict[str, float]] = {
"responsive": {
"partial_cadence_seconds": 0.8,
"min_partial_audio_seconds": 0.3,
"max_segment_duration_seconds": 15.0,
"min_speech_duration_seconds": 0.2,
"trailing_silence_seconds": 0.3,
"leading_buffer_seconds": 0.1,
},
"balanced": {
"partial_cadence_seconds": 1.5,
"min_partial_audio_seconds": 0.5,
"max_segment_duration_seconds": 30.0,
"min_speech_duration_seconds": 0.3,
"trailing_silence_seconds": 0.5,
"leading_buffer_seconds": 0.2,
},
"accurate": {
"partial_cadence_seconds": 2.5,
"min_partial_audio_seconds": 0.8,
"max_segment_duration_seconds": 45.0,
"min_speech_duration_seconds": 0.4,
"trailing_silence_seconds": 0.8,
"leading_buffer_seconds": 0.25,
},
}
@dataclass(frozen=True)
class AudioCase:
label: str
audio: NDArray[np.float32]
sample_rate: int
reference: str | None
@dataclass
class StreamingStats:
first_partial_at: float | None = None
first_final_at: float | None = None
first_segment_persisted_at: float | None = None
partial_count: int = 0
final_count: int = 0
@dataclass
class RunResult:
label: str
meeting_id: str
audio_duration_s: float
wall_time_s: float
first_partial_latency_s: float | None
first_final_latency_s: float | None
first_segment_persisted_latency_s: float | None
transcript: str
wer: float | None
segments: int
class _RequestIdStub:
def __init__(self, stub: object, request_id: str) -> None:
self._stub = stub
self._metadata: tuple[tuple[str, str], ...] = (("x-request-id", request_id),)
def __getattr__(self, name: str) -> object:
attr = getattr(self._stub, name)
if not callable(attr):
return attr
def _wrapped(*args: object, **kwargs: object) -> object:
metadata = cast("tuple[tuple[str, str], ...] | None", kwargs.pop("metadata", None))
if metadata is not None:
kwargs["metadata"] = metadata + self._metadata
else:
kwargs["metadata"] = self._metadata
return attr(*args, **kwargs)
return _wrapped
def _attach_request_id(client: NoteFlowClient, request_id: str | None = None) -> str:
rid = request_id or str(uuid4())
stub = client.require_connection()
setattr(client, "_stub", _RequestIdStub(stub, rid))
return rid
class StreamingConfigStub(Protocol):
def GetStreamingConfiguration(
self,
request: noteflow_pb2.GetStreamingConfigurationRequest,
) -> noteflow_pb2.GetStreamingConfigurationResponse: ...
def UpdateStreamingConfiguration(
self,
request: noteflow_pb2.UpdateStreamingConfigurationRequest,
) -> noteflow_pb2.UpdateStreamingConfigurationResponse: ...
def _normalize_text(text: str) -> list[str]:
cleaned = re.sub(r"[^a-z0-9']+", " ", text.lower())
return [token for token in cleaned.split() if token]
def _word_error_rate(reference: str, hypothesis: str) -> float:
ref_tokens = _normalize_text(reference)
hyp_tokens = _normalize_text(hypothesis)
if not ref_tokens:
return 0.0 if not hyp_tokens else 1.0
rows = len(ref_tokens) + 1
cols = len(hyp_tokens) + 1
dp = [[0] * cols for _ in range(rows)]
for i in range(rows):
dp[i][0] = i
for j in range(cols):
dp[0][j] = j
for i in range(1, rows):
for j in range(1, cols):
cost = 0 if ref_tokens[i - 1] == hyp_tokens[j - 1] else 1
dp[i][j] = min(
dp[i - 1][j] + 1,
dp[i][j - 1] + 1,
dp[i - 1][j - 1] + cost,
)
return dp[-1][-1] / len(ref_tokens)
def _load_wav(path: Path) -> tuple[NDArray[np.float32], int]:
with wave.open(str(path), "rb") as wav_file:
channels = wav_file.getnchannels()
sample_rate = wav_file.getframerate()
sample_width = wav_file.getsampwidth()
frame_count = wav_file.getnframes()
raw = wav_file.readframes(frame_count)
if sample_width != 2:
raise ValueError("Only 16-bit PCM WAV files are supported")
pcm16 = np.frombuffer(raw, dtype=np.int16)
if channels > 1:
pcm16 = pcm16.reshape(-1, channels).mean(axis=1).astype(np.int16)
audio = pcm16.astype(np.float32) / 32767.0
return audio, sample_rate
def _load_meeting_audio(
meeting_id: str,
asset_path: str | None,
meetings_dir: Path,
) -> tuple[NDArray[np.float32], int]:
crypto = AesGcmCryptoBox(KeyringKeyStore())
reader = MeetingAudioReader(crypto, meetings_dir)
chunks = reader.load_meeting_audio(meeting_id, asset_path)
if not chunks:
return np.array([], dtype=np.float32), reader.sample_rate
audio = np.concatenate([chunk.frames for chunk in chunks]).astype(np.float32)
return audio, reader.sample_rate
def _chunk_audio(
audio: NDArray[np.float32],
sample_rate: int,
chunk_ms: int,
) -> Iterable[NDArray[np.float32]]:
chunk_size = max(1, int(sample_rate * (chunk_ms / 1000)))
for start in range(0, audio.shape[0], chunk_size):
yield audio[start : start + chunk_size]
def _poll_first_segment_persisted(
server_address: str,
meeting_id: str,
interval_s: float,
stats: StreamingStats,
lock: threading.Lock,
stop_event: threading.Event,
) -> None:
if interval_s <= 0:
return
poll_client = NoteFlowClient(server_address=server_address)
if not poll_client.connect():
return
_attach_request_id(poll_client)
try:
while not stop_event.is_set():
try:
segments = poll_client.get_meeting_segments(meeting_id)
except Exception:
time.sleep(interval_s)
continue
if segments:
now = time.time()
with lock:
if stats.first_segment_persisted_at is None:
stats.first_segment_persisted_at = now
break
time.sleep(interval_s)
finally:
poll_client.disconnect()
def _get_streaming_config(stub: StreamingConfigStub) -> dict[str, float]:
response = stub.GetStreamingConfiguration(noteflow_pb2.GetStreamingConfigurationRequest())
config = response.configuration
return {
"partial_cadence_seconds": config.partial_cadence_seconds,
"min_partial_audio_seconds": config.min_partial_audio_seconds,
"max_segment_duration_seconds": config.max_segment_duration_seconds,
"min_speech_duration_seconds": config.min_speech_duration_seconds,
"trailing_silence_seconds": config.trailing_silence_seconds,
"leading_buffer_seconds": config.leading_buffer_seconds,
}
def _apply_streaming_config(
stub: StreamingConfigStub,
config: dict[str, float],
) -> None:
request = noteflow_pb2.UpdateStreamingConfigurationRequest(**config)
stub.UpdateStreamingConfiguration(request)
def _read_reference(reference_path: str | None) -> str | None:
if not reference_path:
return None
path = Path(reference_path)
if not path.exists():
raise FileNotFoundError(f"Reference file not found: {path}")
return path.read_text(encoding="utf-8")
def _make_case(
label: str,
meeting_id: str | None,
asset_path: str | None,
wav_path: str | None,
reference_path: str | None,
) -> AudioCase:
settings = get_settings()
reference = _read_reference(reference_path)
if meeting_id:
audio, sample_rate = _load_meeting_audio(
meeting_id,
asset_path,
Path(settings.meetings_dir),
)
elif wav_path:
audio, sample_rate = _load_wav(Path(wav_path))
else:
raise ValueError("Either meeting_id or wav_path must be provided.")
return AudioCase(
label=label,
audio=audio,
sample_rate=sample_rate,
reference=reference,
)
def _run_streaming_case(
client: NoteFlowClient,
case: AudioCase,
config_label: str,
config: dict[str, float],
chunk_ms: int,
realtime: bool,
final_wait_seconds: float,
segment_poll_ms: int,
) -> RunResult:
stub = client.require_connection()
_apply_streaming_config(stub, config)
meeting_title = f"AB {case.label} [{config_label}]"
meeting = client.create_meeting(meeting_title)
if meeting is None:
raise RuntimeError("Failed to create meeting")
stats = StreamingStats()
lock = threading.Lock()
poll_stop = threading.Event()
poll_thread: threading.Thread | None = None
def on_transcript(segment: TranscriptSegment) -> None:
now = time.time()
with lock:
if segment.is_final:
stats.final_count += 1
if stats.first_final_at is None:
stats.first_final_at = now
else:
stats.partial_count += 1
if stats.first_partial_at is None:
stats.first_partial_at = now
client.on_transcript = on_transcript
if segment_poll_ms > 0:
poll_thread = threading.Thread(
target=_poll_first_segment_persisted,
args=(
client.server_address,
meeting.id,
segment_poll_ms / 1000.0,
stats,
lock,
poll_stop,
),
daemon=True,
)
poll_thread.start()
if not client.start_streaming(meeting.id):
raise RuntimeError("Failed to start streaming")
start_time = time.time()
sent_samples = 0
try:
for chunk in _chunk_audio(case.audio, case.sample_rate, chunk_ms):
if realtime:
target_time = start_time + (sent_samples / case.sample_rate)
sleep_for = target_time - time.time()
if sleep_for > 0:
time.sleep(sleep_for)
while not client.send_audio(chunk, timestamp=time.time()):
time.sleep(0.01)
sent_samples += chunk.shape[0]
finally:
time.sleep(final_wait_seconds)
client.stop_streaming()
client.stop_meeting(meeting.id)
poll_stop.set()
if poll_thread is not None:
poll_thread.join(timeout=2.0)
segments = client.get_meeting_segments(meeting.id)
transcript = " ".join(seg.text.strip() for seg in segments if seg.text.strip())
end_time = time.time()
audio_duration = case.audio.shape[0] / case.sample_rate if case.sample_rate else 0.0
first_partial_latency = (
(stats.first_partial_at - start_time) if stats.first_partial_at else None
)
first_final_latency = (
(stats.first_final_at - start_time) if stats.first_final_at else None
)
first_segment_latency = (
(stats.first_segment_persisted_at - start_time)
if stats.first_segment_persisted_at
else None
)
wer = _word_error_rate(case.reference, transcript) if case.reference else None
return RunResult(
label=f"{case.label}:{config_label}",
meeting_id=meeting.id,
audio_duration_s=audio_duration,
wall_time_s=end_time - start_time,
first_partial_latency_s=first_partial_latency,
first_final_latency_s=first_final_latency,
first_segment_persisted_latency_s=first_segment_latency,
transcript=transcript,
wer=wer,
segments=len(segments),
)
def _load_cases_from_json(path: str) -> list[AudioCase]:
raw = json.loads(Path(path).read_text(encoding="utf-8"))
if not isinstance(raw, list):
raise ValueError("Cases JSON must be a list of case objects.")
payload = cast(list[object], raw)
entries: list[dict[str, object]] = []
for item in payload:
if not isinstance(item, dict):
raise ValueError("Each case must be an object.")
entries.append(cast(dict[str, object], item))
cases: list[AudioCase] = []
for entry_dict in entries:
cases.append(
_make_case(
label=str(entry_dict.get("label", "case")),
meeting_id=cast(str | None, entry_dict.get("meeting_id")),
asset_path=cast(str | None, entry_dict.get("asset_path")),
wav_path=cast(str | None, entry_dict.get("wav_path")),
reference_path=cast(str | None, entry_dict.get("reference_path")),
)
)
return cases
def _format_latency(value: float | None) -> str:
if value is None:
return "n/a"
return f"{value:.2f}s"
def _print_results(results: list[RunResult]) -> None:
for result in results:
print("")
print(f"Case {result.label}")
print(f" meeting_id: {result.meeting_id}")
print(f" audio_duration_s: {result.audio_duration_s:.2f}")
print(f" wall_time_s: {result.wall_time_s:.2f}")
print(f" first_partial_latency: {_format_latency(result.first_partial_latency_s)}")
print(f" first_final_latency: {_format_latency(result.first_final_latency_s)}")
print(
" first_segment_persisted_latency: "
f"{_format_latency(result.first_segment_persisted_latency_s)}"
)
print(f" segments: {result.segments}")
if result.wer is not None:
print(f" WER: {result.wer:.3f}")
def _build_config(value: str | None, label: str) -> dict[str, float]:
if value is None:
raise ValueError(f"Missing config for {label}")
if value in PRESETS:
return PRESETS[value]
path = Path(value)
if path.exists():
data = json.loads(path.read_text(encoding="utf-8"))
if not isinstance(data, dict):
raise ValueError(f"Config file must be an object: {path}")
payload = cast(dict[str, object], data)
config: dict[str, float] = {}
for key, raw_value in payload.items():
if isinstance(raw_value, (int, float)):
config[str(key)] = float(raw_value)
if not config:
raise ValueError(f"Config file has no numeric values: {path}")
return config
raise ValueError(f"Unknown preset or config path: {value}")
def main() -> None:
parser = argparse.ArgumentParser(description="A/B harness for streaming config.")
parser.add_argument("--server", default="localhost:50051")
parser.add_argument("--meeting-id", help="Meeting ID to replay audio from.")
parser.add_argument("--asset-path", help="Override meeting asset path.")
parser.add_argument("--wav", help="WAV file to stream instead of a meeting.")
parser.add_argument("--reference", help="Reference transcript text file.")
parser.add_argument(
"--cases",
help="JSON file describing multiple cases (label, meeting_id/wav_path, reference_path).",
)
parser.add_argument("--preset-a", default="responsive")
parser.add_argument("--preset-b", default="balanced")
parser.add_argument("--chunk-ms", type=int, default=200)
parser.add_argument("--realtime", action="store_true")
parser.add_argument("--final-wait", type=float, default=2.0)
parser.add_argument(
"--segment-poll-ms",
type=int,
default=0,
help="Poll for persisted segments to measure DB ingestion latency.",
)
args = parser.parse_args()
configure_logging(LoggingConfig(level="INFO"))
if args.cases:
cases = _load_cases_from_json(args.cases)
else:
cases = [
_make_case(
label="sample",
meeting_id=args.meeting_id,
asset_path=args.asset_path,
wav_path=args.wav,
reference_path=args.reference,
)
]
config_a = _build_config(args.preset_a, "A")
config_b = _build_config(args.preset_b, "B")
client = NoteFlowClient(server_address=args.server)
if not client.connect():
raise RuntimeError(f"Unable to connect to server at {args.server}")
_attach_request_id(client)
stub = cast(StreamingConfigStub, client.require_connection())
original_config = _get_streaming_config(stub)
results: list[RunResult] = []
try:
for case in cases:
results.append(
_run_streaming_case(
client,
case,
"A",
config_a,
args.chunk_ms,
args.realtime,
args.final_wait,
args.segment_poll_ms,
)
)
results.append(
_run_streaming_case(
client,
case,
"B",
config_b,
args.chunk_ms,
args.realtime,
args.final_wait,
args.segment_poll_ms,
)
)
finally:
_apply_streaming_config(stub, original_config)
client.disconnect()
_print_results(results)
if __name__ == "__main__":
main()