246 lines
9.6 KiB
Python
246 lines
9.6 KiB
Python
"""Tests for chunk sequence tracking and acknowledgment.
|
|
|
|
Tests the sequence tracking and ack emission logic for streaming transcription.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
from noteflow.grpc.mixins.converters import create_ack_update
|
|
from noteflow.grpc.mixins.streaming._processing._chunk_tracking import track_chunk_sequence
|
|
from noteflow.grpc.mixins.streaming._processing._constants import ACK_CHUNK_INTERVAL
|
|
from noteflow.grpc.proto import noteflow_pb2
|
|
|
|
# Test constants for sequence tracking
|
|
MEETING_1_SEQUENCE = 10
|
|
"""Sequence number for meeting 1 in multi-meeting tracking tests."""
|
|
|
|
MEETING_2_SEQUENCE = 20
|
|
"""Sequence number for meeting 2 in multi-meeting tracking tests."""
|
|
|
|
# Test constants for congestion info
|
|
TEST_PROCESSING_DELAY_MS = 500
|
|
"""Test processing delay in milliseconds."""
|
|
|
|
TEST_QUEUE_DEPTH = 10
|
|
"""Test queue depth value."""
|
|
|
|
|
|
def _send_chunks_up_to(host: MagicMock, meeting_id: str, count: int) -> None:
|
|
"""Send sequence of chunks 1 through count to build up tracking state."""
|
|
for i in range(1, count + 1):
|
|
track_chunk_sequence(host, meeting_id, i)
|
|
|
|
|
|
def _send_zero_sequences(host: MagicMock, meeting_id: str, count: int) -> None:
|
|
"""Send count zero-sequence chunks to build up state for legacy client tests."""
|
|
for _ in range(count):
|
|
track_chunk_sequence(host, meeting_id, 0)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_host() -> MagicMock:
|
|
"""Create mock ServicerHost with sequence tracking dicts."""
|
|
host = MagicMock()
|
|
host.chunk_sequences = {}
|
|
host.chunk_counts = {}
|
|
host.pending_chunks = {}
|
|
host.chunk_receipt_times = {}
|
|
return host
|
|
|
|
|
|
class TestCreateAckUpdate:
|
|
"""Tests for create_ack_update converter function."""
|
|
|
|
def test_creates_valid_update(self) -> None:
|
|
"""Verify create_ack_update produces correct TranscriptUpdate."""
|
|
meeting_id = "test-meeting-123"
|
|
ack_seq = 42
|
|
|
|
update = create_ack_update(meeting_id, ack_seq)
|
|
|
|
assert update.meeting_id == meeting_id, "meeting_id should match"
|
|
assert update.ack_sequence == ack_seq, "ack_sequence should be set"
|
|
assert (
|
|
update.update_type == noteflow_pb2.UPDATE_TYPE_UNSPECIFIED
|
|
), "update_type should be UNSPECIFIED for ack-only"
|
|
assert update.server_timestamp > 0, "server_timestamp should be set"
|
|
|
|
def test_ack_update_has_no_content(self) -> None:
|
|
"""Verify ack-only update has no transcript content."""
|
|
update = create_ack_update("meeting-id", 100)
|
|
|
|
assert update.partial_text == "", "should have no partial_text"
|
|
assert not update.HasField("segment"), "should have no segment"
|
|
|
|
def test_ack_update_includes_congestion_info(self) -> None:
|
|
"""Verify ack update can include congestion info."""
|
|
congestion = noteflow_pb2.CongestionInfo(
|
|
processing_delay_ms=TEST_PROCESSING_DELAY_MS,
|
|
queue_depth=TEST_QUEUE_DEPTH,
|
|
throttle_recommended=False,
|
|
)
|
|
|
|
update = create_ack_update("meeting-id", 50, congestion=congestion)
|
|
|
|
assert update.HasField("congestion"), "should have congestion field"
|
|
assert update.congestion.processing_delay_ms == TEST_PROCESSING_DELAY_MS, "delay should match"
|
|
assert update.congestion.queue_depth == TEST_QUEUE_DEPTH, "queue depth should match"
|
|
assert update.congestion.throttle_recommended is False, "throttle should match"
|
|
|
|
def test_ack_update_congestion_with_throttle(self) -> None:
|
|
"""Verify ack update with throttle_recommended=True."""
|
|
congestion = noteflow_pb2.CongestionInfo(
|
|
processing_delay_ms=1500,
|
|
queue_depth=25,
|
|
throttle_recommended=True,
|
|
)
|
|
|
|
update = create_ack_update("meeting-id", 100, congestion=congestion)
|
|
|
|
assert update.congestion.throttle_recommended is True, "throttle should be True"
|
|
|
|
|
|
class TestTrackChunkSequence:
|
|
"""Tests for track_chunk_sequence function."""
|
|
|
|
def test_tracks_sequence_number(self, mock_host: MagicMock) -> None:
|
|
"""Verify sequence number is tracked correctly."""
|
|
meeting_id = "test-meeting"
|
|
|
|
track_chunk_sequence(mock_host, meeting_id, 1)
|
|
assert mock_host.chunk_sequences[meeting_id] == 1, "should track first sequence"
|
|
|
|
track_chunk_sequence(mock_host, meeting_id, 2)
|
|
assert mock_host.chunk_sequences[meeting_id] == 2, "should track second sequence"
|
|
|
|
track_chunk_sequence(mock_host, meeting_id, 5)
|
|
assert mock_host.chunk_sequences[meeting_id] == 5, "should track non-contiguous sequence"
|
|
|
|
def test_ignores_zero_sequence(self, mock_host: MagicMock) -> None:
|
|
"""Verify zero sequence (legacy clients) is ignored."""
|
|
meeting_id = "test-meeting"
|
|
|
|
track_chunk_sequence(mock_host, meeting_id, 0)
|
|
assert meeting_id not in mock_host.chunk_sequences, "should not track zero sequence"
|
|
|
|
def test_tracks_highest_sequence(self, mock_host: MagicMock) -> None:
|
|
"""Verify only highest sequence is stored (handles out-of-order)."""
|
|
meeting_id = "test-meeting"
|
|
|
|
track_chunk_sequence(mock_host, meeting_id, 5)
|
|
track_chunk_sequence(mock_host, meeting_id, 3) # Out of order
|
|
assert mock_host.chunk_sequences[meeting_id] == 5, "should keep highest sequence"
|
|
|
|
@pytest.mark.parametrize(
|
|
("chunk_count", "expects_ack"),
|
|
[
|
|
pytest.param(1, True, id="chunk_1_emits_immediate_ack"),
|
|
pytest.param(ACK_CHUNK_INTERVAL // 2, False, id="midpoint_no_ack"),
|
|
pytest.param(ACK_CHUNK_INTERVAL - 1, False, id="one_before_interval_no_ack"),
|
|
pytest.param(ACK_CHUNK_INTERVAL, True, id="at_interval_emits_ack"),
|
|
],
|
|
)
|
|
def test_ack_emission_at_chunk_count(
|
|
self, mock_host: MagicMock, chunk_count: int, expects_ack: bool
|
|
) -> None:
|
|
"""Verify ack emission: first chunk gets immediate ack, then at interval."""
|
|
meeting_id = "test-meeting"
|
|
_send_chunks_up_to(mock_host, meeting_id, chunk_count - 1)
|
|
|
|
result = track_chunk_sequence(mock_host, meeting_id, chunk_count)
|
|
|
|
assert (result is not None) == expects_ack, (
|
|
f"chunk {chunk_count} should {'emit' if expects_ack else 'not emit'} ack"
|
|
)
|
|
|
|
def test_count_resets_to_zero_after_ack(self, mock_host: MagicMock) -> None:
|
|
"""Verify chunk count resets to zero after ack emission."""
|
|
meeting_id = "test-meeting"
|
|
_send_chunks_up_to(mock_host, meeting_id, ACK_CHUNK_INTERVAL)
|
|
|
|
assert mock_host.chunk_counts[meeting_id] == 0, "count should reset after ack"
|
|
|
|
def test_count_increments_after_reset(self, mock_host: MagicMock) -> None:
|
|
"""Verify chunk count increments correctly after reset."""
|
|
meeting_id = "test-meeting"
|
|
additional_chunks = 2
|
|
# First trigger ack to reset count
|
|
_send_chunks_up_to(mock_host, meeting_id, ACK_CHUNK_INTERVAL)
|
|
# Send additional chunks after reset
|
|
_send_chunks_up_to(mock_host, meeting_id, ACK_CHUNK_INTERVAL + additional_chunks)
|
|
|
|
assert mock_host.chunk_counts[meeting_id] == additional_chunks, "should count new chunks after reset"
|
|
|
|
@pytest.mark.parametrize(
|
|
"call_number",
|
|
[
|
|
pytest.param(1, id="first_zero_seq"),
|
|
pytest.param(ACK_CHUNK_INTERVAL, id="at_interval_zero_seq"),
|
|
pytest.param(ACK_CHUNK_INTERVAL + 1, id="past_interval_zero_seq"),
|
|
],
|
|
)
|
|
def test_no_ack_for_zero_sequence_at_interval(
|
|
self, mock_host: MagicMock, call_number: int
|
|
) -> None:
|
|
"""Verify no ack is emitted for zero-sequence chunks even at interval count."""
|
|
meeting_id = "test-meeting"
|
|
# Build up state with previous zero-sequence calls
|
|
_send_zero_sequences(mock_host, meeting_id, call_number - 1)
|
|
|
|
result = track_chunk_sequence(mock_host, meeting_id, 0)
|
|
|
|
assert result is None, f"zero-seq call #{call_number} should not emit ack"
|
|
|
|
def test_separate_tracking_per_meeting(self, mock_host: MagicMock) -> None:
|
|
"""Verify each meeting has independent sequence tracking."""
|
|
meeting_1 = "meeting-1"
|
|
meeting_2 = "meeting-2"
|
|
|
|
track_chunk_sequence(mock_host, meeting_1, MEETING_1_SEQUENCE)
|
|
track_chunk_sequence(mock_host, meeting_2, MEETING_2_SEQUENCE)
|
|
|
|
assert mock_host.chunk_sequences[meeting_1] == MEETING_1_SEQUENCE, (
|
|
f"meeting 1 should have seq {MEETING_1_SEQUENCE}"
|
|
)
|
|
assert mock_host.chunk_sequences[meeting_2] == MEETING_2_SEQUENCE, (
|
|
f"meeting 2 should have seq {MEETING_2_SEQUENCE}"
|
|
)
|
|
|
|
@pytest.mark.parametrize(
|
|
("current_seq", "next_seq", "expected_logged"),
|
|
[
|
|
pytest.param(1, 2, False, id="contiguous_no_gap"),
|
|
pytest.param(1, 3, True, id="gap_detected"),
|
|
pytest.param(5, 10, True, id="large_gap_detected"),
|
|
],
|
|
)
|
|
def test_gap_detection_logging(
|
|
self,
|
|
mock_host: MagicMock,
|
|
current_seq: int,
|
|
next_seq: int,
|
|
expected_logged: bool,
|
|
) -> None:
|
|
"""Verify gap detection logs warning for non-contiguous sequences."""
|
|
from unittest.mock import patch
|
|
|
|
meeting_id = "test-meeting"
|
|
mock_logger = MagicMock()
|
|
|
|
with patch(
|
|
"noteflow.grpc.mixins.streaming._processing._chunk_tracking.logger",
|
|
mock_logger,
|
|
):
|
|
track_chunk_sequence(mock_host, meeting_id, current_seq)
|
|
track_chunk_sequence(mock_host, meeting_id, next_seq)
|
|
|
|
gap_logged = mock_logger.warning.called
|
|
assert gap_logged == expected_logged, (
|
|
f"Gap from {current_seq} to {next_seq} should "
|
|
f"{'trigger' if expected_logged else 'not trigger'} warning"
|
|
)
|