898 lines
31 KiB
Python
898 lines
31 KiB
Python
"""Tests for OAuth gRPC endpoints.
|
|
|
|
Validates the InitiateOAuth, CompleteOAuth, GetOAuthConnectionStatus, and
|
|
DisconnectOAuth RPCs work correctly with the calendar service.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Sequence
|
|
from datetime import UTC, datetime, timedelta
|
|
from typing import TYPE_CHECKING, Protocol, cast
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
from uuid import UUID, uuid4
|
|
|
|
import grpc
|
|
import pytest
|
|
|
|
from noteflow.application.services.calendar import CalendarServiceError
|
|
from noteflow.domain.entities.integration import IntegrationStatus
|
|
from noteflow.domain.identity import DEFAULT_WORKSPACE_ID
|
|
from noteflow.grpc.config.config import ServicesConfig
|
|
from noteflow.grpc.proto import noteflow_pb2
|
|
from noteflow.grpc.service import NoteFlowServicer
|
|
|
|
if TYPE_CHECKING:
|
|
from noteflow.domain.ports.calendar import OAuthConnectionInfo
|
|
|
|
|
|
class _DummyContext:
|
|
"""Minimal gRPC context for testing."""
|
|
|
|
def __init__(self) -> None:
|
|
self.aborted = False
|
|
self.abort_code: grpc.StatusCode | None = None
|
|
self.abort_details: str = ""
|
|
|
|
async def abort(self, code: grpc.StatusCode, details: str) -> None:
|
|
self.aborted = True
|
|
self.abort_code = code
|
|
self.abort_details = details
|
|
raise AssertionError(f"abort called: {code} - {details}")
|
|
|
|
|
|
class _CalendarProvider(Protocol):
|
|
name: str
|
|
is_authenticated: bool
|
|
display_name: str
|
|
|
|
|
|
class _GetCalendarProvidersResponse(Protocol):
|
|
providers: Sequence[_CalendarProvider]
|
|
|
|
|
|
class _GetCalendarProvidersCallable(Protocol):
|
|
async def __call__(
|
|
self,
|
|
request: noteflow_pb2.GetCalendarProvidersRequest,
|
|
context: _DummyContext,
|
|
) -> _GetCalendarProvidersResponse: ...
|
|
|
|
|
|
async def _call_get_calendar_providers(
|
|
servicer: NoteFlowServicer,
|
|
request: noteflow_pb2.GetCalendarProvidersRequest,
|
|
context: _DummyContext,
|
|
) -> _GetCalendarProvidersResponse:
|
|
get_providers = cast(
|
|
_GetCalendarProvidersCallable,
|
|
servicer.GetCalendarProviders,
|
|
)
|
|
return await get_providers(request, context)
|
|
|
|
|
|
class _InitiateOAuthRequest(Protocol):
|
|
provider: str
|
|
redirect_uri: str
|
|
integration_type: str
|
|
|
|
|
|
class _InitiateOAuthResponse(Protocol):
|
|
auth_url: str
|
|
state: str
|
|
|
|
|
|
class _CompleteOAuthRequest(Protocol):
|
|
provider: str
|
|
code: str
|
|
state: str
|
|
|
|
|
|
class _CompleteOAuthResponse(Protocol):
|
|
success: bool
|
|
error_message: str
|
|
provider_email: str
|
|
integration_id: str
|
|
|
|
|
|
class _OAuthConnection(Protocol):
|
|
provider: str
|
|
status: str
|
|
email: str
|
|
expires_at: str
|
|
error_message: str
|
|
integration_type: str
|
|
|
|
|
|
class _GetOAuthConnectionStatusRequest(Protocol):
|
|
provider: str
|
|
integration_type: str
|
|
|
|
|
|
class _GetOAuthConnectionStatusResponse(Protocol):
|
|
connection: _OAuthConnection
|
|
|
|
|
|
class _DisconnectOAuthRequest(Protocol):
|
|
provider: str
|
|
integration_type: str
|
|
|
|
|
|
class _DisconnectOAuthResponse(Protocol):
|
|
success: bool
|
|
error_message: str
|
|
|
|
|
|
class _InitiateOAuthCallable(Protocol):
|
|
async def __call__(
|
|
self,
|
|
request: _InitiateOAuthRequest,
|
|
context: _DummyContext,
|
|
) -> _InitiateOAuthResponse: ...
|
|
|
|
|
|
class _CompleteOAuthCallable(Protocol):
|
|
async def __call__(
|
|
self,
|
|
request: _CompleteOAuthRequest,
|
|
context: _DummyContext,
|
|
) -> _CompleteOAuthResponse: ...
|
|
|
|
|
|
class _GetOAuthConnectionStatusCallable(Protocol):
|
|
async def __call__(
|
|
self,
|
|
request: _GetOAuthConnectionStatusRequest,
|
|
context: _DummyContext,
|
|
) -> _GetOAuthConnectionStatusResponse: ...
|
|
|
|
|
|
class _DisconnectOAuthCallable(Protocol):
|
|
async def __call__(
|
|
self,
|
|
request: _DisconnectOAuthRequest,
|
|
context: _DummyContext,
|
|
) -> _DisconnectOAuthResponse: ...
|
|
|
|
|
|
async def _call_initiate_oauth(
|
|
servicer: NoteFlowServicer,
|
|
request: _InitiateOAuthRequest,
|
|
context: _DummyContext,
|
|
) -> _InitiateOAuthResponse:
|
|
initiate = cast(_InitiateOAuthCallable, servicer.InitiateOAuth)
|
|
return await initiate(request, context)
|
|
|
|
|
|
async def _call_complete_oauth(
|
|
servicer: NoteFlowServicer,
|
|
request: _CompleteOAuthRequest,
|
|
context: _DummyContext,
|
|
) -> _CompleteOAuthResponse:
|
|
complete = cast(_CompleteOAuthCallable, servicer.CompleteOAuth)
|
|
return await complete(request, context)
|
|
|
|
|
|
async def _call_get_oauth_status(
|
|
servicer: NoteFlowServicer,
|
|
request: _GetOAuthConnectionStatusRequest,
|
|
context: _DummyContext,
|
|
) -> _GetOAuthConnectionStatusResponse:
|
|
get_status = cast(_GetOAuthConnectionStatusCallable, servicer.GetOAuthConnectionStatus)
|
|
return await get_status(request, context)
|
|
|
|
|
|
async def _call_disconnect_oauth(
|
|
servicer: NoteFlowServicer,
|
|
request: _DisconnectOAuthRequest,
|
|
context: _DummyContext,
|
|
) -> _DisconnectOAuthResponse:
|
|
disconnect = cast(_DisconnectOAuthCallable, servicer.DisconnectOAuth)
|
|
return await disconnect(request, context)
|
|
|
|
|
|
def _create_mock_connection_info(
|
|
*,
|
|
provider: str = "google",
|
|
status: str = "disconnected",
|
|
email: str | None = None,
|
|
expires_at: datetime | None = None,
|
|
error_message: str | None = None,
|
|
) -> OAuthConnectionInfo:
|
|
"""Create a mock OAuthConnectionInfo object."""
|
|
info = MagicMock()
|
|
info.provider = provider
|
|
info.status = status
|
|
info.email = email
|
|
info.expires_at = expires_at
|
|
info.error_message = error_message
|
|
return info
|
|
|
|
|
|
def _create_mockcalendar_service(
|
|
*,
|
|
providers_connected: dict[str, bool] | None = None,
|
|
provider_emails: dict[str, str] | None = None,
|
|
) -> MagicMock:
|
|
"""Create a mock calendar service with controllable state."""
|
|
providers_connected = providers_connected or {}
|
|
provider_emails = provider_emails or {}
|
|
|
|
service = MagicMock()
|
|
|
|
async def get_connection_status(
|
|
provider: str,
|
|
workspace_id: UUID | None = None,
|
|
) -> OAuthConnectionInfo:
|
|
is_connected = providers_connected.get(provider, False)
|
|
email = provider_emails.get(provider)
|
|
return _create_mock_connection_info(
|
|
provider=provider,
|
|
status=IntegrationStatus.CONNECTED.value
|
|
if is_connected
|
|
else IntegrationStatus.DISCONNECTED.value,
|
|
email=email,
|
|
expires_at=datetime.now(UTC) + timedelta(hours=1) if is_connected else None,
|
|
)
|
|
|
|
service.get_connection_status = AsyncMock(side_effect=get_connection_status)
|
|
service.initiate_oauth = AsyncMock()
|
|
service.complete_oauth = AsyncMock()
|
|
service.disconnect = AsyncMock(return_value=True)
|
|
|
|
return service
|
|
|
|
|
|
class TestGetCalendarProviders:
|
|
"""Tests for GetCalendarProviders RPC."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_available_providers(self) -> None:
|
|
"""Returns list of available calendar providers."""
|
|
service = _create_mockcalendar_service()
|
|
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
|
|
|
response = await _call_get_calendar_providers(
|
|
servicer,
|
|
noteflow_pb2.GetCalendarProvidersRequest(),
|
|
_DummyContext(),
|
|
)
|
|
|
|
assert len(response.providers) == 2, "should return google and outlook"
|
|
provider_names = [p.name for p in response.providers]
|
|
assert "google" in provider_names, "should include google"
|
|
assert "outlook" in provider_names, "should include outlook"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_authentication_status_for_each_provider(self) -> None:
|
|
"""Returns is_authenticated flag for each provider."""
|
|
service = _create_mockcalendar_service(
|
|
providers_connected={"google": True, "outlook": False}
|
|
)
|
|
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
|
|
|
response = await _call_get_calendar_providers(
|
|
servicer,
|
|
noteflow_pb2.GetCalendarProvidersRequest(),
|
|
_DummyContext(),
|
|
)
|
|
|
|
google = next(p for p in response.providers if p.name == "google")
|
|
outlook = next(p for p in response.providers if p.name == "outlook")
|
|
|
|
assert google.is_authenticated is True, "google should be authenticated"
|
|
assert outlook.is_authenticated is False, "outlook should not be authenticated"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_display_names(self) -> None:
|
|
"""Returns human-readable display names for providers."""
|
|
service = _create_mockcalendar_service()
|
|
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
|
|
|
response = await _call_get_calendar_providers(
|
|
servicer,
|
|
noteflow_pb2.GetCalendarProvidersRequest(),
|
|
_DummyContext(),
|
|
)
|
|
|
|
google = next(p for p in response.providers if p.name == "google")
|
|
outlook = next(p for p in response.providers if p.name == "outlook")
|
|
|
|
assert google.display_name == "Google Calendar", "google should have correct display name"
|
|
assert outlook.display_name == "Microsoft Outlook", (
|
|
"outlook should have correct display name"
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_aborts_whencalendar_service_not_configured(self) -> None:
|
|
"""Aborts with UNAVAILABLE when calendar service is not configured."""
|
|
servicer = NoteFlowServicer()
|
|
context = _DummyContext()
|
|
|
|
with pytest.raises(AssertionError, match="abort called"):
|
|
await _call_get_calendar_providers(
|
|
servicer,
|
|
noteflow_pb2.GetCalendarProvidersRequest(),
|
|
context,
|
|
)
|
|
|
|
assert context.aborted, "should abort when service unavailable"
|
|
|
|
|
|
class TestInitiateOAuth:
|
|
"""Tests for InitiateOAuth RPC."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_auth_url_and_state(self) -> None:
|
|
"""Returns authorization URL and state token on success."""
|
|
service = _create_mockcalendar_service()
|
|
service.initiate_oauth.return_value = (
|
|
"https://accounts.google.com/o/oauth2/v2/auth?client_id=...",
|
|
"state-token-123",
|
|
)
|
|
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
|
|
|
response = await _call_initiate_oauth(
|
|
servicer,
|
|
noteflow_pb2.InitiateOAuthRequest(provider="google"),
|
|
_DummyContext(),
|
|
)
|
|
|
|
assert "accounts.google.com" in response.auth_url, "should return google auth url"
|
|
assert response.state == "state-token-123", "should return state token"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_passes_provider_to_service(self) -> None:
|
|
"""Passes provider name to calendar service."""
|
|
service = _create_mockcalendar_service()
|
|
service.initiate_oauth.return_value = ("https://auth.url", "state")
|
|
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
|
|
|
await _call_initiate_oauth(
|
|
servicer,
|
|
noteflow_pb2.InitiateOAuthRequest(provider="outlook"),
|
|
_DummyContext(),
|
|
)
|
|
|
|
service.initiate_oauth.assert_awaited_once_with(
|
|
provider="outlook",
|
|
redirect_uri=None,
|
|
workspace_id=DEFAULT_WORKSPACE_ID,
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_passes_custom_redirect_uri(self) -> None:
|
|
"""Passes custom redirect URI when provided."""
|
|
service = _create_mockcalendar_service()
|
|
service.initiate_oauth.return_value = ("https://auth.url", "state")
|
|
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
|
|
|
await _call_initiate_oauth(
|
|
servicer,
|
|
noteflow_pb2.InitiateOAuthRequest(
|
|
provider="google",
|
|
redirect_uri="noteflow://oauth/callback",
|
|
),
|
|
_DummyContext(),
|
|
)
|
|
|
|
service.initiate_oauth.assert_awaited_once_with(
|
|
provider="google",
|
|
redirect_uri="noteflow://oauth/callback",
|
|
workspace_id=DEFAULT_WORKSPACE_ID,
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_aborts_on_invalid_provider(self) -> None:
|
|
"""Aborts with INVALID_ARGUMENT for unknown provider."""
|
|
service = _create_mockcalendar_service()
|
|
service.initiate_oauth.side_effect = CalendarServiceError("Unknown provider: unknown")
|
|
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
|
context = _DummyContext()
|
|
|
|
with pytest.raises(AssertionError, match="abort called"):
|
|
await _call_initiate_oauth(
|
|
servicer,
|
|
noteflow_pb2.InitiateOAuthRequest(provider="unknown"),
|
|
context,
|
|
)
|
|
|
|
assert context.aborted, "should abort on invalid provider"
|
|
assert "Unknown provider" in context.abort_details, "should include provider in error"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_aborts_when_initiate_service_unavailable(self) -> None:
|
|
"""Aborts when calendar service is not configured."""
|
|
servicer = NoteFlowServicer()
|
|
context = _DummyContext()
|
|
|
|
with pytest.raises(AssertionError, match="abort called"):
|
|
await _call_initiate_oauth(
|
|
servicer,
|
|
noteflow_pb2.InitiateOAuthRequest(provider="google"),
|
|
context,
|
|
)
|
|
|
|
assert context.aborted, "should abort when service unavailable"
|
|
|
|
|
|
class TestCompleteOAuth:
|
|
"""Tests for CompleteOAuth RPC."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_success_on_valid_code(self) -> None:
|
|
"""Returns success=True when OAuth flow completes successfully."""
|
|
service = _create_mockcalendar_service(
|
|
providers_connected={"google": True},
|
|
provider_emails={"google": "user@gmail.com"},
|
|
)
|
|
service.complete_oauth.return_value = uuid4() # Returns integration ID
|
|
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
|
|
|
response = await _call_complete_oauth(
|
|
servicer,
|
|
noteflow_pb2.CompleteOAuthRequest(
|
|
provider="google",
|
|
code="authorization-code",
|
|
state="state-token-123",
|
|
),
|
|
_DummyContext(),
|
|
)
|
|
|
|
assert response.success is True, "should succeed on valid code"
|
|
assert response.provider_email == "user@gmail.com", "should return email"
|
|
assert response.integration_id, "should return integration_id"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_passes_code_and_state_to_service(self) -> None:
|
|
"""Passes authorization code and state to calendar service."""
|
|
service = _create_mockcalendar_service()
|
|
service.complete_oauth.return_value = uuid4() # Returns integration ID
|
|
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
|
|
|
await _call_complete_oauth(
|
|
servicer,
|
|
noteflow_pb2.CompleteOAuthRequest(
|
|
provider="google",
|
|
code="my-auth-code",
|
|
state="my-state-token",
|
|
),
|
|
_DummyContext(),
|
|
)
|
|
|
|
service.complete_oauth.assert_awaited_once_with(
|
|
provider="google",
|
|
code="my-auth-code",
|
|
state="my-state-token",
|
|
workspace_id=DEFAULT_WORKSPACE_ID,
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_error_on_invalid_state(self) -> None:
|
|
"""Returns success=False with error message for invalid state."""
|
|
service = _create_mockcalendar_service()
|
|
service.complete_oauth.side_effect = CalendarServiceError("Invalid or expired state token")
|
|
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
|
|
|
response = await _call_complete_oauth(
|
|
servicer,
|
|
noteflow_pb2.CompleteOAuthRequest(
|
|
provider="google",
|
|
code="authorization-code",
|
|
state="invalid-state",
|
|
),
|
|
_DummyContext(),
|
|
)
|
|
|
|
assert response.success is False, "should fail on invalid state"
|
|
assert "Invalid or expired state" in response.error_message, (
|
|
"error should mention invalid state"
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_error_on_invalid_code(self) -> None:
|
|
"""Returns success=False with error message for invalid code."""
|
|
service = _create_mockcalendar_service()
|
|
service.complete_oauth.side_effect = CalendarServiceError(
|
|
"Token exchange failed: invalid_grant"
|
|
)
|
|
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
|
|
|
response = await _call_complete_oauth(
|
|
servicer,
|
|
noteflow_pb2.CompleteOAuthRequest(
|
|
provider="google",
|
|
code="invalid-code",
|
|
state="valid-state",
|
|
),
|
|
_DummyContext(),
|
|
)
|
|
|
|
assert response.success is False, "should fail on invalid code"
|
|
assert "Token exchange failed" in response.error_message, (
|
|
"error should mention token exchange failure"
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_aborts_when_complete_service_unavailable(self) -> None:
|
|
"""Aborts when calendar service is not configured."""
|
|
servicer = NoteFlowServicer()
|
|
context = _DummyContext()
|
|
|
|
with pytest.raises(AssertionError, match="abort called"):
|
|
await _call_complete_oauth(
|
|
servicer,
|
|
noteflow_pb2.CompleteOAuthRequest(
|
|
provider="google",
|
|
code="code",
|
|
state="state",
|
|
),
|
|
context,
|
|
)
|
|
|
|
assert context.aborted, "should abort when service unavailable"
|
|
|
|
|
|
class TestGetOAuthConnectionStatus:
|
|
"""Tests for GetOAuthConnectionStatus RPC."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_connected_status(self) -> None:
|
|
"""Returns connection info when provider is connected."""
|
|
service = _create_mockcalendar_service(
|
|
providers_connected={"google": True},
|
|
provider_emails={"google": "user@gmail.com"},
|
|
)
|
|
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
|
|
|
response = await _call_get_oauth_status(
|
|
servicer,
|
|
noteflow_pb2.GetOAuthConnectionStatusRequest(provider="google"),
|
|
_DummyContext(),
|
|
)
|
|
|
|
assert response.connection.provider == "google", "should return correct provider"
|
|
assert response.connection.status == IntegrationStatus.CONNECTED.value, (
|
|
"status should be connected"
|
|
)
|
|
assert response.connection.email == "user@gmail.com", "should return connected email"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_disconnected_status(self) -> None:
|
|
"""Returns disconnected status when provider not connected."""
|
|
service = _create_mockcalendar_service(providers_connected={"google": False})
|
|
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
|
|
|
response = await _call_get_oauth_status(
|
|
servicer,
|
|
noteflow_pb2.GetOAuthConnectionStatusRequest(provider="google"),
|
|
_DummyContext(),
|
|
)
|
|
|
|
assert response.connection.status == IntegrationStatus.DISCONNECTED.value, (
|
|
"status should be disconnected"
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_integration_type(self) -> None:
|
|
"""Returns correct integration type in response."""
|
|
service = _create_mockcalendar_service()
|
|
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
|
|
|
response = await _call_get_oauth_status(
|
|
servicer,
|
|
noteflow_pb2.GetOAuthConnectionStatusRequest(
|
|
provider="google",
|
|
integration_type="calendar",
|
|
),
|
|
_DummyContext(),
|
|
)
|
|
|
|
assert response.connection.integration_type == "calendar", (
|
|
"should return calendar integration type"
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_aborts_when_status_service_unavailable(self) -> None:
|
|
"""Aborts when calendar service is not configured."""
|
|
servicer = NoteFlowServicer()
|
|
context = _DummyContext()
|
|
|
|
with pytest.raises(AssertionError, match="abort called"):
|
|
await _call_get_oauth_status(
|
|
servicer,
|
|
noteflow_pb2.GetOAuthConnectionStatusRequest(provider="google"),
|
|
context,
|
|
)
|
|
|
|
assert context.aborted, "should abort when service unavailable"
|
|
|
|
|
|
class TestDisconnectOAuth:
|
|
"""Tests for DisconnectOAuth RPC."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_success_on_disconnect(self) -> None:
|
|
"""Returns success=True when disconnection succeeds."""
|
|
service = _create_mockcalendar_service(providers_connected={"google": True})
|
|
service.disconnect.return_value = True
|
|
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
|
|
|
response = await _call_disconnect_oauth(
|
|
servicer,
|
|
noteflow_pb2.DisconnectOAuthRequest(provider="google"),
|
|
_DummyContext(),
|
|
)
|
|
|
|
assert response.success is True, "should succeed on disconnect"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_calls_service_disconnect(self) -> None:
|
|
"""Calls calendar service disconnect with correct provider."""
|
|
service = _create_mockcalendar_service()
|
|
service.disconnect.return_value = True
|
|
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
|
|
|
await _call_disconnect_oauth(
|
|
servicer,
|
|
noteflow_pb2.DisconnectOAuthRequest(provider="outlook"),
|
|
_DummyContext(),
|
|
)
|
|
|
|
service.disconnect.assert_awaited_once_with("outlook", workspace_id=DEFAULT_WORKSPACE_ID)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_false_when_not_connected(self) -> None:
|
|
"""Returns success=False when provider was not connected."""
|
|
service = _create_mockcalendar_service(providers_connected={"google": False})
|
|
service.disconnect.return_value = False
|
|
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
|
|
|
response = await _call_disconnect_oauth(
|
|
servicer,
|
|
noteflow_pb2.DisconnectOAuthRequest(provider="google"),
|
|
_DummyContext(),
|
|
)
|
|
|
|
assert response.success is False, "should fail when not connected"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_aborts_when_disconnect_service_unavailable(self) -> None:
|
|
"""Aborts when calendar service is not configured."""
|
|
servicer = NoteFlowServicer()
|
|
context = _DummyContext()
|
|
|
|
with pytest.raises(AssertionError, match="abort called"):
|
|
await _call_disconnect_oauth(
|
|
servicer,
|
|
noteflow_pb2.DisconnectOAuthRequest(provider="google"),
|
|
context,
|
|
)
|
|
|
|
assert context.aborted, "should abort when service unavailable"
|
|
|
|
|
|
class TestOAuthRoundTrip:
|
|
"""Integration tests for complete OAuth flow."""
|
|
|
|
@pytest.fixture
|
|
def oauth_flow_service(self) -> tuple[MagicMock, dict[str, bool], dict[str, str]]:
|
|
"""Create service with mutable state for round-trip testing."""
|
|
connected_state: dict[str, bool] = {"google": False}
|
|
email_state: dict[str, str] = {}
|
|
service = _create_mockcalendar_service()
|
|
|
|
service.initiate_oauth.return_value = (
|
|
"https://accounts.google.com/oauth",
|
|
"state-123",
|
|
)
|
|
|
|
async def complete_oauth(
|
|
provider: str, code: str, state: str, workspace_id: UUID | None = None
|
|
) -> UUID:
|
|
if state != "state-123":
|
|
raise CalendarServiceError("Invalid state")
|
|
connected_state[provider] = True
|
|
email_state[provider] = "user@gmail.com"
|
|
return uuid4() # Returns integration ID
|
|
|
|
service.complete_oauth.side_effect = complete_oauth
|
|
|
|
async def get_status(
|
|
provider: str, workspace_id: UUID | None = None
|
|
) -> OAuthConnectionInfo:
|
|
is_connected = connected_state.get(provider, False)
|
|
return _create_mock_connection_info(
|
|
provider=provider,
|
|
status=IntegrationStatus.CONNECTED.value
|
|
if is_connected
|
|
else IntegrationStatus.DISCONNECTED.value,
|
|
email=email_state.get(provider),
|
|
)
|
|
|
|
service.get_connection_status.side_effect = get_status
|
|
|
|
async def disconnect(provider: str, workspace_id: UUID | None = None) -> bool:
|
|
if connected_state.get(provider, False):
|
|
connected_state[provider] = False
|
|
email_state.pop(provider, None)
|
|
return True
|
|
return False
|
|
|
|
service.disconnect.side_effect = disconnect
|
|
|
|
return service, connected_state, email_state
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_initiate_returns_auth_url(
|
|
self, oauth_flow_service: tuple[MagicMock, dict[str, bool], dict[str, str]]
|
|
) -> None:
|
|
"""OAuth initiation returns auth URL and state."""
|
|
service, _, _ = oauth_flow_service
|
|
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
|
|
|
response = await _call_initiate_oauth(
|
|
servicer,
|
|
noteflow_pb2.InitiateOAuthRequest(provider="google"),
|
|
_DummyContext(),
|
|
)
|
|
|
|
assert "google" in response.auth_url, "should contain google in url"
|
|
assert response.state == "state-123", "should return state token"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_updates_connection_status(
|
|
self, oauth_flow_service: tuple[MagicMock, dict[str, bool], dict[str, str]]
|
|
) -> None:
|
|
"""Completing OAuth updates connection status."""
|
|
service, connected_state, _ = oauth_flow_service
|
|
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
|
|
|
assert connected_state["google"] is False, "should start disconnected"
|
|
|
|
response = await _call_complete_oauth(
|
|
servicer,
|
|
noteflow_pb2.CompleteOAuthRequest(
|
|
provider="google",
|
|
code="auth-code",
|
|
state="state-123",
|
|
),
|
|
_DummyContext(),
|
|
)
|
|
|
|
assert response.success is True, "should complete successfully"
|
|
assert connected_state["google"], "should be connected after complete"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_disconnect_clears_connection(
|
|
self, oauth_flow_service: tuple[MagicMock, dict[str, bool], dict[str, str]]
|
|
) -> None:
|
|
"""Disconnecting clears connection status."""
|
|
service, connected_state, email_state = oauth_flow_service
|
|
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
|
|
|
# First connect
|
|
connected_state["google"] = True
|
|
email_state["google"] = "user@gmail.com"
|
|
|
|
response = await _call_disconnect_oauth(
|
|
servicer,
|
|
noteflow_pb2.DisconnectOAuthRequest(provider="google"),
|
|
_DummyContext(),
|
|
)
|
|
|
|
assert response.success is True, "should disconnect successfully"
|
|
assert not connected_state["google"], "should be disconnected"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_with_wrong_state_fails(self) -> None:
|
|
"""Completing OAuth with wrong state token fails gracefully."""
|
|
service = _create_mockcalendar_service()
|
|
service.initiate_oauth.return_value = ("https://auth.url", "correct-state")
|
|
service.complete_oauth.side_effect = CalendarServiceError("Invalid or expired state token")
|
|
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
|
|
|
response = await _call_complete_oauth(
|
|
servicer,
|
|
noteflow_pb2.CompleteOAuthRequest(
|
|
provider="google",
|
|
code="auth-code",
|
|
state="wrong-state",
|
|
),
|
|
_DummyContext(),
|
|
)
|
|
|
|
assert response.success is False, "should fail with wrong state"
|
|
assert "Invalid or expired state" in response.error_message, (
|
|
"error should mention invalid state"
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_multiple_providers_independent(self) -> None:
|
|
"""Multiple providers can be connected independently."""
|
|
service = _create_mockcalendar_service(
|
|
providers_connected={"google": True, "outlook": False},
|
|
provider_emails={"google": "user@gmail.com"},
|
|
)
|
|
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
|
ctx = _DummyContext()
|
|
|
|
google_status = await _call_get_oauth_status(
|
|
servicer,
|
|
noteflow_pb2.GetOAuthConnectionStatusRequest(provider="google"),
|
|
ctx,
|
|
)
|
|
outlook_status = await _call_get_oauth_status(
|
|
servicer,
|
|
noteflow_pb2.GetOAuthConnectionStatusRequest(provider="outlook"),
|
|
ctx,
|
|
)
|
|
|
|
assert google_status.connection.status == IntegrationStatus.CONNECTED.value, (
|
|
"google should be connected"
|
|
)
|
|
assert outlook_status.connection.status == IntegrationStatus.DISCONNECTED.value, (
|
|
"outlook should be disconnected"
|
|
)
|
|
|
|
|
|
class TestOAuthSecurityBehavior:
|
|
"""Tests for OAuth security requirements."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_state_validation_required(self) -> None:
|
|
"""State token must match for completion to succeed."""
|
|
service = _create_mockcalendar_service()
|
|
service.initiate_oauth.return_value = ("https://auth.url", "secure-state-123")
|
|
service.complete_oauth.side_effect = CalendarServiceError("State mismatch")
|
|
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
|
|
|
response = await _call_complete_oauth(
|
|
servicer,
|
|
noteflow_pb2.CompleteOAuthRequest(
|
|
provider="google",
|
|
code="stolen-code",
|
|
state="attacker-state",
|
|
),
|
|
_DummyContext(),
|
|
)
|
|
|
|
assert response.success is False, "should reject mismatched state"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tokens_revoked_on_disconnect(self) -> None:
|
|
"""Disconnect should call service to revoke tokens."""
|
|
service = _create_mockcalendar_service(providers_connected={"google": True})
|
|
service.disconnect.return_value = True
|
|
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
|
|
|
await _call_disconnect_oauth(
|
|
servicer,
|
|
noteflow_pb2.DisconnectOAuthRequest(provider="google"),
|
|
_DummyContext(),
|
|
)
|
|
|
|
service.disconnect.assert_awaited_once_with("google", workspace_id=DEFAULT_WORKSPACE_ID)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_sensitive_data_in_error_responses(self) -> None:
|
|
"""Error responses should not leak sensitive information."""
|
|
service = _create_mockcalendar_service()
|
|
service.complete_oauth.side_effect = CalendarServiceError(
|
|
"Token exchange failed: invalid_grant"
|
|
)
|
|
servicer = NoteFlowServicer(services=ServicesConfig(calendar_service=service))
|
|
|
|
response = await _call_complete_oauth(
|
|
servicer,
|
|
noteflow_pb2.CompleteOAuthRequest(
|
|
provider="google",
|
|
code="code",
|
|
state="state",
|
|
),
|
|
_DummyContext(),
|
|
)
|
|
|
|
assert "Bearer" not in response.error_message, "error should not leak bearer tokens"
|
|
assert "secret" not in response.error_message.lower(), "error should not leak secrets"
|