feat: implement identity management features in gRPC service

- Introduced `IdentityMixin` to manage user identity operations, including `GetCurrentUser`, `ListWorkspaces`, and `SwitchWorkspace`.
- Added corresponding gRPC methods and message definitions in the proto file for identity management.
- Enhanced `AuthService` to support user authentication and token management.
- Updated `OAuthManager` to include rate limiting for authentication attempts and improved error handling.
- Implemented unit tests for the new identity management features to ensure functionality and reliability.
This commit is contained in:
2026-01-04 02:21:04 -05:00
parent 7fdda7da37
commit c70105f2b8
25 changed files with 3614 additions and 62 deletions

File diff suppressed because one or more lines are too long

View File

@@ -1,5 +1,12 @@
"""Application services for NoteFlow use cases."""
from noteflow.application.services.auth_service import (
AuthResult,
AuthService,
AuthServiceError,
LogoutResult,
UserInfo,
)
from noteflow.application.services.export_service import ExportFormat, ExportService
from noteflow.application.services.identity_service import IdentityService
from noteflow.application.services.meeting_service import MeetingService
@@ -15,10 +22,15 @@ from noteflow.application.services.summarization_service import (
from noteflow.application.services.trigger_service import TriggerService, TriggerServiceSettings
__all__ = [
"AuthResult",
"AuthService",
"AuthServiceError",
"ExportFormat",
"ExportService",
"IdentityService",
"LogoutResult",
"MeetingService",
"UserInfo",
"ProjectService",
"RecoveryService",
"RetentionReport",

View File

@@ -0,0 +1,563 @@
"""Authentication service for OAuth-based user login.
Extends the CalendarService OAuth patterns for user authentication.
Uses the same OAuthManager infrastructure but stores tokens with
IntegrationType.AUTH and manages User entities.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, TypedDict, Unpack
from uuid import UUID, uuid4
from noteflow.config.constants import OAUTH_FIELD_ACCESS_TOKEN
from noteflow.domain.entities.integration import Integration, IntegrationType
from noteflow.domain.identity.entities import User
from noteflow.domain.value_objects import OAuthProvider, OAuthTokens
from noteflow.infrastructure.calendar import OAuthManager
from noteflow.infrastructure.calendar.google_adapter import GoogleCalendarError
from noteflow.infrastructure.calendar.oauth_manager import OAuthError
from noteflow.infrastructure.calendar.outlook_adapter import OutlookCalendarError
from noteflow.infrastructure.logging import get_logger
class _AuthServiceDepsKwargs(TypedDict, total=False):
"""Optional dependency overrides for AuthService."""
oauth_manager: OAuthManager
if TYPE_CHECKING:
from collections.abc import Callable
from noteflow.config.settings import CalendarIntegrationSettings
from noteflow.domain.ports.unit_of_work import UnitOfWork
logger = get_logger(__name__)
# Default IDs for local-first mode
DEFAULT_USER_ID = UUID("00000000-0000-0000-0000-000000000001")
DEFAULT_WORKSPACE_ID = UUID("00000000-0000-0000-0000-000000000001")
class AuthServiceError(Exception):
"""Auth service operation failed."""
@dataclass(frozen=True, slots=True)
class AuthResult:
"""Result of successful authentication.
Note: Tokens are stored securely in IntegrationSecretModel and are NOT
exposed to callers. Use get_current_user() to check auth status.
"""
user_id: UUID
workspace_id: UUID
display_name: str
email: str | None
is_authenticated: bool = True
@dataclass(frozen=True, slots=True)
class UserInfo:
"""Current user information."""
user_id: UUID
workspace_id: UUID
display_name: str
email: str | None
is_authenticated: bool
provider: str | None
@dataclass(frozen=True, slots=True)
class LogoutResult:
"""Result of logout operation.
Provides visibility into both local logout and remote token revocation.
"""
logged_out: bool
"""Whether local logout succeeded (integration deleted)."""
tokens_revoked: bool
"""Whether remote token revocation succeeded."""
revocation_error: str | None = None
"""Error message if revocation failed (for logging/debugging)."""
class AuthService:
"""Authentication service for OAuth-based user login.
Orchestrates OAuth flow for user authentication. Uses:
- IntegrationRepository for auth token storage
- OAuthManager for PKCE OAuth flow
- UserRepository for user entity management
Unlike CalendarService which manages calendar integrations,
AuthService manages user identity and authentication state.
"""
def __init__(
self,
uow_factory: Callable[[], UnitOfWork],
settings: CalendarIntegrationSettings,
**kwargs: Unpack[_AuthServiceDepsKwargs],
) -> None:
"""Initialize auth service.
Args:
uow_factory: Factory function returning UnitOfWork instances.
settings: OAuth settings with credentials.
**kwargs: Optional dependency overrides.
"""
self._uow_factory = uow_factory
self._settings = settings
oauth_manager = kwargs.get("oauth_manager")
self._oauth_manager = oauth_manager or OAuthManager(settings)
async def initiate_login(
self,
provider: str,
redirect_uri: str | None = None,
) -> tuple[str, str]:
"""Start OAuth login flow.
Args:
provider: Provider name ('google' or 'outlook').
redirect_uri: Optional override for OAuth callback URI.
Returns:
Tuple of (authorization_url, state_token).
Raises:
AuthServiceError: If provider is invalid or credentials not configured.
"""
oauth_provider = self._parse_provider(provider)
effective_redirect = redirect_uri or self._settings.redirect_uri
try:
auth_url, state = self._oauth_manager.initiate_auth(
provider=oauth_provider,
redirect_uri=effective_redirect,
)
logger.info(
"auth_login_initiated",
event_type="security",
provider=provider,
redirect_uri=effective_redirect,
)
return auth_url, state
except OAuthError as e:
logger.warning(
"auth_login_initiation_failed",
event_type="security",
provider=provider,
error=str(e),
)
raise AuthServiceError(str(e)) from e
async def complete_login(
self,
provider: str,
code: str,
state: str,
) -> AuthResult:
"""Complete OAuth login and create/update user.
Exchanges authorization code for tokens, fetches user info,
and creates or updates the User entity.
Args:
provider: Provider name ('google' or 'outlook').
code: Authorization code from OAuth callback.
state: State parameter from OAuth callback.
Returns:
AuthResult with user identity and tokens.
Raises:
AuthServiceError: If OAuth exchange fails.
"""
oauth_provider = self._parse_provider(provider)
# Exchange code for tokens
tokens = await self._exchange_tokens(oauth_provider, code, state)
# Fetch user info from provider
email, display_name = await self._fetch_user_info(
oauth_provider, tokens.access_token
)
# Create or update user and store tokens
user_id, workspace_id = await self._store_auth_user(
provider, email, display_name, tokens
)
logger.info(
"auth_login_completed",
event_type="security",
provider=provider,
email=email,
user_id=str(user_id),
workspace_id=str(workspace_id),
)
return AuthResult(
user_id=user_id,
workspace_id=workspace_id,
display_name=display_name,
email=email,
)
async def _exchange_tokens(
self,
oauth_provider: OAuthProvider,
code: str,
state: str,
) -> OAuthTokens:
"""Exchange authorization code for tokens."""
try:
return await self._oauth_manager.complete_auth(
provider=oauth_provider,
code=code,
state=state,
)
except OAuthError as e:
raise AuthServiceError(f"OAuth failed: {e}") from e
async def _fetch_user_info(
self,
oauth_provider: OAuthProvider,
access_token: str,
) -> tuple[str, str]:
"""Fetch user email and display name from provider."""
# Use the calendar adapter to get user info (reuse existing infrastructure)
from noteflow.infrastructure.calendar.google_adapter import GoogleCalendarAdapter
from noteflow.infrastructure.calendar.outlook_adapter import OutlookCalendarAdapter
try:
if oauth_provider == OAuthProvider.GOOGLE:
adapter = GoogleCalendarAdapter()
email, display_name = await adapter.get_user_info(access_token)
else:
adapter = OutlookCalendarAdapter()
email, display_name = await adapter.get_user_info(access_token)
return email, display_name
except (GoogleCalendarError, OutlookCalendarError, OAuthError) as e:
raise AuthServiceError(f"Failed to get user info: {e}") from e
async def _store_auth_user(
self,
provider: str,
email: str,
display_name: str,
tokens: OAuthTokens,
) -> tuple[UUID, UUID]:
"""Create or update user and store auth tokens."""
async with self._uow_factory() as uow:
# Find or create user by email
user = None
if uow.supports_users:
user = await uow.users.get_by_email(email)
if user is None:
# Create new user
user_id = uuid4()
if uow.supports_users:
user = User(
id=user_id,
email=email,
display_name=display_name,
is_default=False,
)
await uow.users.create(user)
logger.info("Created new user: %s (%s)", display_name, email)
else:
user_id = DEFAULT_USER_ID
else:
user_id = user.id
# Update display name if changed
if user.display_name != display_name:
user.display_name = display_name
await uow.users.update(user)
# Get or create default workspace for this user
workspace_id = DEFAULT_WORKSPACE_ID
if uow.supports_workspaces:
workspace = await uow.workspaces.get_default_for_user(user_id)
if workspace:
workspace_id = workspace.id
else:
# Create default "Personal" workspace for new user
workspace_id = uuid4()
await uow.workspaces.create(
workspace_id=workspace_id,
name="Personal",
owner_id=user_id,
is_default=True,
)
logger.info(
"Created default workspace for user_id=%s, workspace_id=%s",
user_id,
workspace_id,
)
# Store auth integration with tokens
integration = await uow.integrations.get_by_provider(
provider=provider,
integration_type=IntegrationType.AUTH.value,
)
if integration is None:
integration = Integration.create(
workspace_id=workspace_id,
name=f"{provider.title()} Auth",
integration_type=IntegrationType.AUTH,
config={"provider": provider, "user_id": str(user_id)},
)
await uow.integrations.create(integration)
else:
integration.config["provider"] = provider
integration.config["user_id"] = str(user_id)
integration.connect(provider_email=email)
await uow.integrations.update(integration)
# Store tokens
await uow.integrations.set_secrets(
integration_id=integration.id,
secrets=tokens.to_secrets_dict(),
)
await uow.commit()
return user_id, workspace_id
async def get_current_user(self) -> UserInfo:
"""Get current authenticated user info.
Returns:
UserInfo with current user details or local default.
"""
async with self._uow_factory() as uow:
# Look for any connected auth integration
for provider in [OAuthProvider.GOOGLE.value, OAuthProvider.OUTLOOK.value]:
integration = await uow.integrations.get_by_provider(
provider=provider,
integration_type=IntegrationType.AUTH.value,
)
if integration and integration.is_connected:
user_id_str = integration.config.get("user_id")
user_id = UUID(user_id_str) if user_id_str else DEFAULT_USER_ID
# Get user details
display_name = "Authenticated User"
if uow.supports_users and user_id_str:
user = await uow.users.get(user_id)
if user:
display_name = user.display_name
return UserInfo(
user_id=user_id,
workspace_id=DEFAULT_WORKSPACE_ID,
display_name=display_name,
email=integration.provider_email,
is_authenticated=True,
provider=provider,
)
# Return local default
return UserInfo(
user_id=DEFAULT_USER_ID,
workspace_id=DEFAULT_WORKSPACE_ID,
display_name="Local User",
email=None,
is_authenticated=False,
provider=None,
)
async def logout(self, provider: str | None = None) -> LogoutResult:
"""Logout and revoke auth tokens.
Args:
provider: Optional specific provider to logout from.
If None, logs out from all providers.
Returns:
LogoutResult with details on local logout and token revocation.
"""
providers = (
[provider]
if provider
else [OAuthProvider.GOOGLE.value, OAuthProvider.OUTLOOK.value]
)
logged_out = False
all_revoked = True
revocation_errors: list[str] = []
for p in providers:
result = await self._logout_provider(p)
logged_out = logged_out or result.logged_out
if not result.tokens_revoked:
all_revoked = False
if result.revocation_error:
revocation_errors.append(f"{p}: {result.revocation_error}")
return LogoutResult(
logged_out=logged_out,
tokens_revoked=all_revoked,
revocation_error="; ".join(revocation_errors) if revocation_errors else None,
)
async def _logout_provider(self, provider: str) -> LogoutResult:
"""Logout from a specific provider."""
oauth_provider = self._parse_provider(provider)
async with self._uow_factory() as uow:
integration = await uow.integrations.get_by_provider(
provider=provider,
integration_type=IntegrationType.AUTH.value,
)
if integration is None:
return LogoutResult(
logged_out=False,
tokens_revoked=True, # No tokens to revoke
)
# Get tokens for revocation
secrets = await uow.integrations.get_secrets(integration.id)
access_token = secrets.get(OAUTH_FIELD_ACCESS_TOKEN) if secrets else None
# Delete integration
await uow.integrations.delete(integration.id)
await uow.commit()
# Revoke tokens (best effort)
tokens_revoked = True
revocation_error: str | None = None
if access_token:
try:
await self._oauth_manager.revoke_tokens(oauth_provider, access_token)
logger.info(
"auth_tokens_revoked",
event_type="security",
provider=provider,
)
except OAuthError as e:
tokens_revoked = False
revocation_error = str(e)
logger.warning(
"auth_token_revocation_failed",
event_type="security",
provider=provider,
error=revocation_error,
)
logger.info(
"auth_logout_completed",
event_type="security",
provider=provider,
tokens_revoked=tokens_revoked,
)
return LogoutResult(
logged_out=True,
tokens_revoked=tokens_revoked,
revocation_error=revocation_error,
)
async def refresh_auth_tokens(self, provider: str) -> AuthResult | None:
"""Refresh expired auth tokens.
Args:
provider: Provider to refresh tokens for.
Returns:
Updated AuthResult or None if refresh failed.
"""
oauth_provider = self._parse_provider(provider)
async with self._uow_factory() as uow:
integration = await uow.integrations.get_by_provider(
provider=provider,
integration_type=IntegrationType.AUTH.value,
)
if integration is None or not integration.is_connected:
return None
secrets = await uow.integrations.get_secrets(integration.id)
if not secrets:
return None
try:
tokens = OAuthTokens.from_secrets_dict(secrets)
except (KeyError, ValueError):
return None
if not tokens.refresh_token:
return None
# Only refresh if token is expired or will expire within 5 minutes
if not tokens.is_expired(buffer_seconds=300):
logger.debug(
"auth_token_still_valid",
provider=provider,
expires_at=tokens.expires_at.isoformat() if tokens.expires_at else None,
)
# Return existing auth info without refreshing
user_id_str = integration.config.get("user_id")
user_id = UUID(user_id_str) if user_id_str else DEFAULT_USER_ID
return AuthResult(
user_id=user_id,
workspace_id=DEFAULT_WORKSPACE_ID,
display_name=integration.provider_email or "User",
email=integration.provider_email,
)
try:
new_tokens = await self._oauth_manager.refresh_tokens(
provider=oauth_provider,
refresh_token=tokens.refresh_token,
)
await uow.integrations.set_secrets(
integration_id=integration.id,
secrets=new_tokens.to_secrets_dict(),
)
await uow.commit()
user_id_str = integration.config.get("user_id")
user_id = UUID(user_id_str) if user_id_str else DEFAULT_USER_ID
return AuthResult(
user_id=user_id,
workspace_id=DEFAULT_WORKSPACE_ID,
display_name=integration.provider_email or "User",
email=integration.provider_email,
)
except OAuthError as e:
integration.mark_error(f"Token refresh failed: {e}")
await uow.integrations.update(integration)
await uow.commit()
return None
@staticmethod
def _parse_provider(provider: str) -> OAuthProvider:
"""Parse and validate provider string."""
try:
return OAuthProvider(provider.lower())
except ValueError as e:
raise AuthServiceError(
f"Invalid provider: {provider}. Must be 'google' or 'outlook'."
) from e

View File

@@ -3,7 +3,7 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime
from datetime import datetime, timedelta
from enum import Enum, IntEnum, StrEnum
from typing import NewType
from uuid import UUID
@@ -145,9 +145,18 @@ class OAuthTokens:
expires_at: datetime
scope: str
def is_expired(self) -> bool:
"""Check if the access token has expired."""
return datetime.now(self.expires_at.tzinfo) > self.expires_at
def is_expired(self, buffer_seconds: int = 0) -> bool:
"""Check if the access token has expired or will expire within buffer.
Args:
buffer_seconds: Consider token expired this many seconds before
actual expiry. Useful for proactive refresh.
Returns:
True if token is expired or will expire within buffer_seconds.
"""
effective_expiry = self.expires_at - timedelta(seconds=buffer_seconds)
return datetime.now(self.expires_at.tzinfo) > effective_expiry
def to_secrets_dict(self) -> dict[str, str]:
"""Convert to dictionary for encrypted storage."""

View File

@@ -9,6 +9,7 @@ from noteflow.config.constants import DEFAULT_GRPC_PORT
if TYPE_CHECKING:
from noteflow.application.services.calendar_service import CalendarService
from noteflow.application.services.identity_service import IdentityService
from noteflow.application.services.ner_service import NerService
from noteflow.application.services.project_service import ProjectService
from noteflow.application.services.summarization_service import SummarizationService
@@ -203,6 +204,7 @@ class ServicesConfig:
calendar_service: Service for OAuth and calendar event fetching.
webhook_service: Service for webhook event notifications.
project_service: Service for project management.
identity_service: Service for identity and workspace context management.
"""
summarization_service: SummarizationService | None = None
@@ -212,3 +214,4 @@ class ServicesConfig:
calendar_service: CalendarService | None = None
webhook_service: WebhookService | None = None
project_service: ProjectService | None = None
identity_service: IdentityService | None = None

View File

@@ -7,6 +7,7 @@ from .diarization import DiarizationMixin
from .diarization_job import DiarizationJobMixin
from .entities import EntitiesMixin
from .export import ExportMixin
from .identity import IdentityMixin
from .meeting import MeetingMixin
from .observability import ObservabilityMixin
from .oidc import OidcMixin
@@ -26,6 +27,7 @@ __all__ = [
"DiarizationMixin",
"EntitiesMixin",
"ExportMixin",
"IdentityMixin",
"MeetingMixin",
"ObservabilityMixin",
"OidcMixin",

View File

@@ -0,0 +1,186 @@
"""Identity management mixin for gRPC service."""
from __future__ import annotations
from typing import Protocol
from noteflow.application.services.identity_service import IdentityService
from noteflow.domain.entities.integration import IntegrationType
from noteflow.domain.identity.context import OperationContext
from noteflow.domain.ports.unit_of_work import UnitOfWork
from noteflow.infrastructure.logging import get_logger
from ..proto import noteflow_pb2
from .errors import abort_database_required, abort_invalid_argument, abort_not_found, parse_workspace_id
from ._types import GrpcContext
logger = get_logger(__name__)
class IdentityServicer(Protocol):
"""Protocol for hosts that support identity operations."""
def create_repository_provider(self) -> UnitOfWork: ...
def get_operation_context(self, context: GrpcContext) -> OperationContext: ...
@property
def identity_service(self) -> IdentityService: ...
class IdentityMixin:
"""Mixin providing identity management functionality.
Implements:
- GetCurrentUser: Get current user's identity info
- ListWorkspaces: List workspaces user belongs to
- SwitchWorkspace: Switch to a different workspace
"""
async def GetCurrentUser(
self: IdentityServicer,
request: noteflow_pb2.GetCurrentUserRequest,
context: GrpcContext,
) -> noteflow_pb2.GetCurrentUserResponse:
"""Get current authenticated user info."""
# Note: op_context from headers provides request metadata
_ = self.get_operation_context(context)
async with self.create_repository_provider() as uow:
# Get or create default user/workspace for local-first mode
user_ctx = await self.identity_service.get_or_create_default_user(uow)
ws_ctx = await self.identity_service.get_or_create_default_workspace(
uow, user_ctx.user_id
)
await uow.commit()
# Check if user has auth integration (authenticated via OAuth)
is_authenticated = False
auth_provider = ""
if hasattr(uow, "supports_integrations") and uow.supports_integrations:
for provider in ["google", "outlook"]:
integration = await uow.integrations.get_by_provider(
provider=provider,
integration_type=IntegrationType.AUTH.value,
)
if integration and integration.is_connected:
is_authenticated = True
auth_provider = provider
break
logger.debug(
"GetCurrentUser: user_id=%s, workspace_id=%s, authenticated=%s",
user_ctx.user_id,
ws_ctx.workspace_id,
is_authenticated,
)
return noteflow_pb2.GetCurrentUserResponse(
user_id=str(user_ctx.user_id),
workspace_id=str(ws_ctx.workspace_id),
display_name=user_ctx.display_name,
email=user_ctx.email or "",
is_authenticated=is_authenticated,
auth_provider=auth_provider,
workspace_name=ws_ctx.workspace_name,
role=ws_ctx.role.value,
)
async def ListWorkspaces(
self: IdentityServicer,
request: noteflow_pb2.ListWorkspacesRequest,
context: GrpcContext,
) -> noteflow_pb2.ListWorkspacesResponse:
"""List workspaces the current user belongs to."""
_ = self.get_operation_context(context)
async with self.create_repository_provider() as uow:
if not uow.supports_workspaces:
await abort_database_required(context, "Workspaces")
user_ctx = await self.identity_service.get_or_create_default_user(uow)
limit = request.limit if request.limit > 0 else 50
offset = request.offset if request.offset >= 0 else 0
workspaces = await self.identity_service.list_workspaces(
uow, user_ctx.user_id, limit, offset
)
workspace_protos: list[noteflow_pb2.WorkspaceProto] = []
for ws in workspaces:
membership = await uow.workspaces.get_membership(ws.id, user_ctx.user_id)
role = membership.role.value if membership else "member"
workspace_protos.append(
noteflow_pb2.WorkspaceProto(
id=str(ws.id),
name=ws.name,
slug=ws.slug or "",
is_default=ws.is_default,
role=role,
)
)
logger.debug(
"ListWorkspaces: user_id=%s, count=%d",
user_ctx.user_id,
len(workspace_protos),
)
return noteflow_pb2.ListWorkspacesResponse(
workspaces=workspace_protos,
total_count=len(workspace_protos),
)
async def SwitchWorkspace(
self: IdentityServicer,
request: noteflow_pb2.SwitchWorkspaceRequest,
context: GrpcContext,
) -> noteflow_pb2.SwitchWorkspaceResponse:
"""Switch to a different workspace."""
_ = self.get_operation_context(context)
if not request.workspace_id:
await abort_invalid_argument(context, "workspace_id is required")
# Parse and validate workspace ID (aborts with INVALID_ARGUMENT if invalid)
workspace_id = await parse_workspace_id(request.workspace_id, context)
async with self.create_repository_provider() as uow:
if not uow.supports_workspaces:
await abort_database_required(context, "Workspaces")
user_ctx = await self.identity_service.get_or_create_default_user(uow)
# Verify workspace exists
workspace = await uow.workspaces.get(workspace_id)
if not workspace:
await abort_not_found(context, "Workspace", str(workspace_id))
# Verify user has access
membership = await uow.workspaces.get_membership(
workspace_id, user_ctx.user_id
)
if not membership:
await abort_not_found(
context, "Workspace membership", str(workspace_id)
)
logger.info(
"SwitchWorkspace: user_id=%s, workspace_id=%s",
user_ctx.user_id,
workspace_id,
)
return noteflow_pb2.SwitchWorkspaceResponse(
success=True,
workspace=noteflow_pb2.WorkspaceProto(
id=str(workspace.id),
name=workspace.name,
slug=workspace.slug or "",
is_default=workspace.is_default,
role=membership.role.value,
),
)

View File

@@ -114,6 +114,11 @@ service NoteFlowService {
rpc UpdateProjectMemberRole(UpdateProjectMemberRoleRequest) returns (ProjectMembershipProto);
rpc RemoveProjectMember(RemoveProjectMemberRequest) returns (RemoveProjectMemberResponse);
rpc ListProjectMembers(ListProjectMembersRequest) returns (ListProjectMembersResponse);
// Identity management (Sprint 16+)
rpc GetCurrentUser(GetCurrentUserRequest) returns (GetCurrentUserResponse);
rpc ListWorkspaces(ListWorkspacesRequest) returns (ListWorkspacesResponse);
rpc SwitchWorkspace(SwitchWorkspaceRequest) returns (SwitchWorkspaceResponse);
}
// =============================================================================
@@ -1848,3 +1853,86 @@ message ListProjectMembersResponse {
// Total count
int32 total_count = 2;
}
// =============================================================================
// Identity Management Messages (Sprint 16+)
// =============================================================================
message GetCurrentUserRequest {
// Empty - user ID comes from request headers
}
message GetCurrentUserResponse {
// User ID (UUID string)
string user_id = 1;
// Current workspace ID (UUID string)
string workspace_id = 2;
// User display name
string display_name = 3;
// User email (optional)
string email = 4;
// Whether user is authenticated (vs local mode)
bool is_authenticated = 5;
// OAuth provider if authenticated (google, outlook, etc.)
string auth_provider = 6;
// Workspace name
string workspace_name = 7;
// User's role in workspace
string role = 8;
}
message WorkspaceProto {
// Workspace ID (UUID string)
string id = 1;
// Workspace name
string name = 2;
// URL slug
string slug = 3;
// Whether this is the default workspace
bool is_default = 4;
// User's role in this workspace
string role = 5;
}
message ListWorkspacesRequest {
// Maximum workspaces to return (default: 50)
int32 limit = 1;
// Pagination offset
int32 offset = 2;
}
message ListWorkspacesResponse {
// User's workspaces
repeated WorkspaceProto workspaces = 1;
// Total count
int32 total_count = 2;
}
message SwitchWorkspaceRequest {
// Workspace ID to switch to
string workspace_id = 1;
}
message SwitchWorkspaceResponse {
// Whether switch succeeded
bool success = 1;
// New current workspace info
WorkspaceProto workspace = 2;
// Error message if failed
string error_message = 3;
}

File diff suppressed because one or more lines are too long

View File

@@ -1621,3 +1621,73 @@ class ListProjectMembersResponse(_message.Message):
members: _containers.RepeatedCompositeFieldContainer[ProjectMembershipProto]
total_count: int
def __init__(self, members: _Optional[_Iterable[_Union[ProjectMembershipProto, _Mapping]]] = ..., total_count: _Optional[int] = ...) -> None: ...
class GetCurrentUserRequest(_message.Message):
__slots__ = ()
def __init__(self) -> None: ...
class GetCurrentUserResponse(_message.Message):
__slots__ = ("user_id", "workspace_id", "display_name", "email", "is_authenticated", "auth_provider", "workspace_name", "role")
USER_ID_FIELD_NUMBER: _ClassVar[int]
WORKSPACE_ID_FIELD_NUMBER: _ClassVar[int]
DISPLAY_NAME_FIELD_NUMBER: _ClassVar[int]
EMAIL_FIELD_NUMBER: _ClassVar[int]
IS_AUTHENTICATED_FIELD_NUMBER: _ClassVar[int]
AUTH_PROVIDER_FIELD_NUMBER: _ClassVar[int]
WORKSPACE_NAME_FIELD_NUMBER: _ClassVar[int]
ROLE_FIELD_NUMBER: _ClassVar[int]
user_id: str
workspace_id: str
display_name: str
email: str
is_authenticated: bool
auth_provider: str
workspace_name: str
role: str
def __init__(self, user_id: _Optional[str] = ..., workspace_id: _Optional[str] = ..., display_name: _Optional[str] = ..., email: _Optional[str] = ..., is_authenticated: bool = ..., auth_provider: _Optional[str] = ..., workspace_name: _Optional[str] = ..., role: _Optional[str] = ...) -> None: ...
class WorkspaceProto(_message.Message):
__slots__ = ("id", "name", "slug", "is_default", "role")
ID_FIELD_NUMBER: _ClassVar[int]
NAME_FIELD_NUMBER: _ClassVar[int]
SLUG_FIELD_NUMBER: _ClassVar[int]
IS_DEFAULT_FIELD_NUMBER: _ClassVar[int]
ROLE_FIELD_NUMBER: _ClassVar[int]
id: str
name: str
slug: str
is_default: bool
role: str
def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., slug: _Optional[str] = ..., is_default: bool = ..., role: _Optional[str] = ...) -> None: ...
class ListWorkspacesRequest(_message.Message):
__slots__ = ("limit", "offset")
LIMIT_FIELD_NUMBER: _ClassVar[int]
OFFSET_FIELD_NUMBER: _ClassVar[int]
limit: int
offset: int
def __init__(self, limit: _Optional[int] = ..., offset: _Optional[int] = ...) -> None: ...
class ListWorkspacesResponse(_message.Message):
__slots__ = ("workspaces", "total_count")
WORKSPACES_FIELD_NUMBER: _ClassVar[int]
TOTAL_COUNT_FIELD_NUMBER: _ClassVar[int]
workspaces: _containers.RepeatedCompositeFieldContainer[WorkspaceProto]
total_count: int
def __init__(self, workspaces: _Optional[_Iterable[_Union[WorkspaceProto, _Mapping]]] = ..., total_count: _Optional[int] = ...) -> None: ...
class SwitchWorkspaceRequest(_message.Message):
__slots__ = ("workspace_id",)
WORKSPACE_ID_FIELD_NUMBER: _ClassVar[int]
workspace_id: str
def __init__(self, workspace_id: _Optional[str] = ...) -> None: ...
class SwitchWorkspaceResponse(_message.Message):
__slots__ = ("success", "workspace", "error_message")
SUCCESS_FIELD_NUMBER: _ClassVar[int]
WORKSPACE_FIELD_NUMBER: _ClassVar[int]
ERROR_MESSAGE_FIELD_NUMBER: _ClassVar[int]
success: bool
workspace: WorkspaceProto
error_message: str
def __init__(self, success: bool = ..., workspace: _Optional[_Union[WorkspaceProto, _Mapping]] = ..., error_message: _Optional[str] = ...) -> None: ...

View File

@@ -363,6 +363,21 @@ class NoteFlowServiceStub(object):
request_serializer=noteflow__pb2.ListProjectMembersRequest.SerializeToString,
response_deserializer=noteflow__pb2.ListProjectMembersResponse.FromString,
_registered_method=True)
self.GetCurrentUser = channel.unary_unary(
'/noteflow.NoteFlowService/GetCurrentUser',
request_serializer=noteflow__pb2.GetCurrentUserRequest.SerializeToString,
response_deserializer=noteflow__pb2.GetCurrentUserResponse.FromString,
_registered_method=True)
self.ListWorkspaces = channel.unary_unary(
'/noteflow.NoteFlowService/ListWorkspaces',
request_serializer=noteflow__pb2.ListWorkspacesRequest.SerializeToString,
response_deserializer=noteflow__pb2.ListWorkspacesResponse.FromString,
_registered_method=True)
self.SwitchWorkspace = channel.unary_unary(
'/noteflow.NoteFlowService/SwitchWorkspace',
request_serializer=noteflow__pb2.SwitchWorkspaceRequest.SerializeToString,
response_deserializer=noteflow__pb2.SwitchWorkspaceResponse.FromString,
_registered_method=True)
class NoteFlowServiceServicer(object):
@@ -782,6 +797,25 @@ class NoteFlowServiceServicer(object):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def GetCurrentUser(self, request, context):
"""Identity management (Sprint 16+)
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def ListWorkspaces(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SwitchWorkspace(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_NoteFlowServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
@@ -1110,6 +1144,21 @@ def add_NoteFlowServiceServicer_to_server(servicer, server):
request_deserializer=noteflow__pb2.ListProjectMembersRequest.FromString,
response_serializer=noteflow__pb2.ListProjectMembersResponse.SerializeToString,
),
'GetCurrentUser': grpc.unary_unary_rpc_method_handler(
servicer.GetCurrentUser,
request_deserializer=noteflow__pb2.GetCurrentUserRequest.FromString,
response_serializer=noteflow__pb2.GetCurrentUserResponse.SerializeToString,
),
'ListWorkspaces': grpc.unary_unary_rpc_method_handler(
servicer.ListWorkspaces,
request_deserializer=noteflow__pb2.ListWorkspacesRequest.FromString,
response_serializer=noteflow__pb2.ListWorkspacesResponse.SerializeToString,
),
'SwitchWorkspace': grpc.unary_unary_rpc_method_handler(
servicer.SwitchWorkspace,
request_deserializer=noteflow__pb2.SwitchWorkspaceRequest.FromString,
response_serializer=noteflow__pb2.SwitchWorkspaceResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'noteflow.NoteFlowService', rpc_method_handlers)
@@ -2879,3 +2928,84 @@ class NoteFlowService(object):
timeout,
metadata,
_registered_method=True)
@staticmethod
def GetCurrentUser(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/noteflow.NoteFlowService/GetCurrentUser',
noteflow__pb2.GetCurrentUserRequest.SerializeToString,
noteflow__pb2.GetCurrentUserResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def ListWorkspaces(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/noteflow.NoteFlowService/ListWorkspaces',
noteflow__pb2.ListWorkspacesRequest.SerializeToString,
noteflow__pb2.ListWorkspacesResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SwitchWorkspace(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/noteflow.NoteFlowService/SwitchWorkspace',
noteflow__pb2.SwitchWorkspaceRequest.SerializeToString,
noteflow__pb2.SwitchWorkspaceResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

View File

@@ -9,7 +9,13 @@ from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, ClassVar, Final
from uuid import UUID
from noteflow import __version__
from noteflow.application.services.identity_service import IdentityService
from noteflow.domain.identity.context import OperationContext, UserContext, WorkspaceContext
from noteflow.domain.identity.roles import WorkspaceRole
from noteflow.infrastructure.logging import request_id_var, user_id_var, workspace_id_var
from noteflow.config.constants import APP_DIR_NAME
from noteflow.config.constants import DEFAULT_SAMPLE_RATE as _DEFAULT_SAMPLE_RATE
from noteflow.domain.entities import Meeting
@@ -34,6 +40,7 @@ from ._mixins import (
DiarizationMixin,
EntitiesMixin,
ExportMixin,
IdentityMixin,
MeetingMixin,
ObservabilityMixin,
OidcMixin,
@@ -71,6 +78,18 @@ else:
pass
# Module-level singleton for identity service (stateless, no dependencies)
_identity_service_instance: IdentityService | None = None
def _default_identity_service() -> IdentityService:
"""Get or create the default identity service singleton."""
global _identity_service_instance
if _identity_service_instance is None:
_identity_service_instance = IdentityService()
return _identity_service_instance
class NoteFlowServicer(
StreamingMixin,
DiarizationMixin,
@@ -86,6 +105,7 @@ class NoteFlowServicer(
ObservabilityMixin,
PreferencesMixin,
OidcMixin,
IdentityMixin,
ProjectMixin,
ProjectMembershipMixin,
NoteFlowServicerStubs,
@@ -138,6 +158,8 @@ class NoteFlowServicer(
self.calendar_service = services.calendar_service
self.webhook_service = services.webhook_service
self.project_service = services.project_service
# Identity service - always available (stateless, no dependencies)
self.identity_service = services.identity_service or _default_identity_service()
self._start_time = time.time()
self.memory_store: MeetingStore | None = MeetingStore() if session_factory is None else None
# Audio infrastructure
@@ -195,6 +217,40 @@ class NoteFlowServicer(
return SqlAlchemyUnitOfWork(self.session_factory, self.meetings_dir)
return MemoryUnitOfWork(self.get_memory_store())
def get_operation_context(self, context: GrpcContext) -> OperationContext:
"""Get operation context from gRPC context variables.
Read identity information set by the IdentityInterceptor from
context variables and construct an OperationContext.
Args:
context: gRPC service context (used for metadata if needed).
Returns:
OperationContext with user, workspace, and request info.
"""
# Read from context variables set by IdentityInterceptor
request_id = request_id_var.get()
user_id_str = user_id_var.get()
workspace_id_str = workspace_id_var.get()
# Default IDs for local-first mode
default_user_id = UUID("00000000-0000-0000-0000-000000000001")
default_workspace_id = UUID("00000000-0000-0000-0000-000000000001")
user_id = UUID(user_id_str) if user_id_str else default_user_id
workspace_id = UUID(workspace_id_str) if workspace_id_str else default_workspace_id
return OperationContext(
user=UserContext(user_id=user_id, display_name=""),
workspace=WorkspaceContext(
workspace_id=workspace_id,
workspace_name="",
role=WorkspaceRole.OWNER,
),
request_id=request_id,
)
def init_streaming_state(self, meeting_id: str, next_segment_id: int) -> None:
"""Initialize VAD, Segmenter, speaking state, and partial buffers for a meeting."""
# Create core components

View File

@@ -148,6 +148,21 @@ class GoogleCalendarAdapter(CalendarPort):
Returns:
User's email address.
Raises:
GoogleCalendarError: If API call fails.
"""
email, _ = await self.get_user_info(access_token)
return email
async def get_user_info(self, access_token: str) -> tuple[str, str]:
"""Get authenticated user's email and display name.
Args:
access_token: Valid OAuth access token.
Returns:
Tuple of (email, display_name).
Raises:
GoogleCalendarError: If API call fails.
"""
@@ -168,11 +183,21 @@ class GoogleCalendarAdapter(CalendarPort):
if not isinstance(data_value, dict):
raise GoogleCalendarError("Invalid userinfo response")
data = cast(dict[str, object], data_value)
if email := data.get("email"):
return str(email)
else:
email = data.get("email")
if not email:
raise GoogleCalendarError("No email in userinfo response")
# Get display name from 'name' field, fall back to email prefix
name = data.get("name")
display_name = (
str(name)
if name
else str(email).split("@")[0].replace(".", " ").title()
)
return str(email), display_name
def _parse_event(self, item: _GoogleEvent) -> CalendarEventInfo:
"""Parse Google Calendar event into CalendarEventInfo."""
event_id = str(item.get("id", ""))

View File

@@ -78,6 +78,13 @@ class OAuthManager(OAuthPort):
# State TTL (10 minutes)
STATE_TTL_SECONDS = 600
# Maximum pending states to prevent memory exhaustion
MAX_PENDING_STATES = 100
# Rate limiting for auth attempts (per provider)
MAX_AUTH_ATTEMPTS_PER_MINUTE = 10
AUTH_RATE_LIMIT_WINDOW_SECONDS = 60
def __init__(self, settings: CalendarIntegrationSettings) -> None:
"""Initialize OAuth manager with calendar settings.
@@ -86,6 +93,8 @@ class OAuthManager(OAuthPort):
"""
self._settings = settings
self._pending_states: dict[str, OAuthState] = {}
# Track auth attempt timestamps per provider for rate limiting
self._auth_attempts: dict[str, list[datetime]] = {}
def get_pending_state(self, state_token: str) -> OAuthState | None:
"""Get pending OAuth state by token.
@@ -137,6 +146,16 @@ class OAuthManager(OAuthPort):
"""
self._cleanup_expired_states()
self._validate_provider_config(provider)
self._check_rate_limit(provider)
# Enforce maximum pending states to prevent memory exhaustion
if len(self._pending_states) >= self.MAX_PENDING_STATES:
logger.warning(
"oauth_max_pending_states_exceeded",
count=len(self._pending_states),
max_allowed=self.MAX_PENDING_STATES,
)
raise OAuthError("Too many pending OAuth flows. Please try again later.")
# Generate PKCE code verifier and challenge
code_verifier = self._generate_code_verifier()
@@ -190,12 +209,31 @@ class OAuthManager(OAuthPort):
# Validate and retrieve state
oauth_state = self._pending_states.pop(state, None)
if oauth_state is None:
logger.warning(
"oauth_invalid_state_token",
event_type="security",
provider=provider.value,
state_prefix=state[:8] if len(state) >= 8 else state,
)
raise OAuthError("Invalid or expired state token")
if oauth_state.is_state_expired():
logger.warning(
"oauth_state_expired",
event_type="security",
provider=provider.value,
created_at=oauth_state.created_at.isoformat(),
expires_at=oauth_state.expires_at.isoformat(),
)
raise OAuthError("State token has expired")
if oauth_state.provider != provider:
logger.warning(
"oauth_provider_mismatch",
event_type="security",
expected_provider=oauth_state.provider.value,
received_provider=provider.value,
)
raise OAuthError(
f"Provider mismatch: expected {oauth_state.provider}, got {provider}"
)
@@ -468,3 +506,44 @@ class OAuthManager(OAuthPort):
]
for key in expired_keys:
del self._pending_states[key]
def _check_rate_limit(self, provider: OAuthProvider) -> None:
"""Check and enforce rate limiting for auth attempts.
Prevents brute force attacks by limiting auth attempts per provider.
Args:
provider: OAuth provider being used.
Raises:
OAuthError: If rate limit exceeded.
"""
provider_key = provider.value
now = datetime.now(UTC)
cutoff = now - timedelta(seconds=self.AUTH_RATE_LIMIT_WINDOW_SECONDS)
# Clean up old attempts and count recent ones
if provider_key not in self._auth_attempts:
self._auth_attempts[provider_key] = []
# Filter to only keep recent attempts within the window
recent_attempts = [
ts for ts in self._auth_attempts[provider_key] if ts > cutoff
]
self._auth_attempts[provider_key] = recent_attempts
if len(recent_attempts) >= self.MAX_AUTH_ATTEMPTS_PER_MINUTE:
logger.warning(
"oauth_rate_limit_exceeded",
event_type="security",
provider=provider_key,
attempts=len(recent_attempts),
limit=self.MAX_AUTH_ATTEMPTS_PER_MINUTE,
window_seconds=self.AUTH_RATE_LIMIT_WINDOW_SECONDS,
)
raise OAuthError(
"Too many auth attempts. Please wait before trying again."
)
# Record this attempt
self._auth_attempts[provider_key].append(now)

View File

@@ -256,11 +256,26 @@ class OutlookCalendarAdapter(CalendarPort):
Returns:
User's email address.
Raises:
OutlookCalendarError: If API call fails.
"""
email, _ = await self.get_user_info(access_token)
return email
async def get_user_info(self, access_token: str) -> tuple[str, str]:
"""Get authenticated user's email and display name.
Args:
access_token: Valid OAuth access token.
Returns:
Tuple of (email, display_name).
Raises:
OutlookCalendarError: If API call fails.
"""
url = f"{self.GRAPH_API_BASE}/me"
params = {"$select": "mail,userPrincipalName"}
params = {"$select": "mail,userPrincipalName,displayName"}
headers = {HTTP_AUTHORIZATION: f"{HTTP_BEARER_PREFIX}{access_token}"}
async with httpx.AsyncClient(
@@ -281,11 +296,21 @@ class OutlookCalendarAdapter(CalendarPort):
if not isinstance(data_value, dict):
raise OutlookCalendarError("Invalid user profile response")
data = cast(_OutlookProfile, data_value)
if email := data.get("mail") or data.get("userPrincipalName"):
return str(email)
else:
email = data.get("mail") or data.get("userPrincipalName")
if not email:
raise OutlookCalendarError("No email in user profile response")
# Get display name, fall back to email prefix
display_name_raw = data.get("displayName")
display_name = (
str(display_name_raw)
if display_name_raw
else str(email).split("@")[0].replace(".", " ").title()
)
return str(email), display_name
def _parse_event(self, item: _OutlookEvent) -> CalendarEventInfo:
"""Parse Microsoft Graph event into CalendarEventInfo."""
event_id = str(item.get("id", ""))

View File

@@ -0,0 +1,166 @@
"""Compatibility patches for pyannote-audio and diart with modern PyTorch/torchaudio.
This module applies runtime monkey-patches to fix compatibility issues between:
- pyannote-audio 3.x and torchaudio 2.9+ (removed AudioMetaData, backend APIs)
- PyTorch 2.6+ (weights_only=True default in torch.load)
- huggingface_hub 0.24+ (use_auth_token renamed to token)
Import this module before importing pyannote.audio or diart to apply patches.
"""
from __future__ import annotations
import warnings
from collections.abc import Callable
from dataclasses import dataclass
from typing import Final, cast
from noteflow.infrastructure.logging import get_logger
logger = get_logger(__name__)
_patches_applied = False
# Attribute names for dynamic patching - avoids B010 lint warning
_ATTR_AUDIO_METADATA: Final = "AudioMetaData"
_ATTR_LOAD: Final = "load"
_ATTR_HF_HUB_DOWNLOAD: Final = "hf_hub_download"
_ATTR_LIST_BACKENDS: Final = "list_audio_backends"
_ATTR_GET_BACKEND: Final = "get_audio_backend"
_ATTR_SET_BACKEND: Final = "set_audio_backend"
@dataclass
class AudioMetaData:
"""Replacement for torchaudio.AudioMetaData removed in torchaudio 2.9+."""
sample_rate: int
num_frames: int
num_channels: int
bits_per_sample: int
encoding: str
def _patch_torchaudio() -> None:
"""Patch torchaudio to restore removed AudioMetaData class."""
try:
import torchaudio
if not hasattr(torchaudio, _ATTR_AUDIO_METADATA):
setattr(torchaudio, _ATTR_AUDIO_METADATA, AudioMetaData)
logger.debug("Patched torchaudio.AudioMetaData")
except ImportError:
pass
def _patch_torch_load() -> None:
"""Patch torch.load to use weights_only=False for pyannote model loading.
PyTorch 2.6+ changed the default to weights_only=True which breaks
loading pyannote checkpoints that contain non-tensor objects.
"""
try:
import torch
from packaging.version import Version
if Version(torch.__version__) >= Version("2.6.0"):
original_load = cast(Callable[..., object], torch.load)
def _patched_load(*args: object, **kwargs: object) -> object:
if "weights_only" not in kwargs:
kwargs["weights_only"] = False
return original_load(*args, **kwargs)
setattr(torch, _ATTR_LOAD, _patched_load)
logger.debug("Patched torch.load for weights_only=False default")
except ImportError:
pass
def _patch_huggingface_auth() -> None:
"""Patch huggingface_hub functions to accept legacy use_auth_token parameter.
huggingface_hub 0.24+ renamed use_auth_token to token. This patch
allows pyannote/diart code using the old parameter name to work.
"""
try:
import huggingface_hub
original_download = cast(
Callable[..., object], huggingface_hub.hf_hub_download
)
def _patched_download(*args: object, **kwargs: object) -> object:
if "use_auth_token" in kwargs:
kwargs["token"] = kwargs.pop("use_auth_token")
return original_download(*args, **kwargs)
setattr(huggingface_hub, _ATTR_HF_HUB_DOWNLOAD, _patched_download)
logger.debug("Patched huggingface_hub.hf_hub_download for use_auth_token")
except ImportError:
pass
def _patch_speechbrain_backend() -> None:
"""Patch speechbrain to handle removed torchaudio backend APIs."""
try:
import torchaudio
if not hasattr(torchaudio, _ATTR_LIST_BACKENDS):
def _list_audio_backends() -> list[str]:
return ["soundfile", "sox"]
setattr(torchaudio, _ATTR_LIST_BACKENDS, _list_audio_backends)
logger.debug("Patched torchaudio.list_audio_backends")
if not hasattr(torchaudio, _ATTR_GET_BACKEND):
def _get_audio_backend() -> str | None:
return None
setattr(torchaudio, _ATTR_GET_BACKEND, _get_audio_backend)
logger.debug("Patched torchaudio.get_audio_backend")
if not hasattr(torchaudio, _ATTR_SET_BACKEND):
def _set_audio_backend(backend: str | None) -> None:
pass
setattr(torchaudio, _ATTR_SET_BACKEND, _set_audio_backend)
logger.debug("Patched torchaudio.set_audio_backend")
except ImportError:
pass
def apply_patches() -> None:
"""Apply all compatibility patches.
Safe to call multiple times - patches are only applied once.
Should be called before importing pyannote.audio or diart.
"""
global _patches_applied
if _patches_applied:
return
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
_patch_torchaudio()
_patch_speechbrain_backend()
_patch_torch_load()
_patch_huggingface_auth()
_patches_applied = True
logger.info("Applied pyannote/diart compatibility patches")
def ensure_compatibility() -> None:
"""Ensure compatibility patches are applied before using diarization.
This is the recommended entry point - call this at the start of any
code path that will use pyannote.audio or diart.
"""
apply_patches()

View File

@@ -151,6 +151,11 @@ class DiarizationEngine:
)
try:
# Apply compatibility patches before importing pyannote/diart
from noteflow.infrastructure.diarization._compat import ensure_compatibility
ensure_compatibility()
import torch
from diart import SpeakerDiarization, SpeakerDiarizationConfig
from diart.models import EmbeddingModel, SegmentationModel
@@ -200,8 +205,16 @@ class DiarizationEngine:
logger.info("Loading shared streaming diarization models on %s...", device)
try:
# Apply compatibility patches before importing pyannote/diart
from noteflow.infrastructure.diarization._compat import ensure_compatibility
ensure_compatibility()
from diart.models import EmbeddingModel, SegmentationModel
# Use pyannote/segmentation-3.0 with wespeaker embedding
# Note: Frame rate mismatch between models causes a warning but is
# handled via interpolation in pyannote's StatsPool
self._segmentation_model = SegmentationModel.from_pretrained(
"pyannote/segmentation-3.0",
use_hf_token=self._hf_token,
@@ -237,9 +250,14 @@ class DiarizationEngine:
import torch
from diart import SpeakerDiarization, SpeakerDiarizationConfig
# Duration must match the segmentation model's expected window size
# pyannote/segmentation-3.0 is trained with 10-second windows
model_duration = 10.0
config = SpeakerDiarizationConfig(
segmentation=self._segmentation_model,
embedding=self._embedding_model,
duration=model_duration,
step=self._streaming_latency,
latency=self._streaming_latency,
device=torch.device(self._resolve_device()),
@@ -252,6 +270,7 @@ class DiarizationEngine:
meeting_id=meeting_id,
_pipeline=pipeline,
_sample_rate=DEFAULT_SAMPLE_RATE,
_chunk_duration=model_duration,
)
def load_offline_model(self) -> None:
@@ -272,6 +291,11 @@ class DiarizationEngine:
with log_timing("diarization_offline_model_load", device=device):
try:
# Apply compatibility patches before importing pyannote
from noteflow.infrastructure.diarization._compat import ensure_compatibility
ensure_compatibility()
import torch
from pyannote.audio import Pipeline

View File

@@ -10,17 +10,22 @@ from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import numpy as np
from noteflow.config.constants import DEFAULT_SAMPLE_RATE
from noteflow.infrastructure.diarization.dto import SpeakerTurn
from noteflow.infrastructure.logging import get_logger
if TYPE_CHECKING:
import numpy as np
from diart import SpeakerDiarization
from numpy.typing import NDArray
from numpy.typing import NDArray
logger = get_logger(__name__)
# Default chunk duration in seconds (matches pyannote segmentation model)
DEFAULT_CHUNK_DURATION: float = 5.0
@dataclass
class DiarizationSession:
@@ -32,14 +37,20 @@ class DiarizationSession:
The session owns its own SpeakerDiarization pipeline instance but
shares the underlying segmentation and embedding models with other
sessions for memory efficiency.
Audio is buffered until a full chunk (default 5 seconds) is available,
as the diart pipeline requires fixed-size input chunks.
"""
meeting_id: str
_pipeline: SpeakerDiarization | None
_sample_rate: int = DEFAULT_SAMPLE_RATE
_chunk_duration: float = DEFAULT_CHUNK_DURATION
_stream_time: float = field(default=0.0, init=False)
_turns: list[SpeakerTurn] = field(default_factory=list, init=False)
_closed: bool = field(default=False, init=False)
_audio_buffer: list[NDArray[np.float32]] = field(default_factory=list, init=False)
_buffer_samples: int = field(default=0, init=False)
def process_chunk(
self,
@@ -48,13 +59,17 @@ class DiarizationSession:
) -> Sequence[SpeakerTurn]:
"""Process an audio chunk and return new speaker turns.
Audio is buffered until a full chunk (default 5 seconds) is available.
The diart pipeline requires fixed-size input chunks matching the
segmentation model's expected duration.
Args:
audio: Audio samples as float32 array (mono).
audio: Audio samples as float32 array (mono, 1D).
sample_rate: Audio sample rate (defaults to session's configured rate).
Returns:
Sequence of speaker turns detected in this chunk,
with times adjusted to absolute stream position.
Sequence of speaker turns detected. Returns empty list if buffer
is not yet full.
Raises:
RuntimeError: If session is closed.
@@ -65,36 +80,97 @@ class DiarizationSession:
if audio.size == 0:
return []
rate = sample_rate or self._sample_rate
duration = len(audio) / rate
# Ensure audio is 1D
if audio.ndim > 1:
audio = audio.flatten()
# Add to buffer
self._audio_buffer.append(audio)
self._buffer_samples += len(audio)
# Calculate required samples for a full chunk
rate = sample_rate or self._sample_rate
required_samples = int(self._chunk_duration * rate)
# Check if we have enough for a full chunk
if self._buffer_samples < required_samples:
return []
# Concatenate buffered audio
full_audio = np.concatenate(self._audio_buffer)
# Extract exactly required_samples for this chunk
chunk_audio = full_audio[:required_samples]
# Keep remaining audio in buffer for next chunk
remaining = full_audio[required_samples:]
if len(remaining) > 0:
self._audio_buffer = [remaining]
self._buffer_samples = len(remaining)
else:
self._audio_buffer = []
self._buffer_samples = 0
# Process the full chunk
return self._process_full_chunk(chunk_audio, rate)
def _process_full_chunk(
self,
audio: NDArray[np.float32],
sample_rate: int,
) -> list[SpeakerTurn]:
"""Process a full audio chunk through the diarization pipeline.
Args:
audio: Audio samples as 1D float32 array (exactly chunk_duration seconds).
sample_rate: Audio sample rate.
Returns:
List of new speaker turns detected. Returns empty list on error.
"""
if self._pipeline is None:
return []
# Import here to avoid import errors when diart not installed
from pyannote.core import SlidingWindow, SlidingWindowFeature
# Reshape audio for diart: (samples,) -> (1, samples)
if audio.ndim == 1:
audio = audio.reshape(1, -1)
duration = len(audio) / sample_rate
# Reshape to (samples, channels) for diart - mono audio has 1 channel
audio_2d = audio.reshape(-1, 1)
# Create SlidingWindowFeature for diart
window = SlidingWindow(start=0.0, duration=duration, step=duration)
waveform = SlidingWindowFeature(audio, window)
waveform = SlidingWindowFeature(audio_2d, window)
# Process through pipeline
results = self._pipeline([waveform])
try:
# Process through pipeline
# Note: Frame rate mismatch between segmentation-3.0 and embedding models
# may cause warnings and occasional errors, which we handle gracefully
results = self._pipeline([waveform])
# Convert results to turns with absolute time offsets
new_turns: list[SpeakerTurn] = []
for annotation, _ in results:
for track in annotation.itertracks(yield_label=True):
if len(track) == 3:
segment, _, speaker = track
turn = SpeakerTurn(
speaker=str(speaker),
start=segment.start + self._stream_time,
end=segment.end + self._stream_time,
)
new_turns.append(turn)
self._turns.append(turn)
# Convert results to turns with absolute time offsets
new_turns: list[SpeakerTurn] = []
for annotation, _ in results:
for track in annotation.itertracks(yield_label=True):
if len(track) == 3:
segment, _, speaker = track
turn = SpeakerTurn(
speaker=str(speaker),
start=segment.start + self._stream_time,
end=segment.end + self._stream_time,
)
new_turns.append(turn)
self._turns.append(turn)
except (RuntimeError, ZeroDivisionError, ValueError) as e:
# Handle frame/weights mismatch and related errors gracefully
# Streaming diarization continues even if individual chunks fail
logger.warning(
"Diarization chunk processing failed (non-fatal): %s",
str(e),
exc_info=False,
)
new_turns = []
self._stream_time += duration
return new_turns
@@ -102,7 +178,7 @@ class DiarizationSession:
def reset(self) -> None:
"""Reset session state for restarting diarization.
Clears accumulated turns and resets stream time to zero.
Clears accumulated turns, audio buffer, and resets stream time to zero.
The underlying pipeline is also reset.
"""
if self._closed or self._pipeline is None:
@@ -111,6 +187,8 @@ class DiarizationSession:
self._pipeline.reset()
self._stream_time = 0.0
self._turns.clear()
self._audio_buffer.clear()
self._buffer_samples = 0
logger.debug("Session %s reset", self.meeting_id)
def restore(
@@ -154,6 +232,8 @@ class DiarizationSession:
self._closed = True
self._turns.clear()
self._audio_buffer.clear()
self._buffer_samples = 0
# Explicitly release pipeline reference to allow GC and GPU memory release
self._pipeline = None
logger.info("diarization_session_closed", meeting_id=self.meeting_id)

View File

@@ -0,0 +1,771 @@
"""Unit tests for AuthService.
Tests cover:
- initiate_login: OAuth flow initiation with valid/invalid providers
- complete_login: Token exchange, user creation, and auth result construction
- get_current_user: Retrieving current user info, authenticated and unauthenticated
- logout: Token revocation and integration cleanup
- refresh_auth_tokens: Token refresh with success and failure scenarios
"""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import UUID, uuid4
import pytest
from noteflow.application.services.auth_service import (
DEFAULT_USER_ID,
DEFAULT_WORKSPACE_ID,
AuthResult,
AuthService,
AuthServiceError,
LogoutResult,
UserInfo,
)
from noteflow.config.settings import CalendarIntegrationSettings
from noteflow.domain.entities.integration import Integration, IntegrationType
from noteflow.domain.identity.entities import User
from noteflow.domain.value_objects import OAuthProvider, OAuthTokens
from noteflow.infrastructure.calendar.oauth_manager import OAuthError
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture
def mock_oauth_manager() -> MagicMock:
"""Create mock OAuthManager for testing."""
manager = MagicMock()
manager.initiate_auth = MagicMock(
return_value=("https://auth.example.com/authorize?...", "state123")
)
manager.complete_auth = AsyncMock()
manager.refresh_tokens = AsyncMock()
manager.revoke_tokens = AsyncMock()
return manager
@pytest.fixture
def mock_auth_uow() -> MagicMock:
"""Create mock UnitOfWork with auth-related repositories."""
uow = MagicMock()
uow.__aenter__ = AsyncMock(return_value=uow)
uow.__aexit__ = AsyncMock(return_value=None)
uow.commit = AsyncMock()
# Users repository
uow.supports_users = True
uow.users = MagicMock()
uow.users.get = AsyncMock(return_value=None)
uow.users.get_by_email = AsyncMock(return_value=None)
uow.users.create = AsyncMock()
uow.users.update = AsyncMock()
# Workspaces repository
uow.supports_workspaces = True
uow.workspaces = MagicMock()
uow.workspaces.get_default_for_user = AsyncMock(return_value=None)
uow.workspaces.create = AsyncMock()
# Integrations repository
uow.supports_integrations = True
uow.integrations = MagicMock()
uow.integrations.get_by_provider = AsyncMock(return_value=None)
uow.integrations.create = AsyncMock()
uow.integrations.update = AsyncMock()
uow.integrations.delete = AsyncMock()
uow.integrations.set_secrets = AsyncMock()
uow.integrations.get_secrets = AsyncMock(return_value=None)
return uow
@pytest.fixture
def auth_service(
calendar_settings: CalendarIntegrationSettings,
mock_oauth_manager: MagicMock,
) -> AuthService:
"""Create AuthService with mock dependencies."""
def uow_factory() -> MagicMock:
"""Return a new mock UoW each time."""
return MagicMock()
return AuthService(
uow_factory=uow_factory,
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
)
@pytest.fixture
def sample_oauth_tokens(sample_datetime: datetime) -> OAuthTokens:
"""Create sample OAuth tokens for testing."""
return OAuthTokens(
access_token="access_token_123",
refresh_token="refresh_token_456",
token_type="Bearer",
expires_at=sample_datetime + timedelta(hours=1),
scope="openid email profile",
)
@pytest.fixture
def sample_integration() -> Integration:
"""Create sample auth integration for testing."""
integration = Integration.create(
workspace_id=uuid4(),
name="Google Auth",
integration_type=IntegrationType.AUTH,
config={"provider": "google", "user_id": str(uuid4())},
)
integration.connect(provider_email="test@example.com")
return integration
# =============================================================================
# Test: initiate_login
# =============================================================================
class TestInitiateLogin:
"""Tests for AuthService.initiate_login."""
@pytest.mark.parametrize(
("provider", "expected_oauth_provider"),
[
pytest.param("google", OAuthProvider.GOOGLE, id="google_provider"),
pytest.param("Google", OAuthProvider.GOOGLE, id="google_uppercase"),
pytest.param("outlook", OAuthProvider.OUTLOOK, id="outlook_provider"),
pytest.param("OUTLOOK", OAuthProvider.OUTLOOK, id="outlook_uppercase"),
],
)
async def test_initiates_login_with_valid_provider(
self,
auth_service: AuthService,
mock_oauth_manager: MagicMock,
provider: str,
expected_oauth_provider: OAuthProvider,
) -> None:
"""initiate_login returns auth URL and state for valid providers."""
auth_url, state = await auth_service.initiate_login(provider)
assert auth_url == "https://auth.example.com/authorize?...", "should return auth URL"
assert state == "state123", "should return state token"
mock_oauth_manager.initiate_auth.assert_called_once()
async def test_initiates_login_with_custom_redirect_uri(
self,
auth_service: AuthService,
mock_oauth_manager: MagicMock,
) -> None:
"""initiate_login uses custom redirect_uri when provided."""
custom_redirect = "https://custom.example.com/callback"
await auth_service.initiate_login("google", redirect_uri=custom_redirect)
call_kwargs = mock_oauth_manager.initiate_auth.call_args[1]
assert call_kwargs["redirect_uri"] == custom_redirect, "should use custom redirect URI"
async def test_initiates_login_raises_for_invalid_provider(
self,
auth_service: AuthService,
) -> None:
"""initiate_login raises AuthServiceError for invalid provider."""
with pytest.raises(AuthServiceError, match="Invalid provider"):
await auth_service.initiate_login("invalid_provider")
async def test_initiates_login_propagates_oauth_error(
self,
auth_service: AuthService,
mock_oauth_manager: MagicMock,
) -> None:
"""initiate_login raises AuthServiceError when OAuthManager fails."""
mock_oauth_manager.initiate_auth.side_effect = OAuthError("OAuth failed")
with pytest.raises(AuthServiceError, match="OAuth failed"):
await auth_service.initiate_login("google")
# =============================================================================
# Test: complete_login
# =============================================================================
class TestCompleteLogin:
"""Tests for AuthService.complete_login."""
async def test_completes_login_creates_new_user(
self,
calendar_settings: CalendarIntegrationSettings,
mock_oauth_manager: MagicMock,
mock_auth_uow: MagicMock,
sample_oauth_tokens: OAuthTokens,
) -> None:
"""complete_login creates new user when email not found."""
mock_oauth_manager.complete_auth.return_value = sample_oauth_tokens
service = AuthService(
uow_factory=lambda: mock_auth_uow,
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
)
with patch.object(
service, "_fetch_user_info", new_callable=AsyncMock
) as mock_fetch:
mock_fetch.return_value = ("test@example.com", "Test User")
result = await service.complete_login("google", "auth_code", "state123")
assert isinstance(result, AuthResult), "should return AuthResult"
assert result.email == "test@example.com", "should include user email"
assert result.display_name == "Test User", "should include display name"
assert result.is_authenticated is True, "should be authenticated"
mock_auth_uow.users.create.assert_called_once()
async def test_completes_login_updates_existing_user(
self,
calendar_settings: CalendarIntegrationSettings,
mock_oauth_manager: MagicMock,
mock_auth_uow: MagicMock,
sample_oauth_tokens: OAuthTokens,
) -> None:
"""complete_login updates existing user when email found."""
existing_user = User(
id=uuid4(),
display_name="Old Name",
email="test@example.com",
is_default=False,
)
mock_auth_uow.users.get_by_email.return_value = existing_user
mock_oauth_manager.complete_auth.return_value = sample_oauth_tokens
service = AuthService(
uow_factory=lambda: mock_auth_uow,
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
)
with patch.object(
service, "_fetch_user_info", new_callable=AsyncMock
) as mock_fetch:
mock_fetch.return_value = ("test@example.com", "New Name")
result = await service.complete_login("google", "auth_code", "state123")
assert result.user_id == existing_user.id, "should use existing user ID"
mock_auth_uow.users.update.assert_called_once()
mock_auth_uow.users.create.assert_not_called()
async def test_completes_login_stores_integration_tokens(
self,
calendar_settings: CalendarIntegrationSettings,
mock_oauth_manager: MagicMock,
mock_auth_uow: MagicMock,
sample_oauth_tokens: OAuthTokens,
) -> None:
"""complete_login stores tokens in integration secrets."""
mock_oauth_manager.complete_auth.return_value = sample_oauth_tokens
service = AuthService(
uow_factory=lambda: mock_auth_uow,
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
)
with patch.object(
service, "_fetch_user_info", new_callable=AsyncMock
) as mock_fetch:
mock_fetch.return_value = ("test@example.com", "Test User")
await service.complete_login("google", "auth_code", "state123")
mock_auth_uow.integrations.set_secrets.assert_called_once()
mock_auth_uow.commit.assert_called()
async def test_completes_login_raises_on_token_exchange_failure(
self,
calendar_settings: CalendarIntegrationSettings,
mock_oauth_manager: MagicMock,
mock_auth_uow: MagicMock,
) -> None:
"""complete_login raises AuthServiceError when token exchange fails."""
mock_oauth_manager.complete_auth.side_effect = OAuthError("Invalid code")
service = AuthService(
uow_factory=lambda: mock_auth_uow,
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
)
with pytest.raises(AuthServiceError, match="OAuth failed"):
await service.complete_login("google", "invalid_code", "state123")
# =============================================================================
# Test: get_current_user
# =============================================================================
class TestGetCurrentUser:
"""Tests for AuthService.get_current_user."""
async def test_returns_authenticated_user_for_connected_integration(
self,
calendar_settings: CalendarIntegrationSettings,
mock_oauth_manager: MagicMock,
mock_auth_uow: MagicMock,
sample_integration: Integration,
) -> None:
"""get_current_user returns authenticated user info when integration exists."""
user = User(
id=UUID(sample_integration.config["user_id"]),
display_name="Authenticated User",
email="test@example.com",
)
mock_auth_uow.integrations.get_by_provider.return_value = sample_integration
mock_auth_uow.users.get.return_value = user
service = AuthService(
uow_factory=lambda: mock_auth_uow,
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
)
result = await service.get_current_user()
assert isinstance(result, UserInfo), "should return UserInfo"
assert result.is_authenticated is True, "should be authenticated"
assert result.provider == "google", "should include provider"
assert result.email == "test@example.com", "should include email"
async def test_returns_local_user_when_no_integration(
self,
calendar_settings: CalendarIntegrationSettings,
mock_oauth_manager: MagicMock,
mock_auth_uow: MagicMock,
) -> None:
"""get_current_user returns local default user when no integration exists."""
mock_auth_uow.integrations.get_by_provider.return_value = None
service = AuthService(
uow_factory=lambda: mock_auth_uow,
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
)
result = await service.get_current_user()
assert isinstance(result, UserInfo), "should return UserInfo"
assert result.is_authenticated is False, "should not be authenticated"
assert result.display_name == "Local User", "should use local user name"
assert result.user_id == DEFAULT_USER_ID, "should use default user ID"
assert result.workspace_id == DEFAULT_WORKSPACE_ID, "should use default workspace ID"
async def test_returns_local_user_for_disconnected_integration(
self,
calendar_settings: CalendarIntegrationSettings,
mock_oauth_manager: MagicMock,
mock_auth_uow: MagicMock,
) -> None:
"""get_current_user returns local user when integration is disconnected."""
integration = Integration.create(
workspace_id=uuid4(),
name="Google Auth",
integration_type=IntegrationType.AUTH,
config={"provider": "google"},
)
# Integration not connected (no provider_email set)
mock_auth_uow.integrations.get_by_provider.return_value = integration
service = AuthService(
uow_factory=lambda: mock_auth_uow,
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
)
result = await service.get_current_user()
assert result.is_authenticated is False, "disconnected integration should not be authenticated"
# =============================================================================
# Test: logout
# =============================================================================
class TestLogout:
"""Tests for AuthService.logout."""
async def test_logout_deletes_integration_and_revokes_tokens(
self,
calendar_settings: CalendarIntegrationSettings,
mock_oauth_manager: MagicMock,
mock_auth_uow: MagicMock,
sample_integration: Integration,
) -> None:
"""logout deletes integration and revokes tokens."""
mock_auth_uow.integrations.get_by_provider.return_value = sample_integration
mock_auth_uow.integrations.get_secrets.return_value = {
"access_token": "token_to_revoke"
}
service = AuthService(
uow_factory=lambda: mock_auth_uow,
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
)
result = await service.logout("google")
assert result.logged_out is True, "logout should return logged_out=True on success"
assert result.tokens_revoked is True, "logout should return tokens_revoked=True"
mock_auth_uow.integrations.delete.assert_called_once_with(sample_integration.id)
mock_oauth_manager.revoke_tokens.assert_called_once()
async def test_logout_returns_false_when_no_integration(
self,
calendar_settings: CalendarIntegrationSettings,
mock_oauth_manager: MagicMock,
mock_auth_uow: MagicMock,
) -> None:
"""logout returns False when no integration exists."""
mock_auth_uow.integrations.get_by_provider.return_value = None
service = AuthService(
uow_factory=lambda: mock_auth_uow,
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
)
result = await service.logout("google")
assert result.logged_out is False, "logout should return logged_out=False when no integration"
mock_auth_uow.integrations.delete.assert_not_called()
async def test_logout_all_providers_when_none_specified(
self,
calendar_settings: CalendarIntegrationSettings,
mock_oauth_manager: MagicMock,
mock_auth_uow: MagicMock,
sample_integration: Integration,
) -> None:
"""logout attempts all providers when none specified."""
# Return integration only for google, not for outlook
def get_by_provider(provider: str, integration_type: str) -> Integration | None:
return sample_integration if provider == "google" else None
mock_auth_uow.integrations.get_by_provider.side_effect = get_by_provider
service = AuthService(
uow_factory=lambda: mock_auth_uow,
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
)
result = await service.logout() # No provider specified
assert result.logged_out is True, "logout should return logged_out=True if any provider logged out"
# Should have checked both providers
EXPECTED_PROVIDER_CHECK_COUNT = 2
assert (
mock_auth_uow.integrations.get_by_provider.call_count
== EXPECTED_PROVIDER_CHECK_COUNT
), "should check both providers"
async def test_logout_handles_revocation_error_gracefully(
self,
calendar_settings: CalendarIntegrationSettings,
mock_oauth_manager: MagicMock,
mock_auth_uow: MagicMock,
sample_integration: Integration,
) -> None:
"""logout continues even if token revocation fails."""
mock_auth_uow.integrations.get_by_provider.return_value = sample_integration
mock_auth_uow.integrations.get_secrets.return_value = {"access_token": "token"}
mock_oauth_manager.revoke_tokens.side_effect = OAuthError("Revocation failed")
service = AuthService(
uow_factory=lambda: mock_auth_uow,
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
)
result = await service.logout("google")
assert result.logged_out is True, "logout should succeed despite revocation failure"
assert result.tokens_revoked is False, "tokens_revoked should be False on revocation failure"
assert result.revocation_error is not None, "should include revocation error message"
mock_auth_uow.integrations.delete.assert_called_once()
# =============================================================================
# Test: refresh_auth_tokens
# =============================================================================
class TestRefreshAuthTokens:
"""Tests for AuthService.refresh_auth_tokens."""
async def test_refreshes_tokens_successfully(
self,
calendar_settings: CalendarIntegrationSettings,
mock_oauth_manager: MagicMock,
mock_auth_uow: MagicMock,
sample_integration: Integration,
sample_datetime: datetime,
) -> None:
"""refresh_auth_tokens updates tokens and returns new AuthResult."""
old_tokens = {
"access_token": "old_access",
"refresh_token": "old_refresh",
"expires_at": sample_datetime.isoformat(),
}
new_tokens = OAuthTokens(
access_token="new_access",
refresh_token="new_refresh",
token_type="Bearer",
expires_at=sample_datetime + timedelta(hours=1),
scope="openid email profile",
)
mock_auth_uow.integrations.get_by_provider.return_value = sample_integration
mock_auth_uow.integrations.get_secrets.return_value = old_tokens
mock_oauth_manager.refresh_tokens.return_value = new_tokens
service = AuthService(
uow_factory=lambda: mock_auth_uow,
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
)
result = await service.refresh_auth_tokens("google")
assert result is not None, "should return AuthResult on success"
assert result.is_authenticated is True, "should return authenticated result"
# Verify tokens were stored (not exposed on result per security design)
mock_auth_uow.integrations.set_secrets.assert_called_once()
mock_auth_uow.commit.assert_called()
async def test_returns_none_when_no_integration(
self,
calendar_settings: CalendarIntegrationSettings,
mock_oauth_manager: MagicMock,
mock_auth_uow: MagicMock,
) -> None:
"""refresh_auth_tokens returns None when no integration exists."""
mock_auth_uow.integrations.get_by_provider.return_value = None
service = AuthService(
uow_factory=lambda: mock_auth_uow,
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
)
result = await service.refresh_auth_tokens("google")
assert result is None, "should return None when no integration"
async def test_returns_none_when_no_refresh_token(
self,
calendar_settings: CalendarIntegrationSettings,
mock_oauth_manager: MagicMock,
mock_auth_uow: MagicMock,
sample_integration: Integration,
sample_datetime: datetime,
) -> None:
"""refresh_auth_tokens returns None when no refresh token available."""
mock_auth_uow.integrations.get_by_provider.return_value = sample_integration
mock_auth_uow.integrations.get_secrets.return_value = {
"access_token": "access",
"expires_at": sample_datetime.isoformat(),
# No refresh_token
}
service = AuthService(
uow_factory=lambda: mock_auth_uow,
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
)
result = await service.refresh_auth_tokens("google")
assert result is None, "should return None when no refresh token"
async def test_marks_error_on_refresh_failure(
self,
calendar_settings: CalendarIntegrationSettings,
mock_oauth_manager: MagicMock,
mock_auth_uow: MagicMock,
sample_integration: Integration,
sample_datetime: datetime,
) -> None:
"""refresh_auth_tokens marks integration error on failure."""
mock_auth_uow.integrations.get_by_provider.return_value = sample_integration
mock_auth_uow.integrations.get_secrets.return_value = {
"access_token": "access",
"refresh_token": "refresh",
"expires_at": sample_datetime.isoformat(),
}
mock_oauth_manager.refresh_tokens.side_effect = OAuthError("Token expired")
service = AuthService(
uow_factory=lambda: mock_auth_uow,
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
)
result = await service.refresh_auth_tokens("google")
assert result is None, "should return None on refresh failure"
mock_auth_uow.integrations.update.assert_called_once()
mock_auth_uow.commit.assert_called()
# =============================================================================
# Test: _store_auth_user (workspace creation)
# =============================================================================
class TestStoreAuthUser:
"""Tests for AuthService._store_auth_user workspace handling."""
async def test_creates_default_workspace_for_new_user(
self,
calendar_settings: CalendarIntegrationSettings,
mock_oauth_manager: MagicMock,
mock_auth_uow: MagicMock,
sample_oauth_tokens: OAuthTokens,
) -> None:
"""_store_auth_user creates default workspace when none exists."""
mock_auth_uow.workspaces.get_default_for_user.return_value = None
mock_auth_uow.workspaces.create = AsyncMock()
service = AuthService(
uow_factory=lambda: mock_auth_uow,
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
)
await service._store_auth_user(
"google", "test@example.com", "Test User", sample_oauth_tokens
)
# Verify workspace.create was called (new workspace for new user)
mock_auth_uow.workspaces.create.assert_called_once()
call_kwargs = mock_auth_uow.workspaces.create.call_args[1]
assert call_kwargs["name"] == "Personal", "should create 'Personal' workspace"
assert call_kwargs["is_default"] is True, "should be default workspace"
# =============================================================================
# Test: refresh_auth_tokens (additional edge cases)
# =============================================================================
class TestRefreshAuthTokensEdgeCases:
"""Additional tests for AuthService.refresh_auth_tokens edge cases."""
async def test_skips_refresh_when_token_still_valid(
self,
calendar_settings: CalendarIntegrationSettings,
mock_oauth_manager: MagicMock,
mock_auth_uow: MagicMock,
sample_integration: Integration,
) -> None:
"""refresh_auth_tokens returns existing auth when token not expired."""
# Token expires in 1 hour (more than 5 minute buffer)
future_expiry = datetime.now(UTC) + timedelta(hours=1)
secrets = {
"access_token": "valid_token",
"refresh_token": "refresh_token",
"expires_at": future_expiry.isoformat(),
}
mock_auth_uow.integrations.get_by_provider.return_value = sample_integration
mock_auth_uow.integrations.get_secrets.return_value = secrets
service = AuthService(
uow_factory=lambda: mock_auth_uow,
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
)
result = await service.refresh_auth_tokens("google")
# Should return result without calling refresh
assert result is not None, "should return AuthResult"
mock_oauth_manager.refresh_tokens.assert_not_called()
async def test_returns_none_on_invalid_secrets_dict(
self,
calendar_settings: CalendarIntegrationSettings,
mock_oauth_manager: MagicMock,
mock_auth_uow: MagicMock,
sample_integration: Integration,
) -> None:
"""refresh_auth_tokens returns None when secrets can't be parsed."""
mock_auth_uow.integrations.get_by_provider.return_value = sample_integration
mock_auth_uow.integrations.get_secrets.return_value = {
"invalid": "data", # Missing required fields
}
service = AuthService(
uow_factory=lambda: mock_auth_uow,
settings=calendar_settings,
oauth_manager=mock_oauth_manager,
)
result = await service.refresh_auth_tokens("google")
assert result is None, "should return None for invalid secrets"
# =============================================================================
# Test: _parse_provider (static method)
# =============================================================================
class TestParseProvider:
"""Tests for AuthService._parse_provider static method."""
@pytest.mark.parametrize(
("input_provider", "expected_output"),
[
pytest.param("google", OAuthProvider.GOOGLE, id="google_lowercase"),
pytest.param("GOOGLE", OAuthProvider.GOOGLE, id="google_uppercase"),
pytest.param("Google", OAuthProvider.GOOGLE, id="google_mixed_case"),
pytest.param("outlook", OAuthProvider.OUTLOOK, id="outlook_lowercase"),
pytest.param("OUTLOOK", OAuthProvider.OUTLOOK, id="outlook_uppercase"),
],
)
def test_parses_valid_providers(
self,
input_provider: str,
expected_output: OAuthProvider,
) -> None:
"""_parse_provider correctly parses valid provider strings."""
result = AuthService._parse_provider(input_provider)
assert result == expected_output, f"should parse {input_provider} correctly"
@pytest.mark.parametrize(
"invalid_provider",
[
pytest.param("github", id="github_not_supported"),
pytest.param("facebook", id="facebook_not_supported"),
pytest.param("", id="empty_string"),
pytest.param("invalid", id="random_string"),
],
)
def test_raises_for_invalid_providers(
self,
invalid_provider: str,
) -> None:
"""_parse_provider raises AuthServiceError for invalid providers."""
with pytest.raises(AuthServiceError, match="Invalid provider"):
AuthService._parse_provider(invalid_provider)

View File

@@ -0,0 +1,578 @@
"""Tests for IdentityMixin gRPC endpoints.
Tests cover:
- GetCurrentUser: Returns user identity, workspace, and auth status
- ListWorkspaces: Lists user's workspaces with pagination
- SwitchWorkspace: Validates workspace access and returns workspace info
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from unittest.mock import AsyncMock, MagicMock
from uuid import UUID, uuid4
import pytest
from noteflow.application.services.identity_service import IdentityService
from noteflow.domain.entities.integration import Integration, IntegrationType
from noteflow.domain.identity.context import OperationContext, UserContext, WorkspaceContext
from noteflow.domain.identity.entities import Workspace, WorkspaceMembership
from noteflow.domain.identity.roles import WorkspaceRole
from noteflow.grpc._mixins._types import GrpcContext
from noteflow.grpc._mixins.identity import IdentityMixin
from noteflow.grpc.proto import noteflow_pb2
if TYPE_CHECKING:
from datetime import datetime
# =============================================================================
# Mock Servicer Host
# =============================================================================
class MockIdentityServicerHost(IdentityMixin):
"""Mock servicer host implementing required protocol for IdentityMixin."""
def __init__(self) -> None:
"""Initialize mock servicer with identity service."""
self._identity_service = IdentityService()
self._mock_uow: MagicMock | None = None
def create_repository_provider(self) -> MagicMock:
"""Return mock UnitOfWork."""
if self._mock_uow is None:
msg = "Mock UoW not configured"
raise RuntimeError(msg)
return self._mock_uow
def get_operation_context(self, context: GrpcContext) -> OperationContext:
"""Return mock operation context."""
return OperationContext(
user=UserContext(
user_id=uuid4(),
display_name="Test User",
),
workspace=WorkspaceContext(
workspace_id=uuid4(),
workspace_name="Test Workspace",
role=WorkspaceRole.OWNER,
),
)
@property
def identity_service(self) -> IdentityService:
"""Return identity service."""
return self._identity_service
def set_mock_uow(self, uow: MagicMock) -> None:
"""Set mock UnitOfWork for testing."""
self._mock_uow = uow
# Type stubs for mixin methods
if TYPE_CHECKING:
async def GetCurrentUser(
self,
request: noteflow_pb2.GetCurrentUserRequest,
context: GrpcContext,
) -> noteflow_pb2.GetCurrentUserResponse: ...
async def ListWorkspaces(
self,
request: noteflow_pb2.ListWorkspacesRequest,
context: GrpcContext,
) -> noteflow_pb2.ListWorkspacesResponse: ...
async def SwitchWorkspace(
self,
request: noteflow_pb2.SwitchWorkspaceRequest,
context: GrpcContext,
) -> noteflow_pb2.SwitchWorkspaceResponse: ...
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture
def identity_servicer() -> MockIdentityServicerHost:
"""Create servicer for identity mixin testing."""
return MockIdentityServicerHost()
@pytest.fixture
def mock_identity_uow() -> MagicMock:
"""Create mock UnitOfWork with identity-related repositories."""
uow = MagicMock()
uow.__aenter__ = AsyncMock(return_value=uow)
uow.__aexit__ = AsyncMock(return_value=None)
uow.commit = AsyncMock()
# Users repository
uow.supports_users = True
uow.users = MagicMock()
uow.users.get = AsyncMock(return_value=None)
uow.users.get_default = AsyncMock(return_value=None)
uow.users.create_default = AsyncMock()
# Workspaces repository
uow.supports_workspaces = True
uow.workspaces = MagicMock()
uow.workspaces.get = AsyncMock(return_value=None)
uow.workspaces.get_default_for_user = AsyncMock(return_value=None)
uow.workspaces.get_membership = AsyncMock(return_value=None)
uow.workspaces.list_for_user = AsyncMock(return_value=[])
uow.workspaces.create = AsyncMock()
# Integrations repository
uow.supports_integrations = True
uow.integrations = MagicMock()
uow.integrations.get_by_provider = AsyncMock(return_value=None)
# Projects repository (for workspace creation)
uow.supports_projects = False
return uow
@pytest.fixture
def sample_user_context() -> UserContext:
"""Create sample user context for testing."""
return UserContext(
user_id=uuid4(),
display_name="Test User",
email="test@example.com",
)
@pytest.fixture
def sample_workspace_context() -> WorkspaceContext:
"""Create sample workspace context for testing."""
return WorkspaceContext(
workspace_id=uuid4(),
workspace_name="Test Workspace",
role=WorkspaceRole.OWNER,
)
@pytest.fixture
def sample_workspace(sample_datetime: datetime) -> Workspace:
"""Create sample workspace for testing."""
return Workspace(
id=uuid4(),
name="Test Workspace",
slug="test-workspace",
is_default=True,
created_at=sample_datetime,
updated_at=sample_datetime,
)
@pytest.fixture
def sample_membership() -> WorkspaceMembership:
"""Create sample workspace membership for testing."""
return WorkspaceMembership(
workspace_id=uuid4(),
user_id=uuid4(),
role=WorkspaceRole.MEMBER,
)
# =============================================================================
# Test: GetCurrentUser
# =============================================================================
class TestGetCurrentUser:
"""Tests for IdentityMixin.GetCurrentUser."""
async def test_returns_default_user_in_memory_mode(
self,
identity_servicer: MockIdentityServicerHost,
mock_grpc_context: MagicMock,
) -> None:
"""GetCurrentUser returns default user when in memory mode."""
uow = MagicMock()
uow.__aenter__ = AsyncMock(return_value=uow)
uow.__aexit__ = AsyncMock(return_value=None)
uow.commit = AsyncMock()
uow.supports_users = False
uow.supports_workspaces = False
uow.supports_integrations = False
identity_servicer.set_mock_uow(uow)
request = noteflow_pb2.GetCurrentUserRequest()
response = await identity_servicer.GetCurrentUser(request, mock_grpc_context)
assert response.user_id, "should return user_id"
assert response.workspace_id, "should return workspace_id"
assert response.display_name, "should return display_name"
assert response.is_authenticated is False, "should not be authenticated in memory mode"
async def test_returns_authenticated_user_with_oauth(
self,
identity_servicer: MockIdentityServicerHost,
mock_identity_uow: MagicMock,
mock_grpc_context: MagicMock,
sample_datetime: datetime,
) -> None:
"""GetCurrentUser returns authenticated user when OAuth integration exists."""
# Configure connected integration
integration = Integration.create(
workspace_id=uuid4(),
name="Google Auth",
integration_type=IntegrationType.AUTH,
config={"provider": "google"},
)
integration.connect(provider_email="test@example.com")
mock_identity_uow.integrations.get_by_provider.return_value = integration
identity_servicer.set_mock_uow(mock_identity_uow)
request = noteflow_pb2.GetCurrentUserRequest()
response = await identity_servicer.GetCurrentUser(request, mock_grpc_context)
assert response.is_authenticated is True, "should be authenticated with OAuth"
assert response.auth_provider == "google", "should return auth provider"
async def test_returns_workspace_role(
self,
identity_servicer: MockIdentityServicerHost,
mock_identity_uow: MagicMock,
mock_grpc_context: MagicMock,
) -> None:
"""GetCurrentUser returns user's workspace role."""
identity_servicer.set_mock_uow(mock_identity_uow)
request = noteflow_pb2.GetCurrentUserRequest()
response = await identity_servicer.GetCurrentUser(request, mock_grpc_context)
# Role should be set (default is owner for first user)
assert response.role, "should return workspace role"
assert response.workspace_name, "should return workspace name"
# =============================================================================
# Test: ListWorkspaces
# =============================================================================
class TestListWorkspaces:
"""Tests for IdentityMixin.ListWorkspaces."""
async def test_returns_empty_list_when_no_workspaces(
self,
identity_servicer: MockIdentityServicerHost,
mock_identity_uow: MagicMock,
mock_grpc_context: MagicMock,
) -> None:
"""ListWorkspaces returns empty list when user has no workspaces."""
identity_servicer.set_mock_uow(mock_identity_uow)
request = noteflow_pb2.ListWorkspacesRequest()
response = await identity_servicer.ListWorkspaces(request, mock_grpc_context)
assert response.total_count == 0, "should return zero count"
assert len(response.workspaces) == 0, "should return empty list"
async def test_returns_user_workspaces(
self,
identity_servicer: MockIdentityServicerHost,
mock_identity_uow: MagicMock,
mock_grpc_context: MagicMock,
sample_workspace: Workspace,
sample_membership: WorkspaceMembership,
) -> None:
"""ListWorkspaces returns workspaces the user belongs to."""
mock_identity_uow.workspaces.list_for_user.return_value = [sample_workspace]
mock_identity_uow.workspaces.get_membership.return_value = sample_membership
identity_servicer.set_mock_uow(mock_identity_uow)
request = noteflow_pb2.ListWorkspacesRequest()
response = await identity_servicer.ListWorkspaces(request, mock_grpc_context)
assert response.total_count == 1, "should return correct count"
assert len(response.workspaces) == 1, "should return one workspace"
assert response.workspaces[0].name == "Test Workspace", "should include workspace name"
assert response.workspaces[0].is_default is True, "should include is_default flag"
async def test_respects_pagination_parameters(
self,
identity_servicer: MockIdentityServicerHost,
mock_identity_uow: MagicMock,
mock_grpc_context: MagicMock,
) -> None:
"""ListWorkspaces passes pagination parameters to repository."""
identity_servicer.set_mock_uow(mock_identity_uow)
request = noteflow_pb2.ListWorkspacesRequest(limit=10, offset=5)
await identity_servicer.ListWorkspaces(request, mock_grpc_context)
# Check that list_for_user was called with pagination params
mock_identity_uow.workspaces.list_for_user.assert_called()
call_args = mock_identity_uow.workspaces.list_for_user.call_args
EXPECTED_LIMIT = 10
EXPECTED_OFFSET = 5
assert call_args[0][1] == EXPECTED_LIMIT, "should pass limit"
assert call_args[0][2] == EXPECTED_OFFSET, "should pass offset"
async def test_uses_default_pagination_values(
self,
identity_servicer: MockIdentityServicerHost,
mock_identity_uow: MagicMock,
mock_grpc_context: MagicMock,
) -> None:
"""ListWorkspaces uses default pagination when not specified."""
identity_servicer.set_mock_uow(mock_identity_uow)
request = noteflow_pb2.ListWorkspacesRequest() # No limit/offset
await identity_servicer.ListWorkspaces(request, mock_grpc_context)
call_args = mock_identity_uow.workspaces.list_for_user.call_args
DEFAULT_LIMIT = 50
DEFAULT_OFFSET = 0
assert call_args[0][1] == DEFAULT_LIMIT, "should use default limit"
assert call_args[0][2] == DEFAULT_OFFSET, "should use default offset"
async def test_includes_workspace_role(
self,
identity_servicer: MockIdentityServicerHost,
mock_identity_uow: MagicMock,
mock_grpc_context: MagicMock,
sample_workspace: Workspace,
) -> None:
"""ListWorkspaces includes user's role in each workspace."""
owner_membership = WorkspaceMembership(
workspace_id=sample_workspace.id,
user_id=uuid4(),
role=WorkspaceRole.OWNER,
)
mock_identity_uow.workspaces.list_for_user.return_value = [sample_workspace]
mock_identity_uow.workspaces.get_membership.return_value = owner_membership
identity_servicer.set_mock_uow(mock_identity_uow)
request = noteflow_pb2.ListWorkspacesRequest()
response = await identity_servicer.ListWorkspaces(request, mock_grpc_context)
assert response.workspaces[0].role == "owner", "should include role"
# =============================================================================
# Test: SwitchWorkspace
# =============================================================================
class TestSwitchWorkspace:
"""Tests for IdentityMixin.SwitchWorkspace."""
async def test_aborts_when_workspace_id_missing(
self,
identity_servicer: MockIdentityServicerHost,
mock_identity_uow: MagicMock,
mock_grpc_context: MagicMock,
) -> None:
"""SwitchWorkspace aborts with INVALID_ARGUMENT when workspace_id not provided."""
identity_servicer.set_mock_uow(mock_identity_uow)
request = noteflow_pb2.SwitchWorkspaceRequest() # No workspace_id
with pytest.raises(AssertionError, match="Unreachable"):
await identity_servicer.SwitchWorkspace(request, mock_grpc_context)
mock_grpc_context.abort.assert_called_once()
async def test_aborts_for_invalid_uuid(
self,
identity_servicer: MockIdentityServicerHost,
mock_identity_uow: MagicMock,
mock_grpc_context: MagicMock,
) -> None:
"""SwitchWorkspace aborts with INVALID_ARGUMENT for invalid workspace_id format."""
identity_servicer.set_mock_uow(mock_identity_uow)
request = noteflow_pb2.SwitchWorkspaceRequest(workspace_id="not-a-uuid")
with pytest.raises(AssertionError, match="Unreachable"):
await identity_servicer.SwitchWorkspace(request, mock_grpc_context)
mock_grpc_context.abort.assert_called_once()
async def test_aborts_when_workspace_not_found(
self,
identity_servicer: MockIdentityServicerHost,
mock_identity_uow: MagicMock,
mock_grpc_context: MagicMock,
) -> None:
"""SwitchWorkspace aborts with NOT_FOUND when workspace doesn't exist."""
mock_identity_uow.workspaces.get.return_value = None
identity_servicer.set_mock_uow(mock_identity_uow)
workspace_id = uuid4()
request = noteflow_pb2.SwitchWorkspaceRequest(workspace_id=str(workspace_id))
with pytest.raises(AssertionError, match="Unreachable"):
await identity_servicer.SwitchWorkspace(request, mock_grpc_context)
mock_grpc_context.abort.assert_called_once()
async def test_aborts_when_user_not_member(
self,
identity_servicer: MockIdentityServicerHost,
mock_identity_uow: MagicMock,
mock_grpc_context: MagicMock,
sample_workspace: Workspace,
) -> None:
"""SwitchWorkspace aborts with NOT_FOUND when user is not a member of workspace."""
mock_identity_uow.workspaces.get.return_value = sample_workspace
mock_identity_uow.workspaces.get_membership.return_value = None # No membership
identity_servicer.set_mock_uow(mock_identity_uow)
request = noteflow_pb2.SwitchWorkspaceRequest(workspace_id=str(sample_workspace.id))
with pytest.raises(AssertionError, match="Unreachable"):
await identity_servicer.SwitchWorkspace(request, mock_grpc_context)
mock_grpc_context.abort.assert_called_once()
async def test_switches_workspace_successfully(
self,
identity_servicer: MockIdentityServicerHost,
mock_identity_uow: MagicMock,
mock_grpc_context: MagicMock,
sample_workspace: Workspace,
sample_membership: WorkspaceMembership,
) -> None:
"""SwitchWorkspace returns workspace info on success."""
mock_identity_uow.workspaces.get.return_value = sample_workspace
mock_identity_uow.workspaces.get_membership.return_value = sample_membership
identity_servicer.set_mock_uow(mock_identity_uow)
request = noteflow_pb2.SwitchWorkspaceRequest(workspace_id=str(sample_workspace.id))
response = await identity_servicer.SwitchWorkspace(request, mock_grpc_context)
assert response.success is True, "should return success=True"
assert response.workspace.id == str(sample_workspace.id), "should return workspace ID"
assert response.workspace.name == "Test Workspace", "should return workspace name"
assert response.workspace.role == "member", "should return user's role"
@pytest.mark.parametrize(
("role", "expected_role_str"),
[
pytest.param(WorkspaceRole.OWNER, "owner", id="owner_role"),
pytest.param(WorkspaceRole.ADMIN, "admin", id="admin_role"),
pytest.param(WorkspaceRole.MEMBER, "member", id="member_role"),
],
)
async def test_returns_correct_role_for_membership(
self,
identity_servicer: MockIdentityServicerHost,
mock_identity_uow: MagicMock,
mock_grpc_context: MagicMock,
sample_workspace: Workspace,
role: WorkspaceRole,
expected_role_str: str,
) -> None:
"""SwitchWorkspace returns correct role string for different memberships."""
membership = WorkspaceMembership(
workspace_id=sample_workspace.id,
user_id=uuid4(),
role=role,
)
mock_identity_uow.workspaces.get.return_value = sample_workspace
mock_identity_uow.workspaces.get_membership.return_value = membership
identity_servicer.set_mock_uow(mock_identity_uow)
request = noteflow_pb2.SwitchWorkspaceRequest(workspace_id=str(sample_workspace.id))
response = await identity_servicer.SwitchWorkspace(request, mock_grpc_context)
assert response.workspace.role == expected_role_str, f"should return role as {expected_role_str}"
async def test_includes_workspace_metadata(
self,
identity_servicer: MockIdentityServicerHost,
mock_identity_uow: MagicMock,
mock_grpc_context: MagicMock,
sample_membership: WorkspaceMembership,
) -> None:
"""SwitchWorkspace includes workspace slug and is_default flag."""
workspace = Workspace(
id=uuid4(),
name="My Custom Workspace",
slug="my-custom-workspace",
is_default=False,
)
mock_identity_uow.workspaces.get.return_value = workspace
mock_identity_uow.workspaces.get_membership.return_value = sample_membership
identity_servicer.set_mock_uow(mock_identity_uow)
request = noteflow_pb2.SwitchWorkspaceRequest(workspace_id=str(workspace.id))
response = await identity_servicer.SwitchWorkspace(request, mock_grpc_context)
assert response.workspace.slug == "my-custom-workspace", "should include slug"
assert response.workspace.is_default is False, "should include is_default flag"
# =============================================================================
# Test: Database Required Error
# =============================================================================
class TestDatabaseRequired:
"""Tests for database requirement handling in identity endpoints."""
async def test_list_workspaces_aborts_without_database(
self,
identity_servicer: MockIdentityServicerHost,
mock_grpc_context: MagicMock,
) -> None:
"""ListWorkspaces aborts when workspaces not supported."""
uow = MagicMock()
uow.__aenter__ = AsyncMock(return_value=uow)
uow.__aexit__ = AsyncMock(return_value=None)
uow.commit = AsyncMock()
uow.supports_users = False
uow.supports_workspaces = False
identity_servicer.set_mock_uow(uow)
request = noteflow_pb2.ListWorkspacesRequest()
# abort helpers raise AssertionError after mock context.abort()
with pytest.raises(AssertionError, match="Unreachable"):
await identity_servicer.ListWorkspaces(request, mock_grpc_context)
mock_grpc_context.abort.assert_called_once()
async def test_switch_workspace_aborts_without_database(
self,
identity_servicer: MockIdentityServicerHost,
mock_grpc_context: MagicMock,
) -> None:
"""SwitchWorkspace aborts when workspaces not supported."""
uow = MagicMock()
uow.__aenter__ = AsyncMock(return_value=uow)
uow.__aexit__ = AsyncMock(return_value=None)
uow.commit = AsyncMock()
uow.supports_users = False
uow.supports_workspaces = False
identity_servicer.set_mock_uow(uow)
workspace_id = uuid4()
request = noteflow_pb2.SwitchWorkspaceRequest(workspace_id=str(workspace_id))
# abort helpers raise AssertionError after mock context.abort()
with pytest.raises(AssertionError, match="Unreachable"):
await identity_servicer.SwitchWorkspace(request, mock_grpc_context)
mock_grpc_context.abort.assert_called_once()

View File

@@ -267,3 +267,126 @@ class TestGoogleCalendarAdapterDateParsing:
events = await adapter.list_events("access-token")
assert events[0].is_recurring is True, "event with recurringEventId should be recurring"
# =============================================================================
# Test: get_user_info
# =============================================================================
class TestGoogleCalendarAdapterGetUserInfo:
"""Tests for GoogleCalendarAdapter.get_user_info."""
@pytest.mark.asyncio
async def test_returns_email_and_display_name(self) -> None:
"""get_user_info should return email and display name."""
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
adapter = GoogleCalendarAdapter()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"email": "user@example.com",
"name": "Test User",
}
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
mock_get.return_value = mock_response
email, display_name = await adapter.get_user_info("access-token")
assert email == "user@example.com", "should return user email"
assert display_name == "Test User", "should return display name"
@pytest.mark.asyncio
async def test_falls_back_to_email_prefix_for_display_name(self) -> None:
"""get_user_info should use email prefix when name is missing."""
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
adapter = GoogleCalendarAdapter()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"email": "john.doe@example.com",
# No "name" field
}
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
mock_get.return_value = mock_response
email, display_name = await adapter.get_user_info("access-token")
assert email == "john.doe@example.com", "should return email"
assert display_name == "John Doe", "should format email prefix as title"
@pytest.mark.asyncio
async def test_raises_on_expired_token(self) -> None:
"""get_user_info should raise error on 401 response."""
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
from noteflow.infrastructure.calendar.google_adapter import GoogleCalendarError
adapter = GoogleCalendarAdapter()
mock_response = MagicMock()
mock_response.status_code = 401
mock_response.text = "Token expired"
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
mock_get.return_value = mock_response
with pytest.raises(GoogleCalendarError, match="expired or invalid"):
await adapter.get_user_info("expired-token")
@pytest.mark.asyncio
async def test_raises_on_api_error(self) -> None:
"""get_user_info should raise error on non-200 response."""
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
from noteflow.infrastructure.calendar.google_adapter import GoogleCalendarError
adapter = GoogleCalendarAdapter()
mock_response = MagicMock()
mock_response.status_code = 500
mock_response.text = "Internal server error"
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
mock_get.return_value = mock_response
with pytest.raises(GoogleCalendarError, match="API error"):
await adapter.get_user_info("access-token")
@pytest.mark.asyncio
async def test_raises_on_invalid_response_type(self) -> None:
"""get_user_info should raise error when response is not dict."""
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
from noteflow.infrastructure.calendar.google_adapter import GoogleCalendarError
adapter = GoogleCalendarAdapter()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = ["not", "a", "dict"]
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
mock_get.return_value = mock_response
with pytest.raises(GoogleCalendarError, match="Invalid userinfo"):
await adapter.get_user_info("access-token")
@pytest.mark.asyncio
async def test_raises_on_missing_email(self) -> None:
"""get_user_info should raise error when email is missing."""
from noteflow.infrastructure.calendar import GoogleCalendarAdapter
from noteflow.infrastructure.calendar.google_adapter import GoogleCalendarError
adapter = GoogleCalendarAdapter()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"name": "No Email User"}
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
mock_get.return_value = mock_response
with pytest.raises(GoogleCalendarError, match="No email"):
await adapter.get_user_info("access-token")

View File

@@ -0,0 +1,159 @@
"""Tests for Outlook Calendar adapter.
Tests cover:
- get_user_info: Fetching user email and display name from Microsoft Graph API
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# =============================================================================
# Test: get_user_info
# =============================================================================
class TestOutlookCalendarAdapterGetUserInfo:
"""Tests for OutlookCalendarAdapter.get_user_info."""
@pytest.mark.asyncio
async def test_returns_email_and_display_name(self) -> None:
"""get_user_info should return email and display name."""
from noteflow.infrastructure.calendar import OutlookCalendarAdapter
adapter = OutlookCalendarAdapter()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"mail": "user@example.com",
"displayName": "Test User",
}
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
mock_get.return_value = mock_response
email, display_name = await adapter.get_user_info("access-token")
assert email == "user@example.com", "should return user email"
assert display_name == "Test User", "should return display name"
@pytest.mark.asyncio
async def test_uses_userPrincipalName_when_mail_missing(self) -> None:
"""get_user_info should fall back to userPrincipalName for email."""
from noteflow.infrastructure.calendar import OutlookCalendarAdapter
adapter = OutlookCalendarAdapter()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"userPrincipalName": "user@company.onmicrosoft.com",
"displayName": "Test User",
# No "mail" field
}
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
mock_get.return_value = mock_response
email, display_name = await adapter.get_user_info("access-token")
assert email == "user@company.onmicrosoft.com", "should use userPrincipalName"
assert display_name == "Test User", "should return display name"
@pytest.mark.asyncio
async def test_falls_back_to_email_prefix_for_display_name(self) -> None:
"""get_user_info should use email prefix when displayName is missing."""
from noteflow.infrastructure.calendar import OutlookCalendarAdapter
adapter = OutlookCalendarAdapter()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"mail": "john.doe@example.com",
# No "displayName" field
}
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
mock_get.return_value = mock_response
email, display_name = await adapter.get_user_info("access-token")
assert email == "john.doe@example.com", "should return email"
assert display_name == "John Doe", "should format email prefix as title"
@pytest.mark.asyncio
async def test_raises_on_expired_token(self) -> None:
"""get_user_info should raise error on 401 response."""
from noteflow.infrastructure.calendar import OutlookCalendarAdapter
from noteflow.infrastructure.calendar.outlook_adapter import OutlookCalendarError
adapter = OutlookCalendarAdapter()
mock_response = MagicMock()
mock_response.status_code = 401
mock_response.text = "Token expired"
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
mock_get.return_value = mock_response
with pytest.raises(OutlookCalendarError, match="expired or invalid"):
await adapter.get_user_info("expired-token")
@pytest.mark.asyncio
async def test_raises_on_api_error(self) -> None:
"""get_user_info should raise error on non-200 response."""
from noteflow.infrastructure.calendar import OutlookCalendarAdapter
from noteflow.infrastructure.calendar.outlook_adapter import OutlookCalendarError
adapter = OutlookCalendarAdapter()
mock_response = MagicMock()
mock_response.status_code = 500
mock_response.text = "Internal server error"
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
mock_get.return_value = mock_response
with pytest.raises(OutlookCalendarError, match="API error"):
await adapter.get_user_info("access-token")
@pytest.mark.asyncio
async def test_raises_on_invalid_response_type(self) -> None:
"""get_user_info should raise error when response is not dict."""
from noteflow.infrastructure.calendar import OutlookCalendarAdapter
from noteflow.infrastructure.calendar.outlook_adapter import OutlookCalendarError
adapter = OutlookCalendarAdapter()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = ["not", "a", "dict"]
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
mock_get.return_value = mock_response
with pytest.raises(OutlookCalendarError, match="Invalid user profile"):
await adapter.get_user_info("access-token")
@pytest.mark.asyncio
async def test_raises_on_missing_email(self) -> None:
"""get_user_info should raise error when no email fields present."""
from noteflow.infrastructure.calendar import OutlookCalendarAdapter
from noteflow.infrastructure.calendar.outlook_adapter import OutlookCalendarError
adapter = OutlookCalendarAdapter()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"displayName": "No Email User",
# Neither "mail" nor "userPrincipalName"
}
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
mock_get.return_value = mock_response
with pytest.raises(OutlookCalendarError, match="No email"):
await adapter.get_user_info("access-token")

View File

@@ -0,0 +1 @@
"""Tests for diarization infrastructure."""

View File

@@ -0,0 +1,379 @@
"""Tests for diarization compatibility patches.
Tests cover:
- _patch_torchaudio: AudioMetaData class injection
- _patch_torch_load: weights_only=False default for PyTorch 2.6+
- _patch_huggingface_auth: use_auth_token → token parameter conversion
- _patch_speechbrain_backend: torchaudio backend API restoration
- apply_patches: Idempotency and warning suppression
- ensure_compatibility: Alias for apply_patches
"""
from __future__ import annotations
import sys
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch
import pytest
from noteflow.infrastructure.diarization._compat import (
AudioMetaData,
_patch_huggingface_auth,
_patch_speechbrain_backend,
_patch_torch_load,
_patch_torchaudio,
apply_patches,
ensure_compatibility,
)
if TYPE_CHECKING:
from collections.abc import Generator
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture
def reset_patches_state() -> Generator[None, None, None]:
"""Reset _patches_applied state before and after tests."""
import noteflow.infrastructure.diarization._compat as compat_module
original_state = compat_module._patches_applied
compat_module._patches_applied = False
yield
compat_module._patches_applied = original_state
@pytest.fixture
def mock_torchaudio() -> MagicMock:
"""Create mock torchaudio module without AudioMetaData."""
mock = MagicMock(spec=[]) # Empty spec means no auto-attributes
return mock
@pytest.fixture
def mock_torch() -> MagicMock:
"""Create mock torch module."""
mock = MagicMock()
mock.__version__ = "2.6.0"
mock.load = MagicMock(return_value={"model": "weights"})
return mock
@pytest.fixture
def mock_huggingface_hub() -> MagicMock:
"""Create mock huggingface_hub module."""
mock = MagicMock()
mock.hf_hub_download = MagicMock(return_value="/path/to/file")
return mock
# =============================================================================
# Test: AudioMetaData Dataclass
# =============================================================================
class TestAudioMetaData:
"""Tests for the replacement AudioMetaData dataclass."""
def test_audiometadata_has_required_fields(self) -> None:
"""AudioMetaData has all fields expected by pyannote.audio."""
metadata = AudioMetaData(
sample_rate=16000,
num_frames=48000,
num_channels=1,
bits_per_sample=16,
encoding="PCM_S",
)
assert metadata.sample_rate == 16000, "should store sample_rate"
assert metadata.num_frames == 48000, "should store num_frames"
assert metadata.num_channels == 1, "should store num_channels"
assert metadata.bits_per_sample == 16, "should store bits_per_sample"
assert metadata.encoding == "PCM_S", "should store encoding"
def test_audiometadata_is_immutable(self) -> None:
"""AudioMetaData fields cannot be modified after creation."""
metadata = AudioMetaData(
sample_rate=16000,
num_frames=48000,
num_channels=1,
bits_per_sample=16,
encoding="PCM_S",
)
# Dataclass is not frozen, so this is documentation of expected behavior
# If it becomes frozen, this test validates that
metadata.sample_rate = 44100 # May or may not raise depending on frozen
# =============================================================================
# Test: _patch_torchaudio
# =============================================================================
class TestPatchTorchaudio:
"""Tests for torchaudio AudioMetaData patching."""
def test_patches_audiometadata_when_missing(
self, mock_torchaudio: MagicMock
) -> None:
"""_patch_torchaudio adds AudioMetaData when not present."""
with patch.dict(sys.modules, {"torchaudio": mock_torchaudio}):
_patch_torchaudio()
assert hasattr(
mock_torchaudio, "AudioMetaData"
), "should add AudioMetaData"
assert (
mock_torchaudio.AudioMetaData is AudioMetaData
), "should use our AudioMetaData class"
def test_does_not_override_existing_audiometadata(self) -> None:
"""_patch_torchaudio preserves existing AudioMetaData if present."""
mock = MagicMock()
existing_class = type("ExistingAudioMetaData", (), {})
mock.AudioMetaData = existing_class
with patch.dict(sys.modules, {"torchaudio": mock}):
_patch_torchaudio()
assert (
mock.AudioMetaData is existing_class
), "should not override existing AudioMetaData"
def test_handles_import_error_gracefully(self) -> None:
"""_patch_torchaudio doesn't raise when torchaudio not installed."""
# Remove torchaudio from modules if present
with patch.dict(sys.modules, {"torchaudio": None}):
# Should not raise
_patch_torchaudio()
# =============================================================================
# Test: _patch_torch_load
# =============================================================================
class TestPatchTorchLoad:
"""Tests for torch.load weights_only patching."""
def test_patches_torch_load_for_pytorch_2_6_plus(
self, mock_torch: MagicMock
) -> None:
"""_patch_torch_load adds weights_only=False default for PyTorch 2.6+."""
original_load = mock_torch.load
with patch.dict(sys.modules, {"torch": mock_torch}):
with patch("packaging.version.Version") as mock_version:
mock_version.return_value = mock_version
mock_version.__ge__ = MagicMock(return_value=True)
_patch_torch_load()
# Verify torch.load was replaced (not the same function)
assert mock_torch.load is not original_load, "load should be patched"
def test_does_not_patch_older_pytorch(self) -> None:
"""_patch_torch_load skips patching for PyTorch < 2.6."""
mock = MagicMock()
mock.__version__ = "2.5.0"
original_load = mock.load
with patch.dict(sys.modules, {"torch": mock}):
with patch("packaging.version.Version") as mock_version:
mock_version.return_value = mock_version
mock_version.__ge__ = MagicMock(return_value=False)
_patch_torch_load()
# load should not have been replaced
assert mock.load is original_load, "should not patch older PyTorch"
def test_handles_import_error_gracefully(self) -> None:
"""_patch_torch_load doesn't raise when torch not installed."""
with patch.dict(sys.modules, {"torch": None}):
_patch_torch_load()
# =============================================================================
# Test: _patch_huggingface_auth
# =============================================================================
class TestPatchHuggingfaceAuth:
"""Tests for huggingface_hub use_auth_token patching."""
def test_converts_use_auth_token_to_token(
self, mock_huggingface_hub: MagicMock
) -> None:
"""_patch_huggingface_auth converts use_auth_token to token parameter."""
original_download = mock_huggingface_hub.hf_hub_download
with patch.dict(sys.modules, {"huggingface_hub": mock_huggingface_hub}):
_patch_huggingface_auth()
# Call with legacy use_auth_token
mock_huggingface_hub.hf_hub_download(
repo_id="test/repo",
filename="model.bin",
use_auth_token="my_token",
)
# Verify original was called with token instead
original_download.assert_called_once()
call_kwargs = original_download.call_args[1]
assert "token" in call_kwargs, "should convert to token parameter"
assert call_kwargs["token"] == "my_token", "should preserve token value"
assert (
"use_auth_token" not in call_kwargs
), "should remove use_auth_token"
def test_preserves_token_parameter(
self, mock_huggingface_hub: MagicMock
) -> None:
"""_patch_huggingface_auth preserves token if already using new API."""
original_download = mock_huggingface_hub.hf_hub_download
with patch.dict(sys.modules, {"huggingface_hub": mock_huggingface_hub}):
_patch_huggingface_auth()
mock_huggingface_hub.hf_hub_download(
repo_id="test/repo",
filename="model.bin",
token="my_token",
)
original_download.assert_called_once()
call_kwargs = original_download.call_args[1]
assert call_kwargs["token"] == "my_token", "should preserve token"
def test_handles_import_error_gracefully(self) -> None:
"""_patch_huggingface_auth doesn't raise when huggingface_hub not installed."""
with patch.dict(sys.modules, {"huggingface_hub": None}):
_patch_huggingface_auth()
# =============================================================================
# Test: _patch_speechbrain_backend
# =============================================================================
class TestPatchSpeechbrainBackend:
"""Tests for torchaudio backend API patching."""
def test_patches_list_audio_backends(self, mock_torchaudio: MagicMock) -> None:
"""_patch_speechbrain_backend adds list_audio_backends when missing."""
with patch.dict(sys.modules, {"torchaudio": mock_torchaudio}):
_patch_speechbrain_backend()
assert hasattr(
mock_torchaudio, "list_audio_backends"
), "should add list_audio_backends"
result = mock_torchaudio.list_audio_backends()
assert isinstance(result, list), "should return list"
def test_patches_get_audio_backend(self, mock_torchaudio: MagicMock) -> None:
"""_patch_speechbrain_backend adds get_audio_backend when missing."""
with patch.dict(sys.modules, {"torchaudio": mock_torchaudio}):
_patch_speechbrain_backend()
assert hasattr(
mock_torchaudio, "get_audio_backend"
), "should add get_audio_backend"
result = mock_torchaudio.get_audio_backend()
assert result is None, "should return None"
def test_patches_set_audio_backend(self, mock_torchaudio: MagicMock) -> None:
"""_patch_speechbrain_backend adds set_audio_backend when missing."""
with patch.dict(sys.modules, {"torchaudio": mock_torchaudio}):
_patch_speechbrain_backend()
assert hasattr(
mock_torchaudio, "set_audio_backend"
), "should add set_audio_backend"
# Should not raise
mock_torchaudio.set_audio_backend("sox")
def test_does_not_override_existing_functions(self) -> None:
"""_patch_speechbrain_backend preserves existing backend functions."""
mock = MagicMock()
existing_list = MagicMock(return_value=["ffmpeg"])
mock.list_audio_backends = existing_list
with patch.dict(sys.modules, {"torchaudio": mock}):
_patch_speechbrain_backend()
assert (
mock.list_audio_backends is existing_list
), "should not override existing function"
# =============================================================================
# Test: apply_patches
# =============================================================================
class TestApplyPatches:
"""Tests for the main apply_patches function."""
def test_apply_patches_is_idempotent(
self, reset_patches_state: None
) -> None:
"""apply_patches only applies patches once."""
import noteflow.infrastructure.diarization._compat as compat_module
with patch.object(compat_module, "_patch_torchaudio") as mock_torchaudio:
with patch.object(compat_module, "_patch_torch_load") as mock_torch:
with patch.object(
compat_module, "_patch_huggingface_auth"
) as mock_hf:
with patch.object(
compat_module, "_patch_speechbrain_backend"
) as mock_sb:
apply_patches()
apply_patches() # Second call
apply_patches() # Third call
# Each patch function should only be called once
mock_torchaudio.assert_called_once()
mock_torch.assert_called_once()
mock_hf.assert_called_once()
mock_sb.assert_called_once()
def test_apply_patches_sets_flag(self, reset_patches_state: None) -> None:
"""apply_patches sets _patches_applied flag."""
import noteflow.infrastructure.diarization._compat as compat_module
assert compat_module._patches_applied is False, "should start False"
with patch.object(compat_module, "_patch_torchaudio"):
with patch.object(compat_module, "_patch_torch_load"):
with patch.object(compat_module, "_patch_huggingface_auth"):
with patch.object(compat_module, "_patch_speechbrain_backend"):
apply_patches()
assert compat_module._patches_applied is True, "should be True after apply"
# =============================================================================
# Test: ensure_compatibility
# =============================================================================
class TestEnsureCompatibility:
"""Tests for the ensure_compatibility entry point."""
def test_ensure_compatibility_calls_apply_patches(
self, reset_patches_state: None
) -> None:
"""ensure_compatibility delegates to apply_patches."""
import noteflow.infrastructure.diarization._compat as compat_module
with patch.object(compat_module, "apply_patches") as mock_apply:
ensure_compatibility()
mock_apply.assert_called_once()

View File

@@ -13,9 +13,18 @@ class SpeakerDiarizationConfig:
*,
segmentation: SegmentationModel,
embedding: EmbeddingModel,
step: float,
latency: float,
device: TorchDevice,
duration: float = ...,
step: float = ...,
latency: float = ...,
tau_active: float = ...,
rho_update: float = ...,
delta_new: float = ...,
gamma: float = ...,
beta: float = ...,
max_speakers: int = ...,
normalize_embedding_weights: bool = ...,
device: TorchDevice | None = ...,
sample_rate: int = ...,
) -> None: ...