fix: update types-grpcio dependency version and improve meeting filtering

- Changed the types-grpcio dependency from `>=1.0.0.20251009` to `==1.0.0.20251001` in `pyproject.toml` and `uv.lock` for consistency.
- Enhanced the meeting filtering logic in the `MeetingStore` and related classes to support filtering by multiple project IDs, improving query flexibility.
- Updated the gRPC proto definitions to include a new `project_ids` field for the `ListMeetingsRequest` message, allowing for more granular project filtering.
- Adjusted related repository and mixin classes to accommodate the new filtering capabilities.
This commit is contained in:
2026-01-05 00:38:33 +00:00
parent 52ecb89e89
commit 4fceb95438
55 changed files with 1920 additions and 2155 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

View File

@@ -2,6 +2,15 @@
Checking for magic numbers...
WARNING: Found potential magic numbers (consider using named constants):
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/oauth_loopback.rs:60: let listener = TcpListener::bind("127.0.0.1:0")
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/oauth_loopback.rs:280: background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/oauth_loopback.rs:286: background: rgba(255,255,255,0.1);
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/oauth_loopback.rs:300: <div class="checkmark">&#10004;</div>
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/oauth_loopback.rs:308: "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/oauth_loopback.rs:332: background: linear-gradient(135deg, #e74c3c 0%, #c0392b 100%);
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/oauth_loopback.rs:338: background: rgba(255,255,255,0.1);
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/oauth_loopback.rs:349: <div class="error">&#10006;</div>
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/oauth_loopback.rs:359: "HTTP/1.1 400 Bad Request\r\nContent-Type: text/html\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/state/preferences.rs:63: server_host: "127.0.0.1".to_string(),
Checking for repeated string literals...
@@ -21,19 +30,19 @@ Checking for long functions...
Checking for deep nesting...
WARNING: Found potentially deep nesting (>7 levels):
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/commands/diarization.rs:232: INITIAL_RETRY_DELAY_MS * RETRY_BACKOFF_MULTIPLIER.pow(consecutive_errors - 1),
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/commands/playback.rs:59: let duration = buffer
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/commands/playback.rs:60: .last()
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/commands/playback.rs:61: .map(|chunk| chunk.timestamp + chunk.duration)
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/commands/playback.rs:62: .unwrap_or(0.0);
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/commands/playback.rs:63: tracing::debug!("Loaded encrypted audio from {:?}", audio_path);
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/commands/playback.rs:64: return LoadedAudio {
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/commands/playback.rs:65: buffer,
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/commands/playback.rs:66: sample_rate,
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/commands/playback.rs:67: duration,
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/oauth_loopback.rs:122: if let Some(result) = handle_connection(stream).await {
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/oauth_loopback.rs:123: if let Some(tx) = result_tx.lock().await.take() {
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/oauth_loopback.rs:124: let _ = tx.send(result);
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/oauth_loopback.rs:125: }
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/oauth_loopback.rs:126: }
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/identity/mod.rs:145: user_id = %identity.user_id,
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/identity/mod.rs:146: is_local = identity.is_local,
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/identity/mod.rs:147: "Loaded identity from keychain"
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/commands/diarization.rs:233: INITIAL_RETRY_DELAY_MS
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/commands/diarization.rs:234: * RETRY_BACKOFF_MULTIPLIER.pow(consecutive_errors - 1),
Checking for unwrap() usage...
OK: No unwrap() calls found
OK: Found 2 unwrap() calls (within acceptable range)
Checking for excessive clone() usage...
OK: No excessive clone() usage detected
@@ -46,18 +55,18 @@ Checking for duplicated error messages...
Checking module file sizes...
WARNING: Large files (>500 lines):
597 /home/trav/repos/noteflow/client/src-tauri/scripts/../src/commands/playback.rs
546 /home/trav/repos/noteflow/client/src-tauri/scripts/../src/error.rs
537 /home/trav/repos/noteflow/client/src-tauri/scripts/../src/grpc/streaming.rs
593 /home/trav/repos/noteflow/client/src-tauri/scripts/../src/commands/playback.rs
550 /home/trav/repos/noteflow/client/src-tauri/scripts/../src/error.rs
533 /home/trav/repos/noteflow/client/src-tauri/scripts/../src/grpc/streaming.rs
Checking for scattered helper functions...
WARNING: Helper functions scattered across 11 files (consider consolidating):
WARNING: Helper functions scattered across 12 files (consider consolidating):
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/oauth_loopback.rs
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/grpc/streaming.rs
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/grpc/client/converters.rs
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/grpc/client/observability.rs
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/grpc/client/sync.rs
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/helpers.rs
=== Summary ===
Errors: 0

View File

@@ -1,53 +0,0 @@
# Phase Gaps: Backend-Client Synchronization Issues
This directory contains sprint specifications for addressing gaps, race conditions, and synchronization issues identified between the NoteFlow backend and Tauri/React client.
## Summary of Findings
Analysis of the gRPC API contracts, state management, and streaming operations revealed several categories of issues requiring remediation.
| Sprint | Category | Severity | Effort |
|--------|----------|----------|--------|
| [SPRINT-GAP-001](./sprint-gap-001-streaming-race-conditions.md) | Streaming Race Conditions | High | L |
| [SPRINT-GAP-002](./sprint-gap-002-state-sync-gaps.md) | State Synchronization | Medium | M |
| [SPRINT-GAP-003](./sprint-gap-003-error-handling.md) | Error Handling Mismatches | Medium | M |
| [SPRINT-GAP-004](./sprint-gap-004-diarization-lifecycle.md) | Diarization Job Lifecycle | Medium | S |
| [SPRINT-GAP-005](./sprint-gap-005-entity-resource-leak.md) | Entity Mixin Resource Leak | High | S |
| [SPRINT-GAP-006](./sprint-gap-006-connection-bootstrapping.md) | Connection Bootstrapping | High | M |
| [SPRINT-GAP-007](./sprint-gap-007-simulation-mode-clarity.md) | Simulation Mode Clarity | Medium | S |
| [SPRINT-GAP-008](./sprint-gap-008-server-address-consistency.md) | Server Address Consistency | Medium | M |
| [SPRINT-GAP-009](./sprint-gap-009-event-bridge-contracts.md) | Event Bridge Contracts | Medium | S |
| [SPRINT-GAP-010](./sprint-gap-010-identity-and-rpc-logging.md) | Identity + RPC Logging | Medium | M |
## Priority Matrix
### Critical (Address Immediately)
- **SPRINT-GAP-001**: Audio streaming fire-and-forget can cause data loss
- **SPRINT-GAP-005**: Entity mixin context manager misuse causes resource leaks
- **SPRINT-GAP-006**: Recording can start without a connected gRPC client
### High Priority
- **SPRINT-GAP-002**: Meeting cache invalidation prevents stale data
- **SPRINT-GAP-003**: Silenced errors hide critical failures
### Medium Priority
- **SPRINT-GAP-004**: Diarization polling resilience improvements
- **SPRINT-GAP-007**: Simulated paths need explicit UX and safety rails
- **SPRINT-GAP-008**: Default server addressing can be mis-pointed in Docker setups
- **SPRINT-GAP-009**: Event bridge should initialize before connection
- **SPRINT-GAP-010**: Identity metadata and per-RPC logging not wired
## Cross-Cutting Concerns
1. **Observability**: All fixes should emit appropriate log events and metrics
2. **Testing**: Each sprint must include integration tests for the identified scenarios
3. **Backwards Compatibility**: Client-side changes must gracefully handle older server versions
## Analysis Methodology
Issues were identified through:
1. Code review of gRPC mixins (`src/noteflow/grpc/_mixins/`)
2. Tauri command handlers (`client/src-tauri/src/commands/`)
3. TypeScript API adapters (`client/src/api/`)
4. Pattern matching for anti-patterns (`.catch(() => {})`, missing awaits)
5. State machine analysis for race conditions

View File

@@ -49,7 +49,7 @@ dev = [
"basedpyright>=1.18",
"pyrefly>=0.46.1",
"sourcery; sys_platform == 'darwin'",
"types-grpcio>=1.0.0.20251009",
"types-grpcio==1.0.0.20251001",
"testcontainers[postgres]>=4.0",
]
triggers = [
@@ -288,6 +288,6 @@ dev = [
"pytest-httpx>=0.36.0",
"ruff>=0.14.9",
"sourcery; sys_platform == 'darwin'",
"types-grpcio>=1.0.0.20251009",
"types-grpcio==1.0.0.20251001",
"watchfiles>=1.1.1",
]

View File

@@ -24,6 +24,7 @@ class MeetingListKwargs(TypedDict, total=False):
offset: int
sort_desc: bool
project_id: UUID | None
project_ids: list[UUID] | None
class MeetingRepository(Protocol):
@@ -83,7 +84,7 @@ class MeetingRepository(Protocol):
"""List meetings with optional filtering.
Args:
**kwargs: Optional filters (states, limit, offset, sort_desc, project_id).
**kwargs: Optional filters (states, limit, offset, sort_desc, project_id, project_ids).
Returns:
Tuple of (meetings list, total count matching filter).

View File

@@ -283,8 +283,30 @@ class MeetingMixin:
state_values = cast(Sequence[int], request.states)
states = [MeetingState(s) for s in state_values] if state_values else None
project_id: UUID | None = None
project_ids: list[UUID] | None = None
if cast(_HasField, request).HasField("project_id") and request.project_id:
if request.project_ids:
project_ids = []
for raw_project_id in request.project_ids:
try:
project_ids.append(UUID(raw_project_id))
except ValueError:
truncated = raw_project_id[:8] + "..." if len(raw_project_id) > 8 else raw_project_id
logger.warning(
"ListMeetings: invalid project_ids format",
project_id_truncated=truncated,
project_id_length=len(raw_project_id),
)
await abort_invalid_argument(
context,
f"{ERROR_INVALID_PROJECT_ID_PREFIX}{raw_project_id}",
)
if (
not project_ids
and cast(_HasField, request).HasField("project_id")
and request.project_id
):
try:
project_id = UUID(request.project_id)
except ValueError:
@@ -297,7 +319,7 @@ class MeetingMixin:
await abort_invalid_argument(context, f"{ERROR_INVALID_PROJECT_ID_PREFIX}{request.project_id}")
async with self.create_repository_provider() as repo:
if project_id is None:
if project_id is None and not project_ids:
project_id = await _resolve_active_project_id(self, repo)
meetings, total = await repo.meetings.list_all(
@@ -306,6 +328,7 @@ class MeetingMixin:
offset=offset,
sort_desc=sort_desc,
project_id=project_id,
project_ids=project_ids,
)
logger.debug(
"ListMeetings returned",
@@ -314,6 +337,7 @@ class MeetingMixin:
limit=limit,
offset=offset,
project_id=str(project_id) if project_id else None,
project_ids=[str(pid) for pid in project_ids] if project_ids else None,
)
return noteflow_pb2.ListMeetingsResponse(
meetings=[meeting_to_proto(m, include_segments=False) for m in meetings],

View File

@@ -107,6 +107,7 @@ class MeetingStore:
offset = kwargs.get("offset", 0)
sort_desc = kwargs.get("sort_desc", True)
project_id = kwargs.get("project_id")
project_ids = kwargs.get("project_ids")
meetings = list(self._meetings.values())
# Filter by state
@@ -114,8 +115,13 @@ class MeetingStore:
state_set = set(states)
meetings = [m for m in meetings if m.state in state_set]
# Filter by project if requested
if project_id:
# Filter by project(s) if requested
if project_ids:
project_set = set(project_ids)
meetings = [
m for m in meetings if m.project_id is not None and str(m.project_id) in project_set
]
elif project_id:
meetings = [
m for m in meetings if m.project_id is not None and str(m.project_id) == project_id
]

View File

@@ -306,6 +306,9 @@ message ListMeetingsRequest {
// Optional project filter (defaults to active project if omitted)
optional string project_id = 5;
// Optional project filter for multiple projects (overrides project_id when provided)
repeated string project_ids = 6;
}
enum SortOrder {

File diff suppressed because one or more lines are too long

View File

@@ -244,18 +244,20 @@ class StopMeetingRequest(_message.Message):
def __init__(self, meeting_id: _Optional[str] = ...) -> None: ...
class ListMeetingsRequest(_message.Message):
__slots__ = ("states", "limit", "offset", "sort_order", "project_id")
__slots__ = ("states", "limit", "offset", "sort_order", "project_id", "project_ids")
STATES_FIELD_NUMBER: _ClassVar[int]
LIMIT_FIELD_NUMBER: _ClassVar[int]
OFFSET_FIELD_NUMBER: _ClassVar[int]
SORT_ORDER_FIELD_NUMBER: _ClassVar[int]
PROJECT_ID_FIELD_NUMBER: _ClassVar[int]
PROJECT_IDS_FIELD_NUMBER: _ClassVar[int]
states: _containers.RepeatedScalarFieldContainer[MeetingState]
limit: int
offset: int
sort_order: SortOrder
project_id: str
def __init__(self, states: _Optional[_Iterable[_Union[MeetingState, str]]] = ..., limit: _Optional[int] = ..., offset: _Optional[int] = ..., sort_order: _Optional[_Union[SortOrder, str]] = ..., project_id: _Optional[str] = ...) -> None: ...
project_ids: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, states: _Optional[_Iterable[_Union[MeetingState, str]]] = ..., limit: _Optional[int] = ..., offset: _Optional[int] = ..., sort_order: _Optional[_Union[SortOrder, str]] = ..., project_id: _Optional[str] = ..., project_ids: _Optional[_Iterable[str]] = ...) -> None: ...
class ListMeetingsResponse(_message.Message):
__slots__ = ("meetings", "total_count")

View File

@@ -51,13 +51,16 @@ class MemoryMeetingRepository:
offset = kwargs.get("offset", 0)
sort_desc = kwargs.get("sort_desc", True)
project_id = kwargs.get("project_id")
project_ids = kwargs.get("project_ids")
project_filter = str(project_id) if project_id else None
project_filters = [str(pid) for pid in project_ids] if project_ids else None
return self._store.list_all(
states=states,
limit=limit,
offset=offset,
sort_desc=sort_desc,
project_id=project_filter,
project_ids=project_filters,
)
async def count_by_state(self, state: MeetingState) -> int:

View File

@@ -137,6 +137,7 @@ class SqlAlchemyMeetingRepository(BaseRepository):
offset = kwargs.get("offset", 0)
sort_desc = kwargs.get("sort_desc", True)
project_id = kwargs.get("project_id")
project_ids = kwargs.get("project_ids")
# Build base query
stmt = select(MeetingModel)
@@ -146,8 +147,10 @@ class SqlAlchemyMeetingRepository(BaseRepository):
state_values = [int(s) for s in states]
stmt = stmt.where(MeetingModel.state.in_(state_values))
# Filter by project if requested
if project_id is not None:
# Filter by project(s) if requested
if project_ids:
stmt = stmt.where(MeetingModel.project_id.in_(project_ids))
elif project_id is not None:
stmt = stmt.where(MeetingModel.project_id == project_id)
# Count total

View File

@@ -22,7 +22,6 @@ from noteflow.application.services.auth_service import (
AuthResult,
AuthService,
AuthServiceError,
LogoutResult,
UserInfo,
)
from noteflow.config.settings import CalendarIntegrationSettings
@@ -635,16 +634,17 @@ class TestRefreshAuthTokens:
class TestStoreAuthUser:
"""Tests for AuthService._store_auth_user workspace handling."""
async def test_creates_default_workspace_for_new_user(
async def test_complete_login_creates_default_workspace_for_new_user(
self,
calendar_settings: CalendarIntegrationSettings,
mock_oauth_manager: MagicMock,
mock_auth_uow: MagicMock,
sample_oauth_tokens: OAuthTokens,
) -> None:
"""_store_auth_user creates default workspace when none exists."""
"""complete_login creates default workspace when none exists."""
mock_auth_uow.workspaces.get_default_for_user.return_value = None
mock_auth_uow.workspaces.create = AsyncMock()
mock_oauth_manager.complete_auth.return_value = sample_oauth_tokens
service = AuthService(
uow_factory=lambda: mock_auth_uow,
@@ -652,9 +652,11 @@ class TestStoreAuthUser:
oauth_manager=mock_oauth_manager,
)
await service._store_auth_user(
"google", "test@example.com", "Test User", sample_oauth_tokens
)
with patch(
"noteflow.infrastructure.calendar.google_adapter.GoogleCalendarAdapter.get_user_info",
new=AsyncMock(return_value=("test@example.com", "Test User")),
):
await service.complete_login("google", "auth-code", "state")
# Verify workspace.create was called (new workspace for new user)
mock_auth_uow.workspaces.create.assert_called_once()
@@ -730,8 +732,8 @@ class TestRefreshAuthTokensEdgeCases:
# =============================================================================
class TestParseProvider:
"""Tests for AuthService._parse_provider static method."""
class TestProviderParsing:
"""Tests for provider parsing through public APIs."""
@pytest.mark.parametrize(
("input_provider", "expected_output"),
@@ -743,15 +745,29 @@ class TestParseProvider:
pytest.param("OUTLOOK", OAuthProvider.OUTLOOK, id="outlook_uppercase"),
],
)
def test_parses_valid_providers(
@pytest.mark.asyncio
async def test_initiate_login_accepts_valid_providers(
self,
input_provider: str,
expected_output: OAuthProvider,
calendar_settings: CalendarIntegrationSettings,
mock_oauth_manager: MagicMock,
) -> None:
"""_parse_provider correctly parses valid provider strings."""
result = AuthService._parse_provider(input_provider)
"""initiate_login accepts valid provider strings (case-insensitive)."""
service = AuthService(
uow_factory=lambda: MagicMock(),
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
)
assert result == expected_output, f"should parse {input_provider} correctly"
auth_url, state = await service.initiate_login(input_provider)
mock_oauth_manager.initiate_auth.assert_called_once_with(
provider=expected_output,
redirect_uri=calendar_settings.redirect_uri,
)
assert auth_url.startswith("https://auth.example.com/"), "should return auth url"
assert state == "state123", "should return state from OAuth manager"
@pytest.mark.parametrize(
"invalid_provider",
@@ -762,10 +778,19 @@ class TestParseProvider:
pytest.param("invalid", id="random_string"),
],
)
def test_raises_for_invalid_providers(
@pytest.mark.asyncio
async def test_initiate_login_rejects_invalid_providers(
self,
invalid_provider: str,
calendar_settings: CalendarIntegrationSettings,
mock_oauth_manager: MagicMock,
) -> None:
"""_parse_provider raises AuthServiceError for invalid providers."""
"""initiate_login raises AuthServiceError for invalid providers."""
service = AuthService(
uow_factory=lambda: MagicMock(),
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
)
with pytest.raises(AuthServiceError, match="Invalid provider"):
AuthService._parse_provider(invalid_provider)
await service.initiate_login(invalid_provider)

View File

@@ -9,6 +9,8 @@ Tests cover:
from __future__ import annotations
from typing import Protocol, cast
import grpc
import pytest
@@ -16,6 +18,81 @@ from noteflow.grpc._config import ServicesConfig
from noteflow.grpc.proto import noteflow_pb2
from noteflow.grpc.service import NoteFlowServicer
from noteflow.infrastructure.diarization import DiarizationEngine
from noteflow.infrastructure.persistence.repositories.diarization_job_repo import (
JOB_STATUS_FAILED,
)
class _RefineSpeakerDiarizationRequest(Protocol):
"""Protocol for refine request objects."""
meeting_id: str
num_speakers: int
class _RefineSpeakerDiarizationResponse(Protocol):
"""Protocol for refine responses used in tests."""
job_id: str
status: int
error_message: str
class _CancelDiarizationJobRequest(Protocol):
"""Protocol for cancel request objects."""
job_id: str
class _CancelDiarizationJobResponse(Protocol):
"""Protocol for cancel responses used in tests."""
success: bool
error_message: str
class _GetDiarizationJobStatusRequest(Protocol):
"""Protocol for job status request objects."""
job_id: str
class _DiarizationJobStatusResponse(Protocol):
"""Protocol for job status responses used in tests."""
job_id: str
status: int
error_message: str
class _RefineSpeakerDiarizationCallable(Protocol):
"""Callable protocol for refine RPC."""
async def __call__(
self,
request: _RefineSpeakerDiarizationRequest,
context: _MockGrpcContext,
) -> _RefineSpeakerDiarizationResponse: ...
class _CancelDiarizationJobCallable(Protocol):
"""Callable protocol for cancel RPC."""
async def __call__(
self,
request: _CancelDiarizationJobRequest,
context: _MockGrpcContext,
) -> _CancelDiarizationJobResponse: ...
class _GetDiarizationJobStatusCallable(Protocol):
"""Callable protocol for job status RPC."""
async def __call__(
self,
request: _GetDiarizationJobStatusRequest,
context: _MockGrpcContext,
) -> _DiarizationJobStatusResponse: ...
class _FakeDiarizationEngine(DiarizationEngine):
@@ -77,14 +154,15 @@ class TestDatabaseRequirement:
store.update(meeting)
context = _MockGrpcContext()
response = await servicer.RefineSpeakerDiarization(
refine = cast(_RefineSpeakerDiarizationCallable, servicer.RefineSpeakerDiarization)
response = await refine(
noteflow_pb2.RefineSpeakerDiarizationRequest(meeting_id=str(meeting.id)),
context,
)
# Should return error response, not job_id
assert not response.job_id, "Should not return job_id without database"
assert response.status == noteflow_pb2.JOB_STATUS_FAILED, "Status should be FAILED"
assert response.status == JOB_STATUS_FAILED, "Status should be FAILED"
assert context.abort_code == grpc.StatusCode.FAILED_PRECONDITION, "Error code should be FAILED_PRECONDITION"
@pytest.mark.asyncio
@@ -98,7 +176,8 @@ class TestDatabaseRequirement:
store.update(meeting)
context = _MockGrpcContext()
response = await servicer.RefineSpeakerDiarization(
refine = cast(_RefineSpeakerDiarizationCallable, servicer.RefineSpeakerDiarization)
response = await refine(
noteflow_pb2.RefineSpeakerDiarizationRequest(meeting_id=str(meeting.id)),
context,
)
@@ -114,7 +193,11 @@ class TestDatabaseRequirement:
# Should abort because no DB support
with pytest.raises(_AbortCalled, match="database") as exc_info:
await servicer.GetDiarizationJobStatus(
get_status = cast(
_GetDiarizationJobStatusCallable,
servicer.GetDiarizationJobStatus,
)
await get_status(
noteflow_pb2.GetDiarizationJobStatusRequest(job_id="any-job-id"),
context,
)
@@ -128,7 +211,8 @@ class TestDatabaseRequirement:
"""CancelDiarizationJob returns error when database unavailable."""
context = _MockGrpcContext()
response = await servicer.CancelDiarizationJob(
cancel = cast(_CancelDiarizationJobCallable, servicer.CancelDiarizationJob)
response = await cancel(
noteflow_pb2.CancelDiarizationJobRequest(job_id="any-job-id"),
context,
)

View File

@@ -6,8 +6,9 @@ and CancelDiarizationJob RPCs with comprehensive edge case coverage.
from __future__ import annotations
from collections.abc import Sequence
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Self, cast
from typing import TYPE_CHECKING, Protocol, Self, cast
from unittest.mock import AsyncMock
from uuid import uuid4
@@ -23,6 +24,13 @@ from noteflow.grpc.proto import noteflow_pb2
from noteflow.grpc.service import NoteFlowServicer
from noteflow.infrastructure.diarization import DiarizationEngine
from noteflow.infrastructure.persistence.repositories import DiarizationJob
from noteflow.infrastructure.persistence.repositories.diarization_job_repo import (
JOB_STATUS_CANCELLED,
JOB_STATUS_COMPLETED,
JOB_STATUS_FAILED,
JOB_STATUS_QUEUED,
JOB_STATUS_RUNNING,
)
# Test constants for progress calculation
EXPECTED_RUNNING_JOB_PROGRESS_PERCENT = 50.0
@@ -62,6 +70,84 @@ if TYPE_CHECKING:
from noteflow.grpc.meeting_store import MeetingStore
class _RefineSpeakerDiarizationRequest(Protocol):
meeting_id: str
num_speakers: int
class _RefineSpeakerDiarizationResponse(Protocol):
segments_updated: int
error_message: str
job_id: str
status: int
class _RenameSpeakerRequest(Protocol):
meeting_id: str
old_speaker_id: str
new_speaker_name: str
class _RenameSpeakerResponse(Protocol):
segments_updated: int
success: bool
class _GetDiarizationJobStatusRequest(Protocol):
job_id: str
class _DiarizationJobStatusResponse(Protocol):
job_id: str
status: int
segments_updated: int
speaker_ids: Sequence[str]
error_message: str
progress_percent: float
class _CancelDiarizationJobRequest(Protocol):
job_id: str
class _CancelDiarizationJobResponse(Protocol):
success: bool
error_message: str
status: int
class _RefineSpeakerDiarizationCallable(Protocol):
async def __call__(
self,
request: _RefineSpeakerDiarizationRequest,
context: _MockGrpcContext,
) -> _RefineSpeakerDiarizationResponse: ...
class _RenameSpeakerCallable(Protocol):
async def __call__(
self,
request: _RenameSpeakerRequest,
context: _MockGrpcContext,
) -> _RenameSpeakerResponse: ...
class _GetDiarizationJobStatusCallable(Protocol):
async def __call__(
self,
request: _GetDiarizationJobStatusRequest,
context: _MockGrpcContext,
) -> _DiarizationJobStatusResponse: ...
class _CancelDiarizationJobCallable(Protocol):
async def __call__(
self,
request: _CancelDiarizationJobRequest,
context: _MockGrpcContext,
) -> _CancelDiarizationJobResponse: ...
class _MockGrpcContext:
"""Minimal async gRPC context for diarization mixin tests."""
@@ -140,6 +226,46 @@ def _get_store(servicer: NoteFlowServicer) -> MeetingStore:
return servicer.get_memory_store()
async def _call_refine(
servicer: NoteFlowServicer,
request: _RefineSpeakerDiarizationRequest,
context: _MockGrpcContext,
) -> _RefineSpeakerDiarizationResponse:
"""Call RefineSpeakerDiarization with typed response."""
refine = cast(_RefineSpeakerDiarizationCallable, servicer.RefineSpeakerDiarization)
return await refine(request, context)
async def _call_rename(
servicer: NoteFlowServicer,
request: _RenameSpeakerRequest,
context: _MockGrpcContext,
) -> _RenameSpeakerResponse:
"""Call RenameSpeaker with typed response."""
rename = cast(_RenameSpeakerCallable, servicer.RenameSpeaker)
return await rename(request, context)
async def _call_get_status(
servicer: NoteFlowServicer,
request: _GetDiarizationJobStatusRequest,
context: _MockGrpcContext,
) -> _DiarizationJobStatusResponse:
"""Call GetDiarizationJobStatus with typed response."""
get_status = cast(_GetDiarizationJobStatusCallable, servicer.GetDiarizationJobStatus)
return await get_status(request, context)
async def _call_cancel(
servicer: NoteFlowServicer,
request: _CancelDiarizationJobRequest,
context: _MockGrpcContext,
) -> _CancelDiarizationJobResponse:
"""Call CancelDiarizationJob with typed response."""
cancel = cast(_CancelDiarizationJobCallable, servicer.CancelDiarizationJob)
return await cancel(request, context)
class _MockDiarizationJobsRepo:
"""Mock diarization jobs repository for testing."""
@@ -186,8 +312,8 @@ class _MockDiarizationJobsRepo:
"""Get active job for meeting."""
for job in self._jobs.values():
if job.meeting_id == meeting_id and job.status in (
noteflow_pb2.JOB_STATUS_QUEUED,
noteflow_pb2.JOB_STATUS_RUNNING,
JOB_STATUS_QUEUED,
JOB_STATUS_RUNNING,
):
return job
return None
@@ -198,7 +324,7 @@ class _MockDiarizationJobsRepo:
job
for job in self._jobs.values()
if job.status
in (noteflow_pb2.JOB_STATUS_QUEUED, noteflow_pb2.JOB_STATUS_RUNNING)
in (JOB_STATUS_QUEUED, JOB_STATUS_RUNNING)
]
async def prune_completed(self, ttl_seconds: float) -> int:
@@ -277,7 +403,8 @@ class TestRefineSpeakerDiarizationValidation:
"""Invalid UUID format returns error response."""
context = _MockGrpcContext()
response = await diarization_servicer.RefineSpeakerDiarization(
response = await _call_refine(
diarization_servicer,
noteflow_pb2.RefineSpeakerDiarizationRequest(meeting_id="not-a-uuid"),
context,
)
@@ -293,7 +420,8 @@ class TestRefineSpeakerDiarizationValidation:
context = _MockGrpcContext()
meeting_id = str(uuid4())
response = await diarization_servicer.RefineSpeakerDiarization(
response = await _call_refine(
diarization_servicer,
noteflow_pb2.RefineSpeakerDiarizationRequest(meeting_id=meeting_id),
context,
)
@@ -316,7 +444,8 @@ class TestRefineSpeakerDiarizationState:
store.update(meeting)
context = _MockGrpcContext()
response = await diarization_servicer.RefineSpeakerDiarization(
response = await _call_refine(
diarization_servicer,
noteflow_pb2.RefineSpeakerDiarizationRequest(meeting_id=str(meeting.id)),
context,
)
@@ -336,7 +465,8 @@ class TestRefineSpeakerDiarizationState:
store.update(meeting)
context = _MockGrpcContext()
response = await diarization_servicer.RefineSpeakerDiarization(
response = await _call_refine(
diarization_servicer,
noteflow_pb2.RefineSpeakerDiarizationRequest(meeting_id=str(meeting.id)),
context,
)
@@ -357,13 +487,14 @@ class TestRefineSpeakerDiarizationState:
store.update(meeting)
context = _MockGrpcContext()
response = await diarization_servicer_with_db.RefineSpeakerDiarization(
response = await _call_refine(
diarization_servicer_with_db,
noteflow_pb2.RefineSpeakerDiarizationRequest(meeting_id=str(meeting.id)),
context,
)
assert response.job_id, "Should return job_id for accepted request"
assert response.status == noteflow_pb2.JOB_STATUS_QUEUED, "Job status should be QUEUED"
assert response.status == JOB_STATUS_QUEUED, "Job status should be QUEUED"
class TestRefineSpeakerDiarizationServer:
@@ -382,7 +513,8 @@ class TestRefineSpeakerDiarizationServer:
store.update(meeting)
context = _MockGrpcContext()
response = await diarization_servicer_disabled.RefineSpeakerDiarization(
response = await _call_refine(
diarization_servicer_disabled,
noteflow_pb2.RefineSpeakerDiarizationRequest(meeting_id=str(meeting.id)),
context,
)
@@ -403,7 +535,8 @@ class TestRefineSpeakerDiarizationServer:
store.update(meeting)
context = _MockGrpcContext()
response = await diarization_servicer_no_engine.RefineSpeakerDiarization(
response = await _call_refine(
diarization_servicer_no_engine,
noteflow_pb2.RefineSpeakerDiarizationRequest(meeting_id=str(meeting.id)),
context,
)
@@ -423,7 +556,8 @@ class TestRenameSpeakerValidation:
context = _MockGrpcContext()
with pytest.raises(AssertionError, match="abort called"):
await diarization_servicer.RenameSpeaker(
await _call_rename(
diarization_servicer,
noteflow_pb2.RenameSpeakerRequest(
meeting_id=str(uuid4()),
old_speaker_id="",
@@ -442,7 +576,8 @@ class TestRenameSpeakerValidation:
context = _MockGrpcContext()
with pytest.raises(AssertionError, match="abort called"):
await diarization_servicer.RenameSpeaker(
await _call_rename(
diarization_servicer,
noteflow_pb2.RenameSpeakerRequest(
meeting_id=str(uuid4()),
old_speaker_id="SPEAKER_0",
@@ -461,7 +596,8 @@ class TestRenameSpeakerValidation:
context = _MockGrpcContext()
with pytest.raises(AssertionError, match="abort called"):
await diarization_servicer.RenameSpeaker(
await _call_rename(
diarization_servicer,
noteflow_pb2.RenameSpeakerRequest(
meeting_id="invalid-uuid",
old_speaker_id="SPEAKER_0",
@@ -496,7 +632,8 @@ class TestRenameSpeakerOperation:
]
store.update(meeting)
response = await diarization_servicer.RenameSpeaker(
response = await _call_rename(
diarization_servicer,
noteflow_pb2.RenameSpeakerRequest(
meeting_id=str(meeting.id), old_speaker_id="SPEAKER_0", new_speaker_name="Alice"
),
@@ -531,7 +668,8 @@ class TestRenameSpeakerOperation:
store.update(meeting)
context = _MockGrpcContext()
response = await diarization_servicer.RenameSpeaker(
response = await _call_rename(
diarization_servicer,
noteflow_pb2.RenameSpeakerRequest(
meeting_id=str(meeting.id),
old_speaker_id="SPEAKER_0",
@@ -560,11 +698,12 @@ class TestGetDiarizationJobStatusProgress:
job_id = str(uuid4())
await mock_diarization_jobs_repo.create(
_create_test_job(job_id, str(meeting.id), noteflow_pb2.JOB_STATUS_QUEUED)
_create_test_job(job_id, str(meeting.id), JOB_STATUS_QUEUED)
)
context = _MockGrpcContext()
response = await diarization_servicer_with_db.GetDiarizationJobStatus(
response = await _call_get_status(
diarization_servicer_with_db,
noteflow_pb2.GetDiarizationJobStatusRequest(job_id=job_id),
context,
)
@@ -587,14 +726,15 @@ class TestGetDiarizationJobStatusProgress:
_create_test_job(
job_id,
str(meeting.id),
noteflow_pb2.JOB_STATUS_RUNNING,
JOB_STATUS_RUNNING,
started_at=utc_now() - timedelta(seconds=10),
audio_duration_seconds=120.0,
)
)
context = _MockGrpcContext()
response = await diarization_servicer_with_db.GetDiarizationJobStatus(
response = await _call_get_status(
diarization_servicer_with_db,
noteflow_pb2.GetDiarizationJobStatusRequest(job_id=job_id),
context,
)
@@ -615,11 +755,12 @@ class TestGetDiarizationJobStatusProgress:
job_id = str(uuid4())
await mock_diarization_jobs_repo.create(
_create_test_job(job_id, str(meeting.id), noteflow_pb2.JOB_STATUS_COMPLETED)
_create_test_job(job_id, str(meeting.id), JOB_STATUS_COMPLETED)
)
context = _MockGrpcContext()
response = await diarization_servicer_with_db.GetDiarizationJobStatus(
response = await _call_get_status(
diarization_servicer_with_db,
noteflow_pb2.GetDiarizationJobStatusRequest(job_id=job_id),
context,
)
@@ -639,11 +780,12 @@ class TestGetDiarizationJobStatusProgress:
job_id = str(uuid4())
await mock_diarization_jobs_repo.create(
_create_test_job(job_id, str(meeting.id), noteflow_pb2.JOB_STATUS_FAILED)
_create_test_job(job_id, str(meeting.id), JOB_STATUS_FAILED)
)
context = _MockGrpcContext()
response = await diarization_servicer_with_db.GetDiarizationJobStatus(
response = await _call_get_status(
diarization_servicer_with_db,
noteflow_pb2.GetDiarizationJobStatusRequest(job_id=job_id),
context,
)
@@ -670,17 +812,18 @@ class TestCancelDiarizationJobStates:
job_id = str(uuid4())
await mock_diarization_jobs_repo.create(
_create_test_job(job_id, str(meeting.id), noteflow_pb2.JOB_STATUS_QUEUED)
_create_test_job(job_id, str(meeting.id), JOB_STATUS_QUEUED)
)
context = _MockGrpcContext()
response = await diarization_servicer_with_db.CancelDiarizationJob(
response = await _call_cancel(
diarization_servicer_with_db,
noteflow_pb2.CancelDiarizationJobRequest(job_id=job_id),
context,
)
assert response.success is True, "Cancelling queued job should succeed"
assert response.status == noteflow_pb2.JOB_STATUS_CANCELLED, "Cancelled queued job status should be CANCELLED"
assert response.status == JOB_STATUS_CANCELLED, "Cancelled queued job status should be CANCELLED"
@pytest.mark.asyncio
async def test_cancel_mixin_running_succeeds(
@@ -695,17 +838,18 @@ class TestCancelDiarizationJobStates:
job_id = str(uuid4())
await mock_diarization_jobs_repo.create(
_create_test_job(job_id, str(meeting.id), noteflow_pb2.JOB_STATUS_RUNNING)
_create_test_job(job_id, str(meeting.id), JOB_STATUS_RUNNING)
)
context = _MockGrpcContext()
response = await diarization_servicer_with_db.CancelDiarizationJob(
response = await _call_cancel(
diarization_servicer_with_db,
noteflow_pb2.CancelDiarizationJobRequest(job_id=job_id),
context,
)
assert response.success is True, "Cancelling running job should succeed"
assert response.status == noteflow_pb2.JOB_STATUS_CANCELLED, "Cancelled running job status should be CANCELLED"
assert response.status == JOB_STATUS_CANCELLED, "Cancelled running job status should be CANCELLED"
@pytest.mark.asyncio
async def test_cancel_mixin_nonexistent_fails(
@@ -715,7 +859,8 @@ class TestCancelDiarizationJobStates:
"""Nonexistent job cannot be cancelled."""
context = _MockGrpcContext()
response = await diarization_servicer_with_db.CancelDiarizationJob(
response = await _call_cancel(
diarization_servicer_with_db,
noteflow_pb2.CancelDiarizationJobRequest(job_id="nonexistent-job"),
context,
)
@@ -736,14 +881,15 @@ class TestCancelDiarizationJobStates:
job_id = str(uuid4())
await mock_diarization_jobs_repo.create(
_create_test_job(job_id, str(meeting.id), noteflow_pb2.JOB_STATUS_COMPLETED)
_create_test_job(job_id, str(meeting.id), JOB_STATUS_COMPLETED)
)
context = _MockGrpcContext()
response = await diarization_servicer_with_db.CancelDiarizationJob(
response = await _call_cancel(
diarization_servicer_with_db,
noteflow_pb2.CancelDiarizationJobRequest(job_id=job_id),
context,
)
assert response.success is False, "Cancelling completed job should fail"
assert response.status == noteflow_pb2.JOB_STATUS_COMPLETED, "Completed job status should remain COMPLETED"
assert response.status == JOB_STATUS_COMPLETED, "Completed job status should remain COMPLETED"

View File

@@ -2,6 +2,8 @@
from __future__ import annotations
from typing import Protocol, cast
import grpc
import pytest
@@ -11,6 +13,24 @@ from noteflow.grpc.service import NoteFlowServicer
from noteflow.infrastructure.diarization import DiarizationEngine
class _RefineSpeakerDiarizationRequest(Protocol):
meeting_id: str
num_speakers: int
class _RefineSpeakerDiarizationResponse(Protocol):
segments_updated: int
error_message: str
class _RefineSpeakerDiarizationCallable(Protocol):
async def __call__(
self,
request: _RefineSpeakerDiarizationRequest,
context: _DummyContext,
) -> _RefineSpeakerDiarizationResponse: ...
class _DummyContext:
"""Minimal gRPC context that raises if abort is invoked."""
@@ -43,7 +63,8 @@ async def test_refine_speaker_diarization_rejects_active_meeting() -> None:
meeting.start_recording()
store.update(meeting)
response = await servicer.RefineSpeakerDiarization(
refine = cast(_RefineSpeakerDiarizationCallable, servicer.RefineSpeakerDiarization)
response = await refine(
noteflow_pb2.RefineSpeakerDiarizationRequest(meeting_id=str(meeting.id)),
_DummyContext(),
)

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import TypedDict, Unpack
import grpc
import pytest
@@ -55,14 +56,17 @@ async def test_generate_summary_uses_placeholder_when_service_missing() -> None:
class _FailingSummarizationService(SummarizationService):
"""Summarization service that always reports provider unavailability."""
class _Options(TypedDict, total=False):
mode: SummarizationMode | None
max_key_points: int | None
max_action_items: int | None
style_prompt: str | None
async def summarize(
self,
meeting_id: MeetingId,
segments: Sequence[Segment],
mode: SummarizationMode | None = None,
max_key_points: int | None = None,
max_action_items: int | None = None,
style_prompt: str | None = None,
**kwargs: Unpack[_Options],
) -> SummarizationServiceResult:
raise ProviderUnavailableError("LLM unavailable")

View File

@@ -10,7 +10,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from unittest.mock import AsyncMock, MagicMock
from uuid import UUID, uuid4
from uuid import uuid4
import pytest

View File

@@ -6,7 +6,7 @@ Tests identity context validation and per-RPC request logging.
from __future__ import annotations
from collections.abc import Awaitable, Callable
from typing import TypeVar
from typing import Protocol, cast
from unittest.mock import AsyncMock, MagicMock, patch
import grpc
@@ -29,8 +29,12 @@ from noteflow.infrastructure.logging import (
workspace_id_var,
)
_TRequest = TypeVar("_TRequest")
_TResponse = TypeVar("_TResponse")
class _DummyRequest:
"""Placeholder request type for handler casts."""
class _DummyResponse:
"""Placeholder response type for handler casts."""
# Test data
TEST_REQUEST_ID = "test-request-123"
@@ -58,30 +62,62 @@ def create_handler_call_details(
return details
def create_mock_handler() -> grpc.RpcMethodHandler[_TRequest, _TResponse]:
def create_mock_handler() -> _UnaryUnaryHandler:
"""Create a mock RPC method handler."""
handler = MagicMock(spec=grpc.RpcMethodHandler)
handler.unary_unary = AsyncMock(return_value="response")
handler.unary_stream = None
handler.stream_unary = None
handler.stream_stream = None
handler.request_deserializer = None
handler.response_serializer = None
return handler
return _MockHandler()
def create_mock_continuation(
handler: grpc.RpcMethodHandler[_TRequest, _TResponse] | None = None,
) -> Callable[
[grpc.HandlerCallDetails],
Awaitable[grpc.RpcMethodHandler[_TRequest, _TResponse]],
]:
handler: _UnaryUnaryHandler | None = None,
) -> AsyncMock:
"""Create a mock continuation function."""
if handler is None:
handler = create_mock_handler()
return AsyncMock(return_value=handler)
class _UnaryUnaryHandler(Protocol):
"""Protocol for unary-unary RPC method handlers."""
unary_unary: Callable[
[_DummyRequest, aio.ServicerContext[_DummyRequest, _DummyResponse]],
Awaitable[_DummyResponse],
] | None
unary_stream: object | None
stream_unary: object | None
stream_stream: object | None
request_deserializer: object | None
response_serializer: object | None
class _MockHandler:
"""Concrete handler for tests with typed unary_unary."""
unary_unary: Callable[
[_DummyRequest, aio.ServicerContext[_DummyRequest, _DummyResponse]],
Awaitable[_DummyResponse],
] | None
unary_stream: object | None
stream_unary: object | None
stream_stream: object | None
request_deserializer: object | None
response_serializer: object | None
def __init__(self) -> None:
self.unary_unary = cast(
Callable[
[_DummyRequest, aio.ServicerContext[_DummyRequest, _DummyResponse]],
Awaitable[_DummyResponse],
],
AsyncMock(return_value="response"),
)
self.unary_stream = None
self.stream_unary = None
self.stream_stream = None
self.request_deserializer = None
self.response_serializer = None
class TestIdentityInterceptor:
"""Tests for IdentityInterceptor."""
@@ -126,9 +162,10 @@ class TestIdentityInterceptor:
continuation = create_mock_continuation()
handler = await interceptor.intercept_service(continuation, details)
typed_handler = cast(_UnaryUnaryHandler, handler)
# Handler should be a rejection handler, not the original
assert handler.unary_unary is not None, "handler should have unary_unary"
assert typed_handler.unary_unary is not None, "handler should have unary_unary"
# Continuation should NOT have been called
continuation.assert_not_called()
@@ -140,13 +177,15 @@ class TestIdentityInterceptor:
continuation = create_mock_continuation()
handler = await interceptor.intercept_service(continuation, details)
typed_handler = cast(_UnaryUnaryHandler, handler)
# Create mock context to verify abort behavior
context = AsyncMock(spec=aio.ServicerContext)
context.abort = AsyncMock(side_effect=grpc.RpcError("missing x-request-id"))
with pytest.raises(grpc.RpcError, match="x-request-id"):
await handler.unary_unary(MagicMock(), context)
assert typed_handler.unary_unary is not None
await typed_handler.unary_unary(MagicMock(), context)
context.abort.assert_called_once()
call_args = context.abort.call_args
@@ -187,11 +226,13 @@ class TestRequestLoggingInterceptor:
with patch("noteflow.grpc.interceptors.logging.logger") as mock_logger:
wrapped_handler = await interceptor.intercept_service(continuation, details)
typed_handler = cast(_UnaryUnaryHandler, wrapped_handler)
# Execute the wrapped handler
context = AsyncMock(spec=aio.ServicerContext)
context.peer = MagicMock(return_value="ipv4:127.0.0.1:12345")
await wrapped_handler.unary_unary(MagicMock(), context)
assert typed_handler.unary_unary is not None
await typed_handler.unary_unary(MagicMock(), context)
# Verify logging
mock_logger.info.assert_called_once()
@@ -215,12 +256,14 @@ class TestRequestLoggingInterceptor:
with patch("noteflow.grpc.interceptors.logging.logger") as mock_logger:
wrapped_handler = await interceptor.intercept_service(continuation, details)
typed_handler = cast(_UnaryUnaryHandler, wrapped_handler)
context = AsyncMock(spec=aio.ServicerContext)
context.peer = MagicMock(return_value="ipv4:127.0.0.1:12345")
with pytest.raises(Exception, match="Test error"):
await wrapped_handler.unary_unary(MagicMock(), context)
assert typed_handler.unary_unary is not None
await typed_handler.unary_unary(MagicMock(), context)
# Should still log with INTERNAL status
mock_logger.info.assert_called_once()
@@ -249,12 +292,14 @@ class TestRequestLoggingInterceptor:
with patch("noteflow.grpc.interceptors.logging.logger") as mock_logger:
wrapped_handler = await interceptor.intercept_service(continuation, details)
typed_handler = cast(_UnaryUnaryHandler, wrapped_handler)
# Context without peer method
context = AsyncMock(spec=aio.ServicerContext)
context.peer = MagicMock(side_effect=RuntimeError("No peer"))
await wrapped_handler.unary_unary(MagicMock(), context)
assert typed_handler.unary_unary is not None
await typed_handler.unary_unary(MagicMock(), context)
# Should still log with None peer
mock_logger.info.assert_called_once()

View File

@@ -7,10 +7,12 @@ DisconnectOAuth RPCs work correctly with the calendar service.
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from typing import TYPE_CHECKING
from collections.abc import Sequence
from typing import TYPE_CHECKING, Protocol, cast
from unittest.mock import AsyncMock, MagicMock
from uuid import UUID, uuid4
import grpc
import pytest
from noteflow.application.services.calendar_service import CalendarServiceError
@@ -28,16 +30,166 @@ class _DummyContext:
def __init__(self) -> None:
self.aborted = False
self.abort_code: object = None
self.abort_code: grpc.StatusCode | None = None
self.abort_details: str = ""
async def abort(self, code: object, details: str) -> None:
async def abort(self, code: grpc.StatusCode, details: str) -> None:
self.aborted = True
self.abort_code = code
self.abort_details = details
raise AssertionError(f"abort called: {code} - {details}")
class _CalendarProvider(Protocol):
name: str
is_authenticated: bool
display_name: str
class _GetCalendarProvidersResponse(Protocol):
providers: Sequence[_CalendarProvider]
class _GetCalendarProvidersCallable(Protocol):
async def __call__(
self,
request: noteflow_pb2.GetCalendarProvidersRequest,
context: _DummyContext,
) -> _GetCalendarProvidersResponse: ...
async def _call_get_calendar_providers(
servicer: NoteFlowServicer,
request: noteflow_pb2.GetCalendarProvidersRequest,
context: _DummyContext,
) -> _GetCalendarProvidersResponse:
get_providers = cast(
_GetCalendarProvidersCallable,
servicer.GetCalendarProviders,
)
return await get_providers(request, context)
class _InitiateOAuthRequest(Protocol):
provider: str
redirect_uri: str
integration_type: str
class _InitiateOAuthResponse(Protocol):
auth_url: str
state: str
class _CompleteOAuthRequest(Protocol):
provider: str
code: str
state: str
class _CompleteOAuthResponse(Protocol):
success: bool
error_message: str
provider_email: str
integration_id: str
class _OAuthConnection(Protocol):
provider: str
status: str
email: str
expires_at: str
error_message: str
integration_type: str
class _GetOAuthConnectionStatusRequest(Protocol):
provider: str
integration_type: str
class _GetOAuthConnectionStatusResponse(Protocol):
connection: _OAuthConnection
class _DisconnectOAuthRequest(Protocol):
provider: str
integration_type: str
class _DisconnectOAuthResponse(Protocol):
success: bool
error_message: str
class _InitiateOAuthCallable(Protocol):
async def __call__(
self,
request: _InitiateOAuthRequest,
context: _DummyContext,
) -> _InitiateOAuthResponse: ...
class _CompleteOAuthCallable(Protocol):
async def __call__(
self,
request: _CompleteOAuthRequest,
context: _DummyContext,
) -> _CompleteOAuthResponse: ...
class _GetOAuthConnectionStatusCallable(Protocol):
async def __call__(
self,
request: _GetOAuthConnectionStatusRequest,
context: _DummyContext,
) -> _GetOAuthConnectionStatusResponse: ...
class _DisconnectOAuthCallable(Protocol):
async def __call__(
self,
request: _DisconnectOAuthRequest,
context: _DummyContext,
) -> _DisconnectOAuthResponse: ...
async def _call_initiate_oauth(
servicer: NoteFlowServicer,
request: _InitiateOAuthRequest,
context: _DummyContext,
) -> _InitiateOAuthResponse:
initiate = cast(_InitiateOAuthCallable, servicer.InitiateOAuth)
return await initiate(request, context)
async def _call_complete_oauth(
servicer: NoteFlowServicer,
request: _CompleteOAuthRequest,
context: _DummyContext,
) -> _CompleteOAuthResponse:
complete = cast(_CompleteOAuthCallable, servicer.CompleteOAuth)
return await complete(request, context)
async def _call_get_oauth_status(
servicer: NoteFlowServicer,
request: _GetOAuthConnectionStatusRequest,
context: _DummyContext,
) -> _GetOAuthConnectionStatusResponse:
get_status = cast(_GetOAuthConnectionStatusCallable, servicer.GetOAuthConnectionStatus)
return await get_status(request, context)
async def _call_disconnect_oauth(
servicer: NoteFlowServicer,
request: _DisconnectOAuthRequest,
context: _DummyContext,
) -> _DisconnectOAuthResponse:
disconnect = cast(_DisconnectOAuthCallable, servicer.DisconnectOAuth)
return await disconnect(request, context)
def _create_mock_connection_info(
*,
provider: str = "google",
@@ -94,7 +246,8 @@ class TestGetCalendarProviders:
service = _create_mockcalendar_service()
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
response = await servicer.GetCalendarProviders(
response = await _call_get_calendar_providers(
servicer,
noteflow_pb2.GetCalendarProvidersRequest(),
_DummyContext(),
)
@@ -112,7 +265,8 @@ class TestGetCalendarProviders:
)
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
response = await servicer.GetCalendarProviders(
response = await _call_get_calendar_providers(
servicer,
noteflow_pb2.GetCalendarProvidersRequest(),
_DummyContext(),
)
@@ -129,7 +283,8 @@ class TestGetCalendarProviders:
service = _create_mockcalendar_service()
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
response = await servicer.GetCalendarProviders(
response = await _call_get_calendar_providers(
servicer,
noteflow_pb2.GetCalendarProvidersRequest(),
_DummyContext(),
)
@@ -147,7 +302,8 @@ class TestGetCalendarProviders:
context = _DummyContext()
with pytest.raises(AssertionError, match="abort called"):
await servicer.GetCalendarProviders(
await _call_get_calendar_providers(
servicer,
noteflow_pb2.GetCalendarProvidersRequest(),
context,
)
@@ -168,7 +324,8 @@ class TestInitiateOAuth:
)
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
response = await servicer.InitiateOAuth(
response = await _call_initiate_oauth(
servicer,
noteflow_pb2.InitiateOAuthRequest(provider="google"),
_DummyContext(),
)
@@ -183,7 +340,8 @@ class TestInitiateOAuth:
service.initiate_oauth.return_value = ("https://auth.url", "state")
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
await servicer.InitiateOAuth(
await _call_initiate_oauth(
servicer,
noteflow_pb2.InitiateOAuthRequest(provider="outlook"),
_DummyContext(),
)
@@ -200,7 +358,8 @@ class TestInitiateOAuth:
service.initiate_oauth.return_value = ("https://auth.url", "state")
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
await servicer.InitiateOAuth(
await _call_initiate_oauth(
servicer,
noteflow_pb2.InitiateOAuthRequest(
provider="google",
redirect_uri="noteflow://oauth/callback",
@@ -222,7 +381,8 @@ class TestInitiateOAuth:
context = _DummyContext()
with pytest.raises(AssertionError, match="abort called"):
await servicer.InitiateOAuth(
await _call_initiate_oauth(
servicer,
noteflow_pb2.InitiateOAuthRequest(provider="unknown"),
context,
)
@@ -237,7 +397,8 @@ class TestInitiateOAuth:
context = _DummyContext()
with pytest.raises(AssertionError, match="abort called"):
await servicer.InitiateOAuth(
await _call_initiate_oauth(
servicer,
noteflow_pb2.InitiateOAuthRequest(provider="google"),
context,
)
@@ -258,7 +419,8 @@ class TestCompleteOAuth:
service.complete_oauth.return_value = uuid4() # Returns integration ID
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
response = await servicer.CompleteOAuth(
response = await _call_complete_oauth(
servicer,
noteflow_pb2.CompleteOAuthRequest(
provider="google",
code="authorization-code",
@@ -278,7 +440,8 @@ class TestCompleteOAuth:
service.complete_oauth.return_value = uuid4() # Returns integration ID
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
await servicer.CompleteOAuth(
await _call_complete_oauth(
servicer,
noteflow_pb2.CompleteOAuthRequest(
provider="google",
code="my-auth-code",
@@ -302,7 +465,8 @@ class TestCompleteOAuth:
)
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
response = await servicer.CompleteOAuth(
response = await _call_complete_oauth(
servicer,
noteflow_pb2.CompleteOAuthRequest(
provider="google",
code="authorization-code",
@@ -323,7 +487,8 @@ class TestCompleteOAuth:
)
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
response = await servicer.CompleteOAuth(
response = await _call_complete_oauth(
servicer,
noteflow_pb2.CompleteOAuthRequest(
provider="google",
code="invalid-code",
@@ -342,7 +507,8 @@ class TestCompleteOAuth:
context = _DummyContext()
with pytest.raises(AssertionError, match="abort called"):
await servicer.CompleteOAuth(
await _call_complete_oauth(
servicer,
noteflow_pb2.CompleteOAuthRequest(
provider="google",
code="code",
@@ -366,7 +532,8 @@ class TestGetOAuthConnectionStatus:
)
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
response = await servicer.GetOAuthConnectionStatus(
response = await _call_get_oauth_status(
servicer,
noteflow_pb2.GetOAuthConnectionStatusRequest(provider="google"),
_DummyContext(),
)
@@ -383,7 +550,8 @@ class TestGetOAuthConnectionStatus:
)
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
response = await servicer.GetOAuthConnectionStatus(
response = await _call_get_oauth_status(
servicer,
noteflow_pb2.GetOAuthConnectionStatusRequest(provider="google"),
_DummyContext(),
)
@@ -396,7 +564,8 @@ class TestGetOAuthConnectionStatus:
service = _create_mockcalendar_service()
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
response = await servicer.GetOAuthConnectionStatus(
response = await _call_get_oauth_status(
servicer,
noteflow_pb2.GetOAuthConnectionStatusRequest(
provider="google",
integration_type="calendar",
@@ -413,7 +582,8 @@ class TestGetOAuthConnectionStatus:
context = _DummyContext()
with pytest.raises(AssertionError, match="abort called"):
await servicer.GetOAuthConnectionStatus(
await _call_get_oauth_status(
servicer,
noteflow_pb2.GetOAuthConnectionStatusRequest(provider="google"),
context,
)
@@ -433,7 +603,8 @@ class TestDisconnectOAuth:
service.disconnect.return_value = True
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
response = await servicer.DisconnectOAuth(
response = await _call_disconnect_oauth(
servicer,
noteflow_pb2.DisconnectOAuthRequest(provider="google"),
_DummyContext(),
)
@@ -447,7 +618,8 @@ class TestDisconnectOAuth:
service.disconnect.return_value = True
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
await servicer.DisconnectOAuth(
await _call_disconnect_oauth(
servicer,
noteflow_pb2.DisconnectOAuthRequest(provider="outlook"),
_DummyContext(),
)
@@ -463,7 +635,8 @@ class TestDisconnectOAuth:
service.disconnect.return_value = False
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
response = await servicer.DisconnectOAuth(
response = await _call_disconnect_oauth(
servicer,
noteflow_pb2.DisconnectOAuthRequest(provider="google"),
_DummyContext(),
)
@@ -477,7 +650,8 @@ class TestDisconnectOAuth:
context = _DummyContext()
with pytest.raises(AssertionError, match="abort called"):
await servicer.DisconnectOAuth(
await _call_disconnect_oauth(
servicer,
noteflow_pb2.DisconnectOAuthRequest(provider="google"),
context,
)
@@ -538,7 +712,8 @@ class TestOAuthRoundTrip:
service, _, _ = oauth_flow_service
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
response = await servicer.InitiateOAuth(
response = await _call_initiate_oauth(
servicer,
noteflow_pb2.InitiateOAuthRequest(provider="google"),
_DummyContext(),
)
@@ -556,7 +731,8 @@ class TestOAuthRoundTrip:
assert connected_state["google"] is False, "should start disconnected"
response = await servicer.CompleteOAuth(
response = await _call_complete_oauth(
servicer,
noteflow_pb2.CompleteOAuthRequest(
provider="google",
code="auth-code",
@@ -580,7 +756,8 @@ class TestOAuthRoundTrip:
connected_state["google"] = True
email_state["google"] = "user@gmail.com"
response = await servicer.DisconnectOAuth(
response = await _call_disconnect_oauth(
servicer,
noteflow_pb2.DisconnectOAuthRequest(provider="google"),
_DummyContext(),
)
@@ -598,7 +775,8 @@ class TestOAuthRoundTrip:
)
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
response = await servicer.CompleteOAuth(
response = await _call_complete_oauth(
servicer,
noteflow_pb2.CompleteOAuthRequest(
provider="google",
code="auth-code",
@@ -620,11 +798,13 @@ class TestOAuthRoundTrip:
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
ctx = _DummyContext()
google_status = await servicer.GetOAuthConnectionStatus(
google_status = await _call_get_oauth_status(
servicer,
noteflow_pb2.GetOAuthConnectionStatusRequest(provider="google"),
ctx,
)
outlook_status = await servicer.GetOAuthConnectionStatus(
outlook_status = await _call_get_oauth_status(
servicer,
noteflow_pb2.GetOAuthConnectionStatusRequest(provider="outlook"),
ctx,
)
@@ -644,7 +824,8 @@ class TestOAuthSecurityBehavior:
service.complete_oauth.side_effect = CalendarServiceError("State mismatch")
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
response = await servicer.CompleteOAuth(
response = await _call_complete_oauth(
servicer,
noteflow_pb2.CompleteOAuthRequest(
provider="google",
code="stolen-code",
@@ -664,7 +845,8 @@ class TestOAuthSecurityBehavior:
service.disconnect.return_value = True
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
await servicer.DisconnectOAuth(
await _call_disconnect_oauth(
servicer,
noteflow_pb2.DisconnectOAuthRequest(provider="google"),
_DummyContext(),
)
@@ -680,7 +862,8 @@ class TestOAuthSecurityBehavior:
)
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
response = await servicer.CompleteOAuth(
response = await _call_complete_oauth(
servicer,
noteflow_pb2.CompleteOAuthRequest(
provider="google",
code="code",

View File

@@ -10,10 +10,10 @@ list_all repository method tests.
from __future__ import annotations
import asyncio
from collections.abc import Callable
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from datetime import datetime
from typing import cast
from typing import Protocol, cast
from uuid import UUID, uuid4
import grpc
@@ -23,6 +23,7 @@ from noteflow.application.services.calendar_service import CalendarService
from noteflow.domain.entities.integration import (
Integration,
IntegrationType,
SyncRun,
)
from noteflow.grpc._config import ServicesConfig
from noteflow.grpc.meeting_store import MeetingStore
@@ -30,7 +31,7 @@ from noteflow.grpc.proto import noteflow_pb2
from noteflow.grpc.service import NoteFlowServicer
def _get_status_code_not_found() -> object:
def _get_status_code_not_found() -> grpc.StatusCode:
"""Helper function to get StatusCode.NOT_FOUND for type checker compatibility."""
attr_name = "StatusCode"
status_code = getattr(grpc, attr_name)
@@ -42,16 +43,132 @@ class _DummyContext:
def __init__(self) -> None:
self.aborted = False
self.abort_code: object = None
self.abort_code: grpc.StatusCode | None = None
self.abort_details: str | None = None
async def abort(self, code: object, details: str) -> None:
async def abort(self, code: grpc.StatusCode, details: str) -> None:
self.aborted = True
self.abort_code = code
self.abort_details = details
raise AssertionError(f"abort called: {code} - {details}")
class _StartIntegrationSyncRequest(Protocol):
integration_id: str
class _StartIntegrationSyncResponse(Protocol):
sync_run_id: str
status: str
class _GetSyncStatusRequest(Protocol):
sync_run_id: str
class _GetSyncStatusResponse(Protocol):
status: str
items_synced: int
items_total: int
error_message: str
duration_ms: int
expires_at: str
class _ListSyncHistoryRequest(Protocol):
integration_id: str
limit: int
class _ListSyncHistoryResponse(Protocol):
total_count: int
runs: Sequence["_SyncRunInfo"]
class _SyncRunInfo(Protocol):
id: str
class _IntegrationInfo(Protocol):
id: str
name: str
type: str
status: str
workspace_id: str
class _GetUserIntegrationsResponse(Protocol):
integrations: Sequence[_IntegrationInfo]
class _StartIntegrationSyncCallable(Protocol):
async def __call__(
self,
request: _StartIntegrationSyncRequest,
context: _DummyContext,
) -> _StartIntegrationSyncResponse: ...
class _GetSyncStatusCallable(Protocol):
async def __call__(
self,
request: _GetSyncStatusRequest,
context: _DummyContext,
) -> _GetSyncStatusResponse: ...
class _ListSyncHistoryCallable(Protocol):
async def __call__(
self,
request: _ListSyncHistoryRequest,
context: _DummyContext,
) -> _ListSyncHistoryResponse: ...
class _GetUserIntegrationsCallable(Protocol):
async def __call__(
self,
request: noteflow_pb2.GetUserIntegrationsRequest,
context: _DummyContext,
) -> _GetUserIntegrationsResponse: ...
async def _call_start_sync(
servicer: NoteFlowServicer,
request: _StartIntegrationSyncRequest,
context: _DummyContext,
) -> _StartIntegrationSyncResponse:
start_sync = cast(_StartIntegrationSyncCallable, servicer.StartIntegrationSync)
return await start_sync(request, context)
async def _call_get_sync_status(
servicer: NoteFlowServicer,
request: _GetSyncStatusRequest,
context: _DummyContext,
) -> _GetSyncStatusResponse:
get_status = cast(_GetSyncStatusCallable, servicer.GetSyncStatus)
return await get_status(request, context)
async def _call_list_sync_history(
servicer: NoteFlowServicer,
request: _ListSyncHistoryRequest,
context: _DummyContext,
) -> _ListSyncHistoryResponse:
list_history = cast(_ListSyncHistoryCallable, servicer.ListSyncHistory)
return await list_history(request, context)
async def _call_get_user_integrations(
servicer: NoteFlowServicer,
request: noteflow_pb2.GetUserIntegrationsRequest,
context: _DummyContext,
) -> _GetUserIntegrationsResponse:
get_integrations = cast(_GetUserIntegrationsCallable, servicer.GetUserIntegrations)
return await get_integrations(request, context)
@dataclass
class MockCalendarEvent:
"""Mock calendar event for testing."""
@@ -166,17 +283,18 @@ async def await_sync_completion(
sync_run_id: str,
context: _DummyContext,
timeout: float = 2.0,
) -> noteflow_pb2.GetSyncStatusResponse:
) -> _GetSyncStatusResponse:
"""Wait for sync to complete using event-based synchronization.
Uses asyncio.wait_for for timeout instead of polling loop.
The sync task runs in the background; we wait then check final status.
"""
async def _get_final_status() -> noteflow_pb2.GetSyncStatusResponse:
async def _get_final_status() -> _GetSyncStatusResponse:
# Brief delay to allow background sync to complete
await asyncio.sleep(0.05)
return await servicer.GetSyncStatus(
return await _call_get_sync_status(
servicer,
noteflow_pb2.GetSyncStatusRequest(sync_run_id=sync_run_id),
context,
)
@@ -194,7 +312,8 @@ class TestStartIntegrationSync:
context = _DummyContext()
with pytest.raises(AssertionError, match="abort called"):
await servicer.StartIntegrationSync(
await _call_start_sync(
servicer,
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(uuid4())),
context,
)
@@ -208,7 +327,8 @@ class TestStartIntegrationSync:
context = _DummyContext()
with pytest.raises(AssertionError, match="abort called"):
await servicer.StartIntegrationSync(
await _call_start_sync(
servicer,
noteflow_pb2.StartIntegrationSyncRequest(integration_id=""),
context,
)
@@ -226,7 +346,8 @@ class TestSyncStatus:
context = _DummyContext()
with pytest.raises(AssertionError, match="abort called"):
await servicer.GetSyncStatus(
await _call_get_sync_status(
servicer,
noteflow_pb2.GetSyncStatusRequest(sync_run_id=str(uuid4())),
context,
)
@@ -240,7 +361,8 @@ class TestSyncStatus:
context = _DummyContext()
with pytest.raises(AssertionError, match="abort called"):
await servicer.GetSyncStatus(
await _call_get_sync_status(
servicer,
noteflow_pb2.GetSyncStatusRequest(sync_run_id=""),
context,
)
@@ -257,7 +379,8 @@ class TestSyncHistory:
servicer = NoteFlowServicer()
context = _DummyContext()
response = await servicer.ListSyncHistory(
response = await _call_list_sync_history(
servicer,
noteflow_pb2.ListSyncHistoryRequest(integration_id=str(uuid4()), limit=10),
context,
)
@@ -271,7 +394,8 @@ class TestSyncHistory:
servicer = NoteFlowServicer()
context = _DummyContext()
response = await servicer.ListSyncHistory(
response = await _call_list_sync_history(
servicer,
noteflow_pb2.ListSyncHistoryRequest(integration_id=str(uuid4())),
context,
)
@@ -290,7 +414,8 @@ class TestSyncHappyPath:
integration = await create_test_integration(meeting_store)
context = _DummyContext()
response = await servicer_with_success.StartIntegrationSync(
response = await _call_start_sync(
servicer_with_success,
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(integration.id)),
context,
)
@@ -315,7 +440,8 @@ class TestSyncHappyPath:
integration = await create_test_integration(meeting_store)
context = _DummyContext()
start = await servicer.StartIntegrationSync(
start = await _call_start_sync(
servicer,
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(integration.id)),
context,
)
@@ -340,13 +466,15 @@ class TestSyncHappyPath:
integration = await create_test_integration(meeting_store)
context = _DummyContext()
start = await servicer.StartIntegrationSync(
start = await _call_start_sync(
servicer,
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(integration.id)),
context,
)
await await_sync_completion(servicer, start.sync_run_id, context)
history = await servicer.ListSyncHistory(
history = await _call_list_sync_history(
servicer,
noteflow_pb2.ListSyncHistoryRequest(integration_id=str(integration.id), limit=10),
context,
)
@@ -370,7 +498,8 @@ class TestSyncErrorHandling:
integration = await create_test_integration(meeting_store)
context = _DummyContext()
start = await servicer.StartIntegrationSync(
start = await _call_start_sync(
servicer,
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(integration.id)),
context,
)
@@ -387,7 +516,8 @@ class TestSyncErrorHandling:
integration = await create_test_integration(meeting_store)
context = _DummyContext()
start = await servicer_with_failure.StartIntegrationSync(
start = await _call_start_sync(
servicer_with_failure,
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(integration.id)),
context,
)
@@ -406,7 +536,8 @@ class TestSyncErrorHandling:
servicer_fail = NoteFlowServicer(services=ServicesConfig(calendar_service=cast(CalendarService, failing_service)))
servicer_fail.memory_store = meeting_store
first = await servicer_fail.StartIntegrationSync(
first = await _call_start_sync(
servicer_fail,
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(integration.id)),
context,
)
@@ -419,7 +550,8 @@ class TestSyncErrorHandling:
servicer_success = NoteFlowServicer(services=ServicesConfig(calendar_service=cast(CalendarService, success_service)))
servicer_success.memory_store = meeting_store
second = await servicer_success.StartIntegrationSync(
second = await _call_start_sync(
servicer_success,
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(integration.id)),
context,
)
@@ -442,11 +574,13 @@ class TestConcurrentSyncs:
int2 = await create_test_integration(meeting_store, "Calendar 2")
context = _DummyContext()
r1 = await servicer_with_success.StartIntegrationSync(
r1 = await _call_start_sync(
servicer_with_success,
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(int1.id)),
context,
)
r2 = await servicer_with_success.StartIntegrationSync(
r2 = await _call_start_sync(
servicer_with_success,
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(int2.id)),
context,
)
@@ -466,22 +600,26 @@ class TestConcurrentSyncs:
int2 = await create_test_integration(meeting_store, "Calendar 2")
context = _DummyContext()
r1 = await servicer_with_success.StartIntegrationSync(
r1 = await _call_start_sync(
servicer_with_success,
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(int1.id)),
context,
)
r2 = await servicer_with_success.StartIntegrationSync(
r2 = await _call_start_sync(
servicer_with_success,
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(int2.id)),
context,
)
await await_sync_completion(servicer_with_success, r1.sync_run_id, context)
await await_sync_completion(servicer_with_success, r2.sync_run_id, context)
h1 = await servicer_with_success.ListSyncHistory(
h1 = await _call_list_sync_history(
servicer_with_success,
noteflow_pb2.ListSyncHistoryRequest(integration_id=str(int1.id), limit=10),
context,
)
h2 = await servicer_with_success.ListSyncHistory(
h2 = await _call_list_sync_history(
servicer_with_success,
noteflow_pb2.ListSyncHistoryRequest(integration_id=str(int2.id), limit=10),
context,
)
@@ -506,7 +644,8 @@ class TestSyncPolling:
integration = await create_test_integration(meeting_store)
context = _DummyContext()
start = await servicer.StartIntegrationSync(
start = await _call_start_sync(
servicer,
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(integration.id)),
context,
)
@@ -515,7 +654,8 @@ class TestSyncPolling:
await asyncio.wait_for(blocking_service.sync_started.wait(), timeout=1.0)
# Check status while blocked - should still be running
immediate_status = await servicer.GetSyncStatus(
immediate_status = await _call_get_sync_status(
servicer,
noteflow_pb2.GetSyncStatusRequest(sync_run_id=start.sync_run_id),
context,
)
@@ -535,7 +675,8 @@ class TestSyncPolling:
integration = await create_test_integration(meeting_store)
context = _DummyContext()
start = await servicer_with_success.StartIntegrationSync(
start = await _call_start_sync(
servicer_with_success,
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(integration.id)),
context,
)
@@ -554,7 +695,8 @@ class TestGetUserIntegrations:
servicer = NoteFlowServicer()
context = _DummyContext()
response = await servicer.GetUserIntegrations(
response = await _call_get_user_integrations(
servicer,
noteflow_pb2.GetUserIntegrationsRequest(),
context,
)
@@ -575,7 +717,8 @@ class TestGetUserIntegrations:
integration = await create_test_integration(meeting_store, "Google Calendar")
response = await servicer.GetUserIntegrations(
response = await _call_get_user_integrations(
servicer,
noteflow_pb2.GetUserIntegrationsRequest(),
context,
)
@@ -596,7 +739,8 @@ class TestGetUserIntegrations:
integration = await create_test_integration(meeting_store, "Test Calendar")
response = await servicer.GetUserIntegrations(
response = await _call_get_user_integrations(
servicer,
noteflow_pb2.GetUserIntegrationsRequest(),
context,
)
@@ -625,7 +769,8 @@ class TestGetUserIntegrations:
await meeting_store.integrations.delete(int1.id)
response = await servicer.GetUserIntegrations(
response = await _call_get_user_integrations(
servicer,
noteflow_pb2.GetUserIntegrationsRequest(),
context,
)
@@ -648,7 +793,8 @@ class TestNotFoundStatusCode:
nonexistent_id = str(uuid4())
with pytest.raises(AssertionError, match="abort called"):
await servicer_with_success.StartIntegrationSync(
await _call_start_sync(
servicer_with_success,
noteflow_pb2.StartIntegrationSyncRequest(integration_id=nonexistent_id),
context,
)
@@ -666,7 +812,8 @@ class TestNotFoundStatusCode:
nonexistent_id = str(uuid4())
with pytest.raises(AssertionError, match="abort called"):
await servicer_with_success.GetSyncStatus(
await _call_get_sync_status(
servicer_with_success,
noteflow_pb2.GetSyncStatusRequest(sync_run_id=nonexistent_id),
context,
)
@@ -686,11 +833,13 @@ class TestSyncRunExpiryMetadata:
integration = await create_test_integration(meeting_store)
context = _DummyContext()
start = await servicer_with_success.StartIntegrationSync(
start = await _call_start_sync(
servicer_with_success,
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(integration.id)),
context,
)
status = await servicer_with_success.GetSyncStatus(
status = await _call_get_sync_status(
servicer_with_success,
noteflow_pb2.GetSyncStatusRequest(sync_run_id=start.sync_run_id),
context,
)
@@ -706,7 +855,8 @@ class TestSyncRunExpiryMetadata:
integration = await create_test_integration(meeting_store)
context = _DummyContext()
start = await servicer_with_success.StartIntegrationSync(
start = await _call_start_sync(
servicer_with_success,
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(integration.id)),
context,
)
@@ -714,7 +864,7 @@ class TestSyncRunExpiryMetadata:
sync_run_id_uuid = UUID(start.sync_run_id)
# Type annotation needed: ensure_sync_runs_cache is a mixin method added via SyncMixin
ensure_cache = cast(
Callable[[], dict[UUID, object]],
Callable[[], dict[UUID, SyncRun]],
servicer_with_success.ensure_sync_runs_cache,
)
ensure_cache()

View File

@@ -90,14 +90,17 @@ class TestOidcProviderRegistry:
json=valid_discovery_document,
)
workspace_id = uuid4()
provider = await registry.create_provider(
registration = OidcProviderRegistration(
workspace_id=workspace_id,
name="Test Provider",
issuer_url="https://auth.example.com",
client_id="test-client-id",
)
provider = await registry.create_provider(
registration,
)
assert provider.name == "Test Provider", "provider name should match"
assert provider.workspace_id == workspace_id, "workspace_id should match"
assert provider.discovery is not None, "discovery should be populated"
@@ -110,12 +113,15 @@ class TestOidcProviderRegistry:
) -> None:
"""Verify provider creation without auto-discovery."""
workspace_id = uuid4()
provider = await registry.create_provider(
registration = OidcProviderRegistration(
workspace_id=workspace_id,
name="Test Provider",
issuer_url="https://auth.example.com",
client_id="test-client-id",
)
provider = await registry.create_provider(
registration,
auto_discover=False,
)
@@ -130,12 +136,16 @@ class TestOidcProviderRegistry:
"""Verify provider creation applies preset defaults."""
workspace_id = uuid4()
params = OidcProviderCreateParams(preset=OidcProviderPreset.AUTHENTIK)
provider = await registry.create_provider(
registration = OidcProviderRegistration(
workspace_id=workspace_id,
name="Authentik",
issuer_url="https://auth.example.com",
client_id="test-client-id",
)
params = OidcProviderCreateParams(preset=OidcProviderPreset.AUTHENTIK)
provider = await registry.create_provider(
registration,
params=params,
auto_discover=False,
)
@@ -155,14 +165,15 @@ class TestOidcProviderRegistry:
url="https://auth.example.com/.well-known/openid-configuration",
status_code=404,
)
registration = OidcProviderRegistration(
workspace_id=uuid4(),
name="Test Provider",
issuer_url="https://auth.example.com",
client_id="test-client-id",
)
with pytest.raises(OidcDiscoveryError, match="HTTP 404"):
await registry.create_provider(
workspace_id=uuid4(),
name="Test Provider",
issuer_url="https://auth.example.com",
client_id="test-client-id",
)
await registry.create_provider(registration)
def test_get_provider(self, registry: OidcProviderRegistry) -> None:
"""Verify get_provider returns correct provider."""
@@ -179,26 +190,36 @@ class TestOidcProviderRegistry:
workspace1 = uuid4()
workspace2 = uuid4()
# Create providers without discovery
await registry.create_provider(
registration1 = OidcProviderRegistration(
workspace_id=workspace1,
name="Provider 1",
issuer_url="https://auth1.example.com",
client_id="client1",
auto_discover=False,
)
provider2 = await registry.create_provider(
registration2 = OidcProviderRegistration(
workspace_id=workspace1,
name="Provider 2",
issuer_url="https://auth2.example.com",
client_id="client2",
auto_discover=False,
)
await registry.create_provider(
registration3 = OidcProviderRegistration(
workspace_id=workspace2,
name="Provider 3",
issuer_url="https://auth3.example.com",
client_id="client3",
)
# Create providers without discovery
await registry.create_provider(
registration1,
auto_discover=False,
)
provider2 = await registry.create_provider(
registration2,
auto_discover=False,
)
await registry.create_provider(
registration3,
auto_discover=False,
)
@@ -222,12 +243,15 @@ class TestOidcProviderRegistry:
) -> None:
"""Verify provider removal."""
workspace_id = uuid4()
provider = await registry.create_provider(
registration = OidcProviderRegistration(
workspace_id=workspace_id,
name="Test Provider",
issuer_url="https://auth.example.com",
client_id="test-client-id",
)
provider = await registry.create_provider(
registration,
auto_discover=False,
)

View File

@@ -11,40 +11,39 @@ Tests cover:
from __future__ import annotations
import importlib
import sys
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Protocol, cast
from unittest.mock import MagicMock, patch
import pytest
from noteflow.infrastructure.diarization._compat import (
AudioMetaData,
_patch_huggingface_auth,
_patch_speechbrain_backend,
_patch_torch_load,
_patch_torchaudio,
apply_patches,
ensure_compatibility,
)
from noteflow.infrastructure.diarization import _compat
if TYPE_CHECKING:
from collections.abc import Generator
class _CompatModule(Protocol):
"""Protocol for the diarization compatibility module."""
AudioMetaData: type
def apply_patches(self) -> None: ...
def ensure_compatibility(self) -> None: ...
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture
def reset_patches_state() -> Generator[None, None, None]:
"""Reset _patches_applied state before and after tests."""
import noteflow.infrastructure.diarization._compat as compat_module
original_state = compat_module._patches_applied
compat_module._patches_applied = False
yield
compat_module._patches_applied = original_state
def compat_module() -> Generator[_CompatModule, None, None]:
"""Reload compatibility module to reset internal patch state."""
module = importlib.reload(_compat)
yield cast(_CompatModule, module)
@pytest.fixture
@@ -79,9 +78,9 @@ def mock_huggingface_hub() -> MagicMock:
class TestAudioMetaData:
"""Tests for the replacement AudioMetaData dataclass."""
def test_audiometadata_has_required_fields(self) -> None:
def test_audiometadata_has_required_fields(self, compat_module: _CompatModule) -> None:
"""AudioMetaData has all fields expected by pyannote.audio."""
metadata = AudioMetaData(
metadata = compat_module.AudioMetaData(
sample_rate=16000,
num_frames=48000,
num_channels=1,
@@ -95,9 +94,9 @@ class TestAudioMetaData:
assert metadata.bits_per_sample == 16, "should store bits_per_sample"
assert metadata.encoding == "PCM_S", "should store encoding"
def test_audiometadata_is_immutable(self) -> None:
def test_audiometadata_is_immutable(self, compat_module: _CompatModule) -> None:
"""AudioMetaData fields cannot be modified after creation."""
metadata = AudioMetaData(
metadata = compat_module.AudioMetaData(
sample_rate=16000,
num_frames=48000,
num_channels=1,
@@ -119,38 +118,42 @@ class TestPatchTorchaudio:
"""Tests for torchaudio AudioMetaData patching."""
def test_patches_audiometadata_when_missing(
self, mock_torchaudio: MagicMock
self, compat_module: _CompatModule, mock_torchaudio: MagicMock
) -> None:
"""_patch_torchaudio adds AudioMetaData when not present."""
with patch.dict(sys.modules, {"torchaudio": mock_torchaudio}):
_patch_torchaudio()
compat_module.apply_patches()
assert hasattr(
mock_torchaudio, "AudioMetaData"
), "should add AudioMetaData"
assert (
mock_torchaudio.AudioMetaData is AudioMetaData
mock_torchaudio.AudioMetaData is compat_module.AudioMetaData
), "should use our AudioMetaData class"
def test_does_not_override_existing_audiometadata(self) -> None:
def test_does_not_override_existing_audiometadata(
self, compat_module: _CompatModule
) -> None:
"""_patch_torchaudio preserves existing AudioMetaData if present."""
mock = MagicMock()
existing_class = type("ExistingAudioMetaData", (), {})
mock.AudioMetaData = existing_class
with patch.dict(sys.modules, {"torchaudio": mock}):
_patch_torchaudio()
compat_module.apply_patches()
assert (
mock.AudioMetaData is existing_class
), "should not override existing AudioMetaData"
def test_handles_import_error_gracefully(self) -> None:
def test_handles_import_error_gracefully(
self, compat_module: _CompatModule
) -> None:
"""_patch_torchaudio doesn't raise when torchaudio not installed."""
# Remove torchaudio from modules if present
with patch.dict(sys.modules, {"torchaudio": None}):
# Should not raise
_patch_torchaudio()
compat_module.apply_patches()
# =============================================================================
@@ -162,7 +165,7 @@ class TestPatchTorchLoad:
"""Tests for torch.load weights_only patching."""
def test_patches_torch_load_for_pytorch_2_6_plus(
self, mock_torch: MagicMock
self, compat_module: _CompatModule, mock_torch: MagicMock
) -> None:
"""_patch_torch_load adds weights_only=False default for PyTorch 2.6+."""
original_load = mock_torch.load
@@ -172,12 +175,12 @@ class TestPatchTorchLoad:
mock_version.return_value = mock_version
mock_version.__ge__ = MagicMock(return_value=True)
_patch_torch_load()
compat_module.apply_patches()
# Verify torch.load was replaced (not the same function)
assert mock_torch.load is not original_load, "load should be patched"
def test_does_not_patch_older_pytorch(self) -> None:
def test_does_not_patch_older_pytorch(self, compat_module: _CompatModule) -> None:
"""_patch_torch_load skips patching for PyTorch < 2.6."""
mock = MagicMock()
mock.__version__ = "2.5.0"
@@ -188,15 +191,17 @@ class TestPatchTorchLoad:
mock_version.return_value = mock_version
mock_version.__ge__ = MagicMock(return_value=False)
_patch_torch_load()
compat_module.apply_patches()
# load should not have been replaced
assert mock.load is original_load, "should not patch older PyTorch"
def test_handles_import_error_gracefully(self) -> None:
def test_handles_import_error_gracefully(
self, compat_module: _CompatModule
) -> None:
"""_patch_torch_load doesn't raise when torch not installed."""
with patch.dict(sys.modules, {"torch": None}):
_patch_torch_load()
compat_module.apply_patches()
# =============================================================================
@@ -208,13 +213,13 @@ class TestPatchHuggingfaceAuth:
"""Tests for huggingface_hub use_auth_token patching."""
def test_converts_use_auth_token_to_token(
self, mock_huggingface_hub: MagicMock
self, compat_module: _CompatModule, mock_huggingface_hub: MagicMock
) -> None:
"""_patch_huggingface_auth converts use_auth_token to token parameter."""
original_download = mock_huggingface_hub.hf_hub_download
with patch.dict(sys.modules, {"huggingface_hub": mock_huggingface_hub}):
_patch_huggingface_auth()
compat_module.apply_patches()
# Call with legacy use_auth_token
mock_huggingface_hub.hf_hub_download(
@@ -233,13 +238,13 @@ class TestPatchHuggingfaceAuth:
), "should remove use_auth_token"
def test_preserves_token_parameter(
self, mock_huggingface_hub: MagicMock
self, compat_module: _CompatModule, mock_huggingface_hub: MagicMock
) -> None:
"""_patch_huggingface_auth preserves token if already using new API."""
original_download = mock_huggingface_hub.hf_hub_download
with patch.dict(sys.modules, {"huggingface_hub": mock_huggingface_hub}):
_patch_huggingface_auth()
compat_module.apply_patches()
mock_huggingface_hub.hf_hub_download(
repo_id="test/repo",
@@ -251,10 +256,12 @@ class TestPatchHuggingfaceAuth:
call_kwargs = original_download.call_args[1]
assert call_kwargs["token"] == "my_token", "should preserve token"
def test_handles_import_error_gracefully(self) -> None:
def test_handles_import_error_gracefully(
self, compat_module: _CompatModule
) -> None:
"""_patch_huggingface_auth doesn't raise when huggingface_hub not installed."""
with patch.dict(sys.modules, {"huggingface_hub": None}):
_patch_huggingface_auth()
compat_module.apply_patches()
# =============================================================================
@@ -265,10 +272,12 @@ class TestPatchHuggingfaceAuth:
class TestPatchSpeechbrainBackend:
"""Tests for torchaudio backend API patching."""
def test_patches_list_audio_backends(self, mock_torchaudio: MagicMock) -> None:
def test_patches_list_audio_backends(
self, compat_module: _CompatModule, mock_torchaudio: MagicMock
) -> None:
"""_patch_speechbrain_backend adds list_audio_backends when missing."""
with patch.dict(sys.modules, {"torchaudio": mock_torchaudio}):
_patch_speechbrain_backend()
compat_module.apply_patches()
assert hasattr(
mock_torchaudio, "list_audio_backends"
@@ -276,10 +285,12 @@ class TestPatchSpeechbrainBackend:
result = mock_torchaudio.list_audio_backends()
assert isinstance(result, list), "should return list"
def test_patches_get_audio_backend(self, mock_torchaudio: MagicMock) -> None:
def test_patches_get_audio_backend(
self, compat_module: _CompatModule, mock_torchaudio: MagicMock
) -> None:
"""_patch_speechbrain_backend adds get_audio_backend when missing."""
with patch.dict(sys.modules, {"torchaudio": mock_torchaudio}):
_patch_speechbrain_backend()
compat_module.apply_patches()
assert hasattr(
mock_torchaudio, "get_audio_backend"
@@ -287,10 +298,12 @@ class TestPatchSpeechbrainBackend:
result = mock_torchaudio.get_audio_backend()
assert result is None, "should return None"
def test_patches_set_audio_backend(self, mock_torchaudio: MagicMock) -> None:
def test_patches_set_audio_backend(
self, compat_module: _CompatModule, mock_torchaudio: MagicMock
) -> None:
"""_patch_speechbrain_backend adds set_audio_backend when missing."""
with patch.dict(sys.modules, {"torchaudio": mock_torchaudio}):
_patch_speechbrain_backend()
compat_module.apply_patches()
assert hasattr(
mock_torchaudio, "set_audio_backend"
@@ -298,14 +311,16 @@ class TestPatchSpeechbrainBackend:
# Should not raise
mock_torchaudio.set_audio_backend("sox")
def test_does_not_override_existing_functions(self) -> None:
def test_does_not_override_existing_functions(
self, compat_module: _CompatModule
) -> None:
"""_patch_speechbrain_backend preserves existing backend functions."""
mock = MagicMock()
existing_list = MagicMock(return_value=["ffmpeg"])
mock.list_audio_backends = existing_list
with patch.dict(sys.modules, {"torchaudio": mock}):
_patch_speechbrain_backend()
compat_module.apply_patches()
assert (
mock.list_audio_backends is existing_list
@@ -321,42 +336,24 @@ class TestApplyPatches:
"""Tests for the main apply_patches function."""
def test_apply_patches_is_idempotent(
self, reset_patches_state: None
self, compat_module: _CompatModule
) -> None:
"""apply_patches only applies patches once."""
import noteflow.infrastructure.diarization._compat as compat_module
mock_torch = MagicMock()
mock_torch.__version__ = "2.6.0"
original_load = mock_torch.load
with patch.object(compat_module, "_patch_torchaudio") as mock_torchaudio:
with patch.object(compat_module, "_patch_torch_load") as mock_torch:
with patch.object(
compat_module, "_patch_huggingface_auth"
) as mock_hf:
with patch.object(
compat_module, "_patch_speechbrain_backend"
) as mock_sb:
apply_patches()
apply_patches() # Second call
apply_patches() # Third call
with patch.dict(sys.modules, {"torch": mock_torch}):
with patch("packaging.version.Version") as mock_version:
mock_version.return_value = mock_version
mock_version.__ge__ = MagicMock(return_value=True)
# Each patch function should only be called once
mock_torchaudio.assert_called_once()
mock_torch.assert_called_once()
mock_hf.assert_called_once()
mock_sb.assert_called_once()
compat_module.apply_patches()
first_load = mock_torch.load
compat_module.apply_patches()
def test_apply_patches_sets_flag(self, reset_patches_state: None) -> None:
"""apply_patches sets _patches_applied flag."""
import noteflow.infrastructure.diarization._compat as compat_module
assert compat_module._patches_applied is False, "should start False"
with patch.object(compat_module, "_patch_torchaudio"):
with patch.object(compat_module, "_patch_torch_load"):
with patch.object(compat_module, "_patch_huggingface_auth"):
with patch.object(compat_module, "_patch_speechbrain_backend"):
apply_patches()
assert compat_module._patches_applied is True, "should be True after apply"
assert first_load is not original_load, "initial call should patch torch.load"
assert mock_torch.load is first_load, "subsequent calls should be idempotent"
# =============================================================================
@@ -368,12 +365,10 @@ class TestEnsureCompatibility:
"""Tests for the ensure_compatibility entry point."""
def test_ensure_compatibility_calls_apply_patches(
self, reset_patches_state: None
self, compat_module: _CompatModule
) -> None:
"""ensure_compatibility delegates to apply_patches."""
import noteflow.infrastructure.diarization._compat as compat_module
with patch.object(compat_module, "apply_patches") as mock_apply:
ensure_compatibility()
compat_module.ensure_compatibility()
mock_apply.assert_called_once()

View File

@@ -60,7 +60,11 @@ def enabled_config() -> WebhookConfig:
def disabled_config() -> WebhookConfig:
"""Create a disabled webhook config."""
workspace_id = uuid4()
now = WebhookConfig.create(workspace_id, "", [WebhookEventType.MEETING_COMPLETED]).created_at
now = WebhookConfig.create(
workspace_id=workspace_id,
url="",
events=[WebhookEventType.MEETING_COMPLETED],
).created_at
return WebhookConfig(
id=uuid4(),
workspace_id=workspace_id,

View File

@@ -15,7 +15,7 @@ from __future__ import annotations
from collections.abc import Sequence
from pathlib import Path
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Protocol, cast
from unittest.mock import MagicMock
from uuid import UUID, uuid4
@@ -40,6 +40,66 @@ from noteflow.infrastructure.persistence.unit_of_work import SqlAlchemyUnitOfWor
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
class _RefineSpeakerDiarizationRequest(Protocol):
meeting_id: str
num_speakers: int
class _RefineSpeakerDiarizationResponse(Protocol):
job_id: str
status: int
error_message: str
segments_updated: int
class _DiarizationJobStatusRequest(Protocol):
job_id: str
class _DiarizationJobStatusResponse(Protocol):
job_id: str
status: int
segments_updated: int
speaker_ids: Sequence[str]
error_message: str
progress_percent: float
class _RenameSpeakerRequest(Protocol):
meeting_id: str
old_speaker_id: str
new_speaker_name: str
class _RenameSpeakerResponse(Protocol):
segments_updated: int
success: bool
class _RefineSpeakerDiarizationCallable(Protocol):
async def __call__(
self,
request: _RefineSpeakerDiarizationRequest,
context: MockContext,
) -> _RefineSpeakerDiarizationResponse: ...
class _GetDiarizationJobStatusCallable(Protocol):
async def __call__(
self,
request: _DiarizationJobStatusRequest,
context: MockContext,
) -> _DiarizationJobStatusResponse: ...
class _RenameSpeakerCallable(Protocol):
async def __call__(
self,
request: _RenameSpeakerRequest,
context: MockContext,
) -> _RenameSpeakerResponse: ...
# ============================================================================
# Test Constants
# ============================================================================
@@ -104,6 +164,36 @@ class _MockRpcError(grpc.RpcError):
return self._details
async def _call_refine(
servicer: NoteFlowServicer,
request: _RefineSpeakerDiarizationRequest,
context: MockContext,
) -> _RefineSpeakerDiarizationResponse:
"""Call RefineSpeakerDiarization with typed response."""
refine = cast(_RefineSpeakerDiarizationCallable, servicer.RefineSpeakerDiarization)
return await refine(request, context)
async def _call_get_status(
servicer: NoteFlowServicer,
request: _DiarizationJobStatusRequest,
context: MockContext,
) -> _DiarizationJobStatusResponse:
"""Call GetDiarizationJobStatus with typed response."""
get_status = cast(_GetDiarizationJobStatusCallable, servicer.GetDiarizationJobStatus)
return await get_status(request, context)
async def _call_rename(
servicer: NoteFlowServicer,
request: _RenameSpeakerRequest,
context: MockContext,
) -> _RenameSpeakerResponse:
"""Call RenameSpeaker with typed response."""
rename = cast(_RenameSpeakerCallable, servicer.RenameSpeaker)
return await rename(request, context)
@pytest.mark.integration
class TestServicerMeetingOperationsWithDatabase:
"""Integration tests for meeting operations using real database."""
@@ -320,10 +410,10 @@ class TestServicerDiarizationWithDatabase:
request = noteflow_pb2.RefineSpeakerDiarizationRequest(
meeting_id=meeting_id_str,
)
result = await servicer.RefineSpeakerDiarization(request, MockContext())
result = await _call_refine(servicer, request, MockContext())
assert result.job_id, "RefineSpeakerDiarization response should include a job ID"
assert result.status == noteflow_pb2.JOB_STATUS_QUEUED, f"expected QUEUED status, got {result.status}"
assert result.status == JOB_STATUS_QUEUED, f"expected QUEUED status, got {result.status}"
async with SqlAlchemyUnitOfWork(session_factory, meetings_dir) as uow:
job = await uow.diarization_jobs.get(result.job_id)
@@ -352,12 +442,12 @@ class TestServicerDiarizationWithDatabase:
servicer = NoteFlowServicer(session_factory=session_factory)
request = noteflow_pb2.GetDiarizationJobStatusRequest(job_id=job.job_id)
response: noteflow_pb2.DiarizationJobStatus = await servicer.GetDiarizationJobStatus(request, MockContext())
response = await _call_get_status(servicer, request, MockContext())
# cast required: protobuf RepeatedScalarFieldContainer is typed in .pyi but pyright doesn't resolve generic
speaker_ids_list: Sequence[str] = cast(Sequence[str], response.speaker_ids)
speaker_ids_list = response.speaker_ids
assert response.job_id == job.job_id, f"expected job_id {job.job_id}, got {response.job_id}"
assert response.status == noteflow_pb2.JOB_STATUS_COMPLETED, f"expected COMPLETED status, got {response.status}"
assert response.status == JOB_STATUS_COMPLETED, f"expected COMPLETED status, got {response.status}"
assert response.segments_updated == DIARIZATION_SEGMENTS_UPDATED, f"expected {DIARIZATION_SEGMENTS_UPDATED} segments_updated, got {response.segments_updated}"
assert list(speaker_ids_list) == ["SPEAKER_00", "SPEAKER_01"], f"expected speaker_ids ['SPEAKER_00', 'SPEAKER_01'], got {list(speaker_ids_list)}"
@@ -371,7 +461,7 @@ class TestServicerDiarizationWithDatabase:
request = noteflow_pb2.GetDiarizationJobStatusRequest(job_id="nonexistent")
with pytest.raises(grpc.RpcError, match=r".*"):
await servicer.GetDiarizationJobStatus(request, context)
await _call_get_status(servicer, request, context)
assert context.abort_code == grpc.StatusCode.NOT_FOUND, f"expected NOT_FOUND status for nonexistent job, got {context.abort_code}"
@@ -397,9 +487,9 @@ class TestServicerDiarizationWithDatabase:
request = noteflow_pb2.RefineSpeakerDiarizationRequest(
meeting_id=str(meeting.id),
)
result = await servicer.RefineSpeakerDiarization(request, MockContext())
result = await _call_refine(servicer, request, MockContext())
assert result.status == noteflow_pb2.JOB_STATUS_FAILED, f"expected FAILED status for recording meeting, got {result.status}"
assert result.status == JOB_STATUS_FAILED, f"expected FAILED status for recording meeting, got {result.status}"
assert "stopped" in result.error_message.lower(), f"expected 'stopped' in error message, got '{result.error_message}'"
@@ -515,7 +605,7 @@ class TestServicerRenameSpeakerWithDatabase:
old_speaker_id="SPEAKER_00",
new_speaker_name="Alice",
)
result = await servicer.RenameSpeaker(request, MockContext())
result = await _call_rename(servicer, request, MockContext())
assert result.segments_updated == 3, f"expected 3 segments updated, got {result.segments_updated}"
assert result.success is True, "RenameSpeaker should return success=True"
@@ -1111,7 +1201,7 @@ class TestServerRestartJobRecovery:
servicer_new = NoteFlowServicer(session_factory=session_factory)
request = noteflow_pb2.GetDiarizationJobStatusRequest(job_id=job_id)
response = await servicer_new.GetDiarizationJobStatus(request, MockContext())
response = await _call_get_status(servicer_new, request, MockContext())
assert response.status == noteflow_pb2.JOB_STATUS_FAILED, "job should be FAILED"
assert response.status == JOB_STATUS_FAILED, "job should be FAILED"
assert response.error_message == "Server restarted", "should have restart error"

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
from collections.abc import AsyncIterator
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Protocol, cast
from unittest.mock import AsyncMock, MagicMock
import grpc
@@ -22,6 +22,25 @@ from noteflow.infrastructure.persistence.unit_of_work import SqlAlchemyUnitOfWor
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
class _AudioChunkRequest(Protocol):
meeting_id: str
audio_data: bytes
sample_rate: int
channels: int
class _TranscriptUpdate(Protocol):
update_type: int
class _StreamTranscriptionCallable(Protocol):
def __call__(
self,
request_iterator: AsyncIterator[_AudioChunkRequest],
context: MockContext,
) -> AsyncIterator[_TranscriptUpdate]: ...
SAMPLE_RATE = DEFAULT_SAMPLE_RATE
CHUNK_SAMPLES = 1600 # 0.1s at 16kHz
SPEECH_CHUNKS = 4
@@ -88,15 +107,16 @@ class TestStreamingRealPipeline:
meetings_dir=meetings_dir,
)
updates: list[noteflow_pb2.TranscriptUpdate] = [
stream = cast(_StreamTranscriptionCallable, servicer.StreamTranscription)
updates: list[_TranscriptUpdate] = [
update
async for update in servicer.StreamTranscription(
async for update in stream(
_audio_stream(str(meeting.id)),
MockContext(),
)
]
final_updates: list[noteflow_pb2.TranscriptUpdate] = [
final_updates: list[_TranscriptUpdate] = [
update
for update in updates
if update.update_type == noteflow_pb2.UPDATE_TYPE_FINAL

12
uv.lock generated
View File

@@ -1,5 +1,5 @@
version = 1
revision = 2
revision = 3
requires-python = ">=3.12"
resolution-markers = [
"python_full_version >= '3.13'",
@@ -2440,7 +2440,7 @@ requires-dist = [
{ name = "testcontainers", extras = ["postgres"], marker = "extra == 'dev'", specifier = ">=4.0" },
{ name = "torch", marker = "extra == 'diarization'", specifier = ">=2.0" },
{ name = "torch", marker = "extra == 'optional'", specifier = ">=2.0" },
{ name = "types-grpcio", marker = "extra == 'dev'", specifier = ">=1.0.0.20251009" },
{ name = "types-grpcio", marker = "extra == 'dev'", specifier = "==1.0.0.20251001" },
{ name = "types-psutil", specifier = ">=7.2.0.20251228" },
{ name = "weasyprint", marker = "extra == 'optional'", specifier = ">=67.0" },
{ name = "weasyprint", marker = "extra == 'pdf'", specifier = ">=67.0" },
@@ -2456,7 +2456,7 @@ dev = [
{ name = "pytest-httpx", specifier = ">=0.36.0" },
{ name = "ruff", specifier = ">=0.14.9" },
{ name = "sourcery", marker = "sys_platform == 'darwin'" },
{ name = "types-grpcio", specifier = ">=1.0.0.20251009" },
{ name = "types-grpcio", specifier = "==1.0.0.20251001" },
{ name = "watchfiles", specifier = ">=1.1.1" },
]
@@ -7411,11 +7411,11 @@ wheels = [
[[package]]
name = "types-grpcio"
version = "1.0.0.20251009"
version = "1.0.0.20251001"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/de/93/78aa083216853c667c9412df4ef8284b2a68c6bcd2aef833f970b311f3c1/types_grpcio-1.0.0.20251009.tar.gz", hash = "sha256:a8f615ea7a47b31f10da028ab5258d4f1611fbd70719ca450fc0ab3fb9c62b63", size = 14479, upload-time = "2025-10-09T02:54:14.539Z" }
sdist = { url = "https://files.pythonhosted.org/packages/6e/84/569f4bd7d6c70337a16171041930027d46229d760cbe1dbaa422e18a7abf/types_grpcio-1.0.0.20251001.tar.gz", hash = "sha256:5334e2076b3ad621188af58b082ac7b31ea52f3e9a01cdd1984823bf63e7ce55", size = 14672, upload-time = "2025-10-01T03:04:10.388Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d4/93/66d28f41b16bb4e6b611bd608ef28dffc740facec93250b30cf83138da21/types_grpcio-1.0.0.20251009-py3-none-any.whl", hash = "sha256:112ac4312a5b0a273a4c414f7f2c7668f342990d9c6ab0f647391c36331f95ed", size = 15208, upload-time = "2025-10-09T02:54:13.588Z" },
{ url = "https://files.pythonhosted.org/packages/42/33/f2e6e5f4f2f80fa6ee253dc47610f856baa76b87e56050b4bd7d91b7d272/types_grpcio-1.0.0.20251001-py3-none-any.whl", hash = "sha256:7e7dc6e7238f1dc353adae2e172e8f4acd2388ad8935eba10d6a93dc58529148", size = 15351, upload-time = "2025-10-01T03:04:09.183Z" },
]
[[package]]