- Moved all hookify configuration files from `.claude/` to `.claude/hooks/` subdirectory for better organization - Added four new blocking hooks to prevent common error handling anti-patterns: - `block-broad-exception-handler`: Prevents catching generic `Exception` with only logging - `block-datetime-now-fallback`: Blocks returning `datetime.now()` as fallback on parse failures to prevent data corruption - `block-default
452 lines
14 KiB
Python
452 lines
14 KiB
Python
#!/usr/bin/env python3
|
|
"""A/B harness for streaming configuration latency and WER comparisons."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import re
|
|
import threading
|
|
import time
|
|
import wave
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Iterable, Protocol, cast
|
|
|
|
import numpy as np
|
|
from numpy.typing import NDArray
|
|
|
|
from noteflow.config.settings import get_settings
|
|
from noteflow.grpc.client import NoteFlowClient
|
|
from noteflow.grpc.proto import noteflow_pb2
|
|
from noteflow.infrastructure.audio.reader import MeetingAudioReader
|
|
from noteflow.infrastructure.logging import LoggingConfig, configure_logging
|
|
from noteflow.infrastructure.security.crypto import AesGcmCryptoBox
|
|
from noteflow.infrastructure.security.keystore import KeyringKeyStore
|
|
|
|
if TYPE_CHECKING:
|
|
from noteflow.grpc.types import TranscriptSegment
|
|
|
|
PRESETS: dict[str, dict[str, float]] = {
|
|
"responsive": {
|
|
"partial_cadence_seconds": 0.8,
|
|
"min_partial_audio_seconds": 0.3,
|
|
"max_segment_duration_seconds": 15.0,
|
|
"min_speech_duration_seconds": 0.2,
|
|
"trailing_silence_seconds": 0.3,
|
|
"leading_buffer_seconds": 0.1,
|
|
},
|
|
"balanced": {
|
|
"partial_cadence_seconds": 1.5,
|
|
"min_partial_audio_seconds": 0.5,
|
|
"max_segment_duration_seconds": 30.0,
|
|
"min_speech_duration_seconds": 0.3,
|
|
"trailing_silence_seconds": 0.5,
|
|
"leading_buffer_seconds": 0.2,
|
|
},
|
|
"accurate": {
|
|
"partial_cadence_seconds": 2.5,
|
|
"min_partial_audio_seconds": 0.8,
|
|
"max_segment_duration_seconds": 45.0,
|
|
"min_speech_duration_seconds": 0.4,
|
|
"trailing_silence_seconds": 0.8,
|
|
"leading_buffer_seconds": 0.25,
|
|
},
|
|
}
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class AudioCase:
|
|
label: str
|
|
audio: NDArray[np.float32]
|
|
sample_rate: int
|
|
reference: str | None
|
|
|
|
|
|
@dataclass
|
|
class StreamingStats:
|
|
first_partial_at: float | None = None
|
|
first_final_at: float | None = None
|
|
partial_count: int = 0
|
|
final_count: int = 0
|
|
|
|
|
|
@dataclass
|
|
class RunResult:
|
|
label: str
|
|
meeting_id: str
|
|
audio_duration_s: float
|
|
wall_time_s: float
|
|
first_partial_latency_s: float | None
|
|
first_final_latency_s: float | None
|
|
transcript: str
|
|
wer: float | None
|
|
segments: int
|
|
|
|
|
|
class StreamingConfigStub(Protocol):
|
|
def GetStreamingConfiguration(
|
|
self,
|
|
request: noteflow_pb2.GetStreamingConfigurationRequest,
|
|
) -> noteflow_pb2.GetStreamingConfigurationResponse: ...
|
|
|
|
def UpdateStreamingConfiguration(
|
|
self,
|
|
request: noteflow_pb2.UpdateStreamingConfigurationRequest,
|
|
) -> noteflow_pb2.UpdateStreamingConfigurationResponse: ...
|
|
|
|
|
|
def _normalize_text(text: str) -> list[str]:
|
|
cleaned = re.sub(r"[^a-z0-9']+", " ", text.lower())
|
|
return [token for token in cleaned.split() if token]
|
|
|
|
|
|
def _word_error_rate(reference: str, hypothesis: str) -> float:
|
|
ref_tokens = _normalize_text(reference)
|
|
hyp_tokens = _normalize_text(hypothesis)
|
|
if not ref_tokens:
|
|
return 0.0 if not hyp_tokens else 1.0
|
|
|
|
rows = len(ref_tokens) + 1
|
|
cols = len(hyp_tokens) + 1
|
|
dp = [[0] * cols for _ in range(rows)]
|
|
for i in range(rows):
|
|
dp[i][0] = i
|
|
for j in range(cols):
|
|
dp[0][j] = j
|
|
for i in range(1, rows):
|
|
for j in range(1, cols):
|
|
cost = 0 if ref_tokens[i - 1] == hyp_tokens[j - 1] else 1
|
|
dp[i][j] = min(
|
|
dp[i - 1][j] + 1,
|
|
dp[i][j - 1] + 1,
|
|
dp[i - 1][j - 1] + cost,
|
|
)
|
|
return dp[-1][-1] / len(ref_tokens)
|
|
|
|
|
|
def _load_wav(path: Path) -> tuple[NDArray[np.float32], int]:
|
|
with wave.open(str(path), "rb") as wav_file:
|
|
channels = wav_file.getnchannels()
|
|
sample_rate = wav_file.getframerate()
|
|
sample_width = wav_file.getsampwidth()
|
|
frame_count = wav_file.getnframes()
|
|
raw = wav_file.readframes(frame_count)
|
|
|
|
if sample_width != 2:
|
|
raise ValueError("Only 16-bit PCM WAV files are supported")
|
|
|
|
pcm16 = np.frombuffer(raw, dtype=np.int16)
|
|
if channels > 1:
|
|
pcm16 = pcm16.reshape(-1, channels).mean(axis=1).astype(np.int16)
|
|
audio = pcm16.astype(np.float32) / 32767.0
|
|
return audio, sample_rate
|
|
|
|
|
|
def _load_meeting_audio(
|
|
meeting_id: str,
|
|
asset_path: str | None,
|
|
meetings_dir: Path,
|
|
) -> tuple[NDArray[np.float32], int]:
|
|
crypto = AesGcmCryptoBox(KeyringKeyStore())
|
|
reader = MeetingAudioReader(crypto, meetings_dir)
|
|
chunks = reader.load_meeting_audio(meeting_id, asset_path)
|
|
if not chunks:
|
|
return np.array([], dtype=np.float32), reader.sample_rate
|
|
audio = np.concatenate([chunk.frames for chunk in chunks]).astype(np.float32)
|
|
return audio, reader.sample_rate
|
|
|
|
|
|
def _chunk_audio(
|
|
audio: NDArray[np.float32],
|
|
sample_rate: int,
|
|
chunk_ms: int,
|
|
) -> Iterable[NDArray[np.float32]]:
|
|
chunk_size = max(1, int(sample_rate * (chunk_ms / 1000)))
|
|
for start in range(0, audio.shape[0], chunk_size):
|
|
yield audio[start : start + chunk_size]
|
|
|
|
|
|
def _get_streaming_config(stub: StreamingConfigStub) -> dict[str, float]:
|
|
response = stub.GetStreamingConfiguration(noteflow_pb2.GetStreamingConfigurationRequest())
|
|
config = response.configuration
|
|
return {
|
|
"partial_cadence_seconds": config.partial_cadence_seconds,
|
|
"min_partial_audio_seconds": config.min_partial_audio_seconds,
|
|
"max_segment_duration_seconds": config.max_segment_duration_seconds,
|
|
"min_speech_duration_seconds": config.min_speech_duration_seconds,
|
|
"trailing_silence_seconds": config.trailing_silence_seconds,
|
|
"leading_buffer_seconds": config.leading_buffer_seconds,
|
|
}
|
|
|
|
|
|
def _apply_streaming_config(
|
|
stub: StreamingConfigStub,
|
|
config: dict[str, float],
|
|
) -> None:
|
|
request = noteflow_pb2.UpdateStreamingConfigurationRequest(**config)
|
|
stub.UpdateStreamingConfiguration(request)
|
|
|
|
|
|
def _read_reference(reference_path: str | None) -> str | None:
|
|
if not reference_path:
|
|
return None
|
|
path = Path(reference_path)
|
|
if not path.exists():
|
|
raise FileNotFoundError(f"Reference file not found: {path}")
|
|
return path.read_text(encoding="utf-8")
|
|
|
|
|
|
def _make_case(
|
|
label: str,
|
|
meeting_id: str | None,
|
|
asset_path: str | None,
|
|
wav_path: str | None,
|
|
reference_path: str | None,
|
|
) -> AudioCase:
|
|
settings = get_settings()
|
|
reference = _read_reference(reference_path)
|
|
if meeting_id:
|
|
audio, sample_rate = _load_meeting_audio(
|
|
meeting_id,
|
|
asset_path,
|
|
Path(settings.meetings_dir),
|
|
)
|
|
elif wav_path:
|
|
audio, sample_rate = _load_wav(Path(wav_path))
|
|
else:
|
|
raise ValueError("Either meeting_id or wav_path must be provided.")
|
|
|
|
return AudioCase(
|
|
label=label,
|
|
audio=audio,
|
|
sample_rate=sample_rate,
|
|
reference=reference,
|
|
)
|
|
|
|
|
|
def _run_streaming_case(
|
|
client: NoteFlowClient,
|
|
case: AudioCase,
|
|
config_label: str,
|
|
config: dict[str, float],
|
|
chunk_ms: int,
|
|
realtime: bool,
|
|
final_wait_seconds: float,
|
|
) -> RunResult:
|
|
stub = client.require_connection()
|
|
_apply_streaming_config(stub, config)
|
|
|
|
meeting_title = f"AB {case.label} [{config_label}]"
|
|
meeting = client.create_meeting(meeting_title)
|
|
if meeting is None:
|
|
raise RuntimeError("Failed to create meeting")
|
|
|
|
stats = StreamingStats()
|
|
lock = threading.Lock()
|
|
|
|
def on_transcript(segment: TranscriptSegment) -> None:
|
|
now = time.time()
|
|
with lock:
|
|
if segment.is_final:
|
|
stats.final_count += 1
|
|
if stats.first_final_at is None:
|
|
stats.first_final_at = now
|
|
else:
|
|
stats.partial_count += 1
|
|
if stats.first_partial_at is None:
|
|
stats.first_partial_at = now
|
|
|
|
client.on_transcript = on_transcript
|
|
if not client.start_streaming(meeting.id):
|
|
raise RuntimeError("Failed to start streaming")
|
|
|
|
start_time = time.time()
|
|
sent_samples = 0
|
|
|
|
try:
|
|
for chunk in _chunk_audio(case.audio, case.sample_rate, chunk_ms):
|
|
if realtime:
|
|
target_time = start_time + (sent_samples / case.sample_rate)
|
|
sleep_for = target_time - time.time()
|
|
if sleep_for > 0:
|
|
time.sleep(sleep_for)
|
|
|
|
while not client.send_audio(chunk, timestamp=time.time()):
|
|
time.sleep(0.01)
|
|
|
|
sent_samples += chunk.shape[0]
|
|
finally:
|
|
time.sleep(final_wait_seconds)
|
|
client.stop_streaming()
|
|
client.stop_meeting(meeting.id)
|
|
|
|
segments = client.get_meeting_segments(meeting.id)
|
|
transcript = " ".join(seg.text.strip() for seg in segments if seg.text.strip())
|
|
end_time = time.time()
|
|
audio_duration = case.audio.shape[0] / case.sample_rate if case.sample_rate else 0.0
|
|
|
|
first_partial_latency = (
|
|
(stats.first_partial_at - start_time) if stats.first_partial_at else None
|
|
)
|
|
first_final_latency = (
|
|
(stats.first_final_at - start_time) if stats.first_final_at else None
|
|
)
|
|
wer = _word_error_rate(case.reference, transcript) if case.reference else None
|
|
|
|
return RunResult(
|
|
label=f"{case.label}:{config_label}",
|
|
meeting_id=meeting.id,
|
|
audio_duration_s=audio_duration,
|
|
wall_time_s=end_time - start_time,
|
|
first_partial_latency_s=first_partial_latency,
|
|
first_final_latency_s=first_final_latency,
|
|
transcript=transcript,
|
|
wer=wer,
|
|
segments=len(segments),
|
|
)
|
|
|
|
|
|
def _load_cases_from_json(path: str) -> list[AudioCase]:
|
|
raw = json.loads(Path(path).read_text(encoding="utf-8"))
|
|
if not isinstance(raw, list):
|
|
raise ValueError("Cases JSON must be a list of case objects.")
|
|
payload = cast(list[object], raw)
|
|
entries: list[dict[str, object]] = []
|
|
for item in payload:
|
|
if not isinstance(item, dict):
|
|
raise ValueError("Each case must be an object.")
|
|
entries.append(cast(dict[str, object], item))
|
|
cases: list[AudioCase] = []
|
|
for entry_dict in entries:
|
|
cases.append(
|
|
_make_case(
|
|
label=str(entry_dict.get("label", "case")),
|
|
meeting_id=cast(str | None, entry_dict.get("meeting_id")),
|
|
asset_path=cast(str | None, entry_dict.get("asset_path")),
|
|
wav_path=cast(str | None, entry_dict.get("wav_path")),
|
|
reference_path=cast(str | None, entry_dict.get("reference_path")),
|
|
)
|
|
)
|
|
return cases
|
|
|
|
|
|
def _format_latency(value: float | None) -> str:
|
|
if value is None:
|
|
return "n/a"
|
|
return f"{value:.2f}s"
|
|
|
|
|
|
def _print_results(results: list[RunResult]) -> None:
|
|
for result in results:
|
|
print("")
|
|
print(f"Case {result.label}")
|
|
print(f" meeting_id: {result.meeting_id}")
|
|
print(f" audio_duration_s: {result.audio_duration_s:.2f}")
|
|
print(f" wall_time_s: {result.wall_time_s:.2f}")
|
|
print(f" first_partial_latency: {_format_latency(result.first_partial_latency_s)}")
|
|
print(f" first_final_latency: {_format_latency(result.first_final_latency_s)}")
|
|
print(f" segments: {result.segments}")
|
|
if result.wer is not None:
|
|
print(f" WER: {result.wer:.3f}")
|
|
|
|
|
|
def _build_config(value: str | None, label: str) -> dict[str, float]:
|
|
if value is None:
|
|
raise ValueError(f"Missing config for {label}")
|
|
if value in PRESETS:
|
|
return PRESETS[value]
|
|
path = Path(value)
|
|
if path.exists():
|
|
data = json.loads(path.read_text(encoding="utf-8"))
|
|
if not isinstance(data, dict):
|
|
raise ValueError(f"Config file must be an object: {path}")
|
|
payload = cast(dict[str, object], data)
|
|
config: dict[str, float] = {}
|
|
for key, raw_value in payload.items():
|
|
if isinstance(raw_value, (int, float)):
|
|
config[str(key)] = float(raw_value)
|
|
if not config:
|
|
raise ValueError(f"Config file has no numeric values: {path}")
|
|
return config
|
|
raise ValueError(f"Unknown preset or config path: {value}")
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="A/B harness for streaming config.")
|
|
parser.add_argument("--server", default="localhost:50051")
|
|
parser.add_argument("--meeting-id", help="Meeting ID to replay audio from.")
|
|
parser.add_argument("--asset-path", help="Override meeting asset path.")
|
|
parser.add_argument("--wav", help="WAV file to stream instead of a meeting.")
|
|
parser.add_argument("--reference", help="Reference transcript text file.")
|
|
parser.add_argument(
|
|
"--cases",
|
|
help="JSON file describing multiple cases (label, meeting_id/wav_path, reference_path).",
|
|
)
|
|
parser.add_argument("--preset-a", default="responsive")
|
|
parser.add_argument("--preset-b", default="balanced")
|
|
parser.add_argument("--chunk-ms", type=int, default=200)
|
|
parser.add_argument("--realtime", action="store_true")
|
|
parser.add_argument("--final-wait", type=float, default=2.0)
|
|
args = parser.parse_args()
|
|
|
|
configure_logging(LoggingConfig(level="INFO"))
|
|
|
|
if args.cases:
|
|
cases = _load_cases_from_json(args.cases)
|
|
else:
|
|
cases = [
|
|
_make_case(
|
|
label="sample",
|
|
meeting_id=args.meeting_id,
|
|
asset_path=args.asset_path,
|
|
wav_path=args.wav,
|
|
reference_path=args.reference,
|
|
)
|
|
]
|
|
|
|
config_a = _build_config(args.preset_a, "A")
|
|
config_b = _build_config(args.preset_b, "B")
|
|
|
|
client = NoteFlowClient(server_address=args.server)
|
|
if not client.connect():
|
|
raise RuntimeError(f"Unable to connect to server at {args.server}")
|
|
|
|
stub = cast(StreamingConfigStub, client.require_connection())
|
|
original_config = _get_streaming_config(stub)
|
|
|
|
results: list[RunResult] = []
|
|
try:
|
|
for case in cases:
|
|
results.append(
|
|
_run_streaming_case(
|
|
client,
|
|
case,
|
|
"A",
|
|
config_a,
|
|
args.chunk_ms,
|
|
args.realtime,
|
|
args.final_wait,
|
|
)
|
|
)
|
|
results.append(
|
|
_run_streaming_case(
|
|
client,
|
|
case,
|
|
"B",
|
|
config_b,
|
|
args.chunk_ms,
|
|
args.realtime,
|
|
args.final_wait,
|
|
)
|
|
)
|
|
finally:
|
|
_apply_streaming_config(stub, original_config)
|
|
client.disconnect()
|
|
|
|
_print_results(results)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|