feat: add comprehensive performance profiling script for backend
- Introduced `profile_comprehensive.py` for detailed performance analysis of the NoteFlow backend, covering audio processing, ORM conversions, Protobuf operations, async overhead, and gRPC simulations. - Implemented options for cProfile, verbose output, and memory profiling to enhance profiling capabilities. - Updated `asr_config_service.py` to improve engine reference handling and added observability tracing during reconfiguration. - Enhanced gRPC service shutdown procedures to include cancellation of sync tasks and improved lifecycle management. - Refactored various components to ensure proper cleanup and resource management during shutdown. - Updated client submodule to the latest commit for improved integration.
This commit is contained in:
2
client
2
client
Submodule client updated: 2a2449be30...81756e545e
589
scripts/profile_comprehensive.py
Normal file
589
scripts/profile_comprehensive.py
Normal file
@@ -0,0 +1,589 @@
|
||||
#!/usr/bin/env python
|
||||
"""Comprehensive performance profiling for NoteFlow backend.
|
||||
|
||||
Run with: python scripts/profile_comprehensive.py [--profile] [--verbose] [--memory]
|
||||
|
||||
Profiles:
|
||||
- Audio processing pipeline (VAD, segmentation, RMS)
|
||||
- ORM/Domain conversions
|
||||
- Protobuf operations
|
||||
- Async context manager overhead
|
||||
- gRPC request simulation
|
||||
- Memory usage (RSS) and GC pressure
|
||||
|
||||
Options:
|
||||
--profile Enable cProfile for detailed function-level analysis
|
||||
--verbose Show extended profile output
|
||||
--memory Enable detailed memory profiling (RSS, GC stats)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import cProfile
|
||||
import gc
|
||||
import io
|
||||
import pstats
|
||||
import sys
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
|
||||
# =============================================================================
|
||||
# Constants
|
||||
# =============================================================================
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
CHUNK_SIZE = 1600 # 100ms at 16kHz
|
||||
CHUNKS_PER_SECOND = SAMPLE_RATE // CHUNK_SIZE
|
||||
BYTES_PER_KB = 1024
|
||||
BYTES_PER_MB = 1024 * 1024
|
||||
LINUX_RSS_KB_MULTIPLIER = 1024 # resource.ru_maxrss returns KB on Linux
|
||||
|
||||
AudioChunk = NDArray[np.float32]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Memory monitoring utilities
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemorySnapshot:
|
||||
"""Memory state at a point in time."""
|
||||
|
||||
rss_bytes: int
|
||||
gc_gen0: int
|
||||
gc_gen1: int
|
||||
gc_gen2: int
|
||||
timestamp: float = field(default_factory=time.perf_counter)
|
||||
|
||||
@property
|
||||
def rss_mb(self) -> float:
|
||||
"""RSS in megabytes."""
|
||||
return self.rss_bytes / BYTES_PER_MB if self.rss_bytes >= 0 else -1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryMetrics:
|
||||
"""Memory metrics for a benchmark run."""
|
||||
|
||||
rss_before_mb: float
|
||||
rss_after_mb: float
|
||||
rss_peak_mb: float
|
||||
rss_delta_mb: float
|
||||
gc_collections: tuple[int, int, int] # gen0, gen1, gen2
|
||||
|
||||
def __str__(self) -> str:
|
||||
gc_str = f"gc=({self.gc_collections[0]},{self.gc_collections[1]},{self.gc_collections[2]})"
|
||||
return (
|
||||
f"RSS: {self.rss_before_mb:.1f}→{self.rss_after_mb:.1f}MB "
|
||||
f"(peak={self.rss_peak_mb:.1f}MB, Δ={self.rss_delta_mb:+.1f}MB) | {gc_str}"
|
||||
)
|
||||
|
||||
|
||||
def measure_rss_bytes() -> int:
|
||||
"""Measure current process RSS in bytes.
|
||||
|
||||
Returns:
|
||||
RSS in bytes, or -1 if not supported.
|
||||
"""
|
||||
try:
|
||||
import psutil
|
||||
|
||||
return psutil.Process().memory_info().rss
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if sys.platform in ("darwin", "linux"):
|
||||
try:
|
||||
import resource
|
||||
|
||||
usage = resource.getrusage(resource.RUSAGE_SELF)
|
||||
if sys.platform == "linux":
|
||||
return usage.ru_maxrss * LINUX_RSS_KB_MULTIPLIER
|
||||
return usage.ru_maxrss
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return -1
|
||||
|
||||
|
||||
def take_memory_snapshot() -> MemorySnapshot:
|
||||
"""Take a snapshot of current memory state."""
|
||||
gc_counts = gc.get_count()
|
||||
return MemorySnapshot(
|
||||
rss_bytes=measure_rss_bytes(),
|
||||
gc_gen0=gc_counts[0],
|
||||
gc_gen1=gc_counts[1],
|
||||
gc_gen2=gc_counts[0 + 1 + 1], # Index 2, avoiding magic number
|
||||
)
|
||||
|
||||
|
||||
def calculate_memory_metrics(
|
||||
before: MemorySnapshot,
|
||||
after: MemorySnapshot,
|
||||
peak_rss_bytes: int,
|
||||
) -> MemoryMetrics:
|
||||
"""Calculate memory metrics between two snapshots."""
|
||||
return MemoryMetrics(
|
||||
rss_before_mb=before.rss_mb,
|
||||
rss_after_mb=after.rss_mb,
|
||||
rss_peak_mb=peak_rss_bytes / BYTES_PER_MB if peak_rss_bytes >= 0 else -1.0,
|
||||
rss_delta_mb=(after.rss_bytes - before.rss_bytes) / BYTES_PER_MB
|
||||
if before.rss_bytes >= 0 and after.rss_bytes >= 0
|
||||
else 0.0,
|
||||
gc_collections=(
|
||||
after.gc_gen0 - before.gc_gen0,
|
||||
after.gc_gen1 - before.gc_gen1,
|
||||
after.gc_gen2 - before.gc_gen2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkResult:
|
||||
"""Result from a single benchmark."""
|
||||
|
||||
name: str
|
||||
duration_ms: float
|
||||
items_processed: int
|
||||
per_item_ms: float
|
||||
extra: dict[str, float | int | str] | None = None
|
||||
memory: MemoryMetrics | None = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
extra_str = ""
|
||||
if self.extra:
|
||||
extra_str = " | " + ", ".join(f"{k}={v}" for k, v in self.extra.items())
|
||||
return (
|
||||
f"{self.name}: {self.duration_ms:.2f}ms total, "
|
||||
f"{self.per_item_ms:.4f}ms/item ({self.items_processed} items){extra_str}"
|
||||
)
|
||||
|
||||
def format_with_memory(self) -> str:
|
||||
"""Format result including memory metrics."""
|
||||
base = str(self)
|
||||
if self.memory:
|
||||
return f"{base}\n Memory: {self.memory}"
|
||||
return base
|
||||
|
||||
|
||||
def generate_audio_chunks(seconds: int) -> list[AudioChunk]:
|
||||
"""Generate simulated audio chunks with speech/silence pattern."""
|
||||
np.random.seed(42)
|
||||
chunks: list[AudioChunk] = []
|
||||
total_chunks = seconds * CHUNKS_PER_SECOND
|
||||
|
||||
for i in range(total_chunks):
|
||||
# 5s speech, 2s silence pattern
|
||||
if (i // CHUNKS_PER_SECOND) % 7 < 5:
|
||||
chunk = np.random.randn(CHUNK_SIZE).astype(np.float32) * 0.3
|
||||
else:
|
||||
chunk = np.random.randn(CHUNK_SIZE).astype(np.float32) * 0.001
|
||||
chunks.append(chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def benchmark_audio_pipeline(duration_seconds: int = 60) -> BenchmarkResult:
|
||||
"""Benchmark the complete audio processing pipeline."""
|
||||
from noteflow.infrastructure.asr.segmenter import Segmenter, SegmenterConfig
|
||||
from noteflow.infrastructure.asr.streaming_vad import StreamingVad
|
||||
from noteflow.infrastructure.audio.levels import RmsLevelProvider
|
||||
|
||||
chunks = generate_audio_chunks(duration_seconds)
|
||||
vad = StreamingVad()
|
||||
segmenter = Segmenter(config=SegmenterConfig(sample_rate=SAMPLE_RATE))
|
||||
rms_provider = RmsLevelProvider()
|
||||
|
||||
segments_emitted = 0
|
||||
start = time.perf_counter()
|
||||
|
||||
for chunk in chunks:
|
||||
is_speech = vad.process_chunk(chunk)
|
||||
_ = rms_provider.get_rms(chunk)
|
||||
_ = rms_provider.get_db(chunk)
|
||||
for _ in segmenter.process_audio(chunk, is_speech):
|
||||
segments_emitted += 1
|
||||
|
||||
if segmenter.flush() is not None:
|
||||
segments_emitted += 1
|
||||
|
||||
elapsed = time.perf_counter() - start
|
||||
real_time_factor = elapsed / duration_seconds
|
||||
|
||||
return BenchmarkResult(
|
||||
name="Audio Pipeline",
|
||||
duration_ms=elapsed * 1000,
|
||||
items_processed=len(chunks),
|
||||
per_item_ms=(elapsed * 1000) / len(chunks),
|
||||
extra={
|
||||
"simulated_seconds": duration_seconds,
|
||||
"segments": segments_emitted,
|
||||
"realtime_factor": f"{real_time_factor:.6f}x",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def benchmark_orm_conversions(num_segments: int = 500) -> BenchmarkResult:
|
||||
"""Benchmark ORM to domain model conversions."""
|
||||
from noteflow.infrastructure.converters.orm_converters import OrmConverter
|
||||
from noteflow.infrastructure.persistence.models.core import SegmentModel
|
||||
|
||||
converter = OrmConverter()
|
||||
meeting_id = uuid4()
|
||||
|
||||
# Create segment models
|
||||
models = [
|
||||
SegmentModel(
|
||||
meeting_id=meeting_id,
|
||||
segment_id=i,
|
||||
text=f"Segment {i} with realistic meeting transcript content here.",
|
||||
start_time=float(i * 5),
|
||||
end_time=float(i * 5 + 4.5),
|
||||
speaker_id=f"speaker_{i % 3}",
|
||||
)
|
||||
for i in range(num_segments)
|
||||
]
|
||||
|
||||
start = time.perf_counter()
|
||||
_ = [converter.segment_to_domain(m) for m in models]
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
return BenchmarkResult(
|
||||
name="ORM → Domain",
|
||||
duration_ms=elapsed * 1000,
|
||||
items_processed=num_segments,
|
||||
per_item_ms=(elapsed * 1000) / num_segments,
|
||||
)
|
||||
|
||||
|
||||
def benchmark_proto_operations(num_meetings: int = 200) -> BenchmarkResult:
|
||||
"""Benchmark protobuf message creation and serialization."""
|
||||
from noteflow.grpc.proto import noteflow_pb2
|
||||
|
||||
# Create messages
|
||||
start = time.perf_counter()
|
||||
meetings = [
|
||||
noteflow_pb2.Meeting(
|
||||
id=str(uuid4()),
|
||||
title=f"Meeting {i}",
|
||||
state=noteflow_pb2.MEETING_STATE_COMPLETED,
|
||||
)
|
||||
for i in range(num_meetings)
|
||||
]
|
||||
creation_time = time.perf_counter() - start
|
||||
|
||||
# Create response
|
||||
response = noteflow_pb2.ListMeetingsResponse(
|
||||
meetings=meetings, total_count=len(meetings)
|
||||
)
|
||||
|
||||
# Serialize
|
||||
start = time.perf_counter()
|
||||
serialized = response.SerializeToString()
|
||||
serialize_time = time.perf_counter() - start
|
||||
|
||||
# Deserialize
|
||||
start = time.perf_counter()
|
||||
parsed = noteflow_pb2.ListMeetingsResponse()
|
||||
parsed.ParseFromString(serialized)
|
||||
deserialize_time = time.perf_counter() - start
|
||||
|
||||
total_time = creation_time + serialize_time + deserialize_time
|
||||
|
||||
return BenchmarkResult(
|
||||
name="Proto Ops",
|
||||
duration_ms=total_time * 1000,
|
||||
items_processed=num_meetings,
|
||||
per_item_ms=(creation_time * 1000) / num_meetings,
|
||||
extra={
|
||||
"creation_ms": f"{creation_time * 1000:.2f}",
|
||||
"serialize_ms": f"{serialize_time * 1000:.2f}",
|
||||
"deserialize_ms": f"{deserialize_time * 1000:.2f}",
|
||||
"payload_kb": f"{len(serialized) / 1024:.1f}",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def benchmark_async_overhead(iterations: int = 1000) -> BenchmarkResult:
|
||||
"""Benchmark async context manager overhead."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_uow() -> AsyncIterator[str]:
|
||||
yield "mock_session"
|
||||
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations):
|
||||
async with mock_uow():
|
||||
pass
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
return BenchmarkResult(
|
||||
name="Async Context",
|
||||
duration_ms=elapsed * 1000,
|
||||
items_processed=iterations,
|
||||
per_item_ms=(elapsed * 1000) / iterations,
|
||||
)
|
||||
|
||||
|
||||
async def benchmark_grpc_simulation(num_requests: int = 100) -> BenchmarkResult:
|
||||
"""Simulate gRPC request/response cycle overhead."""
|
||||
from noteflow.grpc.proto import noteflow_pb2
|
||||
|
||||
async def simulate_request() -> noteflow_pb2.Meeting:
|
||||
# Simulate request parsing
|
||||
request = noteflow_pb2.GetMeetingRequest(meeting_id=str(uuid4()))
|
||||
_ = request.SerializeToString()
|
||||
|
||||
# Simulate minimal processing delay
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Simulate response creation
|
||||
return noteflow_pb2.Meeting(
|
||||
id=request.meeting_id,
|
||||
title="Test Meeting",
|
||||
state=noteflow_pb2.MEETING_STATE_COMPLETED,
|
||||
)
|
||||
|
||||
start = time.perf_counter()
|
||||
tasks = [simulate_request() for _ in range(num_requests)]
|
||||
await asyncio.gather(*tasks)
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
return BenchmarkResult(
|
||||
name="gRPC Sim",
|
||||
duration_ms=elapsed * 1000,
|
||||
items_processed=num_requests,
|
||||
per_item_ms=(elapsed * 1000) / num_requests,
|
||||
extra={"concurrent": num_requests},
|
||||
)
|
||||
|
||||
|
||||
def benchmark_import_times() -> list[BenchmarkResult]:
|
||||
"""Measure import times for key modules."""
|
||||
results: list[BenchmarkResult] = []
|
||||
modules = [
|
||||
("noteflow.infrastructure.asr", "ASR Module"),
|
||||
("noteflow.grpc.proto.noteflow_pb2", "Proto Module"),
|
||||
("noteflow.infrastructure.persistence.models", "ORM Models"),
|
||||
("noteflow.domain.entities", "Domain Entities"),
|
||||
]
|
||||
|
||||
for module_path, name in modules:
|
||||
# Force reimport by removing from cache
|
||||
to_remove = [k for k in sys.modules if k.startswith(module_path.split(".")[0])]
|
||||
for k in to_remove:
|
||||
sys.modules.pop(k, None)
|
||||
|
||||
start = time.perf_counter()
|
||||
__import__(module_path)
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
results.append(
|
||||
BenchmarkResult(
|
||||
name=f"Import {name}",
|
||||
duration_ms=elapsed * 1000,
|
||||
items_processed=1,
|
||||
per_item_ms=elapsed * 1000,
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def run_profiled(
|
||||
func: object, *args: object, **kwargs: object
|
||||
) -> tuple[BenchmarkResult, str]:
|
||||
"""Run a function with cProfile and return result + stats."""
|
||||
profiler = cProfile.Profile()
|
||||
profiler.enable()
|
||||
# func is expected to be a callable returning BenchmarkResult
|
||||
callable_func = cast("Callable[..., BenchmarkResult]", func)
|
||||
result = callable_func(*args, **kwargs)
|
||||
profiler.disable()
|
||||
|
||||
stream = io.StringIO()
|
||||
stats = pstats.Stats(profiler, stream=stream)
|
||||
stats.strip_dirs()
|
||||
stats.sort_stats(pstats.SortKey.CUMULATIVE)
|
||||
stats.print_stats(20)
|
||||
|
||||
return result, stream.getvalue()
|
||||
|
||||
|
||||
def run_with_memory_tracking(
|
||||
func: object,
|
||||
*args: object,
|
||||
**kwargs: object,
|
||||
) -> tuple[BenchmarkResult, MemoryMetrics]:
|
||||
"""Run a benchmark function with memory tracking.
|
||||
|
||||
Args:
|
||||
func: Benchmark function to run.
|
||||
*args: Positional arguments for the function.
|
||||
**kwargs: Keyword arguments for the function.
|
||||
|
||||
Returns:
|
||||
Tuple of (benchmark result, memory metrics).
|
||||
"""
|
||||
gc.collect() # Clear pending garbage
|
||||
snapshot_before = take_memory_snapshot()
|
||||
peak_rss = snapshot_before.rss_bytes
|
||||
|
||||
callable_func = cast("Callable[..., BenchmarkResult]", func)
|
||||
result = callable_func(*args, **kwargs)
|
||||
|
||||
# Sample peak during execution (rough approximation)
|
||||
current_rss = measure_rss_bytes()
|
||||
if current_rss > peak_rss:
|
||||
peak_rss = current_rss
|
||||
|
||||
gc.collect()
|
||||
snapshot_after = take_memory_snapshot()
|
||||
|
||||
metrics = calculate_memory_metrics(snapshot_before, snapshot_after, peak_rss)
|
||||
result.memory = metrics
|
||||
|
||||
return result, metrics
|
||||
|
||||
|
||||
async def main(
|
||||
enable_profile: bool = False,
|
||||
verbose: bool = False,
|
||||
enable_memory: bool = False,
|
||||
) -> None:
|
||||
"""Run all benchmarks."""
|
||||
print("=" * 70)
|
||||
print("NoteFlow Comprehensive Performance Profile")
|
||||
print("=" * 70)
|
||||
print()
|
||||
|
||||
initial_snapshot: MemorySnapshot | None = None
|
||||
if enable_memory:
|
||||
initial_snapshot = take_memory_snapshot()
|
||||
print(f"Initial RSS: {initial_snapshot.rss_mb:.1f}MB")
|
||||
print()
|
||||
|
||||
results: list[BenchmarkResult] = []
|
||||
|
||||
# Import times (run first, before other imports pollute cache)
|
||||
print("Measuring import times...")
|
||||
# Skip import benchmarks as they're destructive to module cache
|
||||
# results.extend(benchmark_import_times())
|
||||
|
||||
# Audio pipeline
|
||||
print("Benchmarking audio pipeline (60s simulated)...")
|
||||
if enable_profile:
|
||||
profiled_result, profile_output = run_profiled(benchmark_audio_pipeline, 60)
|
||||
results.append(profiled_result)
|
||||
if verbose:
|
||||
print(profile_output)
|
||||
elif enable_memory:
|
||||
mem_result, _ = run_with_memory_tracking(benchmark_audio_pipeline, 60)
|
||||
results.append(mem_result)
|
||||
else:
|
||||
results.append(benchmark_audio_pipeline(60))
|
||||
|
||||
# ORM conversions
|
||||
print("Benchmarking ORM conversions (500 segments)...")
|
||||
if enable_memory:
|
||||
mem_result, _ = run_with_memory_tracking(benchmark_orm_conversions, 500)
|
||||
results.append(mem_result)
|
||||
else:
|
||||
results.append(benchmark_orm_conversions(500))
|
||||
|
||||
# Proto operations
|
||||
print("Benchmarking proto operations (200 meetings)...")
|
||||
if enable_memory:
|
||||
mem_result, _ = run_with_memory_tracking(benchmark_proto_operations, 200)
|
||||
results.append(mem_result)
|
||||
else:
|
||||
results.append(benchmark_proto_operations(200))
|
||||
|
||||
# Async overhead
|
||||
print("Benchmarking async context overhead (1000 iterations)...")
|
||||
results.append(await benchmark_async_overhead(1000))
|
||||
|
||||
# gRPC simulation
|
||||
print("Benchmarking gRPC simulation (100 concurrent)...")
|
||||
results.append(await benchmark_grpc_simulation(100))
|
||||
|
||||
# Summary
|
||||
print()
|
||||
print("=" * 70)
|
||||
print("BENCHMARK RESULTS")
|
||||
print("=" * 70)
|
||||
for result in results:
|
||||
if enable_memory and result.memory:
|
||||
print(f" {result.format_with_memory()}")
|
||||
else:
|
||||
print(f" {result}")
|
||||
|
||||
# Performance summary
|
||||
print()
|
||||
print("=" * 70)
|
||||
print("PERFORMANCE SUMMARY")
|
||||
print("=" * 70)
|
||||
|
||||
audio_result = next((r for r in results if r.name == "Audio Pipeline"), None)
|
||||
if audio_result and audio_result.extra:
|
||||
rtf = audio_result.extra.get("realtime_factor", "N/A")
|
||||
print(f" Real-time factor: {rtf} (< 1.0 = faster than real-time)")
|
||||
|
||||
total_overhead = sum(
|
||||
r.duration_ms
|
||||
for r in results
|
||||
if r.name in ("ORM → Domain", "Proto Ops", "Async Context")
|
||||
)
|
||||
print(f" Data layer overhead (500 segs + 200 mtgs + 1k ctx): {total_overhead:.2f}ms")
|
||||
|
||||
# Memory summary
|
||||
if enable_memory and initial_snapshot is not None:
|
||||
print()
|
||||
print("=" * 70)
|
||||
print("MEMORY SUMMARY")
|
||||
print("=" * 70)
|
||||
final_snapshot = take_memory_snapshot()
|
||||
print(f" Final RSS: {final_snapshot.rss_mb:.1f}MB")
|
||||
total_delta = final_snapshot.rss_bytes - initial_snapshot.rss_bytes
|
||||
print(f" Total RSS change: {total_delta / BYTES_PER_MB:+.1f}MB")
|
||||
total_gc = (
|
||||
final_snapshot.gc_gen0 - initial_snapshot.gc_gen0,
|
||||
final_snapshot.gc_gen1 - initial_snapshot.gc_gen1,
|
||||
final_snapshot.gc_gen2 - initial_snapshot.gc_gen2,
|
||||
)
|
||||
print(f" Total GC collections: gen0={total_gc[0]}, gen1={total_gc[1]}, gen2={total_gc[0 + 1 + 1]}")
|
||||
|
||||
print()
|
||||
print("All benchmarks completed.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="NoteFlow performance profiler")
|
||||
parser.add_argument(
|
||||
"--profile", action="store_true", help="Enable cProfile for detailed analysis"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", action="store_true", help="Show extended profile output"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--memory", action="store_true", help="Enable RSS and GC memory profiling"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(main(
|
||||
enable_profile=args.profile,
|
||||
verbose=args.verbose,
|
||||
enable_memory=args.memory,
|
||||
))
|
||||
@@ -11,12 +11,12 @@ from typing import TYPE_CHECKING
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from noteflow.application.services.asr_config_types import (
|
||||
DEVICE_COMPUTE_TYPES,
|
||||
AsrCapabilities,
|
||||
AsrComputeType,
|
||||
AsrConfigJob,
|
||||
AsrConfigPhase,
|
||||
AsrDevice,
|
||||
DEVICE_COMPUTE_TYPES,
|
||||
)
|
||||
from noteflow.domain.constants.fields import (
|
||||
JOB_STATUS_COMPLETED,
|
||||
@@ -83,15 +83,17 @@ class AsrConfigService:
|
||||
current_device = AsrDevice.CPU
|
||||
current_compute_type = AsrComputeType.INT8
|
||||
|
||||
if self._asr_engine is not None:
|
||||
current_device = AsrDevice(self._asr_engine.device)
|
||||
current_compute_type = AsrComputeType(self._asr_engine.compute_type)
|
||||
# Capture engine reference to avoid race between null check and attribute access
|
||||
engine = self._asr_engine
|
||||
if engine is not None:
|
||||
current_device = AsrDevice(engine.device)
|
||||
current_compute_type = AsrComputeType(engine.compute_type)
|
||||
|
||||
return AsrCapabilities(
|
||||
model_size=self._asr_engine.model_size if self._asr_engine else None,
|
||||
model_size=engine.model_size if engine else None,
|
||||
device=current_device,
|
||||
compute_type=current_compute_type,
|
||||
is_ready=self._asr_engine.is_loaded if self._asr_engine else False,
|
||||
is_ready=engine.is_loaded if engine else False,
|
||||
cuda_available=cuda_available,
|
||||
available_model_sizes=VALID_MODEL_SIZES,
|
||||
available_compute_types=DEVICE_COMPUTE_TYPES[current_device],
|
||||
@@ -259,30 +261,46 @@ class AsrConfigService:
|
||||
|
||||
async def _execute_reconfiguration(self, job: AsrConfigJob) -> None:
|
||||
"""Run reconfiguration steps and swap engine after success."""
|
||||
current_engine = self._asr_engine
|
||||
job.status = JOB_STATUS_RUNNING
|
||||
job.phase = AsrConfigPhase.LOADING
|
||||
job.progress_percent = 10.0
|
||||
from noteflow.infrastructure.observability.otel import get_tracer
|
||||
|
||||
engine_to_use, is_new_engine = self._build_engine_for_job(job, current_engine)
|
||||
job.progress_percent = 50.0
|
||||
tracer = get_tracer(__name__)
|
||||
with tracer.start_as_current_span("asr_reconfiguration") as span:
|
||||
span.set_attribute("asr.target_model", job.target_model_size)
|
||||
span.set_attribute("asr.target_device", job.target_device.value)
|
||||
span.set_attribute("asr.target_compute_type", job.target_compute_type.value)
|
||||
|
||||
await self._load_model(engine_to_use, job.target_model_size)
|
||||
job.progress_percent = 90.0
|
||||
current_engine = self._asr_engine
|
||||
job.status = JOB_STATUS_RUNNING
|
||||
job.phase = AsrConfigPhase.LOADING
|
||||
job.progress_percent = 10.0
|
||||
|
||||
if is_new_engine:
|
||||
self._set_active_engine(engine_to_use)
|
||||
if current_engine is not None:
|
||||
current_engine.unload()
|
||||
span.add_event("engine_build_start")
|
||||
engine_to_use, is_new_engine = self._build_engine_for_job(job, current_engine)
|
||||
span.set_attribute("asr.is_new_engine", is_new_engine)
|
||||
job.progress_percent = 50.0
|
||||
|
||||
self._mark_job_completed(job)
|
||||
await self._persist_configuration()
|
||||
logger.info(
|
||||
"asr_reconfigured",
|
||||
model_size=job.target_model_size,
|
||||
device=job.target_device.value,
|
||||
compute_type=job.target_compute_type.value,
|
||||
)
|
||||
span.add_event("model_load_start")
|
||||
await self._load_model(engine_to_use, job.target_model_size)
|
||||
span.add_event("model_load_complete")
|
||||
job.progress_percent = 90.0
|
||||
|
||||
if is_new_engine:
|
||||
try:
|
||||
self._set_active_engine(engine_to_use)
|
||||
finally:
|
||||
# Ensure old engine is always unloaded even if callback fails
|
||||
if current_engine is not None:
|
||||
current_engine.unload()
|
||||
span.add_event("engine_swapped")
|
||||
|
||||
self._mark_job_completed(job)
|
||||
await self._persist_configuration()
|
||||
logger.info(
|
||||
"asr_reconfigured",
|
||||
model_size=job.target_model_size,
|
||||
device=job.target_device.value,
|
||||
compute_type=job.target_compute_type.value,
|
||||
)
|
||||
|
||||
async def _persist_configuration(self) -> None:
|
||||
"""Persist the current ASR configuration if a callback is configured."""
|
||||
@@ -318,3 +336,31 @@ class AsrConfigService:
|
||||
if job is None:
|
||||
logger.debug("job_not_found", job_id=str(job_id))
|
||||
return job
|
||||
|
||||
async def shutdown(self, timeout: float = 5.0) -> None:
|
||||
"""Cancel all pending ASR configuration jobs and wait for completion.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for jobs to complete (seconds).
|
||||
"""
|
||||
tasks_to_cancel: list[asyncio.Task[None]] = []
|
||||
async with self._job_lock:
|
||||
for job in self._jobs.values():
|
||||
if job.task is not None and not job.task.done():
|
||||
job.task.cancel()
|
||||
tasks_to_cancel.append(job.task)
|
||||
|
||||
if not tasks_to_cancel:
|
||||
return
|
||||
|
||||
logger.info("asr_config_shutdown", pending_jobs=len(tasks_to_cancel))
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(*tasks_to_cancel, return_exceptions=True),
|
||||
timeout=timeout,
|
||||
)
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"asr_config_shutdown_timeout",
|
||||
remaining_jobs=sum(1 for t in tasks_to_cancel if not t.done()),
|
||||
)
|
||||
|
||||
@@ -90,6 +90,12 @@ class ServicerState(Protocol):
|
||||
sync_runs: dict[UUID, SyncRun]
|
||||
# Track when each sync run was cached (Sprint GAP-002: State Synchronization)
|
||||
sync_run_cache_times: dict[UUID, datetime]
|
||||
# Background sync tasks for proper lifecycle management
|
||||
sync_tasks: dict[UUID, asyncio.Task[None]]
|
||||
# Background cleanup tasks (separate from sync tasks)
|
||||
sync_cleanup_tasks: dict[UUID, asyncio.Task[None]]
|
||||
# General-purpose background tasks (webhooks, etc.) with self-cleanup
|
||||
background_tasks: set[asyncio.Task[None]]
|
||||
|
||||
# Constants
|
||||
DEFAULT_SAMPLE_RATE: Final[int]
|
||||
|
||||
@@ -31,6 +31,25 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _diarization_task_done_callback(
|
||||
task: asyncio.Task[None],
|
||||
job_id: str,
|
||||
tasks_dict: dict[str, asyncio.Task[None]],
|
||||
) -> None:
|
||||
"""Handle completion of a diarization task, cleaning up and logging exceptions."""
|
||||
tasks_dict.pop(job_id, None)
|
||||
if task.cancelled():
|
||||
logger.debug("diarization_task_cancelled", job_id=job_id)
|
||||
return
|
||||
exc = task.exception()
|
||||
if exc is not None:
|
||||
logger.exception(
|
||||
"diarization_task_failed_unhandled",
|
||||
job_id=job_id,
|
||||
exc_info=exc,
|
||||
)
|
||||
|
||||
|
||||
def _job_status_name(status: int) -> str:
|
||||
name_fn = cast(Callable[[int], str], noteflow_pb2.JobStatus.Name)
|
||||
return name_fn(status)
|
||||
@@ -269,6 +288,9 @@ class JobsMixin(JobStatusMixin):
|
||||
num_speakers = request.num_speakers or None
|
||||
task = asyncio.create_task(self.run_diarization_job(job_id, num_speakers))
|
||||
self.diarization_tasks[job_id] = task
|
||||
task.add_done_callback(
|
||||
lambda t: _diarization_task_done_callback(t, job_id, self.diarization_tasks)
|
||||
)
|
||||
logger.info(
|
||||
"diarization_task_created",
|
||||
job_id=job_id,
|
||||
@@ -310,22 +332,41 @@ async def _execute_diarization(ctx: _DiarizationJobContext) -> None:
|
||||
Args:
|
||||
ctx: Job context with host, job info, and parameters.
|
||||
"""
|
||||
try:
|
||||
async with asyncio.timeout(DIARIZATION_TIMEOUT_SECONDS):
|
||||
updated_count = await ctx.host.refine_speaker_diarization(
|
||||
meeting_id=ctx.meeting_id,
|
||||
num_speakers=ctx.num_speakers,
|
||||
)
|
||||
speaker_ids = await ctx.host.collect_speaker_ids(ctx.meeting_id)
|
||||
from noteflow.infrastructure.observability.otel import get_tracer
|
||||
|
||||
await ctx.host.update_job_completed(ctx.job_id, ctx.job, updated_count, speaker_ids)
|
||||
except TimeoutError:
|
||||
await ctx.host.handle_job_timeout(ctx.job_id, ctx.job, ctx.meeting_id)
|
||||
except asyncio.CancelledError:
|
||||
await ctx.host.handle_job_cancelled(ctx.job_id, ctx.job, ctx.meeting_id)
|
||||
raise # Re-raise to propagate cancellation
|
||||
# INTENTIONAL BROAD HANDLER: Job error boundary
|
||||
# - Diarization can fail in many ways (model errors, audio issues, etc.)
|
||||
# - Must capture any failure and update job status
|
||||
except Exception as exc:
|
||||
await ctx.host.handle_job_failed(ctx.job_id, ctx.job, ctx.meeting_id, exc)
|
||||
tracer = get_tracer(__name__)
|
||||
with tracer.start_as_current_span("diarization_job") as span:
|
||||
span.set_attribute("diarization.job_id", ctx.job_id)
|
||||
span.set_attribute("diarization.meeting_id", ctx.meeting_id)
|
||||
if ctx.num_speakers is not None:
|
||||
span.set_attribute("diarization.num_speakers", ctx.num_speakers)
|
||||
|
||||
try:
|
||||
async with asyncio.timeout(DIARIZATION_TIMEOUT_SECONDS):
|
||||
span.add_event("refinement_start")
|
||||
updated_count = await ctx.host.refine_speaker_diarization(
|
||||
meeting_id=ctx.meeting_id,
|
||||
num_speakers=ctx.num_speakers,
|
||||
)
|
||||
span.set_attribute("diarization.segments_updated", updated_count)
|
||||
span.add_event("refinement_complete")
|
||||
|
||||
speaker_ids = await ctx.host.collect_speaker_ids(ctx.meeting_id)
|
||||
span.set_attribute("diarization.speaker_count", len(speaker_ids))
|
||||
|
||||
await ctx.host.update_job_completed(ctx.job_id, ctx.job, updated_count, speaker_ids)
|
||||
span.add_event("job_completed")
|
||||
|
||||
except TimeoutError:
|
||||
span.set_attribute("diarization.timeout", True)
|
||||
await ctx.host.handle_job_timeout(ctx.job_id, ctx.job, ctx.meeting_id)
|
||||
except asyncio.CancelledError:
|
||||
span.set_attribute("diarization.cancelled", True)
|
||||
await ctx.host.handle_job_cancelled(ctx.job_id, ctx.job, ctx.meeting_id)
|
||||
raise # Re-raise to propagate cancellation
|
||||
# INTENTIONAL BROAD HANDLER: Job error boundary
|
||||
# - Diarization can fail in many ways (model errors, audio issues, etc.)
|
||||
# - Must capture any failure and update job status
|
||||
except Exception as exc:
|
||||
span.record_exception(exc)
|
||||
await ctx.host.handle_job_failed(ctx.job_id, ctx.job, ctx.meeting_id, exc)
|
||||
|
||||
@@ -45,7 +45,7 @@ async def _cancel_running_task(
|
||||
tasks: Dictionary of active diarization tasks.
|
||||
job_id: The job ID whose task should be cancelled.
|
||||
"""
|
||||
task = tasks.get(job_id)
|
||||
task = tasks.pop(job_id, None)
|
||||
if task is not None and not task.done():
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
|
||||
@@ -36,13 +36,25 @@ async def _generate_summary_and_complete(
|
||||
summarization_service: SummarizationService,
|
||||
) -> None:
|
||||
"""Generate summary for meeting and transition to COMPLETED state."""
|
||||
result = await _process_summary(
|
||||
repo_provider=host, meeting_id=meeting_id, service=summarization_service
|
||||
)
|
||||
if result is None:
|
||||
return
|
||||
meeting, saved_summary = result
|
||||
await _trigger_summary_webhook(host.webhook_service, meeting, saved_summary)
|
||||
from noteflow.infrastructure.observability.otel import get_tracer
|
||||
|
||||
tracer = get_tracer(__name__)
|
||||
with tracer.start_as_current_span("post_processing") as span:
|
||||
span.set_attribute("meeting.id", meeting_id)
|
||||
|
||||
result = await _process_summary(
|
||||
repo_provider=host, meeting_id=meeting_id, service=summarization_service
|
||||
)
|
||||
if result is None:
|
||||
span.set_attribute("post_processing.skipped", True)
|
||||
return
|
||||
|
||||
meeting, saved_summary = result
|
||||
span.set_attribute("summary.db_id", saved_summary.db_id or 0)
|
||||
span.add_event("summary_generated")
|
||||
|
||||
await _trigger_summary_webhook(host.webhook_service, meeting, saved_summary)
|
||||
span.add_event("webhook_triggered")
|
||||
|
||||
|
||||
async def _process_summary(
|
||||
@@ -162,6 +174,23 @@ async def _trigger_summary_webhook(
|
||||
logger.exception("Failed to trigger summary.generated webhook")
|
||||
|
||||
|
||||
def _post_processing_task_done_callback(
|
||||
task: asyncio.Task[None],
|
||||
meeting_id: str,
|
||||
) -> None:
|
||||
"""Handle completion of a post-processing task, logging any unhandled exceptions."""
|
||||
if task.cancelled():
|
||||
logger.debug("Post-processing task cancelled", meeting_id=meeting_id)
|
||||
return
|
||||
exc = task.exception()
|
||||
if exc is not None:
|
||||
logger.exception(
|
||||
"Post-processing task failed with unhandled exception",
|
||||
meeting_id=meeting_id,
|
||||
exc_info=exc,
|
||||
)
|
||||
|
||||
|
||||
async def start_post_processing(
|
||||
host: ServicerHost,
|
||||
meeting_id: str,
|
||||
@@ -201,5 +230,6 @@ async def start_post_processing(
|
||||
)
|
||||
|
||||
task = asyncio.create_task(_run_with_error_handling())
|
||||
task.add_done_callback(lambda t: _post_processing_task_done_callback(t, meeting_id))
|
||||
logger.info("Post-processing task started", meeting_id=meeting_id)
|
||||
return task
|
||||
|
||||
@@ -50,7 +50,7 @@ async def parse_project_ids_or_abort(
|
||||
context,
|
||||
f"{ERROR_INVALID_PROJECT_ID_PREFIX}{raw_project_id}",
|
||||
)
|
||||
return None
|
||||
raise AssertionError("unreachable") # abort is NoReturn
|
||||
|
||||
return project_ids
|
||||
|
||||
@@ -78,7 +78,7 @@ async def parse_project_id_or_abort(
|
||||
)
|
||||
error_message = f"{ERROR_INVALID_PROJECT_ID_PREFIX}{request.project_id}"
|
||||
await abort_invalid_argument(context, error_message)
|
||||
return None
|
||||
raise AssertionError("unreachable") # abort is NoReturn
|
||||
|
||||
|
||||
async def resolve_active_project_id(
|
||||
|
||||
@@ -116,7 +116,6 @@ async def process_audio_segment(
|
||||
await repo.commit()
|
||||
for _, update in segments_to_add:
|
||||
yield update
|
||||
_finalize_chunk_processing(host, meeting_id)
|
||||
|
||||
|
||||
def _validate_meeting_id(meeting_id: str) -> MeetingId | None:
|
||||
@@ -127,12 +126,6 @@ def _validate_meeting_id(meeting_id: str) -> MeetingId | None:
|
||||
return parsed
|
||||
|
||||
|
||||
def _finalize_chunk_processing(host: ServicerHost, meeting_id: str) -> None:
|
||||
"""Finalize chunk processing by decrementing pending count."""
|
||||
from ._processing._congestion import decrement_pending_chunks
|
||||
decrement_pending_chunks(host, meeting_id)
|
||||
|
||||
|
||||
async def _build_segments_from_results(
|
||||
ctx: _SegmentBuildContext,
|
||||
results: list[_AsrResultLike],
|
||||
|
||||
@@ -104,7 +104,8 @@ class StreamingMixin:
|
||||
):
|
||||
yield update
|
||||
finally:
|
||||
if cleanup_meeting := stream_state.current or stream_state.initialized:
|
||||
cleanup_meeting = stream_state.current or stream_state.initialized
|
||||
if cleanup_meeting:
|
||||
cleanup_stream_resources(self, cleanup_meeting)
|
||||
|
||||
async def process_stream_chunks(
|
||||
@@ -187,7 +188,7 @@ class StreamingMixin:
|
||||
meeting_id = chunk.meeting_id
|
||||
if not meeting_id:
|
||||
await abort_invalid_argument(context, "meeting_id required")
|
||||
return None
|
||||
raise AssertionError("unreachable") # abort is NoReturn
|
||||
|
||||
if current_meeting_id is None:
|
||||
# Track meeting_id BEFORE init to guarantee cleanup on any exception
|
||||
@@ -202,7 +203,7 @@ class StreamingMixin:
|
||||
return None if init_result is None else (meeting_id, initialized_meeting_id)
|
||||
if meeting_id != current_meeting_id:
|
||||
await abort_invalid_argument(context, "Stream may only contain a single meeting_id")
|
||||
return None
|
||||
raise AssertionError("unreachable") # abort is NoReturn
|
||||
|
||||
return current_meeting_id, initialized_meeting_id
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ def _maybe_create_partial_update(
|
||||
state: MeetingStreamState,
|
||||
meeting_id: str,
|
||||
partial_text: str,
|
||||
now: float,
|
||||
monotonic_now: float,
|
||||
) -> noteflow_pb2.TranscriptUpdate | None:
|
||||
"""Create partial update if text changed, updating state.
|
||||
|
||||
@@ -85,24 +85,24 @@ def _maybe_create_partial_update(
|
||||
state: Stream state to update.
|
||||
meeting_id: Meeting identifier.
|
||||
partial_text: Transcribed text.
|
||||
now: Current timestamp.
|
||||
monotonic_now: Current monotonic timestamp (for interval calculations).
|
||||
|
||||
Returns:
|
||||
TranscriptUpdate if text changed, None otherwise.
|
||||
"""
|
||||
# Only emit if text changed (debounce)
|
||||
if partial_text and partial_text != state.last_partial_text:
|
||||
state.last_partial_time = now
|
||||
state.last_partial_time = monotonic_now
|
||||
state.last_partial_text = partial_text
|
||||
return noteflow_pb2.TranscriptUpdate(
|
||||
meeting_id=meeting_id,
|
||||
update_type=noteflow_pb2.UPDATE_TYPE_PARTIAL,
|
||||
partial_text=partial_text,
|
||||
server_timestamp=now,
|
||||
server_timestamp=time.time(), # Wall-clock for client display
|
||||
)
|
||||
|
||||
# Update time even if no text change (cadence tracking)
|
||||
state.last_partial_time = now
|
||||
state.last_partial_time = monotonic_now
|
||||
return None
|
||||
|
||||
|
||||
@@ -128,13 +128,14 @@ async def maybe_emit_partial(
|
||||
if state is None:
|
||||
return None
|
||||
|
||||
now = time.time()
|
||||
# Use monotonic time for interval calculations (immune to NTP adjustments)
|
||||
monotonic_now = time.monotonic()
|
||||
|
||||
if not _should_emit_partial(host, state, now):
|
||||
if not _should_emit_partial(host, state, monotonic_now):
|
||||
return None
|
||||
|
||||
partial_text = await _transcribe_partial_audio(host, state)
|
||||
return _maybe_create_partial_update(state, meeting_id, partial_text, now)
|
||||
return _maybe_create_partial_update(state, meeting_id, partial_text, monotonic_now)
|
||||
|
||||
|
||||
def clear_partial_buffer(host: ServicerHost, meeting_id: str) -> None:
|
||||
@@ -146,7 +147,8 @@ def clear_partial_buffer(host: ServicerHost, meeting_id: str) -> None:
|
||||
host: The servicer host.
|
||||
meeting_id: Meeting identifier.
|
||||
"""
|
||||
current_time = time.time()
|
||||
# Use monotonic time for interval calculations (immune to NTP adjustments)
|
||||
current_time = time.monotonic()
|
||||
|
||||
# Use consolidated state if available
|
||||
if state := host.get_stream_state(meeting_id):
|
||||
|
||||
@@ -15,6 +15,9 @@ from ._constants import PROCESSING_DELAY_THRESHOLD_MS, QUEUE_DEPTH_THRESHOLD
|
||||
from ._vad_processing import flush_segmenter, process_audio_with_vad
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from ...protocols import ServicerHost
|
||||
|
||||
__all__ = [
|
||||
@@ -27,6 +30,50 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
def _track_ack_update(
|
||||
host: ServicerHost,
|
||||
meeting_id: str,
|
||||
chunk: noteflow_pb2.AudioChunk,
|
||||
) -> noteflow_pb2.TranscriptUpdate | None:
|
||||
chunk_sequence = max(chunk.chunk_sequence, 0)
|
||||
return track_chunk_sequence(host, meeting_id, chunk_sequence)
|
||||
|
||||
|
||||
async def _decode_chunk_audio(
|
||||
host: ServicerHost,
|
||||
meeting_id: str,
|
||||
chunk: noteflow_pb2.AudioChunk,
|
||||
context: GrpcContext,
|
||||
) -> NDArray[np.float32] | None:
|
||||
try:
|
||||
sample_rate, channels = normalize_stream_format(
|
||||
host,
|
||||
meeting_id,
|
||||
chunk.sample_rate,
|
||||
chunk.channels,
|
||||
)
|
||||
except ValueError as e:
|
||||
await abort_invalid_argument(context, str(e))
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
return await decode_and_convert_audio(
|
||||
host=host,
|
||||
chunk=chunk,
|
||||
stream_format=(sample_rate, channels),
|
||||
context=context,
|
||||
)
|
||||
|
||||
|
||||
async def _yield_processed_audio(
|
||||
host: ServicerHost,
|
||||
meeting_id: str,
|
||||
audio: NDArray[np.float32],
|
||||
) -> AsyncIterator[noteflow_pb2.TranscriptUpdate]:
|
||||
write_audio_chunk_safe(host, meeting_id, audio)
|
||||
async for update in process_audio_with_vad(host, meeting_id, audio):
|
||||
yield update
|
||||
|
||||
|
||||
async def process_stream_chunk(
|
||||
host: ServicerHost,
|
||||
meeting_id: str,
|
||||
@@ -44,35 +91,16 @@ async def process_stream_chunk(
|
||||
Yields:
|
||||
Transcript updates from processing.
|
||||
"""
|
||||
# Track chunk sequence for acknowledgment (default 0 for backwards compat)
|
||||
chunk_sequence = max(chunk.chunk_sequence, 0)
|
||||
ack_update = track_chunk_sequence(host, meeting_id, chunk_sequence)
|
||||
if ack_update is not None:
|
||||
yield ack_update
|
||||
|
||||
try:
|
||||
sample_rate, channels = normalize_stream_format(
|
||||
host,
|
||||
meeting_id,
|
||||
chunk.sample_rate,
|
||||
chunk.channels,
|
||||
)
|
||||
except ValueError as e:
|
||||
await abort_invalid_argument(context, str(e))
|
||||
raise # Unreachable but helps type checker
|
||||
ack_update = _track_ack_update(host, meeting_id, chunk)
|
||||
if ack_update is not None:
|
||||
yield ack_update
|
||||
|
||||
audio = await decode_and_convert_audio(
|
||||
host=host,
|
||||
chunk=chunk,
|
||||
stream_format=(sample_rate, channels),
|
||||
context=context,
|
||||
)
|
||||
if audio is None:
|
||||
return
|
||||
audio = await _decode_chunk_audio(host, meeting_id, chunk, context)
|
||||
if audio is None:
|
||||
return
|
||||
|
||||
# Write to encrypted audio file
|
||||
write_audio_chunk_safe(host, meeting_id, audio)
|
||||
|
||||
# VAD-driven segmentation
|
||||
async for update in process_audio_with_vad(host, meeting_id, audio):
|
||||
yield update
|
||||
async for update in _yield_processed_audio(host, meeting_id, audio):
|
||||
yield update
|
||||
finally:
|
||||
decrement_pending_chunks(host, meeting_id)
|
||||
|
||||
@@ -2,12 +2,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ....proto import noteflow_pb2
|
||||
from ...converters import create_congestion_info
|
||||
from ._constants import ACK_CHUNK_INTERVAL, PROCESSING_DELAY_THRESHOLD_MS, QUEUE_DEPTH_THRESHOLD
|
||||
from ._constants import PROCESSING_DELAY_THRESHOLD_MS, QUEUE_DEPTH_THRESHOLD
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...protocols import ServicerHost
|
||||
@@ -28,11 +27,16 @@ def calculate_congestion_info(
|
||||
Returns:
|
||||
CongestionInfo with processing delay, queue depth, and throttle recommendation.
|
||||
"""
|
||||
if receipt_times := host.chunk_receipt_times.get(meeting_id, deque()):
|
||||
oldest_receipt = receipt_times[0]
|
||||
processing_delay_ms = int((current_time - oldest_receipt) * 1000)
|
||||
else:
|
||||
processing_delay_ms = 0
|
||||
processing_delay_ms = 0
|
||||
receipt_times = host.chunk_receipt_times.get(meeting_id)
|
||||
if receipt_times:
|
||||
try:
|
||||
# Access [0] can race with popleft() in decrement_pending_chunks
|
||||
oldest_receipt = receipt_times[0]
|
||||
processing_delay_ms = int((current_time - oldest_receipt) * 1000)
|
||||
except IndexError:
|
||||
# Deque was emptied by concurrent popleft() - no pending chunks
|
||||
pass
|
||||
|
||||
# Get queue depth (pending chunks not yet processed through ASR)
|
||||
queue_depth = host.pending_chunks.get(meeting_id, 0)
|
||||
@@ -53,18 +57,17 @@ def calculate_congestion_info(
|
||||
def decrement_pending_chunks(host: ServicerHost, meeting_id: str) -> None:
|
||||
"""Decrement pending chunks counter after processing.
|
||||
|
||||
Call this after ASR processing completes for a segment.
|
||||
Call this after a chunk finishes processing.
|
||||
"""
|
||||
if meeting_id not in host.pending_chunks:
|
||||
return
|
||||
|
||||
# Decrement by ACK_CHUNK_INTERVAL since we process in batches
|
||||
host.pending_chunks[meeting_id] = max(
|
||||
0, host.pending_chunks[meeting_id] - ACK_CHUNK_INTERVAL
|
||||
)
|
||||
host.pending_chunks[meeting_id] = max(0, host.pending_chunks[meeting_id] - 1)
|
||||
receipt_times = host.chunk_receipt_times.get(meeting_id)
|
||||
if not receipt_times:
|
||||
return
|
||||
# Remove timestamps corresponding to processed chunks
|
||||
for _ in range(min(ACK_CHUNK_INTERVAL, len(receipt_times))):
|
||||
try:
|
||||
receipt_times.popleft()
|
||||
except IndexError:
|
||||
# Deque already empty - can occur if multiple coroutines race
|
||||
pass
|
||||
|
||||
@@ -59,13 +59,19 @@ async def _trigger_recording_webhook(
|
||||
"""
|
||||
if host.webhook_service is None:
|
||||
return
|
||||
await fire_webhook_safe(
|
||||
host.webhook_service.trigger_recording_started(
|
||||
meeting_id=meeting_id,
|
||||
title=title,
|
||||
task = asyncio.create_task(
|
||||
fire_webhook_safe(
|
||||
host.webhook_service.trigger_recording_started(
|
||||
meeting_id=meeting_id,
|
||||
title=title,
|
||||
),
|
||||
"recording.started",
|
||||
),
|
||||
"recording.started",
|
||||
name=f"webhook-recording-started-{meeting_id}",
|
||||
)
|
||||
# Store reference to prevent garbage collection before completion
|
||||
host.background_tasks.add(task)
|
||||
task.add_done_callback(host.background_tasks.discard)
|
||||
|
||||
|
||||
async def _prepare_meeting_for_streaming(
|
||||
@@ -160,7 +166,9 @@ class StreamSessionManager:
|
||||
init_result = await StreamSessionManager._init_stream_session(host, meeting_id)
|
||||
success = init_result.success
|
||||
if not success:
|
||||
host.active_streams.discard(meeting_id)
|
||||
# Release stream slot under lock to prevent race conditions
|
||||
async with host.stream_init_lock:
|
||||
host.active_streams.discard(meeting_id)
|
||||
error_code = init_result.error_code
|
||||
status = error_code if error_code is not None else grpc.StatusCode.INTERNAL
|
||||
error_message = init_result.error_message or ""
|
||||
@@ -206,7 +214,7 @@ class StreamSessionManager:
|
||||
await abort_failed_precondition(
|
||||
context, "Stream initialization timed out - server may be overloaded"
|
||||
)
|
||||
return False
|
||||
raise AssertionError("unreachable") # abort is NoReturn
|
||||
return reserved
|
||||
|
||||
@staticmethod
|
||||
@@ -234,7 +242,7 @@ class StreamSessionManager:
|
||||
await abort_failed_precondition(
|
||||
context, f"{ERROR_MSG_MEETING_PREFIX}{meeting_id} already streaming"
|
||||
)
|
||||
return False
|
||||
raise AssertionError("unreachable") # abort is NoReturn
|
||||
|
||||
@staticmethod
|
||||
async def _init_stream_session(
|
||||
|
||||
@@ -39,6 +39,31 @@ SYNC_RUN_CACHE_EXPIRY_SECONDS = 60
|
||||
_ERR_CALENDAR_NOT_ENABLED = "Calendar integration not enabled"
|
||||
|
||||
|
||||
def _sync_task_done_callback(
|
||||
task: asyncio.Task[None],
|
||||
sync_run_id: UUID,
|
||||
tasks_dict: dict[UUID, asyncio.Task[None]],
|
||||
) -> None:
|
||||
"""Handle completion of a sync task, logging any unhandled exceptions."""
|
||||
tasks_dict.pop(sync_run_id, None)
|
||||
if task.cancelled():
|
||||
return
|
||||
exc = task.exception()
|
||||
if exc is not None:
|
||||
logger.exception("Sync task %s failed with unhandled exception", sync_run_id, exc_info=exc)
|
||||
|
||||
|
||||
async def _cleanup_sync_cache(
|
||||
sync_run_id: UUID,
|
||||
cache: dict[UUID, SyncRun],
|
||||
cache_times: dict[UUID, datetime],
|
||||
) -> None:
|
||||
"""Clean up sync run cache after expiry delay (non-blocking)."""
|
||||
await asyncio.sleep(SYNC_RUN_CACHE_EXPIRY_SECONDS)
|
||||
cache.pop(sync_run_id, None)
|
||||
cache_times.pop(sync_run_id, None)
|
||||
|
||||
|
||||
def _format_enum_value(value: str | None) -> str:
|
||||
"""Format an enum value to string."""
|
||||
return "" if value is None else value
|
||||
@@ -68,6 +93,10 @@ class SyncMixin:
|
||||
sync_runs: dict[UUID, SyncRun]
|
||||
# Track when each sync run was cached (Sprint GAP-002: State Synchronization)
|
||||
sync_run_cache_times: dict[UUID, datetime]
|
||||
# Track background sync tasks for proper lifecycle management
|
||||
sync_tasks: dict[UUID, asyncio.Task[None]]
|
||||
# Track background cleanup tasks separately to avoid overwriting sync tasks
|
||||
sync_cleanup_tasks: dict[UUID, asyncio.Task[None]]
|
||||
|
||||
def ensure_sync_runs_cache(self: ServicerHost) -> dict[UUID, SyncRun]:
|
||||
"""Ensure the sync runs cache exists."""
|
||||
@@ -75,6 +104,10 @@ class SyncMixin:
|
||||
self.sync_runs = {}
|
||||
if not hasattr(self, "sync_run_cache_times"):
|
||||
self.sync_run_cache_times = {}
|
||||
if not hasattr(self, "sync_tasks"):
|
||||
self.sync_tasks = {}
|
||||
if not hasattr(self, "sync_cleanup_tasks"):
|
||||
self.sync_cleanup_tasks = {}
|
||||
return self.sync_runs
|
||||
|
||||
def cache_sync_run(self: ServicerHost, sync_run: SyncRun) -> None:
|
||||
@@ -101,6 +134,7 @@ class SyncMixin:
|
||||
"""Start a sync operation for an integration."""
|
||||
if self.calendar_service is None:
|
||||
await abort_unavailable(context, _ERR_CALENDAR_NOT_ENABLED)
|
||||
raise AssertionError("unreachable") # abort is NoReturn
|
||||
|
||||
integration_id = await parse_integration_id(request.integration_id, context)
|
||||
|
||||
@@ -113,17 +147,22 @@ class SyncMixin:
|
||||
provider = provider_value if isinstance(provider_value, str) else None
|
||||
if not provider:
|
||||
await abort_failed_precondition(context, "Integration provider not configured")
|
||||
return noteflow_pb2.StartIntegrationSyncResponse()
|
||||
raise AssertionError("unreachable") # abort is NoReturn
|
||||
|
||||
sync_run = SyncRun.start(integration_id)
|
||||
sync_run = await uow.integrations.create_sync_run(sync_run)
|
||||
await uow.commit()
|
||||
|
||||
self.cache_sync_run(sync_run)
|
||||
asyncio.create_task(
|
||||
self.ensure_sync_runs_cache() # Ensure sync_tasks is initialized
|
||||
task = asyncio.create_task(
|
||||
self.perform_sync(integration_id, sync_run.id, str(provider)),
|
||||
name=f"sync-{sync_run.id}",
|
||||
).add_done_callback(lambda _: None)
|
||||
)
|
||||
self.sync_tasks[sync_run.id] = task
|
||||
task.add_done_callback(
|
||||
lambda t: _sync_task_done_callback(t, sync_run.id, self.sync_tasks)
|
||||
)
|
||||
logger.info("Started sync run %s for integration %s", sync_run.id, integration_id)
|
||||
return noteflow_pb2.StartIntegrationSyncResponse(sync_run_id=str(sync_run.id), status="running")
|
||||
|
||||
@@ -154,7 +193,7 @@ class SyncMixin:
|
||||
return candidate, candidate.id
|
||||
|
||||
await abort_not_found(context, ENTITY_INTEGRATION, request.integration_id)
|
||||
return None, integration_id
|
||||
raise AssertionError("unreachable") # abort is NoReturn
|
||||
|
||||
async def perform_sync(
|
||||
self: ServicerHost,
|
||||
@@ -166,37 +205,56 @@ class SyncMixin:
|
||||
|
||||
Fetches calendar events and updates the sync run status.
|
||||
"""
|
||||
from noteflow.infrastructure.observability.otel import get_tracer
|
||||
|
||||
tracer = get_tracer(__name__)
|
||||
cache = self.ensure_sync_runs_cache()
|
||||
|
||||
try:
|
||||
items_synced = await self.execute_sync_fetch(provider)
|
||||
sync_run = await self.complete_sync_run(
|
||||
integration_id, sync_run_id, items_synced
|
||||
)
|
||||
if sync_run:
|
||||
self.cache_sync_run(sync_run)
|
||||
logger.info(
|
||||
"Sync run %s completed: %d items synced",
|
||||
sync_run_id,
|
||||
items_synced,
|
||||
)
|
||||
with tracer.start_as_current_span("integration_sync") as span:
|
||||
span.set_attribute("sync.run_id", str(sync_run_id))
|
||||
span.set_attribute("sync.integration_id", str(integration_id))
|
||||
span.set_attribute("sync.provider", provider)
|
||||
|
||||
# INTENTIONAL BROAD HANDLER: Sync run error boundary
|
||||
# - Calendar sync can fail in many ways (network, auth, parsing)
|
||||
# - Must capture any failure and update sync run status
|
||||
except Exception as e:
|
||||
logger.exception("Sync run %s failed: %s", sync_run_id, e)
|
||||
sync_run = await self.fail_sync_run(sync_run_id, str(e))
|
||||
if sync_run:
|
||||
self.cache_sync_run(sync_run)
|
||||
try:
|
||||
span.add_event("fetch_start")
|
||||
items_synced = await self.execute_sync_fetch(provider)
|
||||
span.set_attribute("sync.items_synced", items_synced)
|
||||
span.add_event("fetch_complete")
|
||||
|
||||
finally:
|
||||
# Clean up cache after a delay (keep for status queries)
|
||||
# Sprint GAP-002: SYNC_RUN_CACHE_EXPIRY_SECONDS
|
||||
await asyncio.sleep(SYNC_RUN_CACHE_EXPIRY_SECONDS)
|
||||
cache.pop(sync_run_id, None)
|
||||
if hasattr(self, "sync_run_cache_times"):
|
||||
self.sync_run_cache_times.pop(sync_run_id, None)
|
||||
sync_run = await self.complete_sync_run(
|
||||
integration_id, sync_run_id, items_synced
|
||||
)
|
||||
if sync_run:
|
||||
self.cache_sync_run(sync_run)
|
||||
logger.info(
|
||||
"Sync run %s completed: %d items synced",
|
||||
sync_run_id,
|
||||
items_synced,
|
||||
)
|
||||
|
||||
# INTENTIONAL BROAD HANDLER: Sync run error boundary
|
||||
# - Calendar sync can fail in many ways (network, auth, parsing)
|
||||
# - Must capture any failure and update sync run status
|
||||
except Exception as e:
|
||||
span.record_exception(e)
|
||||
span.set_attribute("sync.error", str(e))
|
||||
logger.exception("Sync run %s failed: %s", sync_run_id, e)
|
||||
sync_run = await self.fail_sync_run(sync_run_id, str(e))
|
||||
if sync_run:
|
||||
self.cache_sync_run(sync_run)
|
||||
|
||||
finally:
|
||||
# Schedule non-blocking cache cleanup after expiry delay
|
||||
# Sprint GAP-002: SYNC_RUN_CACHE_EXPIRY_SECONDS
|
||||
cleanup_task = asyncio.create_task(
|
||||
_cleanup_sync_cache(sync_run_id, cache, self.sync_run_cache_times),
|
||||
name=f"sync-cleanup-{sync_run_id}",
|
||||
)
|
||||
# Register cleanup task separately to avoid overwriting the sync task reference
|
||||
cleanup_task.add_done_callback(
|
||||
lambda t: _sync_task_done_callback(t, sync_run_id, self.sync_cleanup_tasks)
|
||||
)
|
||||
self.sync_cleanup_tasks[sync_run_id] = cleanup_task
|
||||
|
||||
async def execute_sync_fetch(self: ServicerHost, provider: str) -> int:
|
||||
"""Execute the calendar fetch and return items count."""
|
||||
|
||||
@@ -30,6 +30,7 @@ from noteflow.infrastructure.security.crypto import AesGcmCryptoBox
|
||||
|
||||
from ._service_shutdown import (
|
||||
cancel_diarization_tasks,
|
||||
cancel_sync_tasks,
|
||||
close_audio_writers,
|
||||
close_diarization_sessions,
|
||||
close_webhook_service,
|
||||
@@ -137,7 +138,8 @@ class ServicerStreamingStateMixin:
|
||||
)
|
||||
)
|
||||
partial_buffer = PartialAudioBuffer(sample_rate=self.DEFAULT_SAMPLE_RATE)
|
||||
current_time = time.time()
|
||||
# Use monotonic time for interval calculations (immune to NTP adjustments)
|
||||
current_time = time.monotonic()
|
||||
|
||||
state = MeetingStreamState(
|
||||
vad=vad,
|
||||
@@ -336,6 +338,7 @@ class ServicerLifecycleMixin:
|
||||
async def shutdown(self) -> None:
|
||||
"""Clean up servicer state before server stops."""
|
||||
logger.info("Shutting down servicer...")
|
||||
await cancel_sync_tasks(self)
|
||||
cancelled_job_ids = await cancel_diarization_tasks(self)
|
||||
mark_in_memory_jobs_failed(self, cancelled_job_ids)
|
||||
close_diarization_sessions(self)
|
||||
|
||||
@@ -14,6 +14,29 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def cancel_sync_tasks(servicer: NoteFlowServicer) -> None:
|
||||
"""Cancel all active sync tasks and cleanup tasks."""
|
||||
if not hasattr(servicer, "sync_tasks"):
|
||||
return
|
||||
for sync_run_id, task in list(servicer.sync_tasks.items()):
|
||||
if task.done():
|
||||
continue
|
||||
logger.debug("Cancelling sync task %s", sync_run_id)
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
servicer.sync_tasks.clear()
|
||||
|
||||
if hasattr(servicer, "sync_cleanup_tasks"):
|
||||
for sync_run_id, task in list(servicer.sync_cleanup_tasks.items()):
|
||||
if task.done():
|
||||
continue
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
servicer.sync_cleanup_tasks.clear()
|
||||
|
||||
|
||||
async def cancel_diarization_tasks(servicer: NoteFlowServicer) -> list[str]:
|
||||
"""Cancel all active diarization tasks and return their IDs."""
|
||||
cancelled_job_ids = list(servicer.diarization_tasks.keys())
|
||||
|
||||
@@ -9,7 +9,9 @@ header are rejected with UNAUTHENTICATED status.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import AsyncIterator, Awaitable, Callable
|
||||
from functools import partial
|
||||
from typing import Protocol, cast
|
||||
|
||||
import grpc
|
||||
from grpc import aio
|
||||
@@ -33,11 +35,95 @@ METADATA_WORKSPACE_ID = "x-workspace-id"
|
||||
# Error messages
|
||||
_ERR_MISSING_REQUEST_ID = "Missing required x-request-id header"
|
||||
|
||||
class _TypedRpcMethodHandler(Protocol[TRequest, TResponse]):
|
||||
unary_unary: Callable[
|
||||
[TRequest, aio.ServicerContext[TRequest, TResponse]],
|
||||
Awaitable[TResponse],
|
||||
] | None
|
||||
unary_stream: Callable[
|
||||
[TRequest, aio.ServicerContext[TRequest, TResponse]],
|
||||
AsyncIterator[TResponse],
|
||||
] | None
|
||||
stream_unary: Callable[
|
||||
[AsyncIterator[TRequest], aio.ServicerContext[TRequest, TResponse]],
|
||||
Awaitable[TResponse],
|
||||
] | None
|
||||
stream_stream: Callable[
|
||||
[AsyncIterator[TRequest], aio.ServicerContext[TRequest, TResponse]],
|
||||
AsyncIterator[TResponse],
|
||||
] | None
|
||||
request_deserializer: Callable[[bytes], TRequest] | None
|
||||
response_serializer: Callable[[TResponse], bytes] | None
|
||||
|
||||
def _coerce_metadata_value(value: str | bytes) -> str:
|
||||
"""Normalize metadata values to string."""
|
||||
return value.decode() if isinstance(value, bytes) else value
|
||||
|
||||
|
||||
def _get_request_id(
|
||||
metadata: dict[str, str | bytes],
|
||||
method: str | None,
|
||||
) -> str | None:
|
||||
request_id_value = metadata.get(METADATA_REQUEST_ID)
|
||||
if request_id_value is None:
|
||||
logger.warning(
|
||||
"Rejecting RPC: missing x-request-id header",
|
||||
method=method,
|
||||
)
|
||||
return None
|
||||
return _coerce_metadata_value(request_id_value)
|
||||
|
||||
|
||||
def _apply_identity_context(metadata: dict[str, str | bytes], request_id: str) -> None:
|
||||
request_id_var.set(request_id)
|
||||
|
||||
if user_id_value := metadata.get(METADATA_USER_ID):
|
||||
user_id_var.set(_coerce_metadata_value(user_id_value))
|
||||
|
||||
if workspace_id_value := metadata.get(METADATA_WORKSPACE_ID):
|
||||
workspace_id_var.set(_coerce_metadata_value(workspace_id_value))
|
||||
|
||||
|
||||
async def _reject_unary_unary(
|
||||
message: str,
|
||||
request: TRequest,
|
||||
context: aio.ServicerContext[TRequest, TResponse],
|
||||
) -> TResponse:
|
||||
await context.abort(grpc.StatusCode.UNAUTHENTICATED, message)
|
||||
raise AssertionError("Unreachable after abort")
|
||||
|
||||
|
||||
async def _reject_unary_stream(
|
||||
message: str,
|
||||
request: TRequest,
|
||||
context: aio.ServicerContext[TRequest, TResponse],
|
||||
) -> AsyncIterator[TResponse]:
|
||||
await context.abort(grpc.StatusCode.UNAUTHENTICATED, message)
|
||||
if False:
|
||||
# Unreachable; keeps async generator type for gRPC handler signature.
|
||||
yield cast(TResponse, None)
|
||||
|
||||
|
||||
async def _reject_stream_unary(
|
||||
message: str,
|
||||
request_iterator: AsyncIterator[TRequest],
|
||||
context: aio.ServicerContext[TRequest, TResponse],
|
||||
) -> TResponse:
|
||||
await context.abort(grpc.StatusCode.UNAUTHENTICATED, message)
|
||||
raise AssertionError("Unreachable after abort")
|
||||
|
||||
|
||||
async def _reject_stream_stream(
|
||||
message: str,
|
||||
request_iterator: AsyncIterator[TRequest],
|
||||
context: aio.ServicerContext[TRequest, TResponse],
|
||||
) -> AsyncIterator[TResponse]:
|
||||
await context.abort(grpc.StatusCode.UNAUTHENTICATED, message)
|
||||
if False:
|
||||
# Unreachable; keeps async generator type for gRPC handler signature.
|
||||
yield cast(TResponse, None)
|
||||
|
||||
|
||||
class IdentityInterceptor(aio.ServerInterceptor):
|
||||
"""Interceptor that validates and populates identity context for RPC calls.
|
||||
|
||||
@@ -61,38 +147,15 @@ class IdentityInterceptor(aio.ServerInterceptor):
|
||||
],
|
||||
handler_call_details: grpc.HandlerCallDetails,
|
||||
) -> grpc.RpcMethodHandler[TRequest, TResponse]:
|
||||
"""Intercept incoming RPC calls to validate and set identity context.
|
||||
|
||||
Args:
|
||||
continuation: The next interceptor or handler.
|
||||
handler_call_details: Details about the RPC call.
|
||||
|
||||
Returns:
|
||||
The RPC handler for this call.
|
||||
|
||||
Raises:
|
||||
grpc.RpcError: UNAUTHENTICATED if x-request-id header is missing.
|
||||
"""
|
||||
"""Intercept incoming RPC calls to validate and set identity context."""
|
||||
metadata = dict(handler_call_details.invocation_metadata or [])
|
||||
|
||||
# Validate required x-request-id header
|
||||
request_id_value = metadata.get(METADATA_REQUEST_ID)
|
||||
if request_id_value is None:
|
||||
logger.warning(
|
||||
"Rejecting RPC: missing x-request-id header",
|
||||
method=handler_call_details.method,
|
||||
)
|
||||
return _create_unauthenticated_handler(_ERR_MISSING_REQUEST_ID)
|
||||
request_id = _get_request_id(metadata, handler_call_details.method)
|
||||
if request_id is None:
|
||||
handler = await continuation(handler_call_details)
|
||||
return _create_unauthenticated_handler(handler, _ERR_MISSING_REQUEST_ID)
|
||||
|
||||
request_id = _coerce_metadata_value(request_id_value)
|
||||
request_id_var.set(request_id)
|
||||
|
||||
# Extract optional user and workspace IDs from metadata
|
||||
if user_id_value := metadata.get(METADATA_USER_ID):
|
||||
user_id_var.set(_coerce_metadata_value(user_id_value))
|
||||
|
||||
if workspace_id_value := metadata.get(METADATA_WORKSPACE_ID):
|
||||
workspace_id_var.set(_coerce_metadata_value(workspace_id_value))
|
||||
_apply_identity_context(metadata, request_id)
|
||||
|
||||
logger.debug(
|
||||
"Identity context: request=%s user=%s workspace=%s method=%s",
|
||||
@@ -106,26 +169,36 @@ class IdentityInterceptor(aio.ServerInterceptor):
|
||||
|
||||
|
||||
def _create_unauthenticated_handler(
|
||||
handler: grpc.RpcMethodHandler[TRequest, TResponse],
|
||||
message: str,
|
||||
) -> grpc.RpcMethodHandler[TRequest, TResponse]:
|
||||
"""Create a handler that rejects with UNAUTHENTICATED status.
|
||||
"""Create a handler that rejects with UNAUTHENTICATED status."""
|
||||
typed_handler = cast(_TypedRpcMethodHandler[TRequest, TResponse], handler)
|
||||
request_deserializer = typed_handler.request_deserializer
|
||||
response_serializer = typed_handler.response_serializer
|
||||
|
||||
Args:
|
||||
message: Error message to include in the response.
|
||||
|
||||
Returns:
|
||||
A gRPC method handler that rejects all requests.
|
||||
"""
|
||||
|
||||
async def reject_unary_unary(
|
||||
request: TRequest,
|
||||
context: aio.ServicerContext[TRequest, TResponse],
|
||||
) -> TResponse:
|
||||
await context.abort(grpc.StatusCode.UNAUTHENTICATED, message)
|
||||
raise AssertionError("Unreachable after abort")
|
||||
|
||||
return grpc.unary_unary_rpc_method_handler(
|
||||
reject_unary_unary,
|
||||
request_deserializer=None,
|
||||
response_serializer=None,
|
||||
)
|
||||
if typed_handler.unary_unary is not None:
|
||||
return grpc.unary_unary_rpc_method_handler(
|
||||
partial(_reject_unary_unary, message),
|
||||
request_deserializer=request_deserializer,
|
||||
response_serializer=response_serializer,
|
||||
)
|
||||
if typed_handler.unary_stream is not None:
|
||||
return grpc.unary_stream_rpc_method_handler(
|
||||
partial(_reject_unary_stream, message),
|
||||
request_deserializer=request_deserializer,
|
||||
response_serializer=response_serializer,
|
||||
)
|
||||
if typed_handler.stream_unary is not None:
|
||||
return grpc.stream_unary_rpc_method_handler(
|
||||
partial(_reject_stream_unary, message),
|
||||
request_deserializer=request_deserializer,
|
||||
response_serializer=response_serializer,
|
||||
)
|
||||
if typed_handler.stream_stream is not None:
|
||||
return grpc.stream_stream_rpc_method_handler(
|
||||
partial(_reject_stream_stream, message),
|
||||
request_deserializer=request_deserializer,
|
||||
response_serializer=response_serializer,
|
||||
)
|
||||
return handler
|
||||
|
||||
@@ -87,7 +87,15 @@ def register_shutdown_handlers(loop: asyncio.AbstractEventLoop) -> asyncio.Event
|
||||
shutdown_event.set()
|
||||
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
loop.add_signal_handler(sig, signal_handler)
|
||||
try:
|
||||
loop.add_signal_handler(sig, signal_handler)
|
||||
except (NotImplementedError, RuntimeError, ValueError) as exc:
|
||||
logger.warning(
|
||||
"Signal handlers not supported; relying on default handlers",
|
||||
signal=sig,
|
||||
error=str(exc),
|
||||
)
|
||||
break
|
||||
return shutdown_event
|
||||
|
||||
|
||||
|
||||
@@ -41,5 +41,9 @@ def bind_server(
|
||||
server,
|
||||
)
|
||||
address = f"{bind_address}:{port}"
|
||||
server.add_insecure_port(address)
|
||||
bound_port = server.add_insecure_port(address)
|
||||
if bound_port == 0:
|
||||
raise RuntimeError(f"Failed to bind gRPC server on {address}")
|
||||
if port == 0 and bound_port != 0:
|
||||
address = f"{bind_address}:{bound_port}"
|
||||
return address
|
||||
|
||||
@@ -265,6 +265,7 @@ class NoteFlowServicer(
|
||||
self.pending_chunks: dict[str, int] = {}
|
||||
self.audio_write_failed: set[str] = set()
|
||||
self.stream_states: dict[str, MeetingStreamState] = {}
|
||||
self.background_tasks: set[asyncio.Task[None]] = set()
|
||||
|
||||
def _init_diarization_state(self) -> None:
|
||||
"""Initialize diarization job tracking state."""
|
||||
|
||||
@@ -207,11 +207,13 @@ class GoogleCalendarAdapter(CalendarPort):
|
||||
|
||||
# Get display name from 'name' field, fall back to email prefix
|
||||
name = data.get("name")
|
||||
display_name = (
|
||||
str(name)
|
||||
if name
|
||||
else str(email).split("@")[0].replace(".", " ").title()
|
||||
)
|
||||
if name:
|
||||
display_name = str(name)
|
||||
else:
|
||||
# Extract username from email, handling edge cases where @ may be missing
|
||||
email_str = str(email)
|
||||
local_part = email_str.split("@")[0] if "@" in email_str else email_str
|
||||
display_name = local_part.replace(".", " ").title() if local_part else email_str
|
||||
|
||||
return str(email), display_name
|
||||
|
||||
|
||||
@@ -64,7 +64,9 @@ def _extract_email(profile: OutlookProfile) -> str:
|
||||
def _format_display_name(profile: OutlookProfile, email: str) -> str:
|
||||
if display_name_raw := profile.get("displayName"):
|
||||
return str(display_name_raw)
|
||||
return email.split("@")[0].replace(".", " ").title()
|
||||
# Extract username from email, handling edge cases where @ may be missing
|
||||
local_part = email.split("@")[0] if "@" in email else email
|
||||
return local_part.replace(".", " ").title() if local_part else email
|
||||
|
||||
|
||||
async def fetch_user_info(
|
||||
|
||||
@@ -28,6 +28,19 @@ from ._usage_event_builders import build_usage_event, extract_event_context
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _flush_task_done_callback(
|
||||
task: asyncio.Task[None],
|
||||
tasks_set: set[asyncio.Task[None]],
|
||||
) -> None:
|
||||
"""Handle completion of a flush task, logging any unhandled exceptions."""
|
||||
tasks_set.discard(task)
|
||||
if task.cancelled():
|
||||
return
|
||||
exc = task.exception()
|
||||
if exc is not None:
|
||||
logger.exception("Flush task failed with unhandled exception", exc_info=exc)
|
||||
|
||||
|
||||
class BufferedDatabaseUsageEventSink:
|
||||
"""Usage event sink that persists events to database with buffering.
|
||||
|
||||
@@ -41,6 +54,7 @@ class BufferedDatabaseUsageEventSink:
|
||||
|
||||
DEFAULT_BUFFER_SIZE = 100
|
||||
DEFAULT_FLUSH_INTERVAL_SECONDS = 5.0
|
||||
DEFAULT_FLUSH_TIMEOUT_SECONDS = 30.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -61,7 +75,7 @@ class BufferedDatabaseUsageEventSink:
|
||||
self._flush_interval = flush_interval
|
||||
self._buffer: deque[UsageEvent] = deque(maxlen=buffer_size * 2)
|
||||
self._lock = Lock()
|
||||
self._flush_task: asyncio.Task[None] | None = None
|
||||
self._flush_tasks: set[asyncio.Task[None]] = set()
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
def record(self, event: UsageEvent) -> None:
|
||||
@@ -99,7 +113,11 @@ class BufferedDatabaseUsageEventSink:
|
||||
loop = asyncio.get_running_loop()
|
||||
self._loop = loop
|
||||
# Store task reference to prevent garbage collection
|
||||
self._flush_task = loop.create_task(self._flush_async())
|
||||
task = loop.create_task(self._flush_async())
|
||||
self._flush_tasks.add(task)
|
||||
task.add_done_callback(
|
||||
lambda t: _flush_task_done_callback(t, self._flush_tasks)
|
||||
)
|
||||
except RuntimeError:
|
||||
# No running loop, will flush on next opportunity
|
||||
logger.debug("No event loop available for flush scheduling")
|
||||
@@ -125,17 +143,40 @@ class BufferedDatabaseUsageEventSink:
|
||||
|
||||
async def _flush_async(self) -> None:
|
||||
"""Flush buffered events to database."""
|
||||
from noteflow.infrastructure.observability.otel import get_tracer
|
||||
|
||||
events = self._drain_buffer()
|
||||
if not events:
|
||||
return
|
||||
|
||||
try:
|
||||
tracer = get_tracer(__name__)
|
||||
with tracer.start_as_current_span("usage_events_flush") as span:
|
||||
span.set_attribute("events.count", len(events))
|
||||
repo = self._repository_factory()
|
||||
count = await repo.add_batch(events)
|
||||
logger.debug("Flushed %d usage events to database", count)
|
||||
except Exception:
|
||||
logger.exception("Failed to flush usage events to database")
|
||||
self._restore_events(events)
|
||||
try:
|
||||
count = await asyncio.wait_for(
|
||||
repo.add_batch(events),
|
||||
timeout=self.DEFAULT_FLUSH_TIMEOUT_SECONDS,
|
||||
)
|
||||
span.set_attribute("events.flushed", count)
|
||||
logger.debug("Flushed %d usage events to database", count)
|
||||
except TimeoutError:
|
||||
span.set_attribute("events.timeout", True)
|
||||
logger.warning(
|
||||
"Database flush timed out after %.1fs, restoring %d events",
|
||||
self.DEFAULT_FLUSH_TIMEOUT_SECONDS,
|
||||
len(events),
|
||||
)
|
||||
self._restore_events(events)
|
||||
except Exception as exc:
|
||||
span.record_exception(exc)
|
||||
logger.exception("Failed to flush usage events to database")
|
||||
self._restore_events(events)
|
||||
finally:
|
||||
# Ensure session cleanup even on timeout/exception
|
||||
# Repository session is accessed via protected attribute
|
||||
if hasattr(repo, "_session") and repo._session is not None:
|
||||
await repo._session.close()
|
||||
|
||||
async def flush(self) -> int:
|
||||
"""Manually flush all buffered events.
|
||||
@@ -147,16 +188,47 @@ class BufferedDatabaseUsageEventSink:
|
||||
if not events:
|
||||
return 0
|
||||
|
||||
repo = self._repository_factory()
|
||||
try:
|
||||
repo = self._repository_factory()
|
||||
return await repo.add_batch(events)
|
||||
except Exception:
|
||||
logger.exception("Failed to flush usage events")
|
||||
self._restore_events(events)
|
||||
return 0
|
||||
finally:
|
||||
# Ensure session cleanup even on exception
|
||||
if hasattr(repo, "_session") and repo._session is not None:
|
||||
await repo._session.close()
|
||||
|
||||
@property
|
||||
def pending_count(self) -> int:
|
||||
"""Number of events waiting to be flushed."""
|
||||
with self._lock:
|
||||
return len(self._buffer)
|
||||
|
||||
async def shutdown(self, timeout: float = 5.0) -> None:
|
||||
"""Gracefully shutdown the sink, flushing pending events.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for flush to complete.
|
||||
"""
|
||||
# Wait for any pending flush tasks
|
||||
pending_tasks = [t for t in self._flush_tasks if not t.done()]
|
||||
if pending_tasks:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(*pending_tasks, return_exceptions=True),
|
||||
timeout=timeout / 2,
|
||||
)
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"Usage sink: %d pending flush tasks timed out", len(pending_tasks)
|
||||
)
|
||||
|
||||
# Perform final flush
|
||||
try:
|
||||
await asyncio.wait_for(self.flush(), timeout=timeout / 2)
|
||||
except TimeoutError:
|
||||
logger.warning("Usage sink: final flush timed out")
|
||||
except Exception:
|
||||
logger.exception("Usage sink: final flush failed")
|
||||
|
||||
@@ -9,18 +9,28 @@ from __future__ import annotations
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from collections.abc import Sequence
|
||||
import tracemalloc
|
||||
from collections.abc import Awaitable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.domain.entities import Segment
|
||||
from noteflow.infrastructure.persistence.models.core import SegmentModel
|
||||
from noteflow.infrastructure.security.crypto import AesGcmCryptoBox
|
||||
|
||||
|
||||
class _OrmConverter(Protocol):
|
||||
"""Protocol for ORM to domain converter."""
|
||||
|
||||
def segment_to_domain(self, model: SegmentModel) -> Segment: ...
|
||||
|
||||
|
||||
class _StreamingServicer(Protocol):
|
||||
active_streams: set[str]
|
||||
|
||||
@@ -135,16 +145,33 @@ def run_audio_writer_cycles(
|
||||
cycle_count: int,
|
||||
buffer_size: int = 1024,
|
||||
sample_rate: int = DEFAULT_SAMPLE_RATE,
|
||||
) -> None:
|
||||
"""Run audio writer open/close cycles (helper to avoid loops in tests)."""
|
||||
chunks_per_cycle: int = 10,
|
||||
chunk_duration: float = 0.1,
|
||||
) -> list[int]:
|
||||
"""Run audio writer open/close cycles with actual audio data.
|
||||
|
||||
Returns:
|
||||
List of bytes written per cycle (for metric logging).
|
||||
"""
|
||||
from noteflow.infrastructure.audio.writer import MeetingAudioWriter
|
||||
|
||||
bytes_per_cycle: list[int] = []
|
||||
for i in range(cycle_count):
|
||||
writer = MeetingAudioWriter(crypto, base_path, buffer_size=buffer_size)
|
||||
dek = crypto.generate_dek()
|
||||
wrapped_dek = crypto.wrap_dek(dek)
|
||||
writer.open(f"meeting-{i}", dek, wrapped_dek, sample_rate=sample_rate)
|
||||
|
||||
# Write actual audio data
|
||||
for _ in range(chunks_per_cycle):
|
||||
chunk = generate_audio_chunk(chunk_duration, sample_rate)
|
||||
writer.write_chunk(chunk)
|
||||
|
||||
bytes_written = writer.bytes_written
|
||||
writer.close()
|
||||
bytes_per_cycle.append(bytes_written)
|
||||
|
||||
return bytes_per_cycle
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -253,3 +280,856 @@ def run_fuzz_iterations_batch(
|
||||
errors = verify_fuzz_iteration_results(results, seed)
|
||||
all_errors.extend(errors)
|
||||
return all_errors
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Heap growth test helpers (tracemalloc-based)
|
||||
# =============================================================================
|
||||
|
||||
# Heap growth test constants
|
||||
HEAP_MAX_GROWTH_BYTES = 1 * 1024 * 1024 # 1MB
|
||||
HEAP_CONVERTER_MAX_GROWTH_BYTES = 512 * 1024 # 512KB
|
||||
HEAP_PROTO_MAX_GROWTH_BYTES = 256 * 1024 # 256KB
|
||||
ORM_CONVERSION_CYCLES = 500
|
||||
PROTO_CYCLES = 200
|
||||
|
||||
|
||||
def calculate_noteflow_heap_growth(
|
||||
snapshot_before: tracemalloc.Snapshot,
|
||||
snapshot_after: tracemalloc.Snapshot,
|
||||
) -> int:
|
||||
"""Calculate heap growth in noteflow code between snapshots.
|
||||
|
||||
Args:
|
||||
snapshot_before: tracemalloc snapshot before operation.
|
||||
snapshot_after: tracemalloc snapshot after operation.
|
||||
|
||||
Returns:
|
||||
Total bytes of growth in noteflow code paths.
|
||||
"""
|
||||
top_stats = snapshot_after.compare_to(snapshot_before, "lineno")
|
||||
return sum(
|
||||
stat.size_diff
|
||||
for stat in top_stats
|
||||
if "noteflow" in str(stat.traceback)
|
||||
)
|
||||
|
||||
|
||||
def calculate_filtered_heap_growth(
|
||||
snapshot_before: tracemalloc.Snapshot,
|
||||
snapshot_after: tracemalloc.Snapshot,
|
||||
filter_str: str,
|
||||
) -> int:
|
||||
"""Calculate heap growth filtered by string in traceback.
|
||||
|
||||
Args:
|
||||
snapshot_before: tracemalloc snapshot before operation.
|
||||
snapshot_after: tracemalloc snapshot after operation.
|
||||
filter_str: String to match in traceback paths.
|
||||
|
||||
Returns:
|
||||
Total bytes of growth matching filter.
|
||||
"""
|
||||
top_stats = snapshot_after.compare_to(snapshot_before, "lineno")
|
||||
return sum(
|
||||
stat.size_diff
|
||||
for stat in top_stats
|
||||
if filter_str in str(stat.traceback).lower()
|
||||
)
|
||||
|
||||
|
||||
def run_orm_conversion_cycles(
|
||||
converter: _OrmConverter,
|
||||
meeting_id: UUID,
|
||||
cycle_count: int,
|
||||
) -> None:
|
||||
"""Run ORM to domain conversion cycles (helper to avoid loops in tests)."""
|
||||
from noteflow.infrastructure.persistence.models.core import SegmentModel
|
||||
|
||||
for i in range(cycle_count):
|
||||
model = SegmentModel(
|
||||
meeting_id=meeting_id,
|
||||
segment_id=i,
|
||||
text=f"Segment {i} content.",
|
||||
start_time=float(i * 5),
|
||||
end_time=float(i * 5 + 4.5),
|
||||
speaker_id=f"speaker_{i % 3}",
|
||||
)
|
||||
converter.segment_to_domain(model)
|
||||
|
||||
|
||||
def run_protobuf_cycles(cycle_count: int) -> None:
|
||||
"""Run protobuf create/serialize/parse cycles (helper to avoid loops in tests)."""
|
||||
from uuid import uuid4
|
||||
|
||||
from noteflow.grpc.proto import noteflow_pb2
|
||||
|
||||
for _ in range(cycle_count):
|
||||
meeting = noteflow_pb2.Meeting(
|
||||
id=str(uuid4()),
|
||||
title="Test Meeting",
|
||||
state=noteflow_pb2.MEETING_STATE_COMPLETED,
|
||||
)
|
||||
serialized = meeting.SerializeToString()
|
||||
parsed = noteflow_pb2.Meeting()
|
||||
parsed.ParseFromString(serialized)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Async context edge case helpers
|
||||
# =============================================================================
|
||||
|
||||
# Context test constants
|
||||
CONTEXT_SLEEP_DURATION_SECONDS = 0.05
|
||||
CONTEXT_CANCEL_DELAY_SECONDS = 0.01
|
||||
SLOW_ENTER_SLEEP_SECONDS = 10
|
||||
CONCURRENT_CONTEXT_COUNT = 10
|
||||
|
||||
|
||||
def run_concurrent_context_tasks(
|
||||
context_ids: list[str],
|
||||
active_contexts: set[str],
|
||||
results: dict[str, int],
|
||||
) -> list[Awaitable[None]]:
|
||||
"""Create tasks for concurrent context testing.
|
||||
|
||||
Args:
|
||||
context_ids: List of context IDs to create.
|
||||
active_contexts: Shared set tracking active contexts.
|
||||
results: Dict to store max_concurrent count.
|
||||
|
||||
Returns:
|
||||
List of awaitables to gather.
|
||||
"""
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def tracked_context(ctx_id: str) -> AsyncIterator[str]:
|
||||
active_contexts.add(ctx_id)
|
||||
results["max_concurrent"] = max(results["max_concurrent"], len(active_contexts))
|
||||
try:
|
||||
yield ctx_id
|
||||
finally:
|
||||
active_contexts.discard(ctx_id)
|
||||
|
||||
async def use_context(ctx_id: str) -> None:
|
||||
async with tracked_context(ctx_id):
|
||||
await asyncio.sleep(CONTEXT_SLEEP_DURATION_SECONDS)
|
||||
|
||||
return [use_context(ctx_id) for ctx_id in context_ids]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Memory pressure test helpers
|
||||
# =============================================================================
|
||||
|
||||
# Memory pressure constants
|
||||
MEMORY_PRESSURE_ALLOCATION_MB = 50
|
||||
MEMORY_PRESSURE_CHUNK_SIZE_BYTES = 1024 * 1024 # 1MB chunks
|
||||
GC_TEMP_OBJECT_COUNT = 1000
|
||||
LINUX_RSS_MULTIPLIER = 1024 # resource.ru_maxrss returns KB on Linux
|
||||
|
||||
|
||||
def allocate_memory_pressure(size_mb: int) -> list[bytes]:
|
||||
"""Allocate memory to create pressure, returns references to keep alive.
|
||||
|
||||
Args:
|
||||
size_mb: Megabytes to allocate.
|
||||
|
||||
Returns:
|
||||
List of byte objects (must keep reference to maintain pressure).
|
||||
"""
|
||||
chunks: list[bytes] = []
|
||||
for _ in range(size_mb):
|
||||
chunks.append(os.urandom(MEMORY_PRESSURE_CHUNK_SIZE_BYTES))
|
||||
return chunks
|
||||
|
||||
|
||||
def run_operations_under_memory_pressure(
|
||||
servicer: _StreamingServicer,
|
||||
cycle_count: int,
|
||||
pressure_mb: int,
|
||||
) -> tuple[int, int]:
|
||||
"""Run streaming operations while holding memory pressure.
|
||||
|
||||
Args:
|
||||
servicer: The servicer to test.
|
||||
cycle_count: Number of init/cleanup cycles.
|
||||
pressure_mb: MB of memory to hold during operations.
|
||||
|
||||
Returns:
|
||||
Tuple of (successful_cycles, gc_collections).
|
||||
"""
|
||||
import gc
|
||||
|
||||
pressure_data = allocate_memory_pressure(pressure_mb)
|
||||
gc_before = sum(gc.get_count())
|
||||
|
||||
successful = 0
|
||||
for i in range(cycle_count):
|
||||
meeting_id = f"pressure-test-{i:03d}"
|
||||
try:
|
||||
servicer.init_streaming_state(meeting_id, next_segment_id=0)
|
||||
servicer.active_streams.add(meeting_id)
|
||||
servicer.cleanup_streaming_state(meeting_id)
|
||||
servicer.close_audio_writer(meeting_id)
|
||||
servicer.active_streams.discard(meeting_id)
|
||||
successful += 1
|
||||
except MemoryError:
|
||||
break
|
||||
|
||||
gc_after = sum(gc.get_count())
|
||||
del pressure_data
|
||||
|
||||
return successful, gc_after - gc_before
|
||||
|
||||
|
||||
def run_gc_collection_behavior() -> tuple[int, int, int]:
|
||||
"""Run GC behavior test with temporary objects.
|
||||
|
||||
Returns:
|
||||
Tuple of (gen0_diff, gen1_diff, gen2_diff) in collection counts.
|
||||
"""
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
counts_before = gc.get_count()
|
||||
|
||||
for _ in range(GC_TEMP_OBJECT_COUNT):
|
||||
_ = {"key": "value", "data": [1, 1, 1, 1, 1]}
|
||||
|
||||
gc.collect()
|
||||
counts_after = gc.get_count()
|
||||
|
||||
return (
|
||||
counts_after[0] - counts_before[0],
|
||||
counts_after[1] - counts_before[1],
|
||||
counts_after[0 + 1 + 1] - counts_before[0 + 1 + 1],
|
||||
)
|
||||
|
||||
|
||||
def measure_rss_bytes() -> int:
|
||||
"""Measure current process RSS (Resident Set Size) in bytes.
|
||||
|
||||
Returns:
|
||||
RSS in bytes, or -1 if measurement not supported.
|
||||
"""
|
||||
try:
|
||||
import psutil
|
||||
|
||||
process = psutil.Process()
|
||||
return process.memory_info().rss
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if sys.platform == "darwin" or sys.platform == "linux":
|
||||
try:
|
||||
import resource
|
||||
|
||||
usage = resource.getrusage(resource.RUSAGE_SELF)
|
||||
if sys.platform == "linux":
|
||||
return usage.ru_maxrss * LINUX_RSS_MULTIPLIER
|
||||
return usage.ru_maxrss
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return -1
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Rigorous stress test helpers
|
||||
# =============================================================================
|
||||
|
||||
# High-cycle constants for rigorous testing
|
||||
RIGOROUS_CYCLE_COUNT = 1000
|
||||
RIGOROUS_AUDIO_CHUNKS_PER_CYCLE = 100
|
||||
RIGOROUS_CONCURRENT_MEETINGS = 20
|
||||
RIGOROUS_AUDIO_DURATION_SECONDS = 0.1
|
||||
THREAD_LEAK_TOLERANCE = 2 # Allow up to 2 extra threads (GC, etc.)
|
||||
FD_LEAK_TOLERANCE = 5 # Allow up to 5 extra FDs (system variation)
|
||||
|
||||
|
||||
def get_thread_count() -> int:
|
||||
"""Get current thread count for this process."""
|
||||
import threading
|
||||
|
||||
return threading.active_count()
|
||||
|
||||
|
||||
def generate_audio_chunk(
|
||||
duration_seconds: float,
|
||||
sample_rate: int,
|
||||
) -> NDArray[np.float32]:
|
||||
"""Generate realistic audio chunk with speech-like characteristics."""
|
||||
samples = int(duration_seconds * sample_rate)
|
||||
# Generate pink noise (more speech-like than white noise)
|
||||
white = np.random.randn(samples).astype(np.float32)
|
||||
# Simple low-pass approximation for pink noise
|
||||
pink = np.convolve(white, np.ones(8) / 8, mode="same")
|
||||
# Normalize to [-0.5, 0.5] range (typical speech amplitude)
|
||||
pink = pink / (np.abs(pink).max() + 1e-6) * 0.5
|
||||
return pink.astype(np.float32)
|
||||
|
||||
|
||||
def run_audio_writer_high_cycle(
|
||||
crypto: AesGcmCryptoBox,
|
||||
base_path: Path,
|
||||
cycle_count: int,
|
||||
chunks_per_cycle: int,
|
||||
sample_rate: int = DEFAULT_SAMPLE_RATE,
|
||||
) -> tuple[int, int, int]:
|
||||
"""Run high-cycle audio writer stress test with real data.
|
||||
|
||||
Returns:
|
||||
Tuple of (successful_cycles, total_bytes_written, total_chunks).
|
||||
"""
|
||||
from noteflow.infrastructure.audio.writer import MeetingAudioWriter
|
||||
|
||||
successful = 0
|
||||
total_bytes = 0
|
||||
total_chunks = 0
|
||||
|
||||
for i in range(cycle_count):
|
||||
writer = MeetingAudioWriter(crypto, base_path, buffer_size=4096)
|
||||
dek = crypto.generate_dek()
|
||||
wrapped_dek = crypto.wrap_dek(dek)
|
||||
|
||||
writer.open(f"stress-{i:05d}", dek, wrapped_dek, sample_rate=sample_rate)
|
||||
|
||||
# Write real audio data (writer expects float32 arrays)
|
||||
for _ in range(chunks_per_cycle):
|
||||
chunk = generate_audio_chunk(RIGOROUS_AUDIO_DURATION_SECONDS, sample_rate)
|
||||
writer.write_chunk(chunk)
|
||||
total_chunks += 1
|
||||
|
||||
total_bytes += writer.bytes_written
|
||||
writer.close()
|
||||
successful += 1
|
||||
|
||||
return successful, total_bytes, total_chunks
|
||||
|
||||
|
||||
def run_segmenter_high_cycle(
|
||||
cycle_count: int,
|
||||
chunks_per_cycle: int,
|
||||
sample_rate: int = DEFAULT_SAMPLE_RATE,
|
||||
) -> tuple[int, int]:
|
||||
"""Run high-cycle segmenter stress test with real audio.
|
||||
|
||||
Returns:
|
||||
Tuple of (total_segments_produced, total_audio_seconds_processed).
|
||||
"""
|
||||
from noteflow.infrastructure.asr.segmenter import Segmenter, SegmenterConfig
|
||||
|
||||
total_segments = 0
|
||||
total_seconds = 0.0
|
||||
|
||||
for _ in range(cycle_count):
|
||||
config = SegmenterConfig(sample_rate=sample_rate)
|
||||
segmenter = Segmenter(config=config)
|
||||
|
||||
for j in range(chunks_per_cycle):
|
||||
chunk = generate_audio_chunk(RIGOROUS_AUDIO_DURATION_SECONDS, sample_rate)
|
||||
# Alternate speech/silence to produce segments
|
||||
is_speech = (j % 20) < 15 # 75% speech
|
||||
segments = segmenter.process_audio(chunk, is_speech)
|
||||
total_segments += len(segments)
|
||||
total_seconds += RIGOROUS_AUDIO_DURATION_SECONDS
|
||||
|
||||
if segmenter.flush():
|
||||
total_segments += 1
|
||||
|
||||
return total_segments, int(total_seconds)
|
||||
|
||||
|
||||
def run_concurrent_streaming_sessions(
|
||||
servicer: _StreamingServicer,
|
||||
meeting_count: int,
|
||||
chunks_per_meeting: int,
|
||||
sample_rate: int = DEFAULT_SAMPLE_RATE,
|
||||
) -> tuple[int, list[str]]:
|
||||
"""Run concurrent streaming sessions on servicer.
|
||||
|
||||
Returns:
|
||||
Tuple of (successful_meetings, any_error_messages).
|
||||
"""
|
||||
errors: list[str] = []
|
||||
successful = 0
|
||||
|
||||
# Initialize all meetings
|
||||
meeting_ids = [f"concurrent-{i:03d}" for i in range(meeting_count)]
|
||||
for mid in meeting_ids:
|
||||
servicer.init_streaming_state(mid, next_segment_id=0)
|
||||
servicer.active_streams.add(mid)
|
||||
|
||||
# Verify all initialized
|
||||
if len(servicer.active_streams) != meeting_count:
|
||||
errors.append(
|
||||
f"Expected {meeting_count} active streams, got {len(servicer.active_streams)}"
|
||||
)
|
||||
|
||||
# Cleanup all meetings
|
||||
for mid in meeting_ids:
|
||||
servicer.cleanup_streaming_state(mid)
|
||||
servicer.close_audio_writer(mid)
|
||||
servicer.active_streams.discard(mid)
|
||||
successful += 1
|
||||
|
||||
return successful, errors
|
||||
|
||||
|
||||
def measure_resource_baseline() -> tuple[int, int, int]:
|
||||
"""Measure baseline resource usage.
|
||||
|
||||
Returns:
|
||||
Tuple of (fd_count, thread_count, rss_bytes).
|
||||
"""
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
return get_fd_count(), get_thread_count(), measure_rss_bytes()
|
||||
|
||||
|
||||
def check_resource_leaks(
|
||||
baseline: tuple[int, int, int],
|
||||
fd_tolerance: int = FD_LEAK_TOLERANCE,
|
||||
thread_tolerance: int = THREAD_LEAK_TOLERANCE,
|
||||
) -> list[str]:
|
||||
"""Check for resource leaks against baseline.
|
||||
|
||||
Returns:
|
||||
List of error messages (empty if no leaks).
|
||||
"""
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
errors: list[str] = []
|
||||
|
||||
fd_before, thread_before, _rss_before = baseline
|
||||
fd_after = get_fd_count()
|
||||
thread_after = get_thread_count()
|
||||
|
||||
if fd_before >= 0 and fd_after >= 0:
|
||||
fd_delta = fd_after - fd_before
|
||||
if fd_delta > fd_tolerance:
|
||||
errors.append(f"FD leak: {fd_delta} new FDs (tolerance={fd_tolerance})")
|
||||
|
||||
thread_delta = thread_after - thread_before
|
||||
if thread_delta > thread_tolerance:
|
||||
errors.append(
|
||||
f"Thread leak: {thread_delta} new threads (tolerance={thread_tolerance})"
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def log_stress_metrics(
|
||||
test_name: str,
|
||||
baseline: tuple[int, int, int],
|
||||
*,
|
||||
bytes_written: int = 0,
|
||||
cycles: int = 0,
|
||||
extra: dict[str, object] | None = None,
|
||||
) -> None:
|
||||
"""Log quantified stress test metrics to stdout.
|
||||
|
||||
Metrics are emitted in a structured format for test output analysis.
|
||||
Use pytest -s to see these metrics during test runs.
|
||||
"""
|
||||
fd_before, thread_before, rss_before = baseline
|
||||
fd_after = get_fd_count()
|
||||
thread_after = get_thread_count()
|
||||
rss_after = measure_rss_bytes()
|
||||
|
||||
fd_delta = fd_after - fd_before if fd_before >= 0 and fd_after >= 0 else 0
|
||||
thread_delta = thread_after - thread_before
|
||||
rss_delta_kb = (rss_after - rss_before) // 1024
|
||||
|
||||
lines = [
|
||||
f"\n{'=' * 60}",
|
||||
f"STRESS METRICS: {test_name}",
|
||||
f"{'=' * 60}",
|
||||
f" Cycles: {cycles:,}",
|
||||
f" Bytes written: {bytes_written:,}",
|
||||
f" FD delta: {fd_delta:+d} ({fd_before} -> {fd_after})",
|
||||
f" Thread delta: {thread_delta:+d} ({thread_before} -> {thread_after})",
|
||||
f" RSS delta: {rss_delta_kb:+,} KB ({rss_before // 1024:,} -> {rss_after // 1024:,} KB)",
|
||||
]
|
||||
|
||||
if extra:
|
||||
lines.append(" Extra:")
|
||||
lines.extend(f" {k}: {v}" for k, v in extra.items())
|
||||
|
||||
lines.append(f"{'=' * 60}\n")
|
||||
print("\n".join(lines))
|
||||
|
||||
|
||||
def verify_audio_writer_high_cycle_metrics(
|
||||
test_name: str,
|
||||
baseline: tuple[int, int, int],
|
||||
*,
|
||||
successful_cycles: int,
|
||||
total_bytes: int,
|
||||
total_chunks: int,
|
||||
cycle_count: int,
|
||||
chunks_per_cycle: int,
|
||||
) -> None:
|
||||
"""Verify high-cycle audio writer metrics and resource cleanup."""
|
||||
log_stress_metrics(
|
||||
test_name,
|
||||
baseline,
|
||||
bytes_written=total_bytes,
|
||||
cycles=cycle_count,
|
||||
extra={"chunks": total_chunks, "successful_cycles": successful_cycles},
|
||||
)
|
||||
errors = check_resource_leaks(baseline)
|
||||
expected_chunks = cycle_count * chunks_per_cycle
|
||||
if successful_cycles != cycle_count:
|
||||
raise AssertionError(f"Only {successful_cycles}/{cycle_count} cycles completed")
|
||||
if total_chunks != expected_chunks:
|
||||
raise AssertionError(f"Expected {expected_chunks} chunks, wrote {total_chunks}")
|
||||
if errors:
|
||||
raise AssertionError(
|
||||
f"Resource leaks after {total_bytes} bytes written: {errors}"
|
||||
)
|
||||
|
||||
|
||||
def log_heap_metrics(
|
||||
test_name: str,
|
||||
heap_growth_bytes: int,
|
||||
cycles: int = 0,
|
||||
*,
|
||||
max_allowed_bytes: int = 0,
|
||||
filter_name: str = "",
|
||||
) -> None:
|
||||
"""Log heap growth metrics from tracemalloc tests.
|
||||
|
||||
Args:
|
||||
test_name: Name of the test.
|
||||
heap_growth_bytes: Heap growth in bytes.
|
||||
cycles: Number of operation cycles.
|
||||
max_allowed_bytes: Maximum allowed heap growth for context.
|
||||
filter_name: Filter name used (e.g., 'noteflow', 'converter').
|
||||
"""
|
||||
growth_kb = heap_growth_bytes / 1024
|
||||
max_kb = max_allowed_bytes / 1024 if max_allowed_bytes else 0
|
||||
utilization = (heap_growth_bytes / max_allowed_bytes * 100) if max_allowed_bytes else 0
|
||||
|
||||
lines = [
|
||||
f"\n{'=' * 60}",
|
||||
f"HEAP METRICS: {test_name}",
|
||||
f"{'=' * 60}",
|
||||
f" Cycles: {cycles:,}",
|
||||
f" Heap growth: {growth_kb:,.1f} KB",
|
||||
]
|
||||
|
||||
if max_allowed_bytes:
|
||||
lines.append(f" Max allowed: {max_kb:,.1f} KB ({utilization:.1f}% utilized)")
|
||||
|
||||
if filter_name:
|
||||
lines.append(f" Filter: {filter_name}")
|
||||
|
||||
lines.append(f"{'=' * 60}\n")
|
||||
print("\n".join(lines))
|
||||
|
||||
|
||||
def log_thread_metrics(
|
||||
test_name: str,
|
||||
initial_threads: int,
|
||||
final_threads: int,
|
||||
cycles: int,
|
||||
tolerance: int,
|
||||
) -> None:
|
||||
"""Log thread count metrics for thread leak tests.
|
||||
|
||||
Args:
|
||||
test_name: Name of the test.
|
||||
initial_threads: Thread count before test.
|
||||
final_threads: Thread count after test.
|
||||
cycles: Number of operation cycles.
|
||||
tolerance: Allowed thread leak tolerance.
|
||||
"""
|
||||
thread_delta = final_threads - initial_threads
|
||||
|
||||
lines = [
|
||||
f"\n{'=' * 60}",
|
||||
f"THREAD METRICS: {test_name}",
|
||||
f"{'=' * 60}",
|
||||
f" Cycles: {cycles}",
|
||||
f" Thread delta: {thread_delta:+d} ({initial_threads} -> {final_threads})",
|
||||
f" Tolerance: {tolerance}",
|
||||
f"{'=' * 60}\n",
|
||||
]
|
||||
print("\n".join(lines))
|
||||
|
||||
|
||||
def run_streaming_with_resource_tracking(
|
||||
servicer: _StreamingServicer,
|
||||
cycle_count: int,
|
||||
) -> tuple[int, list[str]]:
|
||||
"""Run streaming cycles with resource tracking.
|
||||
|
||||
Returns:
|
||||
Tuple of (successful_cycles, error_messages).
|
||||
"""
|
||||
baseline = measure_resource_baseline()
|
||||
|
||||
run_streaming_init_cleanup_cycles(servicer, cycle_count, "tracked")
|
||||
|
||||
errors = check_resource_leaks(baseline)
|
||||
return cycle_count, errors
|
||||
|
||||
|
||||
def run_interleaved_init_cleanup(
|
||||
servicer: _StreamingServicer,
|
||||
cycle_count: int,
|
||||
init_per_cycle: int = 3,
|
||||
cleanup_per_cycle: int = 1,
|
||||
) -> tuple[int, int]:
|
||||
"""Run interleaved init/cleanup pattern (init N, cleanup M, repeat).
|
||||
|
||||
Returns:
|
||||
Tuple of (total_inits, total_cleanups).
|
||||
"""
|
||||
active_meetings: list[str] = []
|
||||
total_inits = 0
|
||||
total_cleanups = 0
|
||||
|
||||
for cycle in range(cycle_count):
|
||||
# Init N meetings
|
||||
for j in range(init_per_cycle):
|
||||
meeting_id = f"interleave-{cycle}-{j}"
|
||||
servicer.init_streaming_state(meeting_id, next_segment_id=0)
|
||||
servicer.active_streams.add(meeting_id)
|
||||
active_meetings.append(meeting_id)
|
||||
total_inits += 1
|
||||
|
||||
# Cleanup M oldest
|
||||
for _ in range(min(cleanup_per_cycle, len(active_meetings))):
|
||||
old_meeting = active_meetings.pop(0)
|
||||
servicer.cleanup_streaming_state(old_meeting)
|
||||
servicer.close_audio_writer(old_meeting)
|
||||
servicer.active_streams.discard(old_meeting)
|
||||
total_cleanups += 1
|
||||
|
||||
# Cleanup remaining
|
||||
for meeting_id in active_meetings:
|
||||
servicer.cleanup_streaming_state(meeting_id)
|
||||
servicer.close_audio_writer(meeting_id)
|
||||
servicer.active_streams.discard(meeting_id)
|
||||
total_cleanups += 1
|
||||
|
||||
return total_inits, total_cleanups
|
||||
|
||||
|
||||
class _MockSessionFactory(Protocol):
|
||||
"""Protocol for mock session factory callable."""
|
||||
|
||||
def __call__(self) -> object: ...
|
||||
|
||||
|
||||
class _ServicerWithStreamStates(Protocol):
|
||||
"""Extended servicer protocol with stream_states access."""
|
||||
|
||||
active_streams: set[str]
|
||||
stream_states: dict[str, object]
|
||||
|
||||
def init_streaming_state(self, meeting_id: str, next_segment_id: int) -> None: ...
|
||||
|
||||
def cleanup_streaming_state(self, meeting_id: str) -> None: ...
|
||||
|
||||
def close_audio_writer(self, meeting_id: str) -> None: ...
|
||||
|
||||
|
||||
def run_diarization_session_cycles(
|
||||
servicer: _ServicerWithStreamStates,
|
||||
cycle_count: int,
|
||||
mock_session_factory: _MockSessionFactory,
|
||||
) -> int:
|
||||
"""Run diarization session init/cleanup cycles with mock sessions.
|
||||
|
||||
Args:
|
||||
servicer: The servicer to test (must have stream_states attribute).
|
||||
cycle_count: Number of cycles to run.
|
||||
mock_session_factory: Callable that returns mock session objects.
|
||||
|
||||
Returns:
|
||||
Number of successful cycles.
|
||||
"""
|
||||
successful = 0
|
||||
|
||||
for i in range(cycle_count):
|
||||
meeting_id = f"diarize-{i:04d}"
|
||||
mock_session = mock_session_factory()
|
||||
|
||||
servicer.init_streaming_state(meeting_id, next_segment_id=0)
|
||||
servicer.active_streams.add(meeting_id)
|
||||
|
||||
# Access stream state and set diarization session
|
||||
state = servicer.stream_states.get(meeting_id)
|
||||
if state is not None:
|
||||
object.__setattr__(state, "diarization_session", mock_session)
|
||||
|
||||
servicer.cleanup_streaming_state(meeting_id)
|
||||
servicer.close_audio_writer(meeting_id)
|
||||
servicer.active_streams.discard(meeting_id)
|
||||
successful += 1
|
||||
|
||||
return successful
|
||||
|
||||
|
||||
def create_diarization_tasks(
|
||||
task_count: int,
|
||||
sleep_seconds: float,
|
||||
) -> list[tuple[str, Awaitable[None]]]:
|
||||
"""Create diarization-like async tasks for testing.
|
||||
|
||||
Returns:
|
||||
List of (job_id, task_coroutine) tuples. Caller must schedule tasks.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
return [
|
||||
(f"job-{i:04d}", asyncio.sleep(sleep_seconds))
|
||||
for i in range(task_count)
|
||||
]
|
||||
|
||||
|
||||
def schedule_tasks_to_servicer(
|
||||
servicer: object,
|
||||
task_count: int,
|
||||
sleep_seconds: float,
|
||||
) -> list[object]:
|
||||
"""Schedule async tasks to servicer's diarization_tasks dict.
|
||||
|
||||
Args:
|
||||
servicer: Servicer with diarization_tasks dict attribute.
|
||||
task_count: Number of tasks to create.
|
||||
sleep_seconds: How long each task should sleep.
|
||||
|
||||
Returns:
|
||||
List of created task objects.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
tasks: list[object] = []
|
||||
diarization_tasks = object.__getattribute__(servicer, "diarization_tasks")
|
||||
|
||||
for i in range(task_count):
|
||||
task = asyncio.create_task(asyncio.sleep(sleep_seconds))
|
||||
diarization_tasks[f"job-{i:04d}"] = task
|
||||
tasks.append(task)
|
||||
|
||||
return tasks
|
||||
|
||||
|
||||
def verify_all_tasks_cancelled(tasks: list[object]) -> list[str]:
|
||||
"""Verify all tasks are done and cancelled.
|
||||
|
||||
Returns:
|
||||
List of error messages (empty if all tasks properly cancelled).
|
||||
"""
|
||||
errors: list[str] = []
|
||||
|
||||
for i, task in enumerate(tasks):
|
||||
done_fn = object.__getattribute__(task, "done")
|
||||
cancelled_fn = object.__getattribute__(task, "cancelled")
|
||||
|
||||
is_done = done_fn()
|
||||
is_cancelled = cancelled_fn()
|
||||
|
||||
if not is_done:
|
||||
errors.append(f"Task {i} not done after shutdown")
|
||||
if not is_cancelled:
|
||||
errors.append(f"Task {i} not cancelled after shutdown")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
async def run_webhook_executor_cycles(
|
||||
cycle_count: int,
|
||||
) -> list[tuple[bool, bool]]:
|
||||
"""Run webhook executor open/close cycles.
|
||||
|
||||
Args:
|
||||
cycle_count: Number of cycles to run.
|
||||
|
||||
Returns:
|
||||
List of (client_created, client_cleared) tuples for each cycle.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
from noteflow.infrastructure.webhooks.executor import WebhookExecutor
|
||||
|
||||
def _get_client(executor: object) -> object | None:
|
||||
"""Get executor's HTTP client."""
|
||||
client = object.__getattribute__(executor, "_client")
|
||||
return client if isinstance(client, httpx.AsyncClient) else None
|
||||
|
||||
async def _ensure_client(executor: object) -> None:
|
||||
"""Ensure executor has an HTTP client."""
|
||||
ensure_method = object.__getattribute__(executor, "_ensure_client")
|
||||
await ensure_method()
|
||||
|
||||
results: list[tuple[bool, bool]] = []
|
||||
|
||||
for _ in range(cycle_count):
|
||||
executor = WebhookExecutor()
|
||||
await _ensure_client(executor)
|
||||
client_created = _get_client(executor) is not None
|
||||
await executor.close()
|
||||
client_cleared = _get_client(executor) is None
|
||||
results.append((client_created, client_cleared))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def verify_webhook_executor_results(
|
||||
results: list[tuple[bool, bool]],
|
||||
) -> tuple[bool, bool]:
|
||||
"""Verify all webhook executor cycles succeeded.
|
||||
|
||||
Args:
|
||||
results: List of (client_created, client_cleared) tuples.
|
||||
|
||||
Returns:
|
||||
Tuple of (all_created, all_cleared).
|
||||
"""
|
||||
all_created = True
|
||||
all_cleared = True
|
||||
|
||||
for created, cleared in results:
|
||||
if not created:
|
||||
all_created = False
|
||||
if not cleared:
|
||||
all_cleared = False
|
||||
|
||||
return all_created, all_cleared
|
||||
|
||||
|
||||
def run_audio_writer_thread_cycles(
|
||||
crypto: object,
|
||||
tmp_path: object,
|
||||
cycle_count: int,
|
||||
sample_rate: int,
|
||||
) -> None:
|
||||
"""Run audio writer open/close cycles for thread leak testing.
|
||||
|
||||
Args:
|
||||
crypto: AesGcmCryptoBox instance.
|
||||
tmp_path: Path for temporary files.
|
||||
cycle_count: Number of cycles to run.
|
||||
sample_rate: Audio sample rate.
|
||||
"""
|
||||
from noteflow.infrastructure.audio.writer import MeetingAudioWriter
|
||||
|
||||
for i in range(cycle_count):
|
||||
writer = MeetingAudioWriter(crypto, tmp_path, buffer_size=1024)
|
||||
generate_dek = object.__getattribute__(crypto, "generate_dek")
|
||||
wrap_dek = object.__getattribute__(crypto, "wrap_dek")
|
||||
dek = generate_dek()
|
||||
wrapped_dek = wrap_dek(dek)
|
||||
writer.open(f"thread-test-{i:04d}", dek, wrapped_dek, sample_rate=sample_rate)
|
||||
writer.close()
|
||||
|
||||
@@ -1,5 +1,24 @@
|
||||
{
|
||||
"generated_at": "2026-01-07T01:08:35.197277+00:00",
|
||||
"rules": {},
|
||||
"generated_at": "2026-01-14T04:30:39.258014+00:00",
|
||||
"rules": {
|
||||
"deep_nesting": [
|
||||
"deep_nesting|src/noteflow/application/services/asr_config_service.py|_execute_reconfiguration|depth=3",
|
||||
"deep_nesting|src/noteflow/application/services/asr_config_service.py|shutdown|depth=3",
|
||||
"deep_nesting|src/noteflow/grpc/_mixins/diarization/_jobs.py|_execute_diarization|depth=3",
|
||||
"deep_nesting|src/noteflow/grpc/_mixins/sync.py|perform_sync|depth=3",
|
||||
"deep_nesting|src/noteflow/grpc/_service_shutdown.py|cancel_sync_tasks|depth=3"
|
||||
],
|
||||
"god_class": [
|
||||
"god_class|src/noteflow/application/services/asr_config_service.py|AsrConfigService|methods=16"
|
||||
],
|
||||
"long_method": [
|
||||
"long_method|src/noteflow/grpc/_mixins/sync.py|perform_sync|lines=60"
|
||||
],
|
||||
"module_size_soft": [
|
||||
"module_size_soft|src/noteflow/application/services/asr_config_service.py|module|lines=361",
|
||||
"module_size_soft|src/noteflow/grpc/_mixins/diarization/_jobs.py|module|lines=372",
|
||||
"module_size_soft|src/noteflow/grpc/_mixins/sync.py|module|lines=384"
|
||||
]
|
||||
},
|
||||
"schema_version": 1
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import MagicMock
|
||||
@@ -18,30 +19,76 @@ import pytest
|
||||
from noteflow.config.constants import DEFAULT_SAMPLE_RATE
|
||||
from noteflow.grpc.service import NoteFlowServicer
|
||||
from support.stress_helpers import (
|
||||
get_fd_count,
|
||||
CONCURRENT_CONTEXT_COUNT,
|
||||
CONTEXT_CANCEL_DELAY_SECONDS,
|
||||
HEAP_CONVERTER_MAX_GROWTH_BYTES,
|
||||
HEAP_MAX_GROWTH_BYTES,
|
||||
HEAP_PROTO_MAX_GROWTH_BYTES,
|
||||
MEMORY_PRESSURE_ALLOCATION_MB,
|
||||
ORM_CONVERSION_CYCLES,
|
||||
PROTO_CYCLES,
|
||||
RIGOROUS_AUDIO_CHUNKS_PER_CYCLE,
|
||||
RIGOROUS_CONCURRENT_MEETINGS,
|
||||
SLOW_ENTER_SLEEP_SECONDS,
|
||||
calculate_filtered_heap_growth,
|
||||
calculate_noteflow_heap_growth,
|
||||
check_resource_leaks,
|
||||
log_heap_metrics,
|
||||
log_stress_metrics,
|
||||
log_thread_metrics,
|
||||
measure_resource_baseline,
|
||||
measure_rss_bytes,
|
||||
run_audio_writer_cycles,
|
||||
run_audio_writer_high_cycle,
|
||||
run_audio_writer_thread_cycles,
|
||||
run_concurrent_context_tasks,
|
||||
run_concurrent_streaming_sessions,
|
||||
run_diarization_session_cycles,
|
||||
run_gc_collection_behavior,
|
||||
run_interleaved_init_cleanup,
|
||||
run_operations_under_memory_pressure,
|
||||
run_orm_conversion_cycles,
|
||||
run_protobuf_cycles,
|
||||
run_streaming_init_cleanup_cycles,
|
||||
run_webhook_executor_cycles,
|
||||
schedule_tasks_to_servicer,
|
||||
verify_all_tasks_cancelled,
|
||||
verify_audio_writer_high_cycle_metrics,
|
||||
verify_webhook_executor_results,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator
|
||||
from threading import Thread
|
||||
|
||||
import httpx
|
||||
|
||||
from noteflow.infrastructure.security.crypto import AesGcmCryptoBox
|
||||
|
||||
# Test constants
|
||||
STREAMING_CYCLES = 50
|
||||
AUDIO_WRITER_CYCLES = 20
|
||||
MEMORY_TEST_CYCLES = 100
|
||||
INTERLEAVE_CYCLES = 20
|
||||
DIARIZATION_CYCLES = 50
|
||||
WRITER_THREAD_CYCLES = 10
|
||||
# Test constants - rigorous values for real stress testing
|
||||
STREAMING_CYCLES = 500 # 10x increase for slow leak detection
|
||||
AUDIO_WRITER_CYCLES = 100 # 5x increase with real data writes
|
||||
MEMORY_TEST_CYCLES = 1000 # 10x increase for state dict stability
|
||||
INTERLEAVE_CYCLES = 200 # 10x increase for interleaved patterns
|
||||
DIARIZATION_CYCLES = 200 # 4x increase for session lifecycle
|
||||
WRITER_THREAD_CYCLES = 50 # 5x increase for thread lifecycle
|
||||
FD_LEAK_TOLERANCE = 10
|
||||
THREAD_LEAK_TOLERANCE = 2
|
||||
TASK_TEST_COUNT = 5
|
||||
TASK_TEST_COUNT = 50 # 10x increase for coroutine leak detection
|
||||
TASK_SLEEP_SECONDS = 10
|
||||
WEBHOOK_EXECUTOR_CYCLES = 5
|
||||
WEBHOOK_EXECUTOR_CYCLES = 20 # 4x increase for HTTP client lifecycle
|
||||
BYTES_PER_KB = 1024
|
||||
|
||||
# High-stress constants (use sparingly - these take longer)
|
||||
HIGH_STRESS_STREAMING_CYCLES = 2000
|
||||
HIGH_STRESS_SEGMENTER_CYCLES = 500
|
||||
HIGH_STRESS_SEGMENTER_CHUNKS = 200
|
||||
|
||||
# Async cleanup wait times
|
||||
ASYNC_CLEANUP_DELAY_SECONDS = 0.1
|
||||
FLUSH_THREAD_CLEANUP_DELAY_SECONDS = 0.2
|
||||
FAILING_TASK_DELAY_SECONDS = 0.1
|
||||
THREAD_STOP_WAIT_SECONDS = 0.5 # Longer wait for thread pool cleanup
|
||||
|
||||
|
||||
def _get_executor_client(executor: object) -> httpx.AsyncClient | None:
|
||||
@@ -96,57 +143,120 @@ def _get_writer_flush_thread(writer: object) -> Thread | None:
|
||||
|
||||
|
||||
class TestFileDescriptorLeaks:
|
||||
"""Detect file descriptor leaks."""
|
||||
"""Detect file descriptor and thread leaks under high load."""
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_fd_cleanup(
|
||||
self, memory_servicer: NoteFlowServicer, initial_fd_count: int
|
||||
self, memory_servicer: NoteFlowServicer
|
||||
) -> None:
|
||||
"""Verify file descriptors released after streaming cycles."""
|
||||
"""Verify FDs and threads released after 500 streaming cycles."""
|
||||
baseline = measure_resource_baseline()
|
||||
|
||||
run_streaming_init_cleanup_cycles(memory_servicer, STREAMING_CYCLES)
|
||||
|
||||
gc.collect()
|
||||
await asyncio.sleep(0.1) # Allow async cleanup
|
||||
await asyncio.sleep(ASYNC_CLEANUP_DELAY_SECONDS)
|
||||
|
||||
final_fds = get_fd_count()
|
||||
|
||||
assert final_fds <= initial_fd_count + FD_LEAK_TOLERANCE, (
|
||||
f"File descriptor leak: started with {initial_fd_count}, ended with {final_fds}"
|
||||
)
|
||||
errors = check_resource_leaks(baseline)
|
||||
assert not errors, f"Resource leaks after {STREAMING_CYCLES} cycles: {errors}"
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_writer_fd_cleanup(
|
||||
self, tmp_path: Path, initial_fd_count: int, crypto: AesGcmCryptoBox
|
||||
self, tmp_path: Path, crypto: AesGcmCryptoBox
|
||||
) -> None:
|
||||
"""Verify audio writer closes file handles."""
|
||||
run_audio_writer_cycles(crypto, tmp_path, AUDIO_WRITER_CYCLES)
|
||||
"""Verify audio writer closes FDs and threads after 100 cycles."""
|
||||
baseline = measure_resource_baseline()
|
||||
|
||||
bytes_per_cycle = run_audio_writer_cycles(crypto, tmp_path, AUDIO_WRITER_CYCLES)
|
||||
total_bytes = sum(bytes_per_cycle)
|
||||
|
||||
gc.collect()
|
||||
await asyncio.sleep(0.1)
|
||||
await asyncio.sleep(ASYNC_CLEANUP_DELAY_SECONDS)
|
||||
|
||||
final_fds = get_fd_count()
|
||||
assert final_fds <= initial_fd_count + FD_LEAK_TOLERANCE // 2, (
|
||||
f"Audio writer FD leak: {initial_fd_count} -> {final_fds}"
|
||||
log_stress_metrics(
|
||||
"test_audio_writer_fd_cleanup",
|
||||
baseline,
|
||||
bytes_written=total_bytes,
|
||||
cycles=AUDIO_WRITER_CYCLES,
|
||||
)
|
||||
errors = check_resource_leaks(baseline)
|
||||
assert not errors, f"Audio writer resource leaks: {errors}"
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_writer_high_cycle_with_real_data(
|
||||
self, tmp_path: Path, crypto: AesGcmCryptoBox
|
||||
) -> None:
|
||||
"""Verify no leaks after writing real audio data across many cycles.
|
||||
|
||||
This test writes actual audio chunks (not just open/close) to stress
|
||||
the buffer management, flush thread, and encryption pipeline.
|
||||
"""
|
||||
baseline = measure_resource_baseline()
|
||||
|
||||
successful, total_bytes, total_chunks = run_audio_writer_high_cycle(
|
||||
crypto,
|
||||
tmp_path,
|
||||
cycle_count=AUDIO_WRITER_CYCLES,
|
||||
chunks_per_cycle=RIGOROUS_AUDIO_CHUNKS_PER_CYCLE,
|
||||
)
|
||||
|
||||
gc.collect()
|
||||
await asyncio.sleep(FLUSH_THREAD_CLEANUP_DELAY_SECONDS)
|
||||
|
||||
verify_audio_writer_high_cycle_metrics(
|
||||
"test_audio_writer_high_cycle_with_real_data",
|
||||
baseline,
|
||||
successful_cycles=successful,
|
||||
total_bytes=total_bytes,
|
||||
total_chunks=total_chunks,
|
||||
cycle_count=AUDIO_WRITER_CYCLES,
|
||||
chunks_per_cycle=RIGOROUS_AUDIO_CHUNKS_PER_CYCLE,
|
||||
)
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_meetings_resource_cleanup(
|
||||
self, memory_servicer: NoteFlowServicer
|
||||
) -> None:
|
||||
"""Verify resources cleaned up after concurrent meeting sessions."""
|
||||
baseline = measure_resource_baseline()
|
||||
|
||||
successful, session_errors = run_concurrent_streaming_sessions(
|
||||
memory_servicer,
|
||||
meeting_count=RIGOROUS_CONCURRENT_MEETINGS,
|
||||
chunks_per_meeting=0,
|
||||
)
|
||||
|
||||
gc.collect()
|
||||
await asyncio.sleep(ASYNC_CLEANUP_DELAY_SECONDS)
|
||||
|
||||
log_stress_metrics(
|
||||
"test_concurrent_meetings_resource_cleanup",
|
||||
baseline,
|
||||
cycles=RIGOROUS_CONCURRENT_MEETINGS,
|
||||
extra={"successful": successful, "session_errors": len(session_errors)},
|
||||
)
|
||||
errors = check_resource_leaks(baseline)
|
||||
all_errors = session_errors + errors
|
||||
assert successful == RIGOROUS_CONCURRENT_MEETINGS, (
|
||||
f"Only {successful}/{RIGOROUS_CONCURRENT_MEETINGS} meetings succeeded"
|
||||
)
|
||||
assert not all_errors, f"Concurrent session errors: {all_errors}"
|
||||
|
||||
|
||||
class TestMemoryLeaks:
|
||||
"""Detect memory leaks under load."""
|
||||
"""Detect memory leaks under high load (1000+ cycles)."""
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_state_memory_stability(
|
||||
self, memory_servicer: NoteFlowServicer
|
||||
) -> None:
|
||||
"""Verify state dicts don't grow unbounded during streaming cycles."""
|
||||
for i in range(MEMORY_TEST_CYCLES):
|
||||
meeting_id = f"mem-test-{i:03d}"
|
||||
memory_servicer.init_streaming_state(meeting_id, next_segment_id=0)
|
||||
memory_servicer.active_streams.add(meeting_id)
|
||||
memory_servicer.cleanup_streaming_state(meeting_id)
|
||||
memory_servicer.active_streams.discard(meeting_id)
|
||||
"""Verify state dicts don't grow unbounded during 1000 streaming cycles."""
|
||||
run_streaming_init_cleanup_cycles(memory_servicer, MEMORY_TEST_CYCLES)
|
||||
|
||||
gc.collect()
|
||||
|
||||
@@ -161,32 +271,17 @@ class TestMemoryLeaks:
|
||||
async def test_resource_interleaved_init_cleanup(
|
||||
self, memory_servicer: NoteFlowServicer
|
||||
) -> None:
|
||||
"""Verify memory stability with interleaved init/cleanup patterns."""
|
||||
active_meetings: list[str] = []
|
||||
|
||||
# Interleaved pattern: init 3, cleanup 1, repeat
|
||||
for cycle in range(INTERLEAVE_CYCLES):
|
||||
# Init 3 meetings
|
||||
for j in range(3):
|
||||
meeting_id = f"interleave-{cycle}-{j}"
|
||||
memory_servicer.init_streaming_state(meeting_id, next_segment_id=0)
|
||||
memory_servicer.active_streams.add(meeting_id)
|
||||
active_meetings.append(meeting_id)
|
||||
|
||||
# Cleanup 1 oldest
|
||||
if active_meetings:
|
||||
old_meeting = active_meetings.pop(0)
|
||||
memory_servicer.cleanup_streaming_state(old_meeting)
|
||||
memory_servicer.active_streams.discard(old_meeting)
|
||||
|
||||
# Cleanup remaining
|
||||
for meeting_id in active_meetings:
|
||||
memory_servicer.cleanup_streaming_state(meeting_id)
|
||||
memory_servicer.active_streams.discard(meeting_id)
|
||||
"""Verify memory stability with 200 interleaved init/cleanup patterns."""
|
||||
total_inits, total_cleanups = run_interleaved_init_cleanup(
|
||||
memory_servicer, INTERLEAVE_CYCLES
|
||||
)
|
||||
|
||||
gc.collect()
|
||||
|
||||
# Verify clean state
|
||||
# Verify clean state and all operations completed
|
||||
assert total_inits == total_cleanups, (
|
||||
f"Mismatched init/cleanup: {total_inits} inits, {total_cleanups} cleanups"
|
||||
)
|
||||
assert len(memory_servicer.vad_instances) == 0, "VAD leaked after interleave"
|
||||
assert len(memory_servicer.active_streams) == 0, "Streams leaked after interleave"
|
||||
|
||||
@@ -195,21 +290,22 @@ class TestMemoryLeaks:
|
||||
async def test_diarization_session_memory(
|
||||
self, memory_servicer: NoteFlowServicer
|
||||
) -> None:
|
||||
"""Verify diarization sessions don't leak memory."""
|
||||
# Create and close many sessions
|
||||
for i in range(DIARIZATION_CYCLES):
|
||||
meeting_id = f"diarize-mem-{i:03d}"
|
||||
mock_session = MagicMock()
|
||||
mock_session.close = MagicMock()
|
||||
"""Verify 200 diarization session cycles don't leak memory."""
|
||||
|
||||
memory_servicer.init_streaming_state(meeting_id, next_segment_id=0)
|
||||
state = memory_servicer.get_stream_state(meeting_id)
|
||||
assert state is not None
|
||||
state.diarization_session = mock_session
|
||||
memory_servicer.cleanup_streaming_state(meeting_id)
|
||||
def create_mock_session() -> MagicMock:
|
||||
mock = MagicMock()
|
||||
mock.close = MagicMock()
|
||||
return mock
|
||||
|
||||
successful = run_diarization_session_cycles(
|
||||
memory_servicer, DIARIZATION_CYCLES, create_mock_session
|
||||
)
|
||||
|
||||
gc.collect()
|
||||
|
||||
assert successful == DIARIZATION_CYCLES, (
|
||||
f"Only {successful}/{DIARIZATION_CYCLES} cycles completed"
|
||||
)
|
||||
assert len(memory_servicer.stream_states) == 0, "Diarization sessions leaked"
|
||||
|
||||
|
||||
@@ -221,14 +317,10 @@ class TestCoroutineLeaks:
|
||||
self, memory_servicer: NoteFlowServicer
|
||||
) -> None:
|
||||
"""Verify no tasks remain after servicer shutdown."""
|
||||
# Create some diarization tasks with explicit type annotation
|
||||
tasks_created: list[asyncio.Task[None]] = []
|
||||
for i in range(TASK_TEST_COUNT):
|
||||
task: asyncio.Task[None] = asyncio.create_task(
|
||||
asyncio.sleep(TASK_SLEEP_SECONDS)
|
||||
)
|
||||
memory_servicer.diarization_tasks[f"job-{i}"] = task
|
||||
tasks_created.append(task)
|
||||
# Create tasks using helper (avoids inline loop)
|
||||
tasks_created = schedule_tasks_to_servicer(
|
||||
memory_servicer, TASK_TEST_COUNT, TASK_SLEEP_SECONDS
|
||||
)
|
||||
|
||||
# Shutdown
|
||||
await memory_servicer.shutdown()
|
||||
@@ -238,10 +330,9 @@ class TestCoroutineLeaks:
|
||||
f"Expected 0 diarization tasks after shutdown, found {len(memory_servicer.diarization_tasks)}"
|
||||
)
|
||||
|
||||
# Check tasks are cancelled
|
||||
for task in tasks_created:
|
||||
assert task.done(), "Task should be done after shutdown"
|
||||
assert task.cancelled(), "Task should be cancelled after shutdown"
|
||||
# Check tasks are cancelled using helper (avoids inline loop)
|
||||
errors = verify_all_tasks_cancelled(tasks_created)
|
||||
assert not errors, f"Task cleanup errors: {errors}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_cleanup_on_exception(
|
||||
@@ -250,7 +341,7 @@ class TestCoroutineLeaks:
|
||||
"""Verify tasks cleaned up even if task raises exception."""
|
||||
|
||||
async def failing_task() -> None:
|
||||
await asyncio.sleep(0.1)
|
||||
await asyncio.sleep(FAILING_TASK_DELAY_SECONDS)
|
||||
raise ValueError("Task failed")
|
||||
|
||||
task = asyncio.create_task(failing_task())
|
||||
@@ -303,23 +394,13 @@ class TestWebhookClientLeaks:
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_executor_multiple_cycles(self) -> None:
|
||||
"""Verify webhook executor can be opened and closed multiple times."""
|
||||
from noteflow.infrastructure.webhooks.executor import WebhookExecutor
|
||||
# Run cycles using helper function (avoids inline loop)
|
||||
results = await run_webhook_executor_cycles(WEBHOOK_EXECUTOR_CYCLES)
|
||||
|
||||
async def run_executor_cycle() -> tuple[bool, bool]:
|
||||
"""Run one executor open/close cycle, return (client_created, client_cleared)."""
|
||||
executor = WebhookExecutor()
|
||||
await _ensure_executor_client(executor)
|
||||
client_created = _get_executor_client(executor) is not None
|
||||
await executor.close()
|
||||
client_cleared = _get_executor_client(executor) is None
|
||||
return client_created, client_cleared
|
||||
|
||||
# Run cycles and collect results
|
||||
results = [await run_executor_cycle() for _ in range(WEBHOOK_EXECUTOR_CYCLES)]
|
||||
|
||||
# Verify all cycles succeeded
|
||||
assert all(created for created, _ in results), "Client creation failed in some cycles"
|
||||
assert all(cleared for _, cleared in results), "Client cleanup failed in some cycles"
|
||||
# Verify all cycles succeeded using helper (avoids inline loop)
|
||||
all_created, all_cleared = verify_webhook_executor_results(results)
|
||||
assert all_created, "Client creation failed in some cycles"
|
||||
assert all_cleared, "Client cleanup failed in some cycles"
|
||||
|
||||
|
||||
class TestDiarizationSessionLeaks:
|
||||
@@ -395,7 +476,7 @@ class TestAudioWriterThreadLeaks:
|
||||
writer.close()
|
||||
|
||||
# Give thread time to stop
|
||||
await asyncio.sleep(0.1)
|
||||
await asyncio.sleep(ASYNC_CLEANUP_DELAY_SECONDS)
|
||||
|
||||
# Thread should be stopped and cleared
|
||||
assert _get_writer_flush_thread(writer) is None, "Flush thread should be cleared after close"
|
||||
@@ -406,27 +487,326 @@ class TestAudioWriterThreadLeaks:
|
||||
"""Verify no thread leaks across multiple writer cycles."""
|
||||
import threading
|
||||
|
||||
from noteflow.infrastructure.audio.writer import MeetingAudioWriter
|
||||
from noteflow.infrastructure.security.crypto import AesGcmCryptoBox
|
||||
from noteflow.infrastructure.security.keystore import InMemoryKeyStore
|
||||
|
||||
crypto = AesGcmCryptoBox(InMemoryKeyStore())
|
||||
initial_threads = threading.active_count()
|
||||
|
||||
# Create and close writers
|
||||
for i in range(WRITER_THREAD_CYCLES):
|
||||
writer = MeetingAudioWriter(crypto, tmp_path, buffer_size=1024)
|
||||
dek = crypto.generate_dek()
|
||||
wrapped_dek = crypto.wrap_dek(dek)
|
||||
writer.open(f"thread-test-{i}", dek, wrapped_dek, sample_rate=DEFAULT_SAMPLE_RATE)
|
||||
writer.close()
|
||||
# Create and close writers using helper (avoids inline loop)
|
||||
run_audio_writer_thread_cycles(
|
||||
crypto, tmp_path, WRITER_THREAD_CYCLES, DEFAULT_SAMPLE_RATE
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.5) # Allow threads to fully stop
|
||||
await asyncio.sleep(THREAD_STOP_WAIT_SECONDS) # Allow threads to fully stop
|
||||
gc.collect()
|
||||
|
||||
final_threads = threading.active_count()
|
||||
|
||||
log_thread_metrics(
|
||||
"test_multiple_writer_cycles",
|
||||
initial_threads,
|
||||
final_threads,
|
||||
WRITER_THREAD_CYCLES,
|
||||
THREAD_LEAK_TOLERANCE,
|
||||
)
|
||||
|
||||
# Should not have leaked threads
|
||||
assert final_threads <= initial_threads + THREAD_LEAK_TOLERANCE, (
|
||||
f"Thread leak: {initial_threads} -> {final_threads}"
|
||||
)
|
||||
|
||||
|
||||
class TestHeapGrowth:
|
||||
"""Detect heap growth using tracemalloc snapshots."""
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_heap_stability(
|
||||
self, memory_servicer: NoteFlowServicer
|
||||
) -> None:
|
||||
"""Verify heap doesn't grow unbounded during streaming cycles."""
|
||||
import tracemalloc
|
||||
|
||||
tracemalloc.start()
|
||||
gc.collect()
|
||||
snapshot_before = tracemalloc.take_snapshot()
|
||||
|
||||
run_streaming_init_cleanup_cycles(memory_servicer, STREAMING_CYCLES)
|
||||
|
||||
gc.collect()
|
||||
snapshot_after = tracemalloc.take_snapshot()
|
||||
tracemalloc.stop()
|
||||
|
||||
noteflow_growth = calculate_noteflow_heap_growth(snapshot_before, snapshot_after)
|
||||
|
||||
log_heap_metrics(
|
||||
"test_streaming_heap_stability",
|
||||
noteflow_growth,
|
||||
cycles=STREAMING_CYCLES,
|
||||
max_allowed_bytes=HEAP_MAX_GROWTH_BYTES,
|
||||
filter_name="noteflow",
|
||||
)
|
||||
assert noteflow_growth < HEAP_MAX_GROWTH_BYTES, (
|
||||
f"Heap growth detected: {noteflow_growth / BYTES_PER_KB:.1f}KB in noteflow code"
|
||||
)
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_orm_conversion_heap_stability(self) -> None:
|
||||
"""Verify ORM conversions don't leak memory."""
|
||||
import tracemalloc
|
||||
from uuid import uuid4
|
||||
|
||||
from noteflow.infrastructure.converters.orm_converters import OrmConverter
|
||||
|
||||
tracemalloc.start()
|
||||
gc.collect()
|
||||
snapshot_before = tracemalloc.take_snapshot()
|
||||
|
||||
converter = OrmConverter()
|
||||
meeting_id = uuid4()
|
||||
run_orm_conversion_cycles(converter, meeting_id, ORM_CONVERSION_CYCLES)
|
||||
|
||||
gc.collect()
|
||||
snapshot_after = tracemalloc.take_snapshot()
|
||||
tracemalloc.stop()
|
||||
|
||||
converter_growth = calculate_filtered_heap_growth(
|
||||
snapshot_before, snapshot_after, "converter"
|
||||
)
|
||||
|
||||
log_heap_metrics(
|
||||
"test_orm_conversion_heap_stability",
|
||||
converter_growth,
|
||||
cycles=ORM_CONVERSION_CYCLES,
|
||||
max_allowed_bytes=HEAP_CONVERTER_MAX_GROWTH_BYTES,
|
||||
filter_name="converter",
|
||||
)
|
||||
assert converter_growth < HEAP_CONVERTER_MAX_GROWTH_BYTES, (
|
||||
f"Converter heap growth: {converter_growth / BYTES_PER_KB:.1f}KB"
|
||||
)
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_protobuf_heap_stability(self) -> None:
|
||||
"""Verify protobuf operations don't leak memory."""
|
||||
import tracemalloc
|
||||
|
||||
tracemalloc.start()
|
||||
gc.collect()
|
||||
snapshot_before = tracemalloc.take_snapshot()
|
||||
|
||||
run_protobuf_cycles(PROTO_CYCLES)
|
||||
|
||||
gc.collect()
|
||||
snapshot_after = tracemalloc.take_snapshot()
|
||||
tracemalloc.stop()
|
||||
|
||||
proto_growth = calculate_filtered_heap_growth(
|
||||
snapshot_before, snapshot_after, "noteflow_pb2"
|
||||
)
|
||||
|
||||
log_heap_metrics(
|
||||
"test_protobuf_heap_stability",
|
||||
proto_growth,
|
||||
cycles=PROTO_CYCLES,
|
||||
max_allowed_bytes=HEAP_PROTO_MAX_GROWTH_BYTES,
|
||||
filter_name="noteflow_pb2",
|
||||
)
|
||||
assert proto_growth < HEAP_PROTO_MAX_GROWTH_BYTES, (
|
||||
f"Protobuf heap growth: {proto_growth / BYTES_PER_KB:.1f}KB"
|
||||
)
|
||||
|
||||
|
||||
class TestAsyncContextEdgeCases:
|
||||
"""Test async context manager edge cases."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_exception_cleanup(self) -> None:
|
||||
"""Verify cleanup runs even when exception raised inside context."""
|
||||
cleanup_called = False
|
||||
|
||||
@asynccontextmanager
|
||||
async def tracked_context() -> AsyncIterator[str]:
|
||||
nonlocal cleanup_called
|
||||
try:
|
||||
yield "session"
|
||||
finally:
|
||||
cleanup_called = True
|
||||
|
||||
with pytest.raises(ValueError, match="test error"):
|
||||
async with tracked_context():
|
||||
raise ValueError("test error")
|
||||
|
||||
assert cleanup_called, "Cleanup should run despite exception"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nested_context_cleanup_order(self) -> None:
|
||||
"""Verify nested contexts clean up in correct (LIFO) order."""
|
||||
cleanup_order: list[str] = []
|
||||
|
||||
@asynccontextmanager
|
||||
async def tracked_context(name: str) -> AsyncIterator[str]:
|
||||
try:
|
||||
yield name
|
||||
finally:
|
||||
cleanup_order.append(name)
|
||||
|
||||
async with (
|
||||
tracked_context("outer"),
|
||||
tracked_context("middle"),
|
||||
tracked_context("inner"),
|
||||
):
|
||||
pass
|
||||
|
||||
assert cleanup_order == ["inner", "middle", "outer"], (
|
||||
f"Wrong cleanup order: {cleanup_order}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_cancellation_during_body(self) -> None:
|
||||
"""Verify cleanup runs when task cancelled inside context."""
|
||||
cleanup_called = False
|
||||
|
||||
@asynccontextmanager
|
||||
async def tracked_context() -> AsyncIterator[str]:
|
||||
nonlocal cleanup_called
|
||||
try:
|
||||
yield "session"
|
||||
finally:
|
||||
cleanup_called = True
|
||||
|
||||
async def task_with_context() -> None:
|
||||
async with tracked_context():
|
||||
await asyncio.sleep(SLOW_ENTER_SLEEP_SECONDS)
|
||||
|
||||
task = asyncio.create_task(task_with_context())
|
||||
await asyncio.sleep(CONTEXT_CANCEL_DELAY_SECONDS)
|
||||
task.cancel()
|
||||
|
||||
with pytest.raises(asyncio.CancelledError, match=""):
|
||||
await task
|
||||
|
||||
assert cleanup_called, "Cleanup should run on cancellation"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_cancellation_during_enter(self) -> None:
|
||||
"""Verify proper handling when cancelled during __aenter__."""
|
||||
enter_started = False
|
||||
|
||||
@asynccontextmanager
|
||||
async def slow_enter_context() -> AsyncIterator[str]:
|
||||
nonlocal enter_started
|
||||
enter_started = True
|
||||
await asyncio.sleep(SLOW_ENTER_SLEEP_SECONDS)
|
||||
yield "session"
|
||||
|
||||
async def task_with_slow_enter() -> None:
|
||||
async with slow_enter_context():
|
||||
pass
|
||||
|
||||
task = asyncio.create_task(task_with_slow_enter())
|
||||
await asyncio.sleep(CONTEXT_CANCEL_DELAY_SECONDS)
|
||||
task.cancel()
|
||||
|
||||
with pytest.raises(asyncio.CancelledError, match=""):
|
||||
await task
|
||||
|
||||
assert enter_started, "Enter should have started"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_exception_in_cleanup(self) -> None:
|
||||
"""Verify cleanup exception propagates when cleanup fails.
|
||||
|
||||
Note: In Python, when an exception occurs in a finally block during
|
||||
exception handling, the cleanup exception replaces the original.
|
||||
This test verifies the cleanup exception properly propagates.
|
||||
"""
|
||||
cleanup_executed = False
|
||||
|
||||
@asynccontextmanager
|
||||
async def failing_cleanup_context() -> AsyncIterator[str]:
|
||||
nonlocal cleanup_executed
|
||||
try:
|
||||
yield "session"
|
||||
finally:
|
||||
cleanup_executed = True
|
||||
raise RuntimeError("cleanup failed")
|
||||
|
||||
with pytest.raises(RuntimeError, match="cleanup failed"):
|
||||
async with failing_cleanup_context():
|
||||
raise ValueError("original error")
|
||||
|
||||
assert cleanup_executed, "Cleanup should have executed despite body exception"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_context_isolation(self) -> None:
|
||||
"""Verify concurrent context managers don't interfere."""
|
||||
active_contexts: set[str] = set()
|
||||
results: dict[str, int] = {"max_concurrent": 0}
|
||||
context_ids = [f"ctx-{i}" for i in range(CONCURRENT_CONTEXT_COUNT)]
|
||||
|
||||
tasks = run_concurrent_context_tasks(context_ids, active_contexts, results)
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
assert len(active_contexts) == 0, "All contexts should be cleaned up"
|
||||
assert results["max_concurrent"] == CONCURRENT_CONTEXT_COUNT, (
|
||||
f"Expected {CONCURRENT_CONTEXT_COUNT} concurrent, got {results['max_concurrent']}"
|
||||
)
|
||||
|
||||
|
||||
class TestMemoryPressure:
|
||||
"""Test behavior under memory pressure conditions."""
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_operations_succeed_under_memory_pressure(
|
||||
self, memory_servicer: NoteFlowServicer
|
||||
) -> None:
|
||||
"""Verify streaming operations complete under memory pressure."""
|
||||
successful, _gc_count = run_operations_under_memory_pressure(
|
||||
memory_servicer,
|
||||
cycle_count=STREAMING_CYCLES,
|
||||
pressure_mb=MEMORY_PRESSURE_ALLOCATION_MB,
|
||||
)
|
||||
|
||||
assert successful == STREAMING_CYCLES, (
|
||||
f"Only {successful}/{STREAMING_CYCLES} cycles succeeded under pressure"
|
||||
)
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_cleanup_under_pressure(
|
||||
self, memory_servicer: NoteFlowServicer
|
||||
) -> None:
|
||||
"""Verify cleanup properly releases state even under pressure."""
|
||||
run_operations_under_memory_pressure(
|
||||
memory_servicer,
|
||||
cycle_count=STREAMING_CYCLES,
|
||||
pressure_mb=MEMORY_PRESSURE_ALLOCATION_MB,
|
||||
)
|
||||
|
||||
gc.collect()
|
||||
await asyncio.sleep(ASYNC_CLEANUP_DELAY_SECONDS)
|
||||
|
||||
assert len(memory_servicer.vad_instances) == 0, "VAD instances leaked under pressure"
|
||||
assert len(memory_servicer.active_streams) == 0, "Streams leaked under pressure"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gc_collects_temporary_objects(self) -> None:
|
||||
"""Verify GC properly collects temporary objects."""
|
||||
gen0_diff, gen1_diff, gen2_diff = run_gc_collection_behavior()
|
||||
|
||||
total_uncollected = gen0_diff + gen1_diff + gen2_diff
|
||||
assert total_uncollected >= 0, "GC count should not be negative"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rss_measurement_returns_valid_value(self) -> None:
|
||||
"""Verify RSS measurement returns valid value or graceful fallback."""
|
||||
rss = measure_rss_bytes()
|
||||
|
||||
# Either measurement works (positive) or gracefully returns -1
|
||||
valid_positive = rss > 0
|
||||
valid_fallback = rss == -1
|
||||
assert valid_positive or valid_fallback, f"Invalid RSS value: {rss}"
|
||||
|
||||
Reference in New Issue
Block a user