"""Tests for chunk sequence tracking and acknowledgment. Tests the sequence tracking and ack emission logic for streaming transcription. """ from __future__ import annotations from unittest.mock import MagicMock import pytest from noteflow.grpc.mixins.converters import create_ack_update from noteflow.grpc.mixins.streaming._processing._chunk_tracking import track_chunk_sequence from noteflow.grpc.mixins.streaming._processing._constants import ACK_CHUNK_INTERVAL from noteflow.grpc.proto import noteflow_pb2 # Test constants for sequence tracking MEETING_1_SEQUENCE = 10 """Sequence number for meeting 1 in multi-meeting tracking tests.""" MEETING_2_SEQUENCE = 20 """Sequence number for meeting 2 in multi-meeting tracking tests.""" # Test constants for congestion info TEST_PROCESSING_DELAY_MS = 500 """Test processing delay in milliseconds.""" TEST_QUEUE_DEPTH = 10 """Test queue depth value.""" def _send_chunks_up_to(host: MagicMock, meeting_id: str, count: int) -> None: """Send sequence of chunks 1 through count to build up tracking state.""" for i in range(1, count + 1): track_chunk_sequence(host, meeting_id, i) def _send_zero_sequences(host: MagicMock, meeting_id: str, count: int) -> None: """Send count zero-sequence chunks to build up state for legacy client tests.""" for _ in range(count): track_chunk_sequence(host, meeting_id, 0) @pytest.fixture def mock_host() -> MagicMock: """Create mock ServicerHost with sequence tracking dicts.""" host = MagicMock() host.chunk_sequences = {} host.chunk_counts = {} host.pending_chunks = {} host.chunk_receipt_times = {} return host class TestCreateAckUpdate: """Tests for create_ack_update converter function.""" def test_creates_valid_update(self) -> None: """Verify create_ack_update produces correct TranscriptUpdate.""" meeting_id = "test-meeting-123" ack_seq = 42 update = create_ack_update(meeting_id, ack_seq) assert update.meeting_id == meeting_id, "meeting_id should match" assert update.ack_sequence == ack_seq, "ack_sequence should be set" assert ( update.update_type == noteflow_pb2.UPDATE_TYPE_UNSPECIFIED ), "update_type should be UNSPECIFIED for ack-only" assert update.server_timestamp > 0, "server_timestamp should be set" def test_ack_update_has_no_content(self) -> None: """Verify ack-only update has no transcript content.""" update = create_ack_update("meeting-id", 100) assert update.partial_text == "", "should have no partial_text" assert not update.HasField("segment"), "should have no segment" def test_ack_update_includes_congestion_info(self) -> None: """Verify ack update can include congestion info.""" congestion = noteflow_pb2.CongestionInfo( processing_delay_ms=TEST_PROCESSING_DELAY_MS, queue_depth=TEST_QUEUE_DEPTH, throttle_recommended=False, ) update = create_ack_update("meeting-id", 50, congestion=congestion) assert update.HasField("congestion"), "should have congestion field" assert update.congestion.processing_delay_ms == TEST_PROCESSING_DELAY_MS, "delay should match" assert update.congestion.queue_depth == TEST_QUEUE_DEPTH, "queue depth should match" assert update.congestion.throttle_recommended is False, "throttle should match" def test_ack_update_congestion_with_throttle(self) -> None: """Verify ack update with throttle_recommended=True.""" congestion = noteflow_pb2.CongestionInfo( processing_delay_ms=1500, queue_depth=25, throttle_recommended=True, ) update = create_ack_update("meeting-id", 100, congestion=congestion) assert update.congestion.throttle_recommended is True, "throttle should be True" class TestTrackChunkSequence: """Tests for track_chunk_sequence function.""" def test_tracks_sequence_number(self, mock_host: MagicMock) -> None: """Verify sequence number is tracked correctly.""" meeting_id = "test-meeting" track_chunk_sequence(mock_host, meeting_id, 1) assert mock_host.chunk_sequences[meeting_id] == 1, "should track first sequence" track_chunk_sequence(mock_host, meeting_id, 2) assert mock_host.chunk_sequences[meeting_id] == 2, "should track second sequence" track_chunk_sequence(mock_host, meeting_id, 5) assert mock_host.chunk_sequences[meeting_id] == 5, "should track non-contiguous sequence" def test_ignores_zero_sequence(self, mock_host: MagicMock) -> None: """Verify zero sequence (legacy clients) is ignored.""" meeting_id = "test-meeting" track_chunk_sequence(mock_host, meeting_id, 0) assert meeting_id not in mock_host.chunk_sequences, "should not track zero sequence" def test_tracks_highest_sequence(self, mock_host: MagicMock) -> None: """Verify only highest sequence is stored (handles out-of-order).""" meeting_id = "test-meeting" track_chunk_sequence(mock_host, meeting_id, 5) track_chunk_sequence(mock_host, meeting_id, 3) # Out of order assert mock_host.chunk_sequences[meeting_id] == 5, "should keep highest sequence" @pytest.mark.parametrize( ("chunk_count", "expects_ack"), [ pytest.param(1, True, id="chunk_1_emits_immediate_ack"), pytest.param(ACK_CHUNK_INTERVAL // 2, False, id="midpoint_no_ack"), pytest.param(ACK_CHUNK_INTERVAL - 1, False, id="one_before_interval_no_ack"), pytest.param(ACK_CHUNK_INTERVAL, True, id="at_interval_emits_ack"), ], ) def test_ack_emission_at_chunk_count( self, mock_host: MagicMock, chunk_count: int, expects_ack: bool ) -> None: """Verify ack emission: first chunk gets immediate ack, then at interval.""" meeting_id = "test-meeting" _send_chunks_up_to(mock_host, meeting_id, chunk_count - 1) result = track_chunk_sequence(mock_host, meeting_id, chunk_count) assert (result is not None) == expects_ack, ( f"chunk {chunk_count} should {'emit' if expects_ack else 'not emit'} ack" ) def test_count_resets_to_zero_after_ack(self, mock_host: MagicMock) -> None: """Verify chunk count resets to zero after ack emission.""" meeting_id = "test-meeting" _send_chunks_up_to(mock_host, meeting_id, ACK_CHUNK_INTERVAL) assert mock_host.chunk_counts[meeting_id] == 0, "count should reset after ack" def test_count_increments_after_reset(self, mock_host: MagicMock) -> None: """Verify chunk count increments correctly after reset.""" meeting_id = "test-meeting" additional_chunks = 2 # First trigger ack to reset count _send_chunks_up_to(mock_host, meeting_id, ACK_CHUNK_INTERVAL) # Send additional chunks after reset _send_chunks_up_to(mock_host, meeting_id, ACK_CHUNK_INTERVAL + additional_chunks) assert mock_host.chunk_counts[meeting_id] == additional_chunks, "should count new chunks after reset" @pytest.mark.parametrize( "call_number", [ pytest.param(1, id="first_zero_seq"), pytest.param(ACK_CHUNK_INTERVAL, id="at_interval_zero_seq"), pytest.param(ACK_CHUNK_INTERVAL + 1, id="past_interval_zero_seq"), ], ) def test_no_ack_for_zero_sequence_at_interval( self, mock_host: MagicMock, call_number: int ) -> None: """Verify no ack is emitted for zero-sequence chunks even at interval count.""" meeting_id = "test-meeting" # Build up state with previous zero-sequence calls _send_zero_sequences(mock_host, meeting_id, call_number - 1) result = track_chunk_sequence(mock_host, meeting_id, 0) assert result is None, f"zero-seq call #{call_number} should not emit ack" def test_separate_tracking_per_meeting(self, mock_host: MagicMock) -> None: """Verify each meeting has independent sequence tracking.""" meeting_1 = "meeting-1" meeting_2 = "meeting-2" track_chunk_sequence(mock_host, meeting_1, MEETING_1_SEQUENCE) track_chunk_sequence(mock_host, meeting_2, MEETING_2_SEQUENCE) assert mock_host.chunk_sequences[meeting_1] == MEETING_1_SEQUENCE, ( f"meeting 1 should have seq {MEETING_1_SEQUENCE}" ) assert mock_host.chunk_sequences[meeting_2] == MEETING_2_SEQUENCE, ( f"meeting 2 should have seq {MEETING_2_SEQUENCE}" ) @pytest.mark.parametrize( ("current_seq", "next_seq", "expected_logged"), [ pytest.param(1, 2, False, id="contiguous_no_gap"), pytest.param(1, 3, True, id="gap_detected"), pytest.param(5, 10, True, id="large_gap_detected"), ], ) def test_gap_detection_logging( self, mock_host: MagicMock, current_seq: int, next_seq: int, expected_logged: bool, ) -> None: """Verify gap detection logs warning for non-contiguous sequences.""" from unittest.mock import patch meeting_id = "test-meeting" mock_logger = MagicMock() with patch( "noteflow.grpc.mixins.streaming._processing._chunk_tracking.logger", mock_logger, ): track_chunk_sequence(mock_host, meeting_id, current_seq) track_chunk_sequence(mock_host, meeting_id, next_seq) gap_logged = mock_logger.warning.called assert gap_logged == expected_logged, ( f"Gap from {current_seq} to {next_seq} should " f"{'trigger' if expected_logged else 'not trigger'} warning" )