chore: update client submodule and add authentication service helpers
- Updated the client submodule to the latest commit for improved features and stability. - Introduced new authentication service helpers for user and workspace management, enhancing the overall authentication flow. - Added shared authentication constants for better maintainability and clarity in the codebase.
This commit is contained in:
File diff suppressed because one or more lines are too long
2
client
2
client
Submodule client updated: 5ab973a1d7...c53b16693a
8
src/noteflow/application/services/auth_constants.py
Normal file
8
src/noteflow/application/services/auth_constants.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Shared authentication constants."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
DEFAULT_USER_ID = UUID("00000000-0000-0000-0000-000000000001")
|
||||
DEFAULT_WORKSPACE_ID = UUID("00000000-0000-0000-0000-000000000001")
|
||||
211
src/noteflow/application/services/auth_helpers.py
Normal file
211
src/noteflow/application/services/auth_helpers.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""Internal helpers for auth service operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from noteflow.domain.entities.integration import Integration, IntegrationType
|
||||
from noteflow.domain.identity.entities import User
|
||||
from noteflow.domain.value_objects import OAuthProvider, OAuthTokens
|
||||
from noteflow.infrastructure.calendar import OAuthManager
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
from .auth_constants import DEFAULT_USER_ID, DEFAULT_WORKSPACE_ID
|
||||
from .auth_types import AuthResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from noteflow.domain.ports.unit_of_work import UnitOfWork
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def resolve_provider_email(integration: Integration) -> str:
|
||||
"""Resolve provider email with a consistent fallback."""
|
||||
return integration.provider_email or "User"
|
||||
|
||||
|
||||
def resolve_user_id_from_integration(integration: Integration) -> UUID:
|
||||
"""Resolve the user ID from integration config, falling back to default."""
|
||||
user_id_str = integration.config.get("user_id")
|
||||
return UUID(user_id_str) if user_id_str else DEFAULT_USER_ID
|
||||
|
||||
|
||||
async def get_or_create_user_id(
|
||||
uow: UnitOfWork,
|
||||
email: str,
|
||||
display_name: str,
|
||||
) -> UUID:
|
||||
"""Fetch an existing user or create a new one, returning user ID."""
|
||||
user = await uow.users.get_by_email(email) if uow.supports_users else None
|
||||
|
||||
if user is None:
|
||||
user_id = uuid4()
|
||||
if uow.supports_users:
|
||||
user = User(
|
||||
id=user_id,
|
||||
email=email,
|
||||
display_name=display_name,
|
||||
is_default=False,
|
||||
)
|
||||
await uow.users.create(user)
|
||||
logger.info("Created new user: %s (%s)", display_name, email)
|
||||
else:
|
||||
user_id = DEFAULT_USER_ID
|
||||
return user_id
|
||||
|
||||
user_id = user.id
|
||||
if user.display_name != display_name:
|
||||
user.display_name = display_name
|
||||
await uow.users.update(user)
|
||||
return user_id
|
||||
|
||||
|
||||
async def get_or_create_default_workspace_id(
|
||||
uow: UnitOfWork,
|
||||
user_id: UUID,
|
||||
) -> UUID:
|
||||
"""Fetch or create the default workspace for a user."""
|
||||
if not uow.supports_workspaces:
|
||||
return DEFAULT_WORKSPACE_ID
|
||||
|
||||
workspace = await uow.workspaces.get_default_for_user(user_id)
|
||||
if workspace:
|
||||
return workspace.id
|
||||
|
||||
workspace_id = uuid4()
|
||||
await uow.workspaces.create(
|
||||
workspace_id=workspace_id,
|
||||
name="Personal",
|
||||
owner_id=user_id,
|
||||
is_default=True,
|
||||
)
|
||||
logger.info(
|
||||
"Created default workspace for user_id=%s, workspace_id=%s",
|
||||
user_id,
|
||||
workspace_id,
|
||||
)
|
||||
return workspace_id
|
||||
|
||||
|
||||
async def get_or_create_auth_integration(
|
||||
uow: UnitOfWork,
|
||||
provider: str,
|
||||
workspace_id: UUID,
|
||||
user_id: UUID,
|
||||
provider_email: str,
|
||||
) -> Integration:
|
||||
"""Fetch or create the auth integration for a provider."""
|
||||
integration = await uow.integrations.get_by_provider(
|
||||
provider=provider,
|
||||
integration_type=IntegrationType.AUTH.value,
|
||||
)
|
||||
|
||||
if integration is None:
|
||||
integration = Integration.create(
|
||||
workspace_id=workspace_id,
|
||||
name=f"{provider.title()} Auth",
|
||||
integration_type=IntegrationType.AUTH,
|
||||
config={"provider": provider, "user_id": str(user_id)},
|
||||
)
|
||||
await uow.integrations.create(integration)
|
||||
else:
|
||||
integration.config["provider"] = provider
|
||||
integration.config["user_id"] = str(user_id)
|
||||
|
||||
integration.connect(provider_email=provider_email)
|
||||
await uow.integrations.update(integration)
|
||||
return integration
|
||||
|
||||
|
||||
async def store_integration_tokens(
|
||||
uow: UnitOfWork,
|
||||
integration: Integration,
|
||||
tokens: OAuthTokens,
|
||||
) -> None:
|
||||
"""Persist updated tokens for an integration."""
|
||||
await uow.integrations.set_secrets(
|
||||
integration_id=integration.id,
|
||||
secrets=tokens.to_secrets_dict(),
|
||||
)
|
||||
|
||||
|
||||
async def find_connected_auth_integration(
|
||||
uow: UnitOfWork,
|
||||
) -> tuple[str, Integration] | None:
|
||||
"""Return the first connected auth integration and provider name."""
|
||||
if not getattr(uow, "supports_integrations", False):
|
||||
return None
|
||||
|
||||
for provider in (OAuthProvider.GOOGLE.value, OAuthProvider.OUTLOOK.value):
|
||||
integration = await uow.integrations.get_by_provider(
|
||||
provider=provider,
|
||||
integration_type=IntegrationType.AUTH.value,
|
||||
)
|
||||
if integration and integration.is_connected:
|
||||
return provider, integration
|
||||
return None
|
||||
|
||||
|
||||
async def resolve_display_name(
|
||||
uow: UnitOfWork,
|
||||
user_id_str: str | None,
|
||||
fallback: str,
|
||||
) -> str:
|
||||
"""Resolve display name from user repository if available."""
|
||||
if not (uow.supports_users and user_id_str):
|
||||
return fallback
|
||||
|
||||
user_id = UUID(user_id_str)
|
||||
user = await uow.users.get(user_id)
|
||||
return user.display_name if user else fallback
|
||||
|
||||
|
||||
async def refresh_tokens_for_integration(
|
||||
uow: UnitOfWork,
|
||||
oauth_provider: OAuthProvider,
|
||||
integration: Integration,
|
||||
oauth_manager: OAuthManager,
|
||||
) -> AuthResult | None:
|
||||
"""Refresh tokens for a connected integration if needed."""
|
||||
secrets = await uow.integrations.get_secrets(integration.id)
|
||||
if not secrets:
|
||||
return None
|
||||
|
||||
try:
|
||||
tokens = OAuthTokens.from_secrets_dict(secrets)
|
||||
except (KeyError, ValueError):
|
||||
return None
|
||||
|
||||
if not tokens.refresh_token:
|
||||
return None
|
||||
|
||||
if not tokens.is_expired(buffer_seconds=300):
|
||||
logger.debug(
|
||||
"auth_token_still_valid",
|
||||
provider=oauth_provider.value,
|
||||
expires_at=tokens.expires_at.isoformat() if tokens.expires_at else None,
|
||||
)
|
||||
user_id = resolve_user_id_from_integration(integration)
|
||||
return AuthResult(
|
||||
user_id=user_id,
|
||||
workspace_id=DEFAULT_WORKSPACE_ID,
|
||||
display_name=resolve_provider_email(integration),
|
||||
email=integration.provider_email,
|
||||
)
|
||||
|
||||
new_tokens = await oauth_manager.refresh_tokens(
|
||||
provider=oauth_provider,
|
||||
refresh_token=tokens.refresh_token,
|
||||
)
|
||||
|
||||
await store_integration_tokens(uow, integration, new_tokens)
|
||||
await uow.commit()
|
||||
|
||||
user_id = resolve_user_id_from_integration(integration)
|
||||
return AuthResult(
|
||||
user_id=user_id,
|
||||
workspace_id=DEFAULT_WORKSPACE_ID,
|
||||
display_name=resolve_provider_email(integration),
|
||||
email=integration.provider_email,
|
||||
)
|
||||
@@ -7,13 +7,11 @@ IntegrationType.AUTH and manages User entities.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, TypedDict, Unpack
|
||||
from uuid import UUID, uuid4
|
||||
from uuid import UUID
|
||||
|
||||
from noteflow.config.constants import OAUTH_FIELD_ACCESS_TOKEN
|
||||
from noteflow.domain.entities.integration import Integration, IntegrationType
|
||||
from noteflow.domain.identity.entities import User
|
||||
from noteflow.domain.entities.integration import IntegrationType
|
||||
from noteflow.domain.value_objects import OAuthProvider, OAuthTokens
|
||||
from noteflow.infrastructure.calendar import OAuthManager
|
||||
from noteflow.infrastructure.calendar.google_adapter import GoogleCalendarError
|
||||
@@ -21,6 +19,18 @@ from noteflow.infrastructure.calendar.oauth_manager import OAuthError
|
||||
from noteflow.infrastructure.calendar.outlook_adapter import OutlookCalendarError
|
||||
from noteflow.infrastructure.logging import get_logger
|
||||
|
||||
from .auth_constants import DEFAULT_USER_ID, DEFAULT_WORKSPACE_ID
|
||||
from .auth_helpers import (
|
||||
find_connected_auth_integration,
|
||||
get_or_create_auth_integration,
|
||||
get_or_create_default_workspace_id,
|
||||
get_or_create_user_id,
|
||||
refresh_tokens_for_integration,
|
||||
store_integration_tokens,
|
||||
resolve_display_name,
|
||||
)
|
||||
from .auth_types import AuthResult, LogoutResult, UserInfo
|
||||
|
||||
|
||||
class _AuthServiceDepsKwargs(TypedDict, total=False):
|
||||
"""Optional dependency overrides for AuthService."""
|
||||
@@ -37,59 +47,10 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# Default IDs for local-first mode
|
||||
DEFAULT_USER_ID = UUID("00000000-0000-0000-0000-000000000001")
|
||||
DEFAULT_WORKSPACE_ID = UUID("00000000-0000-0000-0000-000000000001")
|
||||
|
||||
|
||||
class AuthServiceError(Exception):
|
||||
"""Auth service operation failed."""
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AuthResult:
|
||||
"""Result of successful authentication.
|
||||
|
||||
Note: Tokens are stored securely in IntegrationSecretModel and are NOT
|
||||
exposed to callers. Use get_current_user() to check auth status.
|
||||
"""
|
||||
|
||||
user_id: UUID
|
||||
workspace_id: UUID
|
||||
display_name: str
|
||||
email: str | None
|
||||
is_authenticated: bool = True
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class UserInfo:
|
||||
"""Current user information."""
|
||||
|
||||
user_id: UUID
|
||||
workspace_id: UUID
|
||||
display_name: str
|
||||
email: str | None
|
||||
is_authenticated: bool
|
||||
provider: str | None
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class LogoutResult:
|
||||
"""Result of logout operation.
|
||||
|
||||
Provides visibility into both local logout and remote token revocation.
|
||||
"""
|
||||
|
||||
logged_out: bool
|
||||
"""Whether local logout succeeded (integration deleted)."""
|
||||
|
||||
tokens_revoked: bool
|
||||
"""Whether remote token revocation succeeded."""
|
||||
|
||||
revocation_error: str | None = None
|
||||
"""Error message if revocation failed (for logging/debugging)."""
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""Authentication service for OAuth-based user login.
|
||||
|
||||
@@ -261,79 +222,16 @@ class AuthService:
|
||||
) -> tuple[UUID, UUID]:
|
||||
"""Create or update user and store auth tokens."""
|
||||
async with self._uow_factory() as uow:
|
||||
# Find or create user by email
|
||||
user = None
|
||||
if uow.supports_users:
|
||||
user = await uow.users.get_by_email(email)
|
||||
|
||||
if user is None:
|
||||
# Create new user
|
||||
user_id = uuid4()
|
||||
if uow.supports_users:
|
||||
user = User(
|
||||
id=user_id,
|
||||
email=email,
|
||||
display_name=display_name,
|
||||
is_default=False,
|
||||
)
|
||||
await uow.users.create(user)
|
||||
logger.info("Created new user: %s (%s)", display_name, email)
|
||||
else:
|
||||
user_id = DEFAULT_USER_ID
|
||||
else:
|
||||
user_id = user.id
|
||||
# Update display name if changed
|
||||
if user.display_name != display_name:
|
||||
user.display_name = display_name
|
||||
await uow.users.update(user)
|
||||
|
||||
# Get or create default workspace for this user
|
||||
workspace_id = DEFAULT_WORKSPACE_ID
|
||||
if uow.supports_workspaces:
|
||||
workspace = await uow.workspaces.get_default_for_user(user_id)
|
||||
if workspace:
|
||||
workspace_id = workspace.id
|
||||
else:
|
||||
# Create default "Personal" workspace for new user
|
||||
workspace_id = uuid4()
|
||||
await uow.workspaces.create(
|
||||
workspace_id=workspace_id,
|
||||
name="Personal",
|
||||
owner_id=user_id,
|
||||
is_default=True,
|
||||
)
|
||||
logger.info(
|
||||
"Created default workspace for user_id=%s, workspace_id=%s",
|
||||
user_id,
|
||||
workspace_id,
|
||||
)
|
||||
|
||||
# Store auth integration with tokens
|
||||
integration = await uow.integrations.get_by_provider(
|
||||
user_id = await get_or_create_user_id(uow, email, display_name)
|
||||
workspace_id = await get_or_create_default_workspace_id(uow, user_id)
|
||||
integration = await get_or_create_auth_integration(
|
||||
uow,
|
||||
provider=provider,
|
||||
integration_type=IntegrationType.AUTH.value,
|
||||
)
|
||||
|
||||
if integration is None:
|
||||
integration = Integration.create(
|
||||
workspace_id=workspace_id,
|
||||
name=f"{provider.title()} Auth",
|
||||
integration_type=IntegrationType.AUTH,
|
||||
config={"provider": provider, "user_id": str(user_id)},
|
||||
)
|
||||
await uow.integrations.create(integration)
|
||||
else:
|
||||
integration.config["provider"] = provider
|
||||
integration.config["user_id"] = str(user_id)
|
||||
|
||||
integration.connect(provider_email=email)
|
||||
await uow.integrations.update(integration)
|
||||
|
||||
# Store tokens
|
||||
await uow.integrations.set_secrets(
|
||||
integration_id=integration.id,
|
||||
secrets=tokens.to_secrets_dict(),
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id,
|
||||
provider_email=email,
|
||||
)
|
||||
await store_integration_tokens(uow, integration, tokens)
|
||||
await uow.commit()
|
||||
|
||||
return user_id, workspace_id
|
||||
@@ -345,42 +243,33 @@ class AuthService:
|
||||
UserInfo with current user details or local default.
|
||||
"""
|
||||
async with self._uow_factory() as uow:
|
||||
# Look for any connected auth integration
|
||||
for provider in [OAuthProvider.GOOGLE.value, OAuthProvider.OUTLOOK.value]:
|
||||
integration = await uow.integrations.get_by_provider(
|
||||
found = await find_connected_auth_integration(uow)
|
||||
if found:
|
||||
provider, integration = found
|
||||
user_id_str = integration.config.get("user_id")
|
||||
user_id = UUID(user_id_str) if user_id_str else DEFAULT_USER_ID
|
||||
display_name = await resolve_display_name(
|
||||
uow,
|
||||
user_id_str,
|
||||
fallback="Authenticated User",
|
||||
)
|
||||
return UserInfo(
|
||||
user_id=user_id,
|
||||
workspace_id=DEFAULT_WORKSPACE_ID,
|
||||
display_name=display_name,
|
||||
email=integration.provider_email,
|
||||
is_authenticated=True,
|
||||
provider=provider,
|
||||
integration_type=IntegrationType.AUTH.value,
|
||||
)
|
||||
|
||||
if integration and integration.is_connected:
|
||||
user_id_str = integration.config.get("user_id")
|
||||
user_id = UUID(user_id_str) if user_id_str else DEFAULT_USER_ID
|
||||
|
||||
# Get user details
|
||||
display_name = "Authenticated User"
|
||||
if uow.supports_users and user_id_str:
|
||||
user = await uow.users.get(user_id)
|
||||
if user:
|
||||
display_name = user.display_name
|
||||
|
||||
return UserInfo(
|
||||
user_id=user_id,
|
||||
workspace_id=DEFAULT_WORKSPACE_ID,
|
||||
display_name=display_name,
|
||||
email=integration.provider_email,
|
||||
is_authenticated=True,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
# Return local default
|
||||
return UserInfo(
|
||||
user_id=DEFAULT_USER_ID,
|
||||
workspace_id=DEFAULT_WORKSPACE_ID,
|
||||
display_name="Local User",
|
||||
email=None,
|
||||
is_authenticated=False,
|
||||
provider=None,
|
||||
)
|
||||
return UserInfo(
|
||||
user_id=DEFAULT_USER_ID,
|
||||
workspace_id=DEFAULT_WORKSPACE_ID,
|
||||
display_name="Local User",
|
||||
email=None,
|
||||
is_authenticated=False,
|
||||
provider=None,
|
||||
)
|
||||
|
||||
async def logout(self, provider: str | None = None) -> LogoutResult:
|
||||
"""Logout and revoke auth tokens.
|
||||
@@ -495,57 +384,13 @@ class AuthService:
|
||||
if integration is None or not integration.is_connected:
|
||||
return None
|
||||
|
||||
secrets = await uow.integrations.get_secrets(integration.id)
|
||||
if not secrets:
|
||||
return None
|
||||
|
||||
try:
|
||||
tokens = OAuthTokens.from_secrets_dict(secrets)
|
||||
except (KeyError, ValueError):
|
||||
return None
|
||||
|
||||
if not tokens.refresh_token:
|
||||
return None
|
||||
|
||||
# Only refresh if token is expired or will expire within 5 minutes
|
||||
if not tokens.is_expired(buffer_seconds=300):
|
||||
logger.debug(
|
||||
"auth_token_still_valid",
|
||||
provider=provider,
|
||||
expires_at=tokens.expires_at.isoformat() if tokens.expires_at else None,
|
||||
return await refresh_tokens_for_integration(
|
||||
uow,
|
||||
oauth_provider=oauth_provider,
|
||||
integration=integration,
|
||||
oauth_manager=self._oauth_manager,
|
||||
)
|
||||
# Return existing auth info without refreshing
|
||||
user_id_str = integration.config.get("user_id")
|
||||
user_id = UUID(user_id_str) if user_id_str else DEFAULT_USER_ID
|
||||
return AuthResult(
|
||||
user_id=user_id,
|
||||
workspace_id=DEFAULT_WORKSPACE_ID,
|
||||
display_name=integration.provider_email or "User",
|
||||
email=integration.provider_email,
|
||||
)
|
||||
|
||||
try:
|
||||
new_tokens = await self._oauth_manager.refresh_tokens(
|
||||
provider=oauth_provider,
|
||||
refresh_token=tokens.refresh_token,
|
||||
)
|
||||
|
||||
await uow.integrations.set_secrets(
|
||||
integration_id=integration.id,
|
||||
secrets=new_tokens.to_secrets_dict(),
|
||||
)
|
||||
await uow.commit()
|
||||
|
||||
user_id_str = integration.config.get("user_id")
|
||||
user_id = UUID(user_id_str) if user_id_str else DEFAULT_USER_ID
|
||||
|
||||
return AuthResult(
|
||||
user_id=user_id,
|
||||
workspace_id=DEFAULT_WORKSPACE_ID,
|
||||
display_name=integration.provider_email or "User",
|
||||
email=integration.provider_email,
|
||||
)
|
||||
|
||||
except OAuthError as e:
|
||||
integration.mark_error(f"Token refresh failed: {e}")
|
||||
await uow.integrations.update(integration)
|
||||
|
||||
50
src/noteflow/application/services/auth_types.py
Normal file
50
src/noteflow/application/services/auth_types.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Auth service data structures."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AuthResult:
|
||||
"""Result of successful authentication.
|
||||
|
||||
Note: Tokens are stored securely in IntegrationSecretModel and are NOT
|
||||
exposed to callers. Use get_current_user() to check auth status.
|
||||
"""
|
||||
|
||||
user_id: UUID
|
||||
workspace_id: UUID
|
||||
display_name: str
|
||||
email: str | None
|
||||
is_authenticated: bool = True
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class UserInfo:
|
||||
"""Current user information."""
|
||||
|
||||
user_id: UUID
|
||||
workspace_id: UUID
|
||||
display_name: str
|
||||
email: str | None
|
||||
is_authenticated: bool
|
||||
provider: str | None
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class LogoutResult:
|
||||
"""Result of logout operation.
|
||||
|
||||
Provides visibility into both local logout and remote token revocation.
|
||||
"""
|
||||
|
||||
logged_out: bool
|
||||
"""Whether local logout succeeded (integration deleted)."""
|
||||
|
||||
tokens_revoked: bool
|
||||
"""Whether remote token revocation succeeded."""
|
||||
|
||||
revocation_error: str | None = None
|
||||
"""Error message if revocation failed (for logging/debugging)."""
|
||||
@@ -54,14 +54,16 @@ class ProcessingStepState:
|
||||
@classmethod
|
||||
def pending(cls) -> ProcessingStepState:
|
||||
"""Create a pending step state."""
|
||||
return cls(status=ProcessingStepStatus.PENDING)
|
||||
status = ProcessingStepStatus.PENDING
|
||||
return cls(status=status)
|
||||
|
||||
@classmethod
|
||||
def running(cls, started_at: datetime | None = None) -> ProcessingStepState:
|
||||
"""Create a running step state."""
|
||||
started = started_at or utc_now()
|
||||
return cls(
|
||||
status=ProcessingStepStatus.RUNNING,
|
||||
started_at=started_at or utc_now(),
|
||||
started_at=started,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -71,10 +73,11 @@ class ProcessingStepState:
|
||||
completed_at: datetime | None = None,
|
||||
) -> ProcessingStepState:
|
||||
"""Create a completed step state."""
|
||||
completed = completed_at or utc_now()
|
||||
return cls(
|
||||
status=ProcessingStepStatus.COMPLETED,
|
||||
started_at=started_at,
|
||||
completed_at=completed_at or utc_now(),
|
||||
completed_at=completed,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -84,17 +87,29 @@ class ProcessingStepState:
|
||||
started_at: datetime | None = None,
|
||||
) -> ProcessingStepState:
|
||||
"""Create a failed step state."""
|
||||
completed = utc_now()
|
||||
return cls(
|
||||
status=ProcessingStepStatus.FAILED,
|
||||
error_message=error_message,
|
||||
started_at=started_at,
|
||||
completed_at=utc_now(),
|
||||
completed_at=completed,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def skipped(cls) -> ProcessingStepState:
|
||||
"""Create a skipped step state."""
|
||||
return cls(status=ProcessingStepStatus.SKIPPED)
|
||||
status = ProcessingStepStatus.SKIPPED
|
||||
return cls(status=status)
|
||||
|
||||
def with_error(self, message: str) -> ProcessingStepState:
|
||||
"""Return a failed state derived from this instance."""
|
||||
started_at = self.started_at or utc_now()
|
||||
return ProcessingStepState(
|
||||
status=ProcessingStepStatus.FAILED,
|
||||
error_message=message,
|
||||
started_at=started_at,
|
||||
completed_at=utc_now(),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
@@ -113,7 +128,8 @@ class ProcessingStatus:
|
||||
@classmethod
|
||||
def create_pending(cls) -> ProcessingStatus:
|
||||
"""Create a processing status with all steps pending."""
|
||||
return cls()
|
||||
status = cls()
|
||||
return status
|
||||
|
||||
@property
|
||||
def is_complete(self) -> bool:
|
||||
|
||||
15
src/noteflow/grpc/_identity_singleton.py
Normal file
15
src/noteflow/grpc/_identity_singleton.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Identity service singleton for gRPC runtime."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from noteflow.application.services.identity_service import IdentityService
|
||||
|
||||
_identity_service_instance: IdentityService | None = None
|
||||
|
||||
|
||||
def default_identity_service() -> IdentityService:
|
||||
"""Get or create the default identity service singleton."""
|
||||
global _identity_service_instance
|
||||
if _identity_service_instance is None:
|
||||
_identity_service_instance = IdentityService()
|
||||
return _identity_service_instance
|
||||
@@ -17,6 +17,21 @@ from ._types import GrpcContext
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def _resolve_auth_status(uow: UnitOfWork) -> tuple[bool, str]:
|
||||
"""Resolve authentication status and provider from integrations."""
|
||||
if not getattr(uow, "supports_integrations", False):
|
||||
return False, ""
|
||||
|
||||
for provider in ("google", "outlook"):
|
||||
integration = await uow.integrations.get_by_provider(
|
||||
provider=provider,
|
||||
integration_type=IntegrationType.AUTH.value,
|
||||
)
|
||||
if integration and integration.is_connected:
|
||||
return True, provider
|
||||
return False, ""
|
||||
|
||||
|
||||
class IdentityServicer(Protocol):
|
||||
"""Protocol for hosts that support identity operations."""
|
||||
|
||||
@@ -55,19 +70,7 @@ class IdentityMixin:
|
||||
await uow.commit()
|
||||
|
||||
# Check if user has auth integration (authenticated via OAuth)
|
||||
is_authenticated = False
|
||||
auth_provider = ""
|
||||
|
||||
if hasattr(uow, "supports_integrations") and uow.supports_integrations:
|
||||
for provider in ["google", "outlook"]:
|
||||
integration = await uow.integrations.get_by_provider(
|
||||
provider=provider,
|
||||
integration_type=IntegrationType.AUTH.value,
|
||||
)
|
||||
if integration and integration.is_connected:
|
||||
is_authenticated = True
|
||||
auth_provider = provider
|
||||
break
|
||||
is_authenticated, auth_provider = await _resolve_auth_status(uow)
|
||||
|
||||
logger.debug(
|
||||
"GetCurrentUser: user_id=%s, workspace_id=%s, authenticated=%s",
|
||||
|
||||
@@ -86,6 +86,60 @@ class _HasField(Protocol):
|
||||
def HasField(self, field_name: str) -> bool: ...
|
||||
|
||||
|
||||
async def _parse_project_ids_or_abort(
|
||||
request: noteflow_pb2.ListMeetingsRequest,
|
||||
context: GrpcContext,
|
||||
) -> list[UUID] | None:
|
||||
"""Parse optional project_ids list, aborting on invalid values."""
|
||||
if not request.project_ids:
|
||||
return None
|
||||
|
||||
project_ids: list[UUID] = []
|
||||
for raw_project_id in request.project_ids:
|
||||
try:
|
||||
project_ids.append(UUID(raw_project_id))
|
||||
except ValueError:
|
||||
truncated = (
|
||||
raw_project_id[:8] + "..." if len(raw_project_id) > 8 else raw_project_id
|
||||
)
|
||||
logger.warning(
|
||||
"ListMeetings: invalid project_ids format",
|
||||
project_id_truncated=truncated,
|
||||
project_id_length=len(raw_project_id),
|
||||
)
|
||||
await abort_invalid_argument(
|
||||
context,
|
||||
f"{ERROR_INVALID_PROJECT_ID_PREFIX}{raw_project_id}",
|
||||
)
|
||||
return None
|
||||
|
||||
return project_ids
|
||||
|
||||
|
||||
async def _parse_project_id_or_abort(
|
||||
request: noteflow_pb2.ListMeetingsRequest,
|
||||
context: GrpcContext,
|
||||
) -> UUID | None:
|
||||
"""Parse optional project_id, aborting on invalid values."""
|
||||
if not (cast(_HasField, request).HasField("project_id") and request.project_id):
|
||||
return None
|
||||
|
||||
try:
|
||||
return UUID(request.project_id)
|
||||
except ValueError:
|
||||
truncated = (
|
||||
request.project_id[:8] + "..." if len(request.project_id) > 8 else request.project_id
|
||||
)
|
||||
logger.warning(
|
||||
"ListMeetings: invalid project_id format",
|
||||
project_id_truncated=truncated,
|
||||
project_id_length=len(request.project_id),
|
||||
)
|
||||
error_message = f"{ERROR_INVALID_PROJECT_ID_PREFIX}{request.project_id}"
|
||||
await abort_invalid_argument(context, error_message)
|
||||
return None
|
||||
|
||||
|
||||
class MeetingServicer(Protocol):
|
||||
"""Protocol for hosts that support meeting operations."""
|
||||
|
||||
@@ -283,40 +337,9 @@ class MeetingMixin:
|
||||
state_values = cast(Sequence[int], request.states)
|
||||
states = [MeetingState(s) for s in state_values] if state_values else None
|
||||
project_id: UUID | None = None
|
||||
project_ids: list[UUID] | None = None
|
||||
|
||||
if request.project_ids:
|
||||
project_ids = []
|
||||
for raw_project_id in request.project_ids:
|
||||
try:
|
||||
project_ids.append(UUID(raw_project_id))
|
||||
except ValueError:
|
||||
truncated = raw_project_id[:8] + "..." if len(raw_project_id) > 8 else raw_project_id
|
||||
logger.warning(
|
||||
"ListMeetings: invalid project_ids format",
|
||||
project_id_truncated=truncated,
|
||||
project_id_length=len(raw_project_id),
|
||||
)
|
||||
await abort_invalid_argument(
|
||||
context,
|
||||
f"{ERROR_INVALID_PROJECT_ID_PREFIX}{raw_project_id}",
|
||||
)
|
||||
|
||||
if (
|
||||
not project_ids
|
||||
and cast(_HasField, request).HasField("project_id")
|
||||
and request.project_id
|
||||
):
|
||||
try:
|
||||
project_id = UUID(request.project_id)
|
||||
except ValueError:
|
||||
truncated = request.project_id[:8] + "..." if len(request.project_id) > 8 else request.project_id
|
||||
logger.warning(
|
||||
"ListMeetings: invalid project_id format",
|
||||
project_id_truncated=truncated,
|
||||
project_id_length=len(request.project_id),
|
||||
)
|
||||
await abort_invalid_argument(context, f"{ERROR_INVALID_PROJECT_ID_PREFIX}{request.project_id}")
|
||||
project_ids = await _parse_project_ids_or_abort(request, context)
|
||||
if not project_ids:
|
||||
project_id = await _parse_project_id_or_abort(request, context)
|
||||
|
||||
async with self.create_repository_provider() as repo:
|
||||
if project_id is None and not project_ids:
|
||||
|
||||
22
src/noteflow/grpc/_service_base.py
Normal file
22
src/noteflow/grpc/_service_base.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Runtime gRPC servicer base types."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .proto import noteflow_pb2_grpc
|
||||
|
||||
if TYPE_CHECKING:
|
||||
GrpcBaseServicer = object
|
||||
|
||||
class NoteFlowServicerStubs:
|
||||
"""Type-checking placeholder for servicer stubs."""
|
||||
|
||||
pass
|
||||
else:
|
||||
GrpcBaseServicer = noteflow_pb2_grpc.NoteFlowServiceServicer
|
||||
|
||||
class NoteFlowServicerStubs:
|
||||
"""Runtime placeholder for type stubs (empty at runtime)."""
|
||||
|
||||
pass
|
||||
@@ -7,6 +7,7 @@ Used as fallback when no database is configured.
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Unpack
|
||||
|
||||
from noteflow.config.constants import ERROR_MSG_MEETING_PREFIX
|
||||
@@ -21,6 +22,62 @@ if TYPE_CHECKING:
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _MeetingListOptions:
|
||||
states: set[MeetingState] | None
|
||||
limit: int
|
||||
offset: int
|
||||
sort_desc: bool
|
||||
project_id: str | None
|
||||
project_ids: set[str] | None
|
||||
|
||||
|
||||
def _normalize_list_options(
|
||||
options: MeetingListKwargs,
|
||||
) -> _MeetingListOptions:
|
||||
states = options.get("states")
|
||||
state_set = set(states) if states else None
|
||||
limit = options.get("limit", 100)
|
||||
offset = options.get("offset", 0)
|
||||
sort_desc = options.get("sort_desc", True)
|
||||
project_id = options.get("project_id")
|
||||
project_ids = options.get("project_ids")
|
||||
project_id_set = set(project_ids) if project_ids else None
|
||||
return _MeetingListOptions(
|
||||
states=state_set,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
sort_desc=sort_desc,
|
||||
project_id=project_id,
|
||||
project_ids=project_id_set,
|
||||
)
|
||||
|
||||
|
||||
def _filter_meetings(
|
||||
meetings: list[Meeting],
|
||||
options: _MeetingListOptions,
|
||||
) -> list[Meeting]:
|
||||
filtered = meetings
|
||||
|
||||
if options.states:
|
||||
filtered = [m for m in filtered if m.state in options.states]
|
||||
|
||||
if options.project_ids:
|
||||
filtered = [
|
||||
m
|
||||
for m in filtered
|
||||
if m.project_id is not None and str(m.project_id) in options.project_ids
|
||||
]
|
||||
elif options.project_id:
|
||||
filtered = [
|
||||
m
|
||||
for m in filtered
|
||||
if m.project_id is not None and str(m.project_id) == options.project_id
|
||||
]
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
class MeetingStore:
|
||||
"""Thread-safe in-memory meeting storage using domain entities."""
|
||||
|
||||
@@ -102,38 +159,11 @@ class MeetingStore:
|
||||
Tuple of (paginated meeting list, total matching count).
|
||||
"""
|
||||
with self._lock:
|
||||
states = kwargs.get("states")
|
||||
limit = kwargs.get("limit", 100)
|
||||
offset = kwargs.get("offset", 0)
|
||||
sort_desc = kwargs.get("sort_desc", True)
|
||||
project_id = kwargs.get("project_id")
|
||||
project_ids = kwargs.get("project_ids")
|
||||
meetings = list(self._meetings.values())
|
||||
|
||||
# Filter by state
|
||||
if states:
|
||||
state_set = set(states)
|
||||
meetings = [m for m in meetings if m.state in state_set]
|
||||
|
||||
# Filter by project(s) if requested
|
||||
if project_ids:
|
||||
project_set = set(project_ids)
|
||||
meetings = [
|
||||
m for m in meetings if m.project_id is not None and str(m.project_id) in project_set
|
||||
]
|
||||
elif project_id:
|
||||
meetings = [
|
||||
m for m in meetings if m.project_id is not None and str(m.project_id) == project_id
|
||||
]
|
||||
|
||||
options = _normalize_list_options(kwargs)
|
||||
meetings = _filter_meetings(list(self._meetings.values()), options)
|
||||
total = len(meetings)
|
||||
|
||||
# Sort
|
||||
meetings.sort(key=lambda m: m.created_at, reverse=sort_desc)
|
||||
|
||||
# Paginate
|
||||
meetings = meetings[offset : offset + limit]
|
||||
|
||||
meetings.sort(key=lambda m: m.created_at, reverse=options.sort_desc)
|
||||
meetings = meetings[options.offset : options.offset + options.limit]
|
||||
return meetings, total
|
||||
|
||||
def find_older_than(self, cutoff: datetime) -> list[Meeting]:
|
||||
|
||||
@@ -12,7 +12,6 @@ from typing import TYPE_CHECKING, ClassVar, Final
|
||||
from uuid import UUID
|
||||
|
||||
from noteflow import __version__
|
||||
from noteflow.application.services.identity_service import IdentityService
|
||||
from noteflow.domain.identity.context import OperationContext, UserContext, WorkspaceContext
|
||||
from noteflow.domain.identity.roles import WorkspaceRole
|
||||
from noteflow.infrastructure.logging import request_id_var, user_id_var, workspace_id_var
|
||||
@@ -52,7 +51,9 @@ from ._mixins import (
|
||||
SyncMixin,
|
||||
WebhooksMixin,
|
||||
)
|
||||
from .proto import noteflow_pb2, noteflow_pb2_grpc
|
||||
from ._identity_singleton import default_identity_service
|
||||
from ._service_base import GrpcBaseServicer, NoteFlowServicerStubs
|
||||
from .proto import noteflow_pb2
|
||||
from .stream_state import MeetingStreamState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -65,31 +66,6 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
_GrpcBaseServicer = object
|
||||
else:
|
||||
_GrpcBaseServicer = noteflow_pb2_grpc.NoteFlowServiceServicer
|
||||
|
||||
# Empty class to satisfy MRO - cannot use `object` directly as it conflicts
|
||||
# with NoteFlowServiceServicer's inheritance from object
|
||||
class NoteFlowServicerStubs:
|
||||
"""Runtime placeholder for type stubs (empty at runtime)."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Module-level singleton for identity service (stateless, no dependencies)
|
||||
_identity_service_instance: IdentityService | None = None
|
||||
|
||||
|
||||
def _default_identity_service() -> IdentityService:
|
||||
"""Get or create the default identity service singleton."""
|
||||
global _identity_service_instance
|
||||
if _identity_service_instance is None:
|
||||
_identity_service_instance = IdentityService()
|
||||
return _identity_service_instance
|
||||
|
||||
|
||||
class NoteFlowServicer(
|
||||
StreamingMixin,
|
||||
DiarizationMixin,
|
||||
@@ -109,7 +85,7 @@ class NoteFlowServicer(
|
||||
ProjectMixin,
|
||||
ProjectMembershipMixin,
|
||||
NoteFlowServicerStubs,
|
||||
_GrpcBaseServicer,
|
||||
GrpcBaseServicer,
|
||||
):
|
||||
"""Async gRPC service implementation for NoteFlow with PostgreSQL persistence.
|
||||
|
||||
@@ -159,7 +135,7 @@ class NoteFlowServicer(
|
||||
self.webhook_service = services.webhook_service
|
||||
self.project_service = services.project_service
|
||||
# Identity service - always available (stateless, no dependencies)
|
||||
self.identity_service = services.identity_service or _default_identity_service()
|
||||
self.identity_service = services.identity_service or default_identity_service()
|
||||
self._start_time = time.time()
|
||||
self.memory_store: MeetingStore | None = MeetingStore() if session_factory is None else None
|
||||
# Audio infrastructure
|
||||
|
||||
214
src/noteflow/infrastructure/calendar/oauth_helpers.py
Normal file
214
src/noteflow/infrastructure/calendar/oauth_helpers.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""Shared helpers for OAuth manager."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Mapping
|
||||
from dataclasses import dataclass
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from noteflow.config.constants import (
|
||||
DEFAULT_OAUTH_TOKEN_EXPIRY_SECONDS,
|
||||
OAUTH_FIELD_ACCESS_TOKEN,
|
||||
OAUTH_FIELD_REFRESH_TOKEN,
|
||||
OAUTH_FIELD_SCOPE,
|
||||
OAUTH_FIELD_TOKEN_TYPE,
|
||||
)
|
||||
from noteflow.domain.value_objects import OAuthProvider, OAuthState, OAuthTokens
|
||||
|
||||
|
||||
def get_auth_url(provider: OAuthProvider, *, google_url: str, outlook_url: str) -> str:
|
||||
"""Get authorization URL for provider."""
|
||||
return google_url if provider == OAuthProvider.GOOGLE else outlook_url
|
||||
|
||||
|
||||
def get_token_url(provider: OAuthProvider, *, google_url: str, outlook_url: str) -> str:
|
||||
"""Get token URL for provider."""
|
||||
return google_url if provider == OAuthProvider.GOOGLE else outlook_url
|
||||
|
||||
|
||||
def get_revoke_url(provider: OAuthProvider, *, google_url: str, outlook_url: str) -> str:
|
||||
"""Get revoke URL for provider."""
|
||||
return google_url if provider == OAuthProvider.GOOGLE else outlook_url
|
||||
|
||||
|
||||
def get_scopes(
|
||||
provider: OAuthProvider,
|
||||
*,
|
||||
google_scopes: list[str],
|
||||
outlook_scopes: list[str],
|
||||
) -> list[str]:
|
||||
"""Get OAuth scopes for provider."""
|
||||
return google_scopes if provider == OAuthProvider.GOOGLE else outlook_scopes
|
||||
|
||||
|
||||
def generate_code_verifier() -> str:
|
||||
"""Generate a cryptographically random code verifier for PKCE."""
|
||||
verifier = secrets.token_urlsafe(64)
|
||||
return verifier
|
||||
|
||||
|
||||
def generate_code_challenge(verifier: str) -> str:
|
||||
"""Generate code challenge from verifier using S256 method."""
|
||||
digest = hashlib.sha256(verifier.encode("ascii")).digest()
|
||||
return base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AuthUrlConfig:
|
||||
provider: OAuthProvider
|
||||
redirect_uri: str
|
||||
state: str
|
||||
code_challenge: str
|
||||
client_id: str
|
||||
scopes: list[str]
|
||||
google_auth_url: str
|
||||
outlook_auth_url: str
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class OAuthStateConfig:
|
||||
provider: OAuthProvider
|
||||
redirect_uri: str
|
||||
code_verifier: str
|
||||
state_token: str
|
||||
created_at: datetime
|
||||
ttl_seconds: int
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class OAuthFlowConfig:
|
||||
provider: OAuthProvider
|
||||
redirect_uri: str
|
||||
client_id: str
|
||||
scopes: list[str]
|
||||
google_auth_url: str
|
||||
outlook_auth_url: str
|
||||
state_ttl_seconds: int
|
||||
|
||||
|
||||
def build_auth_url(config: AuthUrlConfig) -> str:
|
||||
"""Build OAuth authorization URL with PKCE parameters."""
|
||||
base_url = (
|
||||
config.google_auth_url
|
||||
if config.provider == OAuthProvider.GOOGLE
|
||||
else config.outlook_auth_url
|
||||
)
|
||||
|
||||
params = {
|
||||
"client_id": config.client_id,
|
||||
"redirect_uri": config.redirect_uri,
|
||||
"response_type": "code",
|
||||
OAUTH_FIELD_SCOPE: " ".join(config.scopes),
|
||||
"state": config.state,
|
||||
"code_challenge": config.code_challenge,
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
|
||||
if config.provider == OAuthProvider.GOOGLE:
|
||||
params["access_type"] = "offline"
|
||||
params["prompt"] = "consent"
|
||||
elif config.provider == OAuthProvider.OUTLOOK:
|
||||
params["response_mode"] = "query"
|
||||
|
||||
return f"{base_url}?{urlencode(params)}"
|
||||
|
||||
|
||||
def generate_state_token() -> str:
|
||||
"""Generate a random state token for OAuth CSRF protection."""
|
||||
token = secrets.token_urlsafe(32)
|
||||
return token
|
||||
|
||||
|
||||
def create_oauth_state(config: OAuthStateConfig) -> OAuthState:
|
||||
"""Create an OAuthState from config settings."""
|
||||
expires_at = config.created_at + timedelta(seconds=config.ttl_seconds)
|
||||
return OAuthState(
|
||||
state=config.state_token,
|
||||
provider=config.provider,
|
||||
redirect_uri=config.redirect_uri,
|
||||
code_verifier=config.code_verifier,
|
||||
created_at=config.created_at,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
|
||||
def prepare_oauth_flow(config: OAuthFlowConfig) -> tuple[str, OAuthState, str]:
|
||||
"""Prepare OAuth state and authorization URL for a flow."""
|
||||
code_verifier = generate_code_verifier()
|
||||
code_challenge = generate_code_challenge(code_verifier)
|
||||
state_token = generate_state_token()
|
||||
now = datetime.now(UTC)
|
||||
oauth_state = create_oauth_state(
|
||||
OAuthStateConfig(
|
||||
provider=config.provider,
|
||||
redirect_uri=config.redirect_uri,
|
||||
code_verifier=code_verifier,
|
||||
state_token=state_token,
|
||||
created_at=now,
|
||||
ttl_seconds=config.state_ttl_seconds,
|
||||
)
|
||||
)
|
||||
auth_url = build_auth_url(
|
||||
AuthUrlConfig(
|
||||
provider=config.provider,
|
||||
redirect_uri=config.redirect_uri,
|
||||
state=state_token,
|
||||
code_challenge=code_challenge,
|
||||
client_id=config.client_id,
|
||||
scopes=config.scopes,
|
||||
google_auth_url=config.google_auth_url,
|
||||
outlook_auth_url=config.outlook_auth_url,
|
||||
)
|
||||
)
|
||||
return state_token, oauth_state, auth_url
|
||||
|
||||
|
||||
def parse_token_response(
|
||||
data: Mapping[str, object],
|
||||
*,
|
||||
existing_refresh_token: str | None = None,
|
||||
) -> OAuthTokens:
|
||||
"""Parse token response into OAuthTokens."""
|
||||
access_token = str(data.get(OAUTH_FIELD_ACCESS_TOKEN, ""))
|
||||
if not access_token:
|
||||
raise ValueError("No access_token in response")
|
||||
|
||||
expires_in_raw = data.get("expires_in", DEFAULT_OAUTH_TOKEN_EXPIRY_SECONDS)
|
||||
expires_in = (
|
||||
int(expires_in_raw)
|
||||
if isinstance(expires_in_raw, (int, float, str))
|
||||
else DEFAULT_OAUTH_TOKEN_EXPIRY_SECONDS
|
||||
)
|
||||
expires_at = datetime.now(UTC) + timedelta(seconds=expires_in)
|
||||
|
||||
refresh_token = data.get(OAUTH_FIELD_REFRESH_TOKEN)
|
||||
if isinstance(refresh_token, str):
|
||||
final_refresh_token: str | None = refresh_token
|
||||
else:
|
||||
final_refresh_token = existing_refresh_token
|
||||
|
||||
return OAuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=final_refresh_token,
|
||||
token_type=str(data.get(OAUTH_FIELD_TOKEN_TYPE, "Bearer")),
|
||||
expires_at=expires_at,
|
||||
scope=str(data.get(OAUTH_FIELD_SCOPE, "")),
|
||||
)
|
||||
|
||||
|
||||
def validate_oauth_state(
|
||||
oauth_state: OAuthState,
|
||||
*,
|
||||
provider: OAuthProvider,
|
||||
) -> None:
|
||||
"""Validate OAuth state values, raising ValueError on failures."""
|
||||
if oauth_state.is_state_expired():
|
||||
raise ValueError("State token has expired")
|
||||
if oauth_state.provider != provider:
|
||||
raise ValueError(
|
||||
f"Provider mismatch: expected {oauth_state.provider}, got {provider}"
|
||||
)
|
||||
@@ -6,27 +6,28 @@ Uses PKCE (Proof Key for Code Exchange) for secure OAuth 2.0 flow.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
|
||||
from noteflow.config.constants import (
|
||||
DEFAULT_OAUTH_TOKEN_EXPIRY_SECONDS,
|
||||
ERR_TOKEN_REFRESH_PREFIX,
|
||||
HTTP_STATUS_NO_CONTENT,
|
||||
HTTP_STATUS_OK,
|
||||
OAUTH_FIELD_ACCESS_TOKEN,
|
||||
OAUTH_FIELD_REFRESH_TOKEN,
|
||||
OAUTH_FIELD_SCOPE,
|
||||
OAUTH_FIELD_TOKEN_TYPE,
|
||||
)
|
||||
from noteflow.domain.ports.calendar import OAuthPort
|
||||
from noteflow.domain.value_objects import OAuthProvider, OAuthState, OAuthTokens
|
||||
from noteflow.infrastructure.calendar.oauth_helpers import (
|
||||
OAuthFlowConfig,
|
||||
get_revoke_url,
|
||||
get_scopes,
|
||||
get_token_url,
|
||||
parse_token_response,
|
||||
prepare_oauth_flow,
|
||||
validate_oauth_state,
|
||||
)
|
||||
from noteflow.infrastructure.logging import get_logger, log_timing
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -157,33 +158,25 @@ class OAuthManager(OAuthPort):
|
||||
)
|
||||
raise OAuthError("Too many pending OAuth flows. Please try again later.")
|
||||
|
||||
# Generate PKCE code verifier and challenge
|
||||
code_verifier = self._generate_code_verifier()
|
||||
code_challenge = self._generate_code_challenge(code_verifier)
|
||||
|
||||
# Generate state token for CSRF protection
|
||||
state_token = secrets.token_urlsafe(32)
|
||||
|
||||
# Store state for validation during callback
|
||||
now = datetime.now(UTC)
|
||||
oauth_state = OAuthState(
|
||||
state=state_token,
|
||||
provider=provider,
|
||||
redirect_uri=redirect_uri,
|
||||
code_verifier=code_verifier,
|
||||
created_at=now,
|
||||
expires_at=now + timedelta(seconds=self.STATE_TTL_SECONDS),
|
||||
client_id, _ = self._get_credentials(provider)
|
||||
scopes = get_scopes(
|
||||
provider,
|
||||
google_scopes=self.GOOGLE_SCOPES,
|
||||
outlook_scopes=self.OUTLOOK_SCOPES,
|
||||
)
|
||||
state_token, oauth_state, auth_url = prepare_oauth_flow(
|
||||
OAuthFlowConfig(
|
||||
provider=provider,
|
||||
redirect_uri=redirect_uri,
|
||||
client_id=client_id,
|
||||
scopes=scopes,
|
||||
google_auth_url=self.GOOGLE_AUTH_URL,
|
||||
outlook_auth_url=self.OUTLOOK_AUTH_URL,
|
||||
state_ttl_seconds=self.STATE_TTL_SECONDS,
|
||||
)
|
||||
)
|
||||
self._pending_states[state_token] = oauth_state
|
||||
|
||||
# Build authorization URL
|
||||
auth_url = self._build_auth_url(
|
||||
provider=provider,
|
||||
redirect_uri=redirect_uri,
|
||||
state=state_token,
|
||||
code_challenge=code_challenge,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"oauth_initiated",
|
||||
provider=provider.value,
|
||||
@@ -222,26 +215,24 @@ class OAuthManager(OAuthPort):
|
||||
)
|
||||
raise OAuthError("Invalid or expired state token")
|
||||
|
||||
if oauth_state.is_state_expired():
|
||||
try:
|
||||
validate_oauth_state(oauth_state, provider=provider)
|
||||
except ValueError as exc:
|
||||
event = (
|
||||
"oauth_state_expired"
|
||||
if "expired" in str(exc).lower()
|
||||
else "oauth_provider_mismatch"
|
||||
)
|
||||
logger.warning(
|
||||
"oauth_state_expired",
|
||||
event,
|
||||
event_type="security",
|
||||
provider=provider.value,
|
||||
created_at=oauth_state.created_at.isoformat(),
|
||||
expires_at=oauth_state.expires_at.isoformat(),
|
||||
)
|
||||
raise OAuthError("State token has expired")
|
||||
|
||||
if oauth_state.provider != provider:
|
||||
logger.warning(
|
||||
"oauth_provider_mismatch",
|
||||
event_type="security",
|
||||
expected_provider=oauth_state.provider.value,
|
||||
received_provider=provider.value,
|
||||
)
|
||||
raise OAuthError(
|
||||
f"Provider mismatch: expected {oauth_state.provider}, got {provider}"
|
||||
)
|
||||
raise OAuthError(str(exc)) from exc
|
||||
|
||||
# Exchange code for tokens
|
||||
tokens = await self._exchange_code(
|
||||
@@ -271,7 +262,11 @@ class OAuthManager(OAuthPort):
|
||||
Raises:
|
||||
OAuthError: If refresh fails.
|
||||
"""
|
||||
token_url = self._get_token_url(provider)
|
||||
token_url = get_token_url(
|
||||
provider,
|
||||
google_url=self.GOOGLE_TOKEN_URL,
|
||||
outlook_url=self.OUTLOOK_TOKEN_URL,
|
||||
)
|
||||
client_id, client_secret = self._get_credentials(provider)
|
||||
|
||||
data = {
|
||||
@@ -298,7 +293,13 @@ class OAuthManager(OAuthPort):
|
||||
raise OAuthError(f"{ERR_TOKEN_REFRESH_PREFIX}{error_detail}")
|
||||
|
||||
token_data = response.json()
|
||||
tokens = self._parse_token_response(token_data, refresh_token)
|
||||
try:
|
||||
tokens = parse_token_response(
|
||||
token_data,
|
||||
existing_refresh_token=refresh_token,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise OAuthError(str(exc)) from exc
|
||||
|
||||
logger.info("oauth_tokens_refreshed", provider=provider.value)
|
||||
return tokens
|
||||
@@ -317,7 +318,11 @@ class OAuthManager(OAuthPort):
|
||||
Returns:
|
||||
True if revoked successfully.
|
||||
"""
|
||||
revoke_url = self._get_revoke_url(provider)
|
||||
revoke_url = get_revoke_url(
|
||||
provider,
|
||||
google_url=self.GOOGLE_REVOKE_URL,
|
||||
outlook_url=self.OUTLOOK_REVOKE_URL,
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
if provider == OAuthProvider.GOOGLE:
|
||||
@@ -363,61 +368,6 @@ class OAuthManager(OAuthPort):
|
||||
self._settings.outlook_client_secret,
|
||||
)
|
||||
|
||||
def _get_auth_url(self, provider: OAuthProvider) -> str:
|
||||
"""Get authorization URL for provider."""
|
||||
if provider == OAuthProvider.GOOGLE:
|
||||
return self.GOOGLE_AUTH_URL
|
||||
return self.OUTLOOK_AUTH_URL
|
||||
|
||||
def _get_token_url(self, provider: OAuthProvider) -> str:
|
||||
"""Get token URL for provider."""
|
||||
if provider == OAuthProvider.GOOGLE:
|
||||
return self.GOOGLE_TOKEN_URL
|
||||
return self.OUTLOOK_TOKEN_URL
|
||||
|
||||
def _get_revoke_url(self, provider: OAuthProvider) -> str:
|
||||
"""Get revoke URL for provider."""
|
||||
if provider == OAuthProvider.GOOGLE:
|
||||
return self.GOOGLE_REVOKE_URL
|
||||
return self.OUTLOOK_REVOKE_URL
|
||||
|
||||
def _get_scopes(self, provider: OAuthProvider) -> list[str]:
|
||||
"""Get OAuth scopes for provider."""
|
||||
if provider == OAuthProvider.GOOGLE:
|
||||
return self.GOOGLE_SCOPES
|
||||
return self.OUTLOOK_SCOPES
|
||||
|
||||
def _build_auth_url(
|
||||
self,
|
||||
provider: OAuthProvider,
|
||||
redirect_uri: str,
|
||||
state: str,
|
||||
code_challenge: str,
|
||||
) -> str:
|
||||
"""Build OAuth authorization URL with PKCE parameters."""
|
||||
client_id, _ = self._get_credentials(provider)
|
||||
scopes = self._get_scopes(provider)
|
||||
base_url = self._get_auth_url(provider)
|
||||
|
||||
params = {
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": "code",
|
||||
OAUTH_FIELD_SCOPE: " ".join(scopes),
|
||||
"state": state,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
|
||||
# Provider-specific parameters
|
||||
if provider == OAuthProvider.GOOGLE:
|
||||
params["access_type"] = "offline"
|
||||
params["prompt"] = "consent"
|
||||
elif provider == OAuthProvider.OUTLOOK:
|
||||
params["response_mode"] = "query"
|
||||
|
||||
return f"{base_url}?{urlencode(params)}"
|
||||
|
||||
async def _exchange_code(
|
||||
self,
|
||||
provider: OAuthProvider,
|
||||
@@ -426,7 +376,11 @@ class OAuthManager(OAuthPort):
|
||||
code_verifier: str,
|
||||
) -> OAuthTokens:
|
||||
"""Exchange authorization code for tokens."""
|
||||
token_url = self._get_token_url(provider)
|
||||
token_url = get_token_url(
|
||||
provider,
|
||||
google_url=self.GOOGLE_TOKEN_URL,
|
||||
outlook_url=self.OUTLOOK_TOKEN_URL,
|
||||
)
|
||||
client_id, client_secret = self._get_credentials(provider)
|
||||
|
||||
data = {
|
||||
@@ -454,52 +408,10 @@ class OAuthManager(OAuthPort):
|
||||
raise OAuthError(f"Token exchange failed: {error_detail}")
|
||||
|
||||
token_data = response.json()
|
||||
return self._parse_token_response(token_data)
|
||||
|
||||
def _parse_token_response(
|
||||
self,
|
||||
data: dict[str, object],
|
||||
existing_refresh_token: str | None = None,
|
||||
) -> OAuthTokens:
|
||||
"""Parse token response into OAuthTokens."""
|
||||
access_token = str(data.get(OAUTH_FIELD_ACCESS_TOKEN, ""))
|
||||
if not access_token:
|
||||
raise OAuthError("No access_token in response")
|
||||
|
||||
# Calculate expiry time
|
||||
expires_in_raw = data.get("expires_in", DEFAULT_OAUTH_TOKEN_EXPIRY_SECONDS)
|
||||
expires_in = (
|
||||
int(expires_in_raw)
|
||||
if isinstance(expires_in_raw, (int, float, str))
|
||||
else DEFAULT_OAUTH_TOKEN_EXPIRY_SECONDS
|
||||
)
|
||||
expires_at = datetime.now(UTC) + timedelta(seconds=expires_in)
|
||||
|
||||
# Refresh token may not be returned on refresh
|
||||
refresh_token = data.get(OAUTH_FIELD_REFRESH_TOKEN)
|
||||
if isinstance(refresh_token, str):
|
||||
final_refresh_token: str | None = refresh_token
|
||||
else:
|
||||
final_refresh_token = existing_refresh_token
|
||||
|
||||
return OAuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=final_refresh_token,
|
||||
token_type=str(data.get(OAUTH_FIELD_TOKEN_TYPE, "Bearer")),
|
||||
expires_at=expires_at,
|
||||
scope=str(data.get(OAUTH_FIELD_SCOPE, "")),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _generate_code_verifier() -> str:
|
||||
"""Generate a cryptographically random code verifier for PKCE."""
|
||||
return secrets.token_urlsafe(64)
|
||||
|
||||
@staticmethod
|
||||
def _generate_code_challenge(verifier: str) -> str:
|
||||
"""Generate code challenge from verifier using S256 method."""
|
||||
digest = hashlib.sha256(verifier.encode("ascii")).digest()
|
||||
return base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
|
||||
try:
|
||||
return parse_token_response(token_data)
|
||||
except ValueError as exc:
|
||||
raise OAuthError(str(exc)) from exc
|
||||
|
||||
def _cleanup_expired_states(self) -> None:
|
||||
"""Remove expired state tokens."""
|
||||
|
||||
@@ -64,19 +64,21 @@ def _patch_torch_load() -> None:
|
||||
try:
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
|
||||
if Version(torch.__version__) >= Version("2.6.0"):
|
||||
original_load = cast(Callable[..., object], torch.load)
|
||||
|
||||
def _patched_load(*args: object, **kwargs: object) -> object:
|
||||
if "weights_only" not in kwargs:
|
||||
kwargs["weights_only"] = False
|
||||
return original_load(*args, **kwargs)
|
||||
|
||||
setattr(torch, _ATTR_LOAD, _patched_load)
|
||||
logger.debug("Patched torch.load for weights_only=False default")
|
||||
except ImportError:
|
||||
pass
|
||||
return
|
||||
|
||||
if Version(torch.__version__) < Version("2.6.0"):
|
||||
return
|
||||
|
||||
original_load = cast(Callable[..., object], torch.load)
|
||||
|
||||
def _patched_load(*args: object, **kwargs: object) -> object:
|
||||
if "weights_only" not in kwargs:
|
||||
kwargs["weights_only"] = False
|
||||
return original_load(*args, **kwargs)
|
||||
|
||||
setattr(torch, _ATTR_LOAD, _patched_load)
|
||||
logger.debug("Patched torch.load for weights_only=False default")
|
||||
|
||||
|
||||
def _patch_huggingface_auth() -> None:
|
||||
|
||||
@@ -8,7 +8,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -19,6 +19,19 @@ from noteflow.infrastructure.logging import get_logger
|
||||
if TYPE_CHECKING:
|
||||
from diart import SpeakerDiarization
|
||||
|
||||
|
||||
class _TrackSegment(Protocol):
|
||||
start: float
|
||||
end: float
|
||||
|
||||
|
||||
class _Annotation(Protocol):
|
||||
def itertracks(
|
||||
self,
|
||||
*,
|
||||
yield_label: bool,
|
||||
) -> Sequence[tuple[_TrackSegment, object, object]]: ...
|
||||
|
||||
from numpy.typing import NDArray
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -27,6 +40,27 @@ logger = get_logger(__name__)
|
||||
DEFAULT_CHUNK_DURATION: float = 5.0
|
||||
|
||||
|
||||
def _collect_turns(
|
||||
results: Sequence[tuple[_Annotation, object]],
|
||||
stream_time: float,
|
||||
) -> list[SpeakerTurn]:
|
||||
"""Convert pipeline results to speaker turns with absolute time offsets."""
|
||||
turns: list[SpeakerTurn] = []
|
||||
for annotation, _ in results:
|
||||
for track in annotation.itertracks(yield_label=True):
|
||||
if len(track) != 3:
|
||||
continue
|
||||
segment, _, speaker = track
|
||||
turns.append(
|
||||
SpeakerTurn(
|
||||
speaker=str(speaker),
|
||||
start=segment.start + stream_time,
|
||||
end=segment.end + stream_time,
|
||||
)
|
||||
)
|
||||
return turns
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiarizationSession:
|
||||
"""Per-meeting streaming diarization session.
|
||||
@@ -149,18 +183,8 @@ class DiarizationSession:
|
||||
results = self._pipeline([waveform])
|
||||
|
||||
# Convert results to turns with absolute time offsets
|
||||
new_turns: list[SpeakerTurn] = []
|
||||
for annotation, _ in results:
|
||||
for track in annotation.itertracks(yield_label=True):
|
||||
if len(track) == 3:
|
||||
segment, _, speaker = track
|
||||
turn = SpeakerTurn(
|
||||
speaker=str(speaker),
|
||||
start=segment.start + self._stream_time,
|
||||
end=segment.end + self._stream_time,
|
||||
)
|
||||
new_turns.append(turn)
|
||||
self._turns.append(turn)
|
||||
new_turns = _collect_turns(results, self._stream_time)
|
||||
self._turns.extend(new_turns)
|
||||
|
||||
except (RuntimeError, ZeroDivisionError, ValueError) as e:
|
||||
# Handle frame/weights mismatch and related errors gracefully
|
||||
|
||||
@@ -145,7 +145,9 @@ class MemorySummaryRepository:
|
||||
|
||||
async def get_by_meeting(self, meeting_id: MeetingId) -> Summary | None:
|
||||
"""Get summary for a meeting."""
|
||||
return self._store.get_meeting_summary(str(meeting_id))
|
||||
meeting_key = str(meeting_id)
|
||||
summary = self._store.get_meeting_summary(meeting_key)
|
||||
return summary
|
||||
|
||||
async def delete_by_meeting(self, meeting_id: MeetingId) -> bool:
|
||||
"""Delete summary for a meeting."""
|
||||
|
||||
@@ -80,7 +80,7 @@ def collect_assertion_roulette() -> list[Violation]:
|
||||
if node.msg is None:
|
||||
assertions_without_msg += 1
|
||||
|
||||
if assertions_without_msg > 3:
|
||||
if assertions_without_msg > 1:
|
||||
violations.append(
|
||||
Violation(
|
||||
rule="assertion_roulette",
|
||||
@@ -138,9 +138,6 @@ def collect_sleepy_tests() -> list[Violation]:
|
||||
"""Collect sleepy test violations."""
|
||||
allowed_sleepy_paths = {
|
||||
"tests/stress/",
|
||||
"tests/integration/test_signal_handling.py",
|
||||
"tests/integration/test_database_resilience.py",
|
||||
"tests/grpc/test_stream_lifecycle.py",
|
||||
}
|
||||
violations: list[Violation] = []
|
||||
|
||||
@@ -312,8 +309,8 @@ def collect_magic_number_tests() -> list[Violation]:
|
||||
|
||||
def collect_sensitive_equality() -> list[Violation]:
|
||||
"""Collect sensitive equality (str/repr comparison) violations."""
|
||||
excluded_test_patterns = {"string", "proto", "conversion", "serializ", "preserves_message"}
|
||||
excluded_file_patterns = {"_mixin"}
|
||||
excluded_test_patterns = {"string", "proto"}
|
||||
excluded_file_patterns: set[str] = set()
|
||||
violations: list[Violation] = []
|
||||
|
||||
for py_file in find_test_files():
|
||||
@@ -351,7 +348,7 @@ def collect_sensitive_equality() -> list[Violation]:
|
||||
|
||||
def collect_eager_tests() -> list[Violation]:
|
||||
"""Collect eager test (too many method calls) violations."""
|
||||
max_method_calls = 10
|
||||
max_method_calls = 7
|
||||
violations: list[Violation] = []
|
||||
|
||||
for py_file in find_test_files():
|
||||
@@ -414,7 +411,7 @@ def collect_duplicate_test_names() -> list[Violation]:
|
||||
|
||||
def collect_long_tests() -> list[Violation]:
|
||||
"""Collect long test method violations."""
|
||||
max_lines = 46
|
||||
max_lines = 35
|
||||
violations: list[Violation] = []
|
||||
|
||||
for py_file in find_test_files():
|
||||
|
||||
@@ -1,38 +1,5 @@
|
||||
{
|
||||
"generated_at": "2026-01-05T03:06:34.258338+00:00",
|
||||
"rules": {
|
||||
"deep_nesting": [
|
||||
"deep_nesting|src/noteflow/application/services/auth_service.py|get_current_user|depth=5",
|
||||
"deep_nesting|src/noteflow/grpc/_mixins/identity.py|GetCurrentUser|depth=4",
|
||||
"deep_nesting|src/noteflow/infrastructure/diarization/_compat.py|_patch_torch_load|depth=4",
|
||||
"deep_nesting|src/noteflow/infrastructure/diarization/session.py|_process_full_chunk|depth=4"
|
||||
],
|
||||
"feature_envy": [
|
||||
"feature_envy|src/noteflow/application/services/auth_service.py|AuthService.refresh_auth_tokens|integration=10_vs_self=3",
|
||||
"feature_envy|src/noteflow/grpc/meeting_store.py|MeetingStore.list_all|kwargs=6_vs_self=2",
|
||||
"feature_envy|src/noteflow/grpc/meeting_store.py|MeetingStore.list_all|m=6_vs_self=2",
|
||||
"feature_envy|src/noteflow/infrastructure/calendar/oauth_manager.py|OAuthManager.complete_auth|oauth_state=8_vs_self=2"
|
||||
],
|
||||
"god_class": [
|
||||
"god_class|src/noteflow/infrastructure/calendar/oauth_manager.py|OAuthManager|lines=513",
|
||||
"god_class|src/noteflow/infrastructure/calendar/oauth_manager.py|OAuthManager|methods=21"
|
||||
],
|
||||
"long_method": [
|
||||
"long_method|src/noteflow/application/services/auth_service.py|_store_auth_user|lines=85",
|
||||
"long_method|src/noteflow/application/services/auth_service.py|refresh_auth_tokens|lines=76",
|
||||
"long_method|src/noteflow/grpc/_mixins/meeting.py|ListMeetings|lines=72"
|
||||
],
|
||||
"module_size_soft": [
|
||||
"module_size_soft|src/noteflow/application/services/auth_service.py|module|lines=571",
|
||||
"module_size_soft|src/noteflow/grpc/service.py|module|lines=522",
|
||||
"module_size_soft|src/noteflow/infrastructure/calendar/oauth_manager.py|module|lines=554"
|
||||
],
|
||||
"passthrough_class": [
|
||||
"passthrough_class|src/noteflow/domain/entities/meeting.py|ProcessingStepState|5_methods"
|
||||
],
|
||||
"thin_wrapper": [
|
||||
"thin_wrapper|src/noteflow/infrastructure/persistence/memory/repositories/core.py|get_by_meeting|get_meeting_summary"
|
||||
]
|
||||
},
|
||||
"generated_at": "2026-01-05T15:51:45.809039+00:00",
|
||||
"rules": {},
|
||||
"schema_version": 1
|
||||
}
|
||||
|
||||
@@ -190,7 +190,7 @@ def collect_deprecated_patterns() -> list[Violation]:
|
||||
|
||||
def collect_high_complexity() -> list[Violation]:
|
||||
"""Collect high complexity violations."""
|
||||
max_complexity = 15
|
||||
max_complexity = 12
|
||||
violations: list[Violation] = []
|
||||
|
||||
def count_branches(node: ast.AST) -> int:
|
||||
@@ -232,7 +232,7 @@ def collect_high_complexity() -> list[Violation]:
|
||||
|
||||
def collect_long_parameter_lists() -> list[Violation]:
|
||||
"""Collect long parameter list violations."""
|
||||
max_params = 5
|
||||
max_params = 4
|
||||
violations: list[Violation] = []
|
||||
|
||||
for py_file in find_source_files(include_migrations=False):
|
||||
@@ -282,7 +282,6 @@ def collect_thin_wrappers() -> list[Violation]:
|
||||
("full_transcript", "join"),
|
||||
("duration", "sub"),
|
||||
("is_active", "property"),
|
||||
("is_admin", "can_admin"),
|
||||
# Domain method accessors (type-safe dict access)
|
||||
("get_metadata", "get"),
|
||||
# Strategy pattern implementations (RuleType.evaluate for simple mode)
|
||||
@@ -292,36 +291,21 @@ def collect_thin_wrappers() -> list[Violation]:
|
||||
("generate_request_id", "str"),
|
||||
# Context variable accessors (public API over internal contextvars)
|
||||
("get_request_id", "get"),
|
||||
("get_user_id", "get"),
|
||||
("get_workspace_id", "get"),
|
||||
# Time conversion utilities (semantic naming for datetime operations)
|
||||
("datetime_to_epoch_seconds", "timestamp"),
|
||||
("datetime_to_iso_string", "isoformat"),
|
||||
("epoch_seconds_to_datetime", "fromtimestamp"),
|
||||
("proto_timestamp_to_datetime", "replace"),
|
||||
# Accessor-style wrappers with semantic names
|
||||
("from_metrics", "cls"),
|
||||
("from_dict", "cls"),
|
||||
("empty", "cls"),
|
||||
("get_log_level", "get"),
|
||||
("get_preset_config", "get"),
|
||||
("get_provider", "get"),
|
||||
("get_pending_state", "get"),
|
||||
("get_stream_state", "get"),
|
||||
("get_async_session_factory", "async_sessionmaker"),
|
||||
("process_chunk", "process"),
|
||||
("get_openai_client", "_get_openai_client"),
|
||||
("meeting_apps", "frozenset"),
|
||||
("suppressed_apps", "frozenset"),
|
||||
("get_sync_run", "get"),
|
||||
("list_all", "list"),
|
||||
("get_by_id", "get"),
|
||||
("create", "insert"),
|
||||
("delete_by_meeting", "clear_summary"),
|
||||
("get_by_meeting", "fetch_segments"),
|
||||
("get_by_meeting", "get_summary"),
|
||||
("check_otel_available", "_check_otel_available"),
|
||||
("start_as_current_span", "_NoOpSpanContext"),
|
||||
("start_span", "_NoOpSpan"),
|
||||
("detected_app", "next"),
|
||||
}
|
||||
@@ -376,7 +360,7 @@ def collect_thin_wrappers() -> list[Violation]:
|
||||
|
||||
def collect_long_methods() -> list[Violation]:
|
||||
"""Collect long method violations."""
|
||||
max_lines = 68
|
||||
max_lines = 50
|
||||
violations: list[Violation] = []
|
||||
|
||||
def count_function_lines(node: ast.FunctionDef | ast.AsyncFunctionDef) -> int:
|
||||
@@ -407,7 +391,7 @@ def collect_long_methods() -> list[Violation]:
|
||||
|
||||
def collect_module_size_soft() -> list[Violation]:
|
||||
"""Collect module size soft limit violations."""
|
||||
soft_limit = 500
|
||||
soft_limit = 350
|
||||
violations: list[Violation] = []
|
||||
|
||||
for py_file in find_source_files(include_migrations=False):
|
||||
@@ -466,8 +450,8 @@ def collect_alias_imports() -> list[Violation]:
|
||||
|
||||
def collect_god_classes() -> list[Violation]:
|
||||
"""Collect god class violations."""
|
||||
max_methods = 20
|
||||
max_lines = 500
|
||||
max_methods = 15
|
||||
max_lines = 400
|
||||
violations: list[Violation] = []
|
||||
|
||||
for py_file in find_source_files(include_migrations=False):
|
||||
@@ -512,7 +496,7 @@ def collect_god_classes() -> list[Violation]:
|
||||
|
||||
def collect_deep_nesting() -> list[Violation]:
|
||||
"""Collect deep nesting violations."""
|
||||
max_nesting = 3
|
||||
max_nesting = 2
|
||||
violations: list[Violation] = []
|
||||
|
||||
def count_nesting_depth(node: ast.AST, current_depth: int = 0) -> int:
|
||||
@@ -559,7 +543,6 @@ def collect_feature_envy() -> list[Violation]:
|
||||
"converter",
|
||||
"exporter",
|
||||
"repository",
|
||||
"repo",
|
||||
}
|
||||
excluded_method_patterns = {
|
||||
"_to_domain",
|
||||
@@ -567,8 +550,6 @@ def collect_feature_envy() -> list[Violation]:
|
||||
"_proto_to_",
|
||||
"_to_orm",
|
||||
"_from_orm",
|
||||
"export",
|
||||
"_parse",
|
||||
}
|
||||
excluded_object_names = {
|
||||
"model",
|
||||
@@ -580,7 +561,6 @@ def collect_feature_envy() -> list[Violation]:
|
||||
"noteflow_pb2",
|
||||
"seg",
|
||||
"job",
|
||||
"repo",
|
||||
"ai",
|
||||
"summary",
|
||||
"MeetingState",
|
||||
@@ -590,12 +570,6 @@ def collect_feature_envy() -> list[Violation]:
|
||||
"uow",
|
||||
"span",
|
||||
"host",
|
||||
"logger",
|
||||
"data",
|
||||
"config",
|
||||
"p",
|
||||
"params",
|
||||
"args",
|
||||
}
|
||||
|
||||
violations: list[Violation] = []
|
||||
@@ -634,7 +608,7 @@ def collect_feature_envy() -> list[Violation]:
|
||||
for other_obj, count in other_accesses.items():
|
||||
if other_obj in excluded_object_names:
|
||||
continue
|
||||
if count > self_accesses + 3 and count > 5:
|
||||
if count > self_accesses + 2 and count > 4:
|
||||
violations.append(
|
||||
Violation(
|
||||
rule="feature_envy",
|
||||
|
||||
@@ -75,7 +75,7 @@ def count_function_lines(node: ast.FunctionDef | ast.AsyncFunctionDef) -> int:
|
||||
|
||||
def test_no_high_complexity_functions() -> None:
|
||||
"""Detect functions with high cyclomatic complexity."""
|
||||
max_complexity = 15
|
||||
max_complexity = 12
|
||||
|
||||
violations: list[Violation] = []
|
||||
parse_errors: list[str] = []
|
||||
@@ -109,7 +109,7 @@ def test_no_high_complexity_functions() -> None:
|
||||
|
||||
def test_no_long_parameter_lists() -> None:
|
||||
"""Detect functions with too many parameters."""
|
||||
max_params = 5
|
||||
max_params = 4
|
||||
|
||||
violations: list[Violation] = []
|
||||
parse_errors: list[str] = []
|
||||
@@ -151,8 +151,8 @@ def test_no_long_parameter_lists() -> None:
|
||||
|
||||
def test_no_god_classes() -> None:
|
||||
"""Detect classes with too many methods or too much responsibility."""
|
||||
max_methods = 20
|
||||
max_lines = 500
|
||||
max_methods = 15
|
||||
max_lines = 400
|
||||
|
||||
violations: list[Violation] = []
|
||||
parse_errors: list[str] = []
|
||||
@@ -203,7 +203,7 @@ def test_no_god_classes() -> None:
|
||||
|
||||
def test_no_deep_nesting() -> None:
|
||||
"""Detect functions with excessive nesting depth."""
|
||||
max_nesting = 3
|
||||
max_nesting = 2
|
||||
|
||||
violations: list[Violation] = []
|
||||
parse_errors: list[str] = []
|
||||
@@ -237,7 +237,7 @@ def test_no_deep_nesting() -> None:
|
||||
|
||||
def test_no_long_methods() -> None:
|
||||
"""Detect methods that are too long."""
|
||||
max_lines = 68
|
||||
max_lines = 50
|
||||
|
||||
violations: list[Violation] = []
|
||||
parse_errors: list[str] = []
|
||||
@@ -285,7 +285,6 @@ def test_no_feature_envy() -> None:
|
||||
"converter",
|
||||
"exporter",
|
||||
"repository",
|
||||
"repo",
|
||||
}
|
||||
excluded_method_patterns = {
|
||||
"_to_domain",
|
||||
@@ -293,8 +292,6 @@ def test_no_feature_envy() -> None:
|
||||
"_proto_to_",
|
||||
"_to_orm",
|
||||
"_from_orm",
|
||||
"export",
|
||||
"_parse", # Parsing external data (API responses)
|
||||
}
|
||||
# Objects that are commonly used more than self but aren't feature envy
|
||||
excluded_object_names = {
|
||||
@@ -317,12 +314,6 @@ def test_no_feature_envy() -> None:
|
||||
"uow", # Unit of work in service methods
|
||||
"span", # OpenTelemetry span in observability
|
||||
"host", # Servicer host in mixin methods
|
||||
"logger", # Logging is cross-cutting, not feature envy
|
||||
"data", # Dict parsing in from_dict factory methods
|
||||
"config", # Configuration object access
|
||||
"p", # Short alias for params in factory methods
|
||||
"params", # Parameters object in factory methods
|
||||
"args", # CLI args parsing in factory methods
|
||||
}
|
||||
|
||||
def _is_excluded_class(class_name: str) -> bool:
|
||||
@@ -385,7 +376,7 @@ def test_no_feature_envy() -> None:
|
||||
for other_obj, count in other_accesses.items():
|
||||
if other_obj in excluded_object_names:
|
||||
continue
|
||||
if count > self_accesses + 3 and count > 5:
|
||||
if count > self_accesses + 2 and count > 4:
|
||||
violations.append(
|
||||
Violation(
|
||||
rule="feature_envy",
|
||||
@@ -401,8 +392,8 @@ def test_no_feature_envy() -> None:
|
||||
|
||||
def test_module_size_limits() -> None:
|
||||
"""Check that modules don't exceed size limits."""
|
||||
soft_limit = 500
|
||||
hard_limit = 750
|
||||
soft_limit = 350
|
||||
hard_limit = 600
|
||||
|
||||
soft_violations: list[Violation] = []
|
||||
hard_violations: list[Violation] = []
|
||||
|
||||
@@ -124,11 +124,11 @@ def test_helpers_not_scattered() -> None:
|
||||
f" {', '.join(locations)}"
|
||||
)
|
||||
|
||||
# Target: 15 scattered helpers max - some duplication is expected for:
|
||||
# Target: 5 scattered helpers max - some duplication is expected for:
|
||||
# - Client/server pairs with same method names
|
||||
# - Mixin protocols + implementations
|
||||
assert len(scattered) <= 15, (
|
||||
f"Found {len(scattered)} scattered helper functions (max 15 allowed). "
|
||||
assert len(scattered) <= 5, (
|
||||
f"Found {len(scattered)} scattered helper functions (max 5 allowed). "
|
||||
"Consider consolidating:\n\n" + "\n\n".join(scattered[:5])
|
||||
)
|
||||
|
||||
@@ -192,10 +192,10 @@ def test_no_duplicate_helper_implementations() -> None:
|
||||
loc_strs = [f"{f}:{line}" for f, line, _ in locations]
|
||||
duplicates.append(f"'{signature}' defined at: {', '.join(loc_strs)}")
|
||||
|
||||
# Target: 25 duplicate helper signatures - some duplication expected for:
|
||||
# Target: 10 duplicate helper signatures - some duplication expected for:
|
||||
# - Mixin composition (protocol + implementation)
|
||||
# - Client/server pairs
|
||||
assert len(duplicates) <= 25, (
|
||||
f"Found {len(duplicates)} duplicate helper signatures (max 25 allowed):\n"
|
||||
assert len(duplicates) <= 10, (
|
||||
f"Found {len(duplicates)} duplicate helper signatures (max 10 allowed):\n"
|
||||
+ "\n".join(duplicates[:5])
|
||||
)
|
||||
|
||||
@@ -146,10 +146,9 @@ def test_no_duplicate_function_bodies() -> None:
|
||||
f" Preview: {preview}..."
|
||||
)
|
||||
|
||||
# Allow baseline - some duplication exists between client.py and streaming_session.py
|
||||
# for callback notification methods which will be consolidated during client refactoring
|
||||
assert len(violations) <= 1, (
|
||||
f"Found {len(violations)} duplicate function groups (max 1 allowed):\n\n"
|
||||
# Tighten: no duplicate function bodies allowed.
|
||||
assert len(violations) <= 0, (
|
||||
f"Found {len(violations)} duplicate function groups (max 0 allowed):\n\n"
|
||||
+ "\n\n".join(violations)
|
||||
)
|
||||
|
||||
@@ -186,7 +185,7 @@ def test_no_repeated_code_patterns() -> None:
|
||||
f" Sample locations: {', '.join(locations)}"
|
||||
)
|
||||
|
||||
# Target: 182 repeated patterns max - remaining are architectural:
|
||||
# Target: 120 repeated patterns max - remaining are architectural:
|
||||
# Hexagonal architecture requires Protocol interfaces to match implementations:
|
||||
# - Repository method signatures (~60): Service → Protocol → SQLAlchemy → Memory
|
||||
# Each method signature creates multiple overlapping 4-line windows
|
||||
@@ -200,8 +199,8 @@ def test_no_repeated_code_patterns() -> None:
|
||||
# - Import patterns (~10): webhook imports, RULE_FIELD imports, service TYPE_CHECKING
|
||||
# imports in _config.py/server.py/service.py for ServicesConfig pattern
|
||||
# Note: Alembic migrations are excluded from this check (immutable historical records)
|
||||
# Updated: 189 patterns after observability usage tracking additions
|
||||
assert len(repeated_patterns) <= 189, (
|
||||
f"Found {len(repeated_patterns)} significantly repeated patterns (max 189 allowed). "
|
||||
# Updated: 120 patterns after tightening expectations
|
||||
assert len(repeated_patterns) <= 120, (
|
||||
f"Found {len(repeated_patterns)} significantly repeated patterns (max 120 allowed). "
|
||||
f"Consider abstracting:\n\n" + "\n\n".join(repeated_patterns[:5])
|
||||
)
|
||||
|
||||
@@ -29,7 +29,7 @@ ALLOWED_NUMBERS = {
|
||||
0, 1, 2, 3, 4, 5, -1, # Small integers
|
||||
10, 20, 30, 50, # Common timeout/limit values
|
||||
60, 100, 200, 255, 365, 1000, 1024, # Common constants
|
||||
0.1, 0.3, 0.5, # Common float values
|
||||
0.1, 0.5, # Common float values
|
||||
16000, 50051, # Sample rate and gRPC port
|
||||
}
|
||||
ALLOWED_STRINGS = {
|
||||
@@ -39,11 +39,10 @@ ALLOWED_STRINGS = {
|
||||
"\t",
|
||||
"utf-8",
|
||||
"utf8",
|
||||
"w",
|
||||
"r",
|
||||
"w",
|
||||
"rb",
|
||||
"wb",
|
||||
"a",
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
@@ -58,205 +57,71 @@ ALLOWED_STRINGS = {
|
||||
"name",
|
||||
"type",
|
||||
"value",
|
||||
# Common domain/infrastructure terms
|
||||
"__main__",
|
||||
"noteflow",
|
||||
"meeting",
|
||||
"segment",
|
||||
"summary",
|
||||
"annotation",
|
||||
"CASCADE",
|
||||
"selectin",
|
||||
"schema",
|
||||
"role",
|
||||
"user",
|
||||
"text",
|
||||
"title",
|
||||
"status",
|
||||
"content",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"start_time",
|
||||
"end_time",
|
||||
"meeting_id",
|
||||
"user_id",
|
||||
"request_id",
|
||||
# Domain enums
|
||||
"action_item",
|
||||
"decision",
|
||||
"note",
|
||||
"risk",
|
||||
"unknown",
|
||||
"completed",
|
||||
"failed",
|
||||
"pending",
|
||||
"running",
|
||||
"markdown",
|
||||
"html",
|
||||
# Common patterns
|
||||
"base",
|
||||
"auto",
|
||||
"cuda",
|
||||
"int8",
|
||||
"float32",
|
||||
# argparse actions
|
||||
"store_true",
|
||||
"store_false",
|
||||
# ORM table/column names (intentionally repeated across models/repos)
|
||||
"meetings",
|
||||
"segments",
|
||||
"summaries",
|
||||
"annotations",
|
||||
"key_points",
|
||||
"action_items",
|
||||
"word_timings",
|
||||
"sample_rate",
|
||||
"segment_ids",
|
||||
"summary_id",
|
||||
"wrapped_dek",
|
||||
"diarization_jobs",
|
||||
"user_preferences",
|
||||
"streamingdiarization_turns",
|
||||
# ORM cascade settings
|
||||
"all, delete-orphan",
|
||||
# Foreign key references
|
||||
"noteflow.meetings.id",
|
||||
"noteflow.summaries.id",
|
||||
# Error message patterns (intentional consistency)
|
||||
"UnitOfWork not in context",
|
||||
"Invalid meeting_id",
|
||||
"Invalid annotation_id",
|
||||
# File names (infrastructure constants)
|
||||
"manifest.json",
|
||||
"audio.enc",
|
||||
# HTML tags
|
||||
"</div>",
|
||||
"</dd>",
|
||||
# Model class names (ORM back_populates/relationships - required by SQLAlchemy)
|
||||
"ActionItemModel",
|
||||
"AnnotationModel",
|
||||
"CalendarEventModel",
|
||||
"DiarizationJobModel",
|
||||
"ExternalRefModel",
|
||||
"IntegrationModel",
|
||||
"IntegrationSecretModel",
|
||||
"IntegrationSyncRunModel",
|
||||
"KeyPointModel",
|
||||
"MeetingCalendarLinkModel",
|
||||
"MeetingModel",
|
||||
"MeetingSpeakerModel",
|
||||
"MeetingTagModel",
|
||||
"NamedEntityModel",
|
||||
"PersonModel",
|
||||
"SegmentModel",
|
||||
"SettingsModel",
|
||||
"StreamingDiarizationTurnModel",
|
||||
"SummaryModel",
|
||||
"TagModel",
|
||||
"TaskModel",
|
||||
"UserModel",
|
||||
"UserPreferencesModel",
|
||||
"WebhookConfigModel",
|
||||
"WebhookDeliveryModel",
|
||||
"WordTimingModel",
|
||||
"WorkspaceMembershipModel",
|
||||
"WorkspaceModel",
|
||||
# ORM relationship back_populates names
|
||||
"workspace",
|
||||
"memberships",
|
||||
"integration",
|
||||
"meeting_tags",
|
||||
"tasks",
|
||||
# Foreign key references
|
||||
"noteflow.workspaces.id",
|
||||
"noteflow.users.id",
|
||||
"noteflow.integrations.id",
|
||||
# Database ondelete actions
|
||||
"SET NULL",
|
||||
"RESTRICT",
|
||||
# Column names used in mappings
|
||||
"metadata",
|
||||
"workspace_id",
|
||||
# Database URL prefixes
|
||||
"postgres://",
|
||||
"postgresql://",
|
||||
"postgresql+asyncpg://",
|
||||
# OIDC standard claim names (RFC 7519 / OpenID Connect Core spec)
|
||||
"sub",
|
||||
"email",
|
||||
"email_verified",
|
||||
"preferred_username",
|
||||
"groups",
|
||||
"picture",
|
||||
"given_name",
|
||||
"family_name",
|
||||
"openid",
|
||||
"profile",
|
||||
"offline_access",
|
||||
# OIDC discovery document fields (OpenID Connect Discovery spec)
|
||||
"issuer",
|
||||
"authorization_endpoint",
|
||||
"token_endpoint",
|
||||
"userinfo_endpoint",
|
||||
"jwks_uri",
|
||||
"end_session_endpoint",
|
||||
"revocation_endpoint",
|
||||
"introspection_endpoint",
|
||||
"scopes_supported",
|
||||
"code_challenge_methods_supported",
|
||||
"claims_supported",
|
||||
"response_types_supported",
|
||||
# OIDC provider config fields
|
||||
"discovery",
|
||||
"discovery_refreshed_at",
|
||||
"issuer_url",
|
||||
"client_id",
|
||||
"client_secret",
|
||||
"claim_mapping",
|
||||
"scopes",
|
||||
"preset",
|
||||
"require_email_verified",
|
||||
"allowed_groups",
|
||||
"enabled",
|
||||
# Integration status values
|
||||
"success",
|
||||
"error",
|
||||
"calendar",
|
||||
"provider",
|
||||
# Common error message fragments
|
||||
" not found",
|
||||
# HTML markup
|
||||
"<li>",
|
||||
"</li>",
|
||||
# Logging levels
|
||||
"INFO",
|
||||
"DEBUG",
|
||||
"WARNING",
|
||||
"ERROR",
|
||||
# Domain terms
|
||||
"project",
|
||||
# Internal attribute names (used in multiple gRPC handlers)
|
||||
"_pending_chunks",
|
||||
# Sentinel UUIDs
|
||||
"00000000-0000-0000-0000-000000000001",
|
||||
# Repository type names (used in TYPE_CHECKING imports and annotations)
|
||||
"MeetingRepository",
|
||||
"SegmentRepository",
|
||||
"SummaryRepository",
|
||||
"AnnotationRepository",
|
||||
"UserRepository",
|
||||
"WorkspaceRepository",
|
||||
# Pagination and filter parameter names (used in repositories and services)
|
||||
"states",
|
||||
"limit",
|
||||
"offset",
|
||||
"sort_desc",
|
||||
"project_id",
|
||||
"project_ids",
|
||||
# Domain attribute names (used across entities, converters, services)
|
||||
"provider_name",
|
||||
"model_name",
|
||||
"annotation_type",
|
||||
"error_message",
|
||||
"integration_id",
|
||||
"started_at",
|
||||
@@ -264,53 +129,22 @@ ALLOWED_STRINGS = {
|
||||
"slug",
|
||||
"description",
|
||||
"settings",
|
||||
"location",
|
||||
"date",
|
||||
"start",
|
||||
"attendees",
|
||||
"secret",
|
||||
"timeout_ms",
|
||||
"max_retries",
|
||||
"ascii",
|
||||
"code",
|
||||
# Security and logging categories
|
||||
"security",
|
||||
# Identity and role names
|
||||
"viewer",
|
||||
"User",
|
||||
"Webhook",
|
||||
"Workspaces",
|
||||
# Settings field names
|
||||
"rag_enabled",
|
||||
"default_summarization_template",
|
||||
# Cache keys
|
||||
"sync_run_cache_times",
|
||||
# Log context names
|
||||
"diarization_job",
|
||||
"segmenter_state_transition",
|
||||
# Default model names
|
||||
"gpt-4o-mini",
|
||||
"claude-3-haiku-20240307",
|
||||
# ORM model class names
|
||||
"ProjectMembershipModel",
|
||||
"ProjectModel",
|
||||
# Error code constants
|
||||
"service_not_enabled",
|
||||
# Protocol prefixes
|
||||
"http://",
|
||||
# Timezone format
|
||||
"+00:00",
|
||||
# gRPC status codes
|
||||
"INTERNAL",
|
||||
"UNKNOWN",
|
||||
# Proto type names (used in TYPE_CHECKING)
|
||||
"ProtoAnnotation",
|
||||
"ProtoMeeting",
|
||||
# Error message fragments
|
||||
", got ",
|
||||
# Error message templates (shared across diarization status handlers)
|
||||
"Cannot update job %s: database required",
|
||||
# Ruff ignore directive
|
||||
"ignore",
|
||||
}
|
||||
|
||||
@@ -392,11 +226,9 @@ def test_no_magic_numbers() -> None:
|
||||
for value, mvs in repeated
|
||||
]
|
||||
|
||||
# Target: 11 repeated magic numbers max - common values need named constants:
|
||||
# - 10 (20x), 1024 (14x), 5 (13x), 50 (12x), 40, 24, 300, 10000, 500 should be constants
|
||||
# Note: 40 (model display width), 24 (hours), 300 (timeouts), 10000/500 (http codes) are repeated
|
||||
assert len(violations) <= 11, (
|
||||
f"Found {len(violations)} repeated magic numbers (max 11 allowed). "
|
||||
# Target: 5 repeated magic numbers max - common values should be constants.
|
||||
assert len(violations) <= 5, (
|
||||
f"Found {len(violations)} repeated magic numbers (max 5 allowed). "
|
||||
"Consider extracting to constants:\n\n" + "\n\n".join(violations[:5])
|
||||
)
|
||||
|
||||
@@ -440,10 +272,10 @@ def test_no_repeated_string_literals() -> None:
|
||||
for value, locs in repeated
|
||||
]
|
||||
|
||||
# Target: 31 repeated strings max - many can be extracted to constants
|
||||
# Target: 10 repeated strings max - many can be extracted to constants
|
||||
# - Error messages, schema names, log formats should be centralized
|
||||
assert len(violations) <= 31, (
|
||||
f"Found {len(violations)} repeated string literals (max 31 allowed). "
|
||||
assert len(violations) <= 10, (
|
||||
f"Found {len(violations)} repeated string literals (max 10 allowed). "
|
||||
"Consider using constants or enums:\n\n" + "\n\n".join(violations[:5])
|
||||
)
|
||||
|
||||
|
||||
@@ -106,8 +106,8 @@ def test_no_assertion_roulette() -> None:
|
||||
if not has_assertion_message(node):
|
||||
assertions_without_msg += 1
|
||||
|
||||
# Flag if >3 assertions without messages
|
||||
if assertions_without_msg > 3:
|
||||
# Flag if >1 assertions without messages
|
||||
if assertions_without_msg > 1:
|
||||
violations.append(
|
||||
Violation(
|
||||
rule="assertion_roulette",
|
||||
@@ -237,9 +237,6 @@ def test_no_sleepy_tests() -> None:
|
||||
# Paths where sleep is legitimately needed for stress/resilience testing
|
||||
allowed_sleepy_paths = {
|
||||
"tests/stress/",
|
||||
"tests/integration/test_signal_handling.py",
|
||||
"tests/integration/test_database_resilience.py",
|
||||
"tests/grpc/test_stream_lifecycle.py",
|
||||
}
|
||||
|
||||
violations: list[Violation] = []
|
||||
@@ -589,15 +586,10 @@ def test_no_sensitive_equality() -> None:
|
||||
excluded_test_patterns = {
|
||||
"string", # Testing string conversion behavior
|
||||
"proto", # Testing protobuf field serialization
|
||||
"conversion", # Testing type conversion
|
||||
"serializ", # Testing serialization
|
||||
"preserves_message", # Testing error message preservation
|
||||
}
|
||||
|
||||
# File patterns where str() comparison is expected (gRPC field serialization)
|
||||
excluded_file_patterns = {
|
||||
"_mixin", # gRPC mixin tests compare ID fields
|
||||
}
|
||||
excluded_file_patterns: set[str] = set()
|
||||
|
||||
def _is_excluded_test(test_name: str) -> bool:
|
||||
"""Check if test legitimately uses str() comparison."""
|
||||
@@ -662,7 +654,7 @@ def test_no_eager_tests() -> None:
|
||||
"""
|
||||
violations: list[Violation] = []
|
||||
parse_errors: list[str] = []
|
||||
max_method_calls = 10 # Threshold for "too many" method calls
|
||||
max_method_calls = 7 # Threshold for "too many" method calls
|
||||
|
||||
for py_file in find_test_files():
|
||||
tree, error = parse_file_safe(py_file)
|
||||
@@ -795,7 +787,7 @@ def test_no_long_test_methods() -> None:
|
||||
Long tests are hard to understand and maintain. Break them into
|
||||
smaller, focused tests or extract helper functions.
|
||||
"""
|
||||
max_lines = 46
|
||||
max_lines = 35
|
||||
|
||||
violations: list[Violation] = []
|
||||
parse_errors: list[str] = []
|
||||
|
||||
@@ -69,7 +69,6 @@ def test_no_trivial_wrapper_functions() -> None:
|
||||
("full_transcript", "join"),
|
||||
("duration", "sub"),
|
||||
("is_active", "property"),
|
||||
("is_admin", "can_admin"), # semantic alias for operation context
|
||||
# Domain method accessors (type-safe dict access)
|
||||
("get_metadata", "get"),
|
||||
# Strategy pattern implementations (RuleType.evaluate for simple mode)
|
||||
@@ -79,43 +78,25 @@ def test_no_trivial_wrapper_functions() -> None:
|
||||
("generate_request_id", "str"), # UUID to string conversion
|
||||
# Context variable accessors (public API over internal contextvars)
|
||||
("get_request_id", "get"),
|
||||
("get_user_id", "get"),
|
||||
("get_workspace_id", "get"),
|
||||
# Time conversion utilities (semantic naming for datetime operations)
|
||||
("datetime_to_epoch_seconds", "timestamp"),
|
||||
("datetime_to_iso_string", "isoformat"),
|
||||
("epoch_seconds_to_datetime", "fromtimestamp"),
|
||||
("proto_timestamp_to_datetime", "replace"),
|
||||
# Accessor-style wrappers with semantic names
|
||||
("from_metrics", "cls"),
|
||||
("from_dict", "cls"),
|
||||
("empty", "cls"),
|
||||
# ProcessingStepState factory methods (GAP-W05)
|
||||
("pending", "cls"),
|
||||
("running", "cls"),
|
||||
("completed", "cls"),
|
||||
("failed", "cls"),
|
||||
("skipped", "cls"),
|
||||
("create_pending", "cls"),
|
||||
("get_log_level", "get"),
|
||||
("get_preset_config", "get"),
|
||||
("get_provider", "get"),
|
||||
("get_pending_state", "get"),
|
||||
("get_stream_state", "get"),
|
||||
("get_async_session_factory", "async_sessionmaker"),
|
||||
("process_chunk", "process"),
|
||||
("get_openai_client", "_get_openai_client"),
|
||||
("meeting_apps", "frozenset"),
|
||||
("suppressed_apps", "frozenset"),
|
||||
("get_sync_run", "get"),
|
||||
("list_all", "list"),
|
||||
("get_by_id", "get"),
|
||||
("create", "insert"),
|
||||
("delete_by_meeting", "clear_summary"),
|
||||
("get_by_meeting", "fetch_segments"),
|
||||
("get_by_meeting", "get_summary"),
|
||||
("check_otel_available", "_check_otel_available"),
|
||||
("start_as_current_span", "_NoOpSpanContext"),
|
||||
("start_span", "_NoOpSpan"),
|
||||
("detected_app", "next"),
|
||||
}
|
||||
@@ -231,10 +212,7 @@ def test_no_redundant_type_aliases() -> None:
|
||||
def test_no_passthrough_classes() -> None:
|
||||
"""Detect classes that only delegate to another object."""
|
||||
# Classes that are intentionally factory-pattern based (all methods return cls())
|
||||
allowed_factory_classes = {
|
||||
# Domain entity with factory methods for creating different states (GAP-W05)
|
||||
"ProcessingStepState",
|
||||
}
|
||||
allowed_factory_classes: set[str] = set()
|
||||
|
||||
violations: list[Violation] = []
|
||||
parse_errors: list[str] = []
|
||||
|
||||
Reference in New Issue
Block a user