feat: implement identity management features in gRPC service
- Introduced `IdentityMixin` to manage user identity operations, including `GetCurrentUser`, `ListWorkspaces`, and `SwitchWorkspace`. - Added corresponding gRPC methods and message definitions in the proto file for identity management. - Enhanced `AuthService` to support user authentication and token management. - Updated `OAuthManager` to include rate limiting for authentication attempts and improved error handling. - Implemented unit tests for the new identity management features to ensure functionality and reliability.
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -1,5 +1,12 @@
|
||||
"""Application services for NoteFlow use cases."""
|
||||
|
||||
from noteflow.application.services.auth_service import (
|
||||
AuthResult,
|
||||
AuthService,
|
||||
AuthServiceError,
|
||||
LogoutResult,
|
||||
UserInfo,
|
||||
)
|
||||
from noteflow.application.services.export_service import ExportFormat, ExportService
|
||||
from noteflow.application.services.identity_service import IdentityService
|
||||
from noteflow.application.services.meeting_service import MeetingService
|
||||
@@ -15,10 +22,15 @@ from noteflow.application.services.summarization_service import (
|
||||
from noteflow.application.services.trigger_service import TriggerService, TriggerServiceSettings
|
||||
|
||||
__all__ = [
|
||||
"AuthResult",
|
||||
"AuthService",
|
||||
"AuthServiceError",
|
||||
"ExportFormat",
|
||||
"ExportService",
|
||||
"IdentityService",
|
||||
"LogoutResult",
|
||||
"MeetingService",
|
||||
"UserInfo",
|
||||
"ProjectService",
|
||||
"RecoveryService",
|
||||
"RetentionReport",
|
||||
|
||||
563
src/noteflow/application/services/auth_service.py
Normal file
563
src/noteflow/application/services/auth_service.py
Normal file
@@ -0,0 +1,563 @@
|
||||
"""Authentication service for OAuth-based user login.
|
||||
|
||||
Extends the CalendarService OAuth patterns for user authentication.
|
||||
Uses the same OAuthManager infrastructure but stores tokens with
|
||||
IntegrationType.AUTH and manages User entities.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, TypedDict, Unpack
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from noteflow.config.constants import OAUTH_FIELD_ACCESS_TOKEN
|
||||
from noteflow.domain.entities.integration import Integration, IntegrationType
|
||||
from noteflow.domain.identity.entities import User
|
||||
from noteflow.domain.value_objects import OAuthProvider, OAuthTokens
|
||||
from noteflow.infrastructure.calendar import OAuthManager
|
||||
from noteflow.infrastructure.calendar.google_adapter import GoogleCalendarError
|
||||
from noteflow.infrastructure.calendar.oauth_manager import OAuthError
|
||||
from noteflow.infrastructure.calendar.outlook_adapter import OutlookCalendarError
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
|
||||
class _AuthServiceDepsKwargs(TypedDict, total=False):
|
||||
"""Optional dependency overrides for AuthService."""
|
||||
|
||||
oauth_manager: OAuthManager
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from noteflow.config.settings import CalendarIntegrationSettings
|
||||
from noteflow.domain.ports.unit_of_work import UnitOfWork
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# Default IDs for local-first mode
|
||||
DEFAULT_USER_ID = UUID("00000000-0000-0000-0000-000000000001")
|
||||
DEFAULT_WORKSPACE_ID = UUID("00000000-0000-0000-0000-000000000001")
|
||||
|
||||
|
||||
class AuthServiceError(Exception):
|
||||
"""Auth service operation failed."""
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AuthResult:
|
||||
"""Result of successful authentication.
|
||||
|
||||
Note: Tokens are stored securely in IntegrationSecretModel and are NOT
|
||||
exposed to callers. Use get_current_user() to check auth status.
|
||||
"""
|
||||
|
||||
user_id: UUID
|
||||
workspace_id: UUID
|
||||
display_name: str
|
||||
email: str | None
|
||||
is_authenticated: bool = True
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class UserInfo:
|
||||
"""Current user information."""
|
||||
|
||||
user_id: UUID
|
||||
workspace_id: UUID
|
||||
display_name: str
|
||||
email: str | None
|
||||
is_authenticated: bool
|
||||
provider: str | None
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class LogoutResult:
|
||||
"""Result of logout operation.
|
||||
|
||||
Provides visibility into both local logout and remote token revocation.
|
||||
"""
|
||||
|
||||
logged_out: bool
|
||||
"""Whether local logout succeeded (integration deleted)."""
|
||||
|
||||
tokens_revoked: bool
|
||||
"""Whether remote token revocation succeeded."""
|
||||
|
||||
revocation_error: str | None = None
|
||||
"""Error message if revocation failed (for logging/debugging)."""
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""Authentication service for OAuth-based user login.
|
||||
|
||||
Orchestrates OAuth flow for user authentication. Uses:
|
||||
- IntegrationRepository for auth token storage
|
||||
- OAuthManager for PKCE OAuth flow
|
||||
- UserRepository for user entity management
|
||||
|
||||
Unlike CalendarService which manages calendar integrations,
|
||||
AuthService manages user identity and authentication state.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uow_factory: Callable[[], UnitOfWork],
|
||||
settings: CalendarIntegrationSettings,
|
||||
**kwargs: Unpack[_AuthServiceDepsKwargs],
|
||||
) -> None:
|
||||
"""Initialize auth service.
|
||||
|
||||
Args:
|
||||
uow_factory: Factory function returning UnitOfWork instances.
|
||||
settings: OAuth settings with credentials.
|
||||
**kwargs: Optional dependency overrides.
|
||||
"""
|
||||
self._uow_factory = uow_factory
|
||||
self._settings = settings
|
||||
oauth_manager = kwargs.get("oauth_manager")
|
||||
self._oauth_manager = oauth_manager or OAuthManager(settings)
|
||||
|
||||
async def initiate_login(
|
||||
self,
|
||||
provider: str,
|
||||
redirect_uri: str | None = None,
|
||||
) -> tuple[str, str]:
|
||||
"""Start OAuth login flow.
|
||||
|
||||
Args:
|
||||
provider: Provider name ('google' or 'outlook').
|
||||
redirect_uri: Optional override for OAuth callback URI.
|
||||
|
||||
Returns:
|
||||
Tuple of (authorization_url, state_token).
|
||||
|
||||
Raises:
|
||||
AuthServiceError: If provider is invalid or credentials not configured.
|
||||
"""
|
||||
oauth_provider = self._parse_provider(provider)
|
||||
effective_redirect = redirect_uri or self._settings.redirect_uri
|
||||
|
||||
try:
|
||||
auth_url, state = self._oauth_manager.initiate_auth(
|
||||
provider=oauth_provider,
|
||||
redirect_uri=effective_redirect,
|
||||
)
|
||||
logger.info(
|
||||
"auth_login_initiated",
|
||||
event_type="security",
|
||||
provider=provider,
|
||||
redirect_uri=effective_redirect,
|
||||
)
|
||||
return auth_url, state
|
||||
except OAuthError as e:
|
||||
logger.warning(
|
||||
"auth_login_initiation_failed",
|
||||
event_type="security",
|
||||
provider=provider,
|
||||
error=str(e),
|
||||
)
|
||||
raise AuthServiceError(str(e)) from e
|
||||
|
||||
async def complete_login(
|
||||
self,
|
||||
provider: str,
|
||||
code: str,
|
||||
state: str,
|
||||
) -> AuthResult:
|
||||
"""Complete OAuth login and create/update user.
|
||||
|
||||
Exchanges authorization code for tokens, fetches user info,
|
||||
and creates or updates the User entity.
|
||||
|
||||
Args:
|
||||
provider: Provider name ('google' or 'outlook').
|
||||
code: Authorization code from OAuth callback.
|
||||
state: State parameter from OAuth callback.
|
||||
|
||||
Returns:
|
||||
AuthResult with user identity and tokens.
|
||||
|
||||
Raises:
|
||||
AuthServiceError: If OAuth exchange fails.
|
||||
"""
|
||||
oauth_provider = self._parse_provider(provider)
|
||||
|
||||
# Exchange code for tokens
|
||||
tokens = await self._exchange_tokens(oauth_provider, code, state)
|
||||
|
||||
# Fetch user info from provider
|
||||
email, display_name = await self._fetch_user_info(
|
||||
oauth_provider, tokens.access_token
|
||||
)
|
||||
|
||||
# Create or update user and store tokens
|
||||
user_id, workspace_id = await self._store_auth_user(
|
||||
provider, email, display_name, tokens
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"auth_login_completed",
|
||||
event_type="security",
|
||||
provider=provider,
|
||||
email=email,
|
||||
user_id=str(user_id),
|
||||
workspace_id=str(workspace_id),
|
||||
)
|
||||
|
||||
return AuthResult(
|
||||
user_id=user_id,
|
||||
workspace_id=workspace_id,
|
||||
display_name=display_name,
|
||||
email=email,
|
||||
)
|
||||
|
||||
async def _exchange_tokens(
|
||||
self,
|
||||
oauth_provider: OAuthProvider,
|
||||
code: str,
|
||||
state: str,
|
||||
) -> OAuthTokens:
|
||||
"""Exchange authorization code for tokens."""
|
||||
try:
|
||||
return await self._oauth_manager.complete_auth(
|
||||
provider=oauth_provider,
|
||||
code=code,
|
||||
state=state,
|
||||
)
|
||||
except OAuthError as e:
|
||||
raise AuthServiceError(f"OAuth failed: {e}") from e
|
||||
|
||||
async def _fetch_user_info(
|
||||
self,
|
||||
oauth_provider: OAuthProvider,
|
||||
access_token: str,
|
||||
) -> tuple[str, str]:
|
||||
"""Fetch user email and display name from provider."""
|
||||
# Use the calendar adapter to get user info (reuse existing infrastructure)
|
||||
from noteflow.infrastructure.calendar.google_adapter import GoogleCalendarAdapter
|
||||
from noteflow.infrastructure.calendar.outlook_adapter import OutlookCalendarAdapter
|
||||
|
||||
try:
|
||||
if oauth_provider == OAuthProvider.GOOGLE:
|
||||
adapter = GoogleCalendarAdapter()
|
||||
email, display_name = await adapter.get_user_info(access_token)
|
||||
else:
|
||||
adapter = OutlookCalendarAdapter()
|
||||
email, display_name = await adapter.get_user_info(access_token)
|
||||
|
||||
return email, display_name
|
||||
except (GoogleCalendarError, OutlookCalendarError, OAuthError) as e:
|
||||
raise AuthServiceError(f"Failed to get user info: {e}") from e
|
||||
|
||||
async def _store_auth_user(
|
||||
self,
|
||||
provider: str,
|
||||
email: str,
|
||||
display_name: str,
|
||||
tokens: OAuthTokens,
|
||||
) -> tuple[UUID, UUID]:
|
||||
"""Create or update user and store auth tokens."""
|
||||
async with self._uow_factory() as uow:
|
||||
# Find or create user by email
|
||||
user = None
|
||||
if uow.supports_users:
|
||||
user = await uow.users.get_by_email(email)
|
||||
|
||||
if user is None:
|
||||
# Create new user
|
||||
user_id = uuid4()
|
||||
if uow.supports_users:
|
||||
user = User(
|
||||
id=user_id,
|
||||
email=email,
|
||||
display_name=display_name,
|
||||
is_default=False,
|
||||
)
|
||||
await uow.users.create(user)
|
||||
logger.info("Created new user: %s (%s)", display_name, email)
|
||||
else:
|
||||
user_id = DEFAULT_USER_ID
|
||||
else:
|
||||
user_id = user.id
|
||||
# Update display name if changed
|
||||
if user.display_name != display_name:
|
||||
user.display_name = display_name
|
||||
await uow.users.update(user)
|
||||
|
||||
# Get or create default workspace for this user
|
||||
workspace_id = DEFAULT_WORKSPACE_ID
|
||||
if uow.supports_workspaces:
|
||||
workspace = await uow.workspaces.get_default_for_user(user_id)
|
||||
if workspace:
|
||||
workspace_id = workspace.id
|
||||
else:
|
||||
# Create default "Personal" workspace for new user
|
||||
workspace_id = uuid4()
|
||||
await uow.workspaces.create(
|
||||
workspace_id=workspace_id,
|
||||
name="Personal",
|
||||
owner_id=user_id,
|
||||
is_default=True,
|
||||
)
|
||||
logger.info(
|
||||
"Created default workspace for user_id=%s, workspace_id=%s",
|
||||
user_id,
|
||||
workspace_id,
|
||||
)
|
||||
|
||||
# Store auth integration with tokens
|
||||
integration = await uow.integrations.get_by_provider(
|
||||
provider=provider,
|
||||
integration_type=IntegrationType.AUTH.value,
|
||||
)
|
||||
|
||||
if integration is None:
|
||||
integration = Integration.create(
|
||||
workspace_id=workspace_id,
|
||||
name=f"{provider.title()} Auth",
|
||||
integration_type=IntegrationType.AUTH,
|
||||
config={"provider": provider, "user_id": str(user_id)},
|
||||
)
|
||||
await uow.integrations.create(integration)
|
||||
else:
|
||||
integration.config["provider"] = provider
|
||||
integration.config["user_id"] = str(user_id)
|
||||
|
||||
integration.connect(provider_email=email)
|
||||
await uow.integrations.update(integration)
|
||||
|
||||
# Store tokens
|
||||
await uow.integrations.set_secrets(
|
||||
integration_id=integration.id,
|
||||
secrets=tokens.to_secrets_dict(),
|
||||
)
|
||||
await uow.commit()
|
||||
|
||||
return user_id, workspace_id
|
||||
|
||||
async def get_current_user(self) -> UserInfo:
|
||||
"""Get current authenticated user info.
|
||||
|
||||
Returns:
|
||||
UserInfo with current user details or local default.
|
||||
"""
|
||||
async with self._uow_factory() as uow:
|
||||
# Look for any connected auth integration
|
||||
for provider in [OAuthProvider.GOOGLE.value, OAuthProvider.OUTLOOK.value]:
|
||||
integration = await uow.integrations.get_by_provider(
|
||||
provider=provider,
|
||||
integration_type=IntegrationType.AUTH.value,
|
||||
)
|
||||
|
||||
if integration and integration.is_connected:
|
||||
user_id_str = integration.config.get("user_id")
|
||||
user_id = UUID(user_id_str) if user_id_str else DEFAULT_USER_ID
|
||||
|
||||
# Get user details
|
||||
display_name = "Authenticated User"
|
||||
if uow.supports_users and user_id_str:
|
||||
user = await uow.users.get(user_id)
|
||||
if user:
|
||||
display_name = user.display_name
|
||||
|
||||
return UserInfo(
|
||||
user_id=user_id,
|
||||
workspace_id=DEFAULT_WORKSPACE_ID,
|
||||
display_name=display_name,
|
||||
email=integration.provider_email,
|
||||
is_authenticated=True,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
# Return local default
|
||||
return UserInfo(
|
||||
user_id=DEFAULT_USER_ID,
|
||||
workspace_id=DEFAULT_WORKSPACE_ID,
|
||||
display_name="Local User",
|
||||
email=None,
|
||||
is_authenticated=False,
|
||||
provider=None,
|
||||
)
|
||||
|
||||
async def logout(self, provider: str | None = None) -> LogoutResult:
|
||||
"""Logout and revoke auth tokens.
|
||||
|
||||
Args:
|
||||
provider: Optional specific provider to logout from.
|
||||
If None, logs out from all providers.
|
||||
|
||||
Returns:
|
||||
LogoutResult with details on local logout and token revocation.
|
||||
"""
|
||||
providers = (
|
||||
[provider]
|
||||
if provider
|
||||
else [OAuthProvider.GOOGLE.value, OAuthProvider.OUTLOOK.value]
|
||||
)
|
||||
|
||||
logged_out = False
|
||||
all_revoked = True
|
||||
revocation_errors: list[str] = []
|
||||
|
||||
for p in providers:
|
||||
result = await self._logout_provider(p)
|
||||
logged_out = logged_out or result.logged_out
|
||||
if not result.tokens_revoked:
|
||||
all_revoked = False
|
||||
if result.revocation_error:
|
||||
revocation_errors.append(f"{p}: {result.revocation_error}")
|
||||
|
||||
return LogoutResult(
|
||||
logged_out=logged_out,
|
||||
tokens_revoked=all_revoked,
|
||||
revocation_error="; ".join(revocation_errors) if revocation_errors else None,
|
||||
)
|
||||
|
||||
async def _logout_provider(self, provider: str) -> LogoutResult:
|
||||
"""Logout from a specific provider."""
|
||||
oauth_provider = self._parse_provider(provider)
|
||||
|
||||
async with self._uow_factory() as uow:
|
||||
integration = await uow.integrations.get_by_provider(
|
||||
provider=provider,
|
||||
integration_type=IntegrationType.AUTH.value,
|
||||
)
|
||||
|
||||
if integration is None:
|
||||
return LogoutResult(
|
||||
logged_out=False,
|
||||
tokens_revoked=True, # No tokens to revoke
|
||||
)
|
||||
|
||||
# Get tokens for revocation
|
||||
secrets = await uow.integrations.get_secrets(integration.id)
|
||||
access_token = secrets.get(OAUTH_FIELD_ACCESS_TOKEN) if secrets else None
|
||||
|
||||
# Delete integration
|
||||
await uow.integrations.delete(integration.id)
|
||||
await uow.commit()
|
||||
|
||||
# Revoke tokens (best effort)
|
||||
tokens_revoked = True
|
||||
revocation_error: str | None = None
|
||||
|
||||
if access_token:
|
||||
try:
|
||||
await self._oauth_manager.revoke_tokens(oauth_provider, access_token)
|
||||
logger.info(
|
||||
"auth_tokens_revoked",
|
||||
event_type="security",
|
||||
provider=provider,
|
||||
)
|
||||
except OAuthError as e:
|
||||
tokens_revoked = False
|
||||
revocation_error = str(e)
|
||||
logger.warning(
|
||||
"auth_token_revocation_failed",
|
||||
event_type="security",
|
||||
provider=provider,
|
||||
error=revocation_error,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"auth_logout_completed",
|
||||
event_type="security",
|
||||
provider=provider,
|
||||
tokens_revoked=tokens_revoked,
|
||||
)
|
||||
|
||||
return LogoutResult(
|
||||
logged_out=True,
|
||||
tokens_revoked=tokens_revoked,
|
||||
revocation_error=revocation_error,
|
||||
)
|
||||
|
||||
async def refresh_auth_tokens(self, provider: str) -> AuthResult | None:
|
||||
"""Refresh expired auth tokens.
|
||||
|
||||
Args:
|
||||
provider: Provider to refresh tokens for.
|
||||
|
||||
Returns:
|
||||
Updated AuthResult or None if refresh failed.
|
||||
"""
|
||||
oauth_provider = self._parse_provider(provider)
|
||||
|
||||
async with self._uow_factory() as uow:
|
||||
integration = await uow.integrations.get_by_provider(
|
||||
provider=provider,
|
||||
integration_type=IntegrationType.AUTH.value,
|
||||
)
|
||||
|
||||
if integration is None or not integration.is_connected:
|
||||
return None
|
||||
|
||||
secrets = await uow.integrations.get_secrets(integration.id)
|
||||
if not secrets:
|
||||
return None
|
||||
|
||||
try:
|
||||
tokens = OAuthTokens.from_secrets_dict(secrets)
|
||||
except (KeyError, ValueError):
|
||||
return None
|
||||
|
||||
if not tokens.refresh_token:
|
||||
return None
|
||||
|
||||
# Only refresh if token is expired or will expire within 5 minutes
|
||||
if not tokens.is_expired(buffer_seconds=300):
|
||||
logger.debug(
|
||||
"auth_token_still_valid",
|
||||
provider=provider,
|
||||
expires_at=tokens.expires_at.isoformat() if tokens.expires_at else None,
|
||||
)
|
||||
# Return existing auth info without refreshing
|
||||
user_id_str = integration.config.get("user_id")
|
||||
user_id = UUID(user_id_str) if user_id_str else DEFAULT_USER_ID
|
||||
return AuthResult(
|
||||
user_id=user_id,
|
||||
workspace_id=DEFAULT_WORKSPACE_ID,
|
||||
display_name=integration.provider_email or "User",
|
||||
email=integration.provider_email,
|
||||
)
|
||||
|
||||
try:
|
||||
new_tokens = await self._oauth_manager.refresh_tokens(
|
||||
provider=oauth_provider,
|
||||
refresh_token=tokens.refresh_token,
|
||||
)
|
||||
|
||||
await uow.integrations.set_secrets(
|
||||
integration_id=integration.id,
|
||||
secrets=new_tokens.to_secrets_dict(),
|
||||
)
|
||||
await uow.commit()
|
||||
|
||||
user_id_str = integration.config.get("user_id")
|
||||
user_id = UUID(user_id_str) if user_id_str else DEFAULT_USER_ID
|
||||
|
||||
return AuthResult(
|
||||
user_id=user_id,
|
||||
workspace_id=DEFAULT_WORKSPACE_ID,
|
||||
display_name=integration.provider_email or "User",
|
||||
email=integration.provider_email,
|
||||
)
|
||||
|
||||
except OAuthError as e:
|
||||
integration.mark_error(f"Token refresh failed: {e}")
|
||||
await uow.integrations.update(integration)
|
||||
await uow.commit()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _parse_provider(provider: str) -> OAuthProvider:
|
||||
"""Parse and validate provider string."""
|
||||
try:
|
||||
return OAuthProvider(provider.lower())
|
||||
except ValueError as e:
|
||||
raise AuthServiceError(
|
||||
f"Invalid provider: {provider}. Must be 'google' or 'outlook'."
|
||||
) from e
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum, IntEnum, StrEnum
|
||||
from typing import NewType
|
||||
from uuid import UUID
|
||||
@@ -145,9 +145,18 @@ class OAuthTokens:
|
||||
expires_at: datetime
|
||||
scope: str
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the access token has expired."""
|
||||
return datetime.now(self.expires_at.tzinfo) > self.expires_at
|
||||
def is_expired(self, buffer_seconds: int = 0) -> bool:
|
||||
"""Check if the access token has expired or will expire within buffer.
|
||||
|
||||
Args:
|
||||
buffer_seconds: Consider token expired this many seconds before
|
||||
actual expiry. Useful for proactive refresh.
|
||||
|
||||
Returns:
|
||||
True if token is expired or will expire within buffer_seconds.
|
||||
"""
|
||||
effective_expiry = self.expires_at - timedelta(seconds=buffer_seconds)
|
||||
return datetime.now(self.expires_at.tzinfo) > effective_expiry
|
||||
|
||||
def to_secrets_dict(self) -> dict[str, str]:
|
||||
"""Convert to dictionary for encrypted storage."""
|
||||
|
||||
@@ -9,6 +9,7 @@ from noteflow.config.constants import DEFAULT_GRPC_PORT
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.application.services.calendar_service import CalendarService
|
||||
from noteflow.application.services.identity_service import IdentityService
|
||||
from noteflow.application.services.ner_service import NerService
|
||||
from noteflow.application.services.project_service import ProjectService
|
||||
from noteflow.application.services.summarization_service import SummarizationService
|
||||
@@ -203,6 +204,7 @@ class ServicesConfig:
|
||||
calendar_service: Service for OAuth and calendar event fetching.
|
||||
webhook_service: Service for webhook event notifications.
|
||||
project_service: Service for project management.
|
||||
identity_service: Service for identity and workspace context management.
|
||||
"""
|
||||
|
||||
summarization_service: SummarizationService | None = None
|
||||
@@ -212,3 +214,4 @@ class ServicesConfig:
|
||||
calendar_service: CalendarService | None = None
|
||||
webhook_service: WebhookService | None = None
|
||||
project_service: ProjectService | None = None
|
||||
identity_service: IdentityService | None = None
|
||||
|
||||
@@ -7,6 +7,7 @@ from .diarization import DiarizationMixin
|
||||
from .diarization_job import DiarizationJobMixin
|
||||
from .entities import EntitiesMixin
|
||||
from .export import ExportMixin
|
||||
from .identity import IdentityMixin
|
||||
from .meeting import MeetingMixin
|
||||
from .observability import ObservabilityMixin
|
||||
from .oidc import OidcMixin
|
||||
@@ -26,6 +27,7 @@ __all__ = [
|
||||
"DiarizationMixin",
|
||||
"EntitiesMixin",
|
||||
"ExportMixin",
|
||||
"IdentityMixin",
|
||||
"MeetingMixin",
|
||||
"ObservabilityMixin",
|
||||
"OidcMixin",
|
||||
|
||||
186
src/noteflow/grpc/_mixins/identity.py
Normal file
186
src/noteflow/grpc/_mixins/identity.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""Identity management mixin for gRPC service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from noteflow.application.services.identity_service import IdentityService
|
||||
from noteflow.domain.entities.integration import IntegrationType
|
||||
from noteflow.domain.identity.context import OperationContext
|
||||
from noteflow.domain.ports.unit_of_work import UnitOfWork
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
from ..proto import noteflow_pb2
|
||||
from .errors import abort_database_required, abort_invalid_argument, abort_not_found, parse_workspace_id
|
||||
from ._types import GrpcContext
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class IdentityServicer(Protocol):
|
||||
"""Protocol for hosts that support identity operations."""
|
||||
|
||||
def create_repository_provider(self) -> UnitOfWork: ...
|
||||
|
||||
def get_operation_context(self, context: GrpcContext) -> OperationContext: ...
|
||||
|
||||
@property
|
||||
def identity_service(self) -> IdentityService: ...
|
||||
|
||||
|
||||
class IdentityMixin:
|
||||
"""Mixin providing identity management functionality.
|
||||
|
||||
Implements:
|
||||
- GetCurrentUser: Get current user's identity info
|
||||
- ListWorkspaces: List workspaces user belongs to
|
||||
- SwitchWorkspace: Switch to a different workspace
|
||||
"""
|
||||
|
||||
async def GetCurrentUser(
|
||||
self: IdentityServicer,
|
||||
request: noteflow_pb2.GetCurrentUserRequest,
|
||||
context: GrpcContext,
|
||||
) -> noteflow_pb2.GetCurrentUserResponse:
|
||||
"""Get current authenticated user info."""
|
||||
# Note: op_context from headers provides request metadata
|
||||
_ = self.get_operation_context(context)
|
||||
|
||||
async with self.create_repository_provider() as uow:
|
||||
# Get or create default user/workspace for local-first mode
|
||||
user_ctx = await self.identity_service.get_or_create_default_user(uow)
|
||||
ws_ctx = await self.identity_service.get_or_create_default_workspace(
|
||||
uow, user_ctx.user_id
|
||||
)
|
||||
await uow.commit()
|
||||
|
||||
# Check if user has auth integration (authenticated via OAuth)
|
||||
is_authenticated = False
|
||||
auth_provider = ""
|
||||
|
||||
if hasattr(uow, "supports_integrations") and uow.supports_integrations:
|
||||
for provider in ["google", "outlook"]:
|
||||
integration = await uow.integrations.get_by_provider(
|
||||
provider=provider,
|
||||
integration_type=IntegrationType.AUTH.value,
|
||||
)
|
||||
if integration and integration.is_connected:
|
||||
is_authenticated = True
|
||||
auth_provider = provider
|
||||
break
|
||||
|
||||
logger.debug(
|
||||
"GetCurrentUser: user_id=%s, workspace_id=%s, authenticated=%s",
|
||||
user_ctx.user_id,
|
||||
ws_ctx.workspace_id,
|
||||
is_authenticated,
|
||||
)
|
||||
|
||||
return noteflow_pb2.GetCurrentUserResponse(
|
||||
user_id=str(user_ctx.user_id),
|
||||
workspace_id=str(ws_ctx.workspace_id),
|
||||
display_name=user_ctx.display_name,
|
||||
email=user_ctx.email or "",
|
||||
is_authenticated=is_authenticated,
|
||||
auth_provider=auth_provider,
|
||||
workspace_name=ws_ctx.workspace_name,
|
||||
role=ws_ctx.role.value,
|
||||
)
|
||||
|
||||
async def ListWorkspaces(
|
||||
self: IdentityServicer,
|
||||
request: noteflow_pb2.ListWorkspacesRequest,
|
||||
context: GrpcContext,
|
||||
) -> noteflow_pb2.ListWorkspacesResponse:
|
||||
"""List workspaces the current user belongs to."""
|
||||
_ = self.get_operation_context(context)
|
||||
|
||||
async with self.create_repository_provider() as uow:
|
||||
if not uow.supports_workspaces:
|
||||
await abort_database_required(context, "Workspaces")
|
||||
|
||||
user_ctx = await self.identity_service.get_or_create_default_user(uow)
|
||||
|
||||
limit = request.limit if request.limit > 0 else 50
|
||||
offset = request.offset if request.offset >= 0 else 0
|
||||
|
||||
workspaces = await self.identity_service.list_workspaces(
|
||||
uow, user_ctx.user_id, limit, offset
|
||||
)
|
||||
|
||||
workspace_protos: list[noteflow_pb2.WorkspaceProto] = []
|
||||
for ws in workspaces:
|
||||
membership = await uow.workspaces.get_membership(ws.id, user_ctx.user_id)
|
||||
role = membership.role.value if membership else "member"
|
||||
|
||||
workspace_protos.append(
|
||||
noteflow_pb2.WorkspaceProto(
|
||||
id=str(ws.id),
|
||||
name=ws.name,
|
||||
slug=ws.slug or "",
|
||||
is_default=ws.is_default,
|
||||
role=role,
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"ListWorkspaces: user_id=%s, count=%d",
|
||||
user_ctx.user_id,
|
||||
len(workspace_protos),
|
||||
)
|
||||
|
||||
return noteflow_pb2.ListWorkspacesResponse(
|
||||
workspaces=workspace_protos,
|
||||
total_count=len(workspace_protos),
|
||||
)
|
||||
|
||||
async def SwitchWorkspace(
|
||||
self: IdentityServicer,
|
||||
request: noteflow_pb2.SwitchWorkspaceRequest,
|
||||
context: GrpcContext,
|
||||
) -> noteflow_pb2.SwitchWorkspaceResponse:
|
||||
"""Switch to a different workspace."""
|
||||
_ = self.get_operation_context(context)
|
||||
|
||||
if not request.workspace_id:
|
||||
await abort_invalid_argument(context, "workspace_id is required")
|
||||
|
||||
# Parse and validate workspace ID (aborts with INVALID_ARGUMENT if invalid)
|
||||
workspace_id = await parse_workspace_id(request.workspace_id, context)
|
||||
|
||||
async with self.create_repository_provider() as uow:
|
||||
if not uow.supports_workspaces:
|
||||
await abort_database_required(context, "Workspaces")
|
||||
|
||||
user_ctx = await self.identity_service.get_or_create_default_user(uow)
|
||||
|
||||
# Verify workspace exists
|
||||
workspace = await uow.workspaces.get(workspace_id)
|
||||
if not workspace:
|
||||
await abort_not_found(context, "Workspace", str(workspace_id))
|
||||
|
||||
# Verify user has access
|
||||
membership = await uow.workspaces.get_membership(
|
||||
workspace_id, user_ctx.user_id
|
||||
)
|
||||
if not membership:
|
||||
await abort_not_found(
|
||||
context, "Workspace membership", str(workspace_id)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"SwitchWorkspace: user_id=%s, workspace_id=%s",
|
||||
user_ctx.user_id,
|
||||
workspace_id,
|
||||
)
|
||||
|
||||
return noteflow_pb2.SwitchWorkspaceResponse(
|
||||
success=True,
|
||||
workspace=noteflow_pb2.WorkspaceProto(
|
||||
id=str(workspace.id),
|
||||
name=workspace.name,
|
||||
slug=workspace.slug or "",
|
||||
is_default=workspace.is_default,
|
||||
role=membership.role.value,
|
||||
),
|
||||
)
|
||||
@@ -114,6 +114,11 @@ service NoteFlowService {
|
||||
rpc UpdateProjectMemberRole(UpdateProjectMemberRoleRequest) returns (ProjectMembershipProto);
|
||||
rpc RemoveProjectMember(RemoveProjectMemberRequest) returns (RemoveProjectMemberResponse);
|
||||
rpc ListProjectMembers(ListProjectMembersRequest) returns (ListProjectMembersResponse);
|
||||
|
||||
// Identity management (Sprint 16+)
|
||||
rpc GetCurrentUser(GetCurrentUserRequest) returns (GetCurrentUserResponse);
|
||||
rpc ListWorkspaces(ListWorkspacesRequest) returns (ListWorkspacesResponse);
|
||||
rpc SwitchWorkspace(SwitchWorkspaceRequest) returns (SwitchWorkspaceResponse);
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
@@ -1848,3 +1853,86 @@ message ListProjectMembersResponse {
|
||||
// Total count
|
||||
int32 total_count = 2;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Identity Management Messages (Sprint 16+)
|
||||
// =============================================================================
|
||||
|
||||
message GetCurrentUserRequest {
|
||||
// Empty - user ID comes from request headers
|
||||
}
|
||||
|
||||
message GetCurrentUserResponse {
|
||||
// User ID (UUID string)
|
||||
string user_id = 1;
|
||||
|
||||
// Current workspace ID (UUID string)
|
||||
string workspace_id = 2;
|
||||
|
||||
// User display name
|
||||
string display_name = 3;
|
||||
|
||||
// User email (optional)
|
||||
string email = 4;
|
||||
|
||||
// Whether user is authenticated (vs local mode)
|
||||
bool is_authenticated = 5;
|
||||
|
||||
// OAuth provider if authenticated (google, outlook, etc.)
|
||||
string auth_provider = 6;
|
||||
|
||||
// Workspace name
|
||||
string workspace_name = 7;
|
||||
|
||||
// User's role in workspace
|
||||
string role = 8;
|
||||
}
|
||||
|
||||
message WorkspaceProto {
|
||||
// Workspace ID (UUID string)
|
||||
string id = 1;
|
||||
|
||||
// Workspace name
|
||||
string name = 2;
|
||||
|
||||
// URL slug
|
||||
string slug = 3;
|
||||
|
||||
// Whether this is the default workspace
|
||||
bool is_default = 4;
|
||||
|
||||
// User's role in this workspace
|
||||
string role = 5;
|
||||
}
|
||||
|
||||
message ListWorkspacesRequest {
|
||||
// Maximum workspaces to return (default: 50)
|
||||
int32 limit = 1;
|
||||
|
||||
// Pagination offset
|
||||
int32 offset = 2;
|
||||
}
|
||||
|
||||
message ListWorkspacesResponse {
|
||||
// User's workspaces
|
||||
repeated WorkspaceProto workspaces = 1;
|
||||
|
||||
// Total count
|
||||
int32 total_count = 2;
|
||||
}
|
||||
|
||||
message SwitchWorkspaceRequest {
|
||||
// Workspace ID to switch to
|
||||
string workspace_id = 1;
|
||||
}
|
||||
|
||||
message SwitchWorkspaceResponse {
|
||||
// Whether switch succeeded
|
||||
bool success = 1;
|
||||
|
||||
// New current workspace info
|
||||
WorkspaceProto workspace = 2;
|
||||
|
||||
// Error message if failed
|
||||
string error_message = 3;
|
||||
}
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1621,3 +1621,73 @@ class ListProjectMembersResponse(_message.Message):
|
||||
members: _containers.RepeatedCompositeFieldContainer[ProjectMembershipProto]
|
||||
total_count: int
|
||||
def __init__(self, members: _Optional[_Iterable[_Union[ProjectMembershipProto, _Mapping]]] = ..., total_count: _Optional[int] = ...) -> None: ...
|
||||
|
||||
class GetCurrentUserRequest(_message.Message):
|
||||
__slots__ = ()
|
||||
def __init__(self) -> None: ...
|
||||
|
||||
class GetCurrentUserResponse(_message.Message):
|
||||
__slots__ = ("user_id", "workspace_id", "display_name", "email", "is_authenticated", "auth_provider", "workspace_name", "role")
|
||||
USER_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
WORKSPACE_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
DISPLAY_NAME_FIELD_NUMBER: _ClassVar[int]
|
||||
EMAIL_FIELD_NUMBER: _ClassVar[int]
|
||||
IS_AUTHENTICATED_FIELD_NUMBER: _ClassVar[int]
|
||||
AUTH_PROVIDER_FIELD_NUMBER: _ClassVar[int]
|
||||
WORKSPACE_NAME_FIELD_NUMBER: _ClassVar[int]
|
||||
ROLE_FIELD_NUMBER: _ClassVar[int]
|
||||
user_id: str
|
||||
workspace_id: str
|
||||
display_name: str
|
||||
email: str
|
||||
is_authenticated: bool
|
||||
auth_provider: str
|
||||
workspace_name: str
|
||||
role: str
|
||||
def __init__(self, user_id: _Optional[str] = ..., workspace_id: _Optional[str] = ..., display_name: _Optional[str] = ..., email: _Optional[str] = ..., is_authenticated: bool = ..., auth_provider: _Optional[str] = ..., workspace_name: _Optional[str] = ..., role: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class WorkspaceProto(_message.Message):
|
||||
__slots__ = ("id", "name", "slug", "is_default", "role")
|
||||
ID_FIELD_NUMBER: _ClassVar[int]
|
||||
NAME_FIELD_NUMBER: _ClassVar[int]
|
||||
SLUG_FIELD_NUMBER: _ClassVar[int]
|
||||
IS_DEFAULT_FIELD_NUMBER: _ClassVar[int]
|
||||
ROLE_FIELD_NUMBER: _ClassVar[int]
|
||||
id: str
|
||||
name: str
|
||||
slug: str
|
||||
is_default: bool
|
||||
role: str
|
||||
def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., slug: _Optional[str] = ..., is_default: bool = ..., role: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class ListWorkspacesRequest(_message.Message):
|
||||
__slots__ = ("limit", "offset")
|
||||
LIMIT_FIELD_NUMBER: _ClassVar[int]
|
||||
OFFSET_FIELD_NUMBER: _ClassVar[int]
|
||||
limit: int
|
||||
offset: int
|
||||
def __init__(self, limit: _Optional[int] = ..., offset: _Optional[int] = ...) -> None: ...
|
||||
|
||||
class ListWorkspacesResponse(_message.Message):
|
||||
__slots__ = ("workspaces", "total_count")
|
||||
WORKSPACES_FIELD_NUMBER: _ClassVar[int]
|
||||
TOTAL_COUNT_FIELD_NUMBER: _ClassVar[int]
|
||||
workspaces: _containers.RepeatedCompositeFieldContainer[WorkspaceProto]
|
||||
total_count: int
|
||||
def __init__(self, workspaces: _Optional[_Iterable[_Union[WorkspaceProto, _Mapping]]] = ..., total_count: _Optional[int] = ...) -> None: ...
|
||||
|
||||
class SwitchWorkspaceRequest(_message.Message):
|
||||
__slots__ = ("workspace_id",)
|
||||
WORKSPACE_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
workspace_id: str
|
||||
def __init__(self, workspace_id: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class SwitchWorkspaceResponse(_message.Message):
|
||||
__slots__ = ("success", "workspace", "error_message")
|
||||
SUCCESS_FIELD_NUMBER: _ClassVar[int]
|
||||
WORKSPACE_FIELD_NUMBER: _ClassVar[int]
|
||||
ERROR_MESSAGE_FIELD_NUMBER: _ClassVar[int]
|
||||
success: bool
|
||||
workspace: WorkspaceProto
|
||||
error_message: str
|
||||
def __init__(self, success: bool = ..., workspace: _Optional[_Union[WorkspaceProto, _Mapping]] = ..., error_message: _Optional[str] = ...) -> None: ...
|
||||
|
||||
@@ -363,6 +363,21 @@ class NoteFlowServiceStub(object):
|
||||
request_serializer=noteflow__pb2.ListProjectMembersRequest.SerializeToString,
|
||||
response_deserializer=noteflow__pb2.ListProjectMembersResponse.FromString,
|
||||
_registered_method=True)
|
||||
self.GetCurrentUser = channel.unary_unary(
|
||||
'/noteflow.NoteFlowService/GetCurrentUser',
|
||||
request_serializer=noteflow__pb2.GetCurrentUserRequest.SerializeToString,
|
||||
response_deserializer=noteflow__pb2.GetCurrentUserResponse.FromString,
|
||||
_registered_method=True)
|
||||
self.ListWorkspaces = channel.unary_unary(
|
||||
'/noteflow.NoteFlowService/ListWorkspaces',
|
||||
request_serializer=noteflow__pb2.ListWorkspacesRequest.SerializeToString,
|
||||
response_deserializer=noteflow__pb2.ListWorkspacesResponse.FromString,
|
||||
_registered_method=True)
|
||||
self.SwitchWorkspace = channel.unary_unary(
|
||||
'/noteflow.NoteFlowService/SwitchWorkspace',
|
||||
request_serializer=noteflow__pb2.SwitchWorkspaceRequest.SerializeToString,
|
||||
response_deserializer=noteflow__pb2.SwitchWorkspaceResponse.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class NoteFlowServiceServicer(object):
|
||||
@@ -782,6 +797,25 @@ class NoteFlowServiceServicer(object):
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def GetCurrentUser(self, request, context):
|
||||
"""Identity management (Sprint 16+)
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def ListWorkspaces(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SwitchWorkspace(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_NoteFlowServiceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
@@ -1110,6 +1144,21 @@ def add_NoteFlowServiceServicer_to_server(servicer, server):
|
||||
request_deserializer=noteflow__pb2.ListProjectMembersRequest.FromString,
|
||||
response_serializer=noteflow__pb2.ListProjectMembersResponse.SerializeToString,
|
||||
),
|
||||
'GetCurrentUser': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GetCurrentUser,
|
||||
request_deserializer=noteflow__pb2.GetCurrentUserRequest.FromString,
|
||||
response_serializer=noteflow__pb2.GetCurrentUserResponse.SerializeToString,
|
||||
),
|
||||
'ListWorkspaces': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.ListWorkspaces,
|
||||
request_deserializer=noteflow__pb2.ListWorkspacesRequest.FromString,
|
||||
response_serializer=noteflow__pb2.ListWorkspacesResponse.SerializeToString,
|
||||
),
|
||||
'SwitchWorkspace': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SwitchWorkspace,
|
||||
request_deserializer=noteflow__pb2.SwitchWorkspaceRequest.FromString,
|
||||
response_serializer=noteflow__pb2.SwitchWorkspaceResponse.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'noteflow.NoteFlowService', rpc_method_handlers)
|
||||
@@ -2879,3 +2928,84 @@ class NoteFlowService(object):
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def GetCurrentUser(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/noteflow.NoteFlowService/GetCurrentUser',
|
||||
noteflow__pb2.GetCurrentUserRequest.SerializeToString,
|
||||
noteflow__pb2.GetCurrentUserResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def ListWorkspaces(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/noteflow.NoteFlowService/ListWorkspaces',
|
||||
noteflow__pb2.ListWorkspacesRequest.SerializeToString,
|
||||
noteflow__pb2.ListWorkspacesResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SwitchWorkspace(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/noteflow.NoteFlowService/SwitchWorkspace',
|
||||
noteflow__pb2.SwitchWorkspaceRequest.SerializeToString,
|
||||
noteflow__pb2.SwitchWorkspaceResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@@ -9,7 +9,13 @@ from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, ClassVar, Final
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from noteflow import __version__
|
||||
from noteflow.application.services.identity_service import IdentityService
|
||||
from noteflow.domain.identity.context import OperationContext, UserContext, WorkspaceContext
|
||||
from noteflow.domain.identity.roles import WorkspaceRole
|
||||
from noteflow.infrastructure.logging import request_id_var, user_id_var, workspace_id_var
|
||||
from noteflow.config.constants import APP_DIR_NAME
|
||||
from noteflow.config.constants import DEFAULT_SAMPLE_RATE as _DEFAULT_SAMPLE_RATE
|
||||
from noteflow.domain.entities import Meeting
|
||||
@@ -34,6 +40,7 @@ from ._mixins import (
|
||||
DiarizationMixin,
|
||||
EntitiesMixin,
|
||||
ExportMixin,
|
||||
IdentityMixin,
|
||||
MeetingMixin,
|
||||
ObservabilityMixin,
|
||||
OidcMixin,
|
||||
@@ -71,6 +78,18 @@ else:
|
||||
pass
|
||||
|
||||
|
||||
# Module-level singleton for identity service (stateless, no dependencies)
|
||||
_identity_service_instance: IdentityService | None = None
|
||||
|
||||
|
||||
def _default_identity_service() -> IdentityService:
|
||||
"""Get or create the default identity service singleton."""
|
||||
global _identity_service_instance
|
||||
if _identity_service_instance is None:
|
||||
_identity_service_instance = IdentityService()
|
||||
return _identity_service_instance
|
||||
|
||||
|
||||
class NoteFlowServicer(
|
||||
StreamingMixin,
|
||||
DiarizationMixin,
|
||||
@@ -86,6 +105,7 @@ class NoteFlowServicer(
|
||||
ObservabilityMixin,
|
||||
PreferencesMixin,
|
||||
OidcMixin,
|
||||
IdentityMixin,
|
||||
ProjectMixin,
|
||||
ProjectMembershipMixin,
|
||||
NoteFlowServicerStubs,
|
||||
@@ -138,6 +158,8 @@ class NoteFlowServicer(
|
||||
self.calendar_service = services.calendar_service
|
||||
self.webhook_service = services.webhook_service
|
||||
self.project_service = services.project_service
|
||||
# Identity service - always available (stateless, no dependencies)
|
||||
self.identity_service = services.identity_service or _default_identity_service()
|
||||
self._start_time = time.time()
|
||||
self.memory_store: MeetingStore | None = MeetingStore() if session_factory is None else None
|
||||
# Audio infrastructure
|
||||
@@ -195,6 +217,40 @@ class NoteFlowServicer(
|
||||
return SqlAlchemyUnitOfWork(self.session_factory, self.meetings_dir)
|
||||
return MemoryUnitOfWork(self.get_memory_store())
|
||||
|
||||
def get_operation_context(self, context: GrpcContext) -> OperationContext:
|
||||
"""Get operation context from gRPC context variables.
|
||||
|
||||
Read identity information set by the IdentityInterceptor from
|
||||
context variables and construct an OperationContext.
|
||||
|
||||
Args:
|
||||
context: gRPC service context (used for metadata if needed).
|
||||
|
||||
Returns:
|
||||
OperationContext with user, workspace, and request info.
|
||||
"""
|
||||
# Read from context variables set by IdentityInterceptor
|
||||
request_id = request_id_var.get()
|
||||
user_id_str = user_id_var.get()
|
||||
workspace_id_str = workspace_id_var.get()
|
||||
|
||||
# Default IDs for local-first mode
|
||||
default_user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
default_workspace_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
|
||||
user_id = UUID(user_id_str) if user_id_str else default_user_id
|
||||
workspace_id = UUID(workspace_id_str) if workspace_id_str else default_workspace_id
|
||||
|
||||
return OperationContext(
|
||||
user=UserContext(user_id=user_id, display_name=""),
|
||||
workspace=WorkspaceContext(
|
||||
workspace_id=workspace_id,
|
||||
workspace_name="",
|
||||
role=WorkspaceRole.OWNER,
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
def init_streaming_state(self, meeting_id: str, next_segment_id: int) -> None:
|
||||
"""Initialize VAD, Segmenter, speaking state, and partial buffers for a meeting."""
|
||||
# Create core components
|
||||
|
||||
@@ -148,6 +148,21 @@ class GoogleCalendarAdapter(CalendarPort):
|
||||
Returns:
|
||||
User's email address.
|
||||
|
||||
Raises:
|
||||
GoogleCalendarError: If API call fails.
|
||||
"""
|
||||
email, _ = await self.get_user_info(access_token)
|
||||
return email
|
||||
|
||||
async def get_user_info(self, access_token: str) -> tuple[str, str]:
|
||||
"""Get authenticated user's email and display name.
|
||||
|
||||
Args:
|
||||
access_token: Valid OAuth access token.
|
||||
|
||||
Returns:
|
||||
Tuple of (email, display_name).
|
||||
|
||||
Raises:
|
||||
GoogleCalendarError: If API call fails.
|
||||
"""
|
||||
@@ -168,11 +183,21 @@ class GoogleCalendarAdapter(CalendarPort):
|
||||
if not isinstance(data_value, dict):
|
||||
raise GoogleCalendarError("Invalid userinfo response")
|
||||
data = cast(dict[str, object], data_value)
|
||||
if email := data.get("email"):
|
||||
return str(email)
|
||||
else:
|
||||
|
||||
email = data.get("email")
|
||||
if not email:
|
||||
raise GoogleCalendarError("No email in userinfo response")
|
||||
|
||||
# Get display name from 'name' field, fall back to email prefix
|
||||
name = data.get("name")
|
||||
display_name = (
|
||||
str(name)
|
||||
if name
|
||||
else str(email).split("@")[0].replace(".", " ").title()
|
||||
)
|
||||
|
||||
return str(email), display_name
|
||||
|
||||
def _parse_event(self, item: _GoogleEvent) -> CalendarEventInfo:
|
||||
"""Parse Google Calendar event into CalendarEventInfo."""
|
||||
event_id = str(item.get("id", ""))
|
||||
|
||||
@@ -78,6 +78,13 @@ class OAuthManager(OAuthPort):
|
||||
# State TTL (10 minutes)
|
||||
STATE_TTL_SECONDS = 600
|
||||
|
||||
# Maximum pending states to prevent memory exhaustion
|
||||
MAX_PENDING_STATES = 100
|
||||
|
||||
# Rate limiting for auth attempts (per provider)
|
||||
MAX_AUTH_ATTEMPTS_PER_MINUTE = 10
|
||||
AUTH_RATE_LIMIT_WINDOW_SECONDS = 60
|
||||
|
||||
def __init__(self, settings: CalendarIntegrationSettings) -> None:
|
||||
"""Initialize OAuth manager with calendar settings.
|
||||
|
||||
@@ -86,6 +93,8 @@ class OAuthManager(OAuthPort):
|
||||
"""
|
||||
self._settings = settings
|
||||
self._pending_states: dict[str, OAuthState] = {}
|
||||
# Track auth attempt timestamps per provider for rate limiting
|
||||
self._auth_attempts: dict[str, list[datetime]] = {}
|
||||
|
||||
def get_pending_state(self, state_token: str) -> OAuthState | None:
|
||||
"""Get pending OAuth state by token.
|
||||
@@ -137,6 +146,16 @@ class OAuthManager(OAuthPort):
|
||||
"""
|
||||
self._cleanup_expired_states()
|
||||
self._validate_provider_config(provider)
|
||||
self._check_rate_limit(provider)
|
||||
|
||||
# Enforce maximum pending states to prevent memory exhaustion
|
||||
if len(self._pending_states) >= self.MAX_PENDING_STATES:
|
||||
logger.warning(
|
||||
"oauth_max_pending_states_exceeded",
|
||||
count=len(self._pending_states),
|
||||
max_allowed=self.MAX_PENDING_STATES,
|
||||
)
|
||||
raise OAuthError("Too many pending OAuth flows. Please try again later.")
|
||||
|
||||
# Generate PKCE code verifier and challenge
|
||||
code_verifier = self._generate_code_verifier()
|
||||
@@ -190,12 +209,31 @@ class OAuthManager(OAuthPort):
|
||||
# Validate and retrieve state
|
||||
oauth_state = self._pending_states.pop(state, None)
|
||||
if oauth_state is None:
|
||||
logger.warning(
|
||||
"oauth_invalid_state_token",
|
||||
event_type="security",
|
||||
provider=provider.value,
|
||||
state_prefix=state[:8] if len(state) >= 8 else state,
|
||||
)
|
||||
raise OAuthError("Invalid or expired state token")
|
||||
|
||||
if oauth_state.is_state_expired():
|
||||
logger.warning(
|
||||
"oauth_state_expired",
|
||||
event_type="security",
|
||||
provider=provider.value,
|
||||
created_at=oauth_state.created_at.isoformat(),
|
||||
expires_at=oauth_state.expires_at.isoformat(),
|
||||
)
|
||||
raise OAuthError("State token has expired")
|
||||
|
||||
if oauth_state.provider != provider:
|
||||
logger.warning(
|
||||
"oauth_provider_mismatch",
|
||||
event_type="security",
|
||||
expected_provider=oauth_state.provider.value,
|
||||
received_provider=provider.value,
|
||||
)
|
||||
raise OAuthError(
|
||||
f"Provider mismatch: expected {oauth_state.provider}, got {provider}"
|
||||
)
|
||||
@@ -468,3 +506,44 @@ class OAuthManager(OAuthPort):
|
||||
]
|
||||
for key in expired_keys:
|
||||
del self._pending_states[key]
|
||||
|
||||
def _check_rate_limit(self, provider: OAuthProvider) -> None:
|
||||
"""Check and enforce rate limiting for auth attempts.
|
||||
|
||||
Prevents brute force attacks by limiting auth attempts per provider.
|
||||
|
||||
Args:
|
||||
provider: OAuth provider being used.
|
||||
|
||||
Raises:
|
||||
OAuthError: If rate limit exceeded.
|
||||
"""
|
||||
provider_key = provider.value
|
||||
now = datetime.now(UTC)
|
||||
cutoff = now - timedelta(seconds=self.AUTH_RATE_LIMIT_WINDOW_SECONDS)
|
||||
|
||||
# Clean up old attempts and count recent ones
|
||||
if provider_key not in self._auth_attempts:
|
||||
self._auth_attempts[provider_key] = []
|
||||
|
||||
# Filter to only keep recent attempts within the window
|
||||
recent_attempts = [
|
||||
ts for ts in self._auth_attempts[provider_key] if ts > cutoff
|
||||
]
|
||||
self._auth_attempts[provider_key] = recent_attempts
|
||||
|
||||
if len(recent_attempts) >= self.MAX_AUTH_ATTEMPTS_PER_MINUTE:
|
||||
logger.warning(
|
||||
"oauth_rate_limit_exceeded",
|
||||
event_type="security",
|
||||
provider=provider_key,
|
||||
attempts=len(recent_attempts),
|
||||
limit=self.MAX_AUTH_ATTEMPTS_PER_MINUTE,
|
||||
window_seconds=self.AUTH_RATE_LIMIT_WINDOW_SECONDS,
|
||||
)
|
||||
raise OAuthError(
|
||||
"Too many auth attempts. Please wait before trying again."
|
||||
)
|
||||
|
||||
# Record this attempt
|
||||
self._auth_attempts[provider_key].append(now)
|
||||
|
||||
@@ -256,11 +256,26 @@ class OutlookCalendarAdapter(CalendarPort):
|
||||
Returns:
|
||||
User's email address.
|
||||
|
||||
Raises:
|
||||
OutlookCalendarError: If API call fails.
|
||||
"""
|
||||
email, _ = await self.get_user_info(access_token)
|
||||
return email
|
||||
|
||||
async def get_user_info(self, access_token: str) -> tuple[str, str]:
|
||||
"""Get authenticated user's email and display name.
|
||||
|
||||
Args:
|
||||
access_token: Valid OAuth access token.
|
||||
|
||||
Returns:
|
||||
Tuple of (email, display_name).
|
||||
|
||||
Raises:
|
||||
OutlookCalendarError: If API call fails.
|
||||
"""
|
||||
url = f"{self.GRAPH_API_BASE}/me"
|
||||
params = {"$select": "mail,userPrincipalName"}
|
||||
params = {"$select": "mail,userPrincipalName,displayName"}
|
||||
headers = {HTTP_AUTHORIZATION: f"{HTTP_BEARER_PREFIX}{access_token}"}
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
@@ -281,11 +296,21 @@ class OutlookCalendarAdapter(CalendarPort):
|
||||
if not isinstance(data_value, dict):
|
||||
raise OutlookCalendarError("Invalid user profile response")
|
||||
data = cast(_OutlookProfile, data_value)
|
||||
if email := data.get("mail") or data.get("userPrincipalName"):
|
||||
return str(email)
|
||||
else:
|
||||
|
||||
email = data.get("mail") or data.get("userPrincipalName")
|
||||
if not email:
|
||||
raise OutlookCalendarError("No email in user profile response")
|
||||
|
||||
# Get display name, fall back to email prefix
|
||||
display_name_raw = data.get("displayName")
|
||||
display_name = (
|
||||
str(display_name_raw)
|
||||
if display_name_raw
|
||||
else str(email).split("@")[0].replace(".", " ").title()
|
||||
)
|
||||
|
||||
return str(email), display_name
|
||||
|
||||
def _parse_event(self, item: _OutlookEvent) -> CalendarEventInfo:
|
||||
"""Parse Microsoft Graph event into CalendarEventInfo."""
|
||||
event_id = str(item.get("id", ""))
|
||||
|
||||
166
src/noteflow/infrastructure/diarization/_compat.py
Normal file
166
src/noteflow/infrastructure/diarization/_compat.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""Compatibility patches for pyannote-audio and diart with modern PyTorch/torchaudio.
|
||||
|
||||
This module applies runtime monkey-patches to fix compatibility issues between:
|
||||
- pyannote-audio 3.x and torchaudio 2.9+ (removed AudioMetaData, backend APIs)
|
||||
- PyTorch 2.6+ (weights_only=True default in torch.load)
|
||||
- huggingface_hub 0.24+ (use_auth_token renamed to token)
|
||||
|
||||
Import this module before importing pyannote.audio or diart to apply patches.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Final, cast
|
||||
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
_patches_applied = False
|
||||
|
||||
# Attribute names for dynamic patching - avoids B010 lint warning
|
||||
_ATTR_AUDIO_METADATA: Final = "AudioMetaData"
|
||||
_ATTR_LOAD: Final = "load"
|
||||
_ATTR_HF_HUB_DOWNLOAD: Final = "hf_hub_download"
|
||||
_ATTR_LIST_BACKENDS: Final = "list_audio_backends"
|
||||
_ATTR_GET_BACKEND: Final = "get_audio_backend"
|
||||
_ATTR_SET_BACKEND: Final = "set_audio_backend"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioMetaData:
|
||||
"""Replacement for torchaudio.AudioMetaData removed in torchaudio 2.9+."""
|
||||
|
||||
sample_rate: int
|
||||
num_frames: int
|
||||
num_channels: int
|
||||
bits_per_sample: int
|
||||
encoding: str
|
||||
|
||||
|
||||
def _patch_torchaudio() -> None:
|
||||
"""Patch torchaudio to restore removed AudioMetaData class."""
|
||||
try:
|
||||
import torchaudio
|
||||
|
||||
if not hasattr(torchaudio, _ATTR_AUDIO_METADATA):
|
||||
setattr(torchaudio, _ATTR_AUDIO_METADATA, AudioMetaData)
|
||||
logger.debug("Patched torchaudio.AudioMetaData")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def _patch_torch_load() -> None:
|
||||
"""Patch torch.load to use weights_only=False for pyannote model loading.
|
||||
|
||||
PyTorch 2.6+ changed the default to weights_only=True which breaks
|
||||
loading pyannote checkpoints that contain non-tensor objects.
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
|
||||
if Version(torch.__version__) >= Version("2.6.0"):
|
||||
original_load = cast(Callable[..., object], torch.load)
|
||||
|
||||
def _patched_load(*args: object, **kwargs: object) -> object:
|
||||
if "weights_only" not in kwargs:
|
||||
kwargs["weights_only"] = False
|
||||
return original_load(*args, **kwargs)
|
||||
|
||||
setattr(torch, _ATTR_LOAD, _patched_load)
|
||||
logger.debug("Patched torch.load for weights_only=False default")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def _patch_huggingface_auth() -> None:
|
||||
"""Patch huggingface_hub functions to accept legacy use_auth_token parameter.
|
||||
|
||||
huggingface_hub 0.24+ renamed use_auth_token to token. This patch
|
||||
allows pyannote/diart code using the old parameter name to work.
|
||||
"""
|
||||
try:
|
||||
import huggingface_hub
|
||||
|
||||
original_download = cast(
|
||||
Callable[..., object], huggingface_hub.hf_hub_download
|
||||
)
|
||||
|
||||
def _patched_download(*args: object, **kwargs: object) -> object:
|
||||
if "use_auth_token" in kwargs:
|
||||
kwargs["token"] = kwargs.pop("use_auth_token")
|
||||
return original_download(*args, **kwargs)
|
||||
|
||||
setattr(huggingface_hub, _ATTR_HF_HUB_DOWNLOAD, _patched_download)
|
||||
logger.debug("Patched huggingface_hub.hf_hub_download for use_auth_token")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def _patch_speechbrain_backend() -> None:
|
||||
"""Patch speechbrain to handle removed torchaudio backend APIs."""
|
||||
try:
|
||||
import torchaudio
|
||||
|
||||
if not hasattr(torchaudio, _ATTR_LIST_BACKENDS):
|
||||
|
||||
def _list_audio_backends() -> list[str]:
|
||||
return ["soundfile", "sox"]
|
||||
|
||||
setattr(torchaudio, _ATTR_LIST_BACKENDS, _list_audio_backends)
|
||||
logger.debug("Patched torchaudio.list_audio_backends")
|
||||
|
||||
if not hasattr(torchaudio, _ATTR_GET_BACKEND):
|
||||
|
||||
def _get_audio_backend() -> str | None:
|
||||
return None
|
||||
|
||||
setattr(torchaudio, _ATTR_GET_BACKEND, _get_audio_backend)
|
||||
logger.debug("Patched torchaudio.get_audio_backend")
|
||||
|
||||
if not hasattr(torchaudio, _ATTR_SET_BACKEND):
|
||||
|
||||
def _set_audio_backend(backend: str | None) -> None:
|
||||
pass
|
||||
|
||||
setattr(torchaudio, _ATTR_SET_BACKEND, _set_audio_backend)
|
||||
logger.debug("Patched torchaudio.set_audio_backend")
|
||||
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def apply_patches() -> None:
|
||||
"""Apply all compatibility patches.
|
||||
|
||||
Safe to call multiple times - patches are only applied once.
|
||||
Should be called before importing pyannote.audio or diart.
|
||||
"""
|
||||
global _patches_applied
|
||||
|
||||
if _patches_applied:
|
||||
return
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
|
||||
_patch_torchaudio()
|
||||
_patch_speechbrain_backend()
|
||||
_patch_torch_load()
|
||||
_patch_huggingface_auth()
|
||||
|
||||
_patches_applied = True
|
||||
logger.info("Applied pyannote/diart compatibility patches")
|
||||
|
||||
|
||||
def ensure_compatibility() -> None:
|
||||
"""Ensure compatibility patches are applied before using diarization.
|
||||
|
||||
This is the recommended entry point - call this at the start of any
|
||||
code path that will use pyannote.audio or diart.
|
||||
"""
|
||||
apply_patches()
|
||||
@@ -151,6 +151,11 @@ class DiarizationEngine:
|
||||
)
|
||||
|
||||
try:
|
||||
# Apply compatibility patches before importing pyannote/diart
|
||||
from noteflow.infrastructure.diarization._compat import ensure_compatibility
|
||||
|
||||
ensure_compatibility()
|
||||
|
||||
import torch
|
||||
from diart import SpeakerDiarization, SpeakerDiarizationConfig
|
||||
from diart.models import EmbeddingModel, SegmentationModel
|
||||
@@ -200,8 +205,16 @@ class DiarizationEngine:
|
||||
logger.info("Loading shared streaming diarization models on %s...", device)
|
||||
|
||||
try:
|
||||
# Apply compatibility patches before importing pyannote/diart
|
||||
from noteflow.infrastructure.diarization._compat import ensure_compatibility
|
||||
|
||||
ensure_compatibility()
|
||||
|
||||
from diart.models import EmbeddingModel, SegmentationModel
|
||||
|
||||
# Use pyannote/segmentation-3.0 with wespeaker embedding
|
||||
# Note: Frame rate mismatch between models causes a warning but is
|
||||
# handled via interpolation in pyannote's StatsPool
|
||||
self._segmentation_model = SegmentationModel.from_pretrained(
|
||||
"pyannote/segmentation-3.0",
|
||||
use_hf_token=self._hf_token,
|
||||
@@ -237,9 +250,14 @@ class DiarizationEngine:
|
||||
import torch
|
||||
from diart import SpeakerDiarization, SpeakerDiarizationConfig
|
||||
|
||||
# Duration must match the segmentation model's expected window size
|
||||
# pyannote/segmentation-3.0 is trained with 10-second windows
|
||||
model_duration = 10.0
|
||||
|
||||
config = SpeakerDiarizationConfig(
|
||||
segmentation=self._segmentation_model,
|
||||
embedding=self._embedding_model,
|
||||
duration=model_duration,
|
||||
step=self._streaming_latency,
|
||||
latency=self._streaming_latency,
|
||||
device=torch.device(self._resolve_device()),
|
||||
@@ -252,6 +270,7 @@ class DiarizationEngine:
|
||||
meeting_id=meeting_id,
|
||||
_pipeline=pipeline,
|
||||
_sample_rate=DEFAULT_SAMPLE_RATE,
|
||||
_chunk_duration=model_duration,
|
||||
)
|
||||
|
||||
def load_offline_model(self) -> None:
|
||||
@@ -272,6 +291,11 @@ class DiarizationEngine:
|
||||
|
||||
with log_timing("diarization_offline_model_load", device=device):
|
||||
try:
|
||||
# Apply compatibility patches before importing pyannote
|
||||
from noteflow.infrastructure.diarization._compat import ensure_compatibility
|
||||
|
||||
ensure_compatibility()
|
||||
|
||||
import torch
|
||||
from pyannote.audio import Pipeline
|
||||
|
||||
|
||||
@@ -10,17 +10,22 @@ from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from noteflow.config.constants import DEFAULT_SAMPLE_RATE
|
||||
from noteflow.infrastructure.diarization.dto import SpeakerTurn
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from diart import SpeakerDiarization
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from numpy.typing import NDArray
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Default chunk duration in seconds (matches pyannote segmentation model)
|
||||
DEFAULT_CHUNK_DURATION: float = 5.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiarizationSession:
|
||||
@@ -32,14 +37,20 @@ class DiarizationSession:
|
||||
The session owns its own SpeakerDiarization pipeline instance but
|
||||
shares the underlying segmentation and embedding models with other
|
||||
sessions for memory efficiency.
|
||||
|
||||
Audio is buffered until a full chunk (default 5 seconds) is available,
|
||||
as the diart pipeline requires fixed-size input chunks.
|
||||
"""
|
||||
|
||||
meeting_id: str
|
||||
_pipeline: SpeakerDiarization | None
|
||||
_sample_rate: int = DEFAULT_SAMPLE_RATE
|
||||
_chunk_duration: float = DEFAULT_CHUNK_DURATION
|
||||
_stream_time: float = field(default=0.0, init=False)
|
||||
_turns: list[SpeakerTurn] = field(default_factory=list, init=False)
|
||||
_closed: bool = field(default=False, init=False)
|
||||
_audio_buffer: list[NDArray[np.float32]] = field(default_factory=list, init=False)
|
||||
_buffer_samples: int = field(default=0, init=False)
|
||||
|
||||
def process_chunk(
|
||||
self,
|
||||
@@ -48,13 +59,17 @@ class DiarizationSession:
|
||||
) -> Sequence[SpeakerTurn]:
|
||||
"""Process an audio chunk and return new speaker turns.
|
||||
|
||||
Audio is buffered until a full chunk (default 5 seconds) is available.
|
||||
The diart pipeline requires fixed-size input chunks matching the
|
||||
segmentation model's expected duration.
|
||||
|
||||
Args:
|
||||
audio: Audio samples as float32 array (mono).
|
||||
audio: Audio samples as float32 array (mono, 1D).
|
||||
sample_rate: Audio sample rate (defaults to session's configured rate).
|
||||
|
||||
Returns:
|
||||
Sequence of speaker turns detected in this chunk,
|
||||
with times adjusted to absolute stream position.
|
||||
Sequence of speaker turns detected. Returns empty list if buffer
|
||||
is not yet full.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If session is closed.
|
||||
@@ -65,36 +80,97 @@ class DiarizationSession:
|
||||
if audio.size == 0:
|
||||
return []
|
||||
|
||||
rate = sample_rate or self._sample_rate
|
||||
duration = len(audio) / rate
|
||||
# Ensure audio is 1D
|
||||
if audio.ndim > 1:
|
||||
audio = audio.flatten()
|
||||
|
||||
# Add to buffer
|
||||
self._audio_buffer.append(audio)
|
||||
self._buffer_samples += len(audio)
|
||||
|
||||
# Calculate required samples for a full chunk
|
||||
rate = sample_rate or self._sample_rate
|
||||
required_samples = int(self._chunk_duration * rate)
|
||||
|
||||
# Check if we have enough for a full chunk
|
||||
if self._buffer_samples < required_samples:
|
||||
return []
|
||||
|
||||
# Concatenate buffered audio
|
||||
full_audio = np.concatenate(self._audio_buffer)
|
||||
|
||||
# Extract exactly required_samples for this chunk
|
||||
chunk_audio = full_audio[:required_samples]
|
||||
|
||||
# Keep remaining audio in buffer for next chunk
|
||||
remaining = full_audio[required_samples:]
|
||||
if len(remaining) > 0:
|
||||
self._audio_buffer = [remaining]
|
||||
self._buffer_samples = len(remaining)
|
||||
else:
|
||||
self._audio_buffer = []
|
||||
self._buffer_samples = 0
|
||||
|
||||
# Process the full chunk
|
||||
return self._process_full_chunk(chunk_audio, rate)
|
||||
|
||||
def _process_full_chunk(
|
||||
self,
|
||||
audio: NDArray[np.float32],
|
||||
sample_rate: int,
|
||||
) -> list[SpeakerTurn]:
|
||||
"""Process a full audio chunk through the diarization pipeline.
|
||||
|
||||
Args:
|
||||
audio: Audio samples as 1D float32 array (exactly chunk_duration seconds).
|
||||
sample_rate: Audio sample rate.
|
||||
|
||||
Returns:
|
||||
List of new speaker turns detected. Returns empty list on error.
|
||||
"""
|
||||
if self._pipeline is None:
|
||||
return []
|
||||
|
||||
# Import here to avoid import errors when diart not installed
|
||||
from pyannote.core import SlidingWindow, SlidingWindowFeature
|
||||
|
||||
# Reshape audio for diart: (samples,) -> (1, samples)
|
||||
if audio.ndim == 1:
|
||||
audio = audio.reshape(1, -1)
|
||||
duration = len(audio) / sample_rate
|
||||
|
||||
# Reshape to (samples, channels) for diart - mono audio has 1 channel
|
||||
audio_2d = audio.reshape(-1, 1)
|
||||
|
||||
# Create SlidingWindowFeature for diart
|
||||
window = SlidingWindow(start=0.0, duration=duration, step=duration)
|
||||
waveform = SlidingWindowFeature(audio, window)
|
||||
waveform = SlidingWindowFeature(audio_2d, window)
|
||||
|
||||
# Process through pipeline
|
||||
results = self._pipeline([waveform])
|
||||
try:
|
||||
# Process through pipeline
|
||||
# Note: Frame rate mismatch between segmentation-3.0 and embedding models
|
||||
# may cause warnings and occasional errors, which we handle gracefully
|
||||
results = self._pipeline([waveform])
|
||||
|
||||
# Convert results to turns with absolute time offsets
|
||||
new_turns: list[SpeakerTurn] = []
|
||||
for annotation, _ in results:
|
||||
for track in annotation.itertracks(yield_label=True):
|
||||
if len(track) == 3:
|
||||
segment, _, speaker = track
|
||||
turn = SpeakerTurn(
|
||||
speaker=str(speaker),
|
||||
start=segment.start + self._stream_time,
|
||||
end=segment.end + self._stream_time,
|
||||
)
|
||||
new_turns.append(turn)
|
||||
self._turns.append(turn)
|
||||
# Convert results to turns with absolute time offsets
|
||||
new_turns: list[SpeakerTurn] = []
|
||||
for annotation, _ in results:
|
||||
for track in annotation.itertracks(yield_label=True):
|
||||
if len(track) == 3:
|
||||
segment, _, speaker = track
|
||||
turn = SpeakerTurn(
|
||||
speaker=str(speaker),
|
||||
start=segment.start + self._stream_time,
|
||||
end=segment.end + self._stream_time,
|
||||
)
|
||||
new_turns.append(turn)
|
||||
self._turns.append(turn)
|
||||
|
||||
except (RuntimeError, ZeroDivisionError, ValueError) as e:
|
||||
# Handle frame/weights mismatch and related errors gracefully
|
||||
# Streaming diarization continues even if individual chunks fail
|
||||
logger.warning(
|
||||
"Diarization chunk processing failed (non-fatal): %s",
|
||||
str(e),
|
||||
exc_info=False,
|
||||
)
|
||||
new_turns = []
|
||||
|
||||
self._stream_time += duration
|
||||
return new_turns
|
||||
@@ -102,7 +178,7 @@ class DiarizationSession:
|
||||
def reset(self) -> None:
|
||||
"""Reset session state for restarting diarization.
|
||||
|
||||
Clears accumulated turns and resets stream time to zero.
|
||||
Clears accumulated turns, audio buffer, and resets stream time to zero.
|
||||
The underlying pipeline is also reset.
|
||||
"""
|
||||
if self._closed or self._pipeline is None:
|
||||
@@ -111,6 +187,8 @@ class DiarizationSession:
|
||||
self._pipeline.reset()
|
||||
self._stream_time = 0.0
|
||||
self._turns.clear()
|
||||
self._audio_buffer.clear()
|
||||
self._buffer_samples = 0
|
||||
logger.debug("Session %s reset", self.meeting_id)
|
||||
|
||||
def restore(
|
||||
@@ -154,6 +232,8 @@ class DiarizationSession:
|
||||
|
||||
self._closed = True
|
||||
self._turns.clear()
|
||||
self._audio_buffer.clear()
|
||||
self._buffer_samples = 0
|
||||
# Explicitly release pipeline reference to allow GC and GPU memory release
|
||||
self._pipeline = None
|
||||
logger.info("diarization_session_closed", meeting_id=self.meeting_id)
|
||||
|
||||
771
tests/application/test_auth_service.py
Normal file
771
tests/application/test_auth_service.py
Normal file
@@ -0,0 +1,771 @@
|
||||
"""Unit tests for AuthService.
|
||||
|
||||
Tests cover:
|
||||
- initiate_login: OAuth flow initiation with valid/invalid providers
|
||||
- complete_login: Token exchange, user creation, and auth result construction
|
||||
- get_current_user: Retrieving current user info, authenticated and unauthenticated
|
||||
- logout: Token revocation and integration cleanup
|
||||
- refresh_auth_tokens: Token refresh with success and failure scenarios
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.application.services.auth_service import (
|
||||
DEFAULT_USER_ID,
|
||||
DEFAULT_WORKSPACE_ID,
|
||||
AuthResult,
|
||||
AuthService,
|
||||
AuthServiceError,
|
||||
LogoutResult,
|
||||
UserInfo,
|
||||
)
|
||||
from noteflow.config.settings import CalendarIntegrationSettings
|
||||
from noteflow.domain.entities.integration import Integration, IntegrationType
|
||||
from noteflow.domain.identity.entities import User
|
||||
from noteflow.domain.value_objects import OAuthProvider, OAuthTokens
|
||||
from noteflow.infrastructure.calendar.oauth_manager import OAuthError
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_manager() -> MagicMock:
|
||||
"""Create mock OAuthManager for testing."""
|
||||
manager = MagicMock()
|
||||
manager.initiate_auth = MagicMock(
|
||||
return_value=("https://auth.example.com/authorize?...", "state123")
|
||||
)
|
||||
manager.complete_auth = AsyncMock()
|
||||
manager.refresh_tokens = AsyncMock()
|
||||
manager.revoke_tokens = AsyncMock()
|
||||
return manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth_uow() -> MagicMock:
|
||||
"""Create mock UnitOfWork with auth-related repositories."""
|
||||
uow = MagicMock()
|
||||
uow.__aenter__ = AsyncMock(return_value=uow)
|
||||
uow.__aexit__ = AsyncMock(return_value=None)
|
||||
uow.commit = AsyncMock()
|
||||
|
||||
# Users repository
|
||||
uow.supports_users = True
|
||||
uow.users = MagicMock()
|
||||
uow.users.get = AsyncMock(return_value=None)
|
||||
uow.users.get_by_email = AsyncMock(return_value=None)
|
||||
uow.users.create = AsyncMock()
|
||||
uow.users.update = AsyncMock()
|
||||
|
||||
# Workspaces repository
|
||||
uow.supports_workspaces = True
|
||||
uow.workspaces = MagicMock()
|
||||
uow.workspaces.get_default_for_user = AsyncMock(return_value=None)
|
||||
uow.workspaces.create = AsyncMock()
|
||||
|
||||
# Integrations repository
|
||||
uow.supports_integrations = True
|
||||
uow.integrations = MagicMock()
|
||||
uow.integrations.get_by_provider = AsyncMock(return_value=None)
|
||||
uow.integrations.create = AsyncMock()
|
||||
uow.integrations.update = AsyncMock()
|
||||
uow.integrations.delete = AsyncMock()
|
||||
uow.integrations.set_secrets = AsyncMock()
|
||||
uow.integrations.get_secrets = AsyncMock(return_value=None)
|
||||
|
||||
return uow
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_service(
|
||||
calendar_settings: CalendarIntegrationSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
) -> AuthService:
|
||||
"""Create AuthService with mock dependencies."""
|
||||
|
||||
def uow_factory() -> MagicMock:
|
||||
"""Return a new mock UoW each time."""
|
||||
return MagicMock()
|
||||
|
||||
return AuthService(
|
||||
uow_factory=uow_factory,
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_oauth_tokens(sample_datetime: datetime) -> OAuthTokens:
|
||||
"""Create sample OAuth tokens for testing."""
|
||||
return OAuthTokens(
|
||||
access_token="access_token_123",
|
||||
refresh_token="refresh_token_456",
|
||||
token_type="Bearer",
|
||||
expires_at=sample_datetime + timedelta(hours=1),
|
||||
scope="openid email profile",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_integration() -> Integration:
|
||||
"""Create sample auth integration for testing."""
|
||||
integration = Integration.create(
|
||||
workspace_id=uuid4(),
|
||||
name="Google Auth",
|
||||
integration_type=IntegrationType.AUTH,
|
||||
config={"provider": "google", "user_id": str(uuid4())},
|
||||
)
|
||||
integration.connect(provider_email="test@example.com")
|
||||
return integration
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test: initiate_login
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestInitiateLogin:
|
||||
"""Tests for AuthService.initiate_login."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("provider", "expected_oauth_provider"),
|
||||
[
|
||||
pytest.param("google", OAuthProvider.GOOGLE, id="google_provider"),
|
||||
pytest.param("Google", OAuthProvider.GOOGLE, id="google_uppercase"),
|
||||
pytest.param("outlook", OAuthProvider.OUTLOOK, id="outlook_provider"),
|
||||
pytest.param("OUTLOOK", OAuthProvider.OUTLOOK, id="outlook_uppercase"),
|
||||
],
|
||||
)
|
||||
async def test_initiates_login_with_valid_provider(
|
||||
self,
|
||||
auth_service: AuthService,
|
||||
mock_oauth_manager: MagicMock,
|
||||
provider: str,
|
||||
expected_oauth_provider: OAuthProvider,
|
||||
) -> None:
|
||||
"""initiate_login returns auth URL and state for valid providers."""
|
||||
auth_url, state = await auth_service.initiate_login(provider)
|
||||
|
||||
assert auth_url == "https://auth.example.com/authorize?...", "should return auth URL"
|
||||
assert state == "state123", "should return state token"
|
||||
mock_oauth_manager.initiate_auth.assert_called_once()
|
||||
|
||||
async def test_initiates_login_with_custom_redirect_uri(
|
||||
self,
|
||||
auth_service: AuthService,
|
||||
mock_oauth_manager: MagicMock,
|
||||
) -> None:
|
||||
"""initiate_login uses custom redirect_uri when provided."""
|
||||
custom_redirect = "https://custom.example.com/callback"
|
||||
|
||||
await auth_service.initiate_login("google", redirect_uri=custom_redirect)
|
||||
|
||||
call_kwargs = mock_oauth_manager.initiate_auth.call_args[1]
|
||||
assert call_kwargs["redirect_uri"] == custom_redirect, "should use custom redirect URI"
|
||||
|
||||
async def test_initiates_login_raises_for_invalid_provider(
|
||||
self,
|
||||
auth_service: AuthService,
|
||||
) -> None:
|
||||
"""initiate_login raises AuthServiceError for invalid provider."""
|
||||
with pytest.raises(AuthServiceError, match="Invalid provider"):
|
||||
await auth_service.initiate_login("invalid_provider")
|
||||
|
||||
async def test_initiates_login_propagates_oauth_error(
|
||||
self,
|
||||
auth_service: AuthService,
|
||||
mock_oauth_manager: MagicMock,
|
||||
) -> None:
|
||||
"""initiate_login raises AuthServiceError when OAuthManager fails."""
|
||||
mock_oauth_manager.initiate_auth.side_effect = OAuthError("OAuth failed")
|
||||
|
||||
with pytest.raises(AuthServiceError, match="OAuth failed"):
|
||||
await auth_service.initiate_login("google")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test: complete_login
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestCompleteLogin:
|
||||
"""Tests for AuthService.complete_login."""
|
||||
|
||||
async def test_completes_login_creates_new_user(
|
||||
self,
|
||||
calendar_settings: CalendarIntegrationSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
mock_auth_uow: MagicMock,
|
||||
sample_oauth_tokens: OAuthTokens,
|
||||
) -> None:
|
||||
"""complete_login creates new user when email not found."""
|
||||
mock_oauth_manager.complete_auth.return_value = sample_oauth_tokens
|
||||
|
||||
service = AuthService(
|
||||
uow_factory=lambda: mock_auth_uow,
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
service, "_fetch_user_info", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
mock_fetch.return_value = ("test@example.com", "Test User")
|
||||
|
||||
result = await service.complete_login("google", "auth_code", "state123")
|
||||
|
||||
assert isinstance(result, AuthResult), "should return AuthResult"
|
||||
assert result.email == "test@example.com", "should include user email"
|
||||
assert result.display_name == "Test User", "should include display name"
|
||||
assert result.is_authenticated is True, "should be authenticated"
|
||||
mock_auth_uow.users.create.assert_called_once()
|
||||
|
||||
async def test_completes_login_updates_existing_user(
|
||||
self,
|
||||
calendar_settings: CalendarIntegrationSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
mock_auth_uow: MagicMock,
|
||||
sample_oauth_tokens: OAuthTokens,
|
||||
) -> None:
|
||||
"""complete_login updates existing user when email found."""
|
||||
existing_user = User(
|
||||
id=uuid4(),
|
||||
display_name="Old Name",
|
||||
email="test@example.com",
|
||||
is_default=False,
|
||||
)
|
||||
mock_auth_uow.users.get_by_email.return_value = existing_user
|
||||
mock_oauth_manager.complete_auth.return_value = sample_oauth_tokens
|
||||
|
||||
service = AuthService(
|
||||
uow_factory=lambda: mock_auth_uow,
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
service, "_fetch_user_info", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
mock_fetch.return_value = ("test@example.com", "New Name")
|
||||
|
||||
result = await service.complete_login("google", "auth_code", "state123")
|
||||
|
||||
assert result.user_id == existing_user.id, "should use existing user ID"
|
||||
mock_auth_uow.users.update.assert_called_once()
|
||||
mock_auth_uow.users.create.assert_not_called()
|
||||
|
||||
async def test_completes_login_stores_integration_tokens(
|
||||
self,
|
||||
calendar_settings: CalendarIntegrationSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
mock_auth_uow: MagicMock,
|
||||
sample_oauth_tokens: OAuthTokens,
|
||||
) -> None:
|
||||
"""complete_login stores tokens in integration secrets."""
|
||||
mock_oauth_manager.complete_auth.return_value = sample_oauth_tokens
|
||||
|
||||
service = AuthService(
|
||||
uow_factory=lambda: mock_auth_uow,
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
service, "_fetch_user_info", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
mock_fetch.return_value = ("test@example.com", "Test User")
|
||||
|
||||
await service.complete_login("google", "auth_code", "state123")
|
||||
|
||||
mock_auth_uow.integrations.set_secrets.assert_called_once()
|
||||
mock_auth_uow.commit.assert_called()
|
||||
|
||||
async def test_completes_login_raises_on_token_exchange_failure(
|
||||
self,
|
||||
calendar_settings: CalendarIntegrationSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
mock_auth_uow: MagicMock,
|
||||
) -> None:
|
||||
"""complete_login raises AuthServiceError when token exchange fails."""
|
||||
mock_oauth_manager.complete_auth.side_effect = OAuthError("Invalid code")
|
||||
|
||||
service = AuthService(
|
||||
uow_factory=lambda: mock_auth_uow,
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
)
|
||||
|
||||
with pytest.raises(AuthServiceError, match="OAuth failed"):
|
||||
await service.complete_login("google", "invalid_code", "state123")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test: get_current_user
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestGetCurrentUser:
|
||||
"""Tests for AuthService.get_current_user."""
|
||||
|
||||
async def test_returns_authenticated_user_for_connected_integration(
|
||||
self,
|
||||
calendar_settings: CalendarIntegrationSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
mock_auth_uow: MagicMock,
|
||||
sample_integration: Integration,
|
||||
) -> None:
|
||||
"""get_current_user returns authenticated user info when integration exists."""
|
||||
user = User(
|
||||
id=UUID(sample_integration.config["user_id"]),
|
||||
display_name="Authenticated User",
|
||||
email="test@example.com",
|
||||
)
|
||||
mock_auth_uow.integrations.get_by_provider.return_value = sample_integration
|
||||
mock_auth_uow.users.get.return_value = user
|
||||
|
||||
service = AuthService(
|
||||
uow_factory=lambda: mock_auth_uow,
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
)
|
||||
|
||||
result = await service.get_current_user()
|
||||
|
||||
assert isinstance(result, UserInfo), "should return UserInfo"
|
||||
assert result.is_authenticated is True, "should be authenticated"
|
||||
assert result.provider == "google", "should include provider"
|
||||
assert result.email == "test@example.com", "should include email"
|
||||
|
||||
async def test_returns_local_user_when_no_integration(
|
||||
self,
|
||||
calendar_settings: CalendarIntegrationSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
mock_auth_uow: MagicMock,
|
||||
) -> None:
|
||||
"""get_current_user returns local default user when no integration exists."""
|
||||
mock_auth_uow.integrations.get_by_provider.return_value = None
|
||||
|
||||
service = AuthService(
|
||||
uow_factory=lambda: mock_auth_uow,
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
)
|
||||
|
||||
result = await service.get_current_user()
|
||||
|
||||
assert isinstance(result, UserInfo), "should return UserInfo"
|
||||
assert result.is_authenticated is False, "should not be authenticated"
|
||||
assert result.display_name == "Local User", "should use local user name"
|
||||
assert result.user_id == DEFAULT_USER_ID, "should use default user ID"
|
||||
assert result.workspace_id == DEFAULT_WORKSPACE_ID, "should use default workspace ID"
|
||||
|
||||
async def test_returns_local_user_for_disconnected_integration(
|
||||
self,
|
||||
calendar_settings: CalendarIntegrationSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
mock_auth_uow: MagicMock,
|
||||
) -> None:
|
||||
"""get_current_user returns local user when integration is disconnected."""
|
||||
integration = Integration.create(
|
||||
workspace_id=uuid4(),
|
||||
name="Google Auth",
|
||||
integration_type=IntegrationType.AUTH,
|
||||
config={"provider": "google"},
|
||||
)
|
||||
# Integration not connected (no provider_email set)
|
||||
mock_auth_uow.integrations.get_by_provider.return_value = integration
|
||||
|
||||
service = AuthService(
|
||||
uow_factory=lambda: mock_auth_uow,
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
)
|
||||
|
||||
result = await service.get_current_user()
|
||||
|
||||
assert result.is_authenticated is False, "disconnected integration should not be authenticated"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test: logout
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestLogout:
|
||||
"""Tests for AuthService.logout."""
|
||||
|
||||
async def test_logout_deletes_integration_and_revokes_tokens(
|
||||
self,
|
||||
calendar_settings: CalendarIntegrationSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
mock_auth_uow: MagicMock,
|
||||
sample_integration: Integration,
|
||||
) -> None:
|
||||
"""logout deletes integration and revokes tokens."""
|
||||
mock_auth_uow.integrations.get_by_provider.return_value = sample_integration
|
||||
mock_auth_uow.integrations.get_secrets.return_value = {
|
||||
"access_token": "token_to_revoke"
|
||||
}
|
||||
|
||||
service = AuthService(
|
||||
uow_factory=lambda: mock_auth_uow,
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
)
|
||||
|
||||
result = await service.logout("google")
|
||||
|
||||
assert result.logged_out is True, "logout should return logged_out=True on success"
|
||||
assert result.tokens_revoked is True, "logout should return tokens_revoked=True"
|
||||
mock_auth_uow.integrations.delete.assert_called_once_with(sample_integration.id)
|
||||
mock_oauth_manager.revoke_tokens.assert_called_once()
|
||||
|
||||
async def test_logout_returns_false_when_no_integration(
|
||||
self,
|
||||
calendar_settings: CalendarIntegrationSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
mock_auth_uow: MagicMock,
|
||||
) -> None:
|
||||
"""logout returns False when no integration exists."""
|
||||
mock_auth_uow.integrations.get_by_provider.return_value = None
|
||||
|
||||
service = AuthService(
|
||||
uow_factory=lambda: mock_auth_uow,
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
)
|
||||
|
||||
result = await service.logout("google")
|
||||
|
||||
assert result.logged_out is False, "logout should return logged_out=False when no integration"
|
||||
mock_auth_uow.integrations.delete.assert_not_called()
|
||||
|
||||
async def test_logout_all_providers_when_none_specified(
|
||||
self,
|
||||
calendar_settings: CalendarIntegrationSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
mock_auth_uow: MagicMock,
|
||||
sample_integration: Integration,
|
||||
) -> None:
|
||||
"""logout attempts all providers when none specified."""
|
||||
# Return integration only for google, not for outlook
|
||||
def get_by_provider(provider: str, integration_type: str) -> Integration | None:
|
||||
return sample_integration if provider == "google" else None
|
||||
|
||||
mock_auth_uow.integrations.get_by_provider.side_effect = get_by_provider
|
||||
|
||||
service = AuthService(
|
||||
uow_factory=lambda: mock_auth_uow,
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
)
|
||||
|
||||
result = await service.logout() # No provider specified
|
||||
|
||||
assert result.logged_out is True, "logout should return logged_out=True if any provider logged out"
|
||||
# Should have checked both providers
|
||||
EXPECTED_PROVIDER_CHECK_COUNT = 2
|
||||
assert (
|
||||
mock_auth_uow.integrations.get_by_provider.call_count
|
||||
== EXPECTED_PROVIDER_CHECK_COUNT
|
||||
), "should check both providers"
|
||||
|
||||
async def test_logout_handles_revocation_error_gracefully(
|
||||
self,
|
||||
calendar_settings: CalendarIntegrationSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
mock_auth_uow: MagicMock,
|
||||
sample_integration: Integration,
|
||||
) -> None:
|
||||
"""logout continues even if token revocation fails."""
|
||||
mock_auth_uow.integrations.get_by_provider.return_value = sample_integration
|
||||
mock_auth_uow.integrations.get_secrets.return_value = {"access_token": "token"}
|
||||
mock_oauth_manager.revoke_tokens.side_effect = OAuthError("Revocation failed")
|
||||
|
||||
service = AuthService(
|
||||
uow_factory=lambda: mock_auth_uow,
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
)
|
||||
|
||||
result = await service.logout("google")
|
||||
|
||||
assert result.logged_out is True, "logout should succeed despite revocation failure"
|
||||
assert result.tokens_revoked is False, "tokens_revoked should be False on revocation failure"
|
||||
assert result.revocation_error is not None, "should include revocation error message"
|
||||
mock_auth_uow.integrations.delete.assert_called_once()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test: refresh_auth_tokens
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestRefreshAuthTokens:
|
||||
"""Tests for AuthService.refresh_auth_tokens."""
|
||||
|
||||
async def test_refreshes_tokens_successfully(
|
||||
self,
|
||||
calendar_settings: CalendarIntegrationSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
mock_auth_uow: MagicMock,
|
||||
sample_integration: Integration,
|
||||
sample_datetime: datetime,
|
||||
) -> None:
|
||||
"""refresh_auth_tokens updates tokens and returns new AuthResult."""
|
||||
old_tokens = {
|
||||
"access_token": "old_access",
|
||||
"refresh_token": "old_refresh",
|
||||
"expires_at": sample_datetime.isoformat(),
|
||||
}
|
||||
new_tokens = OAuthTokens(
|
||||
access_token="new_access",
|
||||
refresh_token="new_refresh",
|
||||
token_type="Bearer",
|
||||
expires_at=sample_datetime + timedelta(hours=1),
|
||||
scope="openid email profile",
|
||||
)
|
||||
mock_auth_uow.integrations.get_by_provider.return_value = sample_integration
|
||||
mock_auth_uow.integrations.get_secrets.return_value = old_tokens
|
||||
mock_oauth_manager.refresh_tokens.return_value = new_tokens
|
||||
|
||||
service = AuthService(
|
||||
uow_factory=lambda: mock_auth_uow,
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
)
|
||||
|
||||
result = await service.refresh_auth_tokens("google")
|
||||
|
||||
assert result is not None, "should return AuthResult on success"
|
||||
assert result.is_authenticated is True, "should return authenticated result"
|
||||
# Verify tokens were stored (not exposed on result per security design)
|
||||
mock_auth_uow.integrations.set_secrets.assert_called_once()
|
||||
mock_auth_uow.commit.assert_called()
|
||||
|
||||
async def test_returns_none_when_no_integration(
|
||||
self,
|
||||
calendar_settings: CalendarIntegrationSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
mock_auth_uow: MagicMock,
|
||||
) -> None:
|
||||
"""refresh_auth_tokens returns None when no integration exists."""
|
||||
mock_auth_uow.integrations.get_by_provider.return_value = None
|
||||
|
||||
service = AuthService(
|
||||
uow_factory=lambda: mock_auth_uow,
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
)
|
||||
|
||||
result = await service.refresh_auth_tokens("google")
|
||||
|
||||
assert result is None, "should return None when no integration"
|
||||
|
||||
async def test_returns_none_when_no_refresh_token(
|
||||
self,
|
||||
calendar_settings: CalendarIntegrationSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
mock_auth_uow: MagicMock,
|
||||
sample_integration: Integration,
|
||||
sample_datetime: datetime,
|
||||
) -> None:
|
||||
"""refresh_auth_tokens returns None when no refresh token available."""
|
||||
mock_auth_uow.integrations.get_by_provider.return_value = sample_integration
|
||||
mock_auth_uow.integrations.get_secrets.return_value = {
|
||||
"access_token": "access",
|
||||
"expires_at": sample_datetime.isoformat(),
|
||||
# No refresh_token
|
||||
}
|
||||
|
||||
service = AuthService(
|
||||
uow_factory=lambda: mock_auth_uow,
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
)
|
||||
|
||||
result = await service.refresh_auth_tokens("google")
|
||||
|
||||
assert result is None, "should return None when no refresh token"
|
||||
|
||||
async def test_marks_error_on_refresh_failure(
|
||||
self,
|
||||
calendar_settings: CalendarIntegrationSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
mock_auth_uow: MagicMock,
|
||||
sample_integration: Integration,
|
||||
sample_datetime: datetime,
|
||||
) -> None:
|
||||
"""refresh_auth_tokens marks integration error on failure."""
|
||||
mock_auth_uow.integrations.get_by_provider.return_value = sample_integration
|
||||
mock_auth_uow.integrations.get_secrets.return_value = {
|
||||
"access_token": "access",
|
||||
"refresh_token": "refresh",
|
||||
"expires_at": sample_datetime.isoformat(),
|
||||
}
|
||||
mock_oauth_manager.refresh_tokens.side_effect = OAuthError("Token expired")
|
||||
|
||||
service = AuthService(
|
||||
uow_factory=lambda: mock_auth_uow,
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
)
|
||||
|
||||
result = await service.refresh_auth_tokens("google")
|
||||
|
||||
assert result is None, "should return None on refresh failure"
|
||||
mock_auth_uow.integrations.update.assert_called_once()
|
||||
mock_auth_uow.commit.assert_called()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test: _store_auth_user (workspace creation)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestStoreAuthUser:
|
||||
"""Tests for AuthService._store_auth_user workspace handling."""
|
||||
|
||||
async def test_creates_default_workspace_for_new_user(
|
||||
self,
|
||||
calendar_settings: CalendarIntegrationSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
mock_auth_uow: MagicMock,
|
||||
sample_oauth_tokens: OAuthTokens,
|
||||
) -> None:
|
||||
"""_store_auth_user creates default workspace when none exists."""
|
||||
mock_auth_uow.workspaces.get_default_for_user.return_value = None
|
||||
mock_auth_uow.workspaces.create = AsyncMock()
|
||||
|
||||
service = AuthService(
|
||||
uow_factory=lambda: mock_auth_uow,
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
)
|
||||
|
||||
await service._store_auth_user(
|
||||
"google", "test@example.com", "Test User", sample_oauth_tokens
|
||||
)
|
||||
|
||||
# Verify workspace.create was called (new workspace for new user)
|
||||
mock_auth_uow.workspaces.create.assert_called_once()
|
||||
call_kwargs = mock_auth_uow.workspaces.create.call_args[1]
|
||||
assert call_kwargs["name"] == "Personal", "should create 'Personal' workspace"
|
||||
assert call_kwargs["is_default"] is True, "should be default workspace"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test: refresh_auth_tokens (additional edge cases)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestRefreshAuthTokensEdgeCases:
|
||||
"""Additional tests for AuthService.refresh_auth_tokens edge cases."""
|
||||
|
||||
async def test_skips_refresh_when_token_still_valid(
|
||||
self,
|
||||
calendar_settings: CalendarIntegrationSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
mock_auth_uow: MagicMock,
|
||||
sample_integration: Integration,
|
||||
) -> None:
|
||||
"""refresh_auth_tokens returns existing auth when token not expired."""
|
||||
# Token expires in 1 hour (more than 5 minute buffer)
|
||||
future_expiry = datetime.now(UTC) + timedelta(hours=1)
|
||||
secrets = {
|
||||
"access_token": "valid_token",
|
||||
"refresh_token": "refresh_token",
|
||||
"expires_at": future_expiry.isoformat(),
|
||||
}
|
||||
mock_auth_uow.integrations.get_by_provider.return_value = sample_integration
|
||||
mock_auth_uow.integrations.get_secrets.return_value = secrets
|
||||
|
||||
service = AuthService(
|
||||
uow_factory=lambda: mock_auth_uow,
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
)
|
||||
|
||||
result = await service.refresh_auth_tokens("google")
|
||||
|
||||
# Should return result without calling refresh
|
||||
assert result is not None, "should return AuthResult"
|
||||
mock_oauth_manager.refresh_tokens.assert_not_called()
|
||||
|
||||
async def test_returns_none_on_invalid_secrets_dict(
|
||||
self,
|
||||
calendar_settings: CalendarIntegrationSettings,
|
||||
mock_oauth_manager: MagicMock,
|
||||
mock_auth_uow: MagicMock,
|
||||
sample_integration: Integration,
|
||||
) -> None:
|
||||
"""refresh_auth_tokens returns None when secrets can't be parsed."""
|
||||
mock_auth_uow.integrations.get_by_provider.return_value = sample_integration
|
||||
mock_auth_uow.integrations.get_secrets.return_value = {
|
||||
"invalid": "data", # Missing required fields
|
||||
}
|
||||
|
||||
service = AuthService(
|
||||
uow_factory=lambda: mock_auth_uow,
|
||||
settings=calendar_settings,
|
||||
oauth_manager=mock_oauth_manager,
|
||||
)
|
||||
|
||||
result = await service.refresh_auth_tokens("google")
|
||||
|
||||
assert result is None, "should return None for invalid secrets"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test: _parse_provider (static method)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestParseProvider:
|
||||
"""Tests for AuthService._parse_provider static method."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("input_provider", "expected_output"),
|
||||
[
|
||||
pytest.param("google", OAuthProvider.GOOGLE, id="google_lowercase"),
|
||||
pytest.param("GOOGLE", OAuthProvider.GOOGLE, id="google_uppercase"),
|
||||
pytest.param("Google", OAuthProvider.GOOGLE, id="google_mixed_case"),
|
||||
pytest.param("outlook", OAuthProvider.OUTLOOK, id="outlook_lowercase"),
|
||||
pytest.param("OUTLOOK", OAuthProvider.OUTLOOK, id="outlook_uppercase"),
|
||||
],
|
||||
)
|
||||
def test_parses_valid_providers(
|
||||
self,
|
||||
input_provider: str,
|
||||
expected_output: OAuthProvider,
|
||||
) -> None:
|
||||
"""_parse_provider correctly parses valid provider strings."""
|
||||
result = AuthService._parse_provider(input_provider)
|
||||
|
||||
assert result == expected_output, f"should parse {input_provider} correctly"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_provider",
|
||||
[
|
||||
pytest.param("github", id="github_not_supported"),
|
||||
pytest.param("facebook", id="facebook_not_supported"),
|
||||
pytest.param("", id="empty_string"),
|
||||
pytest.param("invalid", id="random_string"),
|
||||
],
|
||||
)
|
||||
def test_raises_for_invalid_providers(
|
||||
self,
|
||||
invalid_provider: str,
|
||||
) -> None:
|
||||
"""_parse_provider raises AuthServiceError for invalid providers."""
|
||||
with pytest.raises(AuthServiceError, match="Invalid provider"):
|
||||
AuthService._parse_provider(invalid_provider)
|
||||
578
tests/grpc/test_identity_mixin.py
Normal file
578
tests/grpc/test_identity_mixin.py
Normal file
@@ -0,0 +1,578 @@
|
||||
"""Tests for IdentityMixin gRPC endpoints.
|
||||
|
||||
Tests cover:
|
||||
- GetCurrentUser: Returns user identity, workspace, and auth status
|
||||
- ListWorkspaces: Lists user's workspaces with pagination
|
||||
- SwitchWorkspace: Validates workspace access and returns workspace info
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.application.services.identity_service import IdentityService
|
||||
from noteflow.domain.entities.integration import Integration, IntegrationType
|
||||
from noteflow.domain.identity.context import OperationContext, UserContext, WorkspaceContext
|
||||
from noteflow.domain.identity.entities import Workspace, WorkspaceMembership
|
||||
from noteflow.domain.identity.roles import WorkspaceRole
|
||||
from noteflow.grpc._mixins._types import GrpcContext
|
||||
from noteflow.grpc._mixins.identity import IdentityMixin
|
||||
from noteflow.grpc.proto import noteflow_pb2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Mock Servicer Host
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class MockIdentityServicerHost(IdentityMixin):
|
||||
"""Mock servicer host implementing required protocol for IdentityMixin."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize mock servicer with identity service."""
|
||||
self._identity_service = IdentityService()
|
||||
self._mock_uow: MagicMock | None = None
|
||||
|
||||
def create_repository_provider(self) -> MagicMock:
|
||||
"""Return mock UnitOfWork."""
|
||||
if self._mock_uow is None:
|
||||
msg = "Mock UoW not configured"
|
||||
raise RuntimeError(msg)
|
||||
return self._mock_uow
|
||||
|
||||
def get_operation_context(self, context: GrpcContext) -> OperationContext:
|
||||
"""Return mock operation context."""
|
||||
return OperationContext(
|
||||
user=UserContext(
|
||||
user_id=uuid4(),
|
||||
display_name="Test User",
|
||||
),
|
||||
workspace=WorkspaceContext(
|
||||
workspace_id=uuid4(),
|
||||
workspace_name="Test Workspace",
|
||||
role=WorkspaceRole.OWNER,
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def identity_service(self) -> IdentityService:
|
||||
"""Return identity service."""
|
||||
return self._identity_service
|
||||
|
||||
def set_mock_uow(self, uow: MagicMock) -> None:
|
||||
"""Set mock UnitOfWork for testing."""
|
||||
self._mock_uow = uow
|
||||
|
||||
# Type stubs for mixin methods
|
||||
if TYPE_CHECKING:
|
||||
async def GetCurrentUser(
|
||||
self,
|
||||
request: noteflow_pb2.GetCurrentUserRequest,
|
||||
context: GrpcContext,
|
||||
) -> noteflow_pb2.GetCurrentUserResponse: ...
|
||||
|
||||
async def ListWorkspaces(
|
||||
self,
|
||||
request: noteflow_pb2.ListWorkspacesRequest,
|
||||
context: GrpcContext,
|
||||
) -> noteflow_pb2.ListWorkspacesResponse: ...
|
||||
|
||||
async def SwitchWorkspace(
|
||||
self,
|
||||
request: noteflow_pb2.SwitchWorkspaceRequest,
|
||||
context: GrpcContext,
|
||||
) -> noteflow_pb2.SwitchWorkspaceResponse: ...
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def identity_servicer() -> MockIdentityServicerHost:
|
||||
"""Create servicer for identity mixin testing."""
|
||||
return MockIdentityServicerHost()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_identity_uow() -> MagicMock:
|
||||
"""Create mock UnitOfWork with identity-related repositories."""
|
||||
uow = MagicMock()
|
||||
uow.__aenter__ = AsyncMock(return_value=uow)
|
||||
uow.__aexit__ = AsyncMock(return_value=None)
|
||||
uow.commit = AsyncMock()
|
||||
|
||||
# Users repository
|
||||
uow.supports_users = True
|
||||
uow.users = MagicMock()
|
||||
uow.users.get = AsyncMock(return_value=None)
|
||||
uow.users.get_default = AsyncMock(return_value=None)
|
||||
uow.users.create_default = AsyncMock()
|
||||
|
||||
# Workspaces repository
|
||||
uow.supports_workspaces = True
|
||||
uow.workspaces = MagicMock()
|
||||
uow.workspaces.get = AsyncMock(return_value=None)
|
||||
uow.workspaces.get_default_for_user = AsyncMock(return_value=None)
|
||||
uow.workspaces.get_membership = AsyncMock(return_value=None)
|
||||
uow.workspaces.list_for_user = AsyncMock(return_value=[])
|
||||
uow.workspaces.create = AsyncMock()
|
||||
|
||||
# Integrations repository
|
||||
uow.supports_integrations = True
|
||||
uow.integrations = MagicMock()
|
||||
uow.integrations.get_by_provider = AsyncMock(return_value=None)
|
||||
|
||||
# Projects repository (for workspace creation)
|
||||
uow.supports_projects = False
|
||||
|
||||
return uow
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_user_context() -> UserContext:
|
||||
"""Create sample user context for testing."""
|
||||
return UserContext(
|
||||
user_id=uuid4(),
|
||||
display_name="Test User",
|
||||
email="test@example.com",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workspace_context() -> WorkspaceContext:
|
||||
"""Create sample workspace context for testing."""
|
||||
return WorkspaceContext(
|
||||
workspace_id=uuid4(),
|
||||
workspace_name="Test Workspace",
|
||||
role=WorkspaceRole.OWNER,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workspace(sample_datetime: datetime) -> Workspace:
|
||||
"""Create sample workspace for testing."""
|
||||
return Workspace(
|
||||
id=uuid4(),
|
||||
name="Test Workspace",
|
||||
slug="test-workspace",
|
||||
is_default=True,
|
||||
created_at=sample_datetime,
|
||||
updated_at=sample_datetime,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_membership() -> WorkspaceMembership:
|
||||
"""Create sample workspace membership for testing."""
|
||||
return WorkspaceMembership(
|
||||
workspace_id=uuid4(),
|
||||
user_id=uuid4(),
|
||||
role=WorkspaceRole.MEMBER,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test: GetCurrentUser
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestGetCurrentUser:
|
||||
"""Tests for IdentityMixin.GetCurrentUser."""
|
||||
|
||||
async def test_returns_default_user_in_memory_mode(
|
||||
self,
|
||||
identity_servicer: MockIdentityServicerHost,
|
||||
mock_grpc_context: MagicMock,
|
||||
) -> None:
|
||||
"""GetCurrentUser returns default user when in memory mode."""
|
||||
uow = MagicMock()
|
||||
uow.__aenter__ = AsyncMock(return_value=uow)
|
||||
uow.__aexit__ = AsyncMock(return_value=None)
|
||||
uow.commit = AsyncMock()
|
||||
uow.supports_users = False
|
||||
uow.supports_workspaces = False
|
||||
uow.supports_integrations = False
|
||||
|
||||
identity_servicer.set_mock_uow(uow)
|
||||
|
||||
request = noteflow_pb2.GetCurrentUserRequest()
|
||||
response = await identity_servicer.GetCurrentUser(request, mock_grpc_context)
|
||||
|
||||
assert response.user_id, "should return user_id"
|
||||
assert response.workspace_id, "should return workspace_id"
|
||||
assert response.display_name, "should return display_name"
|
||||
assert response.is_authenticated is False, "should not be authenticated in memory mode"
|
||||
|
||||
async def test_returns_authenticated_user_with_oauth(
|
||||
self,
|
||||
identity_servicer: MockIdentityServicerHost,
|
||||
mock_identity_uow: MagicMock,
|
||||
mock_grpc_context: MagicMock,
|
||||
sample_datetime: datetime,
|
||||
) -> None:
|
||||
"""GetCurrentUser returns authenticated user when OAuth integration exists."""
|
||||
# Configure connected integration
|
||||
integration = Integration.create(
|
||||
workspace_id=uuid4(),
|
||||
name="Google Auth",
|
||||
integration_type=IntegrationType.AUTH,
|
||||
config={"provider": "google"},
|
||||
)
|
||||
integration.connect(provider_email="test@example.com")
|
||||
mock_identity_uow.integrations.get_by_provider.return_value = integration
|
||||
|
||||
identity_servicer.set_mock_uow(mock_identity_uow)
|
||||
|
||||
request = noteflow_pb2.GetCurrentUserRequest()
|
||||
response = await identity_servicer.GetCurrentUser(request, mock_grpc_context)
|
||||
|
||||
assert response.is_authenticated is True, "should be authenticated with OAuth"
|
||||
assert response.auth_provider == "google", "should return auth provider"
|
||||
|
||||
async def test_returns_workspace_role(
|
||||
self,
|
||||
identity_servicer: MockIdentityServicerHost,
|
||||
mock_identity_uow: MagicMock,
|
||||
mock_grpc_context: MagicMock,
|
||||
) -> None:
|
||||
"""GetCurrentUser returns user's workspace role."""
|
||||
identity_servicer.set_mock_uow(mock_identity_uow)
|
||||
|
||||
request = noteflow_pb2.GetCurrentUserRequest()
|
||||
response = await identity_servicer.GetCurrentUser(request, mock_grpc_context)
|
||||
|
||||
# Role should be set (default is owner for first user)
|
||||
assert response.role, "should return workspace role"
|
||||
assert response.workspace_name, "should return workspace name"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test: ListWorkspaces
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestListWorkspaces:
|
||||
"""Tests for IdentityMixin.ListWorkspaces."""
|
||||
|
||||
async def test_returns_empty_list_when_no_workspaces(
|
||||
self,
|
||||
identity_servicer: MockIdentityServicerHost,
|
||||
mock_identity_uow: MagicMock,
|
||||
mock_grpc_context: MagicMock,
|
||||
) -> None:
|
||||
"""ListWorkspaces returns empty list when user has no workspaces."""
|
||||
identity_servicer.set_mock_uow(mock_identity_uow)
|
||||
|
||||
request = noteflow_pb2.ListWorkspacesRequest()
|
||||
response = await identity_servicer.ListWorkspaces(request, mock_grpc_context)
|
||||
|
||||
assert response.total_count == 0, "should return zero count"
|
||||
assert len(response.workspaces) == 0, "should return empty list"
|
||||
|
||||
async def test_returns_user_workspaces(
|
||||
self,
|
||||
identity_servicer: MockIdentityServicerHost,
|
||||
mock_identity_uow: MagicMock,
|
||||
mock_grpc_context: MagicMock,
|
||||
sample_workspace: Workspace,
|
||||
sample_membership: WorkspaceMembership,
|
||||
) -> None:
|
||||
"""ListWorkspaces returns workspaces the user belongs to."""
|
||||
mock_identity_uow.workspaces.list_for_user.return_value = [sample_workspace]
|
||||
mock_identity_uow.workspaces.get_membership.return_value = sample_membership
|
||||
|
||||
identity_servicer.set_mock_uow(mock_identity_uow)
|
||||
|
||||
request = noteflow_pb2.ListWorkspacesRequest()
|
||||
response = await identity_servicer.ListWorkspaces(request, mock_grpc_context)
|
||||
|
||||
assert response.total_count == 1, "should return correct count"
|
||||
assert len(response.workspaces) == 1, "should return one workspace"
|
||||
assert response.workspaces[0].name == "Test Workspace", "should include workspace name"
|
||||
assert response.workspaces[0].is_default is True, "should include is_default flag"
|
||||
|
||||
async def test_respects_pagination_parameters(
|
||||
self,
|
||||
identity_servicer: MockIdentityServicerHost,
|
||||
mock_identity_uow: MagicMock,
|
||||
mock_grpc_context: MagicMock,
|
||||
) -> None:
|
||||
"""ListWorkspaces passes pagination parameters to repository."""
|
||||
identity_servicer.set_mock_uow(mock_identity_uow)
|
||||
|
||||
request = noteflow_pb2.ListWorkspacesRequest(limit=10, offset=5)
|
||||
await identity_servicer.ListWorkspaces(request, mock_grpc_context)
|
||||
|
||||
# Check that list_for_user was called with pagination params
|
||||
mock_identity_uow.workspaces.list_for_user.assert_called()
|
||||
call_args = mock_identity_uow.workspaces.list_for_user.call_args
|
||||
EXPECTED_LIMIT = 10
|
||||
EXPECTED_OFFSET = 5
|
||||
assert call_args[0][1] == EXPECTED_LIMIT, "should pass limit"
|
||||
assert call_args[0][2] == EXPECTED_OFFSET, "should pass offset"
|
||||
|
||||
async def test_uses_default_pagination_values(
|
||||
self,
|
||||
identity_servicer: MockIdentityServicerHost,
|
||||
mock_identity_uow: MagicMock,
|
||||
mock_grpc_context: MagicMock,
|
||||
) -> None:
|
||||
"""ListWorkspaces uses default pagination when not specified."""
|
||||
identity_servicer.set_mock_uow(mock_identity_uow)
|
||||
|
||||
request = noteflow_pb2.ListWorkspacesRequest() # No limit/offset
|
||||
await identity_servicer.ListWorkspaces(request, mock_grpc_context)
|
||||
|
||||
call_args = mock_identity_uow.workspaces.list_for_user.call_args
|
||||
DEFAULT_LIMIT = 50
|
||||
DEFAULT_OFFSET = 0
|
||||
assert call_args[0][1] == DEFAULT_LIMIT, "should use default limit"
|
||||
assert call_args[0][2] == DEFAULT_OFFSET, "should use default offset"
|
||||
|
||||
async def test_includes_workspace_role(
|
||||
self,
|
||||
identity_servicer: MockIdentityServicerHost,
|
||||
mock_identity_uow: MagicMock,
|
||||
mock_grpc_context: MagicMock,
|
||||
sample_workspace: Workspace,
|
||||
) -> None:
|
||||
"""ListWorkspaces includes user's role in each workspace."""
|
||||
owner_membership = WorkspaceMembership(
|
||||
workspace_id=sample_workspace.id,
|
||||
user_id=uuid4(),
|
||||
role=WorkspaceRole.OWNER,
|
||||
)
|
||||
mock_identity_uow.workspaces.list_for_user.return_value = [sample_workspace]
|
||||
mock_identity_uow.workspaces.get_membership.return_value = owner_membership
|
||||
|
||||
identity_servicer.set_mock_uow(mock_identity_uow)
|
||||
|
||||
request = noteflow_pb2.ListWorkspacesRequest()
|
||||
response = await identity_servicer.ListWorkspaces(request, mock_grpc_context)
|
||||
|
||||
assert response.workspaces[0].role == "owner", "should include role"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test: SwitchWorkspace
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestSwitchWorkspace:
|
||||
"""Tests for IdentityMixin.SwitchWorkspace."""
|
||||
|
||||
async def test_aborts_when_workspace_id_missing(
|
||||
self,
|
||||
identity_servicer: MockIdentityServicerHost,
|
||||
mock_identity_uow: MagicMock,
|
||||
mock_grpc_context: MagicMock,
|
||||
) -> None:
|
||||
"""SwitchWorkspace aborts with INVALID_ARGUMENT when workspace_id not provided."""
|
||||
identity_servicer.set_mock_uow(mock_identity_uow)
|
||||
|
||||
request = noteflow_pb2.SwitchWorkspaceRequest() # No workspace_id
|
||||
|
||||
with pytest.raises(AssertionError, match="Unreachable"):
|
||||
await identity_servicer.SwitchWorkspace(request, mock_grpc_context)
|
||||
|
||||
mock_grpc_context.abort.assert_called_once()
|
||||
|
||||
async def test_aborts_for_invalid_uuid(
|
||||
self,
|
||||
identity_servicer: MockIdentityServicerHost,
|
||||
mock_identity_uow: MagicMock,
|
||||
mock_grpc_context: MagicMock,
|
||||
) -> None:
|
||||
"""SwitchWorkspace aborts with INVALID_ARGUMENT for invalid workspace_id format."""
|
||||
identity_servicer.set_mock_uow(mock_identity_uow)
|
||||
|
||||
request = noteflow_pb2.SwitchWorkspaceRequest(workspace_id="not-a-uuid")
|
||||
|
||||
with pytest.raises(AssertionError, match="Unreachable"):
|
||||
await identity_servicer.SwitchWorkspace(request, mock_grpc_context)
|
||||
|
||||
mock_grpc_context.abort.assert_called_once()
|
||||
|
||||
async def test_aborts_when_workspace_not_found(
|
||||
self,
|
||||
identity_servicer: MockIdentityServicerHost,
|
||||
mock_identity_uow: MagicMock,
|
||||
mock_grpc_context: MagicMock,
|
||||
) -> None:
|
||||
"""SwitchWorkspace aborts with NOT_FOUND when workspace doesn't exist."""
|
||||
mock_identity_uow.workspaces.get.return_value = None
|
||||
|
||||
identity_servicer.set_mock_uow(mock_identity_uow)
|
||||
|
||||
workspace_id = uuid4()
|
||||
request = noteflow_pb2.SwitchWorkspaceRequest(workspace_id=str(workspace_id))
|
||||
|
||||
with pytest.raises(AssertionError, match="Unreachable"):
|
||||
await identity_servicer.SwitchWorkspace(request, mock_grpc_context)
|
||||
|
||||
mock_grpc_context.abort.assert_called_once()
|
||||
|
||||
async def test_aborts_when_user_not_member(
|
||||
self,
|
||||
identity_servicer: MockIdentityServicerHost,
|
||||
mock_identity_uow: MagicMock,
|
||||
mock_grpc_context: MagicMock,
|
||||
sample_workspace: Workspace,
|
||||
) -> None:
|
||||
"""SwitchWorkspace aborts with NOT_FOUND when user is not a member of workspace."""
|
||||
mock_identity_uow.workspaces.get.return_value = sample_workspace
|
||||
mock_identity_uow.workspaces.get_membership.return_value = None # No membership
|
||||
|
||||
identity_servicer.set_mock_uow(mock_identity_uow)
|
||||
|
||||
request = noteflow_pb2.SwitchWorkspaceRequest(workspace_id=str(sample_workspace.id))
|
||||
|
||||
with pytest.raises(AssertionError, match="Unreachable"):
|
||||
await identity_servicer.SwitchWorkspace(request, mock_grpc_context)
|
||||
|
||||
mock_grpc_context.abort.assert_called_once()
|
||||
|
||||
async def test_switches_workspace_successfully(
|
||||
self,
|
||||
identity_servicer: MockIdentityServicerHost,
|
||||
mock_identity_uow: MagicMock,
|
||||
mock_grpc_context: MagicMock,
|
||||
sample_workspace: Workspace,
|
||||
sample_membership: WorkspaceMembership,
|
||||
) -> None:
|
||||
"""SwitchWorkspace returns workspace info on success."""
|
||||
mock_identity_uow.workspaces.get.return_value = sample_workspace
|
||||
mock_identity_uow.workspaces.get_membership.return_value = sample_membership
|
||||
|
||||
identity_servicer.set_mock_uow(mock_identity_uow)
|
||||
|
||||
request = noteflow_pb2.SwitchWorkspaceRequest(workspace_id=str(sample_workspace.id))
|
||||
response = await identity_servicer.SwitchWorkspace(request, mock_grpc_context)
|
||||
|
||||
assert response.success is True, "should return success=True"
|
||||
assert response.workspace.id == str(sample_workspace.id), "should return workspace ID"
|
||||
assert response.workspace.name == "Test Workspace", "should return workspace name"
|
||||
assert response.workspace.role == "member", "should return user's role"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("role", "expected_role_str"),
|
||||
[
|
||||
pytest.param(WorkspaceRole.OWNER, "owner", id="owner_role"),
|
||||
pytest.param(WorkspaceRole.ADMIN, "admin", id="admin_role"),
|
||||
pytest.param(WorkspaceRole.MEMBER, "member", id="member_role"),
|
||||
],
|
||||
)
|
||||
async def test_returns_correct_role_for_membership(
|
||||
self,
|
||||
identity_servicer: MockIdentityServicerHost,
|
||||
mock_identity_uow: MagicMock,
|
||||
mock_grpc_context: MagicMock,
|
||||
sample_workspace: Workspace,
|
||||
role: WorkspaceRole,
|
||||
expected_role_str: str,
|
||||
) -> None:
|
||||
"""SwitchWorkspace returns correct role string for different memberships."""
|
||||
membership = WorkspaceMembership(
|
||||
workspace_id=sample_workspace.id,
|
||||
user_id=uuid4(),
|
||||
role=role,
|
||||
)
|
||||
mock_identity_uow.workspaces.get.return_value = sample_workspace
|
||||
mock_identity_uow.workspaces.get_membership.return_value = membership
|
||||
|
||||
identity_servicer.set_mock_uow(mock_identity_uow)
|
||||
|
||||
request = noteflow_pb2.SwitchWorkspaceRequest(workspace_id=str(sample_workspace.id))
|
||||
response = await identity_servicer.SwitchWorkspace(request, mock_grpc_context)
|
||||
|
||||
assert response.workspace.role == expected_role_str, f"should return role as {expected_role_str}"
|
||||
|
||||
async def test_includes_workspace_metadata(
|
||||
self,
|
||||
identity_servicer: MockIdentityServicerHost,
|
||||
mock_identity_uow: MagicMock,
|
||||
mock_grpc_context: MagicMock,
|
||||
sample_membership: WorkspaceMembership,
|
||||
) -> None:
|
||||
"""SwitchWorkspace includes workspace slug and is_default flag."""
|
||||
workspace = Workspace(
|
||||
id=uuid4(),
|
||||
name="My Custom Workspace",
|
||||
slug="my-custom-workspace",
|
||||
is_default=False,
|
||||
)
|
||||
mock_identity_uow.workspaces.get.return_value = workspace
|
||||
mock_identity_uow.workspaces.get_membership.return_value = sample_membership
|
||||
|
||||
identity_servicer.set_mock_uow(mock_identity_uow)
|
||||
|
||||
request = noteflow_pb2.SwitchWorkspaceRequest(workspace_id=str(workspace.id))
|
||||
response = await identity_servicer.SwitchWorkspace(request, mock_grpc_context)
|
||||
|
||||
assert response.workspace.slug == "my-custom-workspace", "should include slug"
|
||||
assert response.workspace.is_default is False, "should include is_default flag"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test: Database Required Error
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestDatabaseRequired:
|
||||
"""Tests for database requirement handling in identity endpoints."""
|
||||
|
||||
async def test_list_workspaces_aborts_without_database(
|
||||
self,
|
||||
identity_servicer: MockIdentityServicerHost,
|
||||
mock_grpc_context: MagicMock,
|
||||
) -> None:
|
||||
"""ListWorkspaces aborts when workspaces not supported."""
|
||||
uow = MagicMock()
|
||||
uow.__aenter__ = AsyncMock(return_value=uow)
|
||||
uow.__aexit__ = AsyncMock(return_value=None)
|
||||
uow.commit = AsyncMock()
|
||||
uow.supports_users = False
|
||||
uow.supports_workspaces = False
|
||||
|
||||
identity_servicer.set_mock_uow(uow)
|
||||
|
||||
request = noteflow_pb2.ListWorkspacesRequest()
|
||||
|
||||
# abort helpers raise AssertionError after mock context.abort()
|
||||
with pytest.raises(AssertionError, match="Unreachable"):
|
||||
await identity_servicer.ListWorkspaces(request, mock_grpc_context)
|
||||
|
||||
mock_grpc_context.abort.assert_called_once()
|
||||
|
||||
async def test_switch_workspace_aborts_without_database(
|
||||
self,
|
||||
identity_servicer: MockIdentityServicerHost,
|
||||
mock_grpc_context: MagicMock,
|
||||
) -> None:
|
||||
"""SwitchWorkspace aborts when workspaces not supported."""
|
||||
uow = MagicMock()
|
||||
uow.__aenter__ = AsyncMock(return_value=uow)
|
||||
uow.__aexit__ = AsyncMock(return_value=None)
|
||||
uow.commit = AsyncMock()
|
||||
uow.supports_users = False
|
||||
uow.supports_workspaces = False
|
||||
|
||||
identity_servicer.set_mock_uow(uow)
|
||||
|
||||
workspace_id = uuid4()
|
||||
request = noteflow_pb2.SwitchWorkspaceRequest(workspace_id=str(workspace_id))
|
||||
|
||||
# abort helpers raise AssertionError after mock context.abort()
|
||||
with pytest.raises(AssertionError, match="Unreachable"):
|
||||
await identity_servicer.SwitchWorkspace(request, mock_grpc_context)
|
||||
|
||||
mock_grpc_context.abort.assert_called_once()
|
||||
@@ -267,3 +267,126 @@ class TestGoogleCalendarAdapterDateParsing:
|
||||
events = await adapter.list_events("access-token")
|
||||
|
||||
assert events[0].is_recurring is True, "event with recurringEventId should be recurring"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test: get_user_info
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestGoogleCalendarAdapterGetUserInfo:
|
||||
"""Tests for GoogleCalendarAdapter.get_user_info."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_email_and_display_name(self) -> None:
|
||||
"""get_user_info should return email and display name."""
|
||||
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
|
||||
|
||||
adapter = GoogleCalendarAdapter()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"email": "user@example.com",
|
||||
"name": "Test User",
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = mock_response
|
||||
email, display_name = await adapter.get_user_info("access-token")
|
||||
|
||||
assert email == "user@example.com", "should return user email"
|
||||
assert display_name == "Test User", "should return display name"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_to_email_prefix_for_display_name(self) -> None:
|
||||
"""get_user_info should use email prefix when name is missing."""
|
||||
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
|
||||
|
||||
adapter = GoogleCalendarAdapter()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"email": "john.doe@example.com",
|
||||
# No "name" field
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = mock_response
|
||||
email, display_name = await adapter.get_user_info("access-token")
|
||||
|
||||
assert email == "john.doe@example.com", "should return email"
|
||||
assert display_name == "John Doe", "should format email prefix as title"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_on_expired_token(self) -> None:
|
||||
"""get_user_info should raise error on 401 response."""
|
||||
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
|
||||
from noteflow.infrastructure.calendar.google_adapter import GoogleCalendarError
|
||||
|
||||
adapter = GoogleCalendarAdapter()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.text = "Token expired"
|
||||
|
||||
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with pytest.raises(GoogleCalendarError, match="expired or invalid"):
|
||||
await adapter.get_user_info("expired-token")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_on_api_error(self) -> None:
|
||||
"""get_user_info should raise error on non-200 response."""
|
||||
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
|
||||
from noteflow.infrastructure.calendar.google_adapter import GoogleCalendarError
|
||||
|
||||
adapter = GoogleCalendarAdapter()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.text = "Internal server error"
|
||||
|
||||
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with pytest.raises(GoogleCalendarError, match="API error"):
|
||||
await adapter.get_user_info("access-token")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_on_invalid_response_type(self) -> None:
|
||||
"""get_user_info should raise error when response is not dict."""
|
||||
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
|
||||
from noteflow.infrastructure.calendar.google_adapter import GoogleCalendarError
|
||||
|
||||
adapter = GoogleCalendarAdapter()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = ["not", "a", "dict"]
|
||||
|
||||
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with pytest.raises(GoogleCalendarError, match="Invalid userinfo"):
|
||||
await adapter.get_user_info("access-token")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_on_missing_email(self) -> None:
|
||||
"""get_user_info should raise error when email is missing."""
|
||||
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
|
||||
from noteflow.infrastructure.calendar.google_adapter import GoogleCalendarError
|
||||
|
||||
adapter = GoogleCalendarAdapter()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"name": "No Email User"}
|
||||
|
||||
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with pytest.raises(GoogleCalendarError, match="No email"):
|
||||
await adapter.get_user_info("access-token")
|
||||
|
||||
159
tests/infrastructure/calendar/test_outlook_adapter.py
Normal file
159
tests/infrastructure/calendar/test_outlook_adapter.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""Tests for Outlook Calendar adapter.
|
||||
|
||||
Tests cover:
|
||||
- get_user_info: Fetching user email and display name from Microsoft Graph API
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test: get_user_info
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestOutlookCalendarAdapterGetUserInfo:
|
||||
"""Tests for OutlookCalendarAdapter.get_user_info."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_email_and_display_name(self) -> None:
|
||||
"""get_user_info should return email and display name."""
|
||||
from noteflow.infrastructure.calendar import OutlookCalendarAdapter
|
||||
|
||||
adapter = OutlookCalendarAdapter()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"mail": "user@example.com",
|
||||
"displayName": "Test User",
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = mock_response
|
||||
email, display_name = await adapter.get_user_info("access-token")
|
||||
|
||||
assert email == "user@example.com", "should return user email"
|
||||
assert display_name == "Test User", "should return display name"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uses_userPrincipalName_when_mail_missing(self) -> None:
|
||||
"""get_user_info should fall back to userPrincipalName for email."""
|
||||
from noteflow.infrastructure.calendar import OutlookCalendarAdapter
|
||||
|
||||
adapter = OutlookCalendarAdapter()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"userPrincipalName": "user@company.onmicrosoft.com",
|
||||
"displayName": "Test User",
|
||||
# No "mail" field
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = mock_response
|
||||
email, display_name = await adapter.get_user_info("access-token")
|
||||
|
||||
assert email == "user@company.onmicrosoft.com", "should use userPrincipalName"
|
||||
assert display_name == "Test User", "should return display name"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_to_email_prefix_for_display_name(self) -> None:
|
||||
"""get_user_info should use email prefix when displayName is missing."""
|
||||
from noteflow.infrastructure.calendar import OutlookCalendarAdapter
|
||||
|
||||
adapter = OutlookCalendarAdapter()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"mail": "john.doe@example.com",
|
||||
# No "displayName" field
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = mock_response
|
||||
email, display_name = await adapter.get_user_info("access-token")
|
||||
|
||||
assert email == "john.doe@example.com", "should return email"
|
||||
assert display_name == "John Doe", "should format email prefix as title"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_on_expired_token(self) -> None:
|
||||
"""get_user_info should raise error on 401 response."""
|
||||
from noteflow.infrastructure.calendar import OutlookCalendarAdapter
|
||||
from noteflow.infrastructure.calendar.outlook_adapter import OutlookCalendarError
|
||||
|
||||
adapter = OutlookCalendarAdapter()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.text = "Token expired"
|
||||
|
||||
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with pytest.raises(OutlookCalendarError, match="expired or invalid"):
|
||||
await adapter.get_user_info("expired-token")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_on_api_error(self) -> None:
|
||||
"""get_user_info should raise error on non-200 response."""
|
||||
from noteflow.infrastructure.calendar import OutlookCalendarAdapter
|
||||
from noteflow.infrastructure.calendar.outlook_adapter import OutlookCalendarError
|
||||
|
||||
adapter = OutlookCalendarAdapter()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.text = "Internal server error"
|
||||
|
||||
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with pytest.raises(OutlookCalendarError, match="API error"):
|
||||
await adapter.get_user_info("access-token")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_on_invalid_response_type(self) -> None:
|
||||
"""get_user_info should raise error when response is not dict."""
|
||||
from noteflow.infrastructure.calendar import OutlookCalendarAdapter
|
||||
from noteflow.infrastructure.calendar.outlook_adapter import OutlookCalendarError
|
||||
|
||||
adapter = OutlookCalendarAdapter()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = ["not", "a", "dict"]
|
||||
|
||||
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with pytest.raises(OutlookCalendarError, match="Invalid user profile"):
|
||||
await adapter.get_user_info("access-token")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_on_missing_email(self) -> None:
|
||||
"""get_user_info should raise error when no email fields present."""
|
||||
from noteflow.infrastructure.calendar import OutlookCalendarAdapter
|
||||
from noteflow.infrastructure.calendar.outlook_adapter import OutlookCalendarError
|
||||
|
||||
adapter = OutlookCalendarAdapter()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"displayName": "No Email User",
|
||||
# Neither "mail" nor "userPrincipalName"
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with pytest.raises(OutlookCalendarError, match="No email"):
|
||||
await adapter.get_user_info("access-token")
|
||||
1
tests/infrastructure/diarization/__init__.py
Normal file
1
tests/infrastructure/diarization/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for diarization infrastructure."""
|
||||
379
tests/infrastructure/diarization/test_compat.py
Normal file
379
tests/infrastructure/diarization/test_compat.py
Normal file
@@ -0,0 +1,379 @@
|
||||
"""Tests for diarization compatibility patches.
|
||||
|
||||
Tests cover:
|
||||
- _patch_torchaudio: AudioMetaData class injection
|
||||
- _patch_torch_load: weights_only=False default for PyTorch 2.6+
|
||||
- _patch_huggingface_auth: use_auth_token → token parameter conversion
|
||||
- _patch_speechbrain_backend: torchaudio backend API restoration
|
||||
- apply_patches: Idempotency and warning suppression
|
||||
- ensure_compatibility: Alias for apply_patches
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from noteflow.infrastructure.diarization._compat import (
|
||||
AudioMetaData,
|
||||
_patch_huggingface_auth,
|
||||
_patch_speechbrain_backend,
|
||||
_patch_torch_load,
|
||||
_patch_torchaudio,
|
||||
apply_patches,
|
||||
ensure_compatibility,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_patches_state() -> Generator[None, None, None]:
|
||||
"""Reset _patches_applied state before and after tests."""
|
||||
import noteflow.infrastructure.diarization._compat as compat_module
|
||||
|
||||
original_state = compat_module._patches_applied
|
||||
compat_module._patches_applied = False
|
||||
yield
|
||||
compat_module._patches_applied = original_state
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_torchaudio() -> MagicMock:
|
||||
"""Create mock torchaudio module without AudioMetaData."""
|
||||
mock = MagicMock(spec=[]) # Empty spec means no auto-attributes
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_torch() -> MagicMock:
|
||||
"""Create mock torch module."""
|
||||
mock = MagicMock()
|
||||
mock.__version__ = "2.6.0"
|
||||
mock.load = MagicMock(return_value={"model": "weights"})
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_huggingface_hub() -> MagicMock:
|
||||
"""Create mock huggingface_hub module."""
|
||||
mock = MagicMock()
|
||||
mock.hf_hub_download = MagicMock(return_value="/path/to/file")
|
||||
return mock
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test: AudioMetaData Dataclass
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestAudioMetaData:
|
||||
"""Tests for the replacement AudioMetaData dataclass."""
|
||||
|
||||
def test_audiometadata_has_required_fields(self) -> None:
|
||||
"""AudioMetaData has all fields expected by pyannote.audio."""
|
||||
metadata = AudioMetaData(
|
||||
sample_rate=16000,
|
||||
num_frames=48000,
|
||||
num_channels=1,
|
||||
bits_per_sample=16,
|
||||
encoding="PCM_S",
|
||||
)
|
||||
|
||||
assert metadata.sample_rate == 16000, "should store sample_rate"
|
||||
assert metadata.num_frames == 48000, "should store num_frames"
|
||||
assert metadata.num_channels == 1, "should store num_channels"
|
||||
assert metadata.bits_per_sample == 16, "should store bits_per_sample"
|
||||
assert metadata.encoding == "PCM_S", "should store encoding"
|
||||
|
||||
def test_audiometadata_is_immutable(self) -> None:
|
||||
"""AudioMetaData fields cannot be modified after creation."""
|
||||
metadata = AudioMetaData(
|
||||
sample_rate=16000,
|
||||
num_frames=48000,
|
||||
num_channels=1,
|
||||
bits_per_sample=16,
|
||||
encoding="PCM_S",
|
||||
)
|
||||
|
||||
# Dataclass is not frozen, so this is documentation of expected behavior
|
||||
# If it becomes frozen, this test validates that
|
||||
metadata.sample_rate = 44100 # May or may not raise depending on frozen
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test: _patch_torchaudio
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPatchTorchaudio:
|
||||
"""Tests for torchaudio AudioMetaData patching."""
|
||||
|
||||
def test_patches_audiometadata_when_missing(
|
||||
self, mock_torchaudio: MagicMock
|
||||
) -> None:
|
||||
"""_patch_torchaudio adds AudioMetaData when not present."""
|
||||
with patch.dict(sys.modules, {"torchaudio": mock_torchaudio}):
|
||||
_patch_torchaudio()
|
||||
|
||||
assert hasattr(
|
||||
mock_torchaudio, "AudioMetaData"
|
||||
), "should add AudioMetaData"
|
||||
assert (
|
||||
mock_torchaudio.AudioMetaData is AudioMetaData
|
||||
), "should use our AudioMetaData class"
|
||||
|
||||
def test_does_not_override_existing_audiometadata(self) -> None:
|
||||
"""_patch_torchaudio preserves existing AudioMetaData if present."""
|
||||
mock = MagicMock()
|
||||
existing_class = type("ExistingAudioMetaData", (), {})
|
||||
mock.AudioMetaData = existing_class
|
||||
|
||||
with patch.dict(sys.modules, {"torchaudio": mock}):
|
||||
_patch_torchaudio()
|
||||
|
||||
assert (
|
||||
mock.AudioMetaData is existing_class
|
||||
), "should not override existing AudioMetaData"
|
||||
|
||||
def test_handles_import_error_gracefully(self) -> None:
|
||||
"""_patch_torchaudio doesn't raise when torchaudio not installed."""
|
||||
# Remove torchaudio from modules if present
|
||||
with patch.dict(sys.modules, {"torchaudio": None}):
|
||||
# Should not raise
|
||||
_patch_torchaudio()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test: _patch_torch_load
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPatchTorchLoad:
|
||||
"""Tests for torch.load weights_only patching."""
|
||||
|
||||
def test_patches_torch_load_for_pytorch_2_6_plus(
|
||||
self, mock_torch: MagicMock
|
||||
) -> None:
|
||||
"""_patch_torch_load adds weights_only=False default for PyTorch 2.6+."""
|
||||
original_load = mock_torch.load
|
||||
|
||||
with patch.dict(sys.modules, {"torch": mock_torch}):
|
||||
with patch("packaging.version.Version") as mock_version:
|
||||
mock_version.return_value = mock_version
|
||||
mock_version.__ge__ = MagicMock(return_value=True)
|
||||
|
||||
_patch_torch_load()
|
||||
|
||||
# Verify torch.load was replaced (not the same function)
|
||||
assert mock_torch.load is not original_load, "load should be patched"
|
||||
|
||||
def test_does_not_patch_older_pytorch(self) -> None:
|
||||
"""_patch_torch_load skips patching for PyTorch < 2.6."""
|
||||
mock = MagicMock()
|
||||
mock.__version__ = "2.5.0"
|
||||
original_load = mock.load
|
||||
|
||||
with patch.dict(sys.modules, {"torch": mock}):
|
||||
with patch("packaging.version.Version") as mock_version:
|
||||
mock_version.return_value = mock_version
|
||||
mock_version.__ge__ = MagicMock(return_value=False)
|
||||
|
||||
_patch_torch_load()
|
||||
|
||||
# load should not have been replaced
|
||||
assert mock.load is original_load, "should not patch older PyTorch"
|
||||
|
||||
def test_handles_import_error_gracefully(self) -> None:
|
||||
"""_patch_torch_load doesn't raise when torch not installed."""
|
||||
with patch.dict(sys.modules, {"torch": None}):
|
||||
_patch_torch_load()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test: _patch_huggingface_auth
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPatchHuggingfaceAuth:
|
||||
"""Tests for huggingface_hub use_auth_token patching."""
|
||||
|
||||
def test_converts_use_auth_token_to_token(
|
||||
self, mock_huggingface_hub: MagicMock
|
||||
) -> None:
|
||||
"""_patch_huggingface_auth converts use_auth_token to token parameter."""
|
||||
original_download = mock_huggingface_hub.hf_hub_download
|
||||
|
||||
with patch.dict(sys.modules, {"huggingface_hub": mock_huggingface_hub}):
|
||||
_patch_huggingface_auth()
|
||||
|
||||
# Call with legacy use_auth_token
|
||||
mock_huggingface_hub.hf_hub_download(
|
||||
repo_id="test/repo",
|
||||
filename="model.bin",
|
||||
use_auth_token="my_token",
|
||||
)
|
||||
|
||||
# Verify original was called with token instead
|
||||
original_download.assert_called_once()
|
||||
call_kwargs = original_download.call_args[1]
|
||||
assert "token" in call_kwargs, "should convert to token parameter"
|
||||
assert call_kwargs["token"] == "my_token", "should preserve token value"
|
||||
assert (
|
||||
"use_auth_token" not in call_kwargs
|
||||
), "should remove use_auth_token"
|
||||
|
||||
def test_preserves_token_parameter(
|
||||
self, mock_huggingface_hub: MagicMock
|
||||
) -> None:
|
||||
"""_patch_huggingface_auth preserves token if already using new API."""
|
||||
original_download = mock_huggingface_hub.hf_hub_download
|
||||
|
||||
with patch.dict(sys.modules, {"huggingface_hub": mock_huggingface_hub}):
|
||||
_patch_huggingface_auth()
|
||||
|
||||
mock_huggingface_hub.hf_hub_download(
|
||||
repo_id="test/repo",
|
||||
filename="model.bin",
|
||||
token="my_token",
|
||||
)
|
||||
|
||||
original_download.assert_called_once()
|
||||
call_kwargs = original_download.call_args[1]
|
||||
assert call_kwargs["token"] == "my_token", "should preserve token"
|
||||
|
||||
def test_handles_import_error_gracefully(self) -> None:
|
||||
"""_patch_huggingface_auth doesn't raise when huggingface_hub not installed."""
|
||||
with patch.dict(sys.modules, {"huggingface_hub": None}):
|
||||
_patch_huggingface_auth()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test: _patch_speechbrain_backend
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPatchSpeechbrainBackend:
|
||||
"""Tests for torchaudio backend API patching."""
|
||||
|
||||
def test_patches_list_audio_backends(self, mock_torchaudio: MagicMock) -> None:
|
||||
"""_patch_speechbrain_backend adds list_audio_backends when missing."""
|
||||
with patch.dict(sys.modules, {"torchaudio": mock_torchaudio}):
|
||||
_patch_speechbrain_backend()
|
||||
|
||||
assert hasattr(
|
||||
mock_torchaudio, "list_audio_backends"
|
||||
), "should add list_audio_backends"
|
||||
result = mock_torchaudio.list_audio_backends()
|
||||
assert isinstance(result, list), "should return list"
|
||||
|
||||
def test_patches_get_audio_backend(self, mock_torchaudio: MagicMock) -> None:
|
||||
"""_patch_speechbrain_backend adds get_audio_backend when missing."""
|
||||
with patch.dict(sys.modules, {"torchaudio": mock_torchaudio}):
|
||||
_patch_speechbrain_backend()
|
||||
|
||||
assert hasattr(
|
||||
mock_torchaudio, "get_audio_backend"
|
||||
), "should add get_audio_backend"
|
||||
result = mock_torchaudio.get_audio_backend()
|
||||
assert result is None, "should return None"
|
||||
|
||||
def test_patches_set_audio_backend(self, mock_torchaudio: MagicMock) -> None:
|
||||
"""_patch_speechbrain_backend adds set_audio_backend when missing."""
|
||||
with patch.dict(sys.modules, {"torchaudio": mock_torchaudio}):
|
||||
_patch_speechbrain_backend()
|
||||
|
||||
assert hasattr(
|
||||
mock_torchaudio, "set_audio_backend"
|
||||
), "should add set_audio_backend"
|
||||
# Should not raise
|
||||
mock_torchaudio.set_audio_backend("sox")
|
||||
|
||||
def test_does_not_override_existing_functions(self) -> None:
|
||||
"""_patch_speechbrain_backend preserves existing backend functions."""
|
||||
mock = MagicMock()
|
||||
existing_list = MagicMock(return_value=["ffmpeg"])
|
||||
mock.list_audio_backends = existing_list
|
||||
|
||||
with patch.dict(sys.modules, {"torchaudio": mock}):
|
||||
_patch_speechbrain_backend()
|
||||
|
||||
assert (
|
||||
mock.list_audio_backends is existing_list
|
||||
), "should not override existing function"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test: apply_patches
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestApplyPatches:
|
||||
"""Tests for the main apply_patches function."""
|
||||
|
||||
def test_apply_patches_is_idempotent(
|
||||
self, reset_patches_state: None
|
||||
) -> None:
|
||||
"""apply_patches only applies patches once."""
|
||||
import noteflow.infrastructure.diarization._compat as compat_module
|
||||
|
||||
with patch.object(compat_module, "_patch_torchaudio") as mock_torchaudio:
|
||||
with patch.object(compat_module, "_patch_torch_load") as mock_torch:
|
||||
with patch.object(
|
||||
compat_module, "_patch_huggingface_auth"
|
||||
) as mock_hf:
|
||||
with patch.object(
|
||||
compat_module, "_patch_speechbrain_backend"
|
||||
) as mock_sb:
|
||||
apply_patches()
|
||||
apply_patches() # Second call
|
||||
apply_patches() # Third call
|
||||
|
||||
# Each patch function should only be called once
|
||||
mock_torchaudio.assert_called_once()
|
||||
mock_torch.assert_called_once()
|
||||
mock_hf.assert_called_once()
|
||||
mock_sb.assert_called_once()
|
||||
|
||||
def test_apply_patches_sets_flag(self, reset_patches_state: None) -> None:
|
||||
"""apply_patches sets _patches_applied flag."""
|
||||
import noteflow.infrastructure.diarization._compat as compat_module
|
||||
|
||||
assert compat_module._patches_applied is False, "should start False"
|
||||
|
||||
with patch.object(compat_module, "_patch_torchaudio"):
|
||||
with patch.object(compat_module, "_patch_torch_load"):
|
||||
with patch.object(compat_module, "_patch_huggingface_auth"):
|
||||
with patch.object(compat_module, "_patch_speechbrain_backend"):
|
||||
apply_patches()
|
||||
|
||||
assert compat_module._patches_applied is True, "should be True after apply"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test: ensure_compatibility
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestEnsureCompatibility:
|
||||
"""Tests for the ensure_compatibility entry point."""
|
||||
|
||||
def test_ensure_compatibility_calls_apply_patches(
|
||||
self, reset_patches_state: None
|
||||
) -> None:
|
||||
"""ensure_compatibility delegates to apply_patches."""
|
||||
import noteflow.infrastructure.diarization._compat as compat_module
|
||||
|
||||
with patch.object(compat_module, "apply_patches") as mock_apply:
|
||||
ensure_compatibility()
|
||||
|
||||
mock_apply.assert_called_once()
|
||||
@@ -13,9 +13,18 @@ class SpeakerDiarizationConfig:
|
||||
*,
|
||||
segmentation: SegmentationModel,
|
||||
embedding: EmbeddingModel,
|
||||
step: float,
|
||||
latency: float,
|
||||
device: TorchDevice,
|
||||
duration: float = ...,
|
||||
step: float = ...,
|
||||
latency: float = ...,
|
||||
tau_active: float = ...,
|
||||
rho_update: float = ...,
|
||||
delta_new: float = ...,
|
||||
gamma: float = ...,
|
||||
beta: float = ...,
|
||||
max_speakers: int = ...,
|
||||
normalize_embedding_weights: bool = ...,
|
||||
device: TorchDevice | None = ...,
|
||||
sample_rate: int = ...,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user