chore: update client submodule and add authentication service helpers

- Updated the client submodule to the latest commit for improved features and stability.
- Introduced new authentication service helpers for user and workspace management, enhancing the overall authentication flow.
- Added shared authentication constants for better maintainability and clarity in the codebase.
This commit is contained in:
2026-01-05 16:10:12 +00:00
parent fdb97fb0fd
commit fdb9b69256
27 changed files with 900 additions and 817 deletions

File diff suppressed because one or more lines are too long

2
client

Submodule client updated: 5ab973a1d7...c53b16693a

View File

@@ -0,0 +1,8 @@
"""Shared authentication constants."""
from __future__ import annotations
from uuid import UUID
DEFAULT_USER_ID = UUID("00000000-0000-0000-0000-000000000001")
DEFAULT_WORKSPACE_ID = UUID("00000000-0000-0000-0000-000000000001")

View File

@@ -0,0 +1,211 @@
"""Internal helpers for auth service operations."""
from __future__ import annotations
from typing import TYPE_CHECKING
from uuid import UUID, uuid4
from noteflow.domain.entities.integration import Integration, IntegrationType
from noteflow.domain.identity.entities import User
from noteflow.domain.value_objects import OAuthProvider, OAuthTokens
from noteflow.infrastructure.calendar import OAuthManager
from noteflow.infrastructure.logging import get_logger
from .auth_constants import DEFAULT_USER_ID, DEFAULT_WORKSPACE_ID
from .auth_types import AuthResult
if TYPE_CHECKING:
from noteflow.domain.ports.unit_of_work import UnitOfWork
logger = get_logger(__name__)
def resolve_provider_email(integration: Integration) -> str:
"""Resolve provider email with a consistent fallback."""
return integration.provider_email or "User"
def resolve_user_id_from_integration(integration: Integration) -> UUID:
"""Resolve the user ID from integration config, falling back to default."""
user_id_str = integration.config.get("user_id")
return UUID(user_id_str) if user_id_str else DEFAULT_USER_ID
async def get_or_create_user_id(
uow: UnitOfWork,
email: str,
display_name: str,
) -> UUID:
"""Fetch an existing user or create a new one, returning user ID."""
user = await uow.users.get_by_email(email) if uow.supports_users else None
if user is None:
user_id = uuid4()
if uow.supports_users:
user = User(
id=user_id,
email=email,
display_name=display_name,
is_default=False,
)
await uow.users.create(user)
logger.info("Created new user: %s (%s)", display_name, email)
else:
user_id = DEFAULT_USER_ID
return user_id
user_id = user.id
if user.display_name != display_name:
user.display_name = display_name
await uow.users.update(user)
return user_id
async def get_or_create_default_workspace_id(
uow: UnitOfWork,
user_id: UUID,
) -> UUID:
"""Fetch or create the default workspace for a user."""
if not uow.supports_workspaces:
return DEFAULT_WORKSPACE_ID
workspace = await uow.workspaces.get_default_for_user(user_id)
if workspace:
return workspace.id
workspace_id = uuid4()
await uow.workspaces.create(
workspace_id=workspace_id,
name="Personal",
owner_id=user_id,
is_default=True,
)
logger.info(
"Created default workspace for user_id=%s, workspace_id=%s",
user_id,
workspace_id,
)
return workspace_id
async def get_or_create_auth_integration(
uow: UnitOfWork,
provider: str,
workspace_id: UUID,
user_id: UUID,
provider_email: str,
) -> Integration:
"""Fetch or create the auth integration for a provider."""
integration = await uow.integrations.get_by_provider(
provider=provider,
integration_type=IntegrationType.AUTH.value,
)
if integration is None:
integration = Integration.create(
workspace_id=workspace_id,
name=f"{provider.title()} Auth",
integration_type=IntegrationType.AUTH,
config={"provider": provider, "user_id": str(user_id)},
)
await uow.integrations.create(integration)
else:
integration.config["provider"] = provider
integration.config["user_id"] = str(user_id)
integration.connect(provider_email=provider_email)
await uow.integrations.update(integration)
return integration
async def store_integration_tokens(
uow: UnitOfWork,
integration: Integration,
tokens: OAuthTokens,
) -> None:
"""Persist updated tokens for an integration."""
await uow.integrations.set_secrets(
integration_id=integration.id,
secrets=tokens.to_secrets_dict(),
)
async def find_connected_auth_integration(
uow: UnitOfWork,
) -> tuple[str, Integration] | None:
"""Return the first connected auth integration and provider name."""
if not getattr(uow, "supports_integrations", False):
return None
for provider in (OAuthProvider.GOOGLE.value, OAuthProvider.OUTLOOK.value):
integration = await uow.integrations.get_by_provider(
provider=provider,
integration_type=IntegrationType.AUTH.value,
)
if integration and integration.is_connected:
return provider, integration
return None
async def resolve_display_name(
uow: UnitOfWork,
user_id_str: str | None,
fallback: str,
) -> str:
"""Resolve display name from user repository if available."""
if not (uow.supports_users and user_id_str):
return fallback
user_id = UUID(user_id_str)
user = await uow.users.get(user_id)
return user.display_name if user else fallback
async def refresh_tokens_for_integration(
uow: UnitOfWork,
oauth_provider: OAuthProvider,
integration: Integration,
oauth_manager: OAuthManager,
) -> AuthResult | None:
"""Refresh tokens for a connected integration if needed."""
secrets = await uow.integrations.get_secrets(integration.id)
if not secrets:
return None
try:
tokens = OAuthTokens.from_secrets_dict(secrets)
except (KeyError, ValueError):
return None
if not tokens.refresh_token:
return None
if not tokens.is_expired(buffer_seconds=300):
logger.debug(
"auth_token_still_valid",
provider=oauth_provider.value,
expires_at=tokens.expires_at.isoformat() if tokens.expires_at else None,
)
user_id = resolve_user_id_from_integration(integration)
return AuthResult(
user_id=user_id,
workspace_id=DEFAULT_WORKSPACE_ID,
display_name=resolve_provider_email(integration),
email=integration.provider_email,
)
new_tokens = await oauth_manager.refresh_tokens(
provider=oauth_provider,
refresh_token=tokens.refresh_token,
)
await store_integration_tokens(uow, integration, new_tokens)
await uow.commit()
user_id = resolve_user_id_from_integration(integration)
return AuthResult(
user_id=user_id,
workspace_id=DEFAULT_WORKSPACE_ID,
display_name=resolve_provider_email(integration),
email=integration.provider_email,
)

View File

@@ -7,13 +7,11 @@ IntegrationType.AUTH and manages User entities.
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, TypedDict, Unpack
from uuid import UUID, uuid4
from uuid import UUID
from noteflow.config.constants import OAUTH_FIELD_ACCESS_TOKEN
from noteflow.domain.entities.integration import Integration, IntegrationType
from noteflow.domain.identity.entities import User
from noteflow.domain.entities.integration import IntegrationType
from noteflow.domain.value_objects import OAuthProvider, OAuthTokens
from noteflow.infrastructure.calendar import OAuthManager
from noteflow.infrastructure.calendar.google_adapter import GoogleCalendarError
@@ -21,6 +19,18 @@ from noteflow.infrastructure.calendar.oauth_manager import OAuthError
from noteflow.infrastructure.calendar.outlook_adapter import OutlookCalendarError
from noteflow.infrastructure.logging import get_logger
from .auth_constants import DEFAULT_USER_ID, DEFAULT_WORKSPACE_ID
from .auth_helpers import (
find_connected_auth_integration,
get_or_create_auth_integration,
get_or_create_default_workspace_id,
get_or_create_user_id,
refresh_tokens_for_integration,
store_integration_tokens,
resolve_display_name,
)
from .auth_types import AuthResult, LogoutResult, UserInfo
class _AuthServiceDepsKwargs(TypedDict, total=False):
"""Optional dependency overrides for AuthService."""
@@ -37,59 +47,10 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
# Default IDs for local-first mode
DEFAULT_USER_ID = UUID("00000000-0000-0000-0000-000000000001")
DEFAULT_WORKSPACE_ID = UUID("00000000-0000-0000-0000-000000000001")
class AuthServiceError(Exception):
"""Auth service operation failed."""
@dataclass(frozen=True, slots=True)
class AuthResult:
"""Result of successful authentication.
Note: Tokens are stored securely in IntegrationSecretModel and are NOT
exposed to callers. Use get_current_user() to check auth status.
"""
user_id: UUID
workspace_id: UUID
display_name: str
email: str | None
is_authenticated: bool = True
@dataclass(frozen=True, slots=True)
class UserInfo:
"""Current user information."""
user_id: UUID
workspace_id: UUID
display_name: str
email: str | None
is_authenticated: bool
provider: str | None
@dataclass(frozen=True, slots=True)
class LogoutResult:
"""Result of logout operation.
Provides visibility into both local logout and remote token revocation.
"""
logged_out: bool
"""Whether local logout succeeded (integration deleted)."""
tokens_revoked: bool
"""Whether remote token revocation succeeded."""
revocation_error: str | None = None
"""Error message if revocation failed (for logging/debugging)."""
class AuthService:
"""Authentication service for OAuth-based user login.
@@ -261,79 +222,16 @@ class AuthService:
) -> tuple[UUID, UUID]:
"""Create or update user and store auth tokens."""
async with self._uow_factory() as uow:
# Find or create user by email
user = None
if uow.supports_users:
user = await uow.users.get_by_email(email)
if user is None:
# Create new user
user_id = uuid4()
if uow.supports_users:
user = User(
id=user_id,
email=email,
display_name=display_name,
is_default=False,
)
await uow.users.create(user)
logger.info("Created new user: %s (%s)", display_name, email)
else:
user_id = DEFAULT_USER_ID
else:
user_id = user.id
# Update display name if changed
if user.display_name != display_name:
user.display_name = display_name
await uow.users.update(user)
# Get or create default workspace for this user
workspace_id = DEFAULT_WORKSPACE_ID
if uow.supports_workspaces:
workspace = await uow.workspaces.get_default_for_user(user_id)
if workspace:
workspace_id = workspace.id
else:
# Create default "Personal" workspace for new user
workspace_id = uuid4()
await uow.workspaces.create(
workspace_id=workspace_id,
name="Personal",
owner_id=user_id,
is_default=True,
)
logger.info(
"Created default workspace for user_id=%s, workspace_id=%s",
user_id,
workspace_id,
)
# Store auth integration with tokens
integration = await uow.integrations.get_by_provider(
user_id = await get_or_create_user_id(uow, email, display_name)
workspace_id = await get_or_create_default_workspace_id(uow, user_id)
integration = await get_or_create_auth_integration(
uow,
provider=provider,
integration_type=IntegrationType.AUTH.value,
)
if integration is None:
integration = Integration.create(
workspace_id=workspace_id,
name=f"{provider.title()} Auth",
integration_type=IntegrationType.AUTH,
config={"provider": provider, "user_id": str(user_id)},
)
await uow.integrations.create(integration)
else:
integration.config["provider"] = provider
integration.config["user_id"] = str(user_id)
integration.connect(provider_email=email)
await uow.integrations.update(integration)
# Store tokens
await uow.integrations.set_secrets(
integration_id=integration.id,
secrets=tokens.to_secrets_dict(),
workspace_id=workspace_id,
user_id=user_id,
provider_email=email,
)
await store_integration_tokens(uow, integration, tokens)
await uow.commit()
return user_id, workspace_id
@@ -345,42 +243,33 @@ class AuthService:
UserInfo with current user details or local default.
"""
async with self._uow_factory() as uow:
# Look for any connected auth integration
for provider in [OAuthProvider.GOOGLE.value, OAuthProvider.OUTLOOK.value]:
integration = await uow.integrations.get_by_provider(
found = await find_connected_auth_integration(uow)
if found:
provider, integration = found
user_id_str = integration.config.get("user_id")
user_id = UUID(user_id_str) if user_id_str else DEFAULT_USER_ID
display_name = await resolve_display_name(
uow,
user_id_str,
fallback="Authenticated User",
)
return UserInfo(
user_id=user_id,
workspace_id=DEFAULT_WORKSPACE_ID,
display_name=display_name,
email=integration.provider_email,
is_authenticated=True,
provider=provider,
integration_type=IntegrationType.AUTH.value,
)
if integration and integration.is_connected:
user_id_str = integration.config.get("user_id")
user_id = UUID(user_id_str) if user_id_str else DEFAULT_USER_ID
# Get user details
display_name = "Authenticated User"
if uow.supports_users and user_id_str:
user = await uow.users.get(user_id)
if user:
display_name = user.display_name
return UserInfo(
user_id=user_id,
workspace_id=DEFAULT_WORKSPACE_ID,
display_name=display_name,
email=integration.provider_email,
is_authenticated=True,
provider=provider,
)
# Return local default
return UserInfo(
user_id=DEFAULT_USER_ID,
workspace_id=DEFAULT_WORKSPACE_ID,
display_name="Local User",
email=None,
is_authenticated=False,
provider=None,
)
return UserInfo(
user_id=DEFAULT_USER_ID,
workspace_id=DEFAULT_WORKSPACE_ID,
display_name="Local User",
email=None,
is_authenticated=False,
provider=None,
)
async def logout(self, provider: str | None = None) -> LogoutResult:
"""Logout and revoke auth tokens.
@@ -495,57 +384,13 @@ class AuthService:
if integration is None or not integration.is_connected:
return None
secrets = await uow.integrations.get_secrets(integration.id)
if not secrets:
return None
try:
tokens = OAuthTokens.from_secrets_dict(secrets)
except (KeyError, ValueError):
return None
if not tokens.refresh_token:
return None
# Only refresh if token is expired or will expire within 5 minutes
if not tokens.is_expired(buffer_seconds=300):
logger.debug(
"auth_token_still_valid",
provider=provider,
expires_at=tokens.expires_at.isoformat() if tokens.expires_at else None,
return await refresh_tokens_for_integration(
uow,
oauth_provider=oauth_provider,
integration=integration,
oauth_manager=self._oauth_manager,
)
# Return existing auth info without refreshing
user_id_str = integration.config.get("user_id")
user_id = UUID(user_id_str) if user_id_str else DEFAULT_USER_ID
return AuthResult(
user_id=user_id,
workspace_id=DEFAULT_WORKSPACE_ID,
display_name=integration.provider_email or "User",
email=integration.provider_email,
)
try:
new_tokens = await self._oauth_manager.refresh_tokens(
provider=oauth_provider,
refresh_token=tokens.refresh_token,
)
await uow.integrations.set_secrets(
integration_id=integration.id,
secrets=new_tokens.to_secrets_dict(),
)
await uow.commit()
user_id_str = integration.config.get("user_id")
user_id = UUID(user_id_str) if user_id_str else DEFAULT_USER_ID
return AuthResult(
user_id=user_id,
workspace_id=DEFAULT_WORKSPACE_ID,
display_name=integration.provider_email or "User",
email=integration.provider_email,
)
except OAuthError as e:
integration.mark_error(f"Token refresh failed: {e}")
await uow.integrations.update(integration)

View File

@@ -0,0 +1,50 @@
"""Auth service data structures."""
from __future__ import annotations
from dataclasses import dataclass
from uuid import UUID
@dataclass(frozen=True, slots=True)
class AuthResult:
"""Result of successful authentication.
Note: Tokens are stored securely in IntegrationSecretModel and are NOT
exposed to callers. Use get_current_user() to check auth status.
"""
user_id: UUID
workspace_id: UUID
display_name: str
email: str | None
is_authenticated: bool = True
@dataclass(frozen=True, slots=True)
class UserInfo:
"""Current user information."""
user_id: UUID
workspace_id: UUID
display_name: str
email: str | None
is_authenticated: bool
provider: str | None
@dataclass(frozen=True, slots=True)
class LogoutResult:
"""Result of logout operation.
Provides visibility into both local logout and remote token revocation.
"""
logged_out: bool
"""Whether local logout succeeded (integration deleted)."""
tokens_revoked: bool
"""Whether remote token revocation succeeded."""
revocation_error: str | None = None
"""Error message if revocation failed (for logging/debugging)."""

View File

@@ -54,14 +54,16 @@ class ProcessingStepState:
@classmethod
def pending(cls) -> ProcessingStepState:
"""Create a pending step state."""
return cls(status=ProcessingStepStatus.PENDING)
status = ProcessingStepStatus.PENDING
return cls(status=status)
@classmethod
def running(cls, started_at: datetime | None = None) -> ProcessingStepState:
"""Create a running step state."""
started = started_at or utc_now()
return cls(
status=ProcessingStepStatus.RUNNING,
started_at=started_at or utc_now(),
started_at=started,
)
@classmethod
@@ -71,10 +73,11 @@ class ProcessingStepState:
completed_at: datetime | None = None,
) -> ProcessingStepState:
"""Create a completed step state."""
completed = completed_at or utc_now()
return cls(
status=ProcessingStepStatus.COMPLETED,
started_at=started_at,
completed_at=completed_at or utc_now(),
completed_at=completed,
)
@classmethod
@@ -84,17 +87,29 @@ class ProcessingStepState:
started_at: datetime | None = None,
) -> ProcessingStepState:
"""Create a failed step state."""
completed = utc_now()
return cls(
status=ProcessingStepStatus.FAILED,
error_message=error_message,
started_at=started_at,
completed_at=utc_now(),
completed_at=completed,
)
@classmethod
def skipped(cls) -> ProcessingStepState:
"""Create a skipped step state."""
return cls(status=ProcessingStepStatus.SKIPPED)
status = ProcessingStepStatus.SKIPPED
return cls(status=status)
def with_error(self, message: str) -> ProcessingStepState:
"""Return a failed state derived from this instance."""
started_at = self.started_at or utc_now()
return ProcessingStepState(
status=ProcessingStepStatus.FAILED,
error_message=message,
started_at=started_at,
completed_at=utc_now(),
)
@dataclass(frozen=True, slots=True)
@@ -113,7 +128,8 @@ class ProcessingStatus:
@classmethod
def create_pending(cls) -> ProcessingStatus:
"""Create a processing status with all steps pending."""
return cls()
status = cls()
return status
@property
def is_complete(self) -> bool:

View File

@@ -0,0 +1,15 @@
"""Identity service singleton for gRPC runtime."""
from __future__ import annotations
from noteflow.application.services.identity_service import IdentityService
_identity_service_instance: IdentityService | None = None
def default_identity_service() -> IdentityService:
"""Get or create the default identity service singleton."""
global _identity_service_instance
if _identity_service_instance is None:
_identity_service_instance = IdentityService()
return _identity_service_instance

View File

@@ -17,6 +17,21 @@ from ._types import GrpcContext
logger = get_logger(__name__)
async def _resolve_auth_status(uow: UnitOfWork) -> tuple[bool, str]:
"""Resolve authentication status and provider from integrations."""
if not getattr(uow, "supports_integrations", False):
return False, ""
for provider in ("google", "outlook"):
integration = await uow.integrations.get_by_provider(
provider=provider,
integration_type=IntegrationType.AUTH.value,
)
if integration and integration.is_connected:
return True, provider
return False, ""
class IdentityServicer(Protocol):
"""Protocol for hosts that support identity operations."""
@@ -55,19 +70,7 @@ class IdentityMixin:
await uow.commit()
# Check if user has auth integration (authenticated via OAuth)
is_authenticated = False
auth_provider = ""
if hasattr(uow, "supports_integrations") and uow.supports_integrations:
for provider in ["google", "outlook"]:
integration = await uow.integrations.get_by_provider(
provider=provider,
integration_type=IntegrationType.AUTH.value,
)
if integration and integration.is_connected:
is_authenticated = True
auth_provider = provider
break
is_authenticated, auth_provider = await _resolve_auth_status(uow)
logger.debug(
"GetCurrentUser: user_id=%s, workspace_id=%s, authenticated=%s",

View File

@@ -86,6 +86,60 @@ class _HasField(Protocol):
def HasField(self, field_name: str) -> bool: ...
async def _parse_project_ids_or_abort(
request: noteflow_pb2.ListMeetingsRequest,
context: GrpcContext,
) -> list[UUID] | None:
"""Parse optional project_ids list, aborting on invalid values."""
if not request.project_ids:
return None
project_ids: list[UUID] = []
for raw_project_id in request.project_ids:
try:
project_ids.append(UUID(raw_project_id))
except ValueError:
truncated = (
raw_project_id[:8] + "..." if len(raw_project_id) > 8 else raw_project_id
)
logger.warning(
"ListMeetings: invalid project_ids format",
project_id_truncated=truncated,
project_id_length=len(raw_project_id),
)
await abort_invalid_argument(
context,
f"{ERROR_INVALID_PROJECT_ID_PREFIX}{raw_project_id}",
)
return None
return project_ids
async def _parse_project_id_or_abort(
request: noteflow_pb2.ListMeetingsRequest,
context: GrpcContext,
) -> UUID | None:
"""Parse optional project_id, aborting on invalid values."""
if not (cast(_HasField, request).HasField("project_id") and request.project_id):
return None
try:
return UUID(request.project_id)
except ValueError:
truncated = (
request.project_id[:8] + "..." if len(request.project_id) > 8 else request.project_id
)
logger.warning(
"ListMeetings: invalid project_id format",
project_id_truncated=truncated,
project_id_length=len(request.project_id),
)
error_message = f"{ERROR_INVALID_PROJECT_ID_PREFIX}{request.project_id}"
await abort_invalid_argument(context, error_message)
return None
class MeetingServicer(Protocol):
"""Protocol for hosts that support meeting operations."""
@@ -283,40 +337,9 @@ class MeetingMixin:
state_values = cast(Sequence[int], request.states)
states = [MeetingState(s) for s in state_values] if state_values else None
project_id: UUID | None = None
project_ids: list[UUID] | None = None
if request.project_ids:
project_ids = []
for raw_project_id in request.project_ids:
try:
project_ids.append(UUID(raw_project_id))
except ValueError:
truncated = raw_project_id[:8] + "..." if len(raw_project_id) > 8 else raw_project_id
logger.warning(
"ListMeetings: invalid project_ids format",
project_id_truncated=truncated,
project_id_length=len(raw_project_id),
)
await abort_invalid_argument(
context,
f"{ERROR_INVALID_PROJECT_ID_PREFIX}{raw_project_id}",
)
if (
not project_ids
and cast(_HasField, request).HasField("project_id")
and request.project_id
):
try:
project_id = UUID(request.project_id)
except ValueError:
truncated = request.project_id[:8] + "..." if len(request.project_id) > 8 else request.project_id
logger.warning(
"ListMeetings: invalid project_id format",
project_id_truncated=truncated,
project_id_length=len(request.project_id),
)
await abort_invalid_argument(context, f"{ERROR_INVALID_PROJECT_ID_PREFIX}{request.project_id}")
project_ids = await _parse_project_ids_or_abort(request, context)
if not project_ids:
project_id = await _parse_project_id_or_abort(request, context)
async with self.create_repository_provider() as repo:
if project_id is None and not project_ids:

View File

@@ -0,0 +1,22 @@
"""Runtime gRPC servicer base types."""
from __future__ import annotations
from typing import TYPE_CHECKING
from .proto import noteflow_pb2_grpc
if TYPE_CHECKING:
GrpcBaseServicer = object
class NoteFlowServicerStubs:
"""Type-checking placeholder for servicer stubs."""
pass
else:
GrpcBaseServicer = noteflow_pb2_grpc.NoteFlowServiceServicer
class NoteFlowServicerStubs:
"""Runtime placeholder for type stubs (empty at runtime)."""
pass

View File

@@ -7,6 +7,7 @@ Used as fallback when no database is configured.
from __future__ import annotations
import threading
from dataclasses import dataclass
from typing import TYPE_CHECKING, Unpack
from noteflow.config.constants import ERROR_MSG_MEETING_PREFIX
@@ -21,6 +22,62 @@ if TYPE_CHECKING:
from datetime import datetime
@dataclass(frozen=True, slots=True)
class _MeetingListOptions:
states: set[MeetingState] | None
limit: int
offset: int
sort_desc: bool
project_id: str | None
project_ids: set[str] | None
def _normalize_list_options(
options: MeetingListKwargs,
) -> _MeetingListOptions:
states = options.get("states")
state_set = set(states) if states else None
limit = options.get("limit", 100)
offset = options.get("offset", 0)
sort_desc = options.get("sort_desc", True)
project_id = options.get("project_id")
project_ids = options.get("project_ids")
project_id_set = set(project_ids) if project_ids else None
return _MeetingListOptions(
states=state_set,
limit=limit,
offset=offset,
sort_desc=sort_desc,
project_id=project_id,
project_ids=project_id_set,
)
def _filter_meetings(
meetings: list[Meeting],
options: _MeetingListOptions,
) -> list[Meeting]:
filtered = meetings
if options.states:
filtered = [m for m in filtered if m.state in options.states]
if options.project_ids:
filtered = [
m
for m in filtered
if m.project_id is not None and str(m.project_id) in options.project_ids
]
elif options.project_id:
filtered = [
m
for m in filtered
if m.project_id is not None and str(m.project_id) == options.project_id
]
return filtered
class MeetingStore:
"""Thread-safe in-memory meeting storage using domain entities."""
@@ -102,38 +159,11 @@ class MeetingStore:
Tuple of (paginated meeting list, total matching count).
"""
with self._lock:
states = kwargs.get("states")
limit = kwargs.get("limit", 100)
offset = kwargs.get("offset", 0)
sort_desc = kwargs.get("sort_desc", True)
project_id = kwargs.get("project_id")
project_ids = kwargs.get("project_ids")
meetings = list(self._meetings.values())
# Filter by state
if states:
state_set = set(states)
meetings = [m for m in meetings if m.state in state_set]
# Filter by project(s) if requested
if project_ids:
project_set = set(project_ids)
meetings = [
m for m in meetings if m.project_id is not None and str(m.project_id) in project_set
]
elif project_id:
meetings = [
m for m in meetings if m.project_id is not None and str(m.project_id) == project_id
]
options = _normalize_list_options(kwargs)
meetings = _filter_meetings(list(self._meetings.values()), options)
total = len(meetings)
# Sort
meetings.sort(key=lambda m: m.created_at, reverse=sort_desc)
# Paginate
meetings = meetings[offset : offset + limit]
meetings.sort(key=lambda m: m.created_at, reverse=options.sort_desc)
meetings = meetings[options.offset : options.offset + options.limit]
return meetings, total
def find_older_than(self, cutoff: datetime) -> list[Meeting]:

View File

@@ -12,7 +12,6 @@ from typing import TYPE_CHECKING, ClassVar, Final
from uuid import UUID
from noteflow import __version__
from noteflow.application.services.identity_service import IdentityService
from noteflow.domain.identity.context import OperationContext, UserContext, WorkspaceContext
from noteflow.domain.identity.roles import WorkspaceRole
from noteflow.infrastructure.logging import request_id_var, user_id_var, workspace_id_var
@@ -52,7 +51,9 @@ from ._mixins import (
SyncMixin,
WebhooksMixin,
)
from .proto import noteflow_pb2, noteflow_pb2_grpc
from ._identity_singleton import default_identity_service
from ._service_base import GrpcBaseServicer, NoteFlowServicerStubs
from .proto import noteflow_pb2
from .stream_state import MeetingStreamState
if TYPE_CHECKING:
@@ -65,31 +66,6 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
if TYPE_CHECKING:
_GrpcBaseServicer = object
else:
_GrpcBaseServicer = noteflow_pb2_grpc.NoteFlowServiceServicer
# Empty class to satisfy MRO - cannot use `object` directly as it conflicts
# with NoteFlowServiceServicer's inheritance from object
class NoteFlowServicerStubs:
"""Runtime placeholder for type stubs (empty at runtime)."""
pass
# Module-level singleton for identity service (stateless, no dependencies)
_identity_service_instance: IdentityService | None = None
def _default_identity_service() -> IdentityService:
"""Get or create the default identity service singleton."""
global _identity_service_instance
if _identity_service_instance is None:
_identity_service_instance = IdentityService()
return _identity_service_instance
class NoteFlowServicer(
StreamingMixin,
DiarizationMixin,
@@ -109,7 +85,7 @@ class NoteFlowServicer(
ProjectMixin,
ProjectMembershipMixin,
NoteFlowServicerStubs,
_GrpcBaseServicer,
GrpcBaseServicer,
):
"""Async gRPC service implementation for NoteFlow with PostgreSQL persistence.
@@ -159,7 +135,7 @@ class NoteFlowServicer(
self.webhook_service = services.webhook_service
self.project_service = services.project_service
# Identity service - always available (stateless, no dependencies)
self.identity_service = services.identity_service or _default_identity_service()
self.identity_service = services.identity_service or default_identity_service()
self._start_time = time.time()
self.memory_store: MeetingStore | None = MeetingStore() if session_factory is None else None
# Audio infrastructure

View File

@@ -0,0 +1,214 @@
"""Shared helpers for OAuth manager."""
from __future__ import annotations
import base64
import hashlib
import secrets
from datetime import UTC, datetime, timedelta
from typing import Mapping
from dataclasses import dataclass
from urllib.parse import urlencode
from noteflow.config.constants import (
DEFAULT_OAUTH_TOKEN_EXPIRY_SECONDS,
OAUTH_FIELD_ACCESS_TOKEN,
OAUTH_FIELD_REFRESH_TOKEN,
OAUTH_FIELD_SCOPE,
OAUTH_FIELD_TOKEN_TYPE,
)
from noteflow.domain.value_objects import OAuthProvider, OAuthState, OAuthTokens
def get_auth_url(provider: OAuthProvider, *, google_url: str, outlook_url: str) -> str:
"""Get authorization URL for provider."""
return google_url if provider == OAuthProvider.GOOGLE else outlook_url
def get_token_url(provider: OAuthProvider, *, google_url: str, outlook_url: str) -> str:
"""Get token URL for provider."""
return google_url if provider == OAuthProvider.GOOGLE else outlook_url
def get_revoke_url(provider: OAuthProvider, *, google_url: str, outlook_url: str) -> str:
"""Get revoke URL for provider."""
return google_url if provider == OAuthProvider.GOOGLE else outlook_url
def get_scopes(
provider: OAuthProvider,
*,
google_scopes: list[str],
outlook_scopes: list[str],
) -> list[str]:
"""Get OAuth scopes for provider."""
return google_scopes if provider == OAuthProvider.GOOGLE else outlook_scopes
def generate_code_verifier() -> str:
"""Generate a cryptographically random code verifier for PKCE."""
verifier = secrets.token_urlsafe(64)
return verifier
def generate_code_challenge(verifier: str) -> str:
"""Generate code challenge from verifier using S256 method."""
digest = hashlib.sha256(verifier.encode("ascii")).digest()
return base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
@dataclass(frozen=True, slots=True)
class AuthUrlConfig:
provider: OAuthProvider
redirect_uri: str
state: str
code_challenge: str
client_id: str
scopes: list[str]
google_auth_url: str
outlook_auth_url: str
@dataclass(frozen=True, slots=True)
class OAuthStateConfig:
provider: OAuthProvider
redirect_uri: str
code_verifier: str
state_token: str
created_at: datetime
ttl_seconds: int
@dataclass(frozen=True, slots=True)
class OAuthFlowConfig:
provider: OAuthProvider
redirect_uri: str
client_id: str
scopes: list[str]
google_auth_url: str
outlook_auth_url: str
state_ttl_seconds: int
def build_auth_url(config: AuthUrlConfig) -> str:
"""Build OAuth authorization URL with PKCE parameters."""
base_url = (
config.google_auth_url
if config.provider == OAuthProvider.GOOGLE
else config.outlook_auth_url
)
params = {
"client_id": config.client_id,
"redirect_uri": config.redirect_uri,
"response_type": "code",
OAUTH_FIELD_SCOPE: " ".join(config.scopes),
"state": config.state,
"code_challenge": config.code_challenge,
"code_challenge_method": "S256",
}
if config.provider == OAuthProvider.GOOGLE:
params["access_type"] = "offline"
params["prompt"] = "consent"
elif config.provider == OAuthProvider.OUTLOOK:
params["response_mode"] = "query"
return f"{base_url}?{urlencode(params)}"
def generate_state_token() -> str:
"""Generate a random state token for OAuth CSRF protection."""
token = secrets.token_urlsafe(32)
return token
def create_oauth_state(config: OAuthStateConfig) -> OAuthState:
"""Create an OAuthState from config settings."""
expires_at = config.created_at + timedelta(seconds=config.ttl_seconds)
return OAuthState(
state=config.state_token,
provider=config.provider,
redirect_uri=config.redirect_uri,
code_verifier=config.code_verifier,
created_at=config.created_at,
expires_at=expires_at,
)
def prepare_oauth_flow(config: OAuthFlowConfig) -> tuple[str, OAuthState, str]:
"""Prepare OAuth state and authorization URL for a flow."""
code_verifier = generate_code_verifier()
code_challenge = generate_code_challenge(code_verifier)
state_token = generate_state_token()
now = datetime.now(UTC)
oauth_state = create_oauth_state(
OAuthStateConfig(
provider=config.provider,
redirect_uri=config.redirect_uri,
code_verifier=code_verifier,
state_token=state_token,
created_at=now,
ttl_seconds=config.state_ttl_seconds,
)
)
auth_url = build_auth_url(
AuthUrlConfig(
provider=config.provider,
redirect_uri=config.redirect_uri,
state=state_token,
code_challenge=code_challenge,
client_id=config.client_id,
scopes=config.scopes,
google_auth_url=config.google_auth_url,
outlook_auth_url=config.outlook_auth_url,
)
)
return state_token, oauth_state, auth_url
def parse_token_response(
data: Mapping[str, object],
*,
existing_refresh_token: str | None = None,
) -> OAuthTokens:
"""Parse token response into OAuthTokens."""
access_token = str(data.get(OAUTH_FIELD_ACCESS_TOKEN, ""))
if not access_token:
raise ValueError("No access_token in response")
expires_in_raw = data.get("expires_in", DEFAULT_OAUTH_TOKEN_EXPIRY_SECONDS)
expires_in = (
int(expires_in_raw)
if isinstance(expires_in_raw, (int, float, str))
else DEFAULT_OAUTH_TOKEN_EXPIRY_SECONDS
)
expires_at = datetime.now(UTC) + timedelta(seconds=expires_in)
refresh_token = data.get(OAUTH_FIELD_REFRESH_TOKEN)
if isinstance(refresh_token, str):
final_refresh_token: str | None = refresh_token
else:
final_refresh_token = existing_refresh_token
return OAuthTokens(
access_token=access_token,
refresh_token=final_refresh_token,
token_type=str(data.get(OAUTH_FIELD_TOKEN_TYPE, "Bearer")),
expires_at=expires_at,
scope=str(data.get(OAUTH_FIELD_SCOPE, "")),
)
def validate_oauth_state(
oauth_state: OAuthState,
*,
provider: OAuthProvider,
) -> None:
"""Validate OAuth state values, raising ValueError on failures."""
if oauth_state.is_state_expired():
raise ValueError("State token has expired")
if oauth_state.provider != provider:
raise ValueError(
f"Provider mismatch: expected {oauth_state.provider}, got {provider}"
)

View File

@@ -6,27 +6,28 @@ Uses PKCE (Proof Key for Code Exchange) for secure OAuth 2.0 flow.
from __future__ import annotations
import base64
import hashlib
import secrets
from datetime import UTC, datetime, timedelta
from typing import TYPE_CHECKING, ClassVar
from urllib.parse import urlencode
import httpx
from noteflow.config.constants import (
DEFAULT_OAUTH_TOKEN_EXPIRY_SECONDS,
ERR_TOKEN_REFRESH_PREFIX,
HTTP_STATUS_NO_CONTENT,
HTTP_STATUS_OK,
OAUTH_FIELD_ACCESS_TOKEN,
OAUTH_FIELD_REFRESH_TOKEN,
OAUTH_FIELD_SCOPE,
OAUTH_FIELD_TOKEN_TYPE,
)
from noteflow.domain.ports.calendar import OAuthPort
from noteflow.domain.value_objects import OAuthProvider, OAuthState, OAuthTokens
from noteflow.infrastructure.calendar.oauth_helpers import (
OAuthFlowConfig,
get_revoke_url,
get_scopes,
get_token_url,
parse_token_response,
prepare_oauth_flow,
validate_oauth_state,
)
from noteflow.infrastructure.logging import get_logger, log_timing
if TYPE_CHECKING:
@@ -157,33 +158,25 @@ class OAuthManager(OAuthPort):
)
raise OAuthError("Too many pending OAuth flows. Please try again later.")
# Generate PKCE code verifier and challenge
code_verifier = self._generate_code_verifier()
code_challenge = self._generate_code_challenge(code_verifier)
# Generate state token for CSRF protection
state_token = secrets.token_urlsafe(32)
# Store state for validation during callback
now = datetime.now(UTC)
oauth_state = OAuthState(
state=state_token,
provider=provider,
redirect_uri=redirect_uri,
code_verifier=code_verifier,
created_at=now,
expires_at=now + timedelta(seconds=self.STATE_TTL_SECONDS),
client_id, _ = self._get_credentials(provider)
scopes = get_scopes(
provider,
google_scopes=self.GOOGLE_SCOPES,
outlook_scopes=self.OUTLOOK_SCOPES,
)
state_token, oauth_state, auth_url = prepare_oauth_flow(
OAuthFlowConfig(
provider=provider,
redirect_uri=redirect_uri,
client_id=client_id,
scopes=scopes,
google_auth_url=self.GOOGLE_AUTH_URL,
outlook_auth_url=self.OUTLOOK_AUTH_URL,
state_ttl_seconds=self.STATE_TTL_SECONDS,
)
)
self._pending_states[state_token] = oauth_state
# Build authorization URL
auth_url = self._build_auth_url(
provider=provider,
redirect_uri=redirect_uri,
state=state_token,
code_challenge=code_challenge,
)
logger.info(
"oauth_initiated",
provider=provider.value,
@@ -222,26 +215,24 @@ class OAuthManager(OAuthPort):
)
raise OAuthError("Invalid or expired state token")
if oauth_state.is_state_expired():
try:
validate_oauth_state(oauth_state, provider=provider)
except ValueError as exc:
event = (
"oauth_state_expired"
if "expired" in str(exc).lower()
else "oauth_provider_mismatch"
)
logger.warning(
"oauth_state_expired",
event,
event_type="security",
provider=provider.value,
created_at=oauth_state.created_at.isoformat(),
expires_at=oauth_state.expires_at.isoformat(),
)
raise OAuthError("State token has expired")
if oauth_state.provider != provider:
logger.warning(
"oauth_provider_mismatch",
event_type="security",
expected_provider=oauth_state.provider.value,
received_provider=provider.value,
)
raise OAuthError(
f"Provider mismatch: expected {oauth_state.provider}, got {provider}"
)
raise OAuthError(str(exc)) from exc
# Exchange code for tokens
tokens = await self._exchange_code(
@@ -271,7 +262,11 @@ class OAuthManager(OAuthPort):
Raises:
OAuthError: If refresh fails.
"""
token_url = self._get_token_url(provider)
token_url = get_token_url(
provider,
google_url=self.GOOGLE_TOKEN_URL,
outlook_url=self.OUTLOOK_TOKEN_URL,
)
client_id, client_secret = self._get_credentials(provider)
data = {
@@ -298,7 +293,13 @@ class OAuthManager(OAuthPort):
raise OAuthError(f"{ERR_TOKEN_REFRESH_PREFIX}{error_detail}")
token_data = response.json()
tokens = self._parse_token_response(token_data, refresh_token)
try:
tokens = parse_token_response(
token_data,
existing_refresh_token=refresh_token,
)
except ValueError as exc:
raise OAuthError(str(exc)) from exc
logger.info("oauth_tokens_refreshed", provider=provider.value)
return tokens
@@ -317,7 +318,11 @@ class OAuthManager(OAuthPort):
Returns:
True if revoked successfully.
"""
revoke_url = self._get_revoke_url(provider)
revoke_url = get_revoke_url(
provider,
google_url=self.GOOGLE_REVOKE_URL,
outlook_url=self.OUTLOOK_REVOKE_URL,
)
async with httpx.AsyncClient() as client:
if provider == OAuthProvider.GOOGLE:
@@ -363,61 +368,6 @@ class OAuthManager(OAuthPort):
self._settings.outlook_client_secret,
)
def _get_auth_url(self, provider: OAuthProvider) -> str:
"""Get authorization URL for provider."""
if provider == OAuthProvider.GOOGLE:
return self.GOOGLE_AUTH_URL
return self.OUTLOOK_AUTH_URL
def _get_token_url(self, provider: OAuthProvider) -> str:
"""Get token URL for provider."""
if provider == OAuthProvider.GOOGLE:
return self.GOOGLE_TOKEN_URL
return self.OUTLOOK_TOKEN_URL
def _get_revoke_url(self, provider: OAuthProvider) -> str:
"""Get revoke URL for provider."""
if provider == OAuthProvider.GOOGLE:
return self.GOOGLE_REVOKE_URL
return self.OUTLOOK_REVOKE_URL
def _get_scopes(self, provider: OAuthProvider) -> list[str]:
"""Get OAuth scopes for provider."""
if provider == OAuthProvider.GOOGLE:
return self.GOOGLE_SCOPES
return self.OUTLOOK_SCOPES
def _build_auth_url(
self,
provider: OAuthProvider,
redirect_uri: str,
state: str,
code_challenge: str,
) -> str:
"""Build OAuth authorization URL with PKCE parameters."""
client_id, _ = self._get_credentials(provider)
scopes = self._get_scopes(provider)
base_url = self._get_auth_url(provider)
params = {
"client_id": client_id,
"redirect_uri": redirect_uri,
"response_type": "code",
OAUTH_FIELD_SCOPE: " ".join(scopes),
"state": state,
"code_challenge": code_challenge,
"code_challenge_method": "S256",
}
# Provider-specific parameters
if provider == OAuthProvider.GOOGLE:
params["access_type"] = "offline"
params["prompt"] = "consent"
elif provider == OAuthProvider.OUTLOOK:
params["response_mode"] = "query"
return f"{base_url}?{urlencode(params)}"
async def _exchange_code(
self,
provider: OAuthProvider,
@@ -426,7 +376,11 @@ class OAuthManager(OAuthPort):
code_verifier: str,
) -> OAuthTokens:
"""Exchange authorization code for tokens."""
token_url = self._get_token_url(provider)
token_url = get_token_url(
provider,
google_url=self.GOOGLE_TOKEN_URL,
outlook_url=self.OUTLOOK_TOKEN_URL,
)
client_id, client_secret = self._get_credentials(provider)
data = {
@@ -454,52 +408,10 @@ class OAuthManager(OAuthPort):
raise OAuthError(f"Token exchange failed: {error_detail}")
token_data = response.json()
return self._parse_token_response(token_data)
def _parse_token_response(
self,
data: dict[str, object],
existing_refresh_token: str | None = None,
) -> OAuthTokens:
"""Parse token response into OAuthTokens."""
access_token = str(data.get(OAUTH_FIELD_ACCESS_TOKEN, ""))
if not access_token:
raise OAuthError("No access_token in response")
# Calculate expiry time
expires_in_raw = data.get("expires_in", DEFAULT_OAUTH_TOKEN_EXPIRY_SECONDS)
expires_in = (
int(expires_in_raw)
if isinstance(expires_in_raw, (int, float, str))
else DEFAULT_OAUTH_TOKEN_EXPIRY_SECONDS
)
expires_at = datetime.now(UTC) + timedelta(seconds=expires_in)
# Refresh token may not be returned on refresh
refresh_token = data.get(OAUTH_FIELD_REFRESH_TOKEN)
if isinstance(refresh_token, str):
final_refresh_token: str | None = refresh_token
else:
final_refresh_token = existing_refresh_token
return OAuthTokens(
access_token=access_token,
refresh_token=final_refresh_token,
token_type=str(data.get(OAUTH_FIELD_TOKEN_TYPE, "Bearer")),
expires_at=expires_at,
scope=str(data.get(OAUTH_FIELD_SCOPE, "")),
)
@staticmethod
def _generate_code_verifier() -> str:
"""Generate a cryptographically random code verifier for PKCE."""
return secrets.token_urlsafe(64)
@staticmethod
def _generate_code_challenge(verifier: str) -> str:
"""Generate code challenge from verifier using S256 method."""
digest = hashlib.sha256(verifier.encode("ascii")).digest()
return base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
try:
return parse_token_response(token_data)
except ValueError as exc:
raise OAuthError(str(exc)) from exc
def _cleanup_expired_states(self) -> None:
"""Remove expired state tokens."""

View File

@@ -64,19 +64,21 @@ def _patch_torch_load() -> None:
try:
import torch
from packaging.version import Version
if Version(torch.__version__) >= Version("2.6.0"):
original_load = cast(Callable[..., object], torch.load)
def _patched_load(*args: object, **kwargs: object) -> object:
if "weights_only" not in kwargs:
kwargs["weights_only"] = False
return original_load(*args, **kwargs)
setattr(torch, _ATTR_LOAD, _patched_load)
logger.debug("Patched torch.load for weights_only=False default")
except ImportError:
pass
return
if Version(torch.__version__) < Version("2.6.0"):
return
original_load = cast(Callable[..., object], torch.load)
def _patched_load(*args: object, **kwargs: object) -> object:
if "weights_only" not in kwargs:
kwargs["weights_only"] = False
return original_load(*args, **kwargs)
setattr(torch, _ATTR_LOAD, _patched_load)
logger.debug("Patched torch.load for weights_only=False default")
def _patch_huggingface_auth() -> None:

View File

@@ -8,7 +8,7 @@ from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Protocol
import numpy as np
@@ -19,6 +19,19 @@ from noteflow.infrastructure.logging import get_logger
if TYPE_CHECKING:
from diart import SpeakerDiarization
class _TrackSegment(Protocol):
start: float
end: float
class _Annotation(Protocol):
def itertracks(
self,
*,
yield_label: bool,
) -> Sequence[tuple[_TrackSegment, object, object]]: ...
from numpy.typing import NDArray
logger = get_logger(__name__)
@@ -27,6 +40,27 @@ logger = get_logger(__name__)
DEFAULT_CHUNK_DURATION: float = 5.0
def _collect_turns(
results: Sequence[tuple[_Annotation, object]],
stream_time: float,
) -> list[SpeakerTurn]:
"""Convert pipeline results to speaker turns with absolute time offsets."""
turns: list[SpeakerTurn] = []
for annotation, _ in results:
for track in annotation.itertracks(yield_label=True):
if len(track) != 3:
continue
segment, _, speaker = track
turns.append(
SpeakerTurn(
speaker=str(speaker),
start=segment.start + stream_time,
end=segment.end + stream_time,
)
)
return turns
@dataclass
class DiarizationSession:
"""Per-meeting streaming diarization session.
@@ -149,18 +183,8 @@ class DiarizationSession:
results = self._pipeline([waveform])
# Convert results to turns with absolute time offsets
new_turns: list[SpeakerTurn] = []
for annotation, _ in results:
for track in annotation.itertracks(yield_label=True):
if len(track) == 3:
segment, _, speaker = track
turn = SpeakerTurn(
speaker=str(speaker),
start=segment.start + self._stream_time,
end=segment.end + self._stream_time,
)
new_turns.append(turn)
self._turns.append(turn)
new_turns = _collect_turns(results, self._stream_time)
self._turns.extend(new_turns)
except (RuntimeError, ZeroDivisionError, ValueError) as e:
# Handle frame/weights mismatch and related errors gracefully

View File

@@ -145,7 +145,9 @@ class MemorySummaryRepository:
async def get_by_meeting(self, meeting_id: MeetingId) -> Summary | None:
"""Get summary for a meeting."""
return self._store.get_meeting_summary(str(meeting_id))
meeting_key = str(meeting_id)
summary = self._store.get_meeting_summary(meeting_key)
return summary
async def delete_by_meeting(self, meeting_id: MeetingId) -> bool:
"""Delete summary for a meeting."""

View File

@@ -80,7 +80,7 @@ def collect_assertion_roulette() -> list[Violation]:
if node.msg is None:
assertions_without_msg += 1
if assertions_without_msg > 3:
if assertions_without_msg > 1:
violations.append(
Violation(
rule="assertion_roulette",
@@ -138,9 +138,6 @@ def collect_sleepy_tests() -> list[Violation]:
"""Collect sleepy test violations."""
allowed_sleepy_paths = {
"tests/stress/",
"tests/integration/test_signal_handling.py",
"tests/integration/test_database_resilience.py",
"tests/grpc/test_stream_lifecycle.py",
}
violations: list[Violation] = []
@@ -312,8 +309,8 @@ def collect_magic_number_tests() -> list[Violation]:
def collect_sensitive_equality() -> list[Violation]:
"""Collect sensitive equality (str/repr comparison) violations."""
excluded_test_patterns = {"string", "proto", "conversion", "serializ", "preserves_message"}
excluded_file_patterns = {"_mixin"}
excluded_test_patterns = {"string", "proto"}
excluded_file_patterns: set[str] = set()
violations: list[Violation] = []
for py_file in find_test_files():
@@ -351,7 +348,7 @@ def collect_sensitive_equality() -> list[Violation]:
def collect_eager_tests() -> list[Violation]:
"""Collect eager test (too many method calls) violations."""
max_method_calls = 10
max_method_calls = 7
violations: list[Violation] = []
for py_file in find_test_files():
@@ -414,7 +411,7 @@ def collect_duplicate_test_names() -> list[Violation]:
def collect_long_tests() -> list[Violation]:
"""Collect long test method violations."""
max_lines = 46
max_lines = 35
violations: list[Violation] = []
for py_file in find_test_files():

View File

@@ -1,38 +1,5 @@
{
"generated_at": "2026-01-05T03:06:34.258338+00:00",
"rules": {
"deep_nesting": [
"deep_nesting|src/noteflow/application/services/auth_service.py|get_current_user|depth=5",
"deep_nesting|src/noteflow/grpc/_mixins/identity.py|GetCurrentUser|depth=4",
"deep_nesting|src/noteflow/infrastructure/diarization/_compat.py|_patch_torch_load|depth=4",
"deep_nesting|src/noteflow/infrastructure/diarization/session.py|_process_full_chunk|depth=4"
],
"feature_envy": [
"feature_envy|src/noteflow/application/services/auth_service.py|AuthService.refresh_auth_tokens|integration=10_vs_self=3",
"feature_envy|src/noteflow/grpc/meeting_store.py|MeetingStore.list_all|kwargs=6_vs_self=2",
"feature_envy|src/noteflow/grpc/meeting_store.py|MeetingStore.list_all|m=6_vs_self=2",
"feature_envy|src/noteflow/infrastructure/calendar/oauth_manager.py|OAuthManager.complete_auth|oauth_state=8_vs_self=2"
],
"god_class": [
"god_class|src/noteflow/infrastructure/calendar/oauth_manager.py|OAuthManager|lines=513",
"god_class|src/noteflow/infrastructure/calendar/oauth_manager.py|OAuthManager|methods=21"
],
"long_method": [
"long_method|src/noteflow/application/services/auth_service.py|_store_auth_user|lines=85",
"long_method|src/noteflow/application/services/auth_service.py|refresh_auth_tokens|lines=76",
"long_method|src/noteflow/grpc/_mixins/meeting.py|ListMeetings|lines=72"
],
"module_size_soft": [
"module_size_soft|src/noteflow/application/services/auth_service.py|module|lines=571",
"module_size_soft|src/noteflow/grpc/service.py|module|lines=522",
"module_size_soft|src/noteflow/infrastructure/calendar/oauth_manager.py|module|lines=554"
],
"passthrough_class": [
"passthrough_class|src/noteflow/domain/entities/meeting.py|ProcessingStepState|5_methods"
],
"thin_wrapper": [
"thin_wrapper|src/noteflow/infrastructure/persistence/memory/repositories/core.py|get_by_meeting|get_meeting_summary"
]
},
"generated_at": "2026-01-05T15:51:45.809039+00:00",
"rules": {},
"schema_version": 1
}

View File

@@ -190,7 +190,7 @@ def collect_deprecated_patterns() -> list[Violation]:
def collect_high_complexity() -> list[Violation]:
"""Collect high complexity violations."""
max_complexity = 15
max_complexity = 12
violations: list[Violation] = []
def count_branches(node: ast.AST) -> int:
@@ -232,7 +232,7 @@ def collect_high_complexity() -> list[Violation]:
def collect_long_parameter_lists() -> list[Violation]:
"""Collect long parameter list violations."""
max_params = 5
max_params = 4
violations: list[Violation] = []
for py_file in find_source_files(include_migrations=False):
@@ -282,7 +282,6 @@ def collect_thin_wrappers() -> list[Violation]:
("full_transcript", "join"),
("duration", "sub"),
("is_active", "property"),
("is_admin", "can_admin"),
# Domain method accessors (type-safe dict access)
("get_metadata", "get"),
# Strategy pattern implementations (RuleType.evaluate for simple mode)
@@ -292,36 +291,21 @@ def collect_thin_wrappers() -> list[Violation]:
("generate_request_id", "str"),
# Context variable accessors (public API over internal contextvars)
("get_request_id", "get"),
("get_user_id", "get"),
("get_workspace_id", "get"),
# Time conversion utilities (semantic naming for datetime operations)
("datetime_to_epoch_seconds", "timestamp"),
("datetime_to_iso_string", "isoformat"),
("epoch_seconds_to_datetime", "fromtimestamp"),
("proto_timestamp_to_datetime", "replace"),
# Accessor-style wrappers with semantic names
("from_metrics", "cls"),
("from_dict", "cls"),
("empty", "cls"),
("get_log_level", "get"),
("get_preset_config", "get"),
("get_provider", "get"),
("get_pending_state", "get"),
("get_stream_state", "get"),
("get_async_session_factory", "async_sessionmaker"),
("process_chunk", "process"),
("get_openai_client", "_get_openai_client"),
("meeting_apps", "frozenset"),
("suppressed_apps", "frozenset"),
("get_sync_run", "get"),
("list_all", "list"),
("get_by_id", "get"),
("create", "insert"),
("delete_by_meeting", "clear_summary"),
("get_by_meeting", "fetch_segments"),
("get_by_meeting", "get_summary"),
("check_otel_available", "_check_otel_available"),
("start_as_current_span", "_NoOpSpanContext"),
("start_span", "_NoOpSpan"),
("detected_app", "next"),
}
@@ -376,7 +360,7 @@ def collect_thin_wrappers() -> list[Violation]:
def collect_long_methods() -> list[Violation]:
"""Collect long method violations."""
max_lines = 68
max_lines = 50
violations: list[Violation] = []
def count_function_lines(node: ast.FunctionDef | ast.AsyncFunctionDef) -> int:
@@ -407,7 +391,7 @@ def collect_long_methods() -> list[Violation]:
def collect_module_size_soft() -> list[Violation]:
"""Collect module size soft limit violations."""
soft_limit = 500
soft_limit = 350
violations: list[Violation] = []
for py_file in find_source_files(include_migrations=False):
@@ -466,8 +450,8 @@ def collect_alias_imports() -> list[Violation]:
def collect_god_classes() -> list[Violation]:
"""Collect god class violations."""
max_methods = 20
max_lines = 500
max_methods = 15
max_lines = 400
violations: list[Violation] = []
for py_file in find_source_files(include_migrations=False):
@@ -512,7 +496,7 @@ def collect_god_classes() -> list[Violation]:
def collect_deep_nesting() -> list[Violation]:
"""Collect deep nesting violations."""
max_nesting = 3
max_nesting = 2
violations: list[Violation] = []
def count_nesting_depth(node: ast.AST, current_depth: int = 0) -> int:
@@ -559,7 +543,6 @@ def collect_feature_envy() -> list[Violation]:
"converter",
"exporter",
"repository",
"repo",
}
excluded_method_patterns = {
"_to_domain",
@@ -567,8 +550,6 @@ def collect_feature_envy() -> list[Violation]:
"_proto_to_",
"_to_orm",
"_from_orm",
"export",
"_parse",
}
excluded_object_names = {
"model",
@@ -580,7 +561,6 @@ def collect_feature_envy() -> list[Violation]:
"noteflow_pb2",
"seg",
"job",
"repo",
"ai",
"summary",
"MeetingState",
@@ -590,12 +570,6 @@ def collect_feature_envy() -> list[Violation]:
"uow",
"span",
"host",
"logger",
"data",
"config",
"p",
"params",
"args",
}
violations: list[Violation] = []
@@ -634,7 +608,7 @@ def collect_feature_envy() -> list[Violation]:
for other_obj, count in other_accesses.items():
if other_obj in excluded_object_names:
continue
if count > self_accesses + 3 and count > 5:
if count > self_accesses + 2 and count > 4:
violations.append(
Violation(
rule="feature_envy",

View File

@@ -75,7 +75,7 @@ def count_function_lines(node: ast.FunctionDef | ast.AsyncFunctionDef) -> int:
def test_no_high_complexity_functions() -> None:
"""Detect functions with high cyclomatic complexity."""
max_complexity = 15
max_complexity = 12
violations: list[Violation] = []
parse_errors: list[str] = []
@@ -109,7 +109,7 @@ def test_no_high_complexity_functions() -> None:
def test_no_long_parameter_lists() -> None:
"""Detect functions with too many parameters."""
max_params = 5
max_params = 4
violations: list[Violation] = []
parse_errors: list[str] = []
@@ -151,8 +151,8 @@ def test_no_long_parameter_lists() -> None:
def test_no_god_classes() -> None:
"""Detect classes with too many methods or too much responsibility."""
max_methods = 20
max_lines = 500
max_methods = 15
max_lines = 400
violations: list[Violation] = []
parse_errors: list[str] = []
@@ -203,7 +203,7 @@ def test_no_god_classes() -> None:
def test_no_deep_nesting() -> None:
"""Detect functions with excessive nesting depth."""
max_nesting = 3
max_nesting = 2
violations: list[Violation] = []
parse_errors: list[str] = []
@@ -237,7 +237,7 @@ def test_no_deep_nesting() -> None:
def test_no_long_methods() -> None:
"""Detect methods that are too long."""
max_lines = 68
max_lines = 50
violations: list[Violation] = []
parse_errors: list[str] = []
@@ -285,7 +285,6 @@ def test_no_feature_envy() -> None:
"converter",
"exporter",
"repository",
"repo",
}
excluded_method_patterns = {
"_to_domain",
@@ -293,8 +292,6 @@ def test_no_feature_envy() -> None:
"_proto_to_",
"_to_orm",
"_from_orm",
"export",
"_parse", # Parsing external data (API responses)
}
# Objects that are commonly used more than self but aren't feature envy
excluded_object_names = {
@@ -317,12 +314,6 @@ def test_no_feature_envy() -> None:
"uow", # Unit of work in service methods
"span", # OpenTelemetry span in observability
"host", # Servicer host in mixin methods
"logger", # Logging is cross-cutting, not feature envy
"data", # Dict parsing in from_dict factory methods
"config", # Configuration object access
"p", # Short alias for params in factory methods
"params", # Parameters object in factory methods
"args", # CLI args parsing in factory methods
}
def _is_excluded_class(class_name: str) -> bool:
@@ -385,7 +376,7 @@ def test_no_feature_envy() -> None:
for other_obj, count in other_accesses.items():
if other_obj in excluded_object_names:
continue
if count > self_accesses + 3 and count > 5:
if count > self_accesses + 2 and count > 4:
violations.append(
Violation(
rule="feature_envy",
@@ -401,8 +392,8 @@ def test_no_feature_envy() -> None:
def test_module_size_limits() -> None:
"""Check that modules don't exceed size limits."""
soft_limit = 500
hard_limit = 750
soft_limit = 350
hard_limit = 600
soft_violations: list[Violation] = []
hard_violations: list[Violation] = []

View File

@@ -124,11 +124,11 @@ def test_helpers_not_scattered() -> None:
f" {', '.join(locations)}"
)
# Target: 15 scattered helpers max - some duplication is expected for:
# Target: 5 scattered helpers max - some duplication is expected for:
# - Client/server pairs with same method names
# - Mixin protocols + implementations
assert len(scattered) <= 15, (
f"Found {len(scattered)} scattered helper functions (max 15 allowed). "
assert len(scattered) <= 5, (
f"Found {len(scattered)} scattered helper functions (max 5 allowed). "
"Consider consolidating:\n\n" + "\n\n".join(scattered[:5])
)
@@ -192,10 +192,10 @@ def test_no_duplicate_helper_implementations() -> None:
loc_strs = [f"{f}:{line}" for f, line, _ in locations]
duplicates.append(f"'{signature}' defined at: {', '.join(loc_strs)}")
# Target: 25 duplicate helper signatures - some duplication expected for:
# Target: 10 duplicate helper signatures - some duplication expected for:
# - Mixin composition (protocol + implementation)
# - Client/server pairs
assert len(duplicates) <= 25, (
f"Found {len(duplicates)} duplicate helper signatures (max 25 allowed):\n"
assert len(duplicates) <= 10, (
f"Found {len(duplicates)} duplicate helper signatures (max 10 allowed):\n"
+ "\n".join(duplicates[:5])
)

View File

@@ -146,10 +146,9 @@ def test_no_duplicate_function_bodies() -> None:
f" Preview: {preview}..."
)
# Allow baseline - some duplication exists between client.py and streaming_session.py
# for callback notification methods which will be consolidated during client refactoring
assert len(violations) <= 1, (
f"Found {len(violations)} duplicate function groups (max 1 allowed):\n\n"
# Tighten: no duplicate function bodies allowed.
assert len(violations) <= 0, (
f"Found {len(violations)} duplicate function groups (max 0 allowed):\n\n"
+ "\n\n".join(violations)
)
@@ -186,7 +185,7 @@ def test_no_repeated_code_patterns() -> None:
f" Sample locations: {', '.join(locations)}"
)
# Target: 182 repeated patterns max - remaining are architectural:
# Target: 120 repeated patterns max - remaining are architectural:
# Hexagonal architecture requires Protocol interfaces to match implementations:
# - Repository method signatures (~60): Service → Protocol → SQLAlchemy → Memory
# Each method signature creates multiple overlapping 4-line windows
@@ -200,8 +199,8 @@ def test_no_repeated_code_patterns() -> None:
# - Import patterns (~10): webhook imports, RULE_FIELD imports, service TYPE_CHECKING
# imports in _config.py/server.py/service.py for ServicesConfig pattern
# Note: Alembic migrations are excluded from this check (immutable historical records)
# Updated: 189 patterns after observability usage tracking additions
assert len(repeated_patterns) <= 189, (
f"Found {len(repeated_patterns)} significantly repeated patterns (max 189 allowed). "
# Updated: 120 patterns after tightening expectations
assert len(repeated_patterns) <= 120, (
f"Found {len(repeated_patterns)} significantly repeated patterns (max 120 allowed). "
f"Consider abstracting:\n\n" + "\n\n".join(repeated_patterns[:5])
)

View File

@@ -29,7 +29,7 @@ ALLOWED_NUMBERS = {
0, 1, 2, 3, 4, 5, -1, # Small integers
10, 20, 30, 50, # Common timeout/limit values
60, 100, 200, 255, 365, 1000, 1024, # Common constants
0.1, 0.3, 0.5, # Common float values
0.1, 0.5, # Common float values
16000, 50051, # Sample rate and gRPC port
}
ALLOWED_STRINGS = {
@@ -39,11 +39,10 @@ ALLOWED_STRINGS = {
"\t",
"utf-8",
"utf8",
"w",
"r",
"w",
"rb",
"wb",
"a",
"GET",
"POST",
"PUT",
@@ -58,205 +57,71 @@ ALLOWED_STRINGS = {
"name",
"type",
"value",
# Common domain/infrastructure terms
"__main__",
"noteflow",
"meeting",
"segment",
"summary",
"annotation",
"CASCADE",
"selectin",
"schema",
"role",
"user",
"text",
"title",
"status",
"content",
"created_at",
"updated_at",
"start_time",
"end_time",
"meeting_id",
"user_id",
"request_id",
# Domain enums
"action_item",
"decision",
"note",
"risk",
"unknown",
"completed",
"failed",
"pending",
"running",
"markdown",
"html",
# Common patterns
"base",
"auto",
"cuda",
"int8",
"float32",
# argparse actions
"store_true",
"store_false",
# ORM table/column names (intentionally repeated across models/repos)
"meetings",
"segments",
"summaries",
"annotations",
"key_points",
"action_items",
"word_timings",
"sample_rate",
"segment_ids",
"summary_id",
"wrapped_dek",
"diarization_jobs",
"user_preferences",
"streamingdiarization_turns",
# ORM cascade settings
"all, delete-orphan",
# Foreign key references
"noteflow.meetings.id",
"noteflow.summaries.id",
# Error message patterns (intentional consistency)
"UnitOfWork not in context",
"Invalid meeting_id",
"Invalid annotation_id",
# File names (infrastructure constants)
"manifest.json",
"audio.enc",
# HTML tags
"</div>",
"</dd>",
# Model class names (ORM back_populates/relationships - required by SQLAlchemy)
"ActionItemModel",
"AnnotationModel",
"CalendarEventModel",
"DiarizationJobModel",
"ExternalRefModel",
"IntegrationModel",
"IntegrationSecretModel",
"IntegrationSyncRunModel",
"KeyPointModel",
"MeetingCalendarLinkModel",
"MeetingModel",
"MeetingSpeakerModel",
"MeetingTagModel",
"NamedEntityModel",
"PersonModel",
"SegmentModel",
"SettingsModel",
"StreamingDiarizationTurnModel",
"SummaryModel",
"TagModel",
"TaskModel",
"UserModel",
"UserPreferencesModel",
"WebhookConfigModel",
"WebhookDeliveryModel",
"WordTimingModel",
"WorkspaceMembershipModel",
"WorkspaceModel",
# ORM relationship back_populates names
"workspace",
"memberships",
"integration",
"meeting_tags",
"tasks",
# Foreign key references
"noteflow.workspaces.id",
"noteflow.users.id",
"noteflow.integrations.id",
# Database ondelete actions
"SET NULL",
"RESTRICT",
# Column names used in mappings
"metadata",
"workspace_id",
# Database URL prefixes
"postgres://",
"postgresql://",
"postgresql+asyncpg://",
# OIDC standard claim names (RFC 7519 / OpenID Connect Core spec)
"sub",
"email",
"email_verified",
"preferred_username",
"groups",
"picture",
"given_name",
"family_name",
"openid",
"profile",
"offline_access",
# OIDC discovery document fields (OpenID Connect Discovery spec)
"issuer",
"authorization_endpoint",
"token_endpoint",
"userinfo_endpoint",
"jwks_uri",
"end_session_endpoint",
"revocation_endpoint",
"introspection_endpoint",
"scopes_supported",
"code_challenge_methods_supported",
"claims_supported",
"response_types_supported",
# OIDC provider config fields
"discovery",
"discovery_refreshed_at",
"issuer_url",
"client_id",
"client_secret",
"claim_mapping",
"scopes",
"preset",
"require_email_verified",
"allowed_groups",
"enabled",
# Integration status values
"success",
"error",
"calendar",
"provider",
# Common error message fragments
" not found",
# HTML markup
"<li>",
"</li>",
# Logging levels
"INFO",
"DEBUG",
"WARNING",
"ERROR",
# Domain terms
"project",
# Internal attribute names (used in multiple gRPC handlers)
"_pending_chunks",
# Sentinel UUIDs
"00000000-0000-0000-0000-000000000001",
# Repository type names (used in TYPE_CHECKING imports and annotations)
"MeetingRepository",
"SegmentRepository",
"SummaryRepository",
"AnnotationRepository",
"UserRepository",
"WorkspaceRepository",
# Pagination and filter parameter names (used in repositories and services)
"states",
"limit",
"offset",
"sort_desc",
"project_id",
"project_ids",
# Domain attribute names (used across entities, converters, services)
"provider_name",
"model_name",
"annotation_type",
"error_message",
"integration_id",
"started_at",
@@ -264,53 +129,22 @@ ALLOWED_STRINGS = {
"slug",
"description",
"settings",
"location",
"date",
"start",
"attendees",
"secret",
"timeout_ms",
"max_retries",
"ascii",
"code",
# Security and logging categories
"security",
# Identity and role names
"viewer",
"User",
"Webhook",
"Workspaces",
# Settings field names
"rag_enabled",
"default_summarization_template",
# Cache keys
"sync_run_cache_times",
# Log context names
"diarization_job",
"segmenter_state_transition",
# Default model names
"gpt-4o-mini",
"claude-3-haiku-20240307",
# ORM model class names
"ProjectMembershipModel",
"ProjectModel",
# Error code constants
"service_not_enabled",
# Protocol prefixes
"http://",
# Timezone format
"+00:00",
# gRPC status codes
"INTERNAL",
"UNKNOWN",
# Proto type names (used in TYPE_CHECKING)
"ProtoAnnotation",
"ProtoMeeting",
# Error message fragments
", got ",
# Error message templates (shared across diarization status handlers)
"Cannot update job %s: database required",
# Ruff ignore directive
"ignore",
}
@@ -392,11 +226,9 @@ def test_no_magic_numbers() -> None:
for value, mvs in repeated
]
# Target: 11 repeated magic numbers max - common values need named constants:
# - 10 (20x), 1024 (14x), 5 (13x), 50 (12x), 40, 24, 300, 10000, 500 should be constants
# Note: 40 (model display width), 24 (hours), 300 (timeouts), 10000/500 (http codes) are repeated
assert len(violations) <= 11, (
f"Found {len(violations)} repeated magic numbers (max 11 allowed). "
# Target: 5 repeated magic numbers max - common values should be constants.
assert len(violations) <= 5, (
f"Found {len(violations)} repeated magic numbers (max 5 allowed). "
"Consider extracting to constants:\n\n" + "\n\n".join(violations[:5])
)
@@ -440,10 +272,10 @@ def test_no_repeated_string_literals() -> None:
for value, locs in repeated
]
# Target: 31 repeated strings max - many can be extracted to constants
# Target: 10 repeated strings max - many can be extracted to constants
# - Error messages, schema names, log formats should be centralized
assert len(violations) <= 31, (
f"Found {len(violations)} repeated string literals (max 31 allowed). "
assert len(violations) <= 10, (
f"Found {len(violations)} repeated string literals (max 10 allowed). "
"Consider using constants or enums:\n\n" + "\n\n".join(violations[:5])
)

View File

@@ -106,8 +106,8 @@ def test_no_assertion_roulette() -> None:
if not has_assertion_message(node):
assertions_without_msg += 1
# Flag if >3 assertions without messages
if assertions_without_msg > 3:
# Flag if >1 assertions without messages
if assertions_without_msg > 1:
violations.append(
Violation(
rule="assertion_roulette",
@@ -237,9 +237,6 @@ def test_no_sleepy_tests() -> None:
# Paths where sleep is legitimately needed for stress/resilience testing
allowed_sleepy_paths = {
"tests/stress/",
"tests/integration/test_signal_handling.py",
"tests/integration/test_database_resilience.py",
"tests/grpc/test_stream_lifecycle.py",
}
violations: list[Violation] = []
@@ -589,15 +586,10 @@ def test_no_sensitive_equality() -> None:
excluded_test_patterns = {
"string", # Testing string conversion behavior
"proto", # Testing protobuf field serialization
"conversion", # Testing type conversion
"serializ", # Testing serialization
"preserves_message", # Testing error message preservation
}
# File patterns where str() comparison is expected (gRPC field serialization)
excluded_file_patterns = {
"_mixin", # gRPC mixin tests compare ID fields
}
excluded_file_patterns: set[str] = set()
def _is_excluded_test(test_name: str) -> bool:
"""Check if test legitimately uses str() comparison."""
@@ -662,7 +654,7 @@ def test_no_eager_tests() -> None:
"""
violations: list[Violation] = []
parse_errors: list[str] = []
max_method_calls = 10 # Threshold for "too many" method calls
max_method_calls = 7 # Threshold for "too many" method calls
for py_file in find_test_files():
tree, error = parse_file_safe(py_file)
@@ -795,7 +787,7 @@ def test_no_long_test_methods() -> None:
Long tests are hard to understand and maintain. Break them into
smaller, focused tests or extract helper functions.
"""
max_lines = 46
max_lines = 35
violations: list[Violation] = []
parse_errors: list[str] = []

View File

@@ -69,7 +69,6 @@ def test_no_trivial_wrapper_functions() -> None:
("full_transcript", "join"),
("duration", "sub"),
("is_active", "property"),
("is_admin", "can_admin"), # semantic alias for operation context
# Domain method accessors (type-safe dict access)
("get_metadata", "get"),
# Strategy pattern implementations (RuleType.evaluate for simple mode)
@@ -79,43 +78,25 @@ def test_no_trivial_wrapper_functions() -> None:
("generate_request_id", "str"), # UUID to string conversion
# Context variable accessors (public API over internal contextvars)
("get_request_id", "get"),
("get_user_id", "get"),
("get_workspace_id", "get"),
# Time conversion utilities (semantic naming for datetime operations)
("datetime_to_epoch_seconds", "timestamp"),
("datetime_to_iso_string", "isoformat"),
("epoch_seconds_to_datetime", "fromtimestamp"),
("proto_timestamp_to_datetime", "replace"),
# Accessor-style wrappers with semantic names
("from_metrics", "cls"),
("from_dict", "cls"),
("empty", "cls"),
# ProcessingStepState factory methods (GAP-W05)
("pending", "cls"),
("running", "cls"),
("completed", "cls"),
("failed", "cls"),
("skipped", "cls"),
("create_pending", "cls"),
("get_log_level", "get"),
("get_preset_config", "get"),
("get_provider", "get"),
("get_pending_state", "get"),
("get_stream_state", "get"),
("get_async_session_factory", "async_sessionmaker"),
("process_chunk", "process"),
("get_openai_client", "_get_openai_client"),
("meeting_apps", "frozenset"),
("suppressed_apps", "frozenset"),
("get_sync_run", "get"),
("list_all", "list"),
("get_by_id", "get"),
("create", "insert"),
("delete_by_meeting", "clear_summary"),
("get_by_meeting", "fetch_segments"),
("get_by_meeting", "get_summary"),
("check_otel_available", "_check_otel_available"),
("start_as_current_span", "_NoOpSpanContext"),
("start_span", "_NoOpSpan"),
("detected_app", "next"),
}
@@ -231,10 +212,7 @@ def test_no_redundant_type_aliases() -> None:
def test_no_passthrough_classes() -> None:
"""Detect classes that only delegate to another object."""
# Classes that are intentionally factory-pattern based (all methods return cls())
allowed_factory_classes = {
# Domain entity with factory methods for creating different states (GAP-W05)
"ProcessingStepState",
}
allowed_factory_classes: set[str] = set()
violations: list[Violation] = []
parse_errors: list[str] = []