fix: update types-grpcio dependency version and improve meeting filtering
- Changed the types-grpcio dependency from `>=1.0.0.20251009` to `==1.0.0.20251001` in `pyproject.toml` and `uv.lock` for consistency. - Enhanced the meeting filtering logic in the `MeetingStore` and related classes to support filtering by multiple project IDs, improving query flexibility. - Updated the gRPC proto definitions to include a new `project_ids` field for the `ListMeetingsRequest` message, allowing for more granular project filtering. - Adjusted related repository and mixin classes to accommodate the new filtering capabilities.
This commit is contained in:
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
@@ -2,6 +2,15 @@
|
||||
|
||||
Checking for magic numbers...
|
||||
[1;33mWARNING:[0m Found potential magic numbers (consider using named constants):
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/oauth_loopback.rs:60: let listener = TcpListener::bind("127.0.0.1:0")
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/oauth_loopback.rs:280: background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/oauth_loopback.rs:286: background: rgba(255,255,255,0.1);
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/oauth_loopback.rs:300: <div class="checkmark">✔</div>
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/oauth_loopback.rs:308: "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/oauth_loopback.rs:332: background: linear-gradient(135deg, #e74c3c 0%, #c0392b 100%);
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/oauth_loopback.rs:338: background: rgba(255,255,255,0.1);
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/oauth_loopback.rs:349: <div class="error">✖</div>
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/oauth_loopback.rs:359: "HTTP/1.1 400 Bad Request\r\nContent-Type: text/html\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/state/preferences.rs:63: server_host: "127.0.0.1".to_string(),
|
||||
|
||||
Checking for repeated string literals...
|
||||
@@ -21,19 +30,19 @@ Checking for long functions...
|
||||
|
||||
Checking for deep nesting...
|
||||
[1;33mWARNING:[0m Found potentially deep nesting (>7 levels):
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/commands/diarization.rs:232: INITIAL_RETRY_DELAY_MS * RETRY_BACKOFF_MULTIPLIER.pow(consecutive_errors - 1),
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/commands/playback.rs:59: let duration = buffer
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/commands/playback.rs:60: .last()
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/commands/playback.rs:61: .map(|chunk| chunk.timestamp + chunk.duration)
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/commands/playback.rs:62: .unwrap_or(0.0);
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/commands/playback.rs:63: tracing::debug!("Loaded encrypted audio from {:?}", audio_path);
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/commands/playback.rs:64: return LoadedAudio {
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/commands/playback.rs:65: buffer,
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/commands/playback.rs:66: sample_rate,
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/commands/playback.rs:67: duration,
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/oauth_loopback.rs:122: if let Some(result) = handle_connection(stream).await {
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/oauth_loopback.rs:123: if let Some(tx) = result_tx.lock().await.take() {
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/oauth_loopback.rs:124: let _ = tx.send(result);
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/oauth_loopback.rs:125: }
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/oauth_loopback.rs:126: }
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/identity/mod.rs:145: user_id = %identity.user_id,
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/identity/mod.rs:146: is_local = identity.is_local,
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/identity/mod.rs:147: "Loaded identity from keychain"
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/commands/diarization.rs:233: INITIAL_RETRY_DELAY_MS
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/commands/diarization.rs:234: * RETRY_BACKOFF_MULTIPLIER.pow(consecutive_errors - 1),
|
||||
|
||||
Checking for unwrap() usage...
|
||||
[0;32mOK:[0m No unwrap() calls found
|
||||
[0;32mOK:[0m Found 2 unwrap() calls (within acceptable range)
|
||||
|
||||
Checking for excessive clone() usage...
|
||||
[0;32mOK:[0m No excessive clone() usage detected
|
||||
@@ -46,18 +55,18 @@ Checking for duplicated error messages...
|
||||
|
||||
Checking module file sizes...
|
||||
[1;33mWARNING:[0m Large files (>500 lines):
|
||||
597 /home/trav/repos/noteflow/client/src-tauri/scripts/../src/commands/playback.rs
|
||||
546 /home/trav/repos/noteflow/client/src-tauri/scripts/../src/error.rs
|
||||
537 /home/trav/repos/noteflow/client/src-tauri/scripts/../src/grpc/streaming.rs
|
||||
593 /home/trav/repos/noteflow/client/src-tauri/scripts/../src/commands/playback.rs
|
||||
550 /home/trav/repos/noteflow/client/src-tauri/scripts/../src/error.rs
|
||||
533 /home/trav/repos/noteflow/client/src-tauri/scripts/../src/grpc/streaming.rs
|
||||
|
||||
|
||||
Checking for scattered helper functions...
|
||||
[1;33mWARNING:[0m Helper functions scattered across 11 files (consider consolidating):
|
||||
[1;33mWARNING:[0m Helper functions scattered across 12 files (consider consolidating):
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/oauth_loopback.rs
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/grpc/streaming.rs
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/grpc/client/converters.rs
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/grpc/client/observability.rs
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/grpc/client/sync.rs
|
||||
/home/trav/repos/noteflow/client/src-tauri/scripts/../src/helpers.rs
|
||||
|
||||
=== Summary ===
|
||||
Errors: [0;31m0[0m
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
# Phase Gaps: Backend-Client Synchronization Issues
|
||||
|
||||
This directory contains sprint specifications for addressing gaps, race conditions, and synchronization issues identified between the NoteFlow backend and Tauri/React client.
|
||||
|
||||
## Summary of Findings
|
||||
|
||||
Analysis of the gRPC API contracts, state management, and streaming operations revealed several categories of issues requiring remediation.
|
||||
|
||||
| Sprint | Category | Severity | Effort |
|
||||
|--------|----------|----------|--------|
|
||||
| [SPRINT-GAP-001](./sprint-gap-001-streaming-race-conditions.md) | Streaming Race Conditions | High | L |
|
||||
| [SPRINT-GAP-002](./sprint-gap-002-state-sync-gaps.md) | State Synchronization | Medium | M |
|
||||
| [SPRINT-GAP-003](./sprint-gap-003-error-handling.md) | Error Handling Mismatches | Medium | M |
|
||||
| [SPRINT-GAP-004](./sprint-gap-004-diarization-lifecycle.md) | Diarization Job Lifecycle | Medium | S |
|
||||
| [SPRINT-GAP-005](./sprint-gap-005-entity-resource-leak.md) | Entity Mixin Resource Leak | High | S |
|
||||
| [SPRINT-GAP-006](./sprint-gap-006-connection-bootstrapping.md) | Connection Bootstrapping | High | M |
|
||||
| [SPRINT-GAP-007](./sprint-gap-007-simulation-mode-clarity.md) | Simulation Mode Clarity | Medium | S |
|
||||
| [SPRINT-GAP-008](./sprint-gap-008-server-address-consistency.md) | Server Address Consistency | Medium | M |
|
||||
| [SPRINT-GAP-009](./sprint-gap-009-event-bridge-contracts.md) | Event Bridge Contracts | Medium | S |
|
||||
| [SPRINT-GAP-010](./sprint-gap-010-identity-and-rpc-logging.md) | Identity + RPC Logging | Medium | M |
|
||||
|
||||
## Priority Matrix
|
||||
|
||||
### Critical (Address Immediately)
|
||||
- **SPRINT-GAP-001**: Audio streaming fire-and-forget can cause data loss
|
||||
- **SPRINT-GAP-005**: Entity mixin context manager misuse causes resource leaks
|
||||
- **SPRINT-GAP-006**: Recording can start without a connected gRPC client
|
||||
|
||||
### High Priority
|
||||
- **SPRINT-GAP-002**: Meeting cache invalidation prevents stale data
|
||||
- **SPRINT-GAP-003**: Silenced errors hide critical failures
|
||||
|
||||
### Medium Priority
|
||||
- **SPRINT-GAP-004**: Diarization polling resilience improvements
|
||||
- **SPRINT-GAP-007**: Simulated paths need explicit UX and safety rails
|
||||
- **SPRINT-GAP-008**: Default server addressing can be mis-pointed in Docker setups
|
||||
- **SPRINT-GAP-009**: Event bridge should initialize before connection
|
||||
- **SPRINT-GAP-010**: Identity metadata and per-RPC logging not wired
|
||||
|
||||
## Cross-Cutting Concerns
|
||||
|
||||
1. **Observability**: All fixes should emit appropriate log events and metrics
|
||||
2. **Testing**: Each sprint must include integration tests for the identified scenarios
|
||||
3. **Backwards Compatibility**: Client-side changes must gracefully handle older server versions
|
||||
|
||||
## Analysis Methodology
|
||||
|
||||
Issues were identified through:
|
||||
1. Code review of gRPC mixins (`src/noteflow/grpc/_mixins/`)
|
||||
2. Tauri command handlers (`client/src-tauri/src/commands/`)
|
||||
3. TypeScript API adapters (`client/src/api/`)
|
||||
4. Pattern matching for anti-patterns (`.catch(() => {})`, missing awaits)
|
||||
5. State machine analysis for race conditions
|
||||
@@ -49,7 +49,7 @@ dev = [
|
||||
"basedpyright>=1.18",
|
||||
"pyrefly>=0.46.1",
|
||||
"sourcery; sys_platform == 'darwin'",
|
||||
"types-grpcio>=1.0.0.20251009",
|
||||
"types-grpcio==1.0.0.20251001",
|
||||
"testcontainers[postgres]>=4.0",
|
||||
]
|
||||
triggers = [
|
||||
@@ -288,6 +288,6 @@ dev = [
|
||||
"pytest-httpx>=0.36.0",
|
||||
"ruff>=0.14.9",
|
||||
"sourcery; sys_platform == 'darwin'",
|
||||
"types-grpcio>=1.0.0.20251009",
|
||||
"types-grpcio==1.0.0.20251001",
|
||||
"watchfiles>=1.1.1",
|
||||
]
|
||||
|
||||
@@ -24,6 +24,7 @@ class MeetingListKwargs(TypedDict, total=False):
|
||||
offset: int
|
||||
sort_desc: bool
|
||||
project_id: UUID | None
|
||||
project_ids: list[UUID] | None
|
||||
|
||||
|
||||
class MeetingRepository(Protocol):
|
||||
@@ -83,7 +84,7 @@ class MeetingRepository(Protocol):
|
||||
"""List meetings with optional filtering.
|
||||
|
||||
Args:
|
||||
**kwargs: Optional filters (states, limit, offset, sort_desc, project_id).
|
||||
**kwargs: Optional filters (states, limit, offset, sort_desc, project_id, project_ids).
|
||||
|
||||
Returns:
|
||||
Tuple of (meetings list, total count matching filter).
|
||||
|
||||
@@ -283,8 +283,30 @@ class MeetingMixin:
|
||||
state_values = cast(Sequence[int], request.states)
|
||||
states = [MeetingState(s) for s in state_values] if state_values else None
|
||||
project_id: UUID | None = None
|
||||
project_ids: list[UUID] | None = None
|
||||
|
||||
if cast(_HasField, request).HasField("project_id") and request.project_id:
|
||||
if request.project_ids:
|
||||
project_ids = []
|
||||
for raw_project_id in request.project_ids:
|
||||
try:
|
||||
project_ids.append(UUID(raw_project_id))
|
||||
except ValueError:
|
||||
truncated = raw_project_id[:8] + "..." if len(raw_project_id) > 8 else raw_project_id
|
||||
logger.warning(
|
||||
"ListMeetings: invalid project_ids format",
|
||||
project_id_truncated=truncated,
|
||||
project_id_length=len(raw_project_id),
|
||||
)
|
||||
await abort_invalid_argument(
|
||||
context,
|
||||
f"{ERROR_INVALID_PROJECT_ID_PREFIX}{raw_project_id}",
|
||||
)
|
||||
|
||||
if (
|
||||
not project_ids
|
||||
and cast(_HasField, request).HasField("project_id")
|
||||
and request.project_id
|
||||
):
|
||||
try:
|
||||
project_id = UUID(request.project_id)
|
||||
except ValueError:
|
||||
@@ -297,7 +319,7 @@ class MeetingMixin:
|
||||
await abort_invalid_argument(context, f"{ERROR_INVALID_PROJECT_ID_PREFIX}{request.project_id}")
|
||||
|
||||
async with self.create_repository_provider() as repo:
|
||||
if project_id is None:
|
||||
if project_id is None and not project_ids:
|
||||
project_id = await _resolve_active_project_id(self, repo)
|
||||
|
||||
meetings, total = await repo.meetings.list_all(
|
||||
@@ -306,6 +328,7 @@ class MeetingMixin:
|
||||
offset=offset,
|
||||
sort_desc=sort_desc,
|
||||
project_id=project_id,
|
||||
project_ids=project_ids,
|
||||
)
|
||||
logger.debug(
|
||||
"ListMeetings returned",
|
||||
@@ -314,6 +337,7 @@ class MeetingMixin:
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
project_id=str(project_id) if project_id else None,
|
||||
project_ids=[str(pid) for pid in project_ids] if project_ids else None,
|
||||
)
|
||||
return noteflow_pb2.ListMeetingsResponse(
|
||||
meetings=[meeting_to_proto(m, include_segments=False) for m in meetings],
|
||||
|
||||
@@ -107,6 +107,7 @@ class MeetingStore:
|
||||
offset = kwargs.get("offset", 0)
|
||||
sort_desc = kwargs.get("sort_desc", True)
|
||||
project_id = kwargs.get("project_id")
|
||||
project_ids = kwargs.get("project_ids")
|
||||
meetings = list(self._meetings.values())
|
||||
|
||||
# Filter by state
|
||||
@@ -114,8 +115,13 @@ class MeetingStore:
|
||||
state_set = set(states)
|
||||
meetings = [m for m in meetings if m.state in state_set]
|
||||
|
||||
# Filter by project if requested
|
||||
if project_id:
|
||||
# Filter by project(s) if requested
|
||||
if project_ids:
|
||||
project_set = set(project_ids)
|
||||
meetings = [
|
||||
m for m in meetings if m.project_id is not None and str(m.project_id) in project_set
|
||||
]
|
||||
elif project_id:
|
||||
meetings = [
|
||||
m for m in meetings if m.project_id is not None and str(m.project_id) == project_id
|
||||
]
|
||||
|
||||
@@ -306,6 +306,9 @@ message ListMeetingsRequest {
|
||||
|
||||
// Optional project filter (defaults to active project if omitted)
|
||||
optional string project_id = 5;
|
||||
|
||||
// Optional project filter for multiple projects (overrides project_id when provided)
|
||||
repeated string project_ids = 6;
|
||||
}
|
||||
|
||||
enum SortOrder {
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -244,18 +244,20 @@ class StopMeetingRequest(_message.Message):
|
||||
def __init__(self, meeting_id: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class ListMeetingsRequest(_message.Message):
|
||||
__slots__ = ("states", "limit", "offset", "sort_order", "project_id")
|
||||
__slots__ = ("states", "limit", "offset", "sort_order", "project_id", "project_ids")
|
||||
STATES_FIELD_NUMBER: _ClassVar[int]
|
||||
LIMIT_FIELD_NUMBER: _ClassVar[int]
|
||||
OFFSET_FIELD_NUMBER: _ClassVar[int]
|
||||
SORT_ORDER_FIELD_NUMBER: _ClassVar[int]
|
||||
PROJECT_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
PROJECT_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||
states: _containers.RepeatedScalarFieldContainer[MeetingState]
|
||||
limit: int
|
||||
offset: int
|
||||
sort_order: SortOrder
|
||||
project_id: str
|
||||
def __init__(self, states: _Optional[_Iterable[_Union[MeetingState, str]]] = ..., limit: _Optional[int] = ..., offset: _Optional[int] = ..., sort_order: _Optional[_Union[SortOrder, str]] = ..., project_id: _Optional[str] = ...) -> None: ...
|
||||
project_ids: _containers.RepeatedScalarFieldContainer[str]
|
||||
def __init__(self, states: _Optional[_Iterable[_Union[MeetingState, str]]] = ..., limit: _Optional[int] = ..., offset: _Optional[int] = ..., sort_order: _Optional[_Union[SortOrder, str]] = ..., project_id: _Optional[str] = ..., project_ids: _Optional[_Iterable[str]] = ...) -> None: ...
|
||||
|
||||
class ListMeetingsResponse(_message.Message):
|
||||
__slots__ = ("meetings", "total_count")
|
||||
|
||||
@@ -51,13 +51,16 @@ class MemoryMeetingRepository:
|
||||
offset = kwargs.get("offset", 0)
|
||||
sort_desc = kwargs.get("sort_desc", True)
|
||||
project_id = kwargs.get("project_id")
|
||||
project_ids = kwargs.get("project_ids")
|
||||
project_filter = str(project_id) if project_id else None
|
||||
project_filters = [str(pid) for pid in project_ids] if project_ids else None
|
||||
return self._store.list_all(
|
||||
states=states,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
sort_desc=sort_desc,
|
||||
project_id=project_filter,
|
||||
project_ids=project_filters,
|
||||
)
|
||||
|
||||
async def count_by_state(self, state: MeetingState) -> int:
|
||||
|
||||
@@ -137,6 +137,7 @@ class SqlAlchemyMeetingRepository(BaseRepository):
|
||||
offset = kwargs.get("offset", 0)
|
||||
sort_desc = kwargs.get("sort_desc", True)
|
||||
project_id = kwargs.get("project_id")
|
||||
project_ids = kwargs.get("project_ids")
|
||||
|
||||
# Build base query
|
||||
stmt = select(MeetingModel)
|
||||
@@ -146,8 +147,10 @@ class SqlAlchemyMeetingRepository(BaseRepository):
|
||||
state_values = [int(s) for s in states]
|
||||
stmt = stmt.where(MeetingModel.state.in_(state_values))
|
||||
|
||||
# Filter by project if requested
|
||||
if project_id is not None:
|
||||
# Filter by project(s) if requested
|
||||
if project_ids:
|
||||
stmt = stmt.where(MeetingModel.project_id.in_(project_ids))
|
||||
elif project_id is not None:
|
||||
stmt = stmt.where(MeetingModel.project_id == project_id)
|
||||
|
||||
# Count total
|
||||
|
||||
@@ -22,7 +22,6 @@ from noteflow.application.services.auth_service import (
|
||||
AuthResult,
|
||||
AuthService,
|
||||
AuthServiceError,
|
||||
LogoutResult,
|
||||
UserInfo,
|
||||
)
|
||||
from noteflow.config.settings import CalendarIntegrationSettings
|
||||
@@ -635,16 +634,17 @@ class TestRefreshAuthTokens:
|
||||
class TestStoreAuthUser:
|
||||
"""Tests for AuthService._store_auth_user workspace handling."""
|
||||
|
||||
async def test_creates_default_workspace_for_new_user(
|
||||
async def test_complete_login_creates_default_workspace_for_new_user(
|
||||
self,
|
||||
calendar_settings: CalendarIntegrationSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
mock_auth_uow: MagicMock,
|
||||
sample_oauth_tokens: OAuthTokens,
|
||||
) -> None:
|
||||
"""_store_auth_user creates default workspace when none exists."""
|
||||
"""complete_login creates default workspace when none exists."""
|
||||
mock_auth_uow.workspaces.get_default_for_user.return_value = None
|
||||
mock_auth_uow.workspaces.create = AsyncMock()
|
||||
mock_oauth_manager.complete_auth.return_value = sample_oauth_tokens
|
||||
|
||||
service = AuthService(
|
||||
uow_factory=lambda: mock_auth_uow,
|
||||
@@ -652,9 +652,11 @@ class TestStoreAuthUser:
|
||||
oauth_manager=mock_oauth_manager,
|
||||
)
|
||||
|
||||
await service._store_auth_user(
|
||||
"google", "test@example.com", "Test User", sample_oauth_tokens
|
||||
)
|
||||
with patch(
|
||||
"noteflow.infrastructure.calendar.google_adapter.GoogleCalendarAdapter.get_user_info",
|
||||
new=AsyncMock(return_value=("test@example.com", "Test User")),
|
||||
):
|
||||
await service.complete_login("google", "auth-code", "state")
|
||||
|
||||
# Verify workspace.create was called (new workspace for new user)
|
||||
mock_auth_uow.workspaces.create.assert_called_once()
|
||||
@@ -730,8 +732,8 @@ class TestRefreshAuthTokensEdgeCases:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestParseProvider:
|
||||
"""Tests for AuthService._parse_provider static method."""
|
||||
class TestProviderParsing:
|
||||
"""Tests for provider parsing through public APIs."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("input_provider", "expected_output"),
|
||||
@@ -743,15 +745,29 @@ class TestParseProvider:
|
||||
pytest.param("OUTLOOK", OAuthProvider.OUTLOOK, id="outlook_uppercase"),
|
||||
],
|
||||
)
|
||||
def test_parses_valid_providers(
|
||||
@pytest.mark.asyncio
|
||||
async def test_initiate_login_accepts_valid_providers(
|
||||
self,
|
||||
input_provider: str,
|
||||
expected_output: OAuthProvider,
|
||||
calendar_settings: CalendarIntegrationSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
) -> None:
|
||||
"""_parse_provider correctly parses valid provider strings."""
|
||||
result = AuthService._parse_provider(input_provider)
|
||||
"""initiate_login accepts valid provider strings (case-insensitive)."""
|
||||
service = AuthService(
|
||||
uow_factory=lambda: MagicMock(),
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
)
|
||||
|
||||
assert result == expected_output, f"should parse {input_provider} correctly"
|
||||
auth_url, state = await service.initiate_login(input_provider)
|
||||
|
||||
mock_oauth_manager.initiate_auth.assert_called_once_with(
|
||||
provider=expected_output,
|
||||
redirect_uri=calendar_settings.redirect_uri,
|
||||
)
|
||||
assert auth_url.startswith("https://auth.example.com/"), "should return auth url"
|
||||
assert state == "state123", "should return state from OAuth manager"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_provider",
|
||||
@@ -762,10 +778,19 @@ class TestParseProvider:
|
||||
pytest.param("invalid", id="random_string"),
|
||||
],
|
||||
)
|
||||
def test_raises_for_invalid_providers(
|
||||
@pytest.mark.asyncio
|
||||
async def test_initiate_login_rejects_invalid_providers(
|
||||
self,
|
||||
invalid_provider: str,
|
||||
calendar_settings: CalendarIntegrationSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
) -> None:
|
||||
"""_parse_provider raises AuthServiceError for invalid providers."""
|
||||
"""initiate_login raises AuthServiceError for invalid providers."""
|
||||
service = AuthService(
|
||||
uow_factory=lambda: MagicMock(),
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
)
|
||||
|
||||
with pytest.raises(AuthServiceError, match="Invalid provider"):
|
||||
AuthService._parse_provider(invalid_provider)
|
||||
await service.initiate_login(invalid_provider)
|
||||
|
||||
@@ -9,6 +9,8 @@ Tests cover:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol, cast
|
||||
|
||||
import grpc
|
||||
import pytest
|
||||
|
||||
@@ -16,6 +18,81 @@ from noteflow.grpc._config import ServicesConfig
|
||||
from noteflow.grpc.proto import noteflow_pb2
|
||||
from noteflow.grpc.service import NoteFlowServicer
|
||||
from noteflow.infrastructure.diarization import DiarizationEngine
|
||||
from noteflow.infrastructure.persistence.repositories.diarization_job_repo import (
|
||||
JOB_STATUS_FAILED,
|
||||
)
|
||||
|
||||
|
||||
class _RefineSpeakerDiarizationRequest(Protocol):
|
||||
"""Protocol for refine request objects."""
|
||||
|
||||
meeting_id: str
|
||||
num_speakers: int
|
||||
|
||||
|
||||
class _RefineSpeakerDiarizationResponse(Protocol):
|
||||
"""Protocol for refine responses used in tests."""
|
||||
|
||||
job_id: str
|
||||
status: int
|
||||
error_message: str
|
||||
|
||||
|
||||
class _CancelDiarizationJobRequest(Protocol):
|
||||
"""Protocol for cancel request objects."""
|
||||
|
||||
job_id: str
|
||||
|
||||
|
||||
class _CancelDiarizationJobResponse(Protocol):
|
||||
"""Protocol for cancel responses used in tests."""
|
||||
|
||||
success: bool
|
||||
error_message: str
|
||||
|
||||
|
||||
class _GetDiarizationJobStatusRequest(Protocol):
|
||||
"""Protocol for job status request objects."""
|
||||
|
||||
job_id: str
|
||||
|
||||
|
||||
class _DiarizationJobStatusResponse(Protocol):
|
||||
"""Protocol for job status responses used in tests."""
|
||||
|
||||
job_id: str
|
||||
status: int
|
||||
error_message: str
|
||||
|
||||
|
||||
class _RefineSpeakerDiarizationCallable(Protocol):
|
||||
"""Callable protocol for refine RPC."""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
request: _RefineSpeakerDiarizationRequest,
|
||||
context: _MockGrpcContext,
|
||||
) -> _RefineSpeakerDiarizationResponse: ...
|
||||
|
||||
|
||||
class _CancelDiarizationJobCallable(Protocol):
|
||||
"""Callable protocol for cancel RPC."""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
request: _CancelDiarizationJobRequest,
|
||||
context: _MockGrpcContext,
|
||||
) -> _CancelDiarizationJobResponse: ...
|
||||
|
||||
|
||||
class _GetDiarizationJobStatusCallable(Protocol):
|
||||
"""Callable protocol for job status RPC."""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
request: _GetDiarizationJobStatusRequest,
|
||||
context: _MockGrpcContext,
|
||||
) -> _DiarizationJobStatusResponse: ...
|
||||
|
||||
|
||||
class _FakeDiarizationEngine(DiarizationEngine):
|
||||
@@ -77,14 +154,15 @@ class TestDatabaseRequirement:
|
||||
store.update(meeting)
|
||||
|
||||
context = _MockGrpcContext()
|
||||
response = await servicer.RefineSpeakerDiarization(
|
||||
refine = cast(_RefineSpeakerDiarizationCallable, servicer.RefineSpeakerDiarization)
|
||||
response = await refine(
|
||||
noteflow_pb2.RefineSpeakerDiarizationRequest(meeting_id=str(meeting.id)),
|
||||
context,
|
||||
)
|
||||
|
||||
# Should return error response, not job_id
|
||||
assert not response.job_id, "Should not return job_id without database"
|
||||
assert response.status == noteflow_pb2.JOB_STATUS_FAILED, "Status should be FAILED"
|
||||
assert response.status == JOB_STATUS_FAILED, "Status should be FAILED"
|
||||
assert context.abort_code == grpc.StatusCode.FAILED_PRECONDITION, "Error code should be FAILED_PRECONDITION"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -98,7 +176,8 @@ class TestDatabaseRequirement:
|
||||
store.update(meeting)
|
||||
|
||||
context = _MockGrpcContext()
|
||||
response = await servicer.RefineSpeakerDiarization(
|
||||
refine = cast(_RefineSpeakerDiarizationCallable, servicer.RefineSpeakerDiarization)
|
||||
response = await refine(
|
||||
noteflow_pb2.RefineSpeakerDiarizationRequest(meeting_id=str(meeting.id)),
|
||||
context,
|
||||
)
|
||||
@@ -114,7 +193,11 @@ class TestDatabaseRequirement:
|
||||
|
||||
# Should abort because no DB support
|
||||
with pytest.raises(_AbortCalled, match="database") as exc_info:
|
||||
await servicer.GetDiarizationJobStatus(
|
||||
get_status = cast(
|
||||
_GetDiarizationJobStatusCallable,
|
||||
servicer.GetDiarizationJobStatus,
|
||||
)
|
||||
await get_status(
|
||||
noteflow_pb2.GetDiarizationJobStatusRequest(job_id="any-job-id"),
|
||||
context,
|
||||
)
|
||||
@@ -128,7 +211,8 @@ class TestDatabaseRequirement:
|
||||
"""CancelDiarizationJob returns error when database unavailable."""
|
||||
context = _MockGrpcContext()
|
||||
|
||||
response = await servicer.CancelDiarizationJob(
|
||||
cancel = cast(_CancelDiarizationJobCallable, servicer.CancelDiarizationJob)
|
||||
response = await cancel(
|
||||
noteflow_pb2.CancelDiarizationJobRequest(job_id="any-job-id"),
|
||||
context,
|
||||
)
|
||||
|
||||
@@ -6,8 +6,9 @@ and CancelDiarizationJob RPCs with comprehensive edge case coverage.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime, timedelta
|
||||
from typing import TYPE_CHECKING, Self, cast
|
||||
from typing import TYPE_CHECKING, Protocol, Self, cast
|
||||
from unittest.mock import AsyncMock
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -23,6 +24,13 @@ from noteflow.grpc.proto import noteflow_pb2
|
||||
from noteflow.grpc.service import NoteFlowServicer
|
||||
from noteflow.infrastructure.diarization import DiarizationEngine
|
||||
from noteflow.infrastructure.persistence.repositories import DiarizationJob
|
||||
from noteflow.infrastructure.persistence.repositories.diarization_job_repo import (
|
||||
JOB_STATUS_CANCELLED,
|
||||
JOB_STATUS_COMPLETED,
|
||||
JOB_STATUS_FAILED,
|
||||
JOB_STATUS_QUEUED,
|
||||
JOB_STATUS_RUNNING,
|
||||
)
|
||||
|
||||
# Test constants for progress calculation
|
||||
EXPECTED_RUNNING_JOB_PROGRESS_PERCENT = 50.0
|
||||
@@ -62,6 +70,84 @@ if TYPE_CHECKING:
|
||||
from noteflow.grpc.meeting_store import MeetingStore
|
||||
|
||||
|
||||
class _RefineSpeakerDiarizationRequest(Protocol):
|
||||
meeting_id: str
|
||||
num_speakers: int
|
||||
|
||||
|
||||
class _RefineSpeakerDiarizationResponse(Protocol):
|
||||
segments_updated: int
|
||||
error_message: str
|
||||
job_id: str
|
||||
status: int
|
||||
|
||||
|
||||
class _RenameSpeakerRequest(Protocol):
|
||||
meeting_id: str
|
||||
old_speaker_id: str
|
||||
new_speaker_name: str
|
||||
|
||||
|
||||
class _RenameSpeakerResponse(Protocol):
|
||||
segments_updated: int
|
||||
success: bool
|
||||
|
||||
|
||||
class _GetDiarizationJobStatusRequest(Protocol):
|
||||
job_id: str
|
||||
|
||||
|
||||
class _DiarizationJobStatusResponse(Protocol):
|
||||
job_id: str
|
||||
status: int
|
||||
segments_updated: int
|
||||
speaker_ids: Sequence[str]
|
||||
error_message: str
|
||||
progress_percent: float
|
||||
|
||||
|
||||
class _CancelDiarizationJobRequest(Protocol):
|
||||
job_id: str
|
||||
|
||||
|
||||
class _CancelDiarizationJobResponse(Protocol):
|
||||
success: bool
|
||||
error_message: str
|
||||
status: int
|
||||
|
||||
|
||||
class _RefineSpeakerDiarizationCallable(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
request: _RefineSpeakerDiarizationRequest,
|
||||
context: _MockGrpcContext,
|
||||
) -> _RefineSpeakerDiarizationResponse: ...
|
||||
|
||||
|
||||
class _RenameSpeakerCallable(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
request: _RenameSpeakerRequest,
|
||||
context: _MockGrpcContext,
|
||||
) -> _RenameSpeakerResponse: ...
|
||||
|
||||
|
||||
class _GetDiarizationJobStatusCallable(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
request: _GetDiarizationJobStatusRequest,
|
||||
context: _MockGrpcContext,
|
||||
) -> _DiarizationJobStatusResponse: ...
|
||||
|
||||
|
||||
class _CancelDiarizationJobCallable(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
request: _CancelDiarizationJobRequest,
|
||||
context: _MockGrpcContext,
|
||||
) -> _CancelDiarizationJobResponse: ...
|
||||
|
||||
|
||||
class _MockGrpcContext:
|
||||
"""Minimal async gRPC context for diarization mixin tests."""
|
||||
|
||||
@@ -140,6 +226,46 @@ def _get_store(servicer: NoteFlowServicer) -> MeetingStore:
|
||||
return servicer.get_memory_store()
|
||||
|
||||
|
||||
async def _call_refine(
|
||||
servicer: NoteFlowServicer,
|
||||
request: _RefineSpeakerDiarizationRequest,
|
||||
context: _MockGrpcContext,
|
||||
) -> _RefineSpeakerDiarizationResponse:
|
||||
"""Call RefineSpeakerDiarization with typed response."""
|
||||
refine = cast(_RefineSpeakerDiarizationCallable, servicer.RefineSpeakerDiarization)
|
||||
return await refine(request, context)
|
||||
|
||||
|
||||
async def _call_rename(
|
||||
servicer: NoteFlowServicer,
|
||||
request: _RenameSpeakerRequest,
|
||||
context: _MockGrpcContext,
|
||||
) -> _RenameSpeakerResponse:
|
||||
"""Call RenameSpeaker with typed response."""
|
||||
rename = cast(_RenameSpeakerCallable, servicer.RenameSpeaker)
|
||||
return await rename(request, context)
|
||||
|
||||
|
||||
async def _call_get_status(
|
||||
servicer: NoteFlowServicer,
|
||||
request: _GetDiarizationJobStatusRequest,
|
||||
context: _MockGrpcContext,
|
||||
) -> _DiarizationJobStatusResponse:
|
||||
"""Call GetDiarizationJobStatus with typed response."""
|
||||
get_status = cast(_GetDiarizationJobStatusCallable, servicer.GetDiarizationJobStatus)
|
||||
return await get_status(request, context)
|
||||
|
||||
|
||||
async def _call_cancel(
|
||||
servicer: NoteFlowServicer,
|
||||
request: _CancelDiarizationJobRequest,
|
||||
context: _MockGrpcContext,
|
||||
) -> _CancelDiarizationJobResponse:
|
||||
"""Call CancelDiarizationJob with typed response."""
|
||||
cancel = cast(_CancelDiarizationJobCallable, servicer.CancelDiarizationJob)
|
||||
return await cancel(request, context)
|
||||
|
||||
|
||||
class _MockDiarizationJobsRepo:
|
||||
"""Mock diarization jobs repository for testing."""
|
||||
|
||||
@@ -186,8 +312,8 @@ class _MockDiarizationJobsRepo:
|
||||
"""Get active job for meeting."""
|
||||
for job in self._jobs.values():
|
||||
if job.meeting_id == meeting_id and job.status in (
|
||||
noteflow_pb2.JOB_STATUS_QUEUED,
|
||||
noteflow_pb2.JOB_STATUS_RUNNING,
|
||||
JOB_STATUS_QUEUED,
|
||||
JOB_STATUS_RUNNING,
|
||||
):
|
||||
return job
|
||||
return None
|
||||
@@ -198,7 +324,7 @@ class _MockDiarizationJobsRepo:
|
||||
job
|
||||
for job in self._jobs.values()
|
||||
if job.status
|
||||
in (noteflow_pb2.JOB_STATUS_QUEUED, noteflow_pb2.JOB_STATUS_RUNNING)
|
||||
in (JOB_STATUS_QUEUED, JOB_STATUS_RUNNING)
|
||||
]
|
||||
|
||||
async def prune_completed(self, ttl_seconds: float) -> int:
|
||||
@@ -277,7 +403,8 @@ class TestRefineSpeakerDiarizationValidation:
|
||||
"""Invalid UUID format returns error response."""
|
||||
context = _MockGrpcContext()
|
||||
|
||||
response = await diarization_servicer.RefineSpeakerDiarization(
|
||||
response = await _call_refine(
|
||||
diarization_servicer,
|
||||
noteflow_pb2.RefineSpeakerDiarizationRequest(meeting_id="not-a-uuid"),
|
||||
context,
|
||||
)
|
||||
@@ -293,7 +420,8 @@ class TestRefineSpeakerDiarizationValidation:
|
||||
context = _MockGrpcContext()
|
||||
meeting_id = str(uuid4())
|
||||
|
||||
response = await diarization_servicer.RefineSpeakerDiarization(
|
||||
response = await _call_refine(
|
||||
diarization_servicer,
|
||||
noteflow_pb2.RefineSpeakerDiarizationRequest(meeting_id=meeting_id),
|
||||
context,
|
||||
)
|
||||
@@ -316,7 +444,8 @@ class TestRefineSpeakerDiarizationState:
|
||||
store.update(meeting)
|
||||
context = _MockGrpcContext()
|
||||
|
||||
response = await diarization_servicer.RefineSpeakerDiarization(
|
||||
response = await _call_refine(
|
||||
diarization_servicer,
|
||||
noteflow_pb2.RefineSpeakerDiarizationRequest(meeting_id=str(meeting.id)),
|
||||
context,
|
||||
)
|
||||
@@ -336,7 +465,8 @@ class TestRefineSpeakerDiarizationState:
|
||||
store.update(meeting)
|
||||
context = _MockGrpcContext()
|
||||
|
||||
response = await diarization_servicer.RefineSpeakerDiarization(
|
||||
response = await _call_refine(
|
||||
diarization_servicer,
|
||||
noteflow_pb2.RefineSpeakerDiarizationRequest(meeting_id=str(meeting.id)),
|
||||
context,
|
||||
)
|
||||
@@ -357,13 +487,14 @@ class TestRefineSpeakerDiarizationState:
|
||||
store.update(meeting)
|
||||
context = _MockGrpcContext()
|
||||
|
||||
response = await diarization_servicer_with_db.RefineSpeakerDiarization(
|
||||
response = await _call_refine(
|
||||
diarization_servicer_with_db,
|
||||
noteflow_pb2.RefineSpeakerDiarizationRequest(meeting_id=str(meeting.id)),
|
||||
context,
|
||||
)
|
||||
|
||||
assert response.job_id, "Should return job_id for accepted request"
|
||||
assert response.status == noteflow_pb2.JOB_STATUS_QUEUED, "Job status should be QUEUED"
|
||||
assert response.status == JOB_STATUS_QUEUED, "Job status should be QUEUED"
|
||||
|
||||
|
||||
class TestRefineSpeakerDiarizationServer:
|
||||
@@ -382,7 +513,8 @@ class TestRefineSpeakerDiarizationServer:
|
||||
store.update(meeting)
|
||||
context = _MockGrpcContext()
|
||||
|
||||
response = await diarization_servicer_disabled.RefineSpeakerDiarization(
|
||||
response = await _call_refine(
|
||||
diarization_servicer_disabled,
|
||||
noteflow_pb2.RefineSpeakerDiarizationRequest(meeting_id=str(meeting.id)),
|
||||
context,
|
||||
)
|
||||
@@ -403,7 +535,8 @@ class TestRefineSpeakerDiarizationServer:
|
||||
store.update(meeting)
|
||||
context = _MockGrpcContext()
|
||||
|
||||
response = await diarization_servicer_no_engine.RefineSpeakerDiarization(
|
||||
response = await _call_refine(
|
||||
diarization_servicer_no_engine,
|
||||
noteflow_pb2.RefineSpeakerDiarizationRequest(meeting_id=str(meeting.id)),
|
||||
context,
|
||||
)
|
||||
@@ -423,7 +556,8 @@ class TestRenameSpeakerValidation:
|
||||
context = _MockGrpcContext()
|
||||
|
||||
with pytest.raises(AssertionError, match="abort called"):
|
||||
await diarization_servicer.RenameSpeaker(
|
||||
await _call_rename(
|
||||
diarization_servicer,
|
||||
noteflow_pb2.RenameSpeakerRequest(
|
||||
meeting_id=str(uuid4()),
|
||||
old_speaker_id="",
|
||||
@@ -442,7 +576,8 @@ class TestRenameSpeakerValidation:
|
||||
context = _MockGrpcContext()
|
||||
|
||||
with pytest.raises(AssertionError, match="abort called"):
|
||||
await diarization_servicer.RenameSpeaker(
|
||||
await _call_rename(
|
||||
diarization_servicer,
|
||||
noteflow_pb2.RenameSpeakerRequest(
|
||||
meeting_id=str(uuid4()),
|
||||
old_speaker_id="SPEAKER_0",
|
||||
@@ -461,7 +596,8 @@ class TestRenameSpeakerValidation:
|
||||
context = _MockGrpcContext()
|
||||
|
||||
with pytest.raises(AssertionError, match="abort called"):
|
||||
await diarization_servicer.RenameSpeaker(
|
||||
await _call_rename(
|
||||
diarization_servicer,
|
||||
noteflow_pb2.RenameSpeakerRequest(
|
||||
meeting_id="invalid-uuid",
|
||||
old_speaker_id="SPEAKER_0",
|
||||
@@ -496,7 +632,8 @@ class TestRenameSpeakerOperation:
|
||||
]
|
||||
store.update(meeting)
|
||||
|
||||
response = await diarization_servicer.RenameSpeaker(
|
||||
response = await _call_rename(
|
||||
diarization_servicer,
|
||||
noteflow_pb2.RenameSpeakerRequest(
|
||||
meeting_id=str(meeting.id), old_speaker_id="SPEAKER_0", new_speaker_name="Alice"
|
||||
),
|
||||
@@ -531,7 +668,8 @@ class TestRenameSpeakerOperation:
|
||||
store.update(meeting)
|
||||
context = _MockGrpcContext()
|
||||
|
||||
response = await diarization_servicer.RenameSpeaker(
|
||||
response = await _call_rename(
|
||||
diarization_servicer,
|
||||
noteflow_pb2.RenameSpeakerRequest(
|
||||
meeting_id=str(meeting.id),
|
||||
old_speaker_id="SPEAKER_0",
|
||||
@@ -560,11 +698,12 @@ class TestGetDiarizationJobStatusProgress:
|
||||
|
||||
job_id = str(uuid4())
|
||||
await mock_diarization_jobs_repo.create(
|
||||
_create_test_job(job_id, str(meeting.id), noteflow_pb2.JOB_STATUS_QUEUED)
|
||||
_create_test_job(job_id, str(meeting.id), JOB_STATUS_QUEUED)
|
||||
)
|
||||
context = _MockGrpcContext()
|
||||
|
||||
response = await diarization_servicer_with_db.GetDiarizationJobStatus(
|
||||
response = await _call_get_status(
|
||||
diarization_servicer_with_db,
|
||||
noteflow_pb2.GetDiarizationJobStatusRequest(job_id=job_id),
|
||||
context,
|
||||
)
|
||||
@@ -587,14 +726,15 @@ class TestGetDiarizationJobStatusProgress:
|
||||
_create_test_job(
|
||||
job_id,
|
||||
str(meeting.id),
|
||||
noteflow_pb2.JOB_STATUS_RUNNING,
|
||||
JOB_STATUS_RUNNING,
|
||||
started_at=utc_now() - timedelta(seconds=10),
|
||||
audio_duration_seconds=120.0,
|
||||
)
|
||||
)
|
||||
context = _MockGrpcContext()
|
||||
|
||||
response = await diarization_servicer_with_db.GetDiarizationJobStatus(
|
||||
response = await _call_get_status(
|
||||
diarization_servicer_with_db,
|
||||
noteflow_pb2.GetDiarizationJobStatusRequest(job_id=job_id),
|
||||
context,
|
||||
)
|
||||
@@ -615,11 +755,12 @@ class TestGetDiarizationJobStatusProgress:
|
||||
|
||||
job_id = str(uuid4())
|
||||
await mock_diarization_jobs_repo.create(
|
||||
_create_test_job(job_id, str(meeting.id), noteflow_pb2.JOB_STATUS_COMPLETED)
|
||||
_create_test_job(job_id, str(meeting.id), JOB_STATUS_COMPLETED)
|
||||
)
|
||||
context = _MockGrpcContext()
|
||||
|
||||
response = await diarization_servicer_with_db.GetDiarizationJobStatus(
|
||||
response = await _call_get_status(
|
||||
diarization_servicer_with_db,
|
||||
noteflow_pb2.GetDiarizationJobStatusRequest(job_id=job_id),
|
||||
context,
|
||||
)
|
||||
@@ -639,11 +780,12 @@ class TestGetDiarizationJobStatusProgress:
|
||||
|
||||
job_id = str(uuid4())
|
||||
await mock_diarization_jobs_repo.create(
|
||||
_create_test_job(job_id, str(meeting.id), noteflow_pb2.JOB_STATUS_FAILED)
|
||||
_create_test_job(job_id, str(meeting.id), JOB_STATUS_FAILED)
|
||||
)
|
||||
context = _MockGrpcContext()
|
||||
|
||||
response = await diarization_servicer_with_db.GetDiarizationJobStatus(
|
||||
response = await _call_get_status(
|
||||
diarization_servicer_with_db,
|
||||
noteflow_pb2.GetDiarizationJobStatusRequest(job_id=job_id),
|
||||
context,
|
||||
)
|
||||
@@ -670,17 +812,18 @@ class TestCancelDiarizationJobStates:
|
||||
|
||||
job_id = str(uuid4())
|
||||
await mock_diarization_jobs_repo.create(
|
||||
_create_test_job(job_id, str(meeting.id), noteflow_pb2.JOB_STATUS_QUEUED)
|
||||
_create_test_job(job_id, str(meeting.id), JOB_STATUS_QUEUED)
|
||||
)
|
||||
context = _MockGrpcContext()
|
||||
|
||||
response = await diarization_servicer_with_db.CancelDiarizationJob(
|
||||
response = await _call_cancel(
|
||||
diarization_servicer_with_db,
|
||||
noteflow_pb2.CancelDiarizationJobRequest(job_id=job_id),
|
||||
context,
|
||||
)
|
||||
|
||||
assert response.success is True, "Cancelling queued job should succeed"
|
||||
assert response.status == noteflow_pb2.JOB_STATUS_CANCELLED, "Cancelled queued job status should be CANCELLED"
|
||||
assert response.status == JOB_STATUS_CANCELLED, "Cancelled queued job status should be CANCELLED"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_mixin_running_succeeds(
|
||||
@@ -695,17 +838,18 @@ class TestCancelDiarizationJobStates:
|
||||
|
||||
job_id = str(uuid4())
|
||||
await mock_diarization_jobs_repo.create(
|
||||
_create_test_job(job_id, str(meeting.id), noteflow_pb2.JOB_STATUS_RUNNING)
|
||||
_create_test_job(job_id, str(meeting.id), JOB_STATUS_RUNNING)
|
||||
)
|
||||
context = _MockGrpcContext()
|
||||
|
||||
response = await diarization_servicer_with_db.CancelDiarizationJob(
|
||||
response = await _call_cancel(
|
||||
diarization_servicer_with_db,
|
||||
noteflow_pb2.CancelDiarizationJobRequest(job_id=job_id),
|
||||
context,
|
||||
)
|
||||
|
||||
assert response.success is True, "Cancelling running job should succeed"
|
||||
assert response.status == noteflow_pb2.JOB_STATUS_CANCELLED, "Cancelled running job status should be CANCELLED"
|
||||
assert response.status == JOB_STATUS_CANCELLED, "Cancelled running job status should be CANCELLED"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_mixin_nonexistent_fails(
|
||||
@@ -715,7 +859,8 @@ class TestCancelDiarizationJobStates:
|
||||
"""Nonexistent job cannot be cancelled."""
|
||||
context = _MockGrpcContext()
|
||||
|
||||
response = await diarization_servicer_with_db.CancelDiarizationJob(
|
||||
response = await _call_cancel(
|
||||
diarization_servicer_with_db,
|
||||
noteflow_pb2.CancelDiarizationJobRequest(job_id="nonexistent-job"),
|
||||
context,
|
||||
)
|
||||
@@ -736,14 +881,15 @@ class TestCancelDiarizationJobStates:
|
||||
|
||||
job_id = str(uuid4())
|
||||
await mock_diarization_jobs_repo.create(
|
||||
_create_test_job(job_id, str(meeting.id), noteflow_pb2.JOB_STATUS_COMPLETED)
|
||||
_create_test_job(job_id, str(meeting.id), JOB_STATUS_COMPLETED)
|
||||
)
|
||||
context = _MockGrpcContext()
|
||||
|
||||
response = await diarization_servicer_with_db.CancelDiarizationJob(
|
||||
response = await _call_cancel(
|
||||
diarization_servicer_with_db,
|
||||
noteflow_pb2.CancelDiarizationJobRequest(job_id=job_id),
|
||||
context,
|
||||
)
|
||||
|
||||
assert response.success is False, "Cancelling completed job should fail"
|
||||
assert response.status == noteflow_pb2.JOB_STATUS_COMPLETED, "Completed job status should remain COMPLETED"
|
||||
assert response.status == JOB_STATUS_COMPLETED, "Completed job status should remain COMPLETED"
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol, cast
|
||||
|
||||
import grpc
|
||||
import pytest
|
||||
|
||||
@@ -11,6 +13,24 @@ from noteflow.grpc.service import NoteFlowServicer
|
||||
from noteflow.infrastructure.diarization import DiarizationEngine
|
||||
|
||||
|
||||
class _RefineSpeakerDiarizationRequest(Protocol):
|
||||
meeting_id: str
|
||||
num_speakers: int
|
||||
|
||||
|
||||
class _RefineSpeakerDiarizationResponse(Protocol):
|
||||
segments_updated: int
|
||||
error_message: str
|
||||
|
||||
|
||||
class _RefineSpeakerDiarizationCallable(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
request: _RefineSpeakerDiarizationRequest,
|
||||
context: _DummyContext,
|
||||
) -> _RefineSpeakerDiarizationResponse: ...
|
||||
|
||||
|
||||
class _DummyContext:
|
||||
"""Minimal gRPC context that raises if abort is invoked."""
|
||||
|
||||
@@ -43,7 +63,8 @@ async def test_refine_speaker_diarization_rejects_active_meeting() -> None:
|
||||
meeting.start_recording()
|
||||
store.update(meeting)
|
||||
|
||||
response = await servicer.RefineSpeakerDiarization(
|
||||
refine = cast(_RefineSpeakerDiarizationCallable, servicer.RefineSpeakerDiarization)
|
||||
response = await refine(
|
||||
noteflow_pb2.RefineSpeakerDiarizationRequest(meeting_id=str(meeting.id)),
|
||||
_DummyContext(),
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TypedDict, Unpack
|
||||
|
||||
import grpc
|
||||
import pytest
|
||||
@@ -55,14 +56,17 @@ async def test_generate_summary_uses_placeholder_when_service_missing() -> None:
|
||||
class _FailingSummarizationService(SummarizationService):
|
||||
"""Summarization service that always reports provider unavailability."""
|
||||
|
||||
class _Options(TypedDict, total=False):
|
||||
mode: SummarizationMode | None
|
||||
max_key_points: int | None
|
||||
max_action_items: int | None
|
||||
style_prompt: str | None
|
||||
|
||||
async def summarize(
|
||||
self,
|
||||
meeting_id: MeetingId,
|
||||
segments: Sequence[Segment],
|
||||
mode: SummarizationMode | None = None,
|
||||
max_key_points: int | None = None,
|
||||
max_action_items: int | None = None,
|
||||
style_prompt: str | None = None,
|
||||
**kwargs: Unpack[_Options],
|
||||
) -> SummarizationServiceResult:
|
||||
raise ProviderUnavailableError("LLM unavailable")
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import UUID, uuid4
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ Tests identity context validation and per-RPC request logging.
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TypeVar
|
||||
from typing import Protocol, cast
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import grpc
|
||||
@@ -29,8 +29,12 @@ from noteflow.infrastructure.logging import (
|
||||
workspace_id_var,
|
||||
)
|
||||
|
||||
_TRequest = TypeVar("_TRequest")
|
||||
_TResponse = TypeVar("_TResponse")
|
||||
class _DummyRequest:
|
||||
"""Placeholder request type for handler casts."""
|
||||
|
||||
|
||||
class _DummyResponse:
|
||||
"""Placeholder response type for handler casts."""
|
||||
|
||||
# Test data
|
||||
TEST_REQUEST_ID = "test-request-123"
|
||||
@@ -58,30 +62,62 @@ def create_handler_call_details(
|
||||
return details
|
||||
|
||||
|
||||
def create_mock_handler() -> grpc.RpcMethodHandler[_TRequest, _TResponse]:
|
||||
def create_mock_handler() -> _UnaryUnaryHandler:
|
||||
"""Create a mock RPC method handler."""
|
||||
handler = MagicMock(spec=grpc.RpcMethodHandler)
|
||||
handler.unary_unary = AsyncMock(return_value="response")
|
||||
handler.unary_stream = None
|
||||
handler.stream_unary = None
|
||||
handler.stream_stream = None
|
||||
handler.request_deserializer = None
|
||||
handler.response_serializer = None
|
||||
return handler
|
||||
return _MockHandler()
|
||||
|
||||
|
||||
def create_mock_continuation(
|
||||
handler: grpc.RpcMethodHandler[_TRequest, _TResponse] | None = None,
|
||||
) -> Callable[
|
||||
[grpc.HandlerCallDetails],
|
||||
Awaitable[grpc.RpcMethodHandler[_TRequest, _TResponse]],
|
||||
]:
|
||||
handler: _UnaryUnaryHandler | None = None,
|
||||
) -> AsyncMock:
|
||||
"""Create a mock continuation function."""
|
||||
if handler is None:
|
||||
handler = create_mock_handler()
|
||||
return AsyncMock(return_value=handler)
|
||||
|
||||
|
||||
class _UnaryUnaryHandler(Protocol):
|
||||
"""Protocol for unary-unary RPC method handlers."""
|
||||
|
||||
unary_unary: Callable[
|
||||
[_DummyRequest, aio.ServicerContext[_DummyRequest, _DummyResponse]],
|
||||
Awaitable[_DummyResponse],
|
||||
] | None
|
||||
unary_stream: object | None
|
||||
stream_unary: object | None
|
||||
stream_stream: object | None
|
||||
request_deserializer: object | None
|
||||
response_serializer: object | None
|
||||
|
||||
|
||||
class _MockHandler:
|
||||
"""Concrete handler for tests with typed unary_unary."""
|
||||
|
||||
unary_unary: Callable[
|
||||
[_DummyRequest, aio.ServicerContext[_DummyRequest, _DummyResponse]],
|
||||
Awaitable[_DummyResponse],
|
||||
] | None
|
||||
unary_stream: object | None
|
||||
stream_unary: object | None
|
||||
stream_stream: object | None
|
||||
request_deserializer: object | None
|
||||
response_serializer: object | None
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.unary_unary = cast(
|
||||
Callable[
|
||||
[_DummyRequest, aio.ServicerContext[_DummyRequest, _DummyResponse]],
|
||||
Awaitable[_DummyResponse],
|
||||
],
|
||||
AsyncMock(return_value="response"),
|
||||
)
|
||||
self.unary_stream = None
|
||||
self.stream_unary = None
|
||||
self.stream_stream = None
|
||||
self.request_deserializer = None
|
||||
self.response_serializer = None
|
||||
|
||||
|
||||
class TestIdentityInterceptor:
|
||||
"""Tests for IdentityInterceptor."""
|
||||
|
||||
@@ -126,9 +162,10 @@ class TestIdentityInterceptor:
|
||||
continuation = create_mock_continuation()
|
||||
|
||||
handler = await interceptor.intercept_service(continuation, details)
|
||||
typed_handler = cast(_UnaryUnaryHandler, handler)
|
||||
|
||||
# Handler should be a rejection handler, not the original
|
||||
assert handler.unary_unary is not None, "handler should have unary_unary"
|
||||
assert typed_handler.unary_unary is not None, "handler should have unary_unary"
|
||||
# Continuation should NOT have been called
|
||||
continuation.assert_not_called()
|
||||
|
||||
@@ -140,13 +177,15 @@ class TestIdentityInterceptor:
|
||||
continuation = create_mock_continuation()
|
||||
|
||||
handler = await interceptor.intercept_service(continuation, details)
|
||||
typed_handler = cast(_UnaryUnaryHandler, handler)
|
||||
|
||||
# Create mock context to verify abort behavior
|
||||
context = AsyncMock(spec=aio.ServicerContext)
|
||||
context.abort = AsyncMock(side_effect=grpc.RpcError("missing x-request-id"))
|
||||
|
||||
with pytest.raises(grpc.RpcError, match="x-request-id"):
|
||||
await handler.unary_unary(MagicMock(), context)
|
||||
assert typed_handler.unary_unary is not None
|
||||
await typed_handler.unary_unary(MagicMock(), context)
|
||||
|
||||
context.abort.assert_called_once()
|
||||
call_args = context.abort.call_args
|
||||
@@ -187,11 +226,13 @@ class TestRequestLoggingInterceptor:
|
||||
|
||||
with patch("noteflow.grpc.interceptors.logging.logger") as mock_logger:
|
||||
wrapped_handler = await interceptor.intercept_service(continuation, details)
|
||||
typed_handler = cast(_UnaryUnaryHandler, wrapped_handler)
|
||||
|
||||
# Execute the wrapped handler
|
||||
context = AsyncMock(spec=aio.ServicerContext)
|
||||
context.peer = MagicMock(return_value="ipv4:127.0.0.1:12345")
|
||||
await wrapped_handler.unary_unary(MagicMock(), context)
|
||||
assert typed_handler.unary_unary is not None
|
||||
await typed_handler.unary_unary(MagicMock(), context)
|
||||
|
||||
# Verify logging
|
||||
mock_logger.info.assert_called_once()
|
||||
@@ -215,12 +256,14 @@ class TestRequestLoggingInterceptor:
|
||||
|
||||
with patch("noteflow.grpc.interceptors.logging.logger") as mock_logger:
|
||||
wrapped_handler = await interceptor.intercept_service(continuation, details)
|
||||
typed_handler = cast(_UnaryUnaryHandler, wrapped_handler)
|
||||
|
||||
context = AsyncMock(spec=aio.ServicerContext)
|
||||
context.peer = MagicMock(return_value="ipv4:127.0.0.1:12345")
|
||||
|
||||
with pytest.raises(Exception, match="Test error"):
|
||||
await wrapped_handler.unary_unary(MagicMock(), context)
|
||||
assert typed_handler.unary_unary is not None
|
||||
await typed_handler.unary_unary(MagicMock(), context)
|
||||
|
||||
# Should still log with INTERNAL status
|
||||
mock_logger.info.assert_called_once()
|
||||
@@ -249,12 +292,14 @@ class TestRequestLoggingInterceptor:
|
||||
|
||||
with patch("noteflow.grpc.interceptors.logging.logger") as mock_logger:
|
||||
wrapped_handler = await interceptor.intercept_service(continuation, details)
|
||||
typed_handler = cast(_UnaryUnaryHandler, wrapped_handler)
|
||||
|
||||
# Context without peer method
|
||||
context = AsyncMock(spec=aio.ServicerContext)
|
||||
context.peer = MagicMock(side_effect=RuntimeError("No peer"))
|
||||
|
||||
await wrapped_handler.unary_unary(MagicMock(), context)
|
||||
assert typed_handler.unary_unary is not None
|
||||
await typed_handler.unary_unary(MagicMock(), context)
|
||||
|
||||
# Should still log with None peer
|
||||
mock_logger.info.assert_called_once()
|
||||
|
||||
@@ -7,10 +7,12 @@ DisconnectOAuth RPCs work correctly with the calendar service.
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import TYPE_CHECKING
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Protocol, cast
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import grpc
|
||||
import pytest
|
||||
|
||||
from noteflow.application.services.calendar_service import CalendarServiceError
|
||||
@@ -28,16 +30,166 @@ class _DummyContext:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.aborted = False
|
||||
self.abort_code: object = None
|
||||
self.abort_code: grpc.StatusCode | None = None
|
||||
self.abort_details: str = ""
|
||||
|
||||
async def abort(self, code: object, details: str) -> None:
|
||||
async def abort(self, code: grpc.StatusCode, details: str) -> None:
|
||||
self.aborted = True
|
||||
self.abort_code = code
|
||||
self.abort_details = details
|
||||
raise AssertionError(f"abort called: {code} - {details}")
|
||||
|
||||
|
||||
class _CalendarProvider(Protocol):
|
||||
name: str
|
||||
is_authenticated: bool
|
||||
display_name: str
|
||||
|
||||
|
||||
class _GetCalendarProvidersResponse(Protocol):
|
||||
providers: Sequence[_CalendarProvider]
|
||||
|
||||
|
||||
class _GetCalendarProvidersCallable(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
request: noteflow_pb2.GetCalendarProvidersRequest,
|
||||
context: _DummyContext,
|
||||
) -> _GetCalendarProvidersResponse: ...
|
||||
|
||||
|
||||
async def _call_get_calendar_providers(
|
||||
servicer: NoteFlowServicer,
|
||||
request: noteflow_pb2.GetCalendarProvidersRequest,
|
||||
context: _DummyContext,
|
||||
) -> _GetCalendarProvidersResponse:
|
||||
get_providers = cast(
|
||||
_GetCalendarProvidersCallable,
|
||||
servicer.GetCalendarProviders,
|
||||
)
|
||||
return await get_providers(request, context)
|
||||
|
||||
|
||||
class _InitiateOAuthRequest(Protocol):
|
||||
provider: str
|
||||
redirect_uri: str
|
||||
integration_type: str
|
||||
|
||||
|
||||
class _InitiateOAuthResponse(Protocol):
|
||||
auth_url: str
|
||||
state: str
|
||||
|
||||
|
||||
class _CompleteOAuthRequest(Protocol):
|
||||
provider: str
|
||||
code: str
|
||||
state: str
|
||||
|
||||
|
||||
class _CompleteOAuthResponse(Protocol):
|
||||
success: bool
|
||||
error_message: str
|
||||
provider_email: str
|
||||
integration_id: str
|
||||
|
||||
|
||||
class _OAuthConnection(Protocol):
|
||||
provider: str
|
||||
status: str
|
||||
email: str
|
||||
expires_at: str
|
||||
error_message: str
|
||||
integration_type: str
|
||||
|
||||
|
||||
class _GetOAuthConnectionStatusRequest(Protocol):
|
||||
provider: str
|
||||
integration_type: str
|
||||
|
||||
|
||||
class _GetOAuthConnectionStatusResponse(Protocol):
|
||||
connection: _OAuthConnection
|
||||
|
||||
|
||||
class _DisconnectOAuthRequest(Protocol):
|
||||
provider: str
|
||||
integration_type: str
|
||||
|
||||
|
||||
class _DisconnectOAuthResponse(Protocol):
|
||||
success: bool
|
||||
error_message: str
|
||||
|
||||
|
||||
class _InitiateOAuthCallable(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
request: _InitiateOAuthRequest,
|
||||
context: _DummyContext,
|
||||
) -> _InitiateOAuthResponse: ...
|
||||
|
||||
|
||||
class _CompleteOAuthCallable(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
request: _CompleteOAuthRequest,
|
||||
context: _DummyContext,
|
||||
) -> _CompleteOAuthResponse: ...
|
||||
|
||||
|
||||
class _GetOAuthConnectionStatusCallable(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
request: _GetOAuthConnectionStatusRequest,
|
||||
context: _DummyContext,
|
||||
) -> _GetOAuthConnectionStatusResponse: ...
|
||||
|
||||
|
||||
class _DisconnectOAuthCallable(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
request: _DisconnectOAuthRequest,
|
||||
context: _DummyContext,
|
||||
) -> _DisconnectOAuthResponse: ...
|
||||
|
||||
|
||||
async def _call_initiate_oauth(
|
||||
servicer: NoteFlowServicer,
|
||||
request: _InitiateOAuthRequest,
|
||||
context: _DummyContext,
|
||||
) -> _InitiateOAuthResponse:
|
||||
initiate = cast(_InitiateOAuthCallable, servicer.InitiateOAuth)
|
||||
return await initiate(request, context)
|
||||
|
||||
|
||||
async def _call_complete_oauth(
|
||||
servicer: NoteFlowServicer,
|
||||
request: _CompleteOAuthRequest,
|
||||
context: _DummyContext,
|
||||
) -> _CompleteOAuthResponse:
|
||||
complete = cast(_CompleteOAuthCallable, servicer.CompleteOAuth)
|
||||
return await complete(request, context)
|
||||
|
||||
|
||||
async def _call_get_oauth_status(
|
||||
servicer: NoteFlowServicer,
|
||||
request: _GetOAuthConnectionStatusRequest,
|
||||
context: _DummyContext,
|
||||
) -> _GetOAuthConnectionStatusResponse:
|
||||
get_status = cast(_GetOAuthConnectionStatusCallable, servicer.GetOAuthConnectionStatus)
|
||||
return await get_status(request, context)
|
||||
|
||||
|
||||
async def _call_disconnect_oauth(
|
||||
servicer: NoteFlowServicer,
|
||||
request: _DisconnectOAuthRequest,
|
||||
context: _DummyContext,
|
||||
) -> _DisconnectOAuthResponse:
|
||||
disconnect = cast(_DisconnectOAuthCallable, servicer.DisconnectOAuth)
|
||||
return await disconnect(request, context)
|
||||
|
||||
|
||||
def _create_mock_connection_info(
|
||||
*,
|
||||
provider: str = "google",
|
||||
@@ -94,7 +246,8 @@ class TestGetCalendarProviders:
|
||||
service = _create_mockcalendar_service()
|
||||
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
||||
|
||||
response = await servicer.GetCalendarProviders(
|
||||
response = await _call_get_calendar_providers(
|
||||
servicer,
|
||||
noteflow_pb2.GetCalendarProvidersRequest(),
|
||||
_DummyContext(),
|
||||
)
|
||||
@@ -112,7 +265,8 @@ class TestGetCalendarProviders:
|
||||
)
|
||||
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
||||
|
||||
response = await servicer.GetCalendarProviders(
|
||||
response = await _call_get_calendar_providers(
|
||||
servicer,
|
||||
noteflow_pb2.GetCalendarProvidersRequest(),
|
||||
_DummyContext(),
|
||||
)
|
||||
@@ -129,7 +283,8 @@ class TestGetCalendarProviders:
|
||||
service = _create_mockcalendar_service()
|
||||
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
||||
|
||||
response = await servicer.GetCalendarProviders(
|
||||
response = await _call_get_calendar_providers(
|
||||
servicer,
|
||||
noteflow_pb2.GetCalendarProvidersRequest(),
|
||||
_DummyContext(),
|
||||
)
|
||||
@@ -147,7 +302,8 @@ class TestGetCalendarProviders:
|
||||
context = _DummyContext()
|
||||
|
||||
with pytest.raises(AssertionError, match="abort called"):
|
||||
await servicer.GetCalendarProviders(
|
||||
await _call_get_calendar_providers(
|
||||
servicer,
|
||||
noteflow_pb2.GetCalendarProvidersRequest(),
|
||||
context,
|
||||
)
|
||||
@@ -168,7 +324,8 @@ class TestInitiateOAuth:
|
||||
)
|
||||
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
||||
|
||||
response = await servicer.InitiateOAuth(
|
||||
response = await _call_initiate_oauth(
|
||||
servicer,
|
||||
noteflow_pb2.InitiateOAuthRequest(provider="google"),
|
||||
_DummyContext(),
|
||||
)
|
||||
@@ -183,7 +340,8 @@ class TestInitiateOAuth:
|
||||
service.initiate_oauth.return_value = ("https://auth.url", "state")
|
||||
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
||||
|
||||
await servicer.InitiateOAuth(
|
||||
await _call_initiate_oauth(
|
||||
servicer,
|
||||
noteflow_pb2.InitiateOAuthRequest(provider="outlook"),
|
||||
_DummyContext(),
|
||||
)
|
||||
@@ -200,7 +358,8 @@ class TestInitiateOAuth:
|
||||
service.initiate_oauth.return_value = ("https://auth.url", "state")
|
||||
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
||||
|
||||
await servicer.InitiateOAuth(
|
||||
await _call_initiate_oauth(
|
||||
servicer,
|
||||
noteflow_pb2.InitiateOAuthRequest(
|
||||
provider="google",
|
||||
redirect_uri="noteflow://oauth/callback",
|
||||
@@ -222,7 +381,8 @@ class TestInitiateOAuth:
|
||||
context = _DummyContext()
|
||||
|
||||
with pytest.raises(AssertionError, match="abort called"):
|
||||
await servicer.InitiateOAuth(
|
||||
await _call_initiate_oauth(
|
||||
servicer,
|
||||
noteflow_pb2.InitiateOAuthRequest(provider="unknown"),
|
||||
context,
|
||||
)
|
||||
@@ -237,7 +397,8 @@ class TestInitiateOAuth:
|
||||
context = _DummyContext()
|
||||
|
||||
with pytest.raises(AssertionError, match="abort called"):
|
||||
await servicer.InitiateOAuth(
|
||||
await _call_initiate_oauth(
|
||||
servicer,
|
||||
noteflow_pb2.InitiateOAuthRequest(provider="google"),
|
||||
context,
|
||||
)
|
||||
@@ -258,7 +419,8 @@ class TestCompleteOAuth:
|
||||
service.complete_oauth.return_value = uuid4() # Returns integration ID
|
||||
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
||||
|
||||
response = await servicer.CompleteOAuth(
|
||||
response = await _call_complete_oauth(
|
||||
servicer,
|
||||
noteflow_pb2.CompleteOAuthRequest(
|
||||
provider="google",
|
||||
code="authorization-code",
|
||||
@@ -278,7 +440,8 @@ class TestCompleteOAuth:
|
||||
service.complete_oauth.return_value = uuid4() # Returns integration ID
|
||||
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
||||
|
||||
await servicer.CompleteOAuth(
|
||||
await _call_complete_oauth(
|
||||
servicer,
|
||||
noteflow_pb2.CompleteOAuthRequest(
|
||||
provider="google",
|
||||
code="my-auth-code",
|
||||
@@ -302,7 +465,8 @@ class TestCompleteOAuth:
|
||||
)
|
||||
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
||||
|
||||
response = await servicer.CompleteOAuth(
|
||||
response = await _call_complete_oauth(
|
||||
servicer,
|
||||
noteflow_pb2.CompleteOAuthRequest(
|
||||
provider="google",
|
||||
code="authorization-code",
|
||||
@@ -323,7 +487,8 @@ class TestCompleteOAuth:
|
||||
)
|
||||
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
||||
|
||||
response = await servicer.CompleteOAuth(
|
||||
response = await _call_complete_oauth(
|
||||
servicer,
|
||||
noteflow_pb2.CompleteOAuthRequest(
|
||||
provider="google",
|
||||
code="invalid-code",
|
||||
@@ -342,7 +507,8 @@ class TestCompleteOAuth:
|
||||
context = _DummyContext()
|
||||
|
||||
with pytest.raises(AssertionError, match="abort called"):
|
||||
await servicer.CompleteOAuth(
|
||||
await _call_complete_oauth(
|
||||
servicer,
|
||||
noteflow_pb2.CompleteOAuthRequest(
|
||||
provider="google",
|
||||
code="code",
|
||||
@@ -366,7 +532,8 @@ class TestGetOAuthConnectionStatus:
|
||||
)
|
||||
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
||||
|
||||
response = await servicer.GetOAuthConnectionStatus(
|
||||
response = await _call_get_oauth_status(
|
||||
servicer,
|
||||
noteflow_pb2.GetOAuthConnectionStatusRequest(provider="google"),
|
||||
_DummyContext(),
|
||||
)
|
||||
@@ -383,7 +550,8 @@ class TestGetOAuthConnectionStatus:
|
||||
)
|
||||
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
||||
|
||||
response = await servicer.GetOAuthConnectionStatus(
|
||||
response = await _call_get_oauth_status(
|
||||
servicer,
|
||||
noteflow_pb2.GetOAuthConnectionStatusRequest(provider="google"),
|
||||
_DummyContext(),
|
||||
)
|
||||
@@ -396,7 +564,8 @@ class TestGetOAuthConnectionStatus:
|
||||
service = _create_mockcalendar_service()
|
||||
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
||||
|
||||
response = await servicer.GetOAuthConnectionStatus(
|
||||
response = await _call_get_oauth_status(
|
||||
servicer,
|
||||
noteflow_pb2.GetOAuthConnectionStatusRequest(
|
||||
provider="google",
|
||||
integration_type="calendar",
|
||||
@@ -413,7 +582,8 @@ class TestGetOAuthConnectionStatus:
|
||||
context = _DummyContext()
|
||||
|
||||
with pytest.raises(AssertionError, match="abort called"):
|
||||
await servicer.GetOAuthConnectionStatus(
|
||||
await _call_get_oauth_status(
|
||||
servicer,
|
||||
noteflow_pb2.GetOAuthConnectionStatusRequest(provider="google"),
|
||||
context,
|
||||
)
|
||||
@@ -433,7 +603,8 @@ class TestDisconnectOAuth:
|
||||
service.disconnect.return_value = True
|
||||
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
||||
|
||||
response = await servicer.DisconnectOAuth(
|
||||
response = await _call_disconnect_oauth(
|
||||
servicer,
|
||||
noteflow_pb2.DisconnectOAuthRequest(provider="google"),
|
||||
_DummyContext(),
|
||||
)
|
||||
@@ -447,7 +618,8 @@ class TestDisconnectOAuth:
|
||||
service.disconnect.return_value = True
|
||||
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
||||
|
||||
await servicer.DisconnectOAuth(
|
||||
await _call_disconnect_oauth(
|
||||
servicer,
|
||||
noteflow_pb2.DisconnectOAuthRequest(provider="outlook"),
|
||||
_DummyContext(),
|
||||
)
|
||||
@@ -463,7 +635,8 @@ class TestDisconnectOAuth:
|
||||
service.disconnect.return_value = False
|
||||
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
||||
|
||||
response = await servicer.DisconnectOAuth(
|
||||
response = await _call_disconnect_oauth(
|
||||
servicer,
|
||||
noteflow_pb2.DisconnectOAuthRequest(provider="google"),
|
||||
_DummyContext(),
|
||||
)
|
||||
@@ -477,7 +650,8 @@ class TestDisconnectOAuth:
|
||||
context = _DummyContext()
|
||||
|
||||
with pytest.raises(AssertionError, match="abort called"):
|
||||
await servicer.DisconnectOAuth(
|
||||
await _call_disconnect_oauth(
|
||||
servicer,
|
||||
noteflow_pb2.DisconnectOAuthRequest(provider="google"),
|
||||
context,
|
||||
)
|
||||
@@ -538,7 +712,8 @@ class TestOAuthRoundTrip:
|
||||
service, _, _ = oauth_flow_service
|
||||
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
||||
|
||||
response = await servicer.InitiateOAuth(
|
||||
response = await _call_initiate_oauth(
|
||||
servicer,
|
||||
noteflow_pb2.InitiateOAuthRequest(provider="google"),
|
||||
_DummyContext(),
|
||||
)
|
||||
@@ -556,7 +731,8 @@ class TestOAuthRoundTrip:
|
||||
|
||||
assert connected_state["google"] is False, "should start disconnected"
|
||||
|
||||
response = await servicer.CompleteOAuth(
|
||||
response = await _call_complete_oauth(
|
||||
servicer,
|
||||
noteflow_pb2.CompleteOAuthRequest(
|
||||
provider="google",
|
||||
code="auth-code",
|
||||
@@ -580,7 +756,8 @@ class TestOAuthRoundTrip:
|
||||
connected_state["google"] = True
|
||||
email_state["google"] = "user@gmail.com"
|
||||
|
||||
response = await servicer.DisconnectOAuth(
|
||||
response = await _call_disconnect_oauth(
|
||||
servicer,
|
||||
noteflow_pb2.DisconnectOAuthRequest(provider="google"),
|
||||
_DummyContext(),
|
||||
)
|
||||
@@ -598,7 +775,8 @@ class TestOAuthRoundTrip:
|
||||
)
|
||||
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
||||
|
||||
response = await servicer.CompleteOAuth(
|
||||
response = await _call_complete_oauth(
|
||||
servicer,
|
||||
noteflow_pb2.CompleteOAuthRequest(
|
||||
provider="google",
|
||||
code="auth-code",
|
||||
@@ -620,11 +798,13 @@ class TestOAuthRoundTrip:
|
||||
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
||||
ctx = _DummyContext()
|
||||
|
||||
google_status = await servicer.GetOAuthConnectionStatus(
|
||||
google_status = await _call_get_oauth_status(
|
||||
servicer,
|
||||
noteflow_pb2.GetOAuthConnectionStatusRequest(provider="google"),
|
||||
ctx,
|
||||
)
|
||||
outlook_status = await servicer.GetOAuthConnectionStatus(
|
||||
outlook_status = await _call_get_oauth_status(
|
||||
servicer,
|
||||
noteflow_pb2.GetOAuthConnectionStatusRequest(provider="outlook"),
|
||||
ctx,
|
||||
)
|
||||
@@ -644,7 +824,8 @@ class TestOAuthSecurityBehavior:
|
||||
service.complete_oauth.side_effect = CalendarServiceError("State mismatch")
|
||||
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
||||
|
||||
response = await servicer.CompleteOAuth(
|
||||
response = await _call_complete_oauth(
|
||||
servicer,
|
||||
noteflow_pb2.CompleteOAuthRequest(
|
||||
provider="google",
|
||||
code="stolen-code",
|
||||
@@ -664,7 +845,8 @@ class TestOAuthSecurityBehavior:
|
||||
service.disconnect.return_value = True
|
||||
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
||||
|
||||
await servicer.DisconnectOAuth(
|
||||
await _call_disconnect_oauth(
|
||||
servicer,
|
||||
noteflow_pb2.DisconnectOAuthRequest(provider="google"),
|
||||
_DummyContext(),
|
||||
)
|
||||
@@ -680,7 +862,8 @@ class TestOAuthSecurityBehavior:
|
||||
)
|
||||
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
||||
|
||||
response = await servicer.CompleteOAuth(
|
||||
response = await _call_complete_oauth(
|
||||
servicer,
|
||||
noteflow_pb2.CompleteOAuthRequest(
|
||||
provider="google",
|
||||
code="code",
|
||||
|
||||
@@ -10,10 +10,10 @@ list_all repository method tests.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from typing import Protocol, cast
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import grpc
|
||||
@@ -23,6 +23,7 @@ from noteflow.application.services.calendar_service import CalendarService
|
||||
from noteflow.domain.entities.integration import (
|
||||
Integration,
|
||||
IntegrationType,
|
||||
SyncRun,
|
||||
)
|
||||
from noteflow.grpc._config import ServicesConfig
|
||||
from noteflow.grpc.meeting_store import MeetingStore
|
||||
@@ -30,7 +31,7 @@ from noteflow.grpc.proto import noteflow_pb2
|
||||
from noteflow.grpc.service import NoteFlowServicer
|
||||
|
||||
|
||||
def _get_status_code_not_found() -> object:
|
||||
def _get_status_code_not_found() -> grpc.StatusCode:
|
||||
"""Helper function to get StatusCode.NOT_FOUND for type checker compatibility."""
|
||||
attr_name = "StatusCode"
|
||||
status_code = getattr(grpc, attr_name)
|
||||
@@ -42,16 +43,132 @@ class _DummyContext:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.aborted = False
|
||||
self.abort_code: object = None
|
||||
self.abort_code: grpc.StatusCode | None = None
|
||||
self.abort_details: str | None = None
|
||||
|
||||
async def abort(self, code: object, details: str) -> None:
|
||||
async def abort(self, code: grpc.StatusCode, details: str) -> None:
|
||||
self.aborted = True
|
||||
self.abort_code = code
|
||||
self.abort_details = details
|
||||
raise AssertionError(f"abort called: {code} - {details}")
|
||||
|
||||
|
||||
class _StartIntegrationSyncRequest(Protocol):
|
||||
integration_id: str
|
||||
|
||||
|
||||
class _StartIntegrationSyncResponse(Protocol):
|
||||
sync_run_id: str
|
||||
status: str
|
||||
|
||||
|
||||
class _GetSyncStatusRequest(Protocol):
|
||||
sync_run_id: str
|
||||
|
||||
|
||||
class _GetSyncStatusResponse(Protocol):
|
||||
status: str
|
||||
items_synced: int
|
||||
items_total: int
|
||||
error_message: str
|
||||
duration_ms: int
|
||||
expires_at: str
|
||||
|
||||
|
||||
class _ListSyncHistoryRequest(Protocol):
|
||||
integration_id: str
|
||||
limit: int
|
||||
|
||||
|
||||
class _ListSyncHistoryResponse(Protocol):
|
||||
total_count: int
|
||||
runs: Sequence["_SyncRunInfo"]
|
||||
|
||||
|
||||
class _SyncRunInfo(Protocol):
|
||||
id: str
|
||||
|
||||
|
||||
class _IntegrationInfo(Protocol):
|
||||
id: str
|
||||
name: str
|
||||
type: str
|
||||
status: str
|
||||
workspace_id: str
|
||||
|
||||
|
||||
class _GetUserIntegrationsResponse(Protocol):
|
||||
integrations: Sequence[_IntegrationInfo]
|
||||
|
||||
|
||||
class _StartIntegrationSyncCallable(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
request: _StartIntegrationSyncRequest,
|
||||
context: _DummyContext,
|
||||
) -> _StartIntegrationSyncResponse: ...
|
||||
|
||||
|
||||
class _GetSyncStatusCallable(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
request: _GetSyncStatusRequest,
|
||||
context: _DummyContext,
|
||||
) -> _GetSyncStatusResponse: ...
|
||||
|
||||
|
||||
class _ListSyncHistoryCallable(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
request: _ListSyncHistoryRequest,
|
||||
context: _DummyContext,
|
||||
) -> _ListSyncHistoryResponse: ...
|
||||
|
||||
|
||||
class _GetUserIntegrationsCallable(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
request: noteflow_pb2.GetUserIntegrationsRequest,
|
||||
context: _DummyContext,
|
||||
) -> _GetUserIntegrationsResponse: ...
|
||||
|
||||
|
||||
async def _call_start_sync(
|
||||
servicer: NoteFlowServicer,
|
||||
request: _StartIntegrationSyncRequest,
|
||||
context: _DummyContext,
|
||||
) -> _StartIntegrationSyncResponse:
|
||||
start_sync = cast(_StartIntegrationSyncCallable, servicer.StartIntegrationSync)
|
||||
return await start_sync(request, context)
|
||||
|
||||
|
||||
async def _call_get_sync_status(
|
||||
servicer: NoteFlowServicer,
|
||||
request: _GetSyncStatusRequest,
|
||||
context: _DummyContext,
|
||||
) -> _GetSyncStatusResponse:
|
||||
get_status = cast(_GetSyncStatusCallable, servicer.GetSyncStatus)
|
||||
return await get_status(request, context)
|
||||
|
||||
|
||||
async def _call_list_sync_history(
|
||||
servicer: NoteFlowServicer,
|
||||
request: _ListSyncHistoryRequest,
|
||||
context: _DummyContext,
|
||||
) -> _ListSyncHistoryResponse:
|
||||
list_history = cast(_ListSyncHistoryCallable, servicer.ListSyncHistory)
|
||||
return await list_history(request, context)
|
||||
|
||||
|
||||
async def _call_get_user_integrations(
|
||||
servicer: NoteFlowServicer,
|
||||
request: noteflow_pb2.GetUserIntegrationsRequest,
|
||||
context: _DummyContext,
|
||||
) -> _GetUserIntegrationsResponse:
|
||||
get_integrations = cast(_GetUserIntegrationsCallable, servicer.GetUserIntegrations)
|
||||
return await get_integrations(request, context)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockCalendarEvent:
|
||||
"""Mock calendar event for testing."""
|
||||
@@ -166,17 +283,18 @@ async def await_sync_completion(
|
||||
sync_run_id: str,
|
||||
context: _DummyContext,
|
||||
timeout: float = 2.0,
|
||||
) -> noteflow_pb2.GetSyncStatusResponse:
|
||||
) -> _GetSyncStatusResponse:
|
||||
"""Wait for sync to complete using event-based synchronization.
|
||||
|
||||
Uses asyncio.wait_for for timeout instead of polling loop.
|
||||
The sync task runs in the background; we wait then check final status.
|
||||
"""
|
||||
|
||||
async def _get_final_status() -> noteflow_pb2.GetSyncStatusResponse:
|
||||
async def _get_final_status() -> _GetSyncStatusResponse:
|
||||
# Brief delay to allow background sync to complete
|
||||
await asyncio.sleep(0.05)
|
||||
return await servicer.GetSyncStatus(
|
||||
return await _call_get_sync_status(
|
||||
servicer,
|
||||
noteflow_pb2.GetSyncStatusRequest(sync_run_id=sync_run_id),
|
||||
context,
|
||||
)
|
||||
@@ -194,7 +312,8 @@ class TestStartIntegrationSync:
|
||||
context = _DummyContext()
|
||||
|
||||
with pytest.raises(AssertionError, match="abort called"):
|
||||
await servicer.StartIntegrationSync(
|
||||
await _call_start_sync(
|
||||
servicer,
|
||||
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(uuid4())),
|
||||
context,
|
||||
)
|
||||
@@ -208,7 +327,8 @@ class TestStartIntegrationSync:
|
||||
context = _DummyContext()
|
||||
|
||||
with pytest.raises(AssertionError, match="abort called"):
|
||||
await servicer.StartIntegrationSync(
|
||||
await _call_start_sync(
|
||||
servicer,
|
||||
noteflow_pb2.StartIntegrationSyncRequest(integration_id=""),
|
||||
context,
|
||||
)
|
||||
@@ -226,7 +346,8 @@ class TestSyncStatus:
|
||||
context = _DummyContext()
|
||||
|
||||
with pytest.raises(AssertionError, match="abort called"):
|
||||
await servicer.GetSyncStatus(
|
||||
await _call_get_sync_status(
|
||||
servicer,
|
||||
noteflow_pb2.GetSyncStatusRequest(sync_run_id=str(uuid4())),
|
||||
context,
|
||||
)
|
||||
@@ -240,7 +361,8 @@ class TestSyncStatus:
|
||||
context = _DummyContext()
|
||||
|
||||
with pytest.raises(AssertionError, match="abort called"):
|
||||
await servicer.GetSyncStatus(
|
||||
await _call_get_sync_status(
|
||||
servicer,
|
||||
noteflow_pb2.GetSyncStatusRequest(sync_run_id=""),
|
||||
context,
|
||||
)
|
||||
@@ -257,7 +379,8 @@ class TestSyncHistory:
|
||||
servicer = NoteFlowServicer()
|
||||
context = _DummyContext()
|
||||
|
||||
response = await servicer.ListSyncHistory(
|
||||
response = await _call_list_sync_history(
|
||||
servicer,
|
||||
noteflow_pb2.ListSyncHistoryRequest(integration_id=str(uuid4()), limit=10),
|
||||
context,
|
||||
)
|
||||
@@ -271,7 +394,8 @@ class TestSyncHistory:
|
||||
servicer = NoteFlowServicer()
|
||||
context = _DummyContext()
|
||||
|
||||
response = await servicer.ListSyncHistory(
|
||||
response = await _call_list_sync_history(
|
||||
servicer,
|
||||
noteflow_pb2.ListSyncHistoryRequest(integration_id=str(uuid4())),
|
||||
context,
|
||||
)
|
||||
@@ -290,7 +414,8 @@ class TestSyncHappyPath:
|
||||
integration = await create_test_integration(meeting_store)
|
||||
context = _DummyContext()
|
||||
|
||||
response = await servicer_with_success.StartIntegrationSync(
|
||||
response = await _call_start_sync(
|
||||
servicer_with_success,
|
||||
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(integration.id)),
|
||||
context,
|
||||
)
|
||||
@@ -315,7 +440,8 @@ class TestSyncHappyPath:
|
||||
integration = await create_test_integration(meeting_store)
|
||||
context = _DummyContext()
|
||||
|
||||
start = await servicer.StartIntegrationSync(
|
||||
start = await _call_start_sync(
|
||||
servicer,
|
||||
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(integration.id)),
|
||||
context,
|
||||
)
|
||||
@@ -340,13 +466,15 @@ class TestSyncHappyPath:
|
||||
integration = await create_test_integration(meeting_store)
|
||||
context = _DummyContext()
|
||||
|
||||
start = await servicer.StartIntegrationSync(
|
||||
start = await _call_start_sync(
|
||||
servicer,
|
||||
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(integration.id)),
|
||||
context,
|
||||
)
|
||||
await await_sync_completion(servicer, start.sync_run_id, context)
|
||||
|
||||
history = await servicer.ListSyncHistory(
|
||||
history = await _call_list_sync_history(
|
||||
servicer,
|
||||
noteflow_pb2.ListSyncHistoryRequest(integration_id=str(integration.id), limit=10),
|
||||
context,
|
||||
)
|
||||
@@ -370,7 +498,8 @@ class TestSyncErrorHandling:
|
||||
integration = await create_test_integration(meeting_store)
|
||||
context = _DummyContext()
|
||||
|
||||
start = await servicer.StartIntegrationSync(
|
||||
start = await _call_start_sync(
|
||||
servicer,
|
||||
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(integration.id)),
|
||||
context,
|
||||
)
|
||||
@@ -387,7 +516,8 @@ class TestSyncErrorHandling:
|
||||
integration = await create_test_integration(meeting_store)
|
||||
context = _DummyContext()
|
||||
|
||||
start = await servicer_with_failure.StartIntegrationSync(
|
||||
start = await _call_start_sync(
|
||||
servicer_with_failure,
|
||||
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(integration.id)),
|
||||
context,
|
||||
)
|
||||
@@ -406,7 +536,8 @@ class TestSyncErrorHandling:
|
||||
servicer_fail = NoteFlowServicer(services=ServicesConfig(calendar_service=cast(CalendarService, failing_service)))
|
||||
servicer_fail.memory_store = meeting_store
|
||||
|
||||
first = await servicer_fail.StartIntegrationSync(
|
||||
first = await _call_start_sync(
|
||||
servicer_fail,
|
||||
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(integration.id)),
|
||||
context,
|
||||
)
|
||||
@@ -419,7 +550,8 @@ class TestSyncErrorHandling:
|
||||
servicer_success = NoteFlowServicer(services=ServicesConfig(calendar_service=cast(CalendarService, success_service)))
|
||||
servicer_success.memory_store = meeting_store
|
||||
|
||||
second = await servicer_success.StartIntegrationSync(
|
||||
second = await _call_start_sync(
|
||||
servicer_success,
|
||||
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(integration.id)),
|
||||
context,
|
||||
)
|
||||
@@ -442,11 +574,13 @@ class TestConcurrentSyncs:
|
||||
int2 = await create_test_integration(meeting_store, "Calendar 2")
|
||||
context = _DummyContext()
|
||||
|
||||
r1 = await servicer_with_success.StartIntegrationSync(
|
||||
r1 = await _call_start_sync(
|
||||
servicer_with_success,
|
||||
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(int1.id)),
|
||||
context,
|
||||
)
|
||||
r2 = await servicer_with_success.StartIntegrationSync(
|
||||
r2 = await _call_start_sync(
|
||||
servicer_with_success,
|
||||
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(int2.id)),
|
||||
context,
|
||||
)
|
||||
@@ -466,22 +600,26 @@ class TestConcurrentSyncs:
|
||||
int2 = await create_test_integration(meeting_store, "Calendar 2")
|
||||
context = _DummyContext()
|
||||
|
||||
r1 = await servicer_with_success.StartIntegrationSync(
|
||||
r1 = await _call_start_sync(
|
||||
servicer_with_success,
|
||||
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(int1.id)),
|
||||
context,
|
||||
)
|
||||
r2 = await servicer_with_success.StartIntegrationSync(
|
||||
r2 = await _call_start_sync(
|
||||
servicer_with_success,
|
||||
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(int2.id)),
|
||||
context,
|
||||
)
|
||||
await await_sync_completion(servicer_with_success, r1.sync_run_id, context)
|
||||
await await_sync_completion(servicer_with_success, r2.sync_run_id, context)
|
||||
|
||||
h1 = await servicer_with_success.ListSyncHistory(
|
||||
h1 = await _call_list_sync_history(
|
||||
servicer_with_success,
|
||||
noteflow_pb2.ListSyncHistoryRequest(integration_id=str(int1.id), limit=10),
|
||||
context,
|
||||
)
|
||||
h2 = await servicer_with_success.ListSyncHistory(
|
||||
h2 = await _call_list_sync_history(
|
||||
servicer_with_success,
|
||||
noteflow_pb2.ListSyncHistoryRequest(integration_id=str(int2.id), limit=10),
|
||||
context,
|
||||
)
|
||||
@@ -506,7 +644,8 @@ class TestSyncPolling:
|
||||
integration = await create_test_integration(meeting_store)
|
||||
context = _DummyContext()
|
||||
|
||||
start = await servicer.StartIntegrationSync(
|
||||
start = await _call_start_sync(
|
||||
servicer,
|
||||
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(integration.id)),
|
||||
context,
|
||||
)
|
||||
@@ -515,7 +654,8 @@ class TestSyncPolling:
|
||||
await asyncio.wait_for(blocking_service.sync_started.wait(), timeout=1.0)
|
||||
|
||||
# Check status while blocked - should still be running
|
||||
immediate_status = await servicer.GetSyncStatus(
|
||||
immediate_status = await _call_get_sync_status(
|
||||
servicer,
|
||||
noteflow_pb2.GetSyncStatusRequest(sync_run_id=start.sync_run_id),
|
||||
context,
|
||||
)
|
||||
@@ -535,7 +675,8 @@ class TestSyncPolling:
|
||||
integration = await create_test_integration(meeting_store)
|
||||
context = _DummyContext()
|
||||
|
||||
start = await servicer_with_success.StartIntegrationSync(
|
||||
start = await _call_start_sync(
|
||||
servicer_with_success,
|
||||
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(integration.id)),
|
||||
context,
|
||||
)
|
||||
@@ -554,7 +695,8 @@ class TestGetUserIntegrations:
|
||||
servicer = NoteFlowServicer()
|
||||
context = _DummyContext()
|
||||
|
||||
response = await servicer.GetUserIntegrations(
|
||||
response = await _call_get_user_integrations(
|
||||
servicer,
|
||||
noteflow_pb2.GetUserIntegrationsRequest(),
|
||||
context,
|
||||
)
|
||||
@@ -575,7 +717,8 @@ class TestGetUserIntegrations:
|
||||
|
||||
integration = await create_test_integration(meeting_store, "Google Calendar")
|
||||
|
||||
response = await servicer.GetUserIntegrations(
|
||||
response = await _call_get_user_integrations(
|
||||
servicer,
|
||||
noteflow_pb2.GetUserIntegrationsRequest(),
|
||||
context,
|
||||
)
|
||||
@@ -596,7 +739,8 @@ class TestGetUserIntegrations:
|
||||
|
||||
integration = await create_test_integration(meeting_store, "Test Calendar")
|
||||
|
||||
response = await servicer.GetUserIntegrations(
|
||||
response = await _call_get_user_integrations(
|
||||
servicer,
|
||||
noteflow_pb2.GetUserIntegrationsRequest(),
|
||||
context,
|
||||
)
|
||||
@@ -625,7 +769,8 @@ class TestGetUserIntegrations:
|
||||
|
||||
await meeting_store.integrations.delete(int1.id)
|
||||
|
||||
response = await servicer.GetUserIntegrations(
|
||||
response = await _call_get_user_integrations(
|
||||
servicer,
|
||||
noteflow_pb2.GetUserIntegrationsRequest(),
|
||||
context,
|
||||
)
|
||||
@@ -648,7 +793,8 @@ class TestNotFoundStatusCode:
|
||||
nonexistent_id = str(uuid4())
|
||||
|
||||
with pytest.raises(AssertionError, match="abort called"):
|
||||
await servicer_with_success.StartIntegrationSync(
|
||||
await _call_start_sync(
|
||||
servicer_with_success,
|
||||
noteflow_pb2.StartIntegrationSyncRequest(integration_id=nonexistent_id),
|
||||
context,
|
||||
)
|
||||
@@ -666,7 +812,8 @@ class TestNotFoundStatusCode:
|
||||
nonexistent_id = str(uuid4())
|
||||
|
||||
with pytest.raises(AssertionError, match="abort called"):
|
||||
await servicer_with_success.GetSyncStatus(
|
||||
await _call_get_sync_status(
|
||||
servicer_with_success,
|
||||
noteflow_pb2.GetSyncStatusRequest(sync_run_id=nonexistent_id),
|
||||
context,
|
||||
)
|
||||
@@ -686,11 +833,13 @@ class TestSyncRunExpiryMetadata:
|
||||
integration = await create_test_integration(meeting_store)
|
||||
context = _DummyContext()
|
||||
|
||||
start = await servicer_with_success.StartIntegrationSync(
|
||||
start = await _call_start_sync(
|
||||
servicer_with_success,
|
||||
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(integration.id)),
|
||||
context,
|
||||
)
|
||||
status = await servicer_with_success.GetSyncStatus(
|
||||
status = await _call_get_sync_status(
|
||||
servicer_with_success,
|
||||
noteflow_pb2.GetSyncStatusRequest(sync_run_id=start.sync_run_id),
|
||||
context,
|
||||
)
|
||||
@@ -706,7 +855,8 @@ class TestSyncRunExpiryMetadata:
|
||||
integration = await create_test_integration(meeting_store)
|
||||
context = _DummyContext()
|
||||
|
||||
start = await servicer_with_success.StartIntegrationSync(
|
||||
start = await _call_start_sync(
|
||||
servicer_with_success,
|
||||
noteflow_pb2.StartIntegrationSyncRequest(integration_id=str(integration.id)),
|
||||
context,
|
||||
)
|
||||
@@ -714,7 +864,7 @@ class TestSyncRunExpiryMetadata:
|
||||
sync_run_id_uuid = UUID(start.sync_run_id)
|
||||
# Type annotation needed: ensure_sync_runs_cache is a mixin method added via SyncMixin
|
||||
ensure_cache = cast(
|
||||
Callable[[], dict[UUID, object]],
|
||||
Callable[[], dict[UUID, SyncRun]],
|
||||
servicer_with_success.ensure_sync_runs_cache,
|
||||
)
|
||||
ensure_cache()
|
||||
|
||||
@@ -90,14 +90,17 @@ class TestOidcProviderRegistry:
|
||||
json=valid_discovery_document,
|
||||
)
|
||||
workspace_id = uuid4()
|
||||
|
||||
provider = await registry.create_provider(
|
||||
registration = OidcProviderRegistration(
|
||||
workspace_id=workspace_id,
|
||||
name="Test Provider",
|
||||
issuer_url="https://auth.example.com",
|
||||
client_id="test-client-id",
|
||||
)
|
||||
|
||||
provider = await registry.create_provider(
|
||||
registration,
|
||||
)
|
||||
|
||||
assert provider.name == "Test Provider", "provider name should match"
|
||||
assert provider.workspace_id == workspace_id, "workspace_id should match"
|
||||
assert provider.discovery is not None, "discovery should be populated"
|
||||
@@ -110,12 +113,15 @@ class TestOidcProviderRegistry:
|
||||
) -> None:
|
||||
"""Verify provider creation without auto-discovery."""
|
||||
workspace_id = uuid4()
|
||||
|
||||
provider = await registry.create_provider(
|
||||
registration = OidcProviderRegistration(
|
||||
workspace_id=workspace_id,
|
||||
name="Test Provider",
|
||||
issuer_url="https://auth.example.com",
|
||||
client_id="test-client-id",
|
||||
)
|
||||
|
||||
provider = await registry.create_provider(
|
||||
registration,
|
||||
auto_discover=False,
|
||||
)
|
||||
|
||||
@@ -130,12 +136,16 @@ class TestOidcProviderRegistry:
|
||||
"""Verify provider creation applies preset defaults."""
|
||||
workspace_id = uuid4()
|
||||
|
||||
params = OidcProviderCreateParams(preset=OidcProviderPreset.AUTHENTIK)
|
||||
provider = await registry.create_provider(
|
||||
registration = OidcProviderRegistration(
|
||||
workspace_id=workspace_id,
|
||||
name="Authentik",
|
||||
issuer_url="https://auth.example.com",
|
||||
client_id="test-client-id",
|
||||
)
|
||||
|
||||
params = OidcProviderCreateParams(preset=OidcProviderPreset.AUTHENTIK)
|
||||
provider = await registry.create_provider(
|
||||
registration,
|
||||
params=params,
|
||||
auto_discover=False,
|
||||
)
|
||||
@@ -155,14 +165,15 @@ class TestOidcProviderRegistry:
|
||||
url="https://auth.example.com/.well-known/openid-configuration",
|
||||
status_code=404,
|
||||
)
|
||||
registration = OidcProviderRegistration(
|
||||
workspace_id=uuid4(),
|
||||
name="Test Provider",
|
||||
issuer_url="https://auth.example.com",
|
||||
client_id="test-client-id",
|
||||
)
|
||||
|
||||
with pytest.raises(OidcDiscoveryError, match="HTTP 404"):
|
||||
await registry.create_provider(
|
||||
workspace_id=uuid4(),
|
||||
name="Test Provider",
|
||||
issuer_url="https://auth.example.com",
|
||||
client_id="test-client-id",
|
||||
)
|
||||
await registry.create_provider(registration)
|
||||
|
||||
def test_get_provider(self, registry: OidcProviderRegistry) -> None:
|
||||
"""Verify get_provider returns correct provider."""
|
||||
@@ -179,26 +190,36 @@ class TestOidcProviderRegistry:
|
||||
workspace1 = uuid4()
|
||||
workspace2 = uuid4()
|
||||
|
||||
# Create providers without discovery
|
||||
await registry.create_provider(
|
||||
registration1 = OidcProviderRegistration(
|
||||
workspace_id=workspace1,
|
||||
name="Provider 1",
|
||||
issuer_url="https://auth1.example.com",
|
||||
client_id="client1",
|
||||
auto_discover=False,
|
||||
)
|
||||
provider2 = await registry.create_provider(
|
||||
registration2 = OidcProviderRegistration(
|
||||
workspace_id=workspace1,
|
||||
name="Provider 2",
|
||||
issuer_url="https://auth2.example.com",
|
||||
client_id="client2",
|
||||
auto_discover=False,
|
||||
)
|
||||
await registry.create_provider(
|
||||
registration3 = OidcProviderRegistration(
|
||||
workspace_id=workspace2,
|
||||
name="Provider 3",
|
||||
issuer_url="https://auth3.example.com",
|
||||
client_id="client3",
|
||||
)
|
||||
|
||||
# Create providers without discovery
|
||||
await registry.create_provider(
|
||||
registration1,
|
||||
auto_discover=False,
|
||||
)
|
||||
provider2 = await registry.create_provider(
|
||||
registration2,
|
||||
auto_discover=False,
|
||||
)
|
||||
await registry.create_provider(
|
||||
registration3,
|
||||
auto_discover=False,
|
||||
)
|
||||
|
||||
@@ -222,12 +243,15 @@ class TestOidcProviderRegistry:
|
||||
) -> None:
|
||||
"""Verify provider removal."""
|
||||
workspace_id = uuid4()
|
||||
|
||||
provider = await registry.create_provider(
|
||||
registration = OidcProviderRegistration(
|
||||
workspace_id=workspace_id,
|
||||
name="Test Provider",
|
||||
issuer_url="https://auth.example.com",
|
||||
client_id="test-client-id",
|
||||
)
|
||||
|
||||
provider = await registry.create_provider(
|
||||
registration,
|
||||
auto_discover=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -11,40 +11,39 @@ Tests cover:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Protocol, cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.infrastructure.diarization._compat import (
|
||||
AudioMetaData,
|
||||
_patch_huggingface_auth,
|
||||
_patch_speechbrain_backend,
|
||||
_patch_torch_load,
|
||||
_patch_torchaudio,
|
||||
apply_patches,
|
||||
ensure_compatibility,
|
||||
)
|
||||
from noteflow.infrastructure.diarization import _compat
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
|
||||
class _CompatModule(Protocol):
|
||||
"""Protocol for the diarization compatibility module."""
|
||||
|
||||
AudioMetaData: type
|
||||
|
||||
def apply_patches(self) -> None: ...
|
||||
|
||||
def ensure_compatibility(self) -> None: ...
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_patches_state() -> Generator[None, None, None]:
|
||||
"""Reset _patches_applied state before and after tests."""
|
||||
import noteflow.infrastructure.diarization._compat as compat_module
|
||||
|
||||
original_state = compat_module._patches_applied
|
||||
compat_module._patches_applied = False
|
||||
yield
|
||||
compat_module._patches_applied = original_state
|
||||
def compat_module() -> Generator[_CompatModule, None, None]:
|
||||
"""Reload compatibility module to reset internal patch state."""
|
||||
module = importlib.reload(_compat)
|
||||
yield cast(_CompatModule, module)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -79,9 +78,9 @@ def mock_huggingface_hub() -> MagicMock:
|
||||
class TestAudioMetaData:
|
||||
"""Tests for the replacement AudioMetaData dataclass."""
|
||||
|
||||
def test_audiometadata_has_required_fields(self) -> None:
|
||||
def test_audiometadata_has_required_fields(self, compat_module: _CompatModule) -> None:
|
||||
"""AudioMetaData has all fields expected by pyannote.audio."""
|
||||
metadata = AudioMetaData(
|
||||
metadata = compat_module.AudioMetaData(
|
||||
sample_rate=16000,
|
||||
num_frames=48000,
|
||||
num_channels=1,
|
||||
@@ -95,9 +94,9 @@ class TestAudioMetaData:
|
||||
assert metadata.bits_per_sample == 16, "should store bits_per_sample"
|
||||
assert metadata.encoding == "PCM_S", "should store encoding"
|
||||
|
||||
def test_audiometadata_is_immutable(self) -> None:
|
||||
def test_audiometadata_is_immutable(self, compat_module: _CompatModule) -> None:
|
||||
"""AudioMetaData fields cannot be modified after creation."""
|
||||
metadata = AudioMetaData(
|
||||
metadata = compat_module.AudioMetaData(
|
||||
sample_rate=16000,
|
||||
num_frames=48000,
|
||||
num_channels=1,
|
||||
@@ -119,38 +118,42 @@ class TestPatchTorchaudio:
|
||||
"""Tests for torchaudio AudioMetaData patching."""
|
||||
|
||||
def test_patches_audiometadata_when_missing(
|
||||
self, mock_torchaudio: MagicMock
|
||||
self, compat_module: _CompatModule, mock_torchaudio: MagicMock
|
||||
) -> None:
|
||||
"""_patch_torchaudio adds AudioMetaData when not present."""
|
||||
with patch.dict(sys.modules, {"torchaudio": mock_torchaudio}):
|
||||
_patch_torchaudio()
|
||||
compat_module.apply_patches()
|
||||
|
||||
assert hasattr(
|
||||
mock_torchaudio, "AudioMetaData"
|
||||
), "should add AudioMetaData"
|
||||
assert (
|
||||
mock_torchaudio.AudioMetaData is AudioMetaData
|
||||
mock_torchaudio.AudioMetaData is compat_module.AudioMetaData
|
||||
), "should use our AudioMetaData class"
|
||||
|
||||
def test_does_not_override_existing_audiometadata(self) -> None:
|
||||
def test_does_not_override_existing_audiometadata(
|
||||
self, compat_module: _CompatModule
|
||||
) -> None:
|
||||
"""_patch_torchaudio preserves existing AudioMetaData if present."""
|
||||
mock = MagicMock()
|
||||
existing_class = type("ExistingAudioMetaData", (), {})
|
||||
mock.AudioMetaData = existing_class
|
||||
|
||||
with patch.dict(sys.modules, {"torchaudio": mock}):
|
||||
_patch_torchaudio()
|
||||
compat_module.apply_patches()
|
||||
|
||||
assert (
|
||||
mock.AudioMetaData is existing_class
|
||||
), "should not override existing AudioMetaData"
|
||||
|
||||
def test_handles_import_error_gracefully(self) -> None:
|
||||
def test_handles_import_error_gracefully(
|
||||
self, compat_module: _CompatModule
|
||||
) -> None:
|
||||
"""_patch_torchaudio doesn't raise when torchaudio not installed."""
|
||||
# Remove torchaudio from modules if present
|
||||
with patch.dict(sys.modules, {"torchaudio": None}):
|
||||
# Should not raise
|
||||
_patch_torchaudio()
|
||||
compat_module.apply_patches()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -162,7 +165,7 @@ class TestPatchTorchLoad:
|
||||
"""Tests for torch.load weights_only patching."""
|
||||
|
||||
def test_patches_torch_load_for_pytorch_2_6_plus(
|
||||
self, mock_torch: MagicMock
|
||||
self, compat_module: _CompatModule, mock_torch: MagicMock
|
||||
) -> None:
|
||||
"""_patch_torch_load adds weights_only=False default for PyTorch 2.6+."""
|
||||
original_load = mock_torch.load
|
||||
@@ -172,12 +175,12 @@ class TestPatchTorchLoad:
|
||||
mock_version.return_value = mock_version
|
||||
mock_version.__ge__ = MagicMock(return_value=True)
|
||||
|
||||
_patch_torch_load()
|
||||
compat_module.apply_patches()
|
||||
|
||||
# Verify torch.load was replaced (not the same function)
|
||||
assert mock_torch.load is not original_load, "load should be patched"
|
||||
|
||||
def test_does_not_patch_older_pytorch(self) -> None:
|
||||
def test_does_not_patch_older_pytorch(self, compat_module: _CompatModule) -> None:
|
||||
"""_patch_torch_load skips patching for PyTorch < 2.6."""
|
||||
mock = MagicMock()
|
||||
mock.__version__ = "2.5.0"
|
||||
@@ -188,15 +191,17 @@ class TestPatchTorchLoad:
|
||||
mock_version.return_value = mock_version
|
||||
mock_version.__ge__ = MagicMock(return_value=False)
|
||||
|
||||
_patch_torch_load()
|
||||
compat_module.apply_patches()
|
||||
|
||||
# load should not have been replaced
|
||||
assert mock.load is original_load, "should not patch older PyTorch"
|
||||
|
||||
def test_handles_import_error_gracefully(self) -> None:
|
||||
def test_handles_import_error_gracefully(
|
||||
self, compat_module: _CompatModule
|
||||
) -> None:
|
||||
"""_patch_torch_load doesn't raise when torch not installed."""
|
||||
with patch.dict(sys.modules, {"torch": None}):
|
||||
_patch_torch_load()
|
||||
compat_module.apply_patches()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -208,13 +213,13 @@ class TestPatchHuggingfaceAuth:
|
||||
"""Tests for huggingface_hub use_auth_token patching."""
|
||||
|
||||
def test_converts_use_auth_token_to_token(
|
||||
self, mock_huggingface_hub: MagicMock
|
||||
self, compat_module: _CompatModule, mock_huggingface_hub: MagicMock
|
||||
) -> None:
|
||||
"""_patch_huggingface_auth converts use_auth_token to token parameter."""
|
||||
original_download = mock_huggingface_hub.hf_hub_download
|
||||
|
||||
with patch.dict(sys.modules, {"huggingface_hub": mock_huggingface_hub}):
|
||||
_patch_huggingface_auth()
|
||||
compat_module.apply_patches()
|
||||
|
||||
# Call with legacy use_auth_token
|
||||
mock_huggingface_hub.hf_hub_download(
|
||||
@@ -233,13 +238,13 @@ class TestPatchHuggingfaceAuth:
|
||||
), "should remove use_auth_token"
|
||||
|
||||
def test_preserves_token_parameter(
|
||||
self, mock_huggingface_hub: MagicMock
|
||||
self, compat_module: _CompatModule, mock_huggingface_hub: MagicMock
|
||||
) -> None:
|
||||
"""_patch_huggingface_auth preserves token if already using new API."""
|
||||
original_download = mock_huggingface_hub.hf_hub_download
|
||||
|
||||
with patch.dict(sys.modules, {"huggingface_hub": mock_huggingface_hub}):
|
||||
_patch_huggingface_auth()
|
||||
compat_module.apply_patches()
|
||||
|
||||
mock_huggingface_hub.hf_hub_download(
|
||||
repo_id="test/repo",
|
||||
@@ -251,10 +256,12 @@ class TestPatchHuggingfaceAuth:
|
||||
call_kwargs = original_download.call_args[1]
|
||||
assert call_kwargs["token"] == "my_token", "should preserve token"
|
||||
|
||||
def test_handles_import_error_gracefully(self) -> None:
|
||||
def test_handles_import_error_gracefully(
|
||||
self, compat_module: _CompatModule
|
||||
) -> None:
|
||||
"""_patch_huggingface_auth doesn't raise when huggingface_hub not installed."""
|
||||
with patch.dict(sys.modules, {"huggingface_hub": None}):
|
||||
_patch_huggingface_auth()
|
||||
compat_module.apply_patches()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -265,10 +272,12 @@ class TestPatchHuggingfaceAuth:
|
||||
class TestPatchSpeechbrainBackend:
|
||||
"""Tests for torchaudio backend API patching."""
|
||||
|
||||
def test_patches_list_audio_backends(self, mock_torchaudio: MagicMock) -> None:
|
||||
def test_patches_list_audio_backends(
|
||||
self, compat_module: _CompatModule, mock_torchaudio: MagicMock
|
||||
) -> None:
|
||||
"""_patch_speechbrain_backend adds list_audio_backends when missing."""
|
||||
with patch.dict(sys.modules, {"torchaudio": mock_torchaudio}):
|
||||
_patch_speechbrain_backend()
|
||||
compat_module.apply_patches()
|
||||
|
||||
assert hasattr(
|
||||
mock_torchaudio, "list_audio_backends"
|
||||
@@ -276,10 +285,12 @@ class TestPatchSpeechbrainBackend:
|
||||
result = mock_torchaudio.list_audio_backends()
|
||||
assert isinstance(result, list), "should return list"
|
||||
|
||||
def test_patches_get_audio_backend(self, mock_torchaudio: MagicMock) -> None:
|
||||
def test_patches_get_audio_backend(
|
||||
self, compat_module: _CompatModule, mock_torchaudio: MagicMock
|
||||
) -> None:
|
||||
"""_patch_speechbrain_backend adds get_audio_backend when missing."""
|
||||
with patch.dict(sys.modules, {"torchaudio": mock_torchaudio}):
|
||||
_patch_speechbrain_backend()
|
||||
compat_module.apply_patches()
|
||||
|
||||
assert hasattr(
|
||||
mock_torchaudio, "get_audio_backend"
|
||||
@@ -287,10 +298,12 @@ class TestPatchSpeechbrainBackend:
|
||||
result = mock_torchaudio.get_audio_backend()
|
||||
assert result is None, "should return None"
|
||||
|
||||
def test_patches_set_audio_backend(self, mock_torchaudio: MagicMock) -> None:
|
||||
def test_patches_set_audio_backend(
|
||||
self, compat_module: _CompatModule, mock_torchaudio: MagicMock
|
||||
) -> None:
|
||||
"""_patch_speechbrain_backend adds set_audio_backend when missing."""
|
||||
with patch.dict(sys.modules, {"torchaudio": mock_torchaudio}):
|
||||
_patch_speechbrain_backend()
|
||||
compat_module.apply_patches()
|
||||
|
||||
assert hasattr(
|
||||
mock_torchaudio, "set_audio_backend"
|
||||
@@ -298,14 +311,16 @@ class TestPatchSpeechbrainBackend:
|
||||
# Should not raise
|
||||
mock_torchaudio.set_audio_backend("sox")
|
||||
|
||||
def test_does_not_override_existing_functions(self) -> None:
|
||||
def test_does_not_override_existing_functions(
|
||||
self, compat_module: _CompatModule
|
||||
) -> None:
|
||||
"""_patch_speechbrain_backend preserves existing backend functions."""
|
||||
mock = MagicMock()
|
||||
existing_list = MagicMock(return_value=["ffmpeg"])
|
||||
mock.list_audio_backends = existing_list
|
||||
|
||||
with patch.dict(sys.modules, {"torchaudio": mock}):
|
||||
_patch_speechbrain_backend()
|
||||
compat_module.apply_patches()
|
||||
|
||||
assert (
|
||||
mock.list_audio_backends is existing_list
|
||||
@@ -321,42 +336,24 @@ class TestApplyPatches:
|
||||
"""Tests for the main apply_patches function."""
|
||||
|
||||
def test_apply_patches_is_idempotent(
|
||||
self, reset_patches_state: None
|
||||
self, compat_module: _CompatModule
|
||||
) -> None:
|
||||
"""apply_patches only applies patches once."""
|
||||
import noteflow.infrastructure.diarization._compat as compat_module
|
||||
mock_torch = MagicMock()
|
||||
mock_torch.__version__ = "2.6.0"
|
||||
original_load = mock_torch.load
|
||||
|
||||
with patch.object(compat_module, "_patch_torchaudio") as mock_torchaudio:
|
||||
with patch.object(compat_module, "_patch_torch_load") as mock_torch:
|
||||
with patch.object(
|
||||
compat_module, "_patch_huggingface_auth"
|
||||
) as mock_hf:
|
||||
with patch.object(
|
||||
compat_module, "_patch_speechbrain_backend"
|
||||
) as mock_sb:
|
||||
apply_patches()
|
||||
apply_patches() # Second call
|
||||
apply_patches() # Third call
|
||||
with patch.dict(sys.modules, {"torch": mock_torch}):
|
||||
with patch("packaging.version.Version") as mock_version:
|
||||
mock_version.return_value = mock_version
|
||||
mock_version.__ge__ = MagicMock(return_value=True)
|
||||
|
||||
# Each patch function should only be called once
|
||||
mock_torchaudio.assert_called_once()
|
||||
mock_torch.assert_called_once()
|
||||
mock_hf.assert_called_once()
|
||||
mock_sb.assert_called_once()
|
||||
compat_module.apply_patches()
|
||||
first_load = mock_torch.load
|
||||
compat_module.apply_patches()
|
||||
|
||||
def test_apply_patches_sets_flag(self, reset_patches_state: None) -> None:
|
||||
"""apply_patches sets _patches_applied flag."""
|
||||
import noteflow.infrastructure.diarization._compat as compat_module
|
||||
|
||||
assert compat_module._patches_applied is False, "should start False"
|
||||
|
||||
with patch.object(compat_module, "_patch_torchaudio"):
|
||||
with patch.object(compat_module, "_patch_torch_load"):
|
||||
with patch.object(compat_module, "_patch_huggingface_auth"):
|
||||
with patch.object(compat_module, "_patch_speechbrain_backend"):
|
||||
apply_patches()
|
||||
|
||||
assert compat_module._patches_applied is True, "should be True after apply"
|
||||
assert first_load is not original_load, "initial call should patch torch.load"
|
||||
assert mock_torch.load is first_load, "subsequent calls should be idempotent"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -368,12 +365,10 @@ class TestEnsureCompatibility:
|
||||
"""Tests for the ensure_compatibility entry point."""
|
||||
|
||||
def test_ensure_compatibility_calls_apply_patches(
|
||||
self, reset_patches_state: None
|
||||
self, compat_module: _CompatModule
|
||||
) -> None:
|
||||
"""ensure_compatibility delegates to apply_patches."""
|
||||
import noteflow.infrastructure.diarization._compat as compat_module
|
||||
|
||||
with patch.object(compat_module, "apply_patches") as mock_apply:
|
||||
ensure_compatibility()
|
||||
compat_module.ensure_compatibility()
|
||||
|
||||
mock_apply.assert_called_once()
|
||||
|
||||
@@ -60,7 +60,11 @@ def enabled_config() -> WebhookConfig:
|
||||
def disabled_config() -> WebhookConfig:
|
||||
"""Create a disabled webhook config."""
|
||||
workspace_id = uuid4()
|
||||
now = WebhookConfig.create(workspace_id, "", [WebhookEventType.MEETING_COMPLETED]).created_at
|
||||
now = WebhookConfig.create(
|
||||
workspace_id=workspace_id,
|
||||
url="",
|
||||
events=[WebhookEventType.MEETING_COMPLETED],
|
||||
).created_at
|
||||
return WebhookConfig(
|
||||
id=uuid4(),
|
||||
workspace_id=workspace_id,
|
||||
|
||||
@@ -15,7 +15,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, cast
|
||||
from typing import TYPE_CHECKING, Protocol, cast
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
@@ -40,6 +40,66 @@ from noteflow.infrastructure.persistence.unit_of_work import SqlAlchemyUnitOfWor
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
|
||||
class _RefineSpeakerDiarizationRequest(Protocol):
|
||||
meeting_id: str
|
||||
num_speakers: int
|
||||
|
||||
|
||||
class _RefineSpeakerDiarizationResponse(Protocol):
|
||||
job_id: str
|
||||
status: int
|
||||
error_message: str
|
||||
segments_updated: int
|
||||
|
||||
|
||||
class _DiarizationJobStatusRequest(Protocol):
|
||||
job_id: str
|
||||
|
||||
|
||||
class _DiarizationJobStatusResponse(Protocol):
|
||||
job_id: str
|
||||
status: int
|
||||
segments_updated: int
|
||||
speaker_ids: Sequence[str]
|
||||
error_message: str
|
||||
progress_percent: float
|
||||
|
||||
|
||||
class _RenameSpeakerRequest(Protocol):
|
||||
meeting_id: str
|
||||
old_speaker_id: str
|
||||
new_speaker_name: str
|
||||
|
||||
|
||||
class _RenameSpeakerResponse(Protocol):
|
||||
segments_updated: int
|
||||
success: bool
|
||||
|
||||
|
||||
class _RefineSpeakerDiarizationCallable(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
request: _RefineSpeakerDiarizationRequest,
|
||||
context: MockContext,
|
||||
) -> _RefineSpeakerDiarizationResponse: ...
|
||||
|
||||
|
||||
class _GetDiarizationJobStatusCallable(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
request: _DiarizationJobStatusRequest,
|
||||
context: MockContext,
|
||||
) -> _DiarizationJobStatusResponse: ...
|
||||
|
||||
|
||||
class _RenameSpeakerCallable(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
request: _RenameSpeakerRequest,
|
||||
context: MockContext,
|
||||
) -> _RenameSpeakerResponse: ...
|
||||
|
||||
# ============================================================================
|
||||
# Test Constants
|
||||
# ============================================================================
|
||||
@@ -104,6 +164,36 @@ class _MockRpcError(grpc.RpcError):
|
||||
return self._details
|
||||
|
||||
|
||||
async def _call_refine(
|
||||
servicer: NoteFlowServicer,
|
||||
request: _RefineSpeakerDiarizationRequest,
|
||||
context: MockContext,
|
||||
) -> _RefineSpeakerDiarizationResponse:
|
||||
"""Call RefineSpeakerDiarization with typed response."""
|
||||
refine = cast(_RefineSpeakerDiarizationCallable, servicer.RefineSpeakerDiarization)
|
||||
return await refine(request, context)
|
||||
|
||||
|
||||
async def _call_get_status(
|
||||
servicer: NoteFlowServicer,
|
||||
request: _DiarizationJobStatusRequest,
|
||||
context: MockContext,
|
||||
) -> _DiarizationJobStatusResponse:
|
||||
"""Call GetDiarizationJobStatus with typed response."""
|
||||
get_status = cast(_GetDiarizationJobStatusCallable, servicer.GetDiarizationJobStatus)
|
||||
return await get_status(request, context)
|
||||
|
||||
|
||||
async def _call_rename(
|
||||
servicer: NoteFlowServicer,
|
||||
request: _RenameSpeakerRequest,
|
||||
context: MockContext,
|
||||
) -> _RenameSpeakerResponse:
|
||||
"""Call RenameSpeaker with typed response."""
|
||||
rename = cast(_RenameSpeakerCallable, servicer.RenameSpeaker)
|
||||
return await rename(request, context)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestServicerMeetingOperationsWithDatabase:
|
||||
"""Integration tests for meeting operations using real database."""
|
||||
@@ -320,10 +410,10 @@ class TestServicerDiarizationWithDatabase:
|
||||
request = noteflow_pb2.RefineSpeakerDiarizationRequest(
|
||||
meeting_id=meeting_id_str,
|
||||
)
|
||||
result = await servicer.RefineSpeakerDiarization(request, MockContext())
|
||||
result = await _call_refine(servicer, request, MockContext())
|
||||
|
||||
assert result.job_id, "RefineSpeakerDiarization response should include a job ID"
|
||||
assert result.status == noteflow_pb2.JOB_STATUS_QUEUED, f"expected QUEUED status, got {result.status}"
|
||||
assert result.status == JOB_STATUS_QUEUED, f"expected QUEUED status, got {result.status}"
|
||||
|
||||
async with SqlAlchemyUnitOfWork(session_factory, meetings_dir) as uow:
|
||||
job = await uow.diarization_jobs.get(result.job_id)
|
||||
@@ -352,12 +442,12 @@ class TestServicerDiarizationWithDatabase:
|
||||
servicer = NoteFlowServicer(session_factory=session_factory)
|
||||
|
||||
request = noteflow_pb2.GetDiarizationJobStatusRequest(job_id=job.job_id)
|
||||
response: noteflow_pb2.DiarizationJobStatus = await servicer.GetDiarizationJobStatus(request, MockContext())
|
||||
response = await _call_get_status(servicer, request, MockContext())
|
||||
|
||||
# cast required: protobuf RepeatedScalarFieldContainer is typed in .pyi but pyright doesn't resolve generic
|
||||
speaker_ids_list: Sequence[str] = cast(Sequence[str], response.speaker_ids)
|
||||
speaker_ids_list = response.speaker_ids
|
||||
assert response.job_id == job.job_id, f"expected job_id {job.job_id}, got {response.job_id}"
|
||||
assert response.status == noteflow_pb2.JOB_STATUS_COMPLETED, f"expected COMPLETED status, got {response.status}"
|
||||
assert response.status == JOB_STATUS_COMPLETED, f"expected COMPLETED status, got {response.status}"
|
||||
assert response.segments_updated == DIARIZATION_SEGMENTS_UPDATED, f"expected {DIARIZATION_SEGMENTS_UPDATED} segments_updated, got {response.segments_updated}"
|
||||
assert list(speaker_ids_list) == ["SPEAKER_00", "SPEAKER_01"], f"expected speaker_ids ['SPEAKER_00', 'SPEAKER_01'], got {list(speaker_ids_list)}"
|
||||
|
||||
@@ -371,7 +461,7 @@ class TestServicerDiarizationWithDatabase:
|
||||
request = noteflow_pb2.GetDiarizationJobStatusRequest(job_id="nonexistent")
|
||||
|
||||
with pytest.raises(grpc.RpcError, match=r".*"):
|
||||
await servicer.GetDiarizationJobStatus(request, context)
|
||||
await _call_get_status(servicer, request, context)
|
||||
|
||||
assert context.abort_code == grpc.StatusCode.NOT_FOUND, f"expected NOT_FOUND status for nonexistent job, got {context.abort_code}"
|
||||
|
||||
@@ -397,9 +487,9 @@ class TestServicerDiarizationWithDatabase:
|
||||
request = noteflow_pb2.RefineSpeakerDiarizationRequest(
|
||||
meeting_id=str(meeting.id),
|
||||
)
|
||||
result = await servicer.RefineSpeakerDiarization(request, MockContext())
|
||||
result = await _call_refine(servicer, request, MockContext())
|
||||
|
||||
assert result.status == noteflow_pb2.JOB_STATUS_FAILED, f"expected FAILED status for recording meeting, got {result.status}"
|
||||
assert result.status == JOB_STATUS_FAILED, f"expected FAILED status for recording meeting, got {result.status}"
|
||||
assert "stopped" in result.error_message.lower(), f"expected 'stopped' in error message, got '{result.error_message}'"
|
||||
|
||||
|
||||
@@ -515,7 +605,7 @@ class TestServicerRenameSpeakerWithDatabase:
|
||||
old_speaker_id="SPEAKER_00",
|
||||
new_speaker_name="Alice",
|
||||
)
|
||||
result = await servicer.RenameSpeaker(request, MockContext())
|
||||
result = await _call_rename(servicer, request, MockContext())
|
||||
|
||||
assert result.segments_updated == 3, f"expected 3 segments updated, got {result.segments_updated}"
|
||||
assert result.success is True, "RenameSpeaker should return success=True"
|
||||
@@ -1111,7 +1201,7 @@ class TestServerRestartJobRecovery:
|
||||
|
||||
servicer_new = NoteFlowServicer(session_factory=session_factory)
|
||||
request = noteflow_pb2.GetDiarizationJobStatusRequest(job_id=job_id)
|
||||
response = await servicer_new.GetDiarizationJobStatus(request, MockContext())
|
||||
response = await _call_get_status(servicer_new, request, MockContext())
|
||||
|
||||
assert response.status == noteflow_pb2.JOB_STATUS_FAILED, "job should be FAILED"
|
||||
assert response.status == JOB_STATUS_FAILED, "job should be FAILED"
|
||||
assert response.error_message == "Server restarted", "should have restart error"
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Protocol, cast
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import grpc
|
||||
@@ -22,6 +22,25 @@ from noteflow.infrastructure.persistence.unit_of_work import SqlAlchemyUnitOfWor
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
|
||||
class _AudioChunkRequest(Protocol):
|
||||
meeting_id: str
|
||||
audio_data: bytes
|
||||
sample_rate: int
|
||||
channels: int
|
||||
|
||||
|
||||
class _TranscriptUpdate(Protocol):
|
||||
update_type: int
|
||||
|
||||
|
||||
class _StreamTranscriptionCallable(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
request_iterator: AsyncIterator[_AudioChunkRequest],
|
||||
context: MockContext,
|
||||
) -> AsyncIterator[_TranscriptUpdate]: ...
|
||||
|
||||
SAMPLE_RATE = DEFAULT_SAMPLE_RATE
|
||||
CHUNK_SAMPLES = 1600 # 0.1s at 16kHz
|
||||
SPEECH_CHUNKS = 4
|
||||
@@ -88,15 +107,16 @@ class TestStreamingRealPipeline:
|
||||
meetings_dir=meetings_dir,
|
||||
)
|
||||
|
||||
updates: list[noteflow_pb2.TranscriptUpdate] = [
|
||||
stream = cast(_StreamTranscriptionCallable, servicer.StreamTranscription)
|
||||
updates: list[_TranscriptUpdate] = [
|
||||
update
|
||||
async for update in servicer.StreamTranscription(
|
||||
async for update in stream(
|
||||
_audio_stream(str(meeting.id)),
|
||||
MockContext(),
|
||||
)
|
||||
]
|
||||
|
||||
final_updates: list[noteflow_pb2.TranscriptUpdate] = [
|
||||
final_updates: list[_TranscriptUpdate] = [
|
||||
update
|
||||
for update in updates
|
||||
if update.update_type == noteflow_pb2.UPDATE_TYPE_FINAL
|
||||
|
||||
12
uv.lock
generated
12
uv.lock
generated
@@ -1,5 +1,5 @@
|
||||
version = 1
|
||||
revision = 2
|
||||
revision = 3
|
||||
requires-python = ">=3.12"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.13'",
|
||||
@@ -2440,7 +2440,7 @@ requires-dist = [
|
||||
{ name = "testcontainers", extras = ["postgres"], marker = "extra == 'dev'", specifier = ">=4.0" },
|
||||
{ name = "torch", marker = "extra == 'diarization'", specifier = ">=2.0" },
|
||||
{ name = "torch", marker = "extra == 'optional'", specifier = ">=2.0" },
|
||||
{ name = "types-grpcio", marker = "extra == 'dev'", specifier = ">=1.0.0.20251009" },
|
||||
{ name = "types-grpcio", marker = "extra == 'dev'", specifier = "==1.0.0.20251001" },
|
||||
{ name = "types-psutil", specifier = ">=7.2.0.20251228" },
|
||||
{ name = "weasyprint", marker = "extra == 'optional'", specifier = ">=67.0" },
|
||||
{ name = "weasyprint", marker = "extra == 'pdf'", specifier = ">=67.0" },
|
||||
@@ -2456,7 +2456,7 @@ dev = [
|
||||
{ name = "pytest-httpx", specifier = ">=0.36.0" },
|
||||
{ name = "ruff", specifier = ">=0.14.9" },
|
||||
{ name = "sourcery", marker = "sys_platform == 'darwin'" },
|
||||
{ name = "types-grpcio", specifier = ">=1.0.0.20251009" },
|
||||
{ name = "types-grpcio", specifier = "==1.0.0.20251001" },
|
||||
{ name = "watchfiles", specifier = ">=1.1.1" },
|
||||
]
|
||||
|
||||
@@ -7411,11 +7411,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "types-grpcio"
|
||||
version = "1.0.0.20251009"
|
||||
version = "1.0.0.20251001"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/de/93/78aa083216853c667c9412df4ef8284b2a68c6bcd2aef833f970b311f3c1/types_grpcio-1.0.0.20251009.tar.gz", hash = "sha256:a8f615ea7a47b31f10da028ab5258d4f1611fbd70719ca450fc0ab3fb9c62b63", size = 14479, upload-time = "2025-10-09T02:54:14.539Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/6e/84/569f4bd7d6c70337a16171041930027d46229d760cbe1dbaa422e18a7abf/types_grpcio-1.0.0.20251001.tar.gz", hash = "sha256:5334e2076b3ad621188af58b082ac7b31ea52f3e9a01cdd1984823bf63e7ce55", size = 14672, upload-time = "2025-10-01T03:04:10.388Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d4/93/66d28f41b16bb4e6b611bd608ef28dffc740facec93250b30cf83138da21/types_grpcio-1.0.0.20251009-py3-none-any.whl", hash = "sha256:112ac4312a5b0a273a4c414f7f2c7668f342990d9c6ab0f647391c36331f95ed", size = 15208, upload-time = "2025-10-09T02:54:13.588Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/42/33/f2e6e5f4f2f80fa6ee253dc47610f856baa76b87e56050b4bd7d91b7d272/types_grpcio-1.0.0.20251001-py3-none-any.whl", hash = "sha256:7e7dc6e7238f1dc353adae2e172e8f4acd2388ad8935eba10d6a93dc58529148", size = 15351, upload-time = "2025-10-01T03:04:09.183Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user