Files
noteflow/tests/integration/test_e2e_streaming.py
Travis Vasceannie 301482c410
Some checks failed
CI / test-python (push) Successful in 8m41s
CI / test-typescript (push) Failing after 6m2s
CI / test-rust (push) Failing after 4m28s
Refactor: Improve CI workflow robustness and test environment variable management, and enable parallel quality test execution.
2026-01-26 02:04:38 +00:00

546 lines
21 KiB
Python

"""End-to-end integration tests for streaming transcription.
Tests the complete audio streaming pipeline with database persistence:
- Stream initialization with database
- Audio chunk processing and segment creation
- Streaming diarization turn recovery
- Stream cleanup and meeting state transitions
- Error handling during streaming
"""
from __future__ import annotations
import asyncio
import time
from collections.abc import AsyncIterator, Callable, Sequence
from pathlib import Path
from typing import TYPE_CHECKING, cast
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
import grpc
import numpy as np
import pytest
from numpy.typing import NDArray
from noteflow.config.constants import DEFAULT_SAMPLE_RATE
from noteflow.domain.entities import Meeting, Segment
from noteflow.domain.value_objects import AudioSource, MeetingId, MeetingState, SpeakerRole
from noteflow.grpc.mixins.streaming import StreamingMixin
from noteflow.grpc.proto import noteflow_pb2
from noteflow.grpc.service import NoteFlowServicer
from noteflow.grpc.stream_state import MeetingStreamState
from noteflow.infrastructure.audio.partial_buffer import PartialAudioBuffer
from noteflow.infrastructure.persistence.repositories import StreamingTurn
from noteflow.infrastructure.persistence.unit_of_work import SqlAlchemyUnitOfWork
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from noteflow.grpc.mixins._types import GrpcContext
from support.async_helpers import drain_async_gen, yield_control
# Type alias for StreamTranscription method signature
_StreamMethod = Callable[
[NoteFlowServicer, AsyncIterator[noteflow_pb2.AudioChunk], GrpcContext],
AsyncIterator[noteflow_pb2.TranscriptUpdate],
]
class TypedServicer(NoteFlowServicer):
"""Subclass with explicit type annotation for StreamTranscription.
NoteFlowServicer inherits StreamTranscription from StreamingMixin,
but pyright cannot infer the return type through mixin inheritance.
This wrapper provides explicit type annotations for the method.
"""
def StreamTranscription(
self,
request_iterator: AsyncIterator[noteflow_pb2.AudioChunk],
context: GrpcContext,
) -> AsyncIterator[noteflow_pb2.TranscriptUpdate]:
"""Stream transcription with explicit return type annotation."""
# Cast required: mixin uses self: ServicerHost type narrowing which
# prevents pyright from inferring return type through MRO.
# Use getattr to avoid reportUnknownMemberType on mixin access.
stream_method: _StreamMethod = cast(
_StreamMethod,
StreamingMixin.StreamTranscription,
)
return stream_method(self, request_iterator, context)
class MockContext:
"""Mock gRPC context for testing."""
def __init__(self) -> None:
"""Initialize mock context."""
self.aborted = False
self.abort_code: grpc.StatusCode | None = None
self.abort_details: str | None = None
async def abort(self, code: grpc.StatusCode, details: str) -> None:
"""Record abort and raise to simulate gRPC behavior."""
self.aborted = True
self.abort_code = code
self.abort_details = details
raise grpc.RpcError()
def invocation_metadata(self) -> list[tuple[str, str]]:
"""Return empty metadata for mock context."""
return []
def create_audio_chunk(
meeting_id: str,
audio: NDArray[np.float32] | None = None,
sample_rate: int = DEFAULT_SAMPLE_RATE,
channels: int = 1,
) -> noteflow_pb2.AudioChunk:
"""Create an audio chunk protobuf message."""
if audio is None:
audio = np.random.randn(1600).astype(np.float32) * 0.1
# Type narrowing: audio is guaranteed to be NDArray after None check
audio_array: NDArray[np.float32] = audio
return noteflow_pb2.AudioChunk(
meeting_id=meeting_id,
audio_data=audio_array.tobytes(),
sample_rate=sample_rate,
channels=channels,
)
async def audio_chunk_iterator(
meeting_id: str,
num_chunks: int = 5,
chunk_duration_ms: int = 100,
sample_rate: int = DEFAULT_SAMPLE_RATE,
) -> AsyncIterator[noteflow_pb2.AudioChunk]:
"""Generate audio chunks for testing."""
samples_per_chunk = int(sample_rate * chunk_duration_ms / 1000)
for _ in range(num_chunks):
audio = np.random.randn(samples_per_chunk).astype(np.float32) * 0.1
yield create_audio_chunk(meeting_id, audio, sample_rate)
await asyncio.sleep(0.01)
@pytest.mark.integration
class TestStreamInitialization:
"""Integration tests for stream initialization with database."""
async def test_stream_init_loads_meeting_from_database(
self, session_factory: async_sessionmaker[AsyncSession], meetings_dir: Path
) -> None:
"""Test stream initialization loads meeting from database."""
async with SqlAlchemyUnitOfWork(session_factory, meetings_dir) as uow:
meeting = Meeting.create(title="Stream Test")
await uow.meetings.create(meeting)
await uow.commit()
mock_asr = MagicMock()
mock_asr.is_loaded = True
mock_asr.transcribe_async = AsyncMock(return_value=[])
servicer: TypedServicer = TypedServicer(
session_factory=session_factory,
asr_engine=mock_asr,
)
chunks = [create_audio_chunk(str(meeting.id))]
async def chunk_iter() -> AsyncIterator[noteflow_pb2.AudioChunk]:
for c in chunks:
yield c
updates: list[noteflow_pb2.TranscriptUpdate] = []
async for update in servicer.StreamTranscription(chunk_iter(), MockContext()):
updates.append(update)
async with SqlAlchemyUnitOfWork(session_factory, meetings_dir) as uow:
m = await uow.meetings.get(meeting.id)
assert m is not None, f"meeting {meeting.id} should exist in database after stream init"
assert m.state == MeetingState.RECORDING, (
f"expected meeting state RECORDING after stream start, got {m.state}"
)
async def test_stream_init_recovers_streaming_turns(
self, session_factory: async_sessionmaker[AsyncSession], meetings_dir: Path
) -> None:
"""Test stream initialization loads persisted streaming turns for crash recovery."""
async with SqlAlchemyUnitOfWork(session_factory, meetings_dir) as uow:
meeting = Meeting.create(title="Recovery Test")
meeting.start_recording()
await uow.meetings.create(meeting)
turns = [
StreamingTurn(speaker="SPEAKER_00", start_time=0.0, end_time=1.5, confidence=0.9),
StreamingTurn(speaker="SPEAKER_01", start_time=1.5, end_time=3.0, confidence=0.85),
]
await uow.diarization_jobs.add_streaming_turns(str(meeting.id), turns)
await uow.commit()
mock_asr = MagicMock(is_loaded=True, transcribe_async=AsyncMock(return_value=[]))
servicer: TypedServicer = TypedServicer(
session_factory=session_factory, asr_engine=mock_asr
)
chunks = [create_audio_chunk(str(meeting.id))]
async def chunk_iter() -> AsyncIterator[noteflow_pb2.AudioChunk]:
for c in chunks:
yield c
async for _ in servicer.StreamTranscription(chunk_iter(), MockContext()):
pass
async with SqlAlchemyUnitOfWork(session_factory, meetings_dir) as uow:
m = await uow.meetings.get(meeting.id)
assert m is not None, f"meeting {meeting.id} should exist after recovery"
assert m.state == MeetingState.RECORDING, f"expected RECORDING, got {m.state}"
persisted_turns = await uow.diarization_jobs.get_streaming_turns(str(meeting.id))
assert len(persisted_turns) == 2, f"expected 2 turns, got {len(persisted_turns)}"
async def test_stream_init_fails_for_nonexistent_meeting(
self, session_factory: async_sessionmaker[AsyncSession]
) -> None:
"""Test stream initialization fails for nonexistent meeting."""
mock_asr = MagicMock()
mock_asr.is_loaded = True
servicer: TypedServicer = TypedServicer(
session_factory=session_factory,
asr_engine=mock_asr,
)
context = MockContext()
async def chunk_iter() -> AsyncIterator[noteflow_pb2.AudioChunk]:
yield create_audio_chunk(str(uuid4()))
with pytest.raises(grpc.RpcError, match=r".*"):
async for _ in servicer.StreamTranscription(chunk_iter(), context):
pass
assert context.abort_code == grpc.StatusCode.NOT_FOUND, (
f"expected NOT_FOUND status for nonexistent meeting, got {context.abort_code}"
)
async def test_stream_rejects_invalid_meeting_id(
self, session_factory: async_sessionmaker[AsyncSession]
) -> None:
"""Test stream rejects invalid meeting ID format."""
mock_asr = MagicMock()
mock_asr.is_loaded = True
servicer: TypedServicer = TypedServicer(
session_factory=session_factory,
asr_engine=mock_asr,
)
context = MockContext()
async def chunk_iter() -> AsyncIterator[noteflow_pb2.AudioChunk]:
yield create_audio_chunk("not-a-valid-uuid")
with pytest.raises(grpc.RpcError, match=r".*"):
async for _ in servicer.StreamTranscription(chunk_iter(), context):
pass
assert context.abort_code == grpc.StatusCode.INVALID_ARGUMENT, (
f"expected INVALID_ARGUMENT status for malformed meeting ID, got {context.abort_code}"
)
@pytest.mark.integration
class TestStreamSegmentPersistence:
"""Integration tests for segment persistence during streaming."""
async def _create_test_meeting(
self,
session_factory: async_sessionmaker[AsyncSession],
meetings_dir: Path,
title: str,
) -> Meeting:
"""Create a test meeting in the database."""
async with SqlAlchemyUnitOfWork(session_factory, meetings_dir) as uow:
meeting = Meeting.create(title=title)
await uow.meetings.create(meeting)
await uow.commit()
return meeting
def _setup_streaming_test(
self,
session_factory: async_sessionmaker[AsyncSession],
meeting: Meeting,
asr_result: object,
) -> tuple[MagicMock, TypedServicer, NDArray[np.float32], MeetingStreamState, str]:
"""Set up streaming test with mock ASR and servicer."""
mock_asr = MagicMock(is_loaded=True)
mock_asr.transcribe_async = AsyncMock(return_value=[asr_result])
servicer: TypedServicer = TypedServicer(
session_factory=session_factory, asr_engine=mock_asr
)
audio: NDArray[np.float32] = np.random.randn(DEFAULT_SAMPLE_RATE).astype(np.float32) * 0.1
state = self._create_stream_mocks(audio)
meeting_id_str = str(meeting.id)
return mock_asr, servicer, audio, state, meeting_id_str
def _create_stream_mocks(self, audio: NDArray[np.float32]) -> MeetingStreamState:
"""Create mocked stream state with VAD and segmenter."""
mock_segment = MagicMock(
audio=audio,
start_time=0.0,
audio_source=AudioSource.UNKNOWN,
speaker_role=SpeakerRole.UNKNOWN,
)
segmenter = MagicMock(
process_audio=MagicMock(return_value=[mock_segment]), flush=MagicMock(return_value=None)
)
vad = MagicMock(process_chunk=MagicMock(return_value=True))
return MeetingStreamState(
vad=vad,
segmenter=segmenter,
partial_buffer=PartialAudioBuffer(sample_rate=DEFAULT_SAMPLE_RATE),
sample_rate=DEFAULT_SAMPLE_RATE,
channels=1,
next_segment_id=0,
was_speaking=False,
last_partial_time=time.time(),
last_partial_text="",
)
async def test_segments_persisted_to_database(
self, session_factory: async_sessionmaker[AsyncSession], meetings_dir: Path
) -> None:
"""Test segments created during streaming are persisted to database."""
from noteflow.infrastructure.asr.dto import AsrResult
meeting = await self._create_test_meeting(session_factory, meetings_dir, "Segment Test")
_, servicer, audio, state, meeting_id_str = self._setup_streaming_test(
session_factory,
meeting,
AsrResult(
text="Hello world", start=0.0, end=1.0, language="en", language_probability=0.95
),
)
async def chunk_iter() -> AsyncIterator[noteflow_pb2.AudioChunk]:
yield create_audio_chunk(meeting_id_str, audio)
with patch.object(servicer, "get_stream_state", side_effect={meeting_id_str: state}.get):
await drain_async_gen(servicer.StreamTranscription(chunk_iter(), MockContext()))
await self._verify_segments_persisted(
session_factory, meetings_dir, meeting.id, "Hello world"
)
async def _verify_segments_persisted(
self,
session_factory: async_sessionmaker[AsyncSession],
meetings_dir: Path,
meeting_id: MeetingId,
expected_text: str,
) -> None:
"""Verify segments were persisted with expected text."""
async with SqlAlchemyUnitOfWork(session_factory, meetings_dir) as uow:
segments = await uow.segments.get_by_meeting(meeting_id)
segment_texts = [s.text for s in segments]
assert segments, f"expected at least 1 segment, got {len(segments)}"
assert expected_text in segment_texts, f"'{expected_text}' not in {segment_texts}"
def _verify_segment_persisted(
self,
segments: Sequence[Segment],
segment_texts: list[str],
expected_text: str,
) -> None:
"""Verify segment was persisted with expected text."""
assert segments, f"expected at least 1 segment, got {len(segments)}"
assert expected_text in segment_texts, f"'{expected_text}' not in {segment_texts}"
@pytest.mark.integration
class TestStreamStateManagement:
"""Integration tests for stream state management."""
async def test_meeting_transitions_to_recording_on_stream_start(
self, session_factory: async_sessionmaker[AsyncSession], meetings_dir: Path
) -> None:
"""Test meeting state transitions to RECORDING when stream starts."""
async with SqlAlchemyUnitOfWork(session_factory, meetings_dir) as uow:
meeting = Meeting.create(title="State Test")
await uow.meetings.create(meeting)
await uow.commit()
assert meeting.state == MeetingState.CREATED, (
f"expected initial meeting state CREATED, got {meeting.state}"
)
mock_asr = MagicMock()
mock_asr.is_loaded = True
mock_asr.transcribe_async = AsyncMock(return_value=[])
servicer: TypedServicer = TypedServicer(
session_factory=session_factory,
asr_engine=mock_asr,
)
async def chunk_iter() -> AsyncIterator[noteflow_pb2.AudioChunk]:
yield create_audio_chunk(str(meeting.id))
async for _ in servicer.StreamTranscription(chunk_iter(), MockContext()):
pass
async with SqlAlchemyUnitOfWork(session_factory, meetings_dir) as uow:
m = await uow.meetings.get(meeting.id)
assert m is not None, f"meeting {meeting.id} should exist in database after stream"
assert m.state == MeetingState.RECORDING, (
f"expected meeting state RECORDING after stream start, got {m.state}"
)
async def test_concurrent_streams_rejected(
self, session_factory: async_sessionmaker[AsyncSession], meetings_dir: Path
) -> None:
"""Test concurrent streams for same meeting are rejected."""
async with SqlAlchemyUnitOfWork(session_factory, meetings_dir) as uow:
meeting = Meeting.create(title="Concurrent Test")
meeting.start_recording()
await uow.meetings.create(meeting)
await uow.commit()
mock_asr = MagicMock()
mock_asr.is_loaded = True
mock_asr.transcribe_async = AsyncMock(return_value=[])
servicer: TypedServicer = TypedServicer(
session_factory=session_factory,
asr_engine=mock_asr,
)
servicer.active_streams.add(str(meeting.id))
context = MockContext()
async def chunk_iter() -> AsyncIterator[noteflow_pb2.AudioChunk]:
yield create_audio_chunk(str(meeting.id))
with pytest.raises(grpc.RpcError, match=r".*"):
async for _ in servicer.StreamTranscription(chunk_iter(), context):
pass
assert context.abort_code == grpc.StatusCode.FAILED_PRECONDITION, (
f"expected FAILED_PRECONDITION for concurrent stream, got {context.abort_code}"
)
@pytest.mark.integration
class TestStreamCleanup:
"""Integration tests for stream cleanup."""
async def test_active_stream_removed_on_completion(
self, session_factory: async_sessionmaker[AsyncSession], meetings_dir: Path
) -> None:
"""Test active stream is removed when streaming completes."""
async with SqlAlchemyUnitOfWork(session_factory, meetings_dir) as uow:
meeting = Meeting.create(title="Cleanup Test")
await uow.meetings.create(meeting)
await uow.commit()
mock_asr = MagicMock()
mock_asr.is_loaded = True
mock_asr.transcribe_async = AsyncMock(return_value=[])
servicer: TypedServicer = TypedServicer(
session_factory=session_factory,
asr_engine=mock_asr,
)
async def chunk_iter() -> AsyncIterator[noteflow_pb2.AudioChunk]:
yield create_audio_chunk(str(meeting.id))
async for _ in servicer.StreamTranscription(chunk_iter(), MockContext()):
pass
meeting_id_str = str(meeting.id)
assert meeting_id_str not in servicer.active_streams, (
f"meeting {meeting.id} should be removed from active streams after completion"
)
async def test_streaming_state_cleaned_up_on_error(
self, session_factory: async_sessionmaker[AsyncSession], meetings_dir: Path
) -> None:
"""Test streaming state is cleaned up even when errors occur."""
async with SqlAlchemyUnitOfWork(session_factory, meetings_dir) as uow:
meeting = Meeting.create(title="Error Cleanup Test")
await uow.meetings.create(meeting)
await uow.commit()
mock_asr = MagicMock()
mock_asr.is_loaded = True
mock_asr.transcribe_async = AsyncMock(side_effect=RuntimeError("ASR failed"))
servicer: TypedServicer = TypedServicer(
session_factory=session_factory,
asr_engine=mock_asr,
)
async def chunk_iter() -> AsyncIterator[noteflow_pb2.AudioChunk]:
yield create_audio_chunk(str(meeting.id))
try:
async for _ in servicer.StreamTranscription(chunk_iter(), MockContext()):
pass
except RuntimeError:
pass # Expected: mock_asr.transcribe_async raises RuntimeError("ASR failed")
meeting_id_str = str(meeting.id)
assert meeting_id_str not in servicer.active_streams, (
f"meeting {meeting_id_str} should be removed from active streams after error cleanup"
)
@pytest.mark.integration
class TestStreamStopRequest:
"""Integration tests for graceful stream stop."""
def _create_stop_request_chunk_iterator(
self,
servicer: TypedServicer,
meeting_id: str,
) -> tuple[AsyncIterator[noteflow_pb2.AudioChunk], list[int]]:
"""Create chunk iterator that triggers stop request."""
chunks_processed: list[int] = [0]
async def chunk_iter() -> AsyncIterator[noteflow_pb2.AudioChunk]:
for i in range(10):
chunks_processed[0] += 1
yield create_audio_chunk(meeting_id)
if i == 2:
servicer.stop_requested.add(meeting_id)
await yield_control()
return chunk_iter(), chunks_processed
async def test_stop_request_exits_stream_gracefully(
self, session_factory: async_sessionmaker[AsyncSession], meetings_dir: Path
) -> None:
"""Test stop request causes stream to exit gracefully."""
async with SqlAlchemyUnitOfWork(session_factory, meetings_dir) as uow:
meeting = Meeting.create(title="Stop Test")
meeting.start_recording()
await uow.meetings.create(meeting)
await uow.commit()
mock_asr = MagicMock(is_loaded=True, transcribe_async=AsyncMock(return_value=[]))
servicer: TypedServicer = TypedServicer(
session_factory=session_factory, asr_engine=mock_asr
)
chunk_iter, chunks_processed = self._create_stop_request_chunk_iterator(
servicer, str(meeting.id)
)
async for _ in servicer.StreamTranscription(chunk_iter, MockContext()):
pass
assert chunks_processed[0] <= 5, (
f"expected stream to stop after ~3 chunks due to stop request, but processed {chunks_processed[0]}"
)