"""Integration test for real VAD + Segmenter streaming pipeline.""" from __future__ import annotations from collections.abc import AsyncIterator from pathlib import Path from typing import TYPE_CHECKING, Protocol, cast from unittest.mock import AsyncMock, MagicMock import grpc import numpy as np import numpy.typing as npt import pytest from noteflow.config.constants import DEFAULT_SAMPLE_RATE from noteflow.domain.entities import Meeting from noteflow.domain.value_objects import MeetingId from noteflow.grpc.proto import noteflow_pb2 from noteflow.grpc.service import NoteFlowServicer from noteflow.infrastructure.asr.dto import AsrResult from noteflow.infrastructure.persistence.unit_of_work import SqlAlchemyUnitOfWork if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker class _AudioChunkRequest(Protocol): meeting_id: str audio_data: bytes sample_rate: int channels: int class _TranscriptUpdate(Protocol): update_type: int class _StreamTranscriptionCallable(Protocol): def __call__( self, request_iterator: AsyncIterator[_AudioChunkRequest], context: MockContext, ) -> AsyncIterator[_TranscriptUpdate]: ... SAMPLE_RATE = DEFAULT_SAMPLE_RATE CHUNK_SAMPLES = 1600 # 0.1s at 16kHz SPEECH_CHUNKS = 4 SILENCE_CHUNKS = 10 EXPECTED_TEXT = "Real pipeline segment" class MockContext: """Minimal gRPC context for integration streaming tests.""" async def abort(self, code: grpc.StatusCode, details: str) -> None: _ = (code, details) raise grpc.RpcError() def invocation_metadata(self) -> list[tuple[str, str]]: """Return empty metadata for mock context.""" return [] def _make_chunk(meeting_id: str, audio: npt.NDArray[np.float32]) -> noteflow_pb2.AudioChunk: """Create a protobuf audio chunk.""" return noteflow_pb2.AudioChunk( meeting_id=meeting_id, audio_data=audio.astype(np.float32).tobytes(), sample_rate=SAMPLE_RATE, channels=1, ) async def _audio_stream(meeting_id: str) -> AsyncIterator[noteflow_pb2.AudioChunk]: """Yield speech then silence chunks to exercise VAD + Segmenter.""" rng = np.random.default_rng(0) speech = rng.uniform(-0.2, 0.2, CHUNK_SAMPLES).astype(np.float32) silence = np.zeros(CHUNK_SAMPLES, dtype=np.float32) for _ in range(SPEECH_CHUNKS): yield _make_chunk(meeting_id, speech) for _ in range(SILENCE_CHUNKS): yield _make_chunk(meeting_id, silence) async def _create_meeting_for_streaming( session_factory: async_sessionmaker[AsyncSession], meetings_dir: Path, ) -> Meeting: async with SqlAlchemyUnitOfWork(session_factory, meetings_dir) as uow: meeting = Meeting.create(title="Streaming Pipeline Test") await uow.meetings.create(meeting) await uow.commit() return meeting def _create_mock_asr_engine() -> MagicMock: mock_asr = MagicMock() mock_asr.is_loaded = True mock_asr.transcribe_async = AsyncMock( return_value=[ AsrResult(text=EXPECTED_TEXT, start=0.0, end=0.4), ] ) return mock_asr async def _collect_final_updates( stream: _StreamTranscriptionCallable, meeting_id: str, ) -> list[_TranscriptUpdate]: updates: list[_TranscriptUpdate] = [] async for update in stream(_audio_stream(meeting_id), MockContext()): if update.update_type == noteflow_pb2.UPDATE_TYPE_FINAL: updates.append(update) return updates async def _verify_segment_persisted( session_factory: async_sessionmaker[AsyncSession], meetings_dir: Path, meeting_id: MeetingId, ) -> None: async with SqlAlchemyUnitOfWork(session_factory, meetings_dir) as uow: segments = await uow.segments.get_by_meeting(meeting_id) segment_texts = [s.text for s in segments] assert segments, "Expected at least one segment persisted" assert EXPECTED_TEXT in segment_texts, f"'{EXPECTED_TEXT}' not in {segment_texts}" @pytest.mark.integration class TestStreamingRealPipeline: """Validate streaming with real VAD + Segmenter path.""" @pytest.mark.asyncio async def test_streaming_emits_final_segment( self, session_factory: async_sessionmaker[AsyncSession], meetings_dir: Path, ) -> None: """Real VAD + Segmenter should emit at least one final segment.""" meeting = await _create_meeting_for_streaming(session_factory, meetings_dir) mock_asr = _create_mock_asr_engine() servicer = NoteFlowServicer( session_factory=session_factory, asr_engine=mock_asr, meetings_dir=meetings_dir, ) stream = cast(_StreamTranscriptionCallable, servicer.StreamTranscription) final_updates = await _collect_final_updates(stream, str(meeting.id)) assert final_updates, "Expected at least one final transcript update" await _verify_segment_persisted(session_factory, meetings_dir, meeting.id)