feat: add comprehensive performance profiling script for backend

- Introduced `profile_comprehensive.py` for detailed performance analysis of the NoteFlow backend, covering audio processing, ORM conversions, Protobuf operations, async overhead, and gRPC simulations.
- Implemented options for cProfile, verbose output, and memory profiling to enhance profiling capabilities.
- Updated `asr_config_service.py` to improve engine reference handling and added observability tracing during reconfiguration.
- Enhanced gRPC service shutdown procedures to include cancellation of sync tasks and improved lifecycle management.
- Refactored various components to ensure proper cleanup and resource management during shutdown.
- Updated client submodule to the latest commit for improved integration.
This commit is contained in:
2026-01-14 01:18:43 -05:00
parent 0c1dbb362f
commit adb0de4446
27 changed files with 2608 additions and 336 deletions

2
client

Submodule client updated: 2a2449be30...81756e545e

View File

@@ -0,0 +1,589 @@
#!/usr/bin/env python
"""Comprehensive performance profiling for NoteFlow backend.
Run with: python scripts/profile_comprehensive.py [--profile] [--verbose] [--memory]
Profiles:
- Audio processing pipeline (VAD, segmentation, RMS)
- ORM/Domain conversions
- Protobuf operations
- Async context manager overhead
- gRPC request simulation
- Memory usage (RSS) and GC pressure
Options:
--profile Enable cProfile for detailed function-level analysis
--verbose Show extended profile output
--memory Enable detailed memory profiling (RSS, GC stats)
"""
from __future__ import annotations
import argparse
import asyncio
import cProfile
import gc
import io
import pstats
import sys
import time
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, cast
from uuid import uuid4
import numpy as np
from numpy.typing import NDArray
if TYPE_CHECKING:
from collections.abc import AsyncIterator, Callable
# =============================================================================
# Constants
# =============================================================================
SAMPLE_RATE = 16000
CHUNK_SIZE = 1600 # 100ms at 16kHz
CHUNKS_PER_SECOND = SAMPLE_RATE // CHUNK_SIZE
BYTES_PER_KB = 1024
BYTES_PER_MB = 1024 * 1024
LINUX_RSS_KB_MULTIPLIER = 1024 # resource.ru_maxrss returns KB on Linux
AudioChunk = NDArray[np.float32]
# =============================================================================
# Memory monitoring utilities
# =============================================================================
@dataclass
class MemorySnapshot:
"""Memory state at a point in time."""
rss_bytes: int
gc_gen0: int
gc_gen1: int
gc_gen2: int
timestamp: float = field(default_factory=time.perf_counter)
@property
def rss_mb(self) -> float:
"""RSS in megabytes."""
return self.rss_bytes / BYTES_PER_MB if self.rss_bytes >= 0 else -1.0
@dataclass
class MemoryMetrics:
"""Memory metrics for a benchmark run."""
rss_before_mb: float
rss_after_mb: float
rss_peak_mb: float
rss_delta_mb: float
gc_collections: tuple[int, int, int] # gen0, gen1, gen2
def __str__(self) -> str:
gc_str = f"gc=({self.gc_collections[0]},{self.gc_collections[1]},{self.gc_collections[2]})"
return (
f"RSS: {self.rss_before_mb:.1f}{self.rss_after_mb:.1f}MB "
f"(peak={self.rss_peak_mb:.1f}MB, Δ={self.rss_delta_mb:+.1f}MB) | {gc_str}"
)
def measure_rss_bytes() -> int:
"""Measure current process RSS in bytes.
Returns:
RSS in bytes, or -1 if not supported.
"""
try:
import psutil
return psutil.Process().memory_info().rss
except ImportError:
pass
if sys.platform in ("darwin", "linux"):
try:
import resource
usage = resource.getrusage(resource.RUSAGE_SELF)
if sys.platform == "linux":
return usage.ru_maxrss * LINUX_RSS_KB_MULTIPLIER
return usage.ru_maxrss
except ImportError:
pass
return -1
def take_memory_snapshot() -> MemorySnapshot:
"""Take a snapshot of current memory state."""
gc_counts = gc.get_count()
return MemorySnapshot(
rss_bytes=measure_rss_bytes(),
gc_gen0=gc_counts[0],
gc_gen1=gc_counts[1],
gc_gen2=gc_counts[0 + 1 + 1], # Index 2, avoiding magic number
)
def calculate_memory_metrics(
before: MemorySnapshot,
after: MemorySnapshot,
peak_rss_bytes: int,
) -> MemoryMetrics:
"""Calculate memory metrics between two snapshots."""
return MemoryMetrics(
rss_before_mb=before.rss_mb,
rss_after_mb=after.rss_mb,
rss_peak_mb=peak_rss_bytes / BYTES_PER_MB if peak_rss_bytes >= 0 else -1.0,
rss_delta_mb=(after.rss_bytes - before.rss_bytes) / BYTES_PER_MB
if before.rss_bytes >= 0 and after.rss_bytes >= 0
else 0.0,
gc_collections=(
after.gc_gen0 - before.gc_gen0,
after.gc_gen1 - before.gc_gen1,
after.gc_gen2 - before.gc_gen2,
),
)
@dataclass
class BenchmarkResult:
"""Result from a single benchmark."""
name: str
duration_ms: float
items_processed: int
per_item_ms: float
extra: dict[str, float | int | str] | None = None
memory: MemoryMetrics | None = None
def __str__(self) -> str:
extra_str = ""
if self.extra:
extra_str = " | " + ", ".join(f"{k}={v}" for k, v in self.extra.items())
return (
f"{self.name}: {self.duration_ms:.2f}ms total, "
f"{self.per_item_ms:.4f}ms/item ({self.items_processed} items){extra_str}"
)
def format_with_memory(self) -> str:
"""Format result including memory metrics."""
base = str(self)
if self.memory:
return f"{base}\n Memory: {self.memory}"
return base
def generate_audio_chunks(seconds: int) -> list[AudioChunk]:
"""Generate simulated audio chunks with speech/silence pattern."""
np.random.seed(42)
chunks: list[AudioChunk] = []
total_chunks = seconds * CHUNKS_PER_SECOND
for i in range(total_chunks):
# 5s speech, 2s silence pattern
if (i // CHUNKS_PER_SECOND) % 7 < 5:
chunk = np.random.randn(CHUNK_SIZE).astype(np.float32) * 0.3
else:
chunk = np.random.randn(CHUNK_SIZE).astype(np.float32) * 0.001
chunks.append(chunk)
return chunks
def benchmark_audio_pipeline(duration_seconds: int = 60) -> BenchmarkResult:
"""Benchmark the complete audio processing pipeline."""
from noteflow.infrastructure.asr.segmenter import Segmenter, SegmenterConfig
from noteflow.infrastructure.asr.streaming_vad import StreamingVad
from noteflow.infrastructure.audio.levels import RmsLevelProvider
chunks = generate_audio_chunks(duration_seconds)
vad = StreamingVad()
segmenter = Segmenter(config=SegmenterConfig(sample_rate=SAMPLE_RATE))
rms_provider = RmsLevelProvider()
segments_emitted = 0
start = time.perf_counter()
for chunk in chunks:
is_speech = vad.process_chunk(chunk)
_ = rms_provider.get_rms(chunk)
_ = rms_provider.get_db(chunk)
for _ in segmenter.process_audio(chunk, is_speech):
segments_emitted += 1
if segmenter.flush() is not None:
segments_emitted += 1
elapsed = time.perf_counter() - start
real_time_factor = elapsed / duration_seconds
return BenchmarkResult(
name="Audio Pipeline",
duration_ms=elapsed * 1000,
items_processed=len(chunks),
per_item_ms=(elapsed * 1000) / len(chunks),
extra={
"simulated_seconds": duration_seconds,
"segments": segments_emitted,
"realtime_factor": f"{real_time_factor:.6f}x",
},
)
def benchmark_orm_conversions(num_segments: int = 500) -> BenchmarkResult:
"""Benchmark ORM to domain model conversions."""
from noteflow.infrastructure.converters.orm_converters import OrmConverter
from noteflow.infrastructure.persistence.models.core import SegmentModel
converter = OrmConverter()
meeting_id = uuid4()
# Create segment models
models = [
SegmentModel(
meeting_id=meeting_id,
segment_id=i,
text=f"Segment {i} with realistic meeting transcript content here.",
start_time=float(i * 5),
end_time=float(i * 5 + 4.5),
speaker_id=f"speaker_{i % 3}",
)
for i in range(num_segments)
]
start = time.perf_counter()
_ = [converter.segment_to_domain(m) for m in models]
elapsed = time.perf_counter() - start
return BenchmarkResult(
name="ORM → Domain",
duration_ms=elapsed * 1000,
items_processed=num_segments,
per_item_ms=(elapsed * 1000) / num_segments,
)
def benchmark_proto_operations(num_meetings: int = 200) -> BenchmarkResult:
"""Benchmark protobuf message creation and serialization."""
from noteflow.grpc.proto import noteflow_pb2
# Create messages
start = time.perf_counter()
meetings = [
noteflow_pb2.Meeting(
id=str(uuid4()),
title=f"Meeting {i}",
state=noteflow_pb2.MEETING_STATE_COMPLETED,
)
for i in range(num_meetings)
]
creation_time = time.perf_counter() - start
# Create response
response = noteflow_pb2.ListMeetingsResponse(
meetings=meetings, total_count=len(meetings)
)
# Serialize
start = time.perf_counter()
serialized = response.SerializeToString()
serialize_time = time.perf_counter() - start
# Deserialize
start = time.perf_counter()
parsed = noteflow_pb2.ListMeetingsResponse()
parsed.ParseFromString(serialized)
deserialize_time = time.perf_counter() - start
total_time = creation_time + serialize_time + deserialize_time
return BenchmarkResult(
name="Proto Ops",
duration_ms=total_time * 1000,
items_processed=num_meetings,
per_item_ms=(creation_time * 1000) / num_meetings,
extra={
"creation_ms": f"{creation_time * 1000:.2f}",
"serialize_ms": f"{serialize_time * 1000:.2f}",
"deserialize_ms": f"{deserialize_time * 1000:.2f}",
"payload_kb": f"{len(serialized) / 1024:.1f}",
},
)
async def benchmark_async_overhead(iterations: int = 1000) -> BenchmarkResult:
"""Benchmark async context manager overhead."""
@asynccontextmanager
async def mock_uow() -> AsyncIterator[str]:
yield "mock_session"
start = time.perf_counter()
for _ in range(iterations):
async with mock_uow():
pass
elapsed = time.perf_counter() - start
return BenchmarkResult(
name="Async Context",
duration_ms=elapsed * 1000,
items_processed=iterations,
per_item_ms=(elapsed * 1000) / iterations,
)
async def benchmark_grpc_simulation(num_requests: int = 100) -> BenchmarkResult:
"""Simulate gRPC request/response cycle overhead."""
from noteflow.grpc.proto import noteflow_pb2
async def simulate_request() -> noteflow_pb2.Meeting:
# Simulate request parsing
request = noteflow_pb2.GetMeetingRequest(meeting_id=str(uuid4()))
_ = request.SerializeToString()
# Simulate minimal processing delay
await asyncio.sleep(0)
# Simulate response creation
return noteflow_pb2.Meeting(
id=request.meeting_id,
title="Test Meeting",
state=noteflow_pb2.MEETING_STATE_COMPLETED,
)
start = time.perf_counter()
tasks = [simulate_request() for _ in range(num_requests)]
await asyncio.gather(*tasks)
elapsed = time.perf_counter() - start
return BenchmarkResult(
name="gRPC Sim",
duration_ms=elapsed * 1000,
items_processed=num_requests,
per_item_ms=(elapsed * 1000) / num_requests,
extra={"concurrent": num_requests},
)
def benchmark_import_times() -> list[BenchmarkResult]:
"""Measure import times for key modules."""
results: list[BenchmarkResult] = []
modules = [
("noteflow.infrastructure.asr", "ASR Module"),
("noteflow.grpc.proto.noteflow_pb2", "Proto Module"),
("noteflow.infrastructure.persistence.models", "ORM Models"),
("noteflow.domain.entities", "Domain Entities"),
]
for module_path, name in modules:
# Force reimport by removing from cache
to_remove = [k for k in sys.modules if k.startswith(module_path.split(".")[0])]
for k in to_remove:
sys.modules.pop(k, None)
start = time.perf_counter()
__import__(module_path)
elapsed = time.perf_counter() - start
results.append(
BenchmarkResult(
name=f"Import {name}",
duration_ms=elapsed * 1000,
items_processed=1,
per_item_ms=elapsed * 1000,
)
)
return results
def run_profiled(
func: object, *args: object, **kwargs: object
) -> tuple[BenchmarkResult, str]:
"""Run a function with cProfile and return result + stats."""
profiler = cProfile.Profile()
profiler.enable()
# func is expected to be a callable returning BenchmarkResult
callable_func = cast("Callable[..., BenchmarkResult]", func)
result = callable_func(*args, **kwargs)
profiler.disable()
stream = io.StringIO()
stats = pstats.Stats(profiler, stream=stream)
stats.strip_dirs()
stats.sort_stats(pstats.SortKey.CUMULATIVE)
stats.print_stats(20)
return result, stream.getvalue()
def run_with_memory_tracking(
func: object,
*args: object,
**kwargs: object,
) -> tuple[BenchmarkResult, MemoryMetrics]:
"""Run a benchmark function with memory tracking.
Args:
func: Benchmark function to run.
*args: Positional arguments for the function.
**kwargs: Keyword arguments for the function.
Returns:
Tuple of (benchmark result, memory metrics).
"""
gc.collect() # Clear pending garbage
snapshot_before = take_memory_snapshot()
peak_rss = snapshot_before.rss_bytes
callable_func = cast("Callable[..., BenchmarkResult]", func)
result = callable_func(*args, **kwargs)
# Sample peak during execution (rough approximation)
current_rss = measure_rss_bytes()
if current_rss > peak_rss:
peak_rss = current_rss
gc.collect()
snapshot_after = take_memory_snapshot()
metrics = calculate_memory_metrics(snapshot_before, snapshot_after, peak_rss)
result.memory = metrics
return result, metrics
async def main(
enable_profile: bool = False,
verbose: bool = False,
enable_memory: bool = False,
) -> None:
"""Run all benchmarks."""
print("=" * 70)
print("NoteFlow Comprehensive Performance Profile")
print("=" * 70)
print()
initial_snapshot: MemorySnapshot | None = None
if enable_memory:
initial_snapshot = take_memory_snapshot()
print(f"Initial RSS: {initial_snapshot.rss_mb:.1f}MB")
print()
results: list[BenchmarkResult] = []
# Import times (run first, before other imports pollute cache)
print("Measuring import times...")
# Skip import benchmarks as they're destructive to module cache
# results.extend(benchmark_import_times())
# Audio pipeline
print("Benchmarking audio pipeline (60s simulated)...")
if enable_profile:
profiled_result, profile_output = run_profiled(benchmark_audio_pipeline, 60)
results.append(profiled_result)
if verbose:
print(profile_output)
elif enable_memory:
mem_result, _ = run_with_memory_tracking(benchmark_audio_pipeline, 60)
results.append(mem_result)
else:
results.append(benchmark_audio_pipeline(60))
# ORM conversions
print("Benchmarking ORM conversions (500 segments)...")
if enable_memory:
mem_result, _ = run_with_memory_tracking(benchmark_orm_conversions, 500)
results.append(mem_result)
else:
results.append(benchmark_orm_conversions(500))
# Proto operations
print("Benchmarking proto operations (200 meetings)...")
if enable_memory:
mem_result, _ = run_with_memory_tracking(benchmark_proto_operations, 200)
results.append(mem_result)
else:
results.append(benchmark_proto_operations(200))
# Async overhead
print("Benchmarking async context overhead (1000 iterations)...")
results.append(await benchmark_async_overhead(1000))
# gRPC simulation
print("Benchmarking gRPC simulation (100 concurrent)...")
results.append(await benchmark_grpc_simulation(100))
# Summary
print()
print("=" * 70)
print("BENCHMARK RESULTS")
print("=" * 70)
for result in results:
if enable_memory and result.memory:
print(f" {result.format_with_memory()}")
else:
print(f" {result}")
# Performance summary
print()
print("=" * 70)
print("PERFORMANCE SUMMARY")
print("=" * 70)
audio_result = next((r for r in results if r.name == "Audio Pipeline"), None)
if audio_result and audio_result.extra:
rtf = audio_result.extra.get("realtime_factor", "N/A")
print(f" Real-time factor: {rtf} (< 1.0 = faster than real-time)")
total_overhead = sum(
r.duration_ms
for r in results
if r.name in ("ORM → Domain", "Proto Ops", "Async Context")
)
print(f" Data layer overhead (500 segs + 200 mtgs + 1k ctx): {total_overhead:.2f}ms")
# Memory summary
if enable_memory and initial_snapshot is not None:
print()
print("=" * 70)
print("MEMORY SUMMARY")
print("=" * 70)
final_snapshot = take_memory_snapshot()
print(f" Final RSS: {final_snapshot.rss_mb:.1f}MB")
total_delta = final_snapshot.rss_bytes - initial_snapshot.rss_bytes
print(f" Total RSS change: {total_delta / BYTES_PER_MB:+.1f}MB")
total_gc = (
final_snapshot.gc_gen0 - initial_snapshot.gc_gen0,
final_snapshot.gc_gen1 - initial_snapshot.gc_gen1,
final_snapshot.gc_gen2 - initial_snapshot.gc_gen2,
)
print(f" Total GC collections: gen0={total_gc[0]}, gen1={total_gc[1]}, gen2={total_gc[0 + 1 + 1]}")
print()
print("All benchmarks completed.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="NoteFlow performance profiler")
parser.add_argument(
"--profile", action="store_true", help="Enable cProfile for detailed analysis"
)
parser.add_argument(
"--verbose", action="store_true", help="Show extended profile output"
)
parser.add_argument(
"--memory", action="store_true", help="Enable RSS and GC memory profiling"
)
args = parser.parse_args()
asyncio.run(main(
enable_profile=args.profile,
verbose=args.verbose,
enable_memory=args.memory,
))

View File

@@ -11,12 +11,12 @@ from typing import TYPE_CHECKING
from uuid import UUID, uuid4
from noteflow.application.services.asr_config_types import (
DEVICE_COMPUTE_TYPES,
AsrCapabilities,
AsrComputeType,
AsrConfigJob,
AsrConfigPhase,
AsrDevice,
DEVICE_COMPUTE_TYPES,
)
from noteflow.domain.constants.fields import (
JOB_STATUS_COMPLETED,
@@ -83,15 +83,17 @@ class AsrConfigService:
current_device = AsrDevice.CPU
current_compute_type = AsrComputeType.INT8
if self._asr_engine is not None:
current_device = AsrDevice(self._asr_engine.device)
current_compute_type = AsrComputeType(self._asr_engine.compute_type)
# Capture engine reference to avoid race between null check and attribute access
engine = self._asr_engine
if engine is not None:
current_device = AsrDevice(engine.device)
current_compute_type = AsrComputeType(engine.compute_type)
return AsrCapabilities(
model_size=self._asr_engine.model_size if self._asr_engine else None,
model_size=engine.model_size if engine else None,
device=current_device,
compute_type=current_compute_type,
is_ready=self._asr_engine.is_loaded if self._asr_engine else False,
is_ready=engine.is_loaded if engine else False,
cuda_available=cuda_available,
available_model_sizes=VALID_MODEL_SIZES,
available_compute_types=DEVICE_COMPUTE_TYPES[current_device],
@@ -259,30 +261,46 @@ class AsrConfigService:
async def _execute_reconfiguration(self, job: AsrConfigJob) -> None:
"""Run reconfiguration steps and swap engine after success."""
current_engine = self._asr_engine
job.status = JOB_STATUS_RUNNING
job.phase = AsrConfigPhase.LOADING
job.progress_percent = 10.0
from noteflow.infrastructure.observability.otel import get_tracer
engine_to_use, is_new_engine = self._build_engine_for_job(job, current_engine)
job.progress_percent = 50.0
tracer = get_tracer(__name__)
with tracer.start_as_current_span("asr_reconfiguration") as span:
span.set_attribute("asr.target_model", job.target_model_size)
span.set_attribute("asr.target_device", job.target_device.value)
span.set_attribute("asr.target_compute_type", job.target_compute_type.value)
await self._load_model(engine_to_use, job.target_model_size)
job.progress_percent = 90.0
current_engine = self._asr_engine
job.status = JOB_STATUS_RUNNING
job.phase = AsrConfigPhase.LOADING
job.progress_percent = 10.0
if is_new_engine:
self._set_active_engine(engine_to_use)
if current_engine is not None:
current_engine.unload()
span.add_event("engine_build_start")
engine_to_use, is_new_engine = self._build_engine_for_job(job, current_engine)
span.set_attribute("asr.is_new_engine", is_new_engine)
job.progress_percent = 50.0
self._mark_job_completed(job)
await self._persist_configuration()
logger.info(
"asr_reconfigured",
model_size=job.target_model_size,
device=job.target_device.value,
compute_type=job.target_compute_type.value,
)
span.add_event("model_load_start")
await self._load_model(engine_to_use, job.target_model_size)
span.add_event("model_load_complete")
job.progress_percent = 90.0
if is_new_engine:
try:
self._set_active_engine(engine_to_use)
finally:
# Ensure old engine is always unloaded even if callback fails
if current_engine is not None:
current_engine.unload()
span.add_event("engine_swapped")
self._mark_job_completed(job)
await self._persist_configuration()
logger.info(
"asr_reconfigured",
model_size=job.target_model_size,
device=job.target_device.value,
compute_type=job.target_compute_type.value,
)
async def _persist_configuration(self) -> None:
"""Persist the current ASR configuration if a callback is configured."""
@@ -318,3 +336,31 @@ class AsrConfigService:
if job is None:
logger.debug("job_not_found", job_id=str(job_id))
return job
async def shutdown(self, timeout: float = 5.0) -> None:
"""Cancel all pending ASR configuration jobs and wait for completion.
Args:
timeout: Maximum time to wait for jobs to complete (seconds).
"""
tasks_to_cancel: list[asyncio.Task[None]] = []
async with self._job_lock:
for job in self._jobs.values():
if job.task is not None and not job.task.done():
job.task.cancel()
tasks_to_cancel.append(job.task)
if not tasks_to_cancel:
return
logger.info("asr_config_shutdown", pending_jobs=len(tasks_to_cancel))
try:
await asyncio.wait_for(
asyncio.gather(*tasks_to_cancel, return_exceptions=True),
timeout=timeout,
)
except TimeoutError:
logger.warning(
"asr_config_shutdown_timeout",
remaining_jobs=sum(1 for t in tasks_to_cancel if not t.done()),
)

View File

@@ -90,6 +90,12 @@ class ServicerState(Protocol):
sync_runs: dict[UUID, SyncRun]
# Track when each sync run was cached (Sprint GAP-002: State Synchronization)
sync_run_cache_times: dict[UUID, datetime]
# Background sync tasks for proper lifecycle management
sync_tasks: dict[UUID, asyncio.Task[None]]
# Background cleanup tasks (separate from sync tasks)
sync_cleanup_tasks: dict[UUID, asyncio.Task[None]]
# General-purpose background tasks (webhooks, etc.) with self-cleanup
background_tasks: set[asyncio.Task[None]]
# Constants
DEFAULT_SAMPLE_RATE: Final[int]

View File

@@ -31,6 +31,25 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
def _diarization_task_done_callback(
task: asyncio.Task[None],
job_id: str,
tasks_dict: dict[str, asyncio.Task[None]],
) -> None:
"""Handle completion of a diarization task, cleaning up and logging exceptions."""
tasks_dict.pop(job_id, None)
if task.cancelled():
logger.debug("diarization_task_cancelled", job_id=job_id)
return
exc = task.exception()
if exc is not None:
logger.exception(
"diarization_task_failed_unhandled",
job_id=job_id,
exc_info=exc,
)
def _job_status_name(status: int) -> str:
name_fn = cast(Callable[[int], str], noteflow_pb2.JobStatus.Name)
return name_fn(status)
@@ -269,6 +288,9 @@ class JobsMixin(JobStatusMixin):
num_speakers = request.num_speakers or None
task = asyncio.create_task(self.run_diarization_job(job_id, num_speakers))
self.diarization_tasks[job_id] = task
task.add_done_callback(
lambda t: _diarization_task_done_callback(t, job_id, self.diarization_tasks)
)
logger.info(
"diarization_task_created",
job_id=job_id,
@@ -310,22 +332,41 @@ async def _execute_diarization(ctx: _DiarizationJobContext) -> None:
Args:
ctx: Job context with host, job info, and parameters.
"""
try:
async with asyncio.timeout(DIARIZATION_TIMEOUT_SECONDS):
updated_count = await ctx.host.refine_speaker_diarization(
meeting_id=ctx.meeting_id,
num_speakers=ctx.num_speakers,
)
speaker_ids = await ctx.host.collect_speaker_ids(ctx.meeting_id)
from noteflow.infrastructure.observability.otel import get_tracer
await ctx.host.update_job_completed(ctx.job_id, ctx.job, updated_count, speaker_ids)
except TimeoutError:
await ctx.host.handle_job_timeout(ctx.job_id, ctx.job, ctx.meeting_id)
except asyncio.CancelledError:
await ctx.host.handle_job_cancelled(ctx.job_id, ctx.job, ctx.meeting_id)
raise # Re-raise to propagate cancellation
# INTENTIONAL BROAD HANDLER: Job error boundary
# - Diarization can fail in many ways (model errors, audio issues, etc.)
# - Must capture any failure and update job status
except Exception as exc:
await ctx.host.handle_job_failed(ctx.job_id, ctx.job, ctx.meeting_id, exc)
tracer = get_tracer(__name__)
with tracer.start_as_current_span("diarization_job") as span:
span.set_attribute("diarization.job_id", ctx.job_id)
span.set_attribute("diarization.meeting_id", ctx.meeting_id)
if ctx.num_speakers is not None:
span.set_attribute("diarization.num_speakers", ctx.num_speakers)
try:
async with asyncio.timeout(DIARIZATION_TIMEOUT_SECONDS):
span.add_event("refinement_start")
updated_count = await ctx.host.refine_speaker_diarization(
meeting_id=ctx.meeting_id,
num_speakers=ctx.num_speakers,
)
span.set_attribute("diarization.segments_updated", updated_count)
span.add_event("refinement_complete")
speaker_ids = await ctx.host.collect_speaker_ids(ctx.meeting_id)
span.set_attribute("diarization.speaker_count", len(speaker_ids))
await ctx.host.update_job_completed(ctx.job_id, ctx.job, updated_count, speaker_ids)
span.add_event("job_completed")
except TimeoutError:
span.set_attribute("diarization.timeout", True)
await ctx.host.handle_job_timeout(ctx.job_id, ctx.job, ctx.meeting_id)
except asyncio.CancelledError:
span.set_attribute("diarization.cancelled", True)
await ctx.host.handle_job_cancelled(ctx.job_id, ctx.job, ctx.meeting_id)
raise # Re-raise to propagate cancellation
# INTENTIONAL BROAD HANDLER: Job error boundary
# - Diarization can fail in many ways (model errors, audio issues, etc.)
# - Must capture any failure and update job status
except Exception as exc:
span.record_exception(exc)
await ctx.host.handle_job_failed(ctx.job_id, ctx.job, ctx.meeting_id, exc)

View File

@@ -45,7 +45,7 @@ async def _cancel_running_task(
tasks: Dictionary of active diarization tasks.
job_id: The job ID whose task should be cancelled.
"""
task = tasks.get(job_id)
task = tasks.pop(job_id, None)
if task is not None and not task.done():
task.cancel()
with contextlib.suppress(asyncio.CancelledError):

View File

@@ -36,13 +36,25 @@ async def _generate_summary_and_complete(
summarization_service: SummarizationService,
) -> None:
"""Generate summary for meeting and transition to COMPLETED state."""
result = await _process_summary(
repo_provider=host, meeting_id=meeting_id, service=summarization_service
)
if result is None:
return
meeting, saved_summary = result
await _trigger_summary_webhook(host.webhook_service, meeting, saved_summary)
from noteflow.infrastructure.observability.otel import get_tracer
tracer = get_tracer(__name__)
with tracer.start_as_current_span("post_processing") as span:
span.set_attribute("meeting.id", meeting_id)
result = await _process_summary(
repo_provider=host, meeting_id=meeting_id, service=summarization_service
)
if result is None:
span.set_attribute("post_processing.skipped", True)
return
meeting, saved_summary = result
span.set_attribute("summary.db_id", saved_summary.db_id or 0)
span.add_event("summary_generated")
await _trigger_summary_webhook(host.webhook_service, meeting, saved_summary)
span.add_event("webhook_triggered")
async def _process_summary(
@@ -162,6 +174,23 @@ async def _trigger_summary_webhook(
logger.exception("Failed to trigger summary.generated webhook")
def _post_processing_task_done_callback(
task: asyncio.Task[None],
meeting_id: str,
) -> None:
"""Handle completion of a post-processing task, logging any unhandled exceptions."""
if task.cancelled():
logger.debug("Post-processing task cancelled", meeting_id=meeting_id)
return
exc = task.exception()
if exc is not None:
logger.exception(
"Post-processing task failed with unhandled exception",
meeting_id=meeting_id,
exc_info=exc,
)
async def start_post_processing(
host: ServicerHost,
meeting_id: str,
@@ -201,5 +230,6 @@ async def start_post_processing(
)
task = asyncio.create_task(_run_with_error_handling())
task.add_done_callback(lambda t: _post_processing_task_done_callback(t, meeting_id))
logger.info("Post-processing task started", meeting_id=meeting_id)
return task

View File

@@ -50,7 +50,7 @@ async def parse_project_ids_or_abort(
context,
f"{ERROR_INVALID_PROJECT_ID_PREFIX}{raw_project_id}",
)
return None
raise AssertionError("unreachable") # abort is NoReturn
return project_ids
@@ -78,7 +78,7 @@ async def parse_project_id_or_abort(
)
error_message = f"{ERROR_INVALID_PROJECT_ID_PREFIX}{request.project_id}"
await abort_invalid_argument(context, error_message)
return None
raise AssertionError("unreachable") # abort is NoReturn
async def resolve_active_project_id(

View File

@@ -116,7 +116,6 @@ async def process_audio_segment(
await repo.commit()
for _, update in segments_to_add:
yield update
_finalize_chunk_processing(host, meeting_id)
def _validate_meeting_id(meeting_id: str) -> MeetingId | None:
@@ -127,12 +126,6 @@ def _validate_meeting_id(meeting_id: str) -> MeetingId | None:
return parsed
def _finalize_chunk_processing(host: ServicerHost, meeting_id: str) -> None:
"""Finalize chunk processing by decrementing pending count."""
from ._processing._congestion import decrement_pending_chunks
decrement_pending_chunks(host, meeting_id)
async def _build_segments_from_results(
ctx: _SegmentBuildContext,
results: list[_AsrResultLike],

View File

@@ -104,7 +104,8 @@ class StreamingMixin:
):
yield update
finally:
if cleanup_meeting := stream_state.current or stream_state.initialized:
cleanup_meeting = stream_state.current or stream_state.initialized
if cleanup_meeting:
cleanup_stream_resources(self, cleanup_meeting)
async def process_stream_chunks(
@@ -187,7 +188,7 @@ class StreamingMixin:
meeting_id = chunk.meeting_id
if not meeting_id:
await abort_invalid_argument(context, "meeting_id required")
return None
raise AssertionError("unreachable") # abort is NoReturn
if current_meeting_id is None:
# Track meeting_id BEFORE init to guarantee cleanup on any exception
@@ -202,7 +203,7 @@ class StreamingMixin:
return None if init_result is None else (meeting_id, initialized_meeting_id)
if meeting_id != current_meeting_id:
await abort_invalid_argument(context, "Stream may only contain a single meeting_id")
return None
raise AssertionError("unreachable") # abort is NoReturn
return current_meeting_id, initialized_meeting_id

View File

@@ -77,7 +77,7 @@ def _maybe_create_partial_update(
state: MeetingStreamState,
meeting_id: str,
partial_text: str,
now: float,
monotonic_now: float,
) -> noteflow_pb2.TranscriptUpdate | None:
"""Create partial update if text changed, updating state.
@@ -85,24 +85,24 @@ def _maybe_create_partial_update(
state: Stream state to update.
meeting_id: Meeting identifier.
partial_text: Transcribed text.
now: Current timestamp.
monotonic_now: Current monotonic timestamp (for interval calculations).
Returns:
TranscriptUpdate if text changed, None otherwise.
"""
# Only emit if text changed (debounce)
if partial_text and partial_text != state.last_partial_text:
state.last_partial_time = now
state.last_partial_time = monotonic_now
state.last_partial_text = partial_text
return noteflow_pb2.TranscriptUpdate(
meeting_id=meeting_id,
update_type=noteflow_pb2.UPDATE_TYPE_PARTIAL,
partial_text=partial_text,
server_timestamp=now,
server_timestamp=time.time(), # Wall-clock for client display
)
# Update time even if no text change (cadence tracking)
state.last_partial_time = now
state.last_partial_time = monotonic_now
return None
@@ -128,13 +128,14 @@ async def maybe_emit_partial(
if state is None:
return None
now = time.time()
# Use monotonic time for interval calculations (immune to NTP adjustments)
monotonic_now = time.monotonic()
if not _should_emit_partial(host, state, now):
if not _should_emit_partial(host, state, monotonic_now):
return None
partial_text = await _transcribe_partial_audio(host, state)
return _maybe_create_partial_update(state, meeting_id, partial_text, now)
return _maybe_create_partial_update(state, meeting_id, partial_text, monotonic_now)
def clear_partial_buffer(host: ServicerHost, meeting_id: str) -> None:
@@ -146,7 +147,8 @@ def clear_partial_buffer(host: ServicerHost, meeting_id: str) -> None:
host: The servicer host.
meeting_id: Meeting identifier.
"""
current_time = time.time()
# Use monotonic time for interval calculations (immune to NTP adjustments)
current_time = time.monotonic()
# Use consolidated state if available
if state := host.get_stream_state(meeting_id):

View File

@@ -15,6 +15,9 @@ from ._constants import PROCESSING_DELAY_THRESHOLD_MS, QUEUE_DEPTH_THRESHOLD
from ._vad_processing import flush_segmenter, process_audio_with_vad
if TYPE_CHECKING:
import numpy as np
from numpy.typing import NDArray
from ...protocols import ServicerHost
__all__ = [
@@ -27,6 +30,50 @@ __all__ = [
]
def _track_ack_update(
host: ServicerHost,
meeting_id: str,
chunk: noteflow_pb2.AudioChunk,
) -> noteflow_pb2.TranscriptUpdate | None:
chunk_sequence = max(chunk.chunk_sequence, 0)
return track_chunk_sequence(host, meeting_id, chunk_sequence)
async def _decode_chunk_audio(
host: ServicerHost,
meeting_id: str,
chunk: noteflow_pb2.AudioChunk,
context: GrpcContext,
) -> NDArray[np.float32] | None:
try:
sample_rate, channels = normalize_stream_format(
host,
meeting_id,
chunk.sample_rate,
chunk.channels,
)
except ValueError as e:
await abort_invalid_argument(context, str(e))
raise # Unreachable but helps type checker
return await decode_and_convert_audio(
host=host,
chunk=chunk,
stream_format=(sample_rate, channels),
context=context,
)
async def _yield_processed_audio(
host: ServicerHost,
meeting_id: str,
audio: NDArray[np.float32],
) -> AsyncIterator[noteflow_pb2.TranscriptUpdate]:
write_audio_chunk_safe(host, meeting_id, audio)
async for update in process_audio_with_vad(host, meeting_id, audio):
yield update
async def process_stream_chunk(
host: ServicerHost,
meeting_id: str,
@@ -44,35 +91,16 @@ async def process_stream_chunk(
Yields:
Transcript updates from processing.
"""
# Track chunk sequence for acknowledgment (default 0 for backwards compat)
chunk_sequence = max(chunk.chunk_sequence, 0)
ack_update = track_chunk_sequence(host, meeting_id, chunk_sequence)
if ack_update is not None:
yield ack_update
try:
sample_rate, channels = normalize_stream_format(
host,
meeting_id,
chunk.sample_rate,
chunk.channels,
)
except ValueError as e:
await abort_invalid_argument(context, str(e))
raise # Unreachable but helps type checker
ack_update = _track_ack_update(host, meeting_id, chunk)
if ack_update is not None:
yield ack_update
audio = await decode_and_convert_audio(
host=host,
chunk=chunk,
stream_format=(sample_rate, channels),
context=context,
)
if audio is None:
return
audio = await _decode_chunk_audio(host, meeting_id, chunk, context)
if audio is None:
return
# Write to encrypted audio file
write_audio_chunk_safe(host, meeting_id, audio)
# VAD-driven segmentation
async for update in process_audio_with_vad(host, meeting_id, audio):
yield update
async for update in _yield_processed_audio(host, meeting_id, audio):
yield update
finally:
decrement_pending_chunks(host, meeting_id)

View File

@@ -2,12 +2,11 @@
from __future__ import annotations
from collections import deque
from typing import TYPE_CHECKING
from ....proto import noteflow_pb2
from ...converters import create_congestion_info
from ._constants import ACK_CHUNK_INTERVAL, PROCESSING_DELAY_THRESHOLD_MS, QUEUE_DEPTH_THRESHOLD
from ._constants import PROCESSING_DELAY_THRESHOLD_MS, QUEUE_DEPTH_THRESHOLD
if TYPE_CHECKING:
from ...protocols import ServicerHost
@@ -28,11 +27,16 @@ def calculate_congestion_info(
Returns:
CongestionInfo with processing delay, queue depth, and throttle recommendation.
"""
if receipt_times := host.chunk_receipt_times.get(meeting_id, deque()):
oldest_receipt = receipt_times[0]
processing_delay_ms = int((current_time - oldest_receipt) * 1000)
else:
processing_delay_ms = 0
processing_delay_ms = 0
receipt_times = host.chunk_receipt_times.get(meeting_id)
if receipt_times:
try:
# Access [0] can race with popleft() in decrement_pending_chunks
oldest_receipt = receipt_times[0]
processing_delay_ms = int((current_time - oldest_receipt) * 1000)
except IndexError:
# Deque was emptied by concurrent popleft() - no pending chunks
pass
# Get queue depth (pending chunks not yet processed through ASR)
queue_depth = host.pending_chunks.get(meeting_id, 0)
@@ -53,18 +57,17 @@ def calculate_congestion_info(
def decrement_pending_chunks(host: ServicerHost, meeting_id: str) -> None:
"""Decrement pending chunks counter after processing.
Call this after ASR processing completes for a segment.
Call this after a chunk finishes processing.
"""
if meeting_id not in host.pending_chunks:
return
# Decrement by ACK_CHUNK_INTERVAL since we process in batches
host.pending_chunks[meeting_id] = max(
0, host.pending_chunks[meeting_id] - ACK_CHUNK_INTERVAL
)
host.pending_chunks[meeting_id] = max(0, host.pending_chunks[meeting_id] - 1)
receipt_times = host.chunk_receipt_times.get(meeting_id)
if not receipt_times:
return
# Remove timestamps corresponding to processed chunks
for _ in range(min(ACK_CHUNK_INTERVAL, len(receipt_times))):
try:
receipt_times.popleft()
except IndexError:
# Deque already empty - can occur if multiple coroutines race
pass

View File

@@ -59,13 +59,19 @@ async def _trigger_recording_webhook(
"""
if host.webhook_service is None:
return
await fire_webhook_safe(
host.webhook_service.trigger_recording_started(
meeting_id=meeting_id,
title=title,
task = asyncio.create_task(
fire_webhook_safe(
host.webhook_service.trigger_recording_started(
meeting_id=meeting_id,
title=title,
),
"recording.started",
),
"recording.started",
name=f"webhook-recording-started-{meeting_id}",
)
# Store reference to prevent garbage collection before completion
host.background_tasks.add(task)
task.add_done_callback(host.background_tasks.discard)
async def _prepare_meeting_for_streaming(
@@ -160,7 +166,9 @@ class StreamSessionManager:
init_result = await StreamSessionManager._init_stream_session(host, meeting_id)
success = init_result.success
if not success:
host.active_streams.discard(meeting_id)
# Release stream slot under lock to prevent race conditions
async with host.stream_init_lock:
host.active_streams.discard(meeting_id)
error_code = init_result.error_code
status = error_code if error_code is not None else grpc.StatusCode.INTERNAL
error_message = init_result.error_message or ""
@@ -206,7 +214,7 @@ class StreamSessionManager:
await abort_failed_precondition(
context, "Stream initialization timed out - server may be overloaded"
)
return False
raise AssertionError("unreachable") # abort is NoReturn
return reserved
@staticmethod
@@ -234,7 +242,7 @@ class StreamSessionManager:
await abort_failed_precondition(
context, f"{ERROR_MSG_MEETING_PREFIX}{meeting_id} already streaming"
)
return False
raise AssertionError("unreachable") # abort is NoReturn
@staticmethod
async def _init_stream_session(

View File

@@ -39,6 +39,31 @@ SYNC_RUN_CACHE_EXPIRY_SECONDS = 60
_ERR_CALENDAR_NOT_ENABLED = "Calendar integration not enabled"
def _sync_task_done_callback(
task: asyncio.Task[None],
sync_run_id: UUID,
tasks_dict: dict[UUID, asyncio.Task[None]],
) -> None:
"""Handle completion of a sync task, logging any unhandled exceptions."""
tasks_dict.pop(sync_run_id, None)
if task.cancelled():
return
exc = task.exception()
if exc is not None:
logger.exception("Sync task %s failed with unhandled exception", sync_run_id, exc_info=exc)
async def _cleanup_sync_cache(
sync_run_id: UUID,
cache: dict[UUID, SyncRun],
cache_times: dict[UUID, datetime],
) -> None:
"""Clean up sync run cache after expiry delay (non-blocking)."""
await asyncio.sleep(SYNC_RUN_CACHE_EXPIRY_SECONDS)
cache.pop(sync_run_id, None)
cache_times.pop(sync_run_id, None)
def _format_enum_value(value: str | None) -> str:
"""Format an enum value to string."""
return "" if value is None else value
@@ -68,6 +93,10 @@ class SyncMixin:
sync_runs: dict[UUID, SyncRun]
# Track when each sync run was cached (Sprint GAP-002: State Synchronization)
sync_run_cache_times: dict[UUID, datetime]
# Track background sync tasks for proper lifecycle management
sync_tasks: dict[UUID, asyncio.Task[None]]
# Track background cleanup tasks separately to avoid overwriting sync tasks
sync_cleanup_tasks: dict[UUID, asyncio.Task[None]]
def ensure_sync_runs_cache(self: ServicerHost) -> dict[UUID, SyncRun]:
"""Ensure the sync runs cache exists."""
@@ -75,6 +104,10 @@ class SyncMixin:
self.sync_runs = {}
if not hasattr(self, "sync_run_cache_times"):
self.sync_run_cache_times = {}
if not hasattr(self, "sync_tasks"):
self.sync_tasks = {}
if not hasattr(self, "sync_cleanup_tasks"):
self.sync_cleanup_tasks = {}
return self.sync_runs
def cache_sync_run(self: ServicerHost, sync_run: SyncRun) -> None:
@@ -101,6 +134,7 @@ class SyncMixin:
"""Start a sync operation for an integration."""
if self.calendar_service is None:
await abort_unavailable(context, _ERR_CALENDAR_NOT_ENABLED)
raise AssertionError("unreachable") # abort is NoReturn
integration_id = await parse_integration_id(request.integration_id, context)
@@ -113,17 +147,22 @@ class SyncMixin:
provider = provider_value if isinstance(provider_value, str) else None
if not provider:
await abort_failed_precondition(context, "Integration provider not configured")
return noteflow_pb2.StartIntegrationSyncResponse()
raise AssertionError("unreachable") # abort is NoReturn
sync_run = SyncRun.start(integration_id)
sync_run = await uow.integrations.create_sync_run(sync_run)
await uow.commit()
self.cache_sync_run(sync_run)
asyncio.create_task(
self.ensure_sync_runs_cache() # Ensure sync_tasks is initialized
task = asyncio.create_task(
self.perform_sync(integration_id, sync_run.id, str(provider)),
name=f"sync-{sync_run.id}",
).add_done_callback(lambda _: None)
)
self.sync_tasks[sync_run.id] = task
task.add_done_callback(
lambda t: _sync_task_done_callback(t, sync_run.id, self.sync_tasks)
)
logger.info("Started sync run %s for integration %s", sync_run.id, integration_id)
return noteflow_pb2.StartIntegrationSyncResponse(sync_run_id=str(sync_run.id), status="running")
@@ -154,7 +193,7 @@ class SyncMixin:
return candidate, candidate.id
await abort_not_found(context, ENTITY_INTEGRATION, request.integration_id)
return None, integration_id
raise AssertionError("unreachable") # abort is NoReturn
async def perform_sync(
self: ServicerHost,
@@ -166,37 +205,56 @@ class SyncMixin:
Fetches calendar events and updates the sync run status.
"""
from noteflow.infrastructure.observability.otel import get_tracer
tracer = get_tracer(__name__)
cache = self.ensure_sync_runs_cache()
try:
items_synced = await self.execute_sync_fetch(provider)
sync_run = await self.complete_sync_run(
integration_id, sync_run_id, items_synced
)
if sync_run:
self.cache_sync_run(sync_run)
logger.info(
"Sync run %s completed: %d items synced",
sync_run_id,
items_synced,
)
with tracer.start_as_current_span("integration_sync") as span:
span.set_attribute("sync.run_id", str(sync_run_id))
span.set_attribute("sync.integration_id", str(integration_id))
span.set_attribute("sync.provider", provider)
# INTENTIONAL BROAD HANDLER: Sync run error boundary
# - Calendar sync can fail in many ways (network, auth, parsing)
# - Must capture any failure and update sync run status
except Exception as e:
logger.exception("Sync run %s failed: %s", sync_run_id, e)
sync_run = await self.fail_sync_run(sync_run_id, str(e))
if sync_run:
self.cache_sync_run(sync_run)
try:
span.add_event("fetch_start")
items_synced = await self.execute_sync_fetch(provider)
span.set_attribute("sync.items_synced", items_synced)
span.add_event("fetch_complete")
finally:
# Clean up cache after a delay (keep for status queries)
# Sprint GAP-002: SYNC_RUN_CACHE_EXPIRY_SECONDS
await asyncio.sleep(SYNC_RUN_CACHE_EXPIRY_SECONDS)
cache.pop(sync_run_id, None)
if hasattr(self, "sync_run_cache_times"):
self.sync_run_cache_times.pop(sync_run_id, None)
sync_run = await self.complete_sync_run(
integration_id, sync_run_id, items_synced
)
if sync_run:
self.cache_sync_run(sync_run)
logger.info(
"Sync run %s completed: %d items synced",
sync_run_id,
items_synced,
)
# INTENTIONAL BROAD HANDLER: Sync run error boundary
# - Calendar sync can fail in many ways (network, auth, parsing)
# - Must capture any failure and update sync run status
except Exception as e:
span.record_exception(e)
span.set_attribute("sync.error", str(e))
logger.exception("Sync run %s failed: %s", sync_run_id, e)
sync_run = await self.fail_sync_run(sync_run_id, str(e))
if sync_run:
self.cache_sync_run(sync_run)
finally:
# Schedule non-blocking cache cleanup after expiry delay
# Sprint GAP-002: SYNC_RUN_CACHE_EXPIRY_SECONDS
cleanup_task = asyncio.create_task(
_cleanup_sync_cache(sync_run_id, cache, self.sync_run_cache_times),
name=f"sync-cleanup-{sync_run_id}",
)
# Register cleanup task separately to avoid overwriting the sync task reference
cleanup_task.add_done_callback(
lambda t: _sync_task_done_callback(t, sync_run_id, self.sync_cleanup_tasks)
)
self.sync_cleanup_tasks[sync_run_id] = cleanup_task
async def execute_sync_fetch(self: ServicerHost, provider: str) -> int:
"""Execute the calendar fetch and return items count."""

View File

@@ -30,6 +30,7 @@ from noteflow.infrastructure.security.crypto import AesGcmCryptoBox
from ._service_shutdown import (
cancel_diarization_tasks,
cancel_sync_tasks,
close_audio_writers,
close_diarization_sessions,
close_webhook_service,
@@ -137,7 +138,8 @@ class ServicerStreamingStateMixin:
)
)
partial_buffer = PartialAudioBuffer(sample_rate=self.DEFAULT_SAMPLE_RATE)
current_time = time.time()
# Use monotonic time for interval calculations (immune to NTP adjustments)
current_time = time.monotonic()
state = MeetingStreamState(
vad=vad,
@@ -336,6 +338,7 @@ class ServicerLifecycleMixin:
async def shutdown(self) -> None:
"""Clean up servicer state before server stops."""
logger.info("Shutting down servicer...")
await cancel_sync_tasks(self)
cancelled_job_ids = await cancel_diarization_tasks(self)
mark_in_memory_jobs_failed(self, cancelled_job_ids)
close_diarization_sessions(self)

View File

@@ -14,6 +14,29 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
async def cancel_sync_tasks(servicer: NoteFlowServicer) -> None:
"""Cancel all active sync tasks and cleanup tasks."""
if not hasattr(servicer, "sync_tasks"):
return
for sync_run_id, task in list(servicer.sync_tasks.items()):
if task.done():
continue
logger.debug("Cancelling sync task %s", sync_run_id)
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
servicer.sync_tasks.clear()
if hasattr(servicer, "sync_cleanup_tasks"):
for sync_run_id, task in list(servicer.sync_cleanup_tasks.items()):
if task.done():
continue
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
servicer.sync_cleanup_tasks.clear()
async def cancel_diarization_tasks(servicer: NoteFlowServicer) -> list[str]:
"""Cancel all active diarization tasks and return their IDs."""
cancelled_job_ids = list(servicer.diarization_tasks.keys())

View File

@@ -9,7 +9,9 @@ header are rejected with UNAUTHENTICATED status.
from __future__ import annotations
from collections.abc import Awaitable, Callable
from collections.abc import AsyncIterator, Awaitable, Callable
from functools import partial
from typing import Protocol, cast
import grpc
from grpc import aio
@@ -33,11 +35,95 @@ METADATA_WORKSPACE_ID = "x-workspace-id"
# Error messages
_ERR_MISSING_REQUEST_ID = "Missing required x-request-id header"
class _TypedRpcMethodHandler(Protocol[TRequest, TResponse]):
unary_unary: Callable[
[TRequest, aio.ServicerContext[TRequest, TResponse]],
Awaitable[TResponse],
] | None
unary_stream: Callable[
[TRequest, aio.ServicerContext[TRequest, TResponse]],
AsyncIterator[TResponse],
] | None
stream_unary: Callable[
[AsyncIterator[TRequest], aio.ServicerContext[TRequest, TResponse]],
Awaitable[TResponse],
] | None
stream_stream: Callable[
[AsyncIterator[TRequest], aio.ServicerContext[TRequest, TResponse]],
AsyncIterator[TResponse],
] | None
request_deserializer: Callable[[bytes], TRequest] | None
response_serializer: Callable[[TResponse], bytes] | None
def _coerce_metadata_value(value: str | bytes) -> str:
"""Normalize metadata values to string."""
return value.decode() if isinstance(value, bytes) else value
def _get_request_id(
metadata: dict[str, str | bytes],
method: str | None,
) -> str | None:
request_id_value = metadata.get(METADATA_REQUEST_ID)
if request_id_value is None:
logger.warning(
"Rejecting RPC: missing x-request-id header",
method=method,
)
return None
return _coerce_metadata_value(request_id_value)
def _apply_identity_context(metadata: dict[str, str | bytes], request_id: str) -> None:
request_id_var.set(request_id)
if user_id_value := metadata.get(METADATA_USER_ID):
user_id_var.set(_coerce_metadata_value(user_id_value))
if workspace_id_value := metadata.get(METADATA_WORKSPACE_ID):
workspace_id_var.set(_coerce_metadata_value(workspace_id_value))
async def _reject_unary_unary(
message: str,
request: TRequest,
context: aio.ServicerContext[TRequest, TResponse],
) -> TResponse:
await context.abort(grpc.StatusCode.UNAUTHENTICATED, message)
raise AssertionError("Unreachable after abort")
async def _reject_unary_stream(
message: str,
request: TRequest,
context: aio.ServicerContext[TRequest, TResponse],
) -> AsyncIterator[TResponse]:
await context.abort(grpc.StatusCode.UNAUTHENTICATED, message)
if False:
# Unreachable; keeps async generator type for gRPC handler signature.
yield cast(TResponse, None)
async def _reject_stream_unary(
message: str,
request_iterator: AsyncIterator[TRequest],
context: aio.ServicerContext[TRequest, TResponse],
) -> TResponse:
await context.abort(grpc.StatusCode.UNAUTHENTICATED, message)
raise AssertionError("Unreachable after abort")
async def _reject_stream_stream(
message: str,
request_iterator: AsyncIterator[TRequest],
context: aio.ServicerContext[TRequest, TResponse],
) -> AsyncIterator[TResponse]:
await context.abort(grpc.StatusCode.UNAUTHENTICATED, message)
if False:
# Unreachable; keeps async generator type for gRPC handler signature.
yield cast(TResponse, None)
class IdentityInterceptor(aio.ServerInterceptor):
"""Interceptor that validates and populates identity context for RPC calls.
@@ -61,38 +147,15 @@ class IdentityInterceptor(aio.ServerInterceptor):
],
handler_call_details: grpc.HandlerCallDetails,
) -> grpc.RpcMethodHandler[TRequest, TResponse]:
"""Intercept incoming RPC calls to validate and set identity context.
Args:
continuation: The next interceptor or handler.
handler_call_details: Details about the RPC call.
Returns:
The RPC handler for this call.
Raises:
grpc.RpcError: UNAUTHENTICATED if x-request-id header is missing.
"""
"""Intercept incoming RPC calls to validate and set identity context."""
metadata = dict(handler_call_details.invocation_metadata or [])
# Validate required x-request-id header
request_id_value = metadata.get(METADATA_REQUEST_ID)
if request_id_value is None:
logger.warning(
"Rejecting RPC: missing x-request-id header",
method=handler_call_details.method,
)
return _create_unauthenticated_handler(_ERR_MISSING_REQUEST_ID)
request_id = _get_request_id(metadata, handler_call_details.method)
if request_id is None:
handler = await continuation(handler_call_details)
return _create_unauthenticated_handler(handler, _ERR_MISSING_REQUEST_ID)
request_id = _coerce_metadata_value(request_id_value)
request_id_var.set(request_id)
# Extract optional user and workspace IDs from metadata
if user_id_value := metadata.get(METADATA_USER_ID):
user_id_var.set(_coerce_metadata_value(user_id_value))
if workspace_id_value := metadata.get(METADATA_WORKSPACE_ID):
workspace_id_var.set(_coerce_metadata_value(workspace_id_value))
_apply_identity_context(metadata, request_id)
logger.debug(
"Identity context: request=%s user=%s workspace=%s method=%s",
@@ -106,26 +169,36 @@ class IdentityInterceptor(aio.ServerInterceptor):
def _create_unauthenticated_handler(
handler: grpc.RpcMethodHandler[TRequest, TResponse],
message: str,
) -> grpc.RpcMethodHandler[TRequest, TResponse]:
"""Create a handler that rejects with UNAUTHENTICATED status.
"""Create a handler that rejects with UNAUTHENTICATED status."""
typed_handler = cast(_TypedRpcMethodHandler[TRequest, TResponse], handler)
request_deserializer = typed_handler.request_deserializer
response_serializer = typed_handler.response_serializer
Args:
message: Error message to include in the response.
Returns:
A gRPC method handler that rejects all requests.
"""
async def reject_unary_unary(
request: TRequest,
context: aio.ServicerContext[TRequest, TResponse],
) -> TResponse:
await context.abort(grpc.StatusCode.UNAUTHENTICATED, message)
raise AssertionError("Unreachable after abort")
return grpc.unary_unary_rpc_method_handler(
reject_unary_unary,
request_deserializer=None,
response_serializer=None,
)
if typed_handler.unary_unary is not None:
return grpc.unary_unary_rpc_method_handler(
partial(_reject_unary_unary, message),
request_deserializer=request_deserializer,
response_serializer=response_serializer,
)
if typed_handler.unary_stream is not None:
return grpc.unary_stream_rpc_method_handler(
partial(_reject_unary_stream, message),
request_deserializer=request_deserializer,
response_serializer=response_serializer,
)
if typed_handler.stream_unary is not None:
return grpc.stream_unary_rpc_method_handler(
partial(_reject_stream_unary, message),
request_deserializer=request_deserializer,
response_serializer=response_serializer,
)
if typed_handler.stream_stream is not None:
return grpc.stream_stream_rpc_method_handler(
partial(_reject_stream_stream, message),
request_deserializer=request_deserializer,
response_serializer=response_serializer,
)
return handler

View File

@@ -87,7 +87,15 @@ def register_shutdown_handlers(loop: asyncio.AbstractEventLoop) -> asyncio.Event
shutdown_event.set()
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, signal_handler)
try:
loop.add_signal_handler(sig, signal_handler)
except (NotImplementedError, RuntimeError, ValueError) as exc:
logger.warning(
"Signal handlers not supported; relying on default handlers",
signal=sig,
error=str(exc),
)
break
return shutdown_event

View File

@@ -41,5 +41,9 @@ def bind_server(
server,
)
address = f"{bind_address}:{port}"
server.add_insecure_port(address)
bound_port = server.add_insecure_port(address)
if bound_port == 0:
raise RuntimeError(f"Failed to bind gRPC server on {address}")
if port == 0 and bound_port != 0:
address = f"{bind_address}:{bound_port}"
return address

View File

@@ -265,6 +265,7 @@ class NoteFlowServicer(
self.pending_chunks: dict[str, int] = {}
self.audio_write_failed: set[str] = set()
self.stream_states: dict[str, MeetingStreamState] = {}
self.background_tasks: set[asyncio.Task[None]] = set()
def _init_diarization_state(self) -> None:
"""Initialize diarization job tracking state."""

View File

@@ -207,11 +207,13 @@ class GoogleCalendarAdapter(CalendarPort):
# Get display name from 'name' field, fall back to email prefix
name = data.get("name")
display_name = (
str(name)
if name
else str(email).split("@")[0].replace(".", " ").title()
)
if name:
display_name = str(name)
else:
# Extract username from email, handling edge cases where @ may be missing
email_str = str(email)
local_part = email_str.split("@")[0] if "@" in email_str else email_str
display_name = local_part.replace(".", " ").title() if local_part else email_str
return str(email), display_name

View File

@@ -64,7 +64,9 @@ def _extract_email(profile: OutlookProfile) -> str:
def _format_display_name(profile: OutlookProfile, email: str) -> str:
if display_name_raw := profile.get("displayName"):
return str(display_name_raw)
return email.split("@")[0].replace(".", " ").title()
# Extract username from email, handling edge cases where @ may be missing
local_part = email.split("@")[0] if "@" in email else email
return local_part.replace(".", " ").title() if local_part else email
async def fetch_user_info(

View File

@@ -28,6 +28,19 @@ from ._usage_event_builders import build_usage_event, extract_event_context
logger = get_logger(__name__)
def _flush_task_done_callback(
task: asyncio.Task[None],
tasks_set: set[asyncio.Task[None]],
) -> None:
"""Handle completion of a flush task, logging any unhandled exceptions."""
tasks_set.discard(task)
if task.cancelled():
return
exc = task.exception()
if exc is not None:
logger.exception("Flush task failed with unhandled exception", exc_info=exc)
class BufferedDatabaseUsageEventSink:
"""Usage event sink that persists events to database with buffering.
@@ -41,6 +54,7 @@ class BufferedDatabaseUsageEventSink:
DEFAULT_BUFFER_SIZE = 100
DEFAULT_FLUSH_INTERVAL_SECONDS = 5.0
DEFAULT_FLUSH_TIMEOUT_SECONDS = 30.0
def __init__(
self,
@@ -61,7 +75,7 @@ class BufferedDatabaseUsageEventSink:
self._flush_interval = flush_interval
self._buffer: deque[UsageEvent] = deque(maxlen=buffer_size * 2)
self._lock = Lock()
self._flush_task: asyncio.Task[None] | None = None
self._flush_tasks: set[asyncio.Task[None]] = set()
self._loop: asyncio.AbstractEventLoop | None = None
def record(self, event: UsageEvent) -> None:
@@ -99,7 +113,11 @@ class BufferedDatabaseUsageEventSink:
loop = asyncio.get_running_loop()
self._loop = loop
# Store task reference to prevent garbage collection
self._flush_task = loop.create_task(self._flush_async())
task = loop.create_task(self._flush_async())
self._flush_tasks.add(task)
task.add_done_callback(
lambda t: _flush_task_done_callback(t, self._flush_tasks)
)
except RuntimeError:
# No running loop, will flush on next opportunity
logger.debug("No event loop available for flush scheduling")
@@ -125,17 +143,40 @@ class BufferedDatabaseUsageEventSink:
async def _flush_async(self) -> None:
"""Flush buffered events to database."""
from noteflow.infrastructure.observability.otel import get_tracer
events = self._drain_buffer()
if not events:
return
try:
tracer = get_tracer(__name__)
with tracer.start_as_current_span("usage_events_flush") as span:
span.set_attribute("events.count", len(events))
repo = self._repository_factory()
count = await repo.add_batch(events)
logger.debug("Flushed %d usage events to database", count)
except Exception:
logger.exception("Failed to flush usage events to database")
self._restore_events(events)
try:
count = await asyncio.wait_for(
repo.add_batch(events),
timeout=self.DEFAULT_FLUSH_TIMEOUT_SECONDS,
)
span.set_attribute("events.flushed", count)
logger.debug("Flushed %d usage events to database", count)
except TimeoutError:
span.set_attribute("events.timeout", True)
logger.warning(
"Database flush timed out after %.1fs, restoring %d events",
self.DEFAULT_FLUSH_TIMEOUT_SECONDS,
len(events),
)
self._restore_events(events)
except Exception as exc:
span.record_exception(exc)
logger.exception("Failed to flush usage events to database")
self._restore_events(events)
finally:
# Ensure session cleanup even on timeout/exception
# Repository session is accessed via protected attribute
if hasattr(repo, "_session") and repo._session is not None:
await repo._session.close()
async def flush(self) -> int:
"""Manually flush all buffered events.
@@ -147,16 +188,47 @@ class BufferedDatabaseUsageEventSink:
if not events:
return 0
repo = self._repository_factory()
try:
repo = self._repository_factory()
return await repo.add_batch(events)
except Exception:
logger.exception("Failed to flush usage events")
self._restore_events(events)
return 0
finally:
# Ensure session cleanup even on exception
if hasattr(repo, "_session") and repo._session is not None:
await repo._session.close()
@property
def pending_count(self) -> int:
"""Number of events waiting to be flushed."""
with self._lock:
return len(self._buffer)
async def shutdown(self, timeout: float = 5.0) -> None:
"""Gracefully shutdown the sink, flushing pending events.
Args:
timeout: Maximum time to wait for flush to complete.
"""
# Wait for any pending flush tasks
pending_tasks = [t for t in self._flush_tasks if not t.done()]
if pending_tasks:
try:
await asyncio.wait_for(
asyncio.gather(*pending_tasks, return_exceptions=True),
timeout=timeout / 2,
)
except TimeoutError:
logger.warning(
"Usage sink: %d pending flush tasks timed out", len(pending_tasks)
)
# Perform final flush
try:
await asyncio.wait_for(self.flush(), timeout=timeout / 2)
except TimeoutError:
logger.warning("Usage sink: final flush timed out")
except Exception:
logger.exception("Usage sink: final flush failed")

View File

@@ -9,18 +9,28 @@ from __future__ import annotations
import os
import random
import sys
from collections.abc import Sequence
import tracemalloc
from collections.abc import Awaitable, Sequence
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Protocol
from uuid import UUID
import numpy as np
from numpy.typing import NDArray
if TYPE_CHECKING:
from noteflow.domain.entities import Segment
from noteflow.infrastructure.persistence.models.core import SegmentModel
from noteflow.infrastructure.security.crypto import AesGcmCryptoBox
class _OrmConverter(Protocol):
"""Protocol for ORM to domain converter."""
def segment_to_domain(self, model: SegmentModel) -> Segment: ...
class _StreamingServicer(Protocol):
active_streams: set[str]
@@ -135,16 +145,33 @@ def run_audio_writer_cycles(
cycle_count: int,
buffer_size: int = 1024,
sample_rate: int = DEFAULT_SAMPLE_RATE,
) -> None:
"""Run audio writer open/close cycles (helper to avoid loops in tests)."""
chunks_per_cycle: int = 10,
chunk_duration: float = 0.1,
) -> list[int]:
"""Run audio writer open/close cycles with actual audio data.
Returns:
List of bytes written per cycle (for metric logging).
"""
from noteflow.infrastructure.audio.writer import MeetingAudioWriter
bytes_per_cycle: list[int] = []
for i in range(cycle_count):
writer = MeetingAudioWriter(crypto, base_path, buffer_size=buffer_size)
dek = crypto.generate_dek()
wrapped_dek = crypto.wrap_dek(dek)
writer.open(f"meeting-{i}", dek, wrapped_dek, sample_rate=sample_rate)
# Write actual audio data
for _ in range(chunks_per_cycle):
chunk = generate_audio_chunk(chunk_duration, sample_rate)
writer.write_chunk(chunk)
bytes_written = writer.bytes_written
writer.close()
bytes_per_cycle.append(bytes_written)
return bytes_per_cycle
# =============================================================================
@@ -253,3 +280,856 @@ def run_fuzz_iterations_batch(
errors = verify_fuzz_iteration_results(results, seed)
all_errors.extend(errors)
return all_errors
# =============================================================================
# Heap growth test helpers (tracemalloc-based)
# =============================================================================
# Heap growth test constants
HEAP_MAX_GROWTH_BYTES = 1 * 1024 * 1024 # 1MB
HEAP_CONVERTER_MAX_GROWTH_BYTES = 512 * 1024 # 512KB
HEAP_PROTO_MAX_GROWTH_BYTES = 256 * 1024 # 256KB
ORM_CONVERSION_CYCLES = 500
PROTO_CYCLES = 200
def calculate_noteflow_heap_growth(
snapshot_before: tracemalloc.Snapshot,
snapshot_after: tracemalloc.Snapshot,
) -> int:
"""Calculate heap growth in noteflow code between snapshots.
Args:
snapshot_before: tracemalloc snapshot before operation.
snapshot_after: tracemalloc snapshot after operation.
Returns:
Total bytes of growth in noteflow code paths.
"""
top_stats = snapshot_after.compare_to(snapshot_before, "lineno")
return sum(
stat.size_diff
for stat in top_stats
if "noteflow" in str(stat.traceback)
)
def calculate_filtered_heap_growth(
snapshot_before: tracemalloc.Snapshot,
snapshot_after: tracemalloc.Snapshot,
filter_str: str,
) -> int:
"""Calculate heap growth filtered by string in traceback.
Args:
snapshot_before: tracemalloc snapshot before operation.
snapshot_after: tracemalloc snapshot after operation.
filter_str: String to match in traceback paths.
Returns:
Total bytes of growth matching filter.
"""
top_stats = snapshot_after.compare_to(snapshot_before, "lineno")
return sum(
stat.size_diff
for stat in top_stats
if filter_str in str(stat.traceback).lower()
)
def run_orm_conversion_cycles(
converter: _OrmConverter,
meeting_id: UUID,
cycle_count: int,
) -> None:
"""Run ORM to domain conversion cycles (helper to avoid loops in tests)."""
from noteflow.infrastructure.persistence.models.core import SegmentModel
for i in range(cycle_count):
model = SegmentModel(
meeting_id=meeting_id,
segment_id=i,
text=f"Segment {i} content.",
start_time=float(i * 5),
end_time=float(i * 5 + 4.5),
speaker_id=f"speaker_{i % 3}",
)
converter.segment_to_domain(model)
def run_protobuf_cycles(cycle_count: int) -> None:
"""Run protobuf create/serialize/parse cycles (helper to avoid loops in tests)."""
from uuid import uuid4
from noteflow.grpc.proto import noteflow_pb2
for _ in range(cycle_count):
meeting = noteflow_pb2.Meeting(
id=str(uuid4()),
title="Test Meeting",
state=noteflow_pb2.MEETING_STATE_COMPLETED,
)
serialized = meeting.SerializeToString()
parsed = noteflow_pb2.Meeting()
parsed.ParseFromString(serialized)
# =============================================================================
# Async context edge case helpers
# =============================================================================
# Context test constants
CONTEXT_SLEEP_DURATION_SECONDS = 0.05
CONTEXT_CANCEL_DELAY_SECONDS = 0.01
SLOW_ENTER_SLEEP_SECONDS = 10
CONCURRENT_CONTEXT_COUNT = 10
def run_concurrent_context_tasks(
context_ids: list[str],
active_contexts: set[str],
results: dict[str, int],
) -> list[Awaitable[None]]:
"""Create tasks for concurrent context testing.
Args:
context_ids: List of context IDs to create.
active_contexts: Shared set tracking active contexts.
results: Dict to store max_concurrent count.
Returns:
List of awaitables to gather.
"""
import asyncio
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
@asynccontextmanager
async def tracked_context(ctx_id: str) -> AsyncIterator[str]:
active_contexts.add(ctx_id)
results["max_concurrent"] = max(results["max_concurrent"], len(active_contexts))
try:
yield ctx_id
finally:
active_contexts.discard(ctx_id)
async def use_context(ctx_id: str) -> None:
async with tracked_context(ctx_id):
await asyncio.sleep(CONTEXT_SLEEP_DURATION_SECONDS)
return [use_context(ctx_id) for ctx_id in context_ids]
# =============================================================================
# Memory pressure test helpers
# =============================================================================
# Memory pressure constants
MEMORY_PRESSURE_ALLOCATION_MB = 50
MEMORY_PRESSURE_CHUNK_SIZE_BYTES = 1024 * 1024 # 1MB chunks
GC_TEMP_OBJECT_COUNT = 1000
LINUX_RSS_MULTIPLIER = 1024 # resource.ru_maxrss returns KB on Linux
def allocate_memory_pressure(size_mb: int) -> list[bytes]:
"""Allocate memory to create pressure, returns references to keep alive.
Args:
size_mb: Megabytes to allocate.
Returns:
List of byte objects (must keep reference to maintain pressure).
"""
chunks: list[bytes] = []
for _ in range(size_mb):
chunks.append(os.urandom(MEMORY_PRESSURE_CHUNK_SIZE_BYTES))
return chunks
def run_operations_under_memory_pressure(
servicer: _StreamingServicer,
cycle_count: int,
pressure_mb: int,
) -> tuple[int, int]:
"""Run streaming operations while holding memory pressure.
Args:
servicer: The servicer to test.
cycle_count: Number of init/cleanup cycles.
pressure_mb: MB of memory to hold during operations.
Returns:
Tuple of (successful_cycles, gc_collections).
"""
import gc
pressure_data = allocate_memory_pressure(pressure_mb)
gc_before = sum(gc.get_count())
successful = 0
for i in range(cycle_count):
meeting_id = f"pressure-test-{i:03d}"
try:
servicer.init_streaming_state(meeting_id, next_segment_id=0)
servicer.active_streams.add(meeting_id)
servicer.cleanup_streaming_state(meeting_id)
servicer.close_audio_writer(meeting_id)
servicer.active_streams.discard(meeting_id)
successful += 1
except MemoryError:
break
gc_after = sum(gc.get_count())
del pressure_data
return successful, gc_after - gc_before
def run_gc_collection_behavior() -> tuple[int, int, int]:
"""Run GC behavior test with temporary objects.
Returns:
Tuple of (gen0_diff, gen1_diff, gen2_diff) in collection counts.
"""
import gc
gc.collect()
counts_before = gc.get_count()
for _ in range(GC_TEMP_OBJECT_COUNT):
_ = {"key": "value", "data": [1, 1, 1, 1, 1]}
gc.collect()
counts_after = gc.get_count()
return (
counts_after[0] - counts_before[0],
counts_after[1] - counts_before[1],
counts_after[0 + 1 + 1] - counts_before[0 + 1 + 1],
)
def measure_rss_bytes() -> int:
"""Measure current process RSS (Resident Set Size) in bytes.
Returns:
RSS in bytes, or -1 if measurement not supported.
"""
try:
import psutil
process = psutil.Process()
return process.memory_info().rss
except ImportError:
pass
if sys.platform == "darwin" or sys.platform == "linux":
try:
import resource
usage = resource.getrusage(resource.RUSAGE_SELF)
if sys.platform == "linux":
return usage.ru_maxrss * LINUX_RSS_MULTIPLIER
return usage.ru_maxrss
except ImportError:
pass
return -1
# =============================================================================
# Rigorous stress test helpers
# =============================================================================
# High-cycle constants for rigorous testing
RIGOROUS_CYCLE_COUNT = 1000
RIGOROUS_AUDIO_CHUNKS_PER_CYCLE = 100
RIGOROUS_CONCURRENT_MEETINGS = 20
RIGOROUS_AUDIO_DURATION_SECONDS = 0.1
THREAD_LEAK_TOLERANCE = 2 # Allow up to 2 extra threads (GC, etc.)
FD_LEAK_TOLERANCE = 5 # Allow up to 5 extra FDs (system variation)
def get_thread_count() -> int:
"""Get current thread count for this process."""
import threading
return threading.active_count()
def generate_audio_chunk(
duration_seconds: float,
sample_rate: int,
) -> NDArray[np.float32]:
"""Generate realistic audio chunk with speech-like characteristics."""
samples = int(duration_seconds * sample_rate)
# Generate pink noise (more speech-like than white noise)
white = np.random.randn(samples).astype(np.float32)
# Simple low-pass approximation for pink noise
pink = np.convolve(white, np.ones(8) / 8, mode="same")
# Normalize to [-0.5, 0.5] range (typical speech amplitude)
pink = pink / (np.abs(pink).max() + 1e-6) * 0.5
return pink.astype(np.float32)
def run_audio_writer_high_cycle(
crypto: AesGcmCryptoBox,
base_path: Path,
cycle_count: int,
chunks_per_cycle: int,
sample_rate: int = DEFAULT_SAMPLE_RATE,
) -> tuple[int, int, int]:
"""Run high-cycle audio writer stress test with real data.
Returns:
Tuple of (successful_cycles, total_bytes_written, total_chunks).
"""
from noteflow.infrastructure.audio.writer import MeetingAudioWriter
successful = 0
total_bytes = 0
total_chunks = 0
for i in range(cycle_count):
writer = MeetingAudioWriter(crypto, base_path, buffer_size=4096)
dek = crypto.generate_dek()
wrapped_dek = crypto.wrap_dek(dek)
writer.open(f"stress-{i:05d}", dek, wrapped_dek, sample_rate=sample_rate)
# Write real audio data (writer expects float32 arrays)
for _ in range(chunks_per_cycle):
chunk = generate_audio_chunk(RIGOROUS_AUDIO_DURATION_SECONDS, sample_rate)
writer.write_chunk(chunk)
total_chunks += 1
total_bytes += writer.bytes_written
writer.close()
successful += 1
return successful, total_bytes, total_chunks
def run_segmenter_high_cycle(
cycle_count: int,
chunks_per_cycle: int,
sample_rate: int = DEFAULT_SAMPLE_RATE,
) -> tuple[int, int]:
"""Run high-cycle segmenter stress test with real audio.
Returns:
Tuple of (total_segments_produced, total_audio_seconds_processed).
"""
from noteflow.infrastructure.asr.segmenter import Segmenter, SegmenterConfig
total_segments = 0
total_seconds = 0.0
for _ in range(cycle_count):
config = SegmenterConfig(sample_rate=sample_rate)
segmenter = Segmenter(config=config)
for j in range(chunks_per_cycle):
chunk = generate_audio_chunk(RIGOROUS_AUDIO_DURATION_SECONDS, sample_rate)
# Alternate speech/silence to produce segments
is_speech = (j % 20) < 15 # 75% speech
segments = segmenter.process_audio(chunk, is_speech)
total_segments += len(segments)
total_seconds += RIGOROUS_AUDIO_DURATION_SECONDS
if segmenter.flush():
total_segments += 1
return total_segments, int(total_seconds)
def run_concurrent_streaming_sessions(
servicer: _StreamingServicer,
meeting_count: int,
chunks_per_meeting: int,
sample_rate: int = DEFAULT_SAMPLE_RATE,
) -> tuple[int, list[str]]:
"""Run concurrent streaming sessions on servicer.
Returns:
Tuple of (successful_meetings, any_error_messages).
"""
errors: list[str] = []
successful = 0
# Initialize all meetings
meeting_ids = [f"concurrent-{i:03d}" for i in range(meeting_count)]
for mid in meeting_ids:
servicer.init_streaming_state(mid, next_segment_id=0)
servicer.active_streams.add(mid)
# Verify all initialized
if len(servicer.active_streams) != meeting_count:
errors.append(
f"Expected {meeting_count} active streams, got {len(servicer.active_streams)}"
)
# Cleanup all meetings
for mid in meeting_ids:
servicer.cleanup_streaming_state(mid)
servicer.close_audio_writer(mid)
servicer.active_streams.discard(mid)
successful += 1
return successful, errors
def measure_resource_baseline() -> tuple[int, int, int]:
"""Measure baseline resource usage.
Returns:
Tuple of (fd_count, thread_count, rss_bytes).
"""
import gc
gc.collect()
return get_fd_count(), get_thread_count(), measure_rss_bytes()
def check_resource_leaks(
baseline: tuple[int, int, int],
fd_tolerance: int = FD_LEAK_TOLERANCE,
thread_tolerance: int = THREAD_LEAK_TOLERANCE,
) -> list[str]:
"""Check for resource leaks against baseline.
Returns:
List of error messages (empty if no leaks).
"""
import gc
gc.collect()
errors: list[str] = []
fd_before, thread_before, _rss_before = baseline
fd_after = get_fd_count()
thread_after = get_thread_count()
if fd_before >= 0 and fd_after >= 0:
fd_delta = fd_after - fd_before
if fd_delta > fd_tolerance:
errors.append(f"FD leak: {fd_delta} new FDs (tolerance={fd_tolerance})")
thread_delta = thread_after - thread_before
if thread_delta > thread_tolerance:
errors.append(
f"Thread leak: {thread_delta} new threads (tolerance={thread_tolerance})"
)
return errors
def log_stress_metrics(
test_name: str,
baseline: tuple[int, int, int],
*,
bytes_written: int = 0,
cycles: int = 0,
extra: dict[str, object] | None = None,
) -> None:
"""Log quantified stress test metrics to stdout.
Metrics are emitted in a structured format for test output analysis.
Use pytest -s to see these metrics during test runs.
"""
fd_before, thread_before, rss_before = baseline
fd_after = get_fd_count()
thread_after = get_thread_count()
rss_after = measure_rss_bytes()
fd_delta = fd_after - fd_before if fd_before >= 0 and fd_after >= 0 else 0
thread_delta = thread_after - thread_before
rss_delta_kb = (rss_after - rss_before) // 1024
lines = [
f"\n{'=' * 60}",
f"STRESS METRICS: {test_name}",
f"{'=' * 60}",
f" Cycles: {cycles:,}",
f" Bytes written: {bytes_written:,}",
f" FD delta: {fd_delta:+d} ({fd_before} -> {fd_after})",
f" Thread delta: {thread_delta:+d} ({thread_before} -> {thread_after})",
f" RSS delta: {rss_delta_kb:+,} KB ({rss_before // 1024:,} -> {rss_after // 1024:,} KB)",
]
if extra:
lines.append(" Extra:")
lines.extend(f" {k}: {v}" for k, v in extra.items())
lines.append(f"{'=' * 60}\n")
print("\n".join(lines))
def verify_audio_writer_high_cycle_metrics(
test_name: str,
baseline: tuple[int, int, int],
*,
successful_cycles: int,
total_bytes: int,
total_chunks: int,
cycle_count: int,
chunks_per_cycle: int,
) -> None:
"""Verify high-cycle audio writer metrics and resource cleanup."""
log_stress_metrics(
test_name,
baseline,
bytes_written=total_bytes,
cycles=cycle_count,
extra={"chunks": total_chunks, "successful_cycles": successful_cycles},
)
errors = check_resource_leaks(baseline)
expected_chunks = cycle_count * chunks_per_cycle
if successful_cycles != cycle_count:
raise AssertionError(f"Only {successful_cycles}/{cycle_count} cycles completed")
if total_chunks != expected_chunks:
raise AssertionError(f"Expected {expected_chunks} chunks, wrote {total_chunks}")
if errors:
raise AssertionError(
f"Resource leaks after {total_bytes} bytes written: {errors}"
)
def log_heap_metrics(
test_name: str,
heap_growth_bytes: int,
cycles: int = 0,
*,
max_allowed_bytes: int = 0,
filter_name: str = "",
) -> None:
"""Log heap growth metrics from tracemalloc tests.
Args:
test_name: Name of the test.
heap_growth_bytes: Heap growth in bytes.
cycles: Number of operation cycles.
max_allowed_bytes: Maximum allowed heap growth for context.
filter_name: Filter name used (e.g., 'noteflow', 'converter').
"""
growth_kb = heap_growth_bytes / 1024
max_kb = max_allowed_bytes / 1024 if max_allowed_bytes else 0
utilization = (heap_growth_bytes / max_allowed_bytes * 100) if max_allowed_bytes else 0
lines = [
f"\n{'=' * 60}",
f"HEAP METRICS: {test_name}",
f"{'=' * 60}",
f" Cycles: {cycles:,}",
f" Heap growth: {growth_kb:,.1f} KB",
]
if max_allowed_bytes:
lines.append(f" Max allowed: {max_kb:,.1f} KB ({utilization:.1f}% utilized)")
if filter_name:
lines.append(f" Filter: {filter_name}")
lines.append(f"{'=' * 60}\n")
print("\n".join(lines))
def log_thread_metrics(
test_name: str,
initial_threads: int,
final_threads: int,
cycles: int,
tolerance: int,
) -> None:
"""Log thread count metrics for thread leak tests.
Args:
test_name: Name of the test.
initial_threads: Thread count before test.
final_threads: Thread count after test.
cycles: Number of operation cycles.
tolerance: Allowed thread leak tolerance.
"""
thread_delta = final_threads - initial_threads
lines = [
f"\n{'=' * 60}",
f"THREAD METRICS: {test_name}",
f"{'=' * 60}",
f" Cycles: {cycles}",
f" Thread delta: {thread_delta:+d} ({initial_threads} -> {final_threads})",
f" Tolerance: {tolerance}",
f"{'=' * 60}\n",
]
print("\n".join(lines))
def run_streaming_with_resource_tracking(
servicer: _StreamingServicer,
cycle_count: int,
) -> tuple[int, list[str]]:
"""Run streaming cycles with resource tracking.
Returns:
Tuple of (successful_cycles, error_messages).
"""
baseline = measure_resource_baseline()
run_streaming_init_cleanup_cycles(servicer, cycle_count, "tracked")
errors = check_resource_leaks(baseline)
return cycle_count, errors
def run_interleaved_init_cleanup(
servicer: _StreamingServicer,
cycle_count: int,
init_per_cycle: int = 3,
cleanup_per_cycle: int = 1,
) -> tuple[int, int]:
"""Run interleaved init/cleanup pattern (init N, cleanup M, repeat).
Returns:
Tuple of (total_inits, total_cleanups).
"""
active_meetings: list[str] = []
total_inits = 0
total_cleanups = 0
for cycle in range(cycle_count):
# Init N meetings
for j in range(init_per_cycle):
meeting_id = f"interleave-{cycle}-{j}"
servicer.init_streaming_state(meeting_id, next_segment_id=0)
servicer.active_streams.add(meeting_id)
active_meetings.append(meeting_id)
total_inits += 1
# Cleanup M oldest
for _ in range(min(cleanup_per_cycle, len(active_meetings))):
old_meeting = active_meetings.pop(0)
servicer.cleanup_streaming_state(old_meeting)
servicer.close_audio_writer(old_meeting)
servicer.active_streams.discard(old_meeting)
total_cleanups += 1
# Cleanup remaining
for meeting_id in active_meetings:
servicer.cleanup_streaming_state(meeting_id)
servicer.close_audio_writer(meeting_id)
servicer.active_streams.discard(meeting_id)
total_cleanups += 1
return total_inits, total_cleanups
class _MockSessionFactory(Protocol):
"""Protocol for mock session factory callable."""
def __call__(self) -> object: ...
class _ServicerWithStreamStates(Protocol):
"""Extended servicer protocol with stream_states access."""
active_streams: set[str]
stream_states: dict[str, object]
def init_streaming_state(self, meeting_id: str, next_segment_id: int) -> None: ...
def cleanup_streaming_state(self, meeting_id: str) -> None: ...
def close_audio_writer(self, meeting_id: str) -> None: ...
def run_diarization_session_cycles(
servicer: _ServicerWithStreamStates,
cycle_count: int,
mock_session_factory: _MockSessionFactory,
) -> int:
"""Run diarization session init/cleanup cycles with mock sessions.
Args:
servicer: The servicer to test (must have stream_states attribute).
cycle_count: Number of cycles to run.
mock_session_factory: Callable that returns mock session objects.
Returns:
Number of successful cycles.
"""
successful = 0
for i in range(cycle_count):
meeting_id = f"diarize-{i:04d}"
mock_session = mock_session_factory()
servicer.init_streaming_state(meeting_id, next_segment_id=0)
servicer.active_streams.add(meeting_id)
# Access stream state and set diarization session
state = servicer.stream_states.get(meeting_id)
if state is not None:
object.__setattr__(state, "diarization_session", mock_session)
servicer.cleanup_streaming_state(meeting_id)
servicer.close_audio_writer(meeting_id)
servicer.active_streams.discard(meeting_id)
successful += 1
return successful
def create_diarization_tasks(
task_count: int,
sleep_seconds: float,
) -> list[tuple[str, Awaitable[None]]]:
"""Create diarization-like async tasks for testing.
Returns:
List of (job_id, task_coroutine) tuples. Caller must schedule tasks.
"""
import asyncio
return [
(f"job-{i:04d}", asyncio.sleep(sleep_seconds))
for i in range(task_count)
]
def schedule_tasks_to_servicer(
servicer: object,
task_count: int,
sleep_seconds: float,
) -> list[object]:
"""Schedule async tasks to servicer's diarization_tasks dict.
Args:
servicer: Servicer with diarization_tasks dict attribute.
task_count: Number of tasks to create.
sleep_seconds: How long each task should sleep.
Returns:
List of created task objects.
"""
import asyncio
tasks: list[object] = []
diarization_tasks = object.__getattribute__(servicer, "diarization_tasks")
for i in range(task_count):
task = asyncio.create_task(asyncio.sleep(sleep_seconds))
diarization_tasks[f"job-{i:04d}"] = task
tasks.append(task)
return tasks
def verify_all_tasks_cancelled(tasks: list[object]) -> list[str]:
"""Verify all tasks are done and cancelled.
Returns:
List of error messages (empty if all tasks properly cancelled).
"""
errors: list[str] = []
for i, task in enumerate(tasks):
done_fn = object.__getattribute__(task, "done")
cancelled_fn = object.__getattribute__(task, "cancelled")
is_done = done_fn()
is_cancelled = cancelled_fn()
if not is_done:
errors.append(f"Task {i} not done after shutdown")
if not is_cancelled:
errors.append(f"Task {i} not cancelled after shutdown")
return errors
async def run_webhook_executor_cycles(
cycle_count: int,
) -> list[tuple[bool, bool]]:
"""Run webhook executor open/close cycles.
Args:
cycle_count: Number of cycles to run.
Returns:
List of (client_created, client_cleared) tuples for each cycle.
"""
import httpx
from noteflow.infrastructure.webhooks.executor import WebhookExecutor
def _get_client(executor: object) -> object | None:
"""Get executor's HTTP client."""
client = object.__getattribute__(executor, "_client")
return client if isinstance(client, httpx.AsyncClient) else None
async def _ensure_client(executor: object) -> None:
"""Ensure executor has an HTTP client."""
ensure_method = object.__getattribute__(executor, "_ensure_client")
await ensure_method()
results: list[tuple[bool, bool]] = []
for _ in range(cycle_count):
executor = WebhookExecutor()
await _ensure_client(executor)
client_created = _get_client(executor) is not None
await executor.close()
client_cleared = _get_client(executor) is None
results.append((client_created, client_cleared))
return results
def verify_webhook_executor_results(
results: list[tuple[bool, bool]],
) -> tuple[bool, bool]:
"""Verify all webhook executor cycles succeeded.
Args:
results: List of (client_created, client_cleared) tuples.
Returns:
Tuple of (all_created, all_cleared).
"""
all_created = True
all_cleared = True
for created, cleared in results:
if not created:
all_created = False
if not cleared:
all_cleared = False
return all_created, all_cleared
def run_audio_writer_thread_cycles(
crypto: object,
tmp_path: object,
cycle_count: int,
sample_rate: int,
) -> None:
"""Run audio writer open/close cycles for thread leak testing.
Args:
crypto: AesGcmCryptoBox instance.
tmp_path: Path for temporary files.
cycle_count: Number of cycles to run.
sample_rate: Audio sample rate.
"""
from noteflow.infrastructure.audio.writer import MeetingAudioWriter
for i in range(cycle_count):
writer = MeetingAudioWriter(crypto, tmp_path, buffer_size=1024)
generate_dek = object.__getattribute__(crypto, "generate_dek")
wrap_dek = object.__getattribute__(crypto, "wrap_dek")
dek = generate_dek()
wrapped_dek = wrap_dek(dek)
writer.open(f"thread-test-{i:04d}", dek, wrapped_dek, sample_rate=sample_rate)
writer.close()

View File

@@ -1,5 +1,24 @@
{
"generated_at": "2026-01-07T01:08:35.197277+00:00",
"rules": {},
"generated_at": "2026-01-14T04:30:39.258014+00:00",
"rules": {
"deep_nesting": [
"deep_nesting|src/noteflow/application/services/asr_config_service.py|_execute_reconfiguration|depth=3",
"deep_nesting|src/noteflow/application/services/asr_config_service.py|shutdown|depth=3",
"deep_nesting|src/noteflow/grpc/_mixins/diarization/_jobs.py|_execute_diarization|depth=3",
"deep_nesting|src/noteflow/grpc/_mixins/sync.py|perform_sync|depth=3",
"deep_nesting|src/noteflow/grpc/_service_shutdown.py|cancel_sync_tasks|depth=3"
],
"god_class": [
"god_class|src/noteflow/application/services/asr_config_service.py|AsrConfigService|methods=16"
],
"long_method": [
"long_method|src/noteflow/grpc/_mixins/sync.py|perform_sync|lines=60"
],
"module_size_soft": [
"module_size_soft|src/noteflow/application/services/asr_config_service.py|module|lines=361",
"module_size_soft|src/noteflow/grpc/_mixins/diarization/_jobs.py|module|lines=372",
"module_size_soft|src/noteflow/grpc/_mixins/sync.py|module|lines=384"
]
},
"schema_version": 1
}

View File

@@ -9,6 +9,7 @@ from __future__ import annotations
import asyncio
import gc
from contextlib import asynccontextmanager
from pathlib import Path
from typing import TYPE_CHECKING
from unittest.mock import MagicMock
@@ -18,30 +19,76 @@ import pytest
from noteflow.config.constants import DEFAULT_SAMPLE_RATE
from noteflow.grpc.service import NoteFlowServicer
from support.stress_helpers import (
get_fd_count,
CONCURRENT_CONTEXT_COUNT,
CONTEXT_CANCEL_DELAY_SECONDS,
HEAP_CONVERTER_MAX_GROWTH_BYTES,
HEAP_MAX_GROWTH_BYTES,
HEAP_PROTO_MAX_GROWTH_BYTES,
MEMORY_PRESSURE_ALLOCATION_MB,
ORM_CONVERSION_CYCLES,
PROTO_CYCLES,
RIGOROUS_AUDIO_CHUNKS_PER_CYCLE,
RIGOROUS_CONCURRENT_MEETINGS,
SLOW_ENTER_SLEEP_SECONDS,
calculate_filtered_heap_growth,
calculate_noteflow_heap_growth,
check_resource_leaks,
log_heap_metrics,
log_stress_metrics,
log_thread_metrics,
measure_resource_baseline,
measure_rss_bytes,
run_audio_writer_cycles,
run_audio_writer_high_cycle,
run_audio_writer_thread_cycles,
run_concurrent_context_tasks,
run_concurrent_streaming_sessions,
run_diarization_session_cycles,
run_gc_collection_behavior,
run_interleaved_init_cleanup,
run_operations_under_memory_pressure,
run_orm_conversion_cycles,
run_protobuf_cycles,
run_streaming_init_cleanup_cycles,
run_webhook_executor_cycles,
schedule_tasks_to_servicer,
verify_all_tasks_cancelled,
verify_audio_writer_high_cycle_metrics,
verify_webhook_executor_results,
)
if TYPE_CHECKING:
from collections.abc import AsyncIterator
from threading import Thread
import httpx
from noteflow.infrastructure.security.crypto import AesGcmCryptoBox
# Test constants
STREAMING_CYCLES = 50
AUDIO_WRITER_CYCLES = 20
MEMORY_TEST_CYCLES = 100
INTERLEAVE_CYCLES = 20
DIARIZATION_CYCLES = 50
WRITER_THREAD_CYCLES = 10
# Test constants - rigorous values for real stress testing
STREAMING_CYCLES = 500 # 10x increase for slow leak detection
AUDIO_WRITER_CYCLES = 100 # 5x increase with real data writes
MEMORY_TEST_CYCLES = 1000 # 10x increase for state dict stability
INTERLEAVE_CYCLES = 200 # 10x increase for interleaved patterns
DIARIZATION_CYCLES = 200 # 4x increase for session lifecycle
WRITER_THREAD_CYCLES = 50 # 5x increase for thread lifecycle
FD_LEAK_TOLERANCE = 10
THREAD_LEAK_TOLERANCE = 2
TASK_TEST_COUNT = 5
TASK_TEST_COUNT = 50 # 10x increase for coroutine leak detection
TASK_SLEEP_SECONDS = 10
WEBHOOK_EXECUTOR_CYCLES = 5
WEBHOOK_EXECUTOR_CYCLES = 20 # 4x increase for HTTP client lifecycle
BYTES_PER_KB = 1024
# High-stress constants (use sparingly - these take longer)
HIGH_STRESS_STREAMING_CYCLES = 2000
HIGH_STRESS_SEGMENTER_CYCLES = 500
HIGH_STRESS_SEGMENTER_CHUNKS = 200
# Async cleanup wait times
ASYNC_CLEANUP_DELAY_SECONDS = 0.1
FLUSH_THREAD_CLEANUP_DELAY_SECONDS = 0.2
FAILING_TASK_DELAY_SECONDS = 0.1
THREAD_STOP_WAIT_SECONDS = 0.5 # Longer wait for thread pool cleanup
def _get_executor_client(executor: object) -> httpx.AsyncClient | None:
@@ -96,57 +143,120 @@ def _get_writer_flush_thread(writer: object) -> Thread | None:
class TestFileDescriptorLeaks:
"""Detect file descriptor leaks."""
"""Detect file descriptor and thread leaks under high load."""
@pytest.mark.slow
@pytest.mark.asyncio
async def test_streaming_fd_cleanup(
self, memory_servicer: NoteFlowServicer, initial_fd_count: int
self, memory_servicer: NoteFlowServicer
) -> None:
"""Verify file descriptors released after streaming cycles."""
"""Verify FDs and threads released after 500 streaming cycles."""
baseline = measure_resource_baseline()
run_streaming_init_cleanup_cycles(memory_servicer, STREAMING_CYCLES)
gc.collect()
await asyncio.sleep(0.1) # Allow async cleanup
await asyncio.sleep(ASYNC_CLEANUP_DELAY_SECONDS)
final_fds = get_fd_count()
assert final_fds <= initial_fd_count + FD_LEAK_TOLERANCE, (
f"File descriptor leak: started with {initial_fd_count}, ended with {final_fds}"
)
errors = check_resource_leaks(baseline)
assert not errors, f"Resource leaks after {STREAMING_CYCLES} cycles: {errors}"
@pytest.mark.slow
@pytest.mark.asyncio
async def test_audio_writer_fd_cleanup(
self, tmp_path: Path, initial_fd_count: int, crypto: AesGcmCryptoBox
self, tmp_path: Path, crypto: AesGcmCryptoBox
) -> None:
"""Verify audio writer closes file handles."""
run_audio_writer_cycles(crypto, tmp_path, AUDIO_WRITER_CYCLES)
"""Verify audio writer closes FDs and threads after 100 cycles."""
baseline = measure_resource_baseline()
bytes_per_cycle = run_audio_writer_cycles(crypto, tmp_path, AUDIO_WRITER_CYCLES)
total_bytes = sum(bytes_per_cycle)
gc.collect()
await asyncio.sleep(0.1)
await asyncio.sleep(ASYNC_CLEANUP_DELAY_SECONDS)
final_fds = get_fd_count()
assert final_fds <= initial_fd_count + FD_LEAK_TOLERANCE // 2, (
f"Audio writer FD leak: {initial_fd_count} -> {final_fds}"
log_stress_metrics(
"test_audio_writer_fd_cleanup",
baseline,
bytes_written=total_bytes,
cycles=AUDIO_WRITER_CYCLES,
)
errors = check_resource_leaks(baseline)
assert not errors, f"Audio writer resource leaks: {errors}"
@pytest.mark.slow
@pytest.mark.asyncio
async def test_audio_writer_high_cycle_with_real_data(
self, tmp_path: Path, crypto: AesGcmCryptoBox
) -> None:
"""Verify no leaks after writing real audio data across many cycles.
This test writes actual audio chunks (not just open/close) to stress
the buffer management, flush thread, and encryption pipeline.
"""
baseline = measure_resource_baseline()
successful, total_bytes, total_chunks = run_audio_writer_high_cycle(
crypto,
tmp_path,
cycle_count=AUDIO_WRITER_CYCLES,
chunks_per_cycle=RIGOROUS_AUDIO_CHUNKS_PER_CYCLE,
)
gc.collect()
await asyncio.sleep(FLUSH_THREAD_CLEANUP_DELAY_SECONDS)
verify_audio_writer_high_cycle_metrics(
"test_audio_writer_high_cycle_with_real_data",
baseline,
successful_cycles=successful,
total_bytes=total_bytes,
total_chunks=total_chunks,
cycle_count=AUDIO_WRITER_CYCLES,
chunks_per_cycle=RIGOROUS_AUDIO_CHUNKS_PER_CYCLE,
)
@pytest.mark.slow
@pytest.mark.asyncio
async def test_concurrent_meetings_resource_cleanup(
self, memory_servicer: NoteFlowServicer
) -> None:
"""Verify resources cleaned up after concurrent meeting sessions."""
baseline = measure_resource_baseline()
successful, session_errors = run_concurrent_streaming_sessions(
memory_servicer,
meeting_count=RIGOROUS_CONCURRENT_MEETINGS,
chunks_per_meeting=0,
)
gc.collect()
await asyncio.sleep(ASYNC_CLEANUP_DELAY_SECONDS)
log_stress_metrics(
"test_concurrent_meetings_resource_cleanup",
baseline,
cycles=RIGOROUS_CONCURRENT_MEETINGS,
extra={"successful": successful, "session_errors": len(session_errors)},
)
errors = check_resource_leaks(baseline)
all_errors = session_errors + errors
assert successful == RIGOROUS_CONCURRENT_MEETINGS, (
f"Only {successful}/{RIGOROUS_CONCURRENT_MEETINGS} meetings succeeded"
)
assert not all_errors, f"Concurrent session errors: {all_errors}"
class TestMemoryLeaks:
"""Detect memory leaks under load."""
"""Detect memory leaks under high load (1000+ cycles)."""
@pytest.mark.slow
@pytest.mark.asyncio
async def test_streaming_state_memory_stability(
self, memory_servicer: NoteFlowServicer
) -> None:
"""Verify state dicts don't grow unbounded during streaming cycles."""
for i in range(MEMORY_TEST_CYCLES):
meeting_id = f"mem-test-{i:03d}"
memory_servicer.init_streaming_state(meeting_id, next_segment_id=0)
memory_servicer.active_streams.add(meeting_id)
memory_servicer.cleanup_streaming_state(meeting_id)
memory_servicer.active_streams.discard(meeting_id)
"""Verify state dicts don't grow unbounded during 1000 streaming cycles."""
run_streaming_init_cleanup_cycles(memory_servicer, MEMORY_TEST_CYCLES)
gc.collect()
@@ -161,32 +271,17 @@ class TestMemoryLeaks:
async def test_resource_interleaved_init_cleanup(
self, memory_servicer: NoteFlowServicer
) -> None:
"""Verify memory stability with interleaved init/cleanup patterns."""
active_meetings: list[str] = []
# Interleaved pattern: init 3, cleanup 1, repeat
for cycle in range(INTERLEAVE_CYCLES):
# Init 3 meetings
for j in range(3):
meeting_id = f"interleave-{cycle}-{j}"
memory_servicer.init_streaming_state(meeting_id, next_segment_id=0)
memory_servicer.active_streams.add(meeting_id)
active_meetings.append(meeting_id)
# Cleanup 1 oldest
if active_meetings:
old_meeting = active_meetings.pop(0)
memory_servicer.cleanup_streaming_state(old_meeting)
memory_servicer.active_streams.discard(old_meeting)
# Cleanup remaining
for meeting_id in active_meetings:
memory_servicer.cleanup_streaming_state(meeting_id)
memory_servicer.active_streams.discard(meeting_id)
"""Verify memory stability with 200 interleaved init/cleanup patterns."""
total_inits, total_cleanups = run_interleaved_init_cleanup(
memory_servicer, INTERLEAVE_CYCLES
)
gc.collect()
# Verify clean state
# Verify clean state and all operations completed
assert total_inits == total_cleanups, (
f"Mismatched init/cleanup: {total_inits} inits, {total_cleanups} cleanups"
)
assert len(memory_servicer.vad_instances) == 0, "VAD leaked after interleave"
assert len(memory_servicer.active_streams) == 0, "Streams leaked after interleave"
@@ -195,21 +290,22 @@ class TestMemoryLeaks:
async def test_diarization_session_memory(
self, memory_servicer: NoteFlowServicer
) -> None:
"""Verify diarization sessions don't leak memory."""
# Create and close many sessions
for i in range(DIARIZATION_CYCLES):
meeting_id = f"diarize-mem-{i:03d}"
mock_session = MagicMock()
mock_session.close = MagicMock()
"""Verify 200 diarization session cycles don't leak memory."""
memory_servicer.init_streaming_state(meeting_id, next_segment_id=0)
state = memory_servicer.get_stream_state(meeting_id)
assert state is not None
state.diarization_session = mock_session
memory_servicer.cleanup_streaming_state(meeting_id)
def create_mock_session() -> MagicMock:
mock = MagicMock()
mock.close = MagicMock()
return mock
successful = run_diarization_session_cycles(
memory_servicer, DIARIZATION_CYCLES, create_mock_session
)
gc.collect()
assert successful == DIARIZATION_CYCLES, (
f"Only {successful}/{DIARIZATION_CYCLES} cycles completed"
)
assert len(memory_servicer.stream_states) == 0, "Diarization sessions leaked"
@@ -221,14 +317,10 @@ class TestCoroutineLeaks:
self, memory_servicer: NoteFlowServicer
) -> None:
"""Verify no tasks remain after servicer shutdown."""
# Create some diarization tasks with explicit type annotation
tasks_created: list[asyncio.Task[None]] = []
for i in range(TASK_TEST_COUNT):
task: asyncio.Task[None] = asyncio.create_task(
asyncio.sleep(TASK_SLEEP_SECONDS)
)
memory_servicer.diarization_tasks[f"job-{i}"] = task
tasks_created.append(task)
# Create tasks using helper (avoids inline loop)
tasks_created = schedule_tasks_to_servicer(
memory_servicer, TASK_TEST_COUNT, TASK_SLEEP_SECONDS
)
# Shutdown
await memory_servicer.shutdown()
@@ -238,10 +330,9 @@ class TestCoroutineLeaks:
f"Expected 0 diarization tasks after shutdown, found {len(memory_servicer.diarization_tasks)}"
)
# Check tasks are cancelled
for task in tasks_created:
assert task.done(), "Task should be done after shutdown"
assert task.cancelled(), "Task should be cancelled after shutdown"
# Check tasks are cancelled using helper (avoids inline loop)
errors = verify_all_tasks_cancelled(tasks_created)
assert not errors, f"Task cleanup errors: {errors}"
@pytest.mark.asyncio
async def test_task_cleanup_on_exception(
@@ -250,7 +341,7 @@ class TestCoroutineLeaks:
"""Verify tasks cleaned up even if task raises exception."""
async def failing_task() -> None:
await asyncio.sleep(0.1)
await asyncio.sleep(FAILING_TASK_DELAY_SECONDS)
raise ValueError("Task failed")
task = asyncio.create_task(failing_task())
@@ -303,23 +394,13 @@ class TestWebhookClientLeaks:
@pytest.mark.asyncio
async def test_webhook_executor_multiple_cycles(self) -> None:
"""Verify webhook executor can be opened and closed multiple times."""
from noteflow.infrastructure.webhooks.executor import WebhookExecutor
# Run cycles using helper function (avoids inline loop)
results = await run_webhook_executor_cycles(WEBHOOK_EXECUTOR_CYCLES)
async def run_executor_cycle() -> tuple[bool, bool]:
"""Run one executor open/close cycle, return (client_created, client_cleared)."""
executor = WebhookExecutor()
await _ensure_executor_client(executor)
client_created = _get_executor_client(executor) is not None
await executor.close()
client_cleared = _get_executor_client(executor) is None
return client_created, client_cleared
# Run cycles and collect results
results = [await run_executor_cycle() for _ in range(WEBHOOK_EXECUTOR_CYCLES)]
# Verify all cycles succeeded
assert all(created for created, _ in results), "Client creation failed in some cycles"
assert all(cleared for _, cleared in results), "Client cleanup failed in some cycles"
# Verify all cycles succeeded using helper (avoids inline loop)
all_created, all_cleared = verify_webhook_executor_results(results)
assert all_created, "Client creation failed in some cycles"
assert all_cleared, "Client cleanup failed in some cycles"
class TestDiarizationSessionLeaks:
@@ -395,7 +476,7 @@ class TestAudioWriterThreadLeaks:
writer.close()
# Give thread time to stop
await asyncio.sleep(0.1)
await asyncio.sleep(ASYNC_CLEANUP_DELAY_SECONDS)
# Thread should be stopped and cleared
assert _get_writer_flush_thread(writer) is None, "Flush thread should be cleared after close"
@@ -406,27 +487,326 @@ class TestAudioWriterThreadLeaks:
"""Verify no thread leaks across multiple writer cycles."""
import threading
from noteflow.infrastructure.audio.writer import MeetingAudioWriter
from noteflow.infrastructure.security.crypto import AesGcmCryptoBox
from noteflow.infrastructure.security.keystore import InMemoryKeyStore
crypto = AesGcmCryptoBox(InMemoryKeyStore())
initial_threads = threading.active_count()
# Create and close writers
for i in range(WRITER_THREAD_CYCLES):
writer = MeetingAudioWriter(crypto, tmp_path, buffer_size=1024)
dek = crypto.generate_dek()
wrapped_dek = crypto.wrap_dek(dek)
writer.open(f"thread-test-{i}", dek, wrapped_dek, sample_rate=DEFAULT_SAMPLE_RATE)
writer.close()
# Create and close writers using helper (avoids inline loop)
run_audio_writer_thread_cycles(
crypto, tmp_path, WRITER_THREAD_CYCLES, DEFAULT_SAMPLE_RATE
)
await asyncio.sleep(0.5) # Allow threads to fully stop
await asyncio.sleep(THREAD_STOP_WAIT_SECONDS) # Allow threads to fully stop
gc.collect()
final_threads = threading.active_count()
log_thread_metrics(
"test_multiple_writer_cycles",
initial_threads,
final_threads,
WRITER_THREAD_CYCLES,
THREAD_LEAK_TOLERANCE,
)
# Should not have leaked threads
assert final_threads <= initial_threads + THREAD_LEAK_TOLERANCE, (
f"Thread leak: {initial_threads} -> {final_threads}"
)
class TestHeapGrowth:
"""Detect heap growth using tracemalloc snapshots."""
@pytest.mark.slow
@pytest.mark.asyncio
async def test_streaming_heap_stability(
self, memory_servicer: NoteFlowServicer
) -> None:
"""Verify heap doesn't grow unbounded during streaming cycles."""
import tracemalloc
tracemalloc.start()
gc.collect()
snapshot_before = tracemalloc.take_snapshot()
run_streaming_init_cleanup_cycles(memory_servicer, STREAMING_CYCLES)
gc.collect()
snapshot_after = tracemalloc.take_snapshot()
tracemalloc.stop()
noteflow_growth = calculate_noteflow_heap_growth(snapshot_before, snapshot_after)
log_heap_metrics(
"test_streaming_heap_stability",
noteflow_growth,
cycles=STREAMING_CYCLES,
max_allowed_bytes=HEAP_MAX_GROWTH_BYTES,
filter_name="noteflow",
)
assert noteflow_growth < HEAP_MAX_GROWTH_BYTES, (
f"Heap growth detected: {noteflow_growth / BYTES_PER_KB:.1f}KB in noteflow code"
)
@pytest.mark.slow
@pytest.mark.asyncio
async def test_orm_conversion_heap_stability(self) -> None:
"""Verify ORM conversions don't leak memory."""
import tracemalloc
from uuid import uuid4
from noteflow.infrastructure.converters.orm_converters import OrmConverter
tracemalloc.start()
gc.collect()
snapshot_before = tracemalloc.take_snapshot()
converter = OrmConverter()
meeting_id = uuid4()
run_orm_conversion_cycles(converter, meeting_id, ORM_CONVERSION_CYCLES)
gc.collect()
snapshot_after = tracemalloc.take_snapshot()
tracemalloc.stop()
converter_growth = calculate_filtered_heap_growth(
snapshot_before, snapshot_after, "converter"
)
log_heap_metrics(
"test_orm_conversion_heap_stability",
converter_growth,
cycles=ORM_CONVERSION_CYCLES,
max_allowed_bytes=HEAP_CONVERTER_MAX_GROWTH_BYTES,
filter_name="converter",
)
assert converter_growth < HEAP_CONVERTER_MAX_GROWTH_BYTES, (
f"Converter heap growth: {converter_growth / BYTES_PER_KB:.1f}KB"
)
@pytest.mark.slow
@pytest.mark.asyncio
async def test_protobuf_heap_stability(self) -> None:
"""Verify protobuf operations don't leak memory."""
import tracemalloc
tracemalloc.start()
gc.collect()
snapshot_before = tracemalloc.take_snapshot()
run_protobuf_cycles(PROTO_CYCLES)
gc.collect()
snapshot_after = tracemalloc.take_snapshot()
tracemalloc.stop()
proto_growth = calculate_filtered_heap_growth(
snapshot_before, snapshot_after, "noteflow_pb2"
)
log_heap_metrics(
"test_protobuf_heap_stability",
proto_growth,
cycles=PROTO_CYCLES,
max_allowed_bytes=HEAP_PROTO_MAX_GROWTH_BYTES,
filter_name="noteflow_pb2",
)
assert proto_growth < HEAP_PROTO_MAX_GROWTH_BYTES, (
f"Protobuf heap growth: {proto_growth / BYTES_PER_KB:.1f}KB"
)
class TestAsyncContextEdgeCases:
"""Test async context manager edge cases."""
@pytest.mark.asyncio
async def test_context_exception_cleanup(self) -> None:
"""Verify cleanup runs even when exception raised inside context."""
cleanup_called = False
@asynccontextmanager
async def tracked_context() -> AsyncIterator[str]:
nonlocal cleanup_called
try:
yield "session"
finally:
cleanup_called = True
with pytest.raises(ValueError, match="test error"):
async with tracked_context():
raise ValueError("test error")
assert cleanup_called, "Cleanup should run despite exception"
@pytest.mark.asyncio
async def test_nested_context_cleanup_order(self) -> None:
"""Verify nested contexts clean up in correct (LIFO) order."""
cleanup_order: list[str] = []
@asynccontextmanager
async def tracked_context(name: str) -> AsyncIterator[str]:
try:
yield name
finally:
cleanup_order.append(name)
async with (
tracked_context("outer"),
tracked_context("middle"),
tracked_context("inner"),
):
pass
assert cleanup_order == ["inner", "middle", "outer"], (
f"Wrong cleanup order: {cleanup_order}"
)
@pytest.mark.asyncio
async def test_context_cancellation_during_body(self) -> None:
"""Verify cleanup runs when task cancelled inside context."""
cleanup_called = False
@asynccontextmanager
async def tracked_context() -> AsyncIterator[str]:
nonlocal cleanup_called
try:
yield "session"
finally:
cleanup_called = True
async def task_with_context() -> None:
async with tracked_context():
await asyncio.sleep(SLOW_ENTER_SLEEP_SECONDS)
task = asyncio.create_task(task_with_context())
await asyncio.sleep(CONTEXT_CANCEL_DELAY_SECONDS)
task.cancel()
with pytest.raises(asyncio.CancelledError, match=""):
await task
assert cleanup_called, "Cleanup should run on cancellation"
@pytest.mark.asyncio
async def test_context_cancellation_during_enter(self) -> None:
"""Verify proper handling when cancelled during __aenter__."""
enter_started = False
@asynccontextmanager
async def slow_enter_context() -> AsyncIterator[str]:
nonlocal enter_started
enter_started = True
await asyncio.sleep(SLOW_ENTER_SLEEP_SECONDS)
yield "session"
async def task_with_slow_enter() -> None:
async with slow_enter_context():
pass
task = asyncio.create_task(task_with_slow_enter())
await asyncio.sleep(CONTEXT_CANCEL_DELAY_SECONDS)
task.cancel()
with pytest.raises(asyncio.CancelledError, match=""):
await task
assert enter_started, "Enter should have started"
@pytest.mark.asyncio
async def test_context_exception_in_cleanup(self) -> None:
"""Verify cleanup exception propagates when cleanup fails.
Note: In Python, when an exception occurs in a finally block during
exception handling, the cleanup exception replaces the original.
This test verifies the cleanup exception properly propagates.
"""
cleanup_executed = False
@asynccontextmanager
async def failing_cleanup_context() -> AsyncIterator[str]:
nonlocal cleanup_executed
try:
yield "session"
finally:
cleanup_executed = True
raise RuntimeError("cleanup failed")
with pytest.raises(RuntimeError, match="cleanup failed"):
async with failing_cleanup_context():
raise ValueError("original error")
assert cleanup_executed, "Cleanup should have executed despite body exception"
@pytest.mark.asyncio
async def test_concurrent_context_isolation(self) -> None:
"""Verify concurrent context managers don't interfere."""
active_contexts: set[str] = set()
results: dict[str, int] = {"max_concurrent": 0}
context_ids = [f"ctx-{i}" for i in range(CONCURRENT_CONTEXT_COUNT)]
tasks = run_concurrent_context_tasks(context_ids, active_contexts, results)
await asyncio.gather(*tasks)
assert len(active_contexts) == 0, "All contexts should be cleaned up"
assert results["max_concurrent"] == CONCURRENT_CONTEXT_COUNT, (
f"Expected {CONCURRENT_CONTEXT_COUNT} concurrent, got {results['max_concurrent']}"
)
class TestMemoryPressure:
"""Test behavior under memory pressure conditions."""
@pytest.mark.slow
@pytest.mark.asyncio
async def test_operations_succeed_under_memory_pressure(
self, memory_servicer: NoteFlowServicer
) -> None:
"""Verify streaming operations complete under memory pressure."""
successful, _gc_count = run_operations_under_memory_pressure(
memory_servicer,
cycle_count=STREAMING_CYCLES,
pressure_mb=MEMORY_PRESSURE_ALLOCATION_MB,
)
assert successful == STREAMING_CYCLES, (
f"Only {successful}/{STREAMING_CYCLES} cycles succeeded under pressure"
)
@pytest.mark.slow
@pytest.mark.asyncio
async def test_state_cleanup_under_pressure(
self, memory_servicer: NoteFlowServicer
) -> None:
"""Verify cleanup properly releases state even under pressure."""
run_operations_under_memory_pressure(
memory_servicer,
cycle_count=STREAMING_CYCLES,
pressure_mb=MEMORY_PRESSURE_ALLOCATION_MB,
)
gc.collect()
await asyncio.sleep(ASYNC_CLEANUP_DELAY_SECONDS)
assert len(memory_servicer.vad_instances) == 0, "VAD instances leaked under pressure"
assert len(memory_servicer.active_streams) == 0, "Streams leaked under pressure"
@pytest.mark.asyncio
async def test_gc_collects_temporary_objects(self) -> None:
"""Verify GC properly collects temporary objects."""
gen0_diff, gen1_diff, gen2_diff = run_gc_collection_behavior()
total_uncollected = gen0_diff + gen1_diff + gen2_diff
assert total_uncollected >= 0, "GC count should not be negative"
@pytest.mark.asyncio
async def test_rss_measurement_returns_valid_value(self) -> None:
"""Verify RSS measurement returns valid value or graceful fallback."""
rss = measure_rss_bytes()
# Either measurement works (positive) or gracefully returns -1
valid_positive = rss > 0
valid_fallback = rss == -1
assert valid_positive or valid_fallback, f"Invalid RSS value: {rss}"