157 lines
4.9 KiB
Python
157 lines
4.9 KiB
Python
"""Integration test for real VAD + Segmenter streaming pipeline."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import AsyncIterator
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Protocol, cast
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import grpc
|
|
import numpy as np
|
|
import numpy.typing as npt
|
|
import pytest
|
|
|
|
from noteflow.config.constants import DEFAULT_SAMPLE_RATE
|
|
from noteflow.domain.entities import Meeting
|
|
from noteflow.domain.value_objects import MeetingId
|
|
from noteflow.grpc.proto import noteflow_pb2
|
|
from noteflow.grpc.service import NoteFlowServicer
|
|
from noteflow.infrastructure.asr.dto import AsrResult
|
|
from noteflow.infrastructure.persistence.unit_of_work import SqlAlchemyUnitOfWork
|
|
|
|
if TYPE_CHECKING:
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|
|
|
|
|
class _AudioChunkRequest(Protocol):
|
|
meeting_id: str
|
|
audio_data: bytes
|
|
sample_rate: int
|
|
channels: int
|
|
|
|
|
|
class _TranscriptUpdate(Protocol):
|
|
update_type: int
|
|
|
|
|
|
class _StreamTranscriptionCallable(Protocol):
|
|
def __call__(
|
|
self,
|
|
request_iterator: AsyncIterator[_AudioChunkRequest],
|
|
context: MockContext,
|
|
) -> AsyncIterator[_TranscriptUpdate]: ...
|
|
|
|
|
|
SAMPLE_RATE = DEFAULT_SAMPLE_RATE
|
|
CHUNK_SAMPLES = 1600 # 0.1s at 16kHz
|
|
SPEECH_CHUNKS = 4
|
|
SILENCE_CHUNKS = 10
|
|
EXPECTED_TEXT = "Real pipeline segment"
|
|
|
|
|
|
class MockContext:
|
|
"""Minimal gRPC context for integration streaming tests."""
|
|
|
|
async def abort(self, code: grpc.StatusCode, details: str) -> None:
|
|
_ = (code, details)
|
|
raise grpc.RpcError()
|
|
|
|
def invocation_metadata(self) -> list[tuple[str, str]]:
|
|
"""Return empty metadata for mock context."""
|
|
return []
|
|
|
|
|
|
def _make_chunk(meeting_id: str, audio: npt.NDArray[np.float32]) -> noteflow_pb2.AudioChunk:
|
|
"""Create a protobuf audio chunk."""
|
|
return noteflow_pb2.AudioChunk(
|
|
meeting_id=meeting_id,
|
|
audio_data=audio.astype(np.float32).tobytes(),
|
|
sample_rate=SAMPLE_RATE,
|
|
channels=1,
|
|
)
|
|
|
|
|
|
async def _audio_stream(meeting_id: str) -> AsyncIterator[noteflow_pb2.AudioChunk]:
|
|
"""Yield speech then silence chunks to exercise VAD + Segmenter."""
|
|
rng = np.random.default_rng(0)
|
|
speech = rng.uniform(-0.2, 0.2, CHUNK_SAMPLES).astype(np.float32)
|
|
silence = np.zeros(CHUNK_SAMPLES, dtype=np.float32)
|
|
|
|
for _ in range(SPEECH_CHUNKS):
|
|
yield _make_chunk(meeting_id, speech)
|
|
|
|
for _ in range(SILENCE_CHUNKS):
|
|
yield _make_chunk(meeting_id, silence)
|
|
|
|
|
|
async def _create_meeting_for_streaming(
|
|
session_factory: async_sessionmaker[AsyncSession],
|
|
meetings_dir: Path,
|
|
) -> Meeting:
|
|
async with SqlAlchemyUnitOfWork(session_factory, meetings_dir) as uow:
|
|
meeting = Meeting.create(title="Streaming Pipeline Test")
|
|
await uow.meetings.create(meeting)
|
|
await uow.commit()
|
|
return meeting
|
|
|
|
|
|
def _create_mock_asr_engine() -> MagicMock:
|
|
mock_asr = MagicMock()
|
|
mock_asr.is_loaded = True
|
|
mock_asr.transcribe_async = AsyncMock(
|
|
return_value=[
|
|
AsrResult(text=EXPECTED_TEXT, start=0.0, end=0.4),
|
|
]
|
|
)
|
|
return mock_asr
|
|
|
|
|
|
async def _collect_final_updates(
|
|
stream: _StreamTranscriptionCallable,
|
|
meeting_id: str,
|
|
) -> list[_TranscriptUpdate]:
|
|
updates: list[_TranscriptUpdate] = []
|
|
async for update in stream(_audio_stream(meeting_id), MockContext()):
|
|
if update.update_type == noteflow_pb2.UPDATE_TYPE_FINAL:
|
|
updates.append(update)
|
|
return updates
|
|
|
|
|
|
async def _verify_segment_persisted(
|
|
session_factory: async_sessionmaker[AsyncSession],
|
|
meetings_dir: Path,
|
|
meeting_id: MeetingId,
|
|
) -> None:
|
|
async with SqlAlchemyUnitOfWork(session_factory, meetings_dir) as uow:
|
|
segments = await uow.segments.get_by_meeting(meeting_id)
|
|
segment_texts = [s.text for s in segments]
|
|
assert segments, "Expected at least one segment persisted"
|
|
assert EXPECTED_TEXT in segment_texts, f"'{EXPECTED_TEXT}' not in {segment_texts}"
|
|
|
|
|
|
@pytest.mark.integration
|
|
class TestStreamingRealPipeline:
|
|
"""Validate streaming with real VAD + Segmenter path."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_streaming_emits_final_segment(
|
|
self,
|
|
session_factory: async_sessionmaker[AsyncSession],
|
|
meetings_dir: Path,
|
|
) -> None:
|
|
"""Real VAD + Segmenter should emit at least one final segment."""
|
|
meeting = await _create_meeting_for_streaming(session_factory, meetings_dir)
|
|
mock_asr = _create_mock_asr_engine()
|
|
|
|
servicer = NoteFlowServicer(
|
|
session_factory=session_factory,
|
|
asr_engine=mock_asr,
|
|
meetings_dir=meetings_dir,
|
|
)
|
|
|
|
stream = cast(_StreamTranscriptionCallable, servicer.StreamTranscription)
|
|
final_updates = await _collect_final_updates(stream, str(meeting.id))
|
|
assert final_updates, "Expected at least one final transcript update"
|
|
await _verify_segment_persisted(session_factory, meetings_dir, meeting.id)
|