This commit is contained in:
2026-01-21 00:49:40 +00:00
parent f70b35c39f
commit fc7bbd0ea2
36 changed files with 487 additions and 349 deletions

View File

@@ -114,7 +114,7 @@ file_path_pattern := `\.py$`
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"} tool_name in {"write", "edit", "notebookedit"}
file_path := resolved_file_path file_path := resolved_file_path
regex.match(file_path_pattern, file_path) regex.match(file_path_pattern, file_path)
@@ -132,7 +132,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit" tool_name == "multiedit"
some edit in tool_input.edits some edit in tool_input.edits
file_path := edit_path(edit) file_path := edit_path(edit)
@@ -151,7 +151,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Patch", "ApplyPatch"} tool_name in {"patch", "apply_patch"}
patch := patch_content patch := patch_content
patch != null patch != null

View File

@@ -126,7 +126,7 @@ assertion_pattern := `^\s*assert\s+[^,\n]+\n\s*assert\s+[^,\n]+$`
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"} tool_name in {"write", "edit", "notebookedit"}
file_path := resolved_file_path file_path := resolved_file_path
regex.match(file_path_pattern, file_path) regex.match(file_path_pattern, file_path)
@@ -144,7 +144,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit" tool_name == "multiedit"
some edit in tool_input.edits some edit in tool_input.edits
file_path := edit_path(edit) file_path := edit_path(edit)
@@ -163,7 +163,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Patch", "ApplyPatch"} tool_name in {"patch", "apply_patch"}
patch := patch_content patch := patch_content
patch != null patch != null

View File

@@ -126,7 +126,7 @@ ignore_pattern := `//\s*biome-ignore|//\s*@ts-ignore|//\s*@ts-expect-error|//\s*
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"} tool_name in {"write", "edit", "notebookedit"}
file_path := resolved_file_path file_path := resolved_file_path
regex.match(file_path_pattern, file_path) regex.match(file_path_pattern, file_path)
@@ -144,7 +144,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit" tool_name == "multiedit"
some edit in tool_input.edits some edit in tool_input.edits
file_path := edit_path(edit) file_path := edit_path(edit)
@@ -163,7 +163,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Patch", "ApplyPatch"} tool_name in {"patch", "apply_patch"}
patch := patch_content patch := patch_content
patch != null patch != null

View File

@@ -104,7 +104,7 @@ handler_pattern := `except\s+Exception\s*(?:as\s+\w+)?:\s*\n\s+(?:logger\.|loggi
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"} tool_name in {"write", "edit", "notebookedit"}
content := new_content content := new_content
content != null content != null
@@ -119,7 +119,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit" tool_name == "multiedit"
some edit in tool_input.edits some edit in tool_input.edits
content := edit_new_content(edit) content := edit_new_content(edit)
@@ -135,7 +135,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Patch", "ApplyPatch"} tool_name in {"patch", "apply_patch"}
patch := patch_content patch := patch_content
patch != null patch != null

View File

@@ -104,7 +104,7 @@ file_path_pattern := `src/test/code-quality\.test\.ts$`
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"} tool_name in {"write", "edit", "notebookedit"}
file_path := resolved_file_path file_path := resolved_file_path
regex.match(file_path_pattern, file_path) regex.match(file_path_pattern, file_path)
@@ -118,7 +118,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit" tool_name == "multiedit"
some edit in tool_input.edits some edit in tool_input.edits
file_path := edit_path(edit) file_path := edit_path(edit)

View File

@@ -104,7 +104,7 @@ pattern := `return\s+datetime\.now\s*\(`
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"} tool_name in {"write", "edit", "notebookedit"}
content := new_content content := new_content
content != null content != null
@@ -119,7 +119,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit" tool_name == "multiedit"
some edit in tool_input.edits some edit in tool_input.edits
content := edit_new_content(edit) content := edit_new_content(edit)
@@ -135,7 +135,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Patch", "ApplyPatch"} tool_name in {"patch", "apply_patch"}
patch := patch_content patch := patch_content
patch != null patch != null

View File

@@ -104,7 +104,7 @@ pattern := `except\s+\w*(?:Error|Exception).*?:\s*\n\s+.*?(?:logger\.|logging\.)
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"} tool_name in {"write", "edit", "notebookedit"}
content := new_content content := new_content
content != null content != null
@@ -119,7 +119,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit" tool_name == "multiedit"
some edit in tool_input.edits some edit in tool_input.edits
content := edit_new_content(edit) content := edit_new_content(edit)
@@ -135,7 +135,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Patch", "ApplyPatch"} tool_name in {"patch", "apply_patch"}
patch := patch_content patch := patch_content
patch != null patch != null

View File

@@ -127,7 +127,7 @@ fixture_pattern := `@pytest\.fixture[^@]*\ndef\s+(mock_uow|crypto|meetings_dir|w
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"} tool_name in {"write", "edit", "notebookedit"}
file_path := resolved_file_path file_path := resolved_file_path
regex.match(file_path_pattern, file_path) regex.match(file_path_pattern, file_path)
@@ -146,7 +146,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit" tool_name == "multiedit"
some edit in tool_input.edits some edit in tool_input.edits
file_path := edit_path(edit) file_path := edit_path(edit)
@@ -166,7 +166,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Patch", "ApplyPatch"} tool_name in {"patch", "apply_patch"}
patch := patch_content patch := patch_content
patch != null patch != null

View File

@@ -104,7 +104,7 @@ file_path_pattern := `(^|/)client/.*(?:\.?eslint(?:rc|\.config).*|\.?prettier(?:
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"} tool_name in {"write", "edit", "notebookedit"}
file_path := resolved_file_path file_path := resolved_file_path
regex.match(file_path_pattern, file_path) regex.match(file_path_pattern, file_path)
@@ -119,7 +119,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit" tool_name == "multiedit"
some edit in tool_input.edits some edit in tool_input.edits
file_path := edit_path(edit) file_path := edit_path(edit)

View File

@@ -104,7 +104,7 @@ file_path_pattern := `(?:pyproject\.toml|\.?ruff\.toml|\.?pyrightconfig\.json|\.
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"} tool_name in {"write", "edit", "notebookedit"}
file_path := resolved_file_path file_path := resolved_file_path
regex.match(file_path_pattern, file_path) regex.match(file_path_pattern, file_path)
@@ -119,7 +119,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit" tool_name == "multiedit"
some edit in tool_input.edits some edit in tool_input.edits
file_path := edit_path(edit) file_path := edit_path(edit)

View File

@@ -126,7 +126,7 @@ number_pattern := `(?:timeout|delay|interval|duration|limit|max|min|size|count|t
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"} tool_name in {"write", "edit", "notebookedit"}
file_path := resolved_file_path file_path := resolved_file_path
regex.match(file_path_pattern, file_path) regex.match(file_path_pattern, file_path)
@@ -145,7 +145,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit" tool_name == "multiedit"
some edit in tool_input.edits some edit in tool_input.edits
file_path := edit_path(edit) file_path := edit_path(edit)
@@ -165,7 +165,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Patch", "ApplyPatch"} tool_name in {"patch", "apply_patch"}
patch := patch_content patch := patch_content
patch != null patch != null

View File

@@ -104,7 +104,7 @@ file_path_pattern := `(?:^|/)Makefile$`
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"} tool_name in {"write", "edit", "notebookedit"}
file_path := resolved_file_path file_path := resolved_file_path
regex.match(file_path_pattern, file_path) regex.match(file_path_pattern, file_path)
@@ -118,7 +118,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit" tool_name == "multiedit"
some edit in tool_input.edits some edit in tool_input.edits
file_path := edit_path(edit) file_path := edit_path(edit)

View File

@@ -104,7 +104,7 @@ pattern := `except\s+\w*Error.*?:\s*\n\s+.*?(?:logger\.|logging\.).*?\n\s+return
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"} tool_name in {"write", "edit", "notebookedit"}
content := new_content content := new_content
content != null content != null
@@ -119,7 +119,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit" tool_name == "multiedit"
some edit in tool_input.edits some edit in tool_input.edits
content := edit_new_content(edit) content := edit_new_content(edit)
@@ -135,7 +135,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Patch", "ApplyPatch"} tool_name in {"patch", "apply_patch"}
patch := patch_content patch := patch_content
patch != null patch != null

View File

@@ -126,7 +126,7 @@ pattern := `def test_[^(]+\([^)]*\)[^:]*:[\s\S]*?\b(for|while|if)\s+[^:]+:[\s\S]
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"} tool_name in {"write", "edit", "notebookedit"}
file_path := resolved_file_path file_path := resolved_file_path
regex.match(file_path_pattern, file_path) regex.match(file_path_pattern, file_path)
@@ -144,7 +144,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit" tool_name == "multiedit"
some edit in tool_input.edits some edit in tool_input.edits
file_path := edit_path(edit) file_path := edit_path(edit)
@@ -163,7 +163,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Patch", "ApplyPatch"} tool_name in {"patch", "apply_patch"}
patch := patch_content patch := patch_content
patch != null patch != null

View File

@@ -105,7 +105,7 @@ exclude_pattern := `baselines\.json$`
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"} tool_name in {"write", "edit", "notebookedit"}
file_path := resolved_file_path file_path := resolved_file_path
regex.match(file_path_pattern, file_path) regex.match(file_path_pattern, file_path)
@@ -120,7 +120,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit" tool_name == "multiedit"
some edit in tool_input.edits some edit in tool_input.edits
file_path := edit_path(edit) file_path := edit_path(edit)

View File

@@ -128,7 +128,7 @@ any_type_patterns := [
# Block Write/Edit operations that introduce Any in Python files # Block Write/Edit operations that introduce Any in Python files
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"} tool_name in {"write", "edit", "notebookedit"}
# Only enforce for Python files # Only enforce for Python files
file_path := lower(resolved_file_path) file_path := lower(resolved_file_path)
@@ -149,7 +149,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Patch", "ApplyPatch"} tool_name in {"patch", "apply_patch"}
content := patch_content content := patch_content
content != null content != null
@@ -166,7 +166,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit" tool_name == "multiedit"
some edit in tool_input.edits some edit in tool_input.edits
file_path := lower(edit_path(edit)) file_path := lower(edit_path(edit))

View File

@@ -123,7 +123,7 @@ type_suppression_patterns := [
# Block Write/Edit operations that introduce type suppression in Python files # Block Write/Edit operations that introduce type suppression in Python files
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"} tool_name in {"write", "edit", "notebookedit"}
# Only enforce for Python files # Only enforce for Python files
file_path := lower(resolved_file_path) file_path := lower(resolved_file_path)
@@ -144,7 +144,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Patch", "ApplyPatch"} tool_name in {"patch", "apply_patch"}
content := patch_content content := patch_content
content != null content != null
@@ -161,7 +161,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit" tool_name == "multiedit"
some edit in tool_input.edits some edit in tool_input.edits
file_path := lower(edit_path(edit)) file_path := lower(edit_path(edit))

View File

@@ -104,7 +104,7 @@ file_path_pattern := `tests/quality/baselines\.json$`
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"} tool_name in {"write", "edit", "notebookedit"}
file_path := resolved_file_path file_path := resolved_file_path
regex.match(file_path_pattern, file_path) regex.match(file_path_pattern, file_path)
@@ -118,7 +118,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit" tool_name == "multiedit"
some edit in tool_input.edits some edit in tool_input.edits
file_path := edit_path(edit) file_path := edit_path(edit)

View File

@@ -126,7 +126,7 @@ pattern := `(?:.*\n){500,}`
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Write", "Edit", "NotebookEdit"} tool_name in {"write", "edit", "notebookedit"}
file_path := resolved_file_path file_path := resolved_file_path
regex.match(file_path_pattern, file_path) regex.match(file_path_pattern, file_path)
@@ -144,7 +144,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name == "MultiEdit" tool_name == "multiedit"
some edit in tool_input.edits some edit in tool_input.edits
file_path := edit_path(edit) file_path := edit_path(edit)
@@ -163,7 +163,7 @@ deny contains decision if {
deny contains decision if { deny contains decision if {
input.hook_event_name == "PreToolUse" input.hook_event_name == "PreToolUse"
tool_name in {"Patch", "ApplyPatch"} tool_name in {"patch", "apply_patch"}
patch := patch_content patch := patch_content
patch != null patch != null

View File

@@ -7,7 +7,7 @@ from datetime import datetime
from enum import StrEnum from enum import StrEnum
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from noteflow.domain.constants.fields import CALENDAR, EMAIL from noteflow.domain.constants.fields import CALENDAR, EMAIL, UNKNOWN as UNKNOWN_VALUE
from noteflow.domain.utils.time import utc_now from noteflow.domain.utils.time import utc_now
from noteflow.infrastructure.logging import log_state_transition from noteflow.infrastructure.logging import log_state_transition
@@ -150,7 +150,7 @@ class SyncErrorCode(StrEnum):
AUTH_REQUIRED = "auth_required" AUTH_REQUIRED = "auth_required"
PROVIDER_ERROR = "provider_error" PROVIDER_ERROR = "provider_error"
INTERNAL_ERROR = "internal_error" INTERNAL_ERROR = "internal_error"
UNKNOWN = "unknown" UNKNOWN = UNKNOWN_VALUE
@dataclass @dataclass

View File

@@ -16,9 +16,14 @@ from typing import NoReturn, Protocol, TypeVar, cast
import grpc import grpc
TRequest = TypeVar("TRequest") TRequest = TypeVar("TRequest")
TResponse = TypeVar("TResponse") TResponse = TypeVar("TResponse")
# Type variables for RpcMethodHandlerProtocol variance
TRequest_co = TypeVar("TRequest_co", covariant=True)
TResponse_contra = TypeVar("TResponse_contra", contravariant=True)
# Metadata type alias # Metadata type alias
MetadataLike = Sequence[tuple[str, str | bytes]] MetadataLike = Sequence[tuple[str, str | bytes]]
@@ -55,42 +60,74 @@ class ServicerContextProtocol(Protocol):
... ...
class RpcMethodHandlerProtocol(Protocol): class HandlerCallDetailsProtocol(Protocol):
"""Protocol for gRPC HandlerCallDetails.
This matches the grpc.HandlerCallDetails interface at runtime.
"""
method: str | None
invocation_metadata: MetadataLike | None
class ServerInterceptorProtocol(Protocol):
"""Protocol for grpc.aio.ServerInterceptor."""
async def intercept_service(
self,
continuation: Callable[
[HandlerCallDetailsProtocol],
Awaitable[RpcMethodHandlerProtocol[TRequest, TResponse]],
],
handler_call_details: HandlerCallDetailsProtocol,
) -> RpcMethodHandlerProtocol[TRequest, TResponse]:
...
class RpcMethodHandlerProtocol[TRequest_co, TResponse_contra](Protocol):
"""Protocol for RPC method handlers. """Protocol for RPC method handlers.
This matches the grpc.RpcMethodHandler interface at runtime. This matches the grpc.RpcMethodHandler interface at runtime.
Note: The handler properties return objects that are callable with
specific request/response types, but we use object here for Protocol
compatibility. Use casts at call sites for proper typing.
""" """
@property @property
def request_deserializer(self) -> Callable[[bytes], object] | None: def request_deserializer(self) -> Callable[[bytes], TRequest_co] | None:
"""Request deserializer function.""" """Request deserializer function."""
... ...
@property @property
def response_serializer(self) -> Callable[[object], bytes] | None: def response_serializer(self) -> Callable[[TResponse_contra], bytes] | None:
"""Response serializer function.""" """Response serializer function."""
... ...
@property @property
def unary_unary(self) -> object | None: def unary_unary(
self,
) -> Callable[[TRequest_co, ServicerContextProtocol], Awaitable[TResponse_contra]] | None:
"""Unary-unary handler behavior.""" """Unary-unary handler behavior."""
... ...
@property @property
def unary_stream(self) -> object | None: def unary_stream(
self,
) -> Callable[[TRequest_co, ServicerContextProtocol], AsyncIterator[TResponse_contra]] | None:
"""Unary-stream handler behavior.""" """Unary-stream handler behavior."""
... ...
@property @property
def stream_unary(self) -> object | None: def stream_unary(
self,
) -> Callable[[AsyncIterator[TRequest_co], ServicerContextProtocol], Awaitable[TResponse_contra]] | None:
"""Stream-unary handler behavior.""" """Stream-unary handler behavior."""
... ...
@property @property
def stream_stream(self) -> object | None: def stream_stream(
self,
) -> Callable[
[AsyncIterator[TRequest_co], ServicerContextProtocol],
AsyncIterator[TResponse_contra],
] | None:
"""Stream-stream handler behavior.""" """Stream-stream handler behavior."""
... ...
@@ -112,7 +149,7 @@ class GrpcFactoriesProtocol(Protocol):
*, *,
request_deserializer: Callable[[bytes], TRequest] | None = None, request_deserializer: Callable[[bytes], TRequest] | None = None,
response_serializer: Callable[[TResponse], bytes] | None = None, response_serializer: Callable[[TResponse], bytes] | None = None,
) -> RpcMethodHandlerProtocol: ) -> RpcMethodHandlerProtocol[TRequest, TResponse]:
"""Create a unary-unary RPC method handler.""" """Create a unary-unary RPC method handler."""
... ...
@@ -125,7 +162,7 @@ class GrpcFactoriesProtocol(Protocol):
*, *,
request_deserializer: Callable[[bytes], TRequest] | None = None, request_deserializer: Callable[[bytes], TRequest] | None = None,
response_serializer: Callable[[TResponse], bytes] | None = None, response_serializer: Callable[[TResponse], bytes] | None = None,
) -> RpcMethodHandlerProtocol: ) -> RpcMethodHandlerProtocol[TRequest, TResponse]:
"""Create a unary-stream RPC method handler.""" """Create a unary-stream RPC method handler."""
... ...
@@ -138,7 +175,7 @@ class GrpcFactoriesProtocol(Protocol):
*, *,
request_deserializer: Callable[[bytes], TRequest] | None = None, request_deserializer: Callable[[bytes], TRequest] | None = None,
response_serializer: Callable[[TResponse], bytes] | None = None, response_serializer: Callable[[TResponse], bytes] | None = None,
) -> RpcMethodHandlerProtocol: ) -> RpcMethodHandlerProtocol[TRequest, TResponse]:
"""Create a stream-unary RPC method handler.""" """Create a stream-unary RPC method handler."""
... ...
@@ -151,13 +188,11 @@ class GrpcFactoriesProtocol(Protocol):
*, *,
request_deserializer: Callable[[bytes], TRequest] | None = None, request_deserializer: Callable[[bytes], TRequest] | None = None,
response_serializer: Callable[[TResponse], bytes] | None = None, response_serializer: Callable[[TResponse], bytes] | None = None,
) -> RpcMethodHandlerProtocol: ) -> RpcMethodHandlerProtocol[TRequest, TResponse]:
"""Create a stream-stream RPC method handler.""" """Create a stream-stream RPC method handler."""
... ...
# Typed grpc module with proper factory function signatures # Typed grpc module with proper factory function signatures
# Use this instead of importing grpc directly when you need typed factories # Use this instead of importing grpc directly when you need typed factories
typed_grpc: GrpcFactoriesProtocol = cast( typed_grpc: GrpcFactoriesProtocol = cast(GrpcFactoriesProtocol, importlib.import_module("grpc"))
GrpcFactoriesProtocol, importlib.import_module("grpc")
)

View File

@@ -11,11 +11,9 @@ from __future__ import annotations
from collections.abc import AsyncIterator, Awaitable, Callable from collections.abc import AsyncIterator, Awaitable, Callable
from functools import partial from functools import partial
from typing import NoReturn, cast from typing import TYPE_CHECKING, NoReturn, TypeVar, cast
import grpc import grpc
from grpc import HandlerCallDetails, RpcMethodHandler
from grpc.aio import ServerInterceptor
from noteflow.infrastructure.logging import ( from noteflow.infrastructure.logging import (
get_logger, get_logger,
@@ -25,13 +23,22 @@ from noteflow.infrastructure.logging import (
) )
from ._types import ( from ._types import (
HandlerCallDetailsProtocol,
RpcMethodHandlerProtocol, RpcMethodHandlerProtocol,
ServerInterceptorProtocol,
ServicerContextProtocol, ServicerContextProtocol,
TRequest,
TResponse,
typed_grpc, typed_grpc,
) )
_TRequest = TypeVar("_TRequest")
_TResponse = TypeVar("_TResponse")
if TYPE_CHECKING:
ServerInterceptor = ServerInterceptorProtocol
else:
from grpc.aio import ServerInterceptor
logger = get_logger(__name__) logger = get_logger(__name__)
# Metadata keys for identity context # Metadata keys for identity context
@@ -126,21 +133,18 @@ class IdentityInterceptor(ServerInterceptor):
async def intercept_service( async def intercept_service(
self, self,
continuation: Callable[ continuation: Callable[
[HandlerCallDetails], [HandlerCallDetailsProtocol],
Awaitable[RpcMethodHandler[TRequest, TResponse]], Awaitable[RpcMethodHandlerProtocol[_TRequest, _TResponse]],
], ],
handler_call_details: HandlerCallDetails, handler_call_details: HandlerCallDetailsProtocol,
) -> RpcMethodHandler[TRequest, TResponse]: ) -> RpcMethodHandlerProtocol[_TRequest, _TResponse]:
"""Intercept incoming RPC calls to validate and set identity context.""" """Intercept incoming RPC calls to validate and set identity context."""
metadata = dict(handler_call_details.invocation_metadata or []) metadata = dict(handler_call_details.invocation_metadata or [])
request_id = _get_request_id(metadata, handler_call_details.method) request_id = _get_request_id(metadata, handler_call_details.method)
if request_id is None: if request_id is None:
handler = await continuation(handler_call_details) handler = await continuation(handler_call_details)
return cast( return _create_unauthenticated_handler(handler, _ERR_MISSING_REQUEST_ID)
RpcMethodHandler[TRequest, TResponse],
_create_unauthenticated_handler(handler, _ERR_MISSING_REQUEST_ID),
)
_apply_identity_context(metadata, request_id) _apply_identity_context(metadata, request_id)
@@ -156,35 +160,50 @@ class IdentityInterceptor(ServerInterceptor):
def _create_unauthenticated_handler( def _create_unauthenticated_handler(
handler: RpcMethodHandlerProtocol, handler: RpcMethodHandlerProtocol[_TRequest, _TResponse],
message: str, message: str,
) -> RpcMethodHandlerProtocol: ) -> RpcMethodHandlerProtocol[_TRequest, _TResponse]:
"""Create a handler that rejects with UNAUTHENTICATED status.""" """Create a handler that rejects with UNAUTHENTICATED status."""
request_deserializer = handler.request_deserializer request_deserializer = handler.request_deserializer
response_serializer = handler.response_serializer response_serializer = handler.response_serializer
if response_serializer is not None:
assert callable(response_serializer)
response_serializer_any = cast(Callable[[object], bytes] | None, response_serializer)
if handler.unary_unary is not None: if handler.unary_unary is not None:
return typed_grpc.unary_unary_rpc_method_handler( return cast(
partial(_reject_unary_unary, message), RpcMethodHandlerProtocol[_TRequest, _TResponse],
request_deserializer=request_deserializer, typed_grpc.unary_unary_rpc_method_handler(
response_serializer=response_serializer, partial(_reject_unary_unary, message),
request_deserializer=request_deserializer,
response_serializer=response_serializer_any,
),
) )
if handler.unary_stream is not None: if handler.unary_stream is not None:
return typed_grpc.unary_stream_rpc_method_handler( return cast(
partial(_reject_unary_stream, message), RpcMethodHandlerProtocol[_TRequest, _TResponse],
request_deserializer=request_deserializer, typed_grpc.unary_stream_rpc_method_handler(
response_serializer=response_serializer, partial(_reject_unary_stream, message),
request_deserializer=request_deserializer,
response_serializer=response_serializer_any,
),
) )
if handler.stream_unary is not None: if handler.stream_unary is not None:
return typed_grpc.stream_unary_rpc_method_handler( return cast(
partial(_reject_stream_unary, message), RpcMethodHandlerProtocol[_TRequest, _TResponse],
request_deserializer=request_deserializer, typed_grpc.stream_unary_rpc_method_handler(
response_serializer=response_serializer, partial(_reject_stream_unary, message),
request_deserializer=request_deserializer,
response_serializer=response_serializer_any,
),
) )
if handler.stream_stream is not None: if handler.stream_stream is not None:
return typed_grpc.stream_stream_rpc_method_handler( return cast(
partial(_reject_stream_stream, message), RpcMethodHandlerProtocol[_TRequest, _TResponse],
request_deserializer=request_deserializer, typed_grpc.stream_stream_rpc_method_handler(
response_serializer=response_serializer, partial(_reject_stream_stream, message),
request_deserializer=request_deserializer,
response_serializer=response_serializer_any,
),
) )
return handler return handler

View File

@@ -3,10 +3,13 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import AsyncIterator, Awaitable, Callable from collections.abc import AsyncIterator, Awaitable, Callable
from typing import cast
from .._types import ( from .._types import (
RpcMethodHandlerProtocol, RpcMethodHandlerProtocol,
ServicerContextProtocol, ServicerContextProtocol,
TRequest,
TResponse,
typed_grpc, typed_grpc,
) )
from ._wrappers import ( from ._wrappers import (
@@ -17,90 +20,118 @@ from ._wrappers import (
) )
def wrap_unary_unary_handler( def _wrap_unary_unary_handler(
handler: RpcMethodHandlerProtocol, handler: RpcMethodHandlerProtocol[TRequest, TResponse],
method: str, method: str,
) -> RpcMethodHandlerProtocol: ) -> RpcMethodHandlerProtocol[TRequest, TResponse]:
"""Create wrapped unary-unary handler with logging."""
if handler.unary_unary is None: if handler.unary_unary is None:
raise TypeError("Unary-unary handler is missing") raise TypeError("Unary-unary handler is missing")
wrapped: Callable[ wrapped: Callable[
[object, ServicerContextProtocol], [object, ServicerContextProtocol],
Awaitable[object], Awaitable[object],
] = wrap_unary_unary(handler.unary_unary, method) ] = wrap_unary_unary(
request_deserializer = handler.request_deserializer cast(Callable[[object, ServicerContextProtocol], Awaitable[object]], handler.unary_unary),
response_serializer = handler.response_serializer method,
return typed_grpc.unary_unary_rpc_method_handler( )
wrapped, return cast(
request_deserializer=request_deserializer, RpcMethodHandlerProtocol[TRequest, TResponse],
response_serializer=response_serializer, typed_grpc.unary_unary_rpc_method_handler(
wrapped,
request_deserializer=cast(
Callable[[bytes], object] | None, handler.request_deserializer
),
response_serializer=cast(Callable[[object], bytes] | None, handler.response_serializer),
),
) )
def wrap_unary_stream_handler( def _wrap_unary_stream_handler(
handler: RpcMethodHandlerProtocol, handler: RpcMethodHandlerProtocol[TRequest, TResponse],
method: str, method: str,
) -> RpcMethodHandlerProtocol: ) -> RpcMethodHandlerProtocol[TRequest, TResponse]:
"""Create wrapped unary-stream handler with logging."""
if handler.unary_stream is None: if handler.unary_stream is None:
raise TypeError("Unary-stream handler is missing") raise TypeError("Unary-stream handler is missing")
wrapped: Callable[ wrapped: Callable[
[object, ServicerContextProtocol], [object, ServicerContextProtocol],
AsyncIterator[object], AsyncIterator[object],
] = wrap_unary_stream(handler.unary_stream, method) ] = wrap_unary_stream(
request_deserializer = handler.request_deserializer cast(
response_serializer = handler.response_serializer Callable[[object, ServicerContextProtocol], AsyncIterator[object]], handler.unary_stream
return typed_grpc.unary_stream_rpc_method_handler( ),
wrapped, method,
request_deserializer=request_deserializer, )
response_serializer=response_serializer, return cast(
RpcMethodHandlerProtocol[TRequest, TResponse],
typed_grpc.unary_stream_rpc_method_handler(
wrapped,
request_deserializer=cast(
Callable[[bytes], object] | None, handler.request_deserializer
),
response_serializer=cast(Callable[[object], bytes] | None, handler.response_serializer),
),
) )
def wrap_stream_unary_handler( def _wrap_stream_unary_handler(
handler: RpcMethodHandlerProtocol, handler: RpcMethodHandlerProtocol[TRequest, TResponse],
method: str, method: str,
) -> RpcMethodHandlerProtocol: ) -> RpcMethodHandlerProtocol[TRequest, TResponse]:
"""Create wrapped stream-unary handler with logging."""
if handler.stream_unary is None: if handler.stream_unary is None:
raise TypeError("Stream-unary handler is missing") raise TypeError("Stream-unary handler is missing")
wrapped: Callable[ wrapped: Callable[
[AsyncIterator[object], ServicerContextProtocol], [AsyncIterator[object], ServicerContextProtocol],
Awaitable[object], Awaitable[object],
] = wrap_stream_unary(handler.stream_unary, method) ] = wrap_stream_unary(
request_deserializer = handler.request_deserializer cast(
response_serializer = handler.response_serializer Callable[[AsyncIterator[object], ServicerContextProtocol], Awaitable[object]],
return typed_grpc.stream_unary_rpc_method_handler( handler.stream_unary,
wrapped, ),
request_deserializer=request_deserializer, method,
response_serializer=response_serializer, )
return cast(
RpcMethodHandlerProtocol[TRequest, TResponse],
typed_grpc.stream_unary_rpc_method_handler(
wrapped,
request_deserializer=cast(
Callable[[bytes], object] | None, handler.request_deserializer
),
response_serializer=cast(Callable[[object], bytes] | None, handler.response_serializer),
),
) )
def wrap_stream_stream_handler( def _wrap_stream_stream_handler(
handler: RpcMethodHandlerProtocol, handler: RpcMethodHandlerProtocol[TRequest, TResponse],
method: str, method: str,
) -> RpcMethodHandlerProtocol: ) -> RpcMethodHandlerProtocol[TRequest, TResponse]:
"""Create wrapped stream-stream handler with logging."""
if handler.stream_stream is None: if handler.stream_stream is None:
raise TypeError("Stream-stream handler is missing") raise TypeError("Stream-stream handler is missing")
wrapped: Callable[ wrapped: Callable[
[AsyncIterator[object], ServicerContextProtocol], [AsyncIterator[object], ServicerContextProtocol],
AsyncIterator[object], AsyncIterator[object],
] = wrap_stream_stream(handler.stream_stream, method) ] = wrap_stream_stream(
request_deserializer = handler.request_deserializer cast(
response_serializer = handler.response_serializer Callable[[AsyncIterator[object], ServicerContextProtocol], AsyncIterator[object]],
return typed_grpc.stream_stream_rpc_method_handler( handler.stream_stream,
wrapped, ),
request_deserializer=request_deserializer, method,
response_serializer=response_serializer, )
return cast(
RpcMethodHandlerProtocol[TRequest, TResponse],
typed_grpc.stream_stream_rpc_method_handler(
wrapped,
request_deserializer=cast(
Callable[[bytes], object] | None, handler.request_deserializer
),
response_serializer=cast(Callable[[object], bytes] | None, handler.response_serializer),
),
) )
def create_logging_handler( def create_logging_handler(
handler: RpcMethodHandlerProtocol, handler: RpcMethodHandlerProtocol[TRequest, TResponse],
method: str, method: str,
) -> RpcMethodHandlerProtocol: ) -> RpcMethodHandlerProtocol[TRequest, TResponse]:
"""Wrap an RPC handler to add request logging. """Wrap an RPC handler to add request logging.
Args: Args:
@@ -111,12 +142,11 @@ def create_logging_handler(
Wrapped handler with logging. Wrapped handler with logging.
""" """
if handler.unary_unary is not None: if handler.unary_unary is not None:
return wrap_unary_unary_handler(handler, method) return _wrap_unary_unary_handler(handler, method)
if handler.unary_stream is not None: if handler.unary_stream is not None:
return wrap_unary_stream_handler(handler, method) return _wrap_unary_stream_handler(handler, method)
if handler.stream_unary is not None: if handler.stream_unary is not None:
return wrap_stream_unary_handler(handler, method) return _wrap_stream_unary_handler(handler, method)
if handler.stream_stream is not None: if handler.stream_stream is not None:
return wrap_stream_stream_handler(handler, method) return _wrap_stream_stream_handler(handler, method)
# Fallback: return original handler if type unknown
return handler return handler

View File

@@ -1,19 +1,29 @@
"""Request logging interceptor for gRPC calls. """Request logging interceptor for gRPC calls.
Log every RPC call with method, status, duration, peer, and request context Log every RPC call with method, status, duration, peer, and request context at
at INFO level for production observability and traceability. INFO level for production observability and traceability.
""" """
from __future__ import annotations from __future__ import annotations
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import cast from typing import TYPE_CHECKING, TypeVar
from grpc import HandlerCallDetails, RpcMethodHandler
from grpc.aio import ServerInterceptor
from .._types import (
HandlerCallDetailsProtocol,
RpcMethodHandlerProtocol,
ServerInterceptorProtocol,
)
from ._handler_factory import create_logging_handler from ._handler_factory import create_logging_handler
from .._types import TRequest, TResponse
_TRequest = TypeVar("_TRequest")
_TResponse = TypeVar("_TResponse")
if TYPE_CHECKING:
ServerInterceptor = ServerInterceptorProtocol
else:
from grpc.aio import ServerInterceptor
class RequestLoggingInterceptor(ServerInterceptor): class RequestLoggingInterceptor(ServerInterceptor):
@@ -30,25 +40,12 @@ class RequestLoggingInterceptor(ServerInterceptor):
async def intercept_service( async def intercept_service(
self, self,
continuation: Callable[ continuation: Callable[
[HandlerCallDetails], [HandlerCallDetailsProtocol],
Awaitable[RpcMethodHandler[TRequest, TResponse]], Awaitable[RpcMethodHandlerProtocol[_TRequest, _TResponse]],
], ],
handler_call_details: HandlerCallDetails, handler_call_details: HandlerCallDetailsProtocol,
) -> RpcMethodHandler[TRequest, TResponse]: ) -> RpcMethodHandlerProtocol[_TRequest, _TResponse]:
"""Intercept incoming RPC calls to log request timing and status. """Intercept incoming RPC calls to log request timing and status."""
Args:
continuation: The next interceptor or handler.
handler_call_details: Details about the RPC call.
Returns:
Wrapped RPC handler that logs on completion.
"""
handler = await continuation(handler_call_details) handler = await continuation(handler_call_details)
method = handler_call_details.method method = handler_call_details.method or ""
return create_logging_handler(handler, method)
# Return wrapped handler that logs on completion
return cast(
RpcMethodHandler[TRequest, TResponse],
create_logging_handler(handler, method),
)

View File

@@ -77,7 +77,6 @@ async def require_calendar_service(
return host.calendar_service return host.calendar_service
logger.warning(f"{operation}_unavailable", reason="service_not_enabled") logger.warning(f"{operation}_unavailable", reason="service_not_enabled")
await abort_unavailable(context, _ERR_CALENDAR_NOT_ENABLED) await abort_unavailable(context, _ERR_CALENDAR_NOT_ENABLED)
raise # Unreachable but helps type checker
class CalendarMixin: class CalendarMixin:
@@ -117,7 +116,6 @@ class CalendarMixin:
except CalendarServiceError as e: except CalendarServiceError as e:
logger.error("calendar_list_events_failed", error=str(e), provider=provider) logger.error("calendar_list_events_failed", error=str(e), provider=provider)
await abort_internal(context, str(e)) await abort_internal(context, str(e))
raise # Unreachable but helps type checker
proto_events = [_calendar_event_to_proto(event) for event in events] proto_events = [_calendar_event_to_proto(event) for event in events]
@@ -166,9 +164,7 @@ class CalendarMixin:
status=status.status, status=status.status,
) )
authenticated_count = sum( authenticated_count = sum(bool(provider.is_authenticated) for provider in providers)
bool(provider.is_authenticated) for provider in providers
)
logger.info( logger.info(
"calendar_get_providers_success", "calendar_get_providers_success",
total_providers=len(providers), total_providers=len(providers),
@@ -205,7 +201,6 @@ class CalendarMixin:
error=str(e), error=str(e),
) )
await abort_invalid_argument(context, str(e)) await abort_invalid_argument(context, str(e))
raise # Unreachable but helps type checker
logger.info( logger.info(
"oauth_initiate_success", "oauth_initiate_success",

View File

@@ -92,20 +92,18 @@ class DiarizationJobMixin:
Prunes both in-memory task references and database records. Prunes both in-memory task references and database records.
""" """
# Clean up in-memory task references for completed tasks # Clean up in-memory task references for completed tasks
completed_tasks = [ completed_tasks = [job_id for job_id, task in self.diarization_tasks.items() if task.done()]
job_id for job_id, task in self.diarization_tasks.items() if task.done()
]
for job_id in completed_tasks: for job_id in completed_tasks:
self.diarization_tasks.pop(job_id, None) self.diarization_tasks.pop(job_id, None)
# Prune old completed jobs from database # Prune old completed jobs from database
async with cast(DiarizationJobRepositoryProvider, self.create_repository_provider()) as repo: async with cast(
DiarizationJobRepositoryProvider, self.create_repository_provider()
) as repo:
if not repo.supports_diarization_jobs: if not repo.supports_diarization_jobs:
logger.debug("Job pruning skipped: database required") logger.debug("Job pruning skipped: database required")
return return
pruned = await repo.diarization_jobs.prune_completed( pruned = await repo.diarization_jobs.prune_completed(self.diarization_job_ttl_seconds)
self.diarization_job_ttl_seconds
)
await repo.commit() await repo.commit()
if pruned > 0: if pruned > 0:
logger.debug("Pruned %d completed diarization jobs", pruned) logger.debug("Pruned %d completed diarization jobs", pruned)
@@ -121,15 +119,15 @@ class DiarizationJobMixin:
""" """
await self.prune_diarization_jobs() await self.prune_diarization_jobs()
async with cast(DiarizationJobRepositoryProvider, self.create_repository_provider()) as repo: async with cast(
DiarizationJobRepositoryProvider, self.create_repository_provider()
) as repo:
if not repo.supports_diarization_jobs: if not repo.supports_diarization_jobs:
await abort_not_found(context, "Diarization jobs (database required)", "") await abort_not_found(context, "Diarization jobs (database required)", "")
raise # Unreachable but helps type checker
job = await repo.diarization_jobs.get(request.job_id) job = await repo.diarization_jobs.get(request.job_id)
if job is None: if job is None:
await abort_not_found(context, "Diarization job", request.job_id) await abort_not_found(context, "Diarization job", request.job_id)
raise # Unreachable but helps type checker
return _build_job_status(job) return _build_job_status(job)
@@ -145,7 +143,9 @@ class DiarizationJobMixin:
job_id = request.job_id job_id = request.job_id
await _cancel_running_task(self.diarization_tasks, job_id) await _cancel_running_task(self.diarization_tasks, job_id)
async with cast(DiarizationJobRepositoryProvider, self.create_repository_provider()) as repo: async with cast(
DiarizationJobRepositoryProvider, self.create_repository_provider()
) as repo:
if not repo.supports_diarization_jobs: if not repo.supports_diarization_jobs:
await abort_database_required(context, "Diarization job cancellation") await abort_database_required(context, "Diarization job cancellation")
raise AssertionError(UNREACHABLE_ERROR) # abort is NoReturn raise AssertionError(UNREACHABLE_ERROR) # abort is NoReturn
@@ -185,7 +185,9 @@ class DiarizationJobMixin:
""" """
response = noteflow_pb2.GetActiveDiarizationJobsResponse() response = noteflow_pb2.GetActiveDiarizationJobsResponse()
async with cast(DiarizationJobRepositoryProvider, self.create_repository_provider()) as repo: async with cast(
DiarizationJobRepositoryProvider, self.create_repository_provider()
) as repo:
if not repo.supports_diarization_jobs: if not repo.supports_diarization_jobs:
# Return empty list if DB not available # Return empty list if DB not available
return response return response

View File

@@ -175,7 +175,6 @@ class EntitiesMixin:
if updated is None: if updated is None:
await abort_not_found(context, ENTITY_ENTITY, request.entity_id) await abort_not_found(context, ENTITY_ENTITY, request.entity_id)
raise # Unreachable but helps type checker
await uow.commit() await uow.commit()

View File

@@ -75,9 +75,7 @@ def _merge_workspace_settings(
trigger_rules=updates.trigger_rules trigger_rules=updates.trigger_rules
if updates.trigger_rules is not None if updates.trigger_rules is not None
else current.trigger_rules, else current.trigger_rules,
rag_enabled=updates.rag_enabled rag_enabled=updates.rag_enabled if updates.rag_enabled is not None else current.rag_enabled,
if updates.rag_enabled is not None
else current.rag_enabled,
default_summarization_template=updates.default_summarization_template default_summarization_template=updates.default_summarization_template
if updates.default_summarization_template is not None if updates.default_summarization_template is not None
else current.default_summarization_template, else current.default_summarization_template,
@@ -192,14 +190,12 @@ class IdentityMixin:
if not request.workspace_id: if not request.workspace_id:
await abort_invalid_argument(context, ERROR_WORKSPACE_ID_REQUIRED) await abort_invalid_argument(context, ERROR_WORKSPACE_ID_REQUIRED)
raise # Unreachable but helps type checker
workspace_id = await parse_workspace_id(request.workspace_id, context) workspace_id = await parse_workspace_id(request.workspace_id, context)
async with cast(UnitOfWork, self.create_repository_provider()) as uow: async with cast(UnitOfWork, self.create_repository_provider()) as uow:
if not uow.supports_workspaces: if not uow.supports_workspaces:
await abort_database_required(context, WORKSPACES_LABEL) await abort_database_required(context, WORKSPACES_LABEL)
raise # Unreachable but helps type checker
user_ctx = await self.identity_service.get_or_create_default_user(uow) user_ctx = await self.identity_service.get_or_create_default_user(uow)
workspace, membership = await self._verify_workspace_access( workspace, membership = await self._verify_workspace_access(
@@ -231,14 +227,12 @@ class IdentityMixin:
"""Get workspace settings.""" """Get workspace settings."""
if not request.workspace_id: if not request.workspace_id:
await abort_invalid_argument(context, ERROR_WORKSPACE_ID_REQUIRED) await abort_invalid_argument(context, ERROR_WORKSPACE_ID_REQUIRED)
raise # Unreachable but helps type checker
workspace_id = await parse_workspace_id(request.workspace_id, context) workspace_id = await parse_workspace_id(request.workspace_id, context)
async with cast(UnitOfWork, self.create_repository_provider()) as uow: async with cast(UnitOfWork, self.create_repository_provider()) as uow:
if not uow.supports_workspaces: if not uow.supports_workspaces:
await abort_database_required(context, WORKSPACES_LABEL) await abort_database_required(context, WORKSPACES_LABEL)
raise # Unreachable but helps type checker
user_ctx = await self.identity_service.get_or_create_default_user(uow) user_ctx = await self.identity_service.get_or_create_default_user(uow)
workspace, _ = await self._verify_workspace_access( workspace, _ = await self._verify_workspace_access(
@@ -259,14 +253,12 @@ class IdentityMixin:
"""Update workspace settings.""" """Update workspace settings."""
if not request.workspace_id: if not request.workspace_id:
await abort_invalid_argument(context, ERROR_WORKSPACE_ID_REQUIRED) await abort_invalid_argument(context, ERROR_WORKSPACE_ID_REQUIRED)
raise # Unreachable but helps type checker
workspace_id = await parse_workspace_id(request.workspace_id, context) workspace_id = await parse_workspace_id(request.workspace_id, context)
async with cast(UnitOfWork, self.create_repository_provider()) as uow: async with cast(UnitOfWork, self.create_repository_provider()) as uow:
if not uow.supports_workspaces: if not uow.supports_workspaces:
await abort_database_required(context, WORKSPACES_LABEL) await abort_database_required(context, WORKSPACES_LABEL)
raise # Unreachable but helps type checker
user_ctx = await self.identity_service.get_or_create_default_user(uow) user_ctx = await self.identity_service.get_or_create_default_user(uow)
workspace, membership = await self._verify_workspace_access( workspace, membership = await self._verify_workspace_access(
@@ -278,7 +270,6 @@ class IdentityMixin:
if not membership.role.can_admin(): if not membership.role.can_admin():
await abort_permission_denied(context, ERROR_WORKSPACE_ADMIN_REQUIRED) await abort_permission_denied(context, ERROR_WORKSPACE_ADMIN_REQUIRED)
raise # Unreachable but helps type checker
updates = proto_to_workspace_settings(request.settings) updates = proto_to_workspace_settings(request.settings)
if updates is None: if updates is None:
@@ -302,11 +293,9 @@ class IdentityMixin:
workspace = await uow.workspaces.get(workspace_id) workspace = await uow.workspaces.get(workspace_id)
if not workspace: if not workspace:
await abort_not_found(context, ENTITY_WORKSPACE, str(workspace_id)) await abort_not_found(context, ENTITY_WORKSPACE, str(workspace_id))
raise # Unreachable but helps type checker
membership = await uow.workspaces.get_membership(workspace_id, user_id) membership = await uow.workspaces.get_membership(workspace_id, user_id)
if not membership: if not membership:
await abort_not_found(context, "Workspace membership", str(workspace_id)) await abort_not_found(context, "Workspace membership", str(workspace_id))
raise # Unreachable but helps type checker
return workspace, membership return workspace, membership

View File

@@ -42,7 +42,6 @@ if TYPE_CHECKING:
from .._types import GrpcContext from .._types import GrpcContext
from ..protocols import ServicerHost from ..protocols import ServicerHost
from ..errors._constants import UNREACHABLE_ERROR
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -66,7 +65,6 @@ async def _load_meeting_for_stop(
if meeting is None: if meeting is None:
logger.warning("StopMeeting: meeting not found", meeting_id=meeting_id_str) logger.warning("StopMeeting: meeting not found", meeting_id=meeting_id_str)
await abort_not_found(context, ENTITY_MEETING, meeting_id_str) await abort_not_found(context, ENTITY_MEETING, meeting_id_str)
raise AssertionError(UNREACHABLE_ERROR)
return meeting return meeting
@@ -85,7 +83,7 @@ async def _stop_meeting_and_persist(context: _StopMeetingContext) -> Meeting:
) )
return context.meeting return context.meeting
previous_state = context.meeting.state.value previous_state = context.meeting.state.name
await transition_to_stopped( await transition_to_stopped(
context.meeting, context.meeting,
context.meeting_id, context.meeting_id,
@@ -143,7 +141,9 @@ class MeetingMixin:
op_context = self.get_operation_context(context) op_context = self.get_operation_context(context)
async with cast(MeetingRepositoryProvider, self.create_repository_provider()) as repo: async with cast(MeetingRepositoryProvider, self.create_repository_provider()) as repo:
project_id = await resolve_create_project_id(self, repo, op_context, project_id) project_id = await resolve_create_project_id(
cast("ServicerHost", self), repo, op_context, project_id
)
meeting = Meeting.create( meeting = Meeting.create(
title=request.title, title=request.title,
@@ -212,7 +212,7 @@ class MeetingMixin:
async with cast(MeetingRepositoryProvider, self.create_repository_provider()) as repo: async with cast(MeetingRepositoryProvider, self.create_repository_provider()) as repo:
if project_id is None and not project_ids: if project_id is None and not project_ids:
project_id = await resolve_active_project_id(self, repo) project_id = await resolve_active_project_id(cast("ServicerHost", self), repo)
meetings, total = await repo.meetings.list_all( meetings, total = await repo.meetings.list_all(
states=states, states=states,
@@ -254,7 +254,6 @@ class MeetingMixin:
if meeting is None: if meeting is None:
logger.warning("GetMeeting: meeting not found", meeting_id=request.meeting_id) logger.warning("GetMeeting: meeting not found", meeting_id=request.meeting_id)
await abort_not_found(context, ENTITY_MEETING, request.meeting_id) await abort_not_found(context, ENTITY_MEETING, request.meeting_id)
raise # Unreachable but helps type checker
# Load segments if requested # Load segments if requested
if request.include_segments: if request.include_segments:
segments = await repo.segments.get_by_meeting(meeting.id) segments = await repo.segments.get_by_meeting(meeting.id)
@@ -283,7 +282,6 @@ class MeetingMixin:
if not success: if not success:
logger.warning("DeleteMeeting: meeting not found", meeting_id=request.meeting_id) logger.warning("DeleteMeeting: meeting not found", meeting_id=request.meeting_id)
await abort_not_found(context, ENTITY_MEETING, request.meeting_id) await abort_not_found(context, ENTITY_MEETING, request.meeting_id)
raise # Unreachable but helps type checker
await repo.commit() await repo.commit()
logger.info("Meeting deleted", meeting_id=request.meeting_id) logger.info("Meeting deleted", meeting_id=request.meeting_id)

View File

@@ -35,8 +35,6 @@ if TYPE_CHECKING:
from ..protocols import ProjectRepositoryProvider from ..protocols import ProjectRepositoryProvider
async def _parse_project_and_user_ids( async def _parse_project_and_user_ids(
request_project_id: str, request_project_id: str,
request_user_id: str, request_user_id: str,
@@ -45,21 +43,19 @@ async def _parse_project_and_user_ids(
"""Parse and validate project and user IDs from request.""" """Parse and validate project and user IDs from request."""
if not request_project_id: if not request_project_id:
await abort_invalid_argument(context, ERROR_PROJECT_ID_REQUIRED) await abort_invalid_argument(context, ERROR_PROJECT_ID_REQUIRED)
raise # Unreachable but helps type checker
if not request_user_id: if not request_user_id:
await abort_invalid_argument(context, ERROR_USER_ID_REQUIRED) await abort_invalid_argument(context, ERROR_USER_ID_REQUIRED)
raise # Unreachable but helps type checker
try: try:
project_id = UUID(request_project_id) project_id = UUID(request_project_id)
user_id = UUID(request_user_id) user_id = UUID(request_user_id)
except ValueError as e: except ValueError as e:
await abort_invalid_argument(context, f"{ERROR_INVALID_UUID_PREFIX}{e}") await abort_invalid_argument(context, f"{ERROR_INVALID_UUID_PREFIX}{e}")
raise # Unreachable but helps type checker
return project_id, user_id return project_id, user_id
class ProjectMembershipMixin: class ProjectMembershipMixin:
"""Mixin providing project membership functionality. """Mixin providing project membership functionality.
@@ -93,7 +89,6 @@ class ProjectMembershipMixin:
) )
if membership is None: if membership is None:
await abort_not_found(context, ENTITY_PROJECT, request.project_id) await abort_not_found(context, ENTITY_PROJECT, request.project_id)
raise # Unreachable but helps type checker
return membership_to_proto(membership) return membership_to_proto(membership)
@@ -119,8 +114,9 @@ class ProjectMembershipMixin:
role=role, role=role,
) )
if membership is None: if membership is None:
await abort_not_found(context, "Membership", f"{request.project_id}/{request.user_id}") await abort_not_found(
raise # Unreachable but helps type checker context, "Membership", f"{request.project_id}/{request.user_id}"
)
return membership_to_proto(membership) return membership_to_proto(membership)
@@ -159,8 +155,9 @@ class ProjectMembershipMixin:
try: try:
project_id = UUID(request.project_id) project_id = UUID(request.project_id)
except ValueError: except ValueError:
await abort_invalid_argument(context, f"{ERROR_INVALID_PROJECT_ID_PREFIX}{request.project_id}") await abort_invalid_argument(
raise # Unreachable but helps type checker context, f"{ERROR_INVALID_PROJECT_ID_PREFIX}{request.project_id}"
)
limit = request.limit if request.limit > 0 else 100 limit = request.limit if request.limit > 0 else 100
offset = max(request.offset, 0) offset = max(request.offset, 0)

View File

@@ -41,8 +41,6 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
async def _require_and_parse_project_id( async def _require_and_parse_project_id(
request_project_id: str, request_project_id: str,
context: GrpcContext, context: GrpcContext,
@@ -50,7 +48,6 @@ async def _require_and_parse_project_id(
"""Require and parse a project_id from request.""" """Require and parse a project_id from request."""
if not request_project_id: if not request_project_id:
await abort_invalid_argument(context, ERROR_PROJECT_ID_REQUIRED) await abort_invalid_argument(context, ERROR_PROJECT_ID_REQUIRED)
raise # Unreachable but helps type checker
return await parse_project_id(request_project_id, context) return await parse_project_id(request_project_id, context)
@@ -61,9 +58,9 @@ async def _require_and_parse_workspace_id(
"""Require and parse a workspace_id from request.""" """Require and parse a workspace_id from request."""
if not request_workspace_id: if not request_workspace_id:
await abort_invalid_argument(context, ERROR_WORKSPACE_ID_REQUIRED) await abort_invalid_argument(context, ERROR_WORKSPACE_ID_REQUIRED)
raise # Unreachable but helps type checker
return await parse_workspace_id(request_workspace_id, context) return await parse_workspace_id(request_workspace_id, context)
class ProjectMixin: class ProjectMixin:
"""Mixin providing project management functionality. """Mixin providing project management functionality.
@@ -89,11 +86,12 @@ class ProjectMixin:
if not request.name: if not request.name:
await abort_invalid_argument(context, "name is required") await abort_invalid_argument(context, "name is required")
raise # Unreachable but helps type checker
slug = request.slug if request.HasField("slug") else None slug = request.slug if request.HasField("slug") else None
description = request.description if request.HasField("description") else None description = request.description if request.HasField("description") else None
settings = proto_to_project_settings(request.settings) if request.HasField("settings") else None settings = (
proto_to_project_settings(request.settings) if request.HasField("settings") else None
)
async with cast(ProjectRepositoryProvider, self.create_repository_provider()) as uow: async with cast(ProjectRepositoryProvider, self.create_repository_provider()) as uow:
await require_feature_projects(uow, context) await require_feature_projects(uow, context)
@@ -123,7 +121,6 @@ class ProjectMixin:
project = await project_service.get_project(uow, project_id) project = await project_service.get_project(uow, project_id)
if project is None: if project is None:
await abort_not_found(context, ENTITY_PROJECT, request.project_id) await abort_not_found(context, ENTITY_PROJECT, request.project_id)
raise # Unreachable but helps type checker
return project_to_proto(project) return project_to_proto(project)
@@ -138,7 +135,6 @@ class ProjectMixin:
if not request.slug: if not request.slug:
await abort_invalid_argument(context, "slug is required") await abort_invalid_argument(context, "slug is required")
raise # Unreachable but helps type checker
async with cast(ProjectRepositoryProvider, self.create_repository_provider()) as uow: async with cast(ProjectRepositoryProvider, self.create_repository_provider()) as uow:
await require_feature_projects(uow, context) await require_feature_projects(uow, context)
@@ -146,7 +142,6 @@ class ProjectMixin:
project = await project_service.get_project_by_slug(uow, workspace_id, request.slug) project = await project_service.get_project_by_slug(uow, workspace_id, request.slug)
if project is None: if project is None:
await abort_not_found(context, ENTITY_PROJECT, request.slug) await abort_not_found(context, ENTITY_PROJECT, request.slug)
raise # Unreachable but helps type checker
return project_to_proto(project) return project_to_proto(project)
@@ -159,6 +154,7 @@ class ProjectMixin:
project_service = await require_project_service(self.project_service, context) project_service = await require_project_service(self.project_service, context)
workspace_id = await _require_and_parse_workspace_id(request.workspace_id, context) workspace_id = await _require_and_parse_workspace_id(request.workspace_id, context)
include_archived = request.include_archived
limit = request.limit if request.limit > 0 else 50 limit = request.limit if request.limit > 0 else 50
offset = max(request.offset, 0) offset = max(request.offset, 0)
@@ -166,21 +162,20 @@ class ProjectMixin:
await require_feature_projects(uow, context) await require_feature_projects(uow, context)
projects = await project_service.list_projects( projects = await project_service.list_projects(
uow=uow, uow,
workspace_id=workspace_id, workspace_id,
include_archived=request.include_archived, include_archived=include_archived,
limit=limit, limit=limit,
offset=offset, offset=offset,
) )
total_count = await project_service.count_projects( total_count = await project_service.count_projects(
uow=uow, uow,
workspace_id=workspace_id, workspace_id,
include_archived=request.include_archived, include_archived=include_archived,
) )
return noteflow_pb2.ListProjectsResponse( return noteflow_pb2.ListProjectsResponse(
projects=[project_to_proto(p) for p in projects], projects=[project_to_proto(project) for project in projects],
total_count=total_count, total_count=total_count,
) )
@@ -196,14 +191,16 @@ class ProjectMixin:
name = request.name if request.HasField("name") else None name = request.name if request.HasField("name") else None
slug = request.slug if request.HasField("slug") else None slug = request.slug if request.HasField("slug") else None
description = request.description if request.HasField("description") else None description = request.description if request.HasField("description") else None
settings = proto_to_project_settings(request.settings) if request.HasField("settings") else None settings = (
proto_to_project_settings(request.settings) if request.HasField("settings") else None
)
async with cast(ProjectRepositoryProvider, self.create_repository_provider()) as uow: async with cast(ProjectRepositoryProvider, self.create_repository_provider()) as uow:
await require_feature_projects(uow, context) await require_feature_projects(uow, context)
project = await project_service.update_project( project = await project_service.update_project(
uow=uow, uow,
project_id=project_id, project_id,
name=name, name=name,
slug=slug, slug=slug,
description=description, description=description,
@@ -211,7 +208,6 @@ class ProjectMixin:
) )
if project is None: if project is None:
await abort_not_found(context, ENTITY_PROJECT, request.project_id) await abort_not_found(context, ENTITY_PROJECT, request.project_id)
raise # Unreachable but helps type checker
return project_to_proto(project) return project_to_proto(project)
@@ -231,11 +227,9 @@ class ProjectMixin:
project = await project_service.archive_project(uow, project_id) project = await project_service.archive_project(uow, project_id)
except CannotArchiveDefaultProjectError: except CannotArchiveDefaultProjectError:
await abort_failed_precondition(context, "Cannot archive the default project") await abort_failed_precondition(context, "Cannot archive the default project")
raise # Unreachable but helps type checker
if project is None: if project is None:
await abort_not_found(context, ENTITY_PROJECT, request.project_id) await abort_not_found(context, ENTITY_PROJECT, request.project_id)
raise # Unreachable but helps type checker
return project_to_proto(project) return project_to_proto(project)
@@ -254,7 +248,6 @@ class ProjectMixin:
project = await project_service.restore_project(uow, project_id) project = await project_service.restore_project(uow, project_id)
if project is None: if project is None:
await abort_not_found(context, ENTITY_PROJECT, request.project_id) await abort_not_found(context, ENTITY_PROJECT, request.project_id)
raise # Unreachable but helps type checker
return project_to_proto(project) return project_to_proto(project)
@@ -302,7 +295,6 @@ class ProjectMixin:
) )
except ValueError as exc: except ValueError as exc:
await abort_invalid_argument(context, str(exc)) await abort_invalid_argument(context, str(exc))
raise # Unreachable but helps type checker
await uow.commit() await uow.commit()
return noteflow_pb2.SetActiveProjectResponse() return noteflow_pb2.SetActiveProjectResponse()
@@ -327,11 +319,9 @@ class ProjectMixin:
) )
except ValueError as exc: except ValueError as exc:
await abort_invalid_argument(context, str(exc)) await abort_invalid_argument(context, str(exc))
raise # Unreachable but helps type checker
if project is None: if project is None:
await abort_not_found(context, ENTITY_PROJECT, "default") await abort_not_found(context, ENTITY_PROJECT, "default")
raise # Unreachable but helps type checker
response = noteflow_pb2.GetActiveProjectResponse( response = noteflow_pb2.GetActiveProjectResponse(
project=project_to_proto(project), project=project_to_proto(project),

View File

@@ -59,7 +59,6 @@ async def _decode_chunk_audio(
) )
except ValueError as e: except ValueError as e:
await abort_invalid_argument(context, str(e)) await abort_invalid_argument(context, str(e))
raise # Unreachable but helps type checker
conversion = AudioConversionContext( conversion = AudioConversionContext(
source_sample_rate=sample_rate, source_sample_rate=sample_rate,

View File

@@ -30,10 +30,13 @@ from ._template_resolution import (
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Callable from collections.abc import Callable
from noteflow.application.services.ner import NerService
from noteflow.application.services.summarization import SummarizationService from noteflow.application.services.summarization import SummarizationService
from noteflow.application.services.webhooks import WebhookService from noteflow.application.services.webhooks import WebhookService
from noteflow.domain.entities import Meeting from noteflow.domain.entities import Meeting
from noteflow.domain.identity import OperationContext from noteflow.domain.identity import OperationContext
from noteflow.infrastructure.asr import FasterWhisperEngine
from noteflow.infrastructure.diarization.engine import DiarizationEngine
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -48,9 +51,41 @@ class _SummaryGenerationContext:
force_regenerate: bool force_regenerate: bool
@dataclass(frozen=True, slots=True)
class _SummaryRequestContext:
request: noteflow_pb2.GenerateSummaryRequest
meeting: Meeting
segments: list[Segment]
style_instructions: str | None
context: GrpcContext
op_context: OperationContext
async def resolve_style_prompt(
self,
repo: UnitOfWork,
summarization_service: SummarizationService | None,
) -> str | None:
"""Resolve style prompt for this request context."""
return await resolve_template_prompt(
TemplateResolutionInputs(
request=self.request,
meeting=self.meeting,
segments=self.segments,
style_instructions=self.style_instructions,
context=self.context,
op_context=self.op_context,
repo=repo,
summarization_service=summarization_service,
)
)
class SummarizationGenerationMixin: class SummarizationGenerationMixin:
"""Generate summaries and handle summary webhooks.""" """Generate summaries and handle summary webhooks."""
asr_engine: FasterWhisperEngine | None
diarization_engine: DiarizationEngine | None
ner_service: NerService | None
summarization_service: SummarizationService | None summarization_service: SummarizationService | None
webhook_service: WebhookService | None webhook_service: WebhookService | None
create_repository_provider: Callable[..., object] create_repository_provider: Callable[..., object]
@@ -68,28 +103,28 @@ class SummarizationGenerationMixin:
meeting_id=request.meeting_id, meeting_id=request.meeting_id,
include_provider_details=True, include_provider_details=True,
) )
meeting_id, op_context, style_instructions, meeting, existing, segments = ( (
await self._prepare_summary_request(request, context) meeting_id,
) op_context,
style_instructions,
meeting,
existing,
segments,
) = await self._prepare_summary_request(request, context)
if existing and not request.force_regenerate: if existing and not request.force_regenerate:
await self._mark_summary_step(request.meeting_id, ProcessingStepStatus.COMPLETED) await self._mark_summary_step(request.meeting_id, ProcessingStepStatus.COMPLETED)
return summary_to_proto(existing) return summary_to_proto(existing)
await self._ensure_cloud_provider() await self._ensure_cloud_provider()
request_context = _SummaryRequestContext(
async with cast(UnitOfWork, self.create_repository_provider()) as repo: request=request,
style_prompt = await self._resolve_style_prompt( meeting=meeting,
TemplateResolutionInputs( segments=segments,
request=request, style_instructions=style_instructions,
meeting=meeting, context=context,
segments=segments, op_context=op_context,
style_instructions=style_instructions, )
context=context, style_prompt = await self._resolve_style_prompt_for_request(request_context)
op_context=op_context,
repo=repo,
summarization_service=self.summarization_service,
)
)
saved, trigger_webhook = await self._generate_summary_with_status( saved, trigger_webhook = await self._generate_summary_with_status(
_SummaryGenerationContext( _SummaryGenerationContext(
meeting_id=meeting_id, meeting_id=meeting_id,
@@ -121,6 +156,13 @@ class SummarizationGenerationMixin:
), ),
) )
async def _resolve_style_prompt_for_request(
self,
request_context: _SummaryRequestContext,
) -> str | None:
async with cast(UnitOfWork, self.create_repository_provider()) as repo:
return await request_context.resolve_style_prompt(repo, self.summarization_service)
async def _resolve_style_prompt(self, inputs: TemplateResolutionInputs) -> str | None: async def _resolve_style_prompt(self, inputs: TemplateResolutionInputs) -> str | None:
return await resolve_template_prompt(inputs) return await resolve_template_prompt(inputs)
@@ -244,7 +286,6 @@ class SummarizationGenerationMixin:
meeting = await repo.meetings.get(meeting_id) meeting = await repo.meetings.get(meeting_id)
if meeting is None: if meeting is None:
await abort_not_found(context, ENTITY_MEETING, request.meeting_id) await abort_not_found(context, ENTITY_MEETING, request.meeting_id)
raise # Unreachable but helps type checker
existing = await repo.summaries.get_by_meeting(meeting.id) existing = await repo.summaries.get_by_meeting(meeting.id)
segments = list(await repo.segments.get_by_meeting(meeting.id)) segments = list(await repo.segments.get_by_meeting(meeting.id))

View File

@@ -9,6 +9,8 @@ from typing import Generic, TypeVar, Never
import grpc import grpc
from noteflow.domain.constants.fields import UNKNOWN as UNKNOWN_VALUE
T = TypeVar("T", covariant=True) T = TypeVar("T", covariant=True)
U = TypeVar("U") U = TypeVar("U")
E = TypeVar("E", bound=BaseException) E = TypeVar("E", bound=BaseException)
@@ -17,7 +19,7 @@ E = TypeVar("E", bound=BaseException)
class ClientErrorCode(StrEnum): class ClientErrorCode(StrEnum):
"""Client-facing error codes for gRPC operations.""" """Client-facing error codes for gRPC operations."""
UNKNOWN = "unknown" UNKNOWN = UNKNOWN_VALUE
NOT_CONNECTED = "not_connected" NOT_CONNECTED = "not_connected"
NOT_FOUND = "not_found" NOT_FOUND = "not_found"
INVALID_ARGUMENT = "invalid_argument" INVALID_ARGUMENT = "invalid_argument"

View File

@@ -5,8 +5,8 @@ Tests identity context validation and per-RPC request logging.
from __future__ import annotations from __future__ import annotations
from collections.abc import Awaitable, Callable from collections.abc import AsyncIterator, Awaitable, Callable
from typing import Protocol, cast from typing import cast
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import grpc import grpc
@@ -19,7 +19,11 @@ from noteflow.grpc.interceptors import (
IdentityInterceptor, IdentityInterceptor,
RequestLoggingInterceptor, RequestLoggingInterceptor,
) )
from noteflow.grpc.interceptors._types import ServicerContextProtocol from noteflow.grpc.interceptors._types import (
HandlerCallDetailsProtocol,
RpcMethodHandlerProtocol,
ServicerContextProtocol,
)
from noteflow.infrastructure.logging import ( from noteflow.infrastructure.logging import (
get_request_id, get_request_id,
get_user_id, get_user_id,
@@ -27,14 +31,13 @@ from noteflow.infrastructure.logging import (
request_id_var, request_id_var,
) )
# Type alias for callable unary-unary handlers
UnaryUnaryCallable = Callable[[bytes, ServicerContextProtocol], Awaitable[bytes]]
class _DummyRequest: # Type alias for test handler protocol (using bytes for wire-level compatibility)
"""Placeholder request type for handler casts.""" TestHandlerProtocol = RpcMethodHandlerProtocol[bytes, bytes]
class _DummyResponse:
"""Placeholder response type for handler casts."""
# Test data # Test data
TEST_REQUEST_ID = "test-request-123" TEST_REQUEST_ID = "test-request-123"
TEST_USER_ID = "user-456" TEST_USER_ID = "user-456"
@@ -45,25 +48,28 @@ TEST_METHOD = "/noteflow.NoteFlowService/GetMeeting"
pytestmark = pytest.mark.usefixtures("reset_context_vars") pytestmark = pytest.mark.usefixtures("reset_context_vars")
# Type alias for handler call details matching grpc module
HandlerCallDetails = HandlerCallDetailsProtocol
def create_handler_call_details( def create_handler_call_details(
method: str = TEST_METHOD, method: str = TEST_METHOD,
metadata: list[tuple[str, str | bytes]] | None = None, metadata: list[tuple[str, str | bytes]] | None = None,
) -> grpc.HandlerCallDetails: ) -> HandlerCallDetails:
"""Create mock HandlerCallDetails with metadata.""" """Create mock HandlerCallDetails with metadata."""
details = MagicMock(spec=grpc.HandlerCallDetails) details = MagicMock(spec=HandlerCallDetails)
details.method = method details.method = method
details.invocation_metadata = metadata or [] details.invocation_metadata = metadata or []
return details return details
def create_mock_handler() -> _UnaryUnaryHandler: def create_mock_handler() -> TestHandlerProtocol:
"""Create a mock RPC method handler.""" """Create a mock RPC method handler."""
return _MockHandler() return _MockHandler()
def create_mock_continuation( def create_mock_continuation(
handler: _UnaryUnaryHandler | None = None, handler: TestHandlerProtocol | None = None,
) -> AsyncMock: ) -> AsyncMock:
"""Create a mock continuation function.""" """Create a mock continuation function."""
if handler is None: if handler is None:
@@ -71,46 +77,75 @@ def create_mock_continuation(
return AsyncMock(return_value=handler) return AsyncMock(return_value=handler)
class _UnaryUnaryHandler(Protocol):
"""Protocol for unary-unary RPC method handlers."""
unary_unary: Callable[
[_DummyRequest, ServicerContextProtocol],
Awaitable[_DummyResponse],
] | None
unary_stream: object | None
stream_unary: object | None
stream_stream: object | None
request_deserializer: object | None
response_serializer: object | None
class _MockHandler: class _MockHandler:
"""Concrete handler for tests with typed unary_unary.""" """Concrete handler for tests with typed unary_unary.
unary_unary: Callable[ Implements RpcMethodHandlerProtocol for interceptor testing.
[_DummyRequest, ServicerContextProtocol], """
Awaitable[_DummyResponse],
] | None
unary_stream: object | None
stream_unary: object | None
stream_stream: object | None
request_deserializer: object | None
response_serializer: object | None
def __init__(self) -> None: def __init__(self) -> None:
self.unary_unary = cast( self._unary_unary: Callable[[bytes, ServicerContextProtocol], Awaitable[bytes]] | None = (
Callable[ cast(
[_DummyRequest, ServicerContextProtocol], Callable[[bytes, ServicerContextProtocol], Awaitable[bytes]],
Awaitable[_DummyResponse], AsyncMock(return_value=b"response"),
], )
AsyncMock(return_value="response"),
) )
self.unary_stream = None self._unary_stream: (
self.stream_unary = None Callable[[bytes, ServicerContextProtocol], AsyncIterator[bytes]] | None
self.stream_stream = None ) = None
self.request_deserializer = None self._stream_unary: (
self.response_serializer = None Callable[[AsyncIterator[bytes], ServicerContextProtocol], Awaitable[bytes]] | None
) = None
self._stream_stream: (
Callable[[AsyncIterator[bytes], ServicerContextProtocol], AsyncIterator[bytes]] | None
) = None
self._request_deserializer: Callable[[bytes], bytes] | None = None
self._response_serializer: Callable[[bytes], bytes] | None = None
@property
def request_deserializer(self) -> Callable[[bytes], bytes] | None:
"""Request deserializer function."""
return self._request_deserializer
@property
def response_serializer(self) -> Callable[[bytes], bytes] | None:
"""Response serializer function."""
return self._response_serializer
@property
def unary_unary(
self,
) -> Callable[[bytes, ServicerContextProtocol], Awaitable[bytes]] | None:
"""Unary-unary handler behavior."""
return self._unary_unary
@unary_unary.setter
def unary_unary(
self, value: Callable[[bytes, ServicerContextProtocol], Awaitable[bytes]] | None
) -> None:
"""Set unary-unary handler behavior."""
self._unary_unary = value
@property
def unary_stream(
self,
) -> Callable[[bytes, ServicerContextProtocol], AsyncIterator[bytes]] | None:
"""Unary-stream handler behavior."""
return self._unary_stream
@property
def stream_unary(
self,
) -> Callable[[AsyncIterator[bytes], ServicerContextProtocol], Awaitable[bytes]] | None:
"""Stream-unary handler behavior."""
return self._stream_unary
@property
def stream_stream(
self,
) -> Callable[[AsyncIterator[bytes], ServicerContextProtocol], AsyncIterator[bytes]] | None:
"""Stream-stream handler behavior."""
return self._stream_stream
class TestIdentityInterceptor: class TestIdentityInterceptor:
@@ -157,8 +192,9 @@ class TestIdentityInterceptor:
original_handler = create_mock_handler() original_handler = create_mock_handler()
continuation = create_mock_continuation(original_handler) continuation = create_mock_continuation(original_handler)
handler = await interceptor.intercept_service(continuation, details) typed_handler: TestHandlerProtocol = await interceptor.intercept_service(
typed_handler = cast(_UnaryUnaryHandler, handler) continuation, details
)
# Handler should be a rejection handler wrapping the original # Handler should be a rejection handler wrapping the original
assert typed_handler.unary_unary is not None, "handler should have unary_unary" assert typed_handler.unary_unary is not None, "handler should have unary_unary"
@@ -166,7 +202,9 @@ class TestIdentityInterceptor:
# is a rejection wrapper that will abort with UNAUTHENTICATED # is a rejection wrapper that will abort with UNAUTHENTICATED
continuation.assert_called_once() continuation.assert_called_once()
# The returned handler should NOT be the original handler # The returned handler should NOT be the original handler
assert typed_handler.unary_unary is not original_handler.unary_unary, "should return rejection handler, not original" assert typed_handler.unary_unary is not original_handler.unary_unary, (
"should return rejection handler, not original"
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reject_handler_aborts_with_unauthenticated(self) -> None: async def test_reject_handler_aborts_with_unauthenticated(self) -> None:
@@ -175,8 +213,9 @@ class TestIdentityInterceptor:
details = create_handler_call_details(metadata=[]) details = create_handler_call_details(metadata=[])
continuation = create_mock_continuation() continuation = create_mock_continuation()
handler = await interceptor.intercept_service(continuation, details) typed_handler: TestHandlerProtocol = await interceptor.intercept_service(
typed_handler = cast(_UnaryUnaryHandler, handler) continuation, details
)
# Create mock context to verify abort behavior # Create mock context to verify abort behavior
context: ServicerContextProtocol = AsyncMock(spec=ServicerContextProtocol) context: ServicerContextProtocol = AsyncMock(spec=ServicerContextProtocol)
@@ -184,11 +223,14 @@ class TestIdentityInterceptor:
with pytest.raises(grpc.RpcError, match="x-request-id"): with pytest.raises(grpc.RpcError, match="x-request-id"):
assert typed_handler.unary_unary is not None assert typed_handler.unary_unary is not None
await typed_handler.unary_unary(MagicMock(), context) unary_fn = typed_handler.unary_unary
await unary_fn(MagicMock(), context)
context.abort.assert_called_once() context.abort.assert_called_once()
call_args = context.abort.call_args call_args = context.abort.call_args
assert call_args[0][0] == grpc.StatusCode.UNAUTHENTICATED, "should abort with UNAUTHENTICATED" assert call_args[0][0] == grpc.StatusCode.UNAUTHENTICATED, (
"should abort with UNAUTHENTICATED"
)
assert "x-request-id" in call_args[0][1], "error message should mention x-request-id" assert "x-request-id" in call_args[0][1], "error message should mention x-request-id"
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -224,16 +266,16 @@ class TestRequestLoggingInterceptor:
request_id_var.set(TEST_REQUEST_ID) request_id_var.set(TEST_REQUEST_ID)
with patch("noteflow.grpc.interceptors.logging._logging_ops.logger") as mock_logger: with patch("noteflow.grpc.interceptors.logging._logging_ops.logger") as mock_logger:
wrapped_handler = await interceptor.intercept_service(continuation, details) wrapped_handler: TestHandlerProtocol = await interceptor.intercept_service(
typed_handler = cast(_UnaryUnaryHandler, wrapped_handler) continuation, details
)
# Execute the wrapped handler
context: ServicerContextProtocol = AsyncMock(spec=ServicerContextProtocol) context: ServicerContextProtocol = AsyncMock(spec=ServicerContextProtocol)
context.peer = MagicMock(return_value="ipv4:127.0.0.1:12345") context.peer = MagicMock(return_value="ipv4:127.0.0.1:12345")
assert typed_handler.unary_unary is not None assert wrapped_handler.unary_unary is not None
await typed_handler.unary_unary(MagicMock(), context) unary_fn = wrapped_handler.unary_unary
await unary_fn(MagicMock(), context)
# Verify logging
mock_logger.info.assert_called_once() mock_logger.info.assert_called_once()
call_kwargs = mock_logger.info.call_args[1] call_kwargs = mock_logger.info.call_args[1]
assert call_kwargs["method"] == TEST_METHOD, "should log method" assert call_kwargs["method"] == TEST_METHOD, "should log method"
@@ -246,28 +288,31 @@ class TestRequestLoggingInterceptor:
"""Interceptor logs error status when handler raises.""" """Interceptor logs error status when handler raises."""
interceptor = RequestLoggingInterceptor() interceptor = RequestLoggingInterceptor()
# Create handler that raises
handler = create_mock_handler() handler = create_mock_handler()
handler.unary_unary = AsyncMock(side_effect=Exception("Test error")) mock_handler = cast(_MockHandler, handler)
mock_handler.unary_unary = AsyncMock(side_effect=Exception("Test error"))
details = create_handler_call_details(method=TEST_METHOD) details = create_handler_call_details(method=TEST_METHOD)
continuation = create_mock_continuation(handler) continuation = create_mock_continuation(handler)
with patch("noteflow.grpc.interceptors.logging._logging_ops.logger") as mock_logger: with patch("noteflow.grpc.interceptors.logging._logging_ops.logger") as mock_logger:
wrapped_handler = await interceptor.intercept_service(continuation, details) wrapped_handler: TestHandlerProtocol = await interceptor.intercept_service(
typed_handler = cast(_UnaryUnaryHandler, wrapped_handler) continuation, details
)
context: ServicerContextProtocol = AsyncMock(spec=ServicerContextProtocol) context: ServicerContextProtocol = AsyncMock(spec=ServicerContextProtocol)
context.peer = MagicMock(return_value="ipv4:127.0.0.1:12345") context.peer = MagicMock(return_value="ipv4:127.0.0.1:12345")
with pytest.raises(Exception, match="Test error"): with pytest.raises(Exception, match="Test error"):
assert typed_handler.unary_unary is not None assert wrapped_handler.unary_unary is not None
await typed_handler.unary_unary(MagicMock(), context) unary_fn = wrapped_handler.unary_unary
await unary_fn(MagicMock(), context)
# Should still log with INTERNAL status
mock_logger.info.assert_called_once() mock_logger.info.assert_called_once()
call_kwargs = mock_logger.info.call_args[1] call_kwargs = mock_logger.info.call_args[1]
assert call_kwargs["status"] == "INTERNAL", "should log INTERNAL for unhandled exception" assert call_kwargs["status"] == "INTERNAL", (
"should log INTERNAL for unhandled exception"
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_passes_through_to_continuation(self) -> None: async def test_passes_through_to_continuation(self) -> None:
@@ -277,7 +322,7 @@ class TestRequestLoggingInterceptor:
details = create_handler_call_details() details = create_handler_call_details()
continuation = create_mock_continuation(handler) continuation = create_mock_continuation(handler)
await interceptor.intercept_service(continuation, details) _: TestHandlerProtocol = await interceptor.intercept_service(continuation, details)
continuation.assert_called_once_with(details) continuation.assert_called_once_with(details)
@@ -290,17 +335,17 @@ class TestRequestLoggingInterceptor:
continuation = create_mock_continuation(handler) continuation = create_mock_continuation(handler)
with patch("noteflow.grpc.interceptors.logging._logging_ops.logger") as mock_logger: with patch("noteflow.grpc.interceptors.logging._logging_ops.logger") as mock_logger:
wrapped_handler = await interceptor.intercept_service(continuation, details) wrapped_handler: TestHandlerProtocol = await interceptor.intercept_service(
typed_handler = cast(_UnaryUnaryHandler, wrapped_handler) continuation, details
)
# Context without peer method
context: ServicerContextProtocol = AsyncMock(spec=ServicerContextProtocol) context: ServicerContextProtocol = AsyncMock(spec=ServicerContextProtocol)
context.peer = MagicMock(side_effect=RuntimeError("No peer")) context.peer = MagicMock(side_effect=RuntimeError("No peer"))
assert typed_handler.unary_unary is not None assert wrapped_handler.unary_unary is not None
await typed_handler.unary_unary(MagicMock(), context) unary_fn = wrapped_handler.unary_unary
await unary_fn(MagicMock(), context)
# Should still log with None peer
mock_logger.info.assert_called_once() mock_logger.info.assert_called_once()
call_kwargs = mock_logger.info.call_args[1] call_kwargs = mock_logger.info.call_args[1]
assert call_kwargs["peer"] is None, "should handle missing peer" assert call_kwargs["peer"] is None, "should handle missing peer"