This commit is contained in:
2026-01-21 00:49:40 +00:00
parent f70b35c39f
commit fc7bbd0ea2
36 changed files with 487 additions and 349 deletions

View File

@@ -114,7 +114,7 @@ file_path_pattern := `\.py$`
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"}
tool_name in {"write", "edit", "notebookedit"}
file_path := resolved_file_path
regex.match(file_path_pattern, file_path)
@@ -132,7 +132,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit"
tool_name == "multiedit"
some edit in tool_input.edits
file_path := edit_path(edit)
@@ -151,7 +151,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Patch", "ApplyPatch"}
tool_name in {"patch", "apply_patch"}
patch := patch_content
patch != null

View File

@@ -126,7 +126,7 @@ assertion_pattern := `^\s*assert\s+[^,\n]+\n\s*assert\s+[^,\n]+$`
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"}
tool_name in {"write", "edit", "notebookedit"}
file_path := resolved_file_path
regex.match(file_path_pattern, file_path)
@@ -144,7 +144,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit"
tool_name == "multiedit"
some edit in tool_input.edits
file_path := edit_path(edit)
@@ -163,7 +163,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Patch", "ApplyPatch"}
tool_name in {"patch", "apply_patch"}
patch := patch_content
patch != null

View File

@@ -126,7 +126,7 @@ ignore_pattern := `//\s*biome-ignore|//\s*@ts-ignore|//\s*@ts-expect-error|//\s*
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"}
tool_name in {"write", "edit", "notebookedit"}
file_path := resolved_file_path
regex.match(file_path_pattern, file_path)
@@ -144,7 +144,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit"
tool_name == "multiedit"
some edit in tool_input.edits
file_path := edit_path(edit)
@@ -163,7 +163,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Patch", "ApplyPatch"}
tool_name in {"patch", "apply_patch"}
patch := patch_content
patch != null

View File

@@ -104,7 +104,7 @@ handler_pattern := `except\s+Exception\s*(?:as\s+\w+)?:\s*\n\s+(?:logger\.|loggi
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"}
tool_name in {"write", "edit", "notebookedit"}
content := new_content
content != null
@@ -119,7 +119,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit"
tool_name == "multiedit"
some edit in tool_input.edits
content := edit_new_content(edit)
@@ -135,7 +135,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Patch", "ApplyPatch"}
tool_name in {"patch", "apply_patch"}
patch := patch_content
patch != null

View File

@@ -104,7 +104,7 @@ file_path_pattern := `src/test/code-quality\.test\.ts$`
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"}
tool_name in {"write", "edit", "notebookedit"}
file_path := resolved_file_path
regex.match(file_path_pattern, file_path)
@@ -118,7 +118,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit"
tool_name == "multiedit"
some edit in tool_input.edits
file_path := edit_path(edit)

View File

@@ -104,7 +104,7 @@ pattern := `return\s+datetime\.now\s*\(`
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"}
tool_name in {"write", "edit", "notebookedit"}
content := new_content
content != null
@@ -119,7 +119,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit"
tool_name == "multiedit"
some edit in tool_input.edits
content := edit_new_content(edit)
@@ -135,7 +135,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Patch", "ApplyPatch"}
tool_name in {"patch", "apply_patch"}
patch := patch_content
patch != null

View File

@@ -104,7 +104,7 @@ pattern := `except\s+\w*(?:Error|Exception).*?:\s*\n\s+.*?(?:logger\.|logging\.)
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"}
tool_name in {"write", "edit", "notebookedit"}
content := new_content
content != null
@@ -119,7 +119,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit"
tool_name == "multiedit"
some edit in tool_input.edits
content := edit_new_content(edit)
@@ -135,7 +135,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Patch", "ApplyPatch"}
tool_name in {"patch", "apply_patch"}
patch := patch_content
patch != null

View File

@@ -127,7 +127,7 @@ fixture_pattern := `@pytest\.fixture[^@]*\ndef\s+(mock_uow|crypto|meetings_dir|w
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"}
tool_name in {"write", "edit", "notebookedit"}
file_path := resolved_file_path
regex.match(file_path_pattern, file_path)
@@ -146,7 +146,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit"
tool_name == "multiedit"
some edit in tool_input.edits
file_path := edit_path(edit)
@@ -166,7 +166,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Patch", "ApplyPatch"}
tool_name in {"patch", "apply_patch"}
patch := patch_content
patch != null

View File

@@ -104,7 +104,7 @@ file_path_pattern := `(^|/)client/.*(?:\.?eslint(?:rc|\.config).*|\.?prettier(?:
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"}
tool_name in {"write", "edit", "notebookedit"}
file_path := resolved_file_path
regex.match(file_path_pattern, file_path)
@@ -119,7 +119,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit"
tool_name == "multiedit"
some edit in tool_input.edits
file_path := edit_path(edit)

View File

@@ -104,7 +104,7 @@ file_path_pattern := `(?:pyproject\.toml|\.?ruff\.toml|\.?pyrightconfig\.json|\.
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"}
tool_name in {"write", "edit", "notebookedit"}
file_path := resolved_file_path
regex.match(file_path_pattern, file_path)
@@ -119,7 +119,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit"
tool_name == "multiedit"
some edit in tool_input.edits
file_path := edit_path(edit)

View File

@@ -126,7 +126,7 @@ number_pattern := `(?:timeout|delay|interval|duration|limit|max|min|size|count|t
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"}
tool_name in {"write", "edit", "notebookedit"}
file_path := resolved_file_path
regex.match(file_path_pattern, file_path)
@@ -145,7 +145,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit"
tool_name == "multiedit"
some edit in tool_input.edits
file_path := edit_path(edit)
@@ -165,7 +165,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Patch", "ApplyPatch"}
tool_name in {"patch", "apply_patch"}
patch := patch_content
patch != null

View File

@@ -104,7 +104,7 @@ file_path_pattern := `(?:^|/)Makefile$`
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"}
tool_name in {"write", "edit", "notebookedit"}
file_path := resolved_file_path
regex.match(file_path_pattern, file_path)
@@ -118,7 +118,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit"
tool_name == "multiedit"
some edit in tool_input.edits
file_path := edit_path(edit)

View File

@@ -104,7 +104,7 @@ pattern := `except\s+\w*Error.*?:\s*\n\s+.*?(?:logger\.|logging\.).*?\n\s+return
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"}
tool_name in {"write", "edit", "notebookedit"}
content := new_content
content != null
@@ -119,7 +119,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit"
tool_name == "multiedit"
some edit in tool_input.edits
content := edit_new_content(edit)
@@ -135,7 +135,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Patch", "ApplyPatch"}
tool_name in {"patch", "apply_patch"}
patch := patch_content
patch != null

View File

@@ -126,7 +126,7 @@ pattern := `def test_[^(]+\([^)]*\)[^:]*:[\s\S]*?\b(for|while|if)\s+[^:]+:[\s\S]
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"}
tool_name in {"write", "edit", "notebookedit"}
file_path := resolved_file_path
regex.match(file_path_pattern, file_path)
@@ -144,7 +144,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit"
tool_name == "multiedit"
some edit in tool_input.edits
file_path := edit_path(edit)
@@ -163,7 +163,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Patch", "ApplyPatch"}
tool_name in {"patch", "apply_patch"}
patch := patch_content
patch != null

View File

@@ -105,7 +105,7 @@ exclude_pattern := `baselines\.json$`
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"}
tool_name in {"write", "edit", "notebookedit"}
file_path := resolved_file_path
regex.match(file_path_pattern, file_path)
@@ -120,7 +120,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit"
tool_name == "multiedit"
some edit in tool_input.edits
file_path := edit_path(edit)

View File

@@ -128,7 +128,7 @@ any_type_patterns := [
# Block Write/Edit operations that introduce Any in Python files
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"}
tool_name in {"write", "edit", "notebookedit"}
# Only enforce for Python files
file_path := lower(resolved_file_path)
@@ -149,7 +149,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Patch", "ApplyPatch"}
tool_name in {"patch", "apply_patch"}
content := patch_content
content != null
@@ -166,7 +166,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit"
tool_name == "multiedit"
some edit in tool_input.edits
file_path := lower(edit_path(edit))

View File

@@ -123,7 +123,7 @@ type_suppression_patterns := [
# Block Write/Edit operations that introduce type suppression in Python files
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"}
tool_name in {"write", "edit", "notebookedit"}
# Only enforce for Python files
file_path := lower(resolved_file_path)
@@ -144,7 +144,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Patch", "ApplyPatch"}
tool_name in {"patch", "apply_patch"}
content := patch_content
content != null
@@ -161,7 +161,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit"
tool_name == "multiedit"
some edit in tool_input.edits
file_path := lower(edit_path(edit))

View File

@@ -104,7 +104,7 @@ file_path_pattern := `tests/quality/baselines\.json$`
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"}
tool_name in {"write", "edit", "notebookedit"}
file_path := resolved_file_path
regex.match(file_path_pattern, file_path)
@@ -118,7 +118,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit"
tool_name == "multiedit"
some edit in tool_input.edits
file_path := edit_path(edit)

View File

@@ -126,7 +126,7 @@ pattern := `(?:.*\n){500,}`
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"}
tool_name in {"write", "edit", "notebookedit"}
file_path := resolved_file_path
regex.match(file_path_pattern, file_path)
@@ -144,7 +144,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit"
tool_name == "multiedit"
some edit in tool_input.edits
file_path := edit_path(edit)
@@ -163,7 +163,7 @@ deny contains decision if {
deny contains decision if {
input.hook_event_name == "PreToolUse"
tool_name in {"Patch", "ApplyPatch"}
tool_name in {"patch", "apply_patch"}
patch := patch_content
patch != null

View File

@@ -7,7 +7,7 @@ from datetime import datetime
from enum import StrEnum
from uuid import UUID, uuid4
from noteflow.domain.constants.fields import CALENDAR, EMAIL
from noteflow.domain.constants.fields import CALENDAR, EMAIL, UNKNOWN as UNKNOWN_VALUE
from noteflow.domain.utils.time import utc_now
from noteflow.infrastructure.logging import log_state_transition
@@ -150,7 +150,7 @@ class SyncErrorCode(StrEnum):
AUTH_REQUIRED = "auth_required"
PROVIDER_ERROR = "provider_error"
INTERNAL_ERROR = "internal_error"
UNKNOWN = "unknown"
UNKNOWN = UNKNOWN_VALUE
@dataclass

View File

@@ -16,9 +16,14 @@ from typing import NoReturn, Protocol, TypeVar, cast
import grpc
TRequest = TypeVar("TRequest")
TResponse = TypeVar("TResponse")
# Type variables for RpcMethodHandlerProtocol variance
TRequest_co = TypeVar("TRequest_co", covariant=True)
TResponse_contra = TypeVar("TResponse_contra", contravariant=True)
# Metadata type alias
MetadataLike = Sequence[tuple[str, str | bytes]]
@@ -55,42 +60,74 @@ class ServicerContextProtocol(Protocol):
...
class RpcMethodHandlerProtocol(Protocol):
class HandlerCallDetailsProtocol(Protocol):
"""Protocol for gRPC HandlerCallDetails.
This matches the grpc.HandlerCallDetails interface at runtime.
"""
method: str | None
invocation_metadata: MetadataLike | None
class ServerInterceptorProtocol(Protocol):
"""Protocol for grpc.aio.ServerInterceptor."""
async def intercept_service(
self,
continuation: Callable[
[HandlerCallDetailsProtocol],
Awaitable[RpcMethodHandlerProtocol[TRequest, TResponse]],
],
handler_call_details: HandlerCallDetailsProtocol,
) -> RpcMethodHandlerProtocol[TRequest, TResponse]:
...
class RpcMethodHandlerProtocol[TRequest_co, TResponse_contra](Protocol):
"""Protocol for RPC method handlers.
This matches the grpc.RpcMethodHandler interface at runtime.
Note: The handler properties return objects that are callable with
specific request/response types, but we use object here for Protocol
compatibility. Use casts at call sites for proper typing.
"""
@property
def request_deserializer(self) -> Callable[[bytes], object] | None:
def request_deserializer(self) -> Callable[[bytes], TRequest_co] | None:
"""Request deserializer function."""
...
@property
def response_serializer(self) -> Callable[[object], bytes] | None:
def response_serializer(self) -> Callable[[TResponse_contra], bytes] | None:
"""Response serializer function."""
...
@property
def unary_unary(self) -> object | None:
def unary_unary(
self,
) -> Callable[[TRequest_co, ServicerContextProtocol], Awaitable[TResponse_contra]] | None:
"""Unary-unary handler behavior."""
...
@property
def unary_stream(self) -> object | None:
def unary_stream(
self,
) -> Callable[[TRequest_co, ServicerContextProtocol], AsyncIterator[TResponse_contra]] | None:
"""Unary-stream handler behavior."""
...
@property
def stream_unary(self) -> object | None:
def stream_unary(
self,
) -> Callable[[AsyncIterator[TRequest_co], ServicerContextProtocol], Awaitable[TResponse_contra]] | None:
"""Stream-unary handler behavior."""
...
@property
def stream_stream(self) -> object | None:
def stream_stream(
self,
) -> Callable[
[AsyncIterator[TRequest_co], ServicerContextProtocol],
AsyncIterator[TResponse_contra],
] | None:
"""Stream-stream handler behavior."""
...
@@ -112,7 +149,7 @@ class GrpcFactoriesProtocol(Protocol):
*,
request_deserializer: Callable[[bytes], TRequest] | None = None,
response_serializer: Callable[[TResponse], bytes] | None = None,
) -> RpcMethodHandlerProtocol:
) -> RpcMethodHandlerProtocol[TRequest, TResponse]:
"""Create a unary-unary RPC method handler."""
...
@@ -125,7 +162,7 @@ class GrpcFactoriesProtocol(Protocol):
*,
request_deserializer: Callable[[bytes], TRequest] | None = None,
response_serializer: Callable[[TResponse], bytes] | None = None,
) -> RpcMethodHandlerProtocol:
) -> RpcMethodHandlerProtocol[TRequest, TResponse]:
"""Create a unary-stream RPC method handler."""
...
@@ -138,7 +175,7 @@ class GrpcFactoriesProtocol(Protocol):
*,
request_deserializer: Callable[[bytes], TRequest] | None = None,
response_serializer: Callable[[TResponse], bytes] | None = None,
) -> RpcMethodHandlerProtocol:
) -> RpcMethodHandlerProtocol[TRequest, TResponse]:
"""Create a stream-unary RPC method handler."""
...
@@ -151,13 +188,11 @@ class GrpcFactoriesProtocol(Protocol):
*,
request_deserializer: Callable[[bytes], TRequest] | None = None,
response_serializer: Callable[[TResponse], bytes] | None = None,
) -> RpcMethodHandlerProtocol:
) -> RpcMethodHandlerProtocol[TRequest, TResponse]:
"""Create a stream-stream RPC method handler."""
...
# Typed grpc module with proper factory function signatures
# Use this instead of importing grpc directly when you need typed factories
typed_grpc: GrpcFactoriesProtocol = cast(
GrpcFactoriesProtocol, importlib.import_module("grpc")
)
typed_grpc: GrpcFactoriesProtocol = cast(GrpcFactoriesProtocol, importlib.import_module("grpc"))

View File

@@ -11,11 +11,9 @@ from __future__ import annotations
from collections.abc import AsyncIterator, Awaitable, Callable
from functools import partial
from typing import NoReturn, cast
from typing import TYPE_CHECKING, NoReturn, TypeVar, cast
import grpc
from grpc import HandlerCallDetails, RpcMethodHandler
from grpc.aio import ServerInterceptor
from noteflow.infrastructure.logging import (
get_logger,
@@ -25,13 +23,22 @@ from noteflow.infrastructure.logging import (
)
from ._types import (
HandlerCallDetailsProtocol,
RpcMethodHandlerProtocol,
ServerInterceptorProtocol,
ServicerContextProtocol,
TRequest,
TResponse,
typed_grpc,
)
_TRequest = TypeVar("_TRequest")
_TResponse = TypeVar("_TResponse")
if TYPE_CHECKING:
ServerInterceptor = ServerInterceptorProtocol
else:
from grpc.aio import ServerInterceptor
logger = get_logger(__name__)
# Metadata keys for identity context
@@ -126,21 +133,18 @@ class IdentityInterceptor(ServerInterceptor):
async def intercept_service(
self,
continuation: Callable[
[HandlerCallDetails],
Awaitable[RpcMethodHandler[TRequest, TResponse]],
[HandlerCallDetailsProtocol],
Awaitable[RpcMethodHandlerProtocol[_TRequest, _TResponse]],
],
handler_call_details: HandlerCallDetails,
) -> RpcMethodHandler[TRequest, TResponse]:
handler_call_details: HandlerCallDetailsProtocol,
) -> RpcMethodHandlerProtocol[_TRequest, _TResponse]:
"""Intercept incoming RPC calls to validate and set identity context."""
metadata = dict(handler_call_details.invocation_metadata or [])
request_id = _get_request_id(metadata, handler_call_details.method)
if request_id is None:
handler = await continuation(handler_call_details)
return cast(
RpcMethodHandler[TRequest, TResponse],
_create_unauthenticated_handler(handler, _ERR_MISSING_REQUEST_ID),
)
return _create_unauthenticated_handler(handler, _ERR_MISSING_REQUEST_ID)
_apply_identity_context(metadata, request_id)
@@ -156,35 +160,50 @@ class IdentityInterceptor(ServerInterceptor):
def _create_unauthenticated_handler(
handler: RpcMethodHandlerProtocol,
handler: RpcMethodHandlerProtocol[_TRequest, _TResponse],
message: str,
) -> RpcMethodHandlerProtocol:
) -> RpcMethodHandlerProtocol[_TRequest, _TResponse]:
"""Create a handler that rejects with UNAUTHENTICATED status."""
request_deserializer = handler.request_deserializer
response_serializer = handler.response_serializer
if response_serializer is not None:
assert callable(response_serializer)
response_serializer_any = cast(Callable[[object], bytes] | None, response_serializer)
if handler.unary_unary is not None:
return typed_grpc.unary_unary_rpc_method_handler(
partial(_reject_unary_unary, message),
request_deserializer=request_deserializer,
response_serializer=response_serializer,
return cast(
RpcMethodHandlerProtocol[_TRequest, _TResponse],
typed_grpc.unary_unary_rpc_method_handler(
partial(_reject_unary_unary, message),
request_deserializer=request_deserializer,
response_serializer=response_serializer_any,
),
)
if handler.unary_stream is not None:
return typed_grpc.unary_stream_rpc_method_handler(
partial(_reject_unary_stream, message),
request_deserializer=request_deserializer,
response_serializer=response_serializer,
return cast(
RpcMethodHandlerProtocol[_TRequest, _TResponse],
typed_grpc.unary_stream_rpc_method_handler(
partial(_reject_unary_stream, message),
request_deserializer=request_deserializer,
response_serializer=response_serializer_any,
),
)
if handler.stream_unary is not None:
return typed_grpc.stream_unary_rpc_method_handler(
partial(_reject_stream_unary, message),
request_deserializer=request_deserializer,
response_serializer=response_serializer,
return cast(
RpcMethodHandlerProtocol[_TRequest, _TResponse],
typed_grpc.stream_unary_rpc_method_handler(
partial(_reject_stream_unary, message),
request_deserializer=request_deserializer,
response_serializer=response_serializer_any,
),
)
if handler.stream_stream is not None:
return typed_grpc.stream_stream_rpc_method_handler(
partial(_reject_stream_stream, message),
request_deserializer=request_deserializer,
response_serializer=response_serializer,
return cast(
RpcMethodHandlerProtocol[_TRequest, _TResponse],
typed_grpc.stream_stream_rpc_method_handler(
partial(_reject_stream_stream, message),
request_deserializer=request_deserializer,
response_serializer=response_serializer_any,
),
)
return handler

View File

@@ -3,10 +3,13 @@
from __future__ import annotations
from collections.abc import AsyncIterator, Awaitable, Callable
from typing import cast
from .._types import (
RpcMethodHandlerProtocol,
ServicerContextProtocol,
TRequest,
TResponse,
typed_grpc,
)
from ._wrappers import (
@@ -17,90 +20,118 @@ from ._wrappers import (
)
def wrap_unary_unary_handler(
handler: RpcMethodHandlerProtocol,
def _wrap_unary_unary_handler(
handler: RpcMethodHandlerProtocol[TRequest, TResponse],
method: str,
) -> RpcMethodHandlerProtocol:
"""Create wrapped unary-unary handler with logging."""
) -> RpcMethodHandlerProtocol[TRequest, TResponse]:
if handler.unary_unary is None:
raise TypeError("Unary-unary handler is missing")
wrapped: Callable[
[object, ServicerContextProtocol],
Awaitable[object],
] = wrap_unary_unary(handler.unary_unary, method)
request_deserializer = handler.request_deserializer
response_serializer = handler.response_serializer
return typed_grpc.unary_unary_rpc_method_handler(
wrapped,
request_deserializer=request_deserializer,
response_serializer=response_serializer,
] = wrap_unary_unary(
cast(Callable[[object, ServicerContextProtocol], Awaitable[object]], handler.unary_unary),
method,
)
return cast(
RpcMethodHandlerProtocol[TRequest, TResponse],
typed_grpc.unary_unary_rpc_method_handler(
wrapped,
request_deserializer=cast(
Callable[[bytes], object] | None, handler.request_deserializer
),
response_serializer=cast(Callable[[object], bytes] | None, handler.response_serializer),
),
)
def wrap_unary_stream_handler(
handler: RpcMethodHandlerProtocol,
def _wrap_unary_stream_handler(
handler: RpcMethodHandlerProtocol[TRequest, TResponse],
method: str,
) -> RpcMethodHandlerProtocol:
"""Create wrapped unary-stream handler with logging."""
) -> RpcMethodHandlerProtocol[TRequest, TResponse]:
if handler.unary_stream is None:
raise TypeError("Unary-stream handler is missing")
wrapped: Callable[
[object, ServicerContextProtocol],
AsyncIterator[object],
] = wrap_unary_stream(handler.unary_stream, method)
request_deserializer = handler.request_deserializer
response_serializer = handler.response_serializer
return typed_grpc.unary_stream_rpc_method_handler(
wrapped,
request_deserializer=request_deserializer,
response_serializer=response_serializer,
] = wrap_unary_stream(
cast(
Callable[[object, ServicerContextProtocol], AsyncIterator[object]], handler.unary_stream
),
method,
)
return cast(
RpcMethodHandlerProtocol[TRequest, TResponse],
typed_grpc.unary_stream_rpc_method_handler(
wrapped,
request_deserializer=cast(
Callable[[bytes], object] | None, handler.request_deserializer
),
response_serializer=cast(Callable[[object], bytes] | None, handler.response_serializer),
),
)
def wrap_stream_unary_handler(
handler: RpcMethodHandlerProtocol,
def _wrap_stream_unary_handler(
handler: RpcMethodHandlerProtocol[TRequest, TResponse],
method: str,
) -> RpcMethodHandlerProtocol:
"""Create wrapped stream-unary handler with logging."""
) -> RpcMethodHandlerProtocol[TRequest, TResponse]:
if handler.stream_unary is None:
raise TypeError("Stream-unary handler is missing")
wrapped: Callable[
[AsyncIterator[object], ServicerContextProtocol],
Awaitable[object],
] = wrap_stream_unary(handler.stream_unary, method)
request_deserializer = handler.request_deserializer
response_serializer = handler.response_serializer
return typed_grpc.stream_unary_rpc_method_handler(
wrapped,
request_deserializer=request_deserializer,
response_serializer=response_serializer,
] = wrap_stream_unary(
cast(
Callable[[AsyncIterator[object], ServicerContextProtocol], Awaitable[object]],
handler.stream_unary,
),
method,
)
return cast(
RpcMethodHandlerProtocol[TRequest, TResponse],
typed_grpc.stream_unary_rpc_method_handler(
wrapped,
request_deserializer=cast(
Callable[[bytes], object] | None, handler.request_deserializer
),
response_serializer=cast(Callable[[object], bytes] | None, handler.response_serializer),
),
)
def wrap_stream_stream_handler(
handler: RpcMethodHandlerProtocol,
def _wrap_stream_stream_handler(
handler: RpcMethodHandlerProtocol[TRequest, TResponse],
method: str,
) -> RpcMethodHandlerProtocol:
"""Create wrapped stream-stream handler with logging."""
) -> RpcMethodHandlerProtocol[TRequest, TResponse]:
if handler.stream_stream is None:
raise TypeError("Stream-stream handler is missing")
wrapped: Callable[
[AsyncIterator[object], ServicerContextProtocol],
AsyncIterator[object],
] = wrap_stream_stream(handler.stream_stream, method)
request_deserializer = handler.request_deserializer
response_serializer = handler.response_serializer
return typed_grpc.stream_stream_rpc_method_handler(
wrapped,
request_deserializer=request_deserializer,
response_serializer=response_serializer,
] = wrap_stream_stream(
cast(
Callable[[AsyncIterator[object], ServicerContextProtocol], AsyncIterator[object]],
handler.stream_stream,
),
method,
)
return cast(
RpcMethodHandlerProtocol[TRequest, TResponse],
typed_grpc.stream_stream_rpc_method_handler(
wrapped,
request_deserializer=cast(
Callable[[bytes], object] | None, handler.request_deserializer
),
response_serializer=cast(Callable[[object], bytes] | None, handler.response_serializer),
),
)
def create_logging_handler(
handler: RpcMethodHandlerProtocol,
handler: RpcMethodHandlerProtocol[TRequest, TResponse],
method: str,
) -> RpcMethodHandlerProtocol:
) -> RpcMethodHandlerProtocol[TRequest, TResponse]:
"""Wrap an RPC handler to add request logging.
Args:
@@ -111,12 +142,11 @@ def create_logging_handler(
Wrapped handler with logging.
"""
if handler.unary_unary is not None:
return wrap_unary_unary_handler(handler, method)
return _wrap_unary_unary_handler(handler, method)
if handler.unary_stream is not None:
return wrap_unary_stream_handler(handler, method)
return _wrap_unary_stream_handler(handler, method)
if handler.stream_unary is not None:
return wrap_stream_unary_handler(handler, method)
return _wrap_stream_unary_handler(handler, method)
if handler.stream_stream is not None:
return wrap_stream_stream_handler(handler, method)
# Fallback: return original handler if type unknown
return _wrap_stream_stream_handler(handler, method)
return handler

View File

@@ -1,19 +1,29 @@
"""Request logging interceptor for gRPC calls.
Log every RPC call with method, status, duration, peer, and request context
at INFO level for production observability and traceability.
Log every RPC call with method, status, duration, peer, and request context at
INFO level for production observability and traceability.
"""
from __future__ import annotations
from collections.abc import Awaitable, Callable
from typing import cast
from grpc import HandlerCallDetails, RpcMethodHandler
from grpc.aio import ServerInterceptor
from typing import TYPE_CHECKING, TypeVar
from .._types import (
HandlerCallDetailsProtocol,
RpcMethodHandlerProtocol,
ServerInterceptorProtocol,
)
from ._handler_factory import create_logging_handler
from .._types import TRequest, TResponse
_TRequest = TypeVar("_TRequest")
_TResponse = TypeVar("_TResponse")
if TYPE_CHECKING:
ServerInterceptor = ServerInterceptorProtocol
else:
from grpc.aio import ServerInterceptor
class RequestLoggingInterceptor(ServerInterceptor):
@@ -30,25 +40,12 @@ class RequestLoggingInterceptor(ServerInterceptor):
async def intercept_service(
self,
continuation: Callable[
[HandlerCallDetails],
Awaitable[RpcMethodHandler[TRequest, TResponse]],
[HandlerCallDetailsProtocol],
Awaitable[RpcMethodHandlerProtocol[_TRequest, _TResponse]],
],
handler_call_details: HandlerCallDetails,
) -> RpcMethodHandler[TRequest, TResponse]:
"""Intercept incoming RPC calls to log request timing and status.
Args:
continuation: The next interceptor or handler.
handler_call_details: Details about the RPC call.
Returns:
Wrapped RPC handler that logs on completion.
"""
handler_call_details: HandlerCallDetailsProtocol,
) -> RpcMethodHandlerProtocol[_TRequest, _TResponse]:
"""Intercept incoming RPC calls to log request timing and status."""
handler = await continuation(handler_call_details)
method = handler_call_details.method
# Return wrapped handler that logs on completion
return cast(
RpcMethodHandler[TRequest, TResponse],
create_logging_handler(handler, method),
)
method = handler_call_details.method or ""
return create_logging_handler(handler, method)

View File

@@ -77,7 +77,6 @@ async def require_calendar_service(
return host.calendar_service
logger.warning(f"{operation}_unavailable", reason="service_not_enabled")
await abort_unavailable(context, _ERR_CALENDAR_NOT_ENABLED)
raise # Unreachable but helps type checker
class CalendarMixin:
@@ -117,7 +116,6 @@ class CalendarMixin:
except CalendarServiceError as e:
logger.error("calendar_list_events_failed", error=str(e), provider=provider)
await abort_internal(context, str(e))
raise # Unreachable but helps type checker
proto_events = [_calendar_event_to_proto(event) for event in events]
@@ -166,9 +164,7 @@ class CalendarMixin:
status=status.status,
)
authenticated_count = sum(
bool(provider.is_authenticated) for provider in providers
)
authenticated_count = sum(bool(provider.is_authenticated) for provider in providers)
logger.info(
"calendar_get_providers_success",
total_providers=len(providers),
@@ -205,7 +201,6 @@ class CalendarMixin:
error=str(e),
)
await abort_invalid_argument(context, str(e))
raise # Unreachable but helps type checker
logger.info(
"oauth_initiate_success",

View File

@@ -92,20 +92,18 @@ class DiarizationJobMixin:
Prunes both in-memory task references and database records.
"""
# Clean up in-memory task references for completed tasks
completed_tasks = [
job_id for job_id, task in self.diarization_tasks.items() if task.done()
]
completed_tasks = [job_id for job_id, task in self.diarization_tasks.items() if task.done()]
for job_id in completed_tasks:
self.diarization_tasks.pop(job_id, None)
# Prune old completed jobs from database
async with cast(DiarizationJobRepositoryProvider, self.create_repository_provider()) as repo:
async with cast(
DiarizationJobRepositoryProvider, self.create_repository_provider()
) as repo:
if not repo.supports_diarization_jobs:
logger.debug("Job pruning skipped: database required")
return
pruned = await repo.diarization_jobs.prune_completed(
self.diarization_job_ttl_seconds
)
pruned = await repo.diarization_jobs.prune_completed(self.diarization_job_ttl_seconds)
await repo.commit()
if pruned > 0:
logger.debug("Pruned %d completed diarization jobs", pruned)
@@ -121,15 +119,15 @@ class DiarizationJobMixin:
"""
await self.prune_diarization_jobs()
async with cast(DiarizationJobRepositoryProvider, self.create_repository_provider()) as repo:
async with cast(
DiarizationJobRepositoryProvider, self.create_repository_provider()
) as repo:
if not repo.supports_diarization_jobs:
await abort_not_found(context, "Diarization jobs (database required)", "")
raise # Unreachable but helps type checker
job = await repo.diarization_jobs.get(request.job_id)
if job is None:
await abort_not_found(context, "Diarization job", request.job_id)
raise # Unreachable but helps type checker
return _build_job_status(job)
@@ -145,7 +143,9 @@ class DiarizationJobMixin:
job_id = request.job_id
await _cancel_running_task(self.diarization_tasks, job_id)
async with cast(DiarizationJobRepositoryProvider, self.create_repository_provider()) as repo:
async with cast(
DiarizationJobRepositoryProvider, self.create_repository_provider()
) as repo:
if not repo.supports_diarization_jobs:
await abort_database_required(context, "Diarization job cancellation")
raise AssertionError(UNREACHABLE_ERROR) # abort is NoReturn
@@ -185,7 +185,9 @@ class DiarizationJobMixin:
"""
response = noteflow_pb2.GetActiveDiarizationJobsResponse()
async with cast(DiarizationJobRepositoryProvider, self.create_repository_provider()) as repo:
async with cast(
DiarizationJobRepositoryProvider, self.create_repository_provider()
) as repo:
if not repo.supports_diarization_jobs:
# Return empty list if DB not available
return response

View File

@@ -175,7 +175,6 @@ class EntitiesMixin:
if updated is None:
await abort_not_found(context, ENTITY_ENTITY, request.entity_id)
raise # Unreachable but helps type checker
await uow.commit()

View File

@@ -75,9 +75,7 @@ def _merge_workspace_settings(
trigger_rules=updates.trigger_rules
if updates.trigger_rules is not None
else current.trigger_rules,
rag_enabled=updates.rag_enabled
if updates.rag_enabled is not None
else current.rag_enabled,
rag_enabled=updates.rag_enabled if updates.rag_enabled is not None else current.rag_enabled,
default_summarization_template=updates.default_summarization_template
if updates.default_summarization_template is not None
else current.default_summarization_template,
@@ -192,14 +190,12 @@ class IdentityMixin:
if not request.workspace_id:
await abort_invalid_argument(context, ERROR_WORKSPACE_ID_REQUIRED)
raise # Unreachable but helps type checker
workspace_id = await parse_workspace_id(request.workspace_id, context)
async with cast(UnitOfWork, self.create_repository_provider()) as uow:
if not uow.supports_workspaces:
await abort_database_required(context, WORKSPACES_LABEL)
raise # Unreachable but helps type checker
user_ctx = await self.identity_service.get_or_create_default_user(uow)
workspace, membership = await self._verify_workspace_access(
@@ -231,14 +227,12 @@ class IdentityMixin:
"""Get workspace settings."""
if not request.workspace_id:
await abort_invalid_argument(context, ERROR_WORKSPACE_ID_REQUIRED)
raise # Unreachable but helps type checker
workspace_id = await parse_workspace_id(request.workspace_id, context)
async with cast(UnitOfWork, self.create_repository_provider()) as uow:
if not uow.supports_workspaces:
await abort_database_required(context, WORKSPACES_LABEL)
raise # Unreachable but helps type checker
user_ctx = await self.identity_service.get_or_create_default_user(uow)
workspace, _ = await self._verify_workspace_access(
@@ -259,14 +253,12 @@ class IdentityMixin:
"""Update workspace settings."""
if not request.workspace_id:
await abort_invalid_argument(context, ERROR_WORKSPACE_ID_REQUIRED)
raise # Unreachable but helps type checker
workspace_id = await parse_workspace_id(request.workspace_id, context)
async with cast(UnitOfWork, self.create_repository_provider()) as uow:
if not uow.supports_workspaces:
await abort_database_required(context, WORKSPACES_LABEL)
raise # Unreachable but helps type checker
user_ctx = await self.identity_service.get_or_create_default_user(uow)
workspace, membership = await self._verify_workspace_access(
@@ -278,7 +270,6 @@ class IdentityMixin:
if not membership.role.can_admin():
await abort_permission_denied(context, ERROR_WORKSPACE_ADMIN_REQUIRED)
raise # Unreachable but helps type checker
updates = proto_to_workspace_settings(request.settings)
if updates is None:
@@ -302,11 +293,9 @@ class IdentityMixin:
workspace = await uow.workspaces.get(workspace_id)
if not workspace:
await abort_not_found(context, ENTITY_WORKSPACE, str(workspace_id))
raise # Unreachable but helps type checker
membership = await uow.workspaces.get_membership(workspace_id, user_id)
if not membership:
await abort_not_found(context, "Workspace membership", str(workspace_id))
raise # Unreachable but helps type checker
return workspace, membership

View File

@@ -42,7 +42,6 @@ if TYPE_CHECKING:
from .._types import GrpcContext
from ..protocols import ServicerHost
from ..errors._constants import UNREACHABLE_ERROR
logger = get_logger(__name__)
@@ -66,7 +65,6 @@ async def _load_meeting_for_stop(
if meeting is None:
logger.warning("StopMeeting: meeting not found", meeting_id=meeting_id_str)
await abort_not_found(context, ENTITY_MEETING, meeting_id_str)
raise AssertionError(UNREACHABLE_ERROR)
return meeting
@@ -85,7 +83,7 @@ async def _stop_meeting_and_persist(context: _StopMeetingContext) -> Meeting:
)
return context.meeting
previous_state = context.meeting.state.value
previous_state = context.meeting.state.name
await transition_to_stopped(
context.meeting,
context.meeting_id,
@@ -143,7 +141,9 @@ class MeetingMixin:
op_context = self.get_operation_context(context)
async with cast(MeetingRepositoryProvider, self.create_repository_provider()) as repo:
project_id = await resolve_create_project_id(self, repo, op_context, project_id)
project_id = await resolve_create_project_id(
cast("ServicerHost", self), repo, op_context, project_id
)
meeting = Meeting.create(
title=request.title,
@@ -212,7 +212,7 @@ class MeetingMixin:
async with cast(MeetingRepositoryProvider, self.create_repository_provider()) as repo:
if project_id is None and not project_ids:
project_id = await resolve_active_project_id(self, repo)
project_id = await resolve_active_project_id(cast("ServicerHost", self), repo)
meetings, total = await repo.meetings.list_all(
states=states,
@@ -254,7 +254,6 @@ class MeetingMixin:
if meeting is None:
logger.warning("GetMeeting: meeting not found", meeting_id=request.meeting_id)
await abort_not_found(context, ENTITY_MEETING, request.meeting_id)
raise # Unreachable but helps type checker
# Load segments if requested
if request.include_segments:
segments = await repo.segments.get_by_meeting(meeting.id)
@@ -283,7 +282,6 @@ class MeetingMixin:
if not success:
logger.warning("DeleteMeeting: meeting not found", meeting_id=request.meeting_id)
await abort_not_found(context, ENTITY_MEETING, request.meeting_id)
raise # Unreachable but helps type checker
await repo.commit()
logger.info("Meeting deleted", meeting_id=request.meeting_id)

View File

@@ -35,8 +35,6 @@ if TYPE_CHECKING:
from ..protocols import ProjectRepositoryProvider
async def _parse_project_and_user_ids(
request_project_id: str,
request_user_id: str,
@@ -45,21 +43,19 @@ async def _parse_project_and_user_ids(
"""Parse and validate project and user IDs from request."""
if not request_project_id:
await abort_invalid_argument(context, ERROR_PROJECT_ID_REQUIRED)
raise # Unreachable but helps type checker
if not request_user_id:
await abort_invalid_argument(context, ERROR_USER_ID_REQUIRED)
raise # Unreachable but helps type checker
try:
project_id = UUID(request_project_id)
user_id = UUID(request_user_id)
except ValueError as e:
await abort_invalid_argument(context, f"{ERROR_INVALID_UUID_PREFIX}{e}")
raise # Unreachable but helps type checker
return project_id, user_id
class ProjectMembershipMixin:
"""Mixin providing project membership functionality.
@@ -93,7 +89,6 @@ class ProjectMembershipMixin:
)
if membership is None:
await abort_not_found(context, ENTITY_PROJECT, request.project_id)
raise # Unreachable but helps type checker
return membership_to_proto(membership)
@@ -119,8 +114,9 @@ class ProjectMembershipMixin:
role=role,
)
if membership is None:
await abort_not_found(context, "Membership", f"{request.project_id}/{request.user_id}")
raise # Unreachable but helps type checker
await abort_not_found(
context, "Membership", f"{request.project_id}/{request.user_id}"
)
return membership_to_proto(membership)
@@ -159,8 +155,9 @@ class ProjectMembershipMixin:
try:
project_id = UUID(request.project_id)
except ValueError:
await abort_invalid_argument(context, f"{ERROR_INVALID_PROJECT_ID_PREFIX}{request.project_id}")
raise # Unreachable but helps type checker
await abort_invalid_argument(
context, f"{ERROR_INVALID_PROJECT_ID_PREFIX}{request.project_id}"
)
limit = request.limit if request.limit > 0 else 100
offset = max(request.offset, 0)

View File

@@ -41,8 +41,6 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
async def _require_and_parse_project_id(
request_project_id: str,
context: GrpcContext,
@@ -50,7 +48,6 @@ async def _require_and_parse_project_id(
"""Require and parse a project_id from request."""
if not request_project_id:
await abort_invalid_argument(context, ERROR_PROJECT_ID_REQUIRED)
raise # Unreachable but helps type checker
return await parse_project_id(request_project_id, context)
@@ -61,9 +58,9 @@ async def _require_and_parse_workspace_id(
"""Require and parse a workspace_id from request."""
if not request_workspace_id:
await abort_invalid_argument(context, ERROR_WORKSPACE_ID_REQUIRED)
raise # Unreachable but helps type checker
return await parse_workspace_id(request_workspace_id, context)
class ProjectMixin:
"""Mixin providing project management functionality.
@@ -89,11 +86,12 @@ class ProjectMixin:
if not request.name:
await abort_invalid_argument(context, "name is required")
raise # Unreachable but helps type checker
slug = request.slug if request.HasField("slug") else None
description = request.description if request.HasField("description") else None
settings = proto_to_project_settings(request.settings) if request.HasField("settings") else None
settings = (
proto_to_project_settings(request.settings) if request.HasField("settings") else None
)
async with cast(ProjectRepositoryProvider, self.create_repository_provider()) as uow:
await require_feature_projects(uow, context)
@@ -123,7 +121,6 @@ class ProjectMixin:
project = await project_service.get_project(uow, project_id)
if project is None:
await abort_not_found(context, ENTITY_PROJECT, request.project_id)
raise # Unreachable but helps type checker
return project_to_proto(project)
@@ -138,7 +135,6 @@ class ProjectMixin:
if not request.slug:
await abort_invalid_argument(context, "slug is required")
raise # Unreachable but helps type checker
async with cast(ProjectRepositoryProvider, self.create_repository_provider()) as uow:
await require_feature_projects(uow, context)
@@ -146,7 +142,6 @@ class ProjectMixin:
project = await project_service.get_project_by_slug(uow, workspace_id, request.slug)
if project is None:
await abort_not_found(context, ENTITY_PROJECT, request.slug)
raise # Unreachable but helps type checker
return project_to_proto(project)
@@ -159,6 +154,7 @@ class ProjectMixin:
project_service = await require_project_service(self.project_service, context)
workspace_id = await _require_and_parse_workspace_id(request.workspace_id, context)
include_archived = request.include_archived
limit = request.limit if request.limit > 0 else 50
offset = max(request.offset, 0)
@@ -166,21 +162,20 @@ class ProjectMixin:
await require_feature_projects(uow, context)
projects = await project_service.list_projects(
uow=uow,
workspace_id=workspace_id,
include_archived=request.include_archived,
uow,
workspace_id,
include_archived=include_archived,
limit=limit,
offset=offset,
)
total_count = await project_service.count_projects(
uow=uow,
workspace_id=workspace_id,
include_archived=request.include_archived,
uow,
workspace_id,
include_archived=include_archived,
)
return noteflow_pb2.ListProjectsResponse(
projects=[project_to_proto(p) for p in projects],
projects=[project_to_proto(project) for project in projects],
total_count=total_count,
)
@@ -196,14 +191,16 @@ class ProjectMixin:
name = request.name if request.HasField("name") else None
slug = request.slug if request.HasField("slug") else None
description = request.description if request.HasField("description") else None
settings = proto_to_project_settings(request.settings) if request.HasField("settings") else None
settings = (
proto_to_project_settings(request.settings) if request.HasField("settings") else None
)
async with cast(ProjectRepositoryProvider, self.create_repository_provider()) as uow:
await require_feature_projects(uow, context)
project = await project_service.update_project(
uow=uow,
project_id=project_id,
uow,
project_id,
name=name,
slug=slug,
description=description,
@@ -211,7 +208,6 @@ class ProjectMixin:
)
if project is None:
await abort_not_found(context, ENTITY_PROJECT, request.project_id)
raise # Unreachable but helps type checker
return project_to_proto(project)
@@ -231,11 +227,9 @@ class ProjectMixin:
project = await project_service.archive_project(uow, project_id)
except CannotArchiveDefaultProjectError:
await abort_failed_precondition(context, "Cannot archive the default project")
raise # Unreachable but helps type checker
if project is None:
await abort_not_found(context, ENTITY_PROJECT, request.project_id)
raise # Unreachable but helps type checker
return project_to_proto(project)
@@ -254,7 +248,6 @@ class ProjectMixin:
project = await project_service.restore_project(uow, project_id)
if project is None:
await abort_not_found(context, ENTITY_PROJECT, request.project_id)
raise # Unreachable but helps type checker
return project_to_proto(project)
@@ -302,7 +295,6 @@ class ProjectMixin:
)
except ValueError as exc:
await abort_invalid_argument(context, str(exc))
raise # Unreachable but helps type checker
await uow.commit()
return noteflow_pb2.SetActiveProjectResponse()
@@ -327,11 +319,9 @@ class ProjectMixin:
)
except ValueError as exc:
await abort_invalid_argument(context, str(exc))
raise # Unreachable but helps type checker
if project is None:
await abort_not_found(context, ENTITY_PROJECT, "default")
raise # Unreachable but helps type checker
response = noteflow_pb2.GetActiveProjectResponse(
project=project_to_proto(project),

View File

@@ -59,7 +59,6 @@ async def _decode_chunk_audio(
)
except ValueError as e:
await abort_invalid_argument(context, str(e))
raise # Unreachable but helps type checker
conversion = AudioConversionContext(
source_sample_rate=sample_rate,

View File

@@ -30,10 +30,13 @@ from ._template_resolution import (
if TYPE_CHECKING:
from collections.abc import Callable
from noteflow.application.services.ner import NerService
from noteflow.application.services.summarization import SummarizationService
from noteflow.application.services.webhooks import WebhookService
from noteflow.domain.entities import Meeting
from noteflow.domain.identity import OperationContext
from noteflow.infrastructure.asr import FasterWhisperEngine
from noteflow.infrastructure.diarization.engine import DiarizationEngine
logger = get_logger(__name__)
@@ -48,9 +51,41 @@ class _SummaryGenerationContext:
force_regenerate: bool
@dataclass(frozen=True, slots=True)
class _SummaryRequestContext:
request: noteflow_pb2.GenerateSummaryRequest
meeting: Meeting
segments: list[Segment]
style_instructions: str | None
context: GrpcContext
op_context: OperationContext
async def resolve_style_prompt(
self,
repo: UnitOfWork,
summarization_service: SummarizationService | None,
) -> str | None:
"""Resolve style prompt for this request context."""
return await resolve_template_prompt(
TemplateResolutionInputs(
request=self.request,
meeting=self.meeting,
segments=self.segments,
style_instructions=self.style_instructions,
context=self.context,
op_context=self.op_context,
repo=repo,
summarization_service=summarization_service,
)
)
class SummarizationGenerationMixin:
"""Generate summaries and handle summary webhooks."""
asr_engine: FasterWhisperEngine | None
diarization_engine: DiarizationEngine | None
ner_service: NerService | None
summarization_service: SummarizationService | None
webhook_service: WebhookService | None
create_repository_provider: Callable[..., object]
@@ -68,28 +103,28 @@ class SummarizationGenerationMixin:
meeting_id=request.meeting_id,
include_provider_details=True,
)
meeting_id, op_context, style_instructions, meeting, existing, segments = (
await self._prepare_summary_request(request, context)
)
(
meeting_id,
op_context,
style_instructions,
meeting,
existing,
segments,
) = await self._prepare_summary_request(request, context)
if existing and not request.force_regenerate:
await self._mark_summary_step(request.meeting_id, ProcessingStepStatus.COMPLETED)
return summary_to_proto(existing)
await self._ensure_cloud_provider()
async with cast(UnitOfWork, self.create_repository_provider()) as repo:
style_prompt = await self._resolve_style_prompt(
TemplateResolutionInputs(
request=request,
meeting=meeting,
segments=segments,
style_instructions=style_instructions,
context=context,
op_context=op_context,
repo=repo,
summarization_service=self.summarization_service,
)
)
request_context = _SummaryRequestContext(
request=request,
meeting=meeting,
segments=segments,
style_instructions=style_instructions,
context=context,
op_context=op_context,
)
style_prompt = await self._resolve_style_prompt_for_request(request_context)
saved, trigger_webhook = await self._generate_summary_with_status(
_SummaryGenerationContext(
meeting_id=meeting_id,
@@ -121,6 +156,13 @@ class SummarizationGenerationMixin:
),
)
async def _resolve_style_prompt_for_request(
self,
request_context: _SummaryRequestContext,
) -> str | None:
async with cast(UnitOfWork, self.create_repository_provider()) as repo:
return await request_context.resolve_style_prompt(repo, self.summarization_service)
async def _resolve_style_prompt(self, inputs: TemplateResolutionInputs) -> str | None:
return await resolve_template_prompt(inputs)
@@ -244,7 +286,6 @@ class SummarizationGenerationMixin:
meeting = await repo.meetings.get(meeting_id)
if meeting is None:
await abort_not_found(context, ENTITY_MEETING, request.meeting_id)
raise # Unreachable but helps type checker
existing = await repo.summaries.get_by_meeting(meeting.id)
segments = list(await repo.segments.get_by_meeting(meeting.id))

View File

@@ -9,6 +9,8 @@ from typing import Generic, TypeVar, Never
import grpc
from noteflow.domain.constants.fields import UNKNOWN as UNKNOWN_VALUE
T = TypeVar("T", covariant=True)
U = TypeVar("U")
E = TypeVar("E", bound=BaseException)
@@ -17,7 +19,7 @@ E = TypeVar("E", bound=BaseException)
class ClientErrorCode(StrEnum):
"""Client-facing error codes for gRPC operations."""
UNKNOWN = "unknown"
UNKNOWN = UNKNOWN_VALUE
NOT_CONNECTED = "not_connected"
NOT_FOUND = "not_found"
INVALID_ARGUMENT = "invalid_argument"

View File

@@ -5,8 +5,8 @@ Tests identity context validation and per-RPC request logging.
from __future__ import annotations
from collections.abc import Awaitable, Callable
from typing import Protocol, cast
from collections.abc import AsyncIterator, Awaitable, Callable
from typing import cast
from unittest.mock import AsyncMock, MagicMock, patch
import grpc
@@ -19,7 +19,11 @@ from noteflow.grpc.interceptors import (
IdentityInterceptor,
RequestLoggingInterceptor,
)
from noteflow.grpc.interceptors._types import ServicerContextProtocol
from noteflow.grpc.interceptors._types import (
HandlerCallDetailsProtocol,
RpcMethodHandlerProtocol,
ServicerContextProtocol,
)
from noteflow.infrastructure.logging import (
get_request_id,
get_user_id,
@@ -27,14 +31,13 @@ from noteflow.infrastructure.logging import (
request_id_var,
)
# Type alias for callable unary-unary handlers
UnaryUnaryCallable = Callable[[bytes, ServicerContextProtocol], Awaitable[bytes]]
class _DummyRequest:
"""Placeholder request type for handler casts."""
# Type alias for test handler protocol (using bytes for wire-level compatibility)
TestHandlerProtocol = RpcMethodHandlerProtocol[bytes, bytes]
class _DummyResponse:
"""Placeholder response type for handler casts."""
# Test data
TEST_REQUEST_ID = "test-request-123"
TEST_USER_ID = "user-456"
@@ -45,25 +48,28 @@ TEST_METHOD = "/noteflow.NoteFlowService/GetMeeting"
pytestmark = pytest.mark.usefixtures("reset_context_vars")
# Type alias for handler call details matching grpc module
HandlerCallDetails = HandlerCallDetailsProtocol
def create_handler_call_details(
method: str = TEST_METHOD,
metadata: list[tuple[str, str | bytes]] | None = None,
) -> grpc.HandlerCallDetails:
) -> HandlerCallDetails:
"""Create mock HandlerCallDetails with metadata."""
details = MagicMock(spec=grpc.HandlerCallDetails)
details = MagicMock(spec=HandlerCallDetails)
details.method = method
details.invocation_metadata = metadata or []
return details
def create_mock_handler() -> _UnaryUnaryHandler:
def create_mock_handler() -> TestHandlerProtocol:
"""Create a mock RPC method handler."""
return _MockHandler()
def create_mock_continuation(
handler: _UnaryUnaryHandler | None = None,
handler: TestHandlerProtocol | None = None,
) -> AsyncMock:
"""Create a mock continuation function."""
if handler is None:
@@ -71,46 +77,75 @@ def create_mock_continuation(
return AsyncMock(return_value=handler)
class _UnaryUnaryHandler(Protocol):
"""Protocol for unary-unary RPC method handlers."""
unary_unary: Callable[
[_DummyRequest, ServicerContextProtocol],
Awaitable[_DummyResponse],
] | None
unary_stream: object | None
stream_unary: object | None
stream_stream: object | None
request_deserializer: object | None
response_serializer: object | None
class _MockHandler:
"""Concrete handler for tests with typed unary_unary."""
"""Concrete handler for tests with typed unary_unary.
unary_unary: Callable[
[_DummyRequest, ServicerContextProtocol],
Awaitable[_DummyResponse],
] | None
unary_stream: object | None
stream_unary: object | None
stream_stream: object | None
request_deserializer: object | None
response_serializer: object | None
Implements RpcMethodHandlerProtocol for interceptor testing.
"""
def __init__(self) -> None:
self.unary_unary = cast(
Callable[
[_DummyRequest, ServicerContextProtocol],
Awaitable[_DummyResponse],
],
AsyncMock(return_value="response"),
self._unary_unary: Callable[[bytes, ServicerContextProtocol], Awaitable[bytes]] | None = (
cast(
Callable[[bytes, ServicerContextProtocol], Awaitable[bytes]],
AsyncMock(return_value=b"response"),
)
)
self.unary_stream = None
self.stream_unary = None
self.stream_stream = None
self.request_deserializer = None
self.response_serializer = None
self._unary_stream: (
Callable[[bytes, ServicerContextProtocol], AsyncIterator[bytes]] | None
) = None
self._stream_unary: (
Callable[[AsyncIterator[bytes], ServicerContextProtocol], Awaitable[bytes]] | None
) = None
self._stream_stream: (
Callable[[AsyncIterator[bytes], ServicerContextProtocol], AsyncIterator[bytes]] | None
) = None
self._request_deserializer: Callable[[bytes], bytes] | None = None
self._response_serializer: Callable[[bytes], bytes] | None = None
@property
def request_deserializer(self) -> Callable[[bytes], bytes] | None:
"""Request deserializer function."""
return self._request_deserializer
@property
def response_serializer(self) -> Callable[[bytes], bytes] | None:
"""Response serializer function."""
return self._response_serializer
@property
def unary_unary(
self,
) -> Callable[[bytes, ServicerContextProtocol], Awaitable[bytes]] | None:
"""Unary-unary handler behavior."""
return self._unary_unary
@unary_unary.setter
def unary_unary(
self, value: Callable[[bytes, ServicerContextProtocol], Awaitable[bytes]] | None
) -> None:
"""Set unary-unary handler behavior."""
self._unary_unary = value
@property
def unary_stream(
self,
) -> Callable[[bytes, ServicerContextProtocol], AsyncIterator[bytes]] | None:
"""Unary-stream handler behavior."""
return self._unary_stream
@property
def stream_unary(
self,
) -> Callable[[AsyncIterator[bytes], ServicerContextProtocol], Awaitable[bytes]] | None:
"""Stream-unary handler behavior."""
return self._stream_unary
@property
def stream_stream(
self,
) -> Callable[[AsyncIterator[bytes], ServicerContextProtocol], AsyncIterator[bytes]] | None:
"""Stream-stream handler behavior."""
return self._stream_stream
class TestIdentityInterceptor:
@@ -157,8 +192,9 @@ class TestIdentityInterceptor:
original_handler = create_mock_handler()
continuation = create_mock_continuation(original_handler)
handler = await interceptor.intercept_service(continuation, details)
typed_handler = cast(_UnaryUnaryHandler, handler)
typed_handler: TestHandlerProtocol = await interceptor.intercept_service(
continuation, details
)
# Handler should be a rejection handler wrapping the original
assert typed_handler.unary_unary is not None, "handler should have unary_unary"
@@ -166,7 +202,9 @@ class TestIdentityInterceptor:
# is a rejection wrapper that will abort with UNAUTHENTICATED
continuation.assert_called_once()
# The returned handler should NOT be the original handler
assert typed_handler.unary_unary is not original_handler.unary_unary, "should return rejection handler, not original"
assert typed_handler.unary_unary is not original_handler.unary_unary, (
"should return rejection handler, not original"
)
@pytest.mark.asyncio
async def test_reject_handler_aborts_with_unauthenticated(self) -> None:
@@ -175,8 +213,9 @@ class TestIdentityInterceptor:
details = create_handler_call_details(metadata=[])
continuation = create_mock_continuation()
handler = await interceptor.intercept_service(continuation, details)
typed_handler = cast(_UnaryUnaryHandler, handler)
typed_handler: TestHandlerProtocol = await interceptor.intercept_service(
continuation, details
)
# Create mock context to verify abort behavior
context: ServicerContextProtocol = AsyncMock(spec=ServicerContextProtocol)
@@ -184,11 +223,14 @@ class TestIdentityInterceptor:
with pytest.raises(grpc.RpcError, match="x-request-id"):
assert typed_handler.unary_unary is not None
await typed_handler.unary_unary(MagicMock(), context)
unary_fn = typed_handler.unary_unary
await unary_fn(MagicMock(), context)
context.abort.assert_called_once()
call_args = context.abort.call_args
assert call_args[0][0] == grpc.StatusCode.UNAUTHENTICATED, "should abort with UNAUTHENTICATED"
assert call_args[0][0] == grpc.StatusCode.UNAUTHENTICATED, (
"should abort with UNAUTHENTICATED"
)
assert "x-request-id" in call_args[0][1], "error message should mention x-request-id"
@pytest.mark.asyncio
@@ -224,16 +266,16 @@ class TestRequestLoggingInterceptor:
request_id_var.set(TEST_REQUEST_ID)
with patch("noteflow.grpc.interceptors.logging._logging_ops.logger") as mock_logger:
wrapped_handler = await interceptor.intercept_service(continuation, details)
typed_handler = cast(_UnaryUnaryHandler, wrapped_handler)
wrapped_handler: TestHandlerProtocol = await interceptor.intercept_service(
continuation, details
)
# Execute the wrapped handler
context: ServicerContextProtocol = AsyncMock(spec=ServicerContextProtocol)
context.peer = MagicMock(return_value="ipv4:127.0.0.1:12345")
assert typed_handler.unary_unary is not None
await typed_handler.unary_unary(MagicMock(), context)
assert wrapped_handler.unary_unary is not None
unary_fn = wrapped_handler.unary_unary
await unary_fn(MagicMock(), context)
# Verify logging
mock_logger.info.assert_called_once()
call_kwargs = mock_logger.info.call_args[1]
assert call_kwargs["method"] == TEST_METHOD, "should log method"
@@ -246,28 +288,31 @@ class TestRequestLoggingInterceptor:
"""Interceptor logs error status when handler raises."""
interceptor = RequestLoggingInterceptor()
# Create handler that raises
handler = create_mock_handler()
handler.unary_unary = AsyncMock(side_effect=Exception("Test error"))
mock_handler = cast(_MockHandler, handler)
mock_handler.unary_unary = AsyncMock(side_effect=Exception("Test error"))
details = create_handler_call_details(method=TEST_METHOD)
continuation = create_mock_continuation(handler)
with patch("noteflow.grpc.interceptors.logging._logging_ops.logger") as mock_logger:
wrapped_handler = await interceptor.intercept_service(continuation, details)
typed_handler = cast(_UnaryUnaryHandler, wrapped_handler)
wrapped_handler: TestHandlerProtocol = await interceptor.intercept_service(
continuation, details
)
context: ServicerContextProtocol = AsyncMock(spec=ServicerContextProtocol)
context.peer = MagicMock(return_value="ipv4:127.0.0.1:12345")
with pytest.raises(Exception, match="Test error"):
assert typed_handler.unary_unary is not None
await typed_handler.unary_unary(MagicMock(), context)
assert wrapped_handler.unary_unary is not None
unary_fn = wrapped_handler.unary_unary
await unary_fn(MagicMock(), context)
# Should still log with INTERNAL status
mock_logger.info.assert_called_once()
call_kwargs = mock_logger.info.call_args[1]
assert call_kwargs["status"] == "INTERNAL", "should log INTERNAL for unhandled exception"
assert call_kwargs["status"] == "INTERNAL", (
"should log INTERNAL for unhandled exception"
)
@pytest.mark.asyncio
async def test_passes_through_to_continuation(self) -> None:
@@ -277,7 +322,7 @@ class TestRequestLoggingInterceptor:
details = create_handler_call_details()
continuation = create_mock_continuation(handler)
await interceptor.intercept_service(continuation, details)
_: TestHandlerProtocol = await interceptor.intercept_service(continuation, details)
continuation.assert_called_once_with(details)
@@ -290,17 +335,17 @@ class TestRequestLoggingInterceptor:
continuation = create_mock_continuation(handler)
with patch("noteflow.grpc.interceptors.logging._logging_ops.logger") as mock_logger:
wrapped_handler = await interceptor.intercept_service(continuation, details)
typed_handler = cast(_UnaryUnaryHandler, wrapped_handler)
wrapped_handler: TestHandlerProtocol = await interceptor.intercept_service(
continuation, details
)
# Context without peer method
context: ServicerContextProtocol = AsyncMock(spec=ServicerContextProtocol)
context.peer = MagicMock(side_effect=RuntimeError("No peer"))
assert typed_handler.unary_unary is not None
await typed_handler.unary_unary(MagicMock(), context)
assert wrapped_handler.unary_unary is not None
unary_fn = wrapped_handler.unary_unary
await unary_fn(MagicMock(), context)
# Should still log with None peer
mock_logger.info.assert_called_once()
call_kwargs = mock_logger.info.call_args[1]
assert call_kwargs["peer"] is None, "should handle missing peer"