xx
This commit is contained in:
@@ -114,7 +114,7 @@ file_path_pattern := `\.py$`
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Write", "Edit", "NotebookEdit"}
|
||||
tool_name in {"write", "edit", "notebookedit"}
|
||||
|
||||
file_path := resolved_file_path
|
||||
regex.match(file_path_pattern, file_path)
|
||||
@@ -132,7 +132,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name == "MultiEdit"
|
||||
tool_name == "multiedit"
|
||||
|
||||
some edit in tool_input.edits
|
||||
file_path := edit_path(edit)
|
||||
@@ -151,7 +151,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Patch", "ApplyPatch"}
|
||||
tool_name in {"patch", "apply_patch"}
|
||||
|
||||
patch := patch_content
|
||||
patch != null
|
||||
|
||||
@@ -126,7 +126,7 @@ assertion_pattern := `^\s*assert\s+[^,\n]+\n\s*assert\s+[^,\n]+$`
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Write", "Edit", "NotebookEdit"}
|
||||
tool_name in {"write", "edit", "notebookedit"}
|
||||
|
||||
file_path := resolved_file_path
|
||||
regex.match(file_path_pattern, file_path)
|
||||
@@ -144,7 +144,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name == "MultiEdit"
|
||||
tool_name == "multiedit"
|
||||
|
||||
some edit in tool_input.edits
|
||||
file_path := edit_path(edit)
|
||||
@@ -163,7 +163,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Patch", "ApplyPatch"}
|
||||
tool_name in {"patch", "apply_patch"}
|
||||
|
||||
patch := patch_content
|
||||
patch != null
|
||||
|
||||
@@ -126,7 +126,7 @@ ignore_pattern := `//\s*biome-ignore|//\s*@ts-ignore|//\s*@ts-expect-error|//\s*
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Write", "Edit", "NotebookEdit"}
|
||||
tool_name in {"write", "edit", "notebookedit"}
|
||||
|
||||
file_path := resolved_file_path
|
||||
regex.match(file_path_pattern, file_path)
|
||||
@@ -144,7 +144,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name == "MultiEdit"
|
||||
tool_name == "multiedit"
|
||||
|
||||
some edit in tool_input.edits
|
||||
file_path := edit_path(edit)
|
||||
@@ -163,7 +163,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Patch", "ApplyPatch"}
|
||||
tool_name in {"patch", "apply_patch"}
|
||||
|
||||
patch := patch_content
|
||||
patch != null
|
||||
|
||||
@@ -104,7 +104,7 @@ handler_pattern := `except\s+Exception\s*(?:as\s+\w+)?:\s*\n\s+(?:logger\.|loggi
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Write", "Edit", "NotebookEdit"}
|
||||
tool_name in {"write", "edit", "notebookedit"}
|
||||
|
||||
content := new_content
|
||||
content != null
|
||||
@@ -119,7 +119,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name == "MultiEdit"
|
||||
tool_name == "multiedit"
|
||||
|
||||
some edit in tool_input.edits
|
||||
content := edit_new_content(edit)
|
||||
@@ -135,7 +135,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Patch", "ApplyPatch"}
|
||||
tool_name in {"patch", "apply_patch"}
|
||||
|
||||
patch := patch_content
|
||||
patch != null
|
||||
|
||||
@@ -104,7 +104,7 @@ file_path_pattern := `src/test/code-quality\.test\.ts$`
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Write", "Edit", "NotebookEdit"}
|
||||
tool_name in {"write", "edit", "notebookedit"}
|
||||
|
||||
file_path := resolved_file_path
|
||||
regex.match(file_path_pattern, file_path)
|
||||
@@ -118,7 +118,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name == "MultiEdit"
|
||||
tool_name == "multiedit"
|
||||
|
||||
some edit in tool_input.edits
|
||||
file_path := edit_path(edit)
|
||||
|
||||
@@ -104,7 +104,7 @@ pattern := `return\s+datetime\.now\s*\(`
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Write", "Edit", "NotebookEdit"}
|
||||
tool_name in {"write", "edit", "notebookedit"}
|
||||
|
||||
content := new_content
|
||||
content != null
|
||||
@@ -119,7 +119,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name == "MultiEdit"
|
||||
tool_name == "multiedit"
|
||||
|
||||
some edit in tool_input.edits
|
||||
content := edit_new_content(edit)
|
||||
@@ -135,7 +135,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Patch", "ApplyPatch"}
|
||||
tool_name in {"patch", "apply_patch"}
|
||||
|
||||
patch := patch_content
|
||||
patch != null
|
||||
|
||||
@@ -104,7 +104,7 @@ pattern := `except\s+\w*(?:Error|Exception).*?:\s*\n\s+.*?(?:logger\.|logging\.)
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Write", "Edit", "NotebookEdit"}
|
||||
tool_name in {"write", "edit", "notebookedit"}
|
||||
|
||||
content := new_content
|
||||
content != null
|
||||
@@ -119,7 +119,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name == "MultiEdit"
|
||||
tool_name == "multiedit"
|
||||
|
||||
some edit in tool_input.edits
|
||||
content := edit_new_content(edit)
|
||||
@@ -135,7 +135,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Patch", "ApplyPatch"}
|
||||
tool_name in {"patch", "apply_patch"}
|
||||
|
||||
patch := patch_content
|
||||
patch != null
|
||||
|
||||
@@ -127,7 +127,7 @@ fixture_pattern := `@pytest\.fixture[^@]*\ndef\s+(mock_uow|crypto|meetings_dir|w
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Write", "Edit", "NotebookEdit"}
|
||||
tool_name in {"write", "edit", "notebookedit"}
|
||||
|
||||
file_path := resolved_file_path
|
||||
regex.match(file_path_pattern, file_path)
|
||||
@@ -146,7 +146,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name == "MultiEdit"
|
||||
tool_name == "multiedit"
|
||||
|
||||
some edit in tool_input.edits
|
||||
file_path := edit_path(edit)
|
||||
@@ -166,7 +166,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Patch", "ApplyPatch"}
|
||||
tool_name in {"patch", "apply_patch"}
|
||||
|
||||
patch := patch_content
|
||||
patch != null
|
||||
|
||||
@@ -104,7 +104,7 @@ file_path_pattern := `(^|/)client/.*(?:\.?eslint(?:rc|\.config).*|\.?prettier(?:
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Write", "Edit", "NotebookEdit"}
|
||||
tool_name in {"write", "edit", "notebookedit"}
|
||||
|
||||
file_path := resolved_file_path
|
||||
regex.match(file_path_pattern, file_path)
|
||||
@@ -119,7 +119,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name == "MultiEdit"
|
||||
tool_name == "multiedit"
|
||||
|
||||
some edit in tool_input.edits
|
||||
file_path := edit_path(edit)
|
||||
|
||||
@@ -104,7 +104,7 @@ file_path_pattern := `(?:pyproject\.toml|\.?ruff\.toml|\.?pyrightconfig\.json|\.
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Write", "Edit", "NotebookEdit"}
|
||||
tool_name in {"write", "edit", "notebookedit"}
|
||||
|
||||
file_path := resolved_file_path
|
||||
regex.match(file_path_pattern, file_path)
|
||||
@@ -119,7 +119,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name == "MultiEdit"
|
||||
tool_name == "multiedit"
|
||||
|
||||
some edit in tool_input.edits
|
||||
file_path := edit_path(edit)
|
||||
|
||||
@@ -126,7 +126,7 @@ number_pattern := `(?:timeout|delay|interval|duration|limit|max|min|size|count|t
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Write", "Edit", "NotebookEdit"}
|
||||
tool_name in {"write", "edit", "notebookedit"}
|
||||
|
||||
file_path := resolved_file_path
|
||||
regex.match(file_path_pattern, file_path)
|
||||
@@ -145,7 +145,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name == "MultiEdit"
|
||||
tool_name == "multiedit"
|
||||
|
||||
some edit in tool_input.edits
|
||||
file_path := edit_path(edit)
|
||||
@@ -165,7 +165,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Patch", "ApplyPatch"}
|
||||
tool_name in {"patch", "apply_patch"}
|
||||
|
||||
patch := patch_content
|
||||
patch != null
|
||||
|
||||
@@ -104,7 +104,7 @@ file_path_pattern := `(?:^|/)Makefile$`
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Write", "Edit", "NotebookEdit"}
|
||||
tool_name in {"write", "edit", "notebookedit"}
|
||||
|
||||
file_path := resolved_file_path
|
||||
regex.match(file_path_pattern, file_path)
|
||||
@@ -118,7 +118,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name == "MultiEdit"
|
||||
tool_name == "multiedit"
|
||||
|
||||
some edit in tool_input.edits
|
||||
file_path := edit_path(edit)
|
||||
|
||||
@@ -104,7 +104,7 @@ pattern := `except\s+\w*Error.*?:\s*\n\s+.*?(?:logger\.|logging\.).*?\n\s+return
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Write", "Edit", "NotebookEdit"}
|
||||
tool_name in {"write", "edit", "notebookedit"}
|
||||
|
||||
content := new_content
|
||||
content != null
|
||||
@@ -119,7 +119,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name == "MultiEdit"
|
||||
tool_name == "multiedit"
|
||||
|
||||
some edit in tool_input.edits
|
||||
content := edit_new_content(edit)
|
||||
@@ -135,7 +135,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Patch", "ApplyPatch"}
|
||||
tool_name in {"patch", "apply_patch"}
|
||||
|
||||
patch := patch_content
|
||||
patch != null
|
||||
|
||||
@@ -126,7 +126,7 @@ pattern := `def test_[^(]+\([^)]*\)[^:]*:[\s\S]*?\b(for|while|if)\s+[^:]+:[\s\S]
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Write", "Edit", "NotebookEdit"}
|
||||
tool_name in {"write", "edit", "notebookedit"}
|
||||
|
||||
file_path := resolved_file_path
|
||||
regex.match(file_path_pattern, file_path)
|
||||
@@ -144,7 +144,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name == "MultiEdit"
|
||||
tool_name == "multiedit"
|
||||
|
||||
some edit in tool_input.edits
|
||||
file_path := edit_path(edit)
|
||||
@@ -163,7 +163,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Patch", "ApplyPatch"}
|
||||
tool_name in {"patch", "apply_patch"}
|
||||
|
||||
patch := patch_content
|
||||
patch != null
|
||||
|
||||
@@ -105,7 +105,7 @@ exclude_pattern := `baselines\.json$`
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Write", "Edit", "NotebookEdit"}
|
||||
tool_name in {"write", "edit", "notebookedit"}
|
||||
|
||||
file_path := resolved_file_path
|
||||
regex.match(file_path_pattern, file_path)
|
||||
@@ -120,7 +120,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name == "MultiEdit"
|
||||
tool_name == "multiedit"
|
||||
|
||||
some edit in tool_input.edits
|
||||
file_path := edit_path(edit)
|
||||
|
||||
@@ -128,7 +128,7 @@ any_type_patterns := [
|
||||
# Block Write/Edit operations that introduce Any in Python files
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Write", "Edit", "NotebookEdit"}
|
||||
tool_name in {"write", "edit", "notebookedit"}
|
||||
|
||||
# Only enforce for Python files
|
||||
file_path := lower(resolved_file_path)
|
||||
@@ -149,7 +149,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Patch", "ApplyPatch"}
|
||||
tool_name in {"patch", "apply_patch"}
|
||||
|
||||
content := patch_content
|
||||
content != null
|
||||
@@ -166,7 +166,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name == "MultiEdit"
|
||||
tool_name == "multiedit"
|
||||
|
||||
some edit in tool_input.edits
|
||||
file_path := lower(edit_path(edit))
|
||||
|
||||
@@ -123,7 +123,7 @@ type_suppression_patterns := [
|
||||
# Block Write/Edit operations that introduce type suppression in Python files
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Write", "Edit", "NotebookEdit"}
|
||||
tool_name in {"write", "edit", "notebookedit"}
|
||||
|
||||
# Only enforce for Python files
|
||||
file_path := lower(resolved_file_path)
|
||||
@@ -144,7 +144,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Patch", "ApplyPatch"}
|
||||
tool_name in {"patch", "apply_patch"}
|
||||
|
||||
content := patch_content
|
||||
content != null
|
||||
@@ -161,7 +161,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name == "MultiEdit"
|
||||
tool_name == "multiedit"
|
||||
|
||||
some edit in tool_input.edits
|
||||
file_path := lower(edit_path(edit))
|
||||
|
||||
@@ -104,7 +104,7 @@ file_path_pattern := `tests/quality/baselines\.json$`
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Write", "Edit", "NotebookEdit"}
|
||||
tool_name in {"write", "edit", "notebookedit"}
|
||||
|
||||
file_path := resolved_file_path
|
||||
regex.match(file_path_pattern, file_path)
|
||||
@@ -118,7 +118,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name == "MultiEdit"
|
||||
tool_name == "multiedit"
|
||||
|
||||
some edit in tool_input.edits
|
||||
file_path := edit_path(edit)
|
||||
|
||||
@@ -126,7 +126,7 @@ pattern := `(?:.*\n){500,}`
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Write", "Edit", "NotebookEdit"}
|
||||
tool_name in {"write", "edit", "notebookedit"}
|
||||
|
||||
file_path := resolved_file_path
|
||||
regex.match(file_path_pattern, file_path)
|
||||
@@ -144,7 +144,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name == "MultiEdit"
|
||||
tool_name == "multiedit"
|
||||
|
||||
some edit in tool_input.edits
|
||||
file_path := edit_path(edit)
|
||||
@@ -163,7 +163,7 @@ deny contains decision if {
|
||||
|
||||
deny contains decision if {
|
||||
input.hook_event_name == "PreToolUse"
|
||||
tool_name in {"Patch", "ApplyPatch"}
|
||||
tool_name in {"patch", "apply_patch"}
|
||||
|
||||
patch := patch_content
|
||||
patch != null
|
||||
|
||||
@@ -7,7 +7,7 @@ from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from noteflow.domain.constants.fields import CALENDAR, EMAIL
|
||||
from noteflow.domain.constants.fields import CALENDAR, EMAIL, UNKNOWN as UNKNOWN_VALUE
|
||||
from noteflow.domain.utils.time import utc_now
|
||||
from noteflow.infrastructure.logging import log_state_transition
|
||||
|
||||
@@ -150,7 +150,7 @@ class SyncErrorCode(StrEnum):
|
||||
AUTH_REQUIRED = "auth_required"
|
||||
PROVIDER_ERROR = "provider_error"
|
||||
INTERNAL_ERROR = "internal_error"
|
||||
UNKNOWN = "unknown"
|
||||
UNKNOWN = UNKNOWN_VALUE
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -16,9 +16,14 @@ from typing import NoReturn, Protocol, TypeVar, cast
|
||||
|
||||
import grpc
|
||||
|
||||
|
||||
TRequest = TypeVar("TRequest")
|
||||
TResponse = TypeVar("TResponse")
|
||||
|
||||
# Type variables for RpcMethodHandlerProtocol variance
|
||||
TRequest_co = TypeVar("TRequest_co", covariant=True)
|
||||
TResponse_contra = TypeVar("TResponse_contra", contravariant=True)
|
||||
|
||||
# Metadata type alias
|
||||
MetadataLike = Sequence[tuple[str, str | bytes]]
|
||||
|
||||
@@ -55,42 +60,74 @@ class ServicerContextProtocol(Protocol):
|
||||
...
|
||||
|
||||
|
||||
class RpcMethodHandlerProtocol(Protocol):
|
||||
class HandlerCallDetailsProtocol(Protocol):
|
||||
"""Protocol for gRPC HandlerCallDetails.
|
||||
|
||||
This matches the grpc.HandlerCallDetails interface at runtime.
|
||||
"""
|
||||
|
||||
method: str | None
|
||||
invocation_metadata: MetadataLike | None
|
||||
|
||||
|
||||
class ServerInterceptorProtocol(Protocol):
|
||||
"""Protocol for grpc.aio.ServerInterceptor."""
|
||||
|
||||
async def intercept_service(
|
||||
self,
|
||||
continuation: Callable[
|
||||
[HandlerCallDetailsProtocol],
|
||||
Awaitable[RpcMethodHandlerProtocol[TRequest, TResponse]],
|
||||
],
|
||||
handler_call_details: HandlerCallDetailsProtocol,
|
||||
) -> RpcMethodHandlerProtocol[TRequest, TResponse]:
|
||||
...
|
||||
|
||||
|
||||
class RpcMethodHandlerProtocol[TRequest_co, TResponse_contra](Protocol):
|
||||
"""Protocol for RPC method handlers.
|
||||
|
||||
This matches the grpc.RpcMethodHandler interface at runtime.
|
||||
Note: The handler properties return objects that are callable with
|
||||
specific request/response types, but we use object here for Protocol
|
||||
compatibility. Use casts at call sites for proper typing.
|
||||
"""
|
||||
|
||||
@property
|
||||
def request_deserializer(self) -> Callable[[bytes], object] | None:
|
||||
def request_deserializer(self) -> Callable[[bytes], TRequest_co] | None:
|
||||
"""Request deserializer function."""
|
||||
...
|
||||
|
||||
@property
|
||||
def response_serializer(self) -> Callable[[object], bytes] | None:
|
||||
def response_serializer(self) -> Callable[[TResponse_contra], bytes] | None:
|
||||
"""Response serializer function."""
|
||||
...
|
||||
|
||||
@property
|
||||
def unary_unary(self) -> object | None:
|
||||
def unary_unary(
|
||||
self,
|
||||
) -> Callable[[TRequest_co, ServicerContextProtocol], Awaitable[TResponse_contra]] | None:
|
||||
"""Unary-unary handler behavior."""
|
||||
...
|
||||
|
||||
@property
|
||||
def unary_stream(self) -> object | None:
|
||||
def unary_stream(
|
||||
self,
|
||||
) -> Callable[[TRequest_co, ServicerContextProtocol], AsyncIterator[TResponse_contra]] | None:
|
||||
"""Unary-stream handler behavior."""
|
||||
...
|
||||
|
||||
@property
|
||||
def stream_unary(self) -> object | None:
|
||||
def stream_unary(
|
||||
self,
|
||||
) -> Callable[[AsyncIterator[TRequest_co], ServicerContextProtocol], Awaitable[TResponse_contra]] | None:
|
||||
"""Stream-unary handler behavior."""
|
||||
...
|
||||
|
||||
@property
|
||||
def stream_stream(self) -> object | None:
|
||||
def stream_stream(
|
||||
self,
|
||||
) -> Callable[
|
||||
[AsyncIterator[TRequest_co], ServicerContextProtocol],
|
||||
AsyncIterator[TResponse_contra],
|
||||
] | None:
|
||||
"""Stream-stream handler behavior."""
|
||||
...
|
||||
|
||||
@@ -112,7 +149,7 @@ class GrpcFactoriesProtocol(Protocol):
|
||||
*,
|
||||
request_deserializer: Callable[[bytes], TRequest] | None = None,
|
||||
response_serializer: Callable[[TResponse], bytes] | None = None,
|
||||
) -> RpcMethodHandlerProtocol:
|
||||
) -> RpcMethodHandlerProtocol[TRequest, TResponse]:
|
||||
"""Create a unary-unary RPC method handler."""
|
||||
...
|
||||
|
||||
@@ -125,7 +162,7 @@ class GrpcFactoriesProtocol(Protocol):
|
||||
*,
|
||||
request_deserializer: Callable[[bytes], TRequest] | None = None,
|
||||
response_serializer: Callable[[TResponse], bytes] | None = None,
|
||||
) -> RpcMethodHandlerProtocol:
|
||||
) -> RpcMethodHandlerProtocol[TRequest, TResponse]:
|
||||
"""Create a unary-stream RPC method handler."""
|
||||
...
|
||||
|
||||
@@ -138,7 +175,7 @@ class GrpcFactoriesProtocol(Protocol):
|
||||
*,
|
||||
request_deserializer: Callable[[bytes], TRequest] | None = None,
|
||||
response_serializer: Callable[[TResponse], bytes] | None = None,
|
||||
) -> RpcMethodHandlerProtocol:
|
||||
) -> RpcMethodHandlerProtocol[TRequest, TResponse]:
|
||||
"""Create a stream-unary RPC method handler."""
|
||||
...
|
||||
|
||||
@@ -151,13 +188,11 @@ class GrpcFactoriesProtocol(Protocol):
|
||||
*,
|
||||
request_deserializer: Callable[[bytes], TRequest] | None = None,
|
||||
response_serializer: Callable[[TResponse], bytes] | None = None,
|
||||
) -> RpcMethodHandlerProtocol:
|
||||
) -> RpcMethodHandlerProtocol[TRequest, TResponse]:
|
||||
"""Create a stream-stream RPC method handler."""
|
||||
...
|
||||
|
||||
|
||||
# Typed grpc module with proper factory function signatures
|
||||
# Use this instead of importing grpc directly when you need typed factories
|
||||
typed_grpc: GrpcFactoriesProtocol = cast(
|
||||
GrpcFactoriesProtocol, importlib.import_module("grpc")
|
||||
)
|
||||
typed_grpc: GrpcFactoriesProtocol = cast(GrpcFactoriesProtocol, importlib.import_module("grpc"))
|
||||
|
||||
@@ -11,11 +11,9 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Awaitable, Callable
|
||||
from functools import partial
|
||||
from typing import NoReturn, cast
|
||||
from typing import TYPE_CHECKING, NoReturn, TypeVar, cast
|
||||
|
||||
import grpc
|
||||
from grpc import HandlerCallDetails, RpcMethodHandler
|
||||
from grpc.aio import ServerInterceptor
|
||||
|
||||
from noteflow.infrastructure.logging import (
|
||||
get_logger,
|
||||
@@ -25,13 +23,22 @@ from noteflow.infrastructure.logging import (
|
||||
)
|
||||
|
||||
from ._types import (
|
||||
HandlerCallDetailsProtocol,
|
||||
RpcMethodHandlerProtocol,
|
||||
ServerInterceptorProtocol,
|
||||
ServicerContextProtocol,
|
||||
TRequest,
|
||||
TResponse,
|
||||
typed_grpc,
|
||||
)
|
||||
|
||||
_TRequest = TypeVar("_TRequest")
|
||||
_TResponse = TypeVar("_TResponse")
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
ServerInterceptor = ServerInterceptorProtocol
|
||||
else:
|
||||
from grpc.aio import ServerInterceptor
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Metadata keys for identity context
|
||||
@@ -126,21 +133,18 @@ class IdentityInterceptor(ServerInterceptor):
|
||||
async def intercept_service(
|
||||
self,
|
||||
continuation: Callable[
|
||||
[HandlerCallDetails],
|
||||
Awaitable[RpcMethodHandler[TRequest, TResponse]],
|
||||
[HandlerCallDetailsProtocol],
|
||||
Awaitable[RpcMethodHandlerProtocol[_TRequest, _TResponse]],
|
||||
],
|
||||
handler_call_details: HandlerCallDetails,
|
||||
) -> RpcMethodHandler[TRequest, TResponse]:
|
||||
handler_call_details: HandlerCallDetailsProtocol,
|
||||
) -> RpcMethodHandlerProtocol[_TRequest, _TResponse]:
|
||||
"""Intercept incoming RPC calls to validate and set identity context."""
|
||||
metadata = dict(handler_call_details.invocation_metadata or [])
|
||||
|
||||
request_id = _get_request_id(metadata, handler_call_details.method)
|
||||
if request_id is None:
|
||||
handler = await continuation(handler_call_details)
|
||||
return cast(
|
||||
RpcMethodHandler[TRequest, TResponse],
|
||||
_create_unauthenticated_handler(handler, _ERR_MISSING_REQUEST_ID),
|
||||
)
|
||||
return _create_unauthenticated_handler(handler, _ERR_MISSING_REQUEST_ID)
|
||||
|
||||
_apply_identity_context(metadata, request_id)
|
||||
|
||||
@@ -156,35 +160,50 @@ class IdentityInterceptor(ServerInterceptor):
|
||||
|
||||
|
||||
def _create_unauthenticated_handler(
|
||||
handler: RpcMethodHandlerProtocol,
|
||||
handler: RpcMethodHandlerProtocol[_TRequest, _TResponse],
|
||||
message: str,
|
||||
) -> RpcMethodHandlerProtocol:
|
||||
) -> RpcMethodHandlerProtocol[_TRequest, _TResponse]:
|
||||
"""Create a handler that rejects with UNAUTHENTICATED status."""
|
||||
request_deserializer = handler.request_deserializer
|
||||
response_serializer = handler.response_serializer
|
||||
if response_serializer is not None:
|
||||
assert callable(response_serializer)
|
||||
response_serializer_any = cast(Callable[[object], bytes] | None, response_serializer)
|
||||
|
||||
if handler.unary_unary is not None:
|
||||
return typed_grpc.unary_unary_rpc_method_handler(
|
||||
partial(_reject_unary_unary, message),
|
||||
request_deserializer=request_deserializer,
|
||||
response_serializer=response_serializer,
|
||||
return cast(
|
||||
RpcMethodHandlerProtocol[_TRequest, _TResponse],
|
||||
typed_grpc.unary_unary_rpc_method_handler(
|
||||
partial(_reject_unary_unary, message),
|
||||
request_deserializer=request_deserializer,
|
||||
response_serializer=response_serializer_any,
|
||||
),
|
||||
)
|
||||
if handler.unary_stream is not None:
|
||||
return typed_grpc.unary_stream_rpc_method_handler(
|
||||
partial(_reject_unary_stream, message),
|
||||
request_deserializer=request_deserializer,
|
||||
response_serializer=response_serializer,
|
||||
return cast(
|
||||
RpcMethodHandlerProtocol[_TRequest, _TResponse],
|
||||
typed_grpc.unary_stream_rpc_method_handler(
|
||||
partial(_reject_unary_stream, message),
|
||||
request_deserializer=request_deserializer,
|
||||
response_serializer=response_serializer_any,
|
||||
),
|
||||
)
|
||||
if handler.stream_unary is not None:
|
||||
return typed_grpc.stream_unary_rpc_method_handler(
|
||||
partial(_reject_stream_unary, message),
|
||||
request_deserializer=request_deserializer,
|
||||
response_serializer=response_serializer,
|
||||
return cast(
|
||||
RpcMethodHandlerProtocol[_TRequest, _TResponse],
|
||||
typed_grpc.stream_unary_rpc_method_handler(
|
||||
partial(_reject_stream_unary, message),
|
||||
request_deserializer=request_deserializer,
|
||||
response_serializer=response_serializer_any,
|
||||
),
|
||||
)
|
||||
if handler.stream_stream is not None:
|
||||
return typed_grpc.stream_stream_rpc_method_handler(
|
||||
partial(_reject_stream_stream, message),
|
||||
request_deserializer=request_deserializer,
|
||||
response_serializer=response_serializer,
|
||||
return cast(
|
||||
RpcMethodHandlerProtocol[_TRequest, _TResponse],
|
||||
typed_grpc.stream_stream_rpc_method_handler(
|
||||
partial(_reject_stream_stream, message),
|
||||
request_deserializer=request_deserializer,
|
||||
response_serializer=response_serializer_any,
|
||||
),
|
||||
)
|
||||
return handler
|
||||
|
||||
@@ -3,10 +3,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Awaitable, Callable
|
||||
from typing import cast
|
||||
|
||||
from .._types import (
|
||||
RpcMethodHandlerProtocol,
|
||||
ServicerContextProtocol,
|
||||
TRequest,
|
||||
TResponse,
|
||||
typed_grpc,
|
||||
)
|
||||
from ._wrappers import (
|
||||
@@ -17,90 +20,118 @@ from ._wrappers import (
|
||||
)
|
||||
|
||||
|
||||
def wrap_unary_unary_handler(
|
||||
handler: RpcMethodHandlerProtocol,
|
||||
def _wrap_unary_unary_handler(
|
||||
handler: RpcMethodHandlerProtocol[TRequest, TResponse],
|
||||
method: str,
|
||||
) -> RpcMethodHandlerProtocol:
|
||||
"""Create wrapped unary-unary handler with logging."""
|
||||
) -> RpcMethodHandlerProtocol[TRequest, TResponse]:
|
||||
if handler.unary_unary is None:
|
||||
raise TypeError("Unary-unary handler is missing")
|
||||
wrapped: Callable[
|
||||
[object, ServicerContextProtocol],
|
||||
Awaitable[object],
|
||||
] = wrap_unary_unary(handler.unary_unary, method)
|
||||
request_deserializer = handler.request_deserializer
|
||||
response_serializer = handler.response_serializer
|
||||
return typed_grpc.unary_unary_rpc_method_handler(
|
||||
wrapped,
|
||||
request_deserializer=request_deserializer,
|
||||
response_serializer=response_serializer,
|
||||
] = wrap_unary_unary(
|
||||
cast(Callable[[object, ServicerContextProtocol], Awaitable[object]], handler.unary_unary),
|
||||
method,
|
||||
)
|
||||
return cast(
|
||||
RpcMethodHandlerProtocol[TRequest, TResponse],
|
||||
typed_grpc.unary_unary_rpc_method_handler(
|
||||
wrapped,
|
||||
request_deserializer=cast(
|
||||
Callable[[bytes], object] | None, handler.request_deserializer
|
||||
),
|
||||
response_serializer=cast(Callable[[object], bytes] | None, handler.response_serializer),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def wrap_unary_stream_handler(
|
||||
handler: RpcMethodHandlerProtocol,
|
||||
def _wrap_unary_stream_handler(
|
||||
handler: RpcMethodHandlerProtocol[TRequest, TResponse],
|
||||
method: str,
|
||||
) -> RpcMethodHandlerProtocol:
|
||||
"""Create wrapped unary-stream handler with logging."""
|
||||
) -> RpcMethodHandlerProtocol[TRequest, TResponse]:
|
||||
if handler.unary_stream is None:
|
||||
raise TypeError("Unary-stream handler is missing")
|
||||
wrapped: Callable[
|
||||
[object, ServicerContextProtocol],
|
||||
AsyncIterator[object],
|
||||
] = wrap_unary_stream(handler.unary_stream, method)
|
||||
request_deserializer = handler.request_deserializer
|
||||
response_serializer = handler.response_serializer
|
||||
return typed_grpc.unary_stream_rpc_method_handler(
|
||||
wrapped,
|
||||
request_deserializer=request_deserializer,
|
||||
response_serializer=response_serializer,
|
||||
] = wrap_unary_stream(
|
||||
cast(
|
||||
Callable[[object, ServicerContextProtocol], AsyncIterator[object]], handler.unary_stream
|
||||
),
|
||||
method,
|
||||
)
|
||||
return cast(
|
||||
RpcMethodHandlerProtocol[TRequest, TResponse],
|
||||
typed_grpc.unary_stream_rpc_method_handler(
|
||||
wrapped,
|
||||
request_deserializer=cast(
|
||||
Callable[[bytes], object] | None, handler.request_deserializer
|
||||
),
|
||||
response_serializer=cast(Callable[[object], bytes] | None, handler.response_serializer),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def wrap_stream_unary_handler(
|
||||
handler: RpcMethodHandlerProtocol,
|
||||
def _wrap_stream_unary_handler(
|
||||
handler: RpcMethodHandlerProtocol[TRequest, TResponse],
|
||||
method: str,
|
||||
) -> RpcMethodHandlerProtocol:
|
||||
"""Create wrapped stream-unary handler with logging."""
|
||||
) -> RpcMethodHandlerProtocol[TRequest, TResponse]:
|
||||
if handler.stream_unary is None:
|
||||
raise TypeError("Stream-unary handler is missing")
|
||||
wrapped: Callable[
|
||||
[AsyncIterator[object], ServicerContextProtocol],
|
||||
Awaitable[object],
|
||||
] = wrap_stream_unary(handler.stream_unary, method)
|
||||
request_deserializer = handler.request_deserializer
|
||||
response_serializer = handler.response_serializer
|
||||
return typed_grpc.stream_unary_rpc_method_handler(
|
||||
wrapped,
|
||||
request_deserializer=request_deserializer,
|
||||
response_serializer=response_serializer,
|
||||
] = wrap_stream_unary(
|
||||
cast(
|
||||
Callable[[AsyncIterator[object], ServicerContextProtocol], Awaitable[object]],
|
||||
handler.stream_unary,
|
||||
),
|
||||
method,
|
||||
)
|
||||
return cast(
|
||||
RpcMethodHandlerProtocol[TRequest, TResponse],
|
||||
typed_grpc.stream_unary_rpc_method_handler(
|
||||
wrapped,
|
||||
request_deserializer=cast(
|
||||
Callable[[bytes], object] | None, handler.request_deserializer
|
||||
),
|
||||
response_serializer=cast(Callable[[object], bytes] | None, handler.response_serializer),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def wrap_stream_stream_handler(
|
||||
handler: RpcMethodHandlerProtocol,
|
||||
def _wrap_stream_stream_handler(
|
||||
handler: RpcMethodHandlerProtocol[TRequest, TResponse],
|
||||
method: str,
|
||||
) -> RpcMethodHandlerProtocol:
|
||||
"""Create wrapped stream-stream handler with logging."""
|
||||
) -> RpcMethodHandlerProtocol[TRequest, TResponse]:
|
||||
if handler.stream_stream is None:
|
||||
raise TypeError("Stream-stream handler is missing")
|
||||
wrapped: Callable[
|
||||
[AsyncIterator[object], ServicerContextProtocol],
|
||||
AsyncIterator[object],
|
||||
] = wrap_stream_stream(handler.stream_stream, method)
|
||||
request_deserializer = handler.request_deserializer
|
||||
response_serializer = handler.response_serializer
|
||||
return typed_grpc.stream_stream_rpc_method_handler(
|
||||
wrapped,
|
||||
request_deserializer=request_deserializer,
|
||||
response_serializer=response_serializer,
|
||||
] = wrap_stream_stream(
|
||||
cast(
|
||||
Callable[[AsyncIterator[object], ServicerContextProtocol], AsyncIterator[object]],
|
||||
handler.stream_stream,
|
||||
),
|
||||
method,
|
||||
)
|
||||
return cast(
|
||||
RpcMethodHandlerProtocol[TRequest, TResponse],
|
||||
typed_grpc.stream_stream_rpc_method_handler(
|
||||
wrapped,
|
||||
request_deserializer=cast(
|
||||
Callable[[bytes], object] | None, handler.request_deserializer
|
||||
),
|
||||
response_serializer=cast(Callable[[object], bytes] | None, handler.response_serializer),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def create_logging_handler(
|
||||
handler: RpcMethodHandlerProtocol,
|
||||
handler: RpcMethodHandlerProtocol[TRequest, TResponse],
|
||||
method: str,
|
||||
) -> RpcMethodHandlerProtocol:
|
||||
) -> RpcMethodHandlerProtocol[TRequest, TResponse]:
|
||||
"""Wrap an RPC handler to add request logging.
|
||||
|
||||
Args:
|
||||
@@ -111,12 +142,11 @@ def create_logging_handler(
|
||||
Wrapped handler with logging.
|
||||
"""
|
||||
if handler.unary_unary is not None:
|
||||
return wrap_unary_unary_handler(handler, method)
|
||||
return _wrap_unary_unary_handler(handler, method)
|
||||
if handler.unary_stream is not None:
|
||||
return wrap_unary_stream_handler(handler, method)
|
||||
return _wrap_unary_stream_handler(handler, method)
|
||||
if handler.stream_unary is not None:
|
||||
return wrap_stream_unary_handler(handler, method)
|
||||
return _wrap_stream_unary_handler(handler, method)
|
||||
if handler.stream_stream is not None:
|
||||
return wrap_stream_stream_handler(handler, method)
|
||||
# Fallback: return original handler if type unknown
|
||||
return _wrap_stream_stream_handler(handler, method)
|
||||
return handler
|
||||
|
||||
@@ -1,19 +1,29 @@
|
||||
"""Request logging interceptor for gRPC calls.
|
||||
|
||||
Log every RPC call with method, status, duration, peer, and request context
|
||||
at INFO level for production observability and traceability.
|
||||
Log every RPC call with method, status, duration, peer, and request context at
|
||||
INFO level for production observability and traceability.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import cast
|
||||
|
||||
from grpc import HandlerCallDetails, RpcMethodHandler
|
||||
from grpc.aio import ServerInterceptor
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
from .._types import (
|
||||
HandlerCallDetailsProtocol,
|
||||
RpcMethodHandlerProtocol,
|
||||
ServerInterceptorProtocol,
|
||||
)
|
||||
from ._handler_factory import create_logging_handler
|
||||
from .._types import TRequest, TResponse
|
||||
|
||||
_TRequest = TypeVar("_TRequest")
|
||||
_TResponse = TypeVar("_TResponse")
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
ServerInterceptor = ServerInterceptorProtocol
|
||||
else:
|
||||
from grpc.aio import ServerInterceptor
|
||||
|
||||
|
||||
class RequestLoggingInterceptor(ServerInterceptor):
|
||||
@@ -30,25 +40,12 @@ class RequestLoggingInterceptor(ServerInterceptor):
|
||||
async def intercept_service(
|
||||
self,
|
||||
continuation: Callable[
|
||||
[HandlerCallDetails],
|
||||
Awaitable[RpcMethodHandler[TRequest, TResponse]],
|
||||
[HandlerCallDetailsProtocol],
|
||||
Awaitable[RpcMethodHandlerProtocol[_TRequest, _TResponse]],
|
||||
],
|
||||
handler_call_details: HandlerCallDetails,
|
||||
) -> RpcMethodHandler[TRequest, TResponse]:
|
||||
"""Intercept incoming RPC calls to log request timing and status.
|
||||
|
||||
Args:
|
||||
continuation: The next interceptor or handler.
|
||||
handler_call_details: Details about the RPC call.
|
||||
|
||||
Returns:
|
||||
Wrapped RPC handler that logs on completion.
|
||||
"""
|
||||
handler_call_details: HandlerCallDetailsProtocol,
|
||||
) -> RpcMethodHandlerProtocol[_TRequest, _TResponse]:
|
||||
"""Intercept incoming RPC calls to log request timing and status."""
|
||||
handler = await continuation(handler_call_details)
|
||||
method = handler_call_details.method
|
||||
|
||||
# Return wrapped handler that logs on completion
|
||||
return cast(
|
||||
RpcMethodHandler[TRequest, TResponse],
|
||||
create_logging_handler(handler, method),
|
||||
)
|
||||
method = handler_call_details.method or ""
|
||||
return create_logging_handler(handler, method)
|
||||
|
||||
@@ -77,7 +77,6 @@ async def require_calendar_service(
|
||||
return host.calendar_service
|
||||
logger.warning(f"{operation}_unavailable", reason="service_not_enabled")
|
||||
await abort_unavailable(context, _ERR_CALENDAR_NOT_ENABLED)
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
|
||||
class CalendarMixin:
|
||||
@@ -117,7 +116,6 @@ class CalendarMixin:
|
||||
except CalendarServiceError as e:
|
||||
logger.error("calendar_list_events_failed", error=str(e), provider=provider)
|
||||
await abort_internal(context, str(e))
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
proto_events = [_calendar_event_to_proto(event) for event in events]
|
||||
|
||||
@@ -166,9 +164,7 @@ class CalendarMixin:
|
||||
status=status.status,
|
||||
)
|
||||
|
||||
authenticated_count = sum(
|
||||
bool(provider.is_authenticated) for provider in providers
|
||||
)
|
||||
authenticated_count = sum(bool(provider.is_authenticated) for provider in providers)
|
||||
logger.info(
|
||||
"calendar_get_providers_success",
|
||||
total_providers=len(providers),
|
||||
@@ -205,7 +201,6 @@ class CalendarMixin:
|
||||
error=str(e),
|
||||
)
|
||||
await abort_invalid_argument(context, str(e))
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
logger.info(
|
||||
"oauth_initiate_success",
|
||||
|
||||
@@ -92,20 +92,18 @@ class DiarizationJobMixin:
|
||||
Prunes both in-memory task references and database records.
|
||||
"""
|
||||
# Clean up in-memory task references for completed tasks
|
||||
completed_tasks = [
|
||||
job_id for job_id, task in self.diarization_tasks.items() if task.done()
|
||||
]
|
||||
completed_tasks = [job_id for job_id, task in self.diarization_tasks.items() if task.done()]
|
||||
for job_id in completed_tasks:
|
||||
self.diarization_tasks.pop(job_id, None)
|
||||
|
||||
# Prune old completed jobs from database
|
||||
async with cast(DiarizationJobRepositoryProvider, self.create_repository_provider()) as repo:
|
||||
async with cast(
|
||||
DiarizationJobRepositoryProvider, self.create_repository_provider()
|
||||
) as repo:
|
||||
if not repo.supports_diarization_jobs:
|
||||
logger.debug("Job pruning skipped: database required")
|
||||
return
|
||||
pruned = await repo.diarization_jobs.prune_completed(
|
||||
self.diarization_job_ttl_seconds
|
||||
)
|
||||
pruned = await repo.diarization_jobs.prune_completed(self.diarization_job_ttl_seconds)
|
||||
await repo.commit()
|
||||
if pruned > 0:
|
||||
logger.debug("Pruned %d completed diarization jobs", pruned)
|
||||
@@ -121,15 +119,15 @@ class DiarizationJobMixin:
|
||||
"""
|
||||
await self.prune_diarization_jobs()
|
||||
|
||||
async with cast(DiarizationJobRepositoryProvider, self.create_repository_provider()) as repo:
|
||||
async with cast(
|
||||
DiarizationJobRepositoryProvider, self.create_repository_provider()
|
||||
) as repo:
|
||||
if not repo.supports_diarization_jobs:
|
||||
await abort_not_found(context, "Diarization jobs (database required)", "")
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
job = await repo.diarization_jobs.get(request.job_id)
|
||||
if job is None:
|
||||
await abort_not_found(context, "Diarization job", request.job_id)
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
return _build_job_status(job)
|
||||
|
||||
@@ -145,7 +143,9 @@ class DiarizationJobMixin:
|
||||
job_id = request.job_id
|
||||
await _cancel_running_task(self.diarization_tasks, job_id)
|
||||
|
||||
async with cast(DiarizationJobRepositoryProvider, self.create_repository_provider()) as repo:
|
||||
async with cast(
|
||||
DiarizationJobRepositoryProvider, self.create_repository_provider()
|
||||
) as repo:
|
||||
if not repo.supports_diarization_jobs:
|
||||
await abort_database_required(context, "Diarization job cancellation")
|
||||
raise AssertionError(UNREACHABLE_ERROR) # abort is NoReturn
|
||||
@@ -185,7 +185,9 @@ class DiarizationJobMixin:
|
||||
"""
|
||||
response = noteflow_pb2.GetActiveDiarizationJobsResponse()
|
||||
|
||||
async with cast(DiarizationJobRepositoryProvider, self.create_repository_provider()) as repo:
|
||||
async with cast(
|
||||
DiarizationJobRepositoryProvider, self.create_repository_provider()
|
||||
) as repo:
|
||||
if not repo.supports_diarization_jobs:
|
||||
# Return empty list if DB not available
|
||||
return response
|
||||
|
||||
@@ -175,7 +175,6 @@ class EntitiesMixin:
|
||||
|
||||
if updated is None:
|
||||
await abort_not_found(context, ENTITY_ENTITY, request.entity_id)
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
await uow.commit()
|
||||
|
||||
|
||||
@@ -75,9 +75,7 @@ def _merge_workspace_settings(
|
||||
trigger_rules=updates.trigger_rules
|
||||
if updates.trigger_rules is not None
|
||||
else current.trigger_rules,
|
||||
rag_enabled=updates.rag_enabled
|
||||
if updates.rag_enabled is not None
|
||||
else current.rag_enabled,
|
||||
rag_enabled=updates.rag_enabled if updates.rag_enabled is not None else current.rag_enabled,
|
||||
default_summarization_template=updates.default_summarization_template
|
||||
if updates.default_summarization_template is not None
|
||||
else current.default_summarization_template,
|
||||
@@ -192,14 +190,12 @@ class IdentityMixin:
|
||||
|
||||
if not request.workspace_id:
|
||||
await abort_invalid_argument(context, ERROR_WORKSPACE_ID_REQUIRED)
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
workspace_id = await parse_workspace_id(request.workspace_id, context)
|
||||
|
||||
async with cast(UnitOfWork, self.create_repository_provider()) as uow:
|
||||
if not uow.supports_workspaces:
|
||||
await abort_database_required(context, WORKSPACES_LABEL)
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
user_ctx = await self.identity_service.get_or_create_default_user(uow)
|
||||
workspace, membership = await self._verify_workspace_access(
|
||||
@@ -231,14 +227,12 @@ class IdentityMixin:
|
||||
"""Get workspace settings."""
|
||||
if not request.workspace_id:
|
||||
await abort_invalid_argument(context, ERROR_WORKSPACE_ID_REQUIRED)
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
workspace_id = await parse_workspace_id(request.workspace_id, context)
|
||||
|
||||
async with cast(UnitOfWork, self.create_repository_provider()) as uow:
|
||||
if not uow.supports_workspaces:
|
||||
await abort_database_required(context, WORKSPACES_LABEL)
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
user_ctx = await self.identity_service.get_or_create_default_user(uow)
|
||||
workspace, _ = await self._verify_workspace_access(
|
||||
@@ -259,14 +253,12 @@ class IdentityMixin:
|
||||
"""Update workspace settings."""
|
||||
if not request.workspace_id:
|
||||
await abort_invalid_argument(context, ERROR_WORKSPACE_ID_REQUIRED)
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
workspace_id = await parse_workspace_id(request.workspace_id, context)
|
||||
|
||||
async with cast(UnitOfWork, self.create_repository_provider()) as uow:
|
||||
if not uow.supports_workspaces:
|
||||
await abort_database_required(context, WORKSPACES_LABEL)
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
user_ctx = await self.identity_service.get_or_create_default_user(uow)
|
||||
workspace, membership = await self._verify_workspace_access(
|
||||
@@ -278,7 +270,6 @@ class IdentityMixin:
|
||||
|
||||
if not membership.role.can_admin():
|
||||
await abort_permission_denied(context, ERROR_WORKSPACE_ADMIN_REQUIRED)
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
updates = proto_to_workspace_settings(request.settings)
|
||||
if updates is None:
|
||||
@@ -302,11 +293,9 @@ class IdentityMixin:
|
||||
workspace = await uow.workspaces.get(workspace_id)
|
||||
if not workspace:
|
||||
await abort_not_found(context, ENTITY_WORKSPACE, str(workspace_id))
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
membership = await uow.workspaces.get_membership(workspace_id, user_id)
|
||||
if not membership:
|
||||
await abort_not_found(context, "Workspace membership", str(workspace_id))
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
return workspace, membership
|
||||
|
||||
@@ -42,7 +42,6 @@ if TYPE_CHECKING:
|
||||
|
||||
from .._types import GrpcContext
|
||||
from ..protocols import ServicerHost
|
||||
from ..errors._constants import UNREACHABLE_ERROR
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -66,7 +65,6 @@ async def _load_meeting_for_stop(
|
||||
if meeting is None:
|
||||
logger.warning("StopMeeting: meeting not found", meeting_id=meeting_id_str)
|
||||
await abort_not_found(context, ENTITY_MEETING, meeting_id_str)
|
||||
raise AssertionError(UNREACHABLE_ERROR)
|
||||
return meeting
|
||||
|
||||
|
||||
@@ -85,7 +83,7 @@ async def _stop_meeting_and_persist(context: _StopMeetingContext) -> Meeting:
|
||||
)
|
||||
return context.meeting
|
||||
|
||||
previous_state = context.meeting.state.value
|
||||
previous_state = context.meeting.state.name
|
||||
await transition_to_stopped(
|
||||
context.meeting,
|
||||
context.meeting_id,
|
||||
@@ -143,7 +141,9 @@ class MeetingMixin:
|
||||
op_context = self.get_operation_context(context)
|
||||
|
||||
async with cast(MeetingRepositoryProvider, self.create_repository_provider()) as repo:
|
||||
project_id = await resolve_create_project_id(self, repo, op_context, project_id)
|
||||
project_id = await resolve_create_project_id(
|
||||
cast("ServicerHost", self), repo, op_context, project_id
|
||||
)
|
||||
|
||||
meeting = Meeting.create(
|
||||
title=request.title,
|
||||
@@ -212,7 +212,7 @@ class MeetingMixin:
|
||||
|
||||
async with cast(MeetingRepositoryProvider, self.create_repository_provider()) as repo:
|
||||
if project_id is None and not project_ids:
|
||||
project_id = await resolve_active_project_id(self, repo)
|
||||
project_id = await resolve_active_project_id(cast("ServicerHost", self), repo)
|
||||
|
||||
meetings, total = await repo.meetings.list_all(
|
||||
states=states,
|
||||
@@ -254,7 +254,6 @@ class MeetingMixin:
|
||||
if meeting is None:
|
||||
logger.warning("GetMeeting: meeting not found", meeting_id=request.meeting_id)
|
||||
await abort_not_found(context, ENTITY_MEETING, request.meeting_id)
|
||||
raise # Unreachable but helps type checker
|
||||
# Load segments if requested
|
||||
if request.include_segments:
|
||||
segments = await repo.segments.get_by_meeting(meeting.id)
|
||||
@@ -283,7 +282,6 @@ class MeetingMixin:
|
||||
if not success:
|
||||
logger.warning("DeleteMeeting: meeting not found", meeting_id=request.meeting_id)
|
||||
await abort_not_found(context, ENTITY_MEETING, request.meeting_id)
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
await repo.commit()
|
||||
logger.info("Meeting deleted", meeting_id=request.meeting_id)
|
||||
|
||||
@@ -35,8 +35,6 @@ if TYPE_CHECKING:
|
||||
from ..protocols import ProjectRepositoryProvider
|
||||
|
||||
|
||||
|
||||
|
||||
async def _parse_project_and_user_ids(
|
||||
request_project_id: str,
|
||||
request_user_id: str,
|
||||
@@ -45,21 +43,19 @@ async def _parse_project_and_user_ids(
|
||||
"""Parse and validate project and user IDs from request."""
|
||||
if not request_project_id:
|
||||
await abort_invalid_argument(context, ERROR_PROJECT_ID_REQUIRED)
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
if not request_user_id:
|
||||
await abort_invalid_argument(context, ERROR_USER_ID_REQUIRED)
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
try:
|
||||
project_id = UUID(request_project_id)
|
||||
user_id = UUID(request_user_id)
|
||||
except ValueError as e:
|
||||
await abort_invalid_argument(context, f"{ERROR_INVALID_UUID_PREFIX}{e}")
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
return project_id, user_id
|
||||
|
||||
|
||||
class ProjectMembershipMixin:
|
||||
"""Mixin providing project membership functionality.
|
||||
|
||||
@@ -93,7 +89,6 @@ class ProjectMembershipMixin:
|
||||
)
|
||||
if membership is None:
|
||||
await abort_not_found(context, ENTITY_PROJECT, request.project_id)
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
return membership_to_proto(membership)
|
||||
|
||||
@@ -119,8 +114,9 @@ class ProjectMembershipMixin:
|
||||
role=role,
|
||||
)
|
||||
if membership is None:
|
||||
await abort_not_found(context, "Membership", f"{request.project_id}/{request.user_id}")
|
||||
raise # Unreachable but helps type checker
|
||||
await abort_not_found(
|
||||
context, "Membership", f"{request.project_id}/{request.user_id}"
|
||||
)
|
||||
|
||||
return membership_to_proto(membership)
|
||||
|
||||
@@ -159,8 +155,9 @@ class ProjectMembershipMixin:
|
||||
try:
|
||||
project_id = UUID(request.project_id)
|
||||
except ValueError:
|
||||
await abort_invalid_argument(context, f"{ERROR_INVALID_PROJECT_ID_PREFIX}{request.project_id}")
|
||||
raise # Unreachable but helps type checker
|
||||
await abort_invalid_argument(
|
||||
context, f"{ERROR_INVALID_PROJECT_ID_PREFIX}{request.project_id}"
|
||||
)
|
||||
|
||||
limit = request.limit if request.limit > 0 else 100
|
||||
offset = max(request.offset, 0)
|
||||
|
||||
@@ -41,8 +41,6 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
|
||||
|
||||
async def _require_and_parse_project_id(
|
||||
request_project_id: str,
|
||||
context: GrpcContext,
|
||||
@@ -50,7 +48,6 @@ async def _require_and_parse_project_id(
|
||||
"""Require and parse a project_id from request."""
|
||||
if not request_project_id:
|
||||
await abort_invalid_argument(context, ERROR_PROJECT_ID_REQUIRED)
|
||||
raise # Unreachable but helps type checker
|
||||
return await parse_project_id(request_project_id, context)
|
||||
|
||||
|
||||
@@ -61,9 +58,9 @@ async def _require_and_parse_workspace_id(
|
||||
"""Require and parse a workspace_id from request."""
|
||||
if not request_workspace_id:
|
||||
await abort_invalid_argument(context, ERROR_WORKSPACE_ID_REQUIRED)
|
||||
raise # Unreachable but helps type checker
|
||||
return await parse_workspace_id(request_workspace_id, context)
|
||||
|
||||
|
||||
class ProjectMixin:
|
||||
"""Mixin providing project management functionality.
|
||||
|
||||
@@ -89,11 +86,12 @@ class ProjectMixin:
|
||||
|
||||
if not request.name:
|
||||
await abort_invalid_argument(context, "name is required")
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
slug = request.slug if request.HasField("slug") else None
|
||||
description = request.description if request.HasField("description") else None
|
||||
settings = proto_to_project_settings(request.settings) if request.HasField("settings") else None
|
||||
settings = (
|
||||
proto_to_project_settings(request.settings) if request.HasField("settings") else None
|
||||
)
|
||||
|
||||
async with cast(ProjectRepositoryProvider, self.create_repository_provider()) as uow:
|
||||
await require_feature_projects(uow, context)
|
||||
@@ -123,7 +121,6 @@ class ProjectMixin:
|
||||
project = await project_service.get_project(uow, project_id)
|
||||
if project is None:
|
||||
await abort_not_found(context, ENTITY_PROJECT, request.project_id)
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
return project_to_proto(project)
|
||||
|
||||
@@ -138,7 +135,6 @@ class ProjectMixin:
|
||||
|
||||
if not request.slug:
|
||||
await abort_invalid_argument(context, "slug is required")
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
async with cast(ProjectRepositoryProvider, self.create_repository_provider()) as uow:
|
||||
await require_feature_projects(uow, context)
|
||||
@@ -146,7 +142,6 @@ class ProjectMixin:
|
||||
project = await project_service.get_project_by_slug(uow, workspace_id, request.slug)
|
||||
if project is None:
|
||||
await abort_not_found(context, ENTITY_PROJECT, request.slug)
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
return project_to_proto(project)
|
||||
|
||||
@@ -159,6 +154,7 @@ class ProjectMixin:
|
||||
project_service = await require_project_service(self.project_service, context)
|
||||
workspace_id = await _require_and_parse_workspace_id(request.workspace_id, context)
|
||||
|
||||
include_archived = request.include_archived
|
||||
limit = request.limit if request.limit > 0 else 50
|
||||
offset = max(request.offset, 0)
|
||||
|
||||
@@ -166,21 +162,20 @@ class ProjectMixin:
|
||||
await require_feature_projects(uow, context)
|
||||
|
||||
projects = await project_service.list_projects(
|
||||
uow=uow,
|
||||
workspace_id=workspace_id,
|
||||
include_archived=request.include_archived,
|
||||
uow,
|
||||
workspace_id,
|
||||
include_archived=include_archived,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
total_count = await project_service.count_projects(
|
||||
uow=uow,
|
||||
workspace_id=workspace_id,
|
||||
include_archived=request.include_archived,
|
||||
uow,
|
||||
workspace_id,
|
||||
include_archived=include_archived,
|
||||
)
|
||||
|
||||
return noteflow_pb2.ListProjectsResponse(
|
||||
projects=[project_to_proto(p) for p in projects],
|
||||
projects=[project_to_proto(project) for project in projects],
|
||||
total_count=total_count,
|
||||
)
|
||||
|
||||
@@ -196,14 +191,16 @@ class ProjectMixin:
|
||||
name = request.name if request.HasField("name") else None
|
||||
slug = request.slug if request.HasField("slug") else None
|
||||
description = request.description if request.HasField("description") else None
|
||||
settings = proto_to_project_settings(request.settings) if request.HasField("settings") else None
|
||||
settings = (
|
||||
proto_to_project_settings(request.settings) if request.HasField("settings") else None
|
||||
)
|
||||
|
||||
async with cast(ProjectRepositoryProvider, self.create_repository_provider()) as uow:
|
||||
await require_feature_projects(uow, context)
|
||||
|
||||
project = await project_service.update_project(
|
||||
uow=uow,
|
||||
project_id=project_id,
|
||||
uow,
|
||||
project_id,
|
||||
name=name,
|
||||
slug=slug,
|
||||
description=description,
|
||||
@@ -211,7 +208,6 @@ class ProjectMixin:
|
||||
)
|
||||
if project is None:
|
||||
await abort_not_found(context, ENTITY_PROJECT, request.project_id)
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
return project_to_proto(project)
|
||||
|
||||
@@ -231,11 +227,9 @@ class ProjectMixin:
|
||||
project = await project_service.archive_project(uow, project_id)
|
||||
except CannotArchiveDefaultProjectError:
|
||||
await abort_failed_precondition(context, "Cannot archive the default project")
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
if project is None:
|
||||
await abort_not_found(context, ENTITY_PROJECT, request.project_id)
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
return project_to_proto(project)
|
||||
|
||||
@@ -254,7 +248,6 @@ class ProjectMixin:
|
||||
project = await project_service.restore_project(uow, project_id)
|
||||
if project is None:
|
||||
await abort_not_found(context, ENTITY_PROJECT, request.project_id)
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
return project_to_proto(project)
|
||||
|
||||
@@ -302,7 +295,6 @@ class ProjectMixin:
|
||||
)
|
||||
except ValueError as exc:
|
||||
await abort_invalid_argument(context, str(exc))
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
await uow.commit()
|
||||
return noteflow_pb2.SetActiveProjectResponse()
|
||||
@@ -327,11 +319,9 @@ class ProjectMixin:
|
||||
)
|
||||
except ValueError as exc:
|
||||
await abort_invalid_argument(context, str(exc))
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
if project is None:
|
||||
await abort_not_found(context, ENTITY_PROJECT, "default")
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
response = noteflow_pb2.GetActiveProjectResponse(
|
||||
project=project_to_proto(project),
|
||||
|
||||
@@ -59,7 +59,6 @@ async def _decode_chunk_audio(
|
||||
)
|
||||
except ValueError as e:
|
||||
await abort_invalid_argument(context, str(e))
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
conversion = AudioConversionContext(
|
||||
source_sample_rate=sample_rate,
|
||||
|
||||
@@ -30,10 +30,13 @@ from ._template_resolution import (
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from noteflow.application.services.ner import NerService
|
||||
from noteflow.application.services.summarization import SummarizationService
|
||||
from noteflow.application.services.webhooks import WebhookService
|
||||
from noteflow.domain.entities import Meeting
|
||||
from noteflow.domain.identity import OperationContext
|
||||
from noteflow.infrastructure.asr import FasterWhisperEngine
|
||||
from noteflow.infrastructure.diarization.engine import DiarizationEngine
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -48,9 +51,41 @@ class _SummaryGenerationContext:
|
||||
force_regenerate: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _SummaryRequestContext:
|
||||
request: noteflow_pb2.GenerateSummaryRequest
|
||||
meeting: Meeting
|
||||
segments: list[Segment]
|
||||
style_instructions: str | None
|
||||
context: GrpcContext
|
||||
op_context: OperationContext
|
||||
|
||||
async def resolve_style_prompt(
|
||||
self,
|
||||
repo: UnitOfWork,
|
||||
summarization_service: SummarizationService | None,
|
||||
) -> str | None:
|
||||
"""Resolve style prompt for this request context."""
|
||||
return await resolve_template_prompt(
|
||||
TemplateResolutionInputs(
|
||||
request=self.request,
|
||||
meeting=self.meeting,
|
||||
segments=self.segments,
|
||||
style_instructions=self.style_instructions,
|
||||
context=self.context,
|
||||
op_context=self.op_context,
|
||||
repo=repo,
|
||||
summarization_service=summarization_service,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class SummarizationGenerationMixin:
|
||||
"""Generate summaries and handle summary webhooks."""
|
||||
|
||||
asr_engine: FasterWhisperEngine | None
|
||||
diarization_engine: DiarizationEngine | None
|
||||
ner_service: NerService | None
|
||||
summarization_service: SummarizationService | None
|
||||
webhook_service: WebhookService | None
|
||||
create_repository_provider: Callable[..., object]
|
||||
@@ -68,28 +103,28 @@ class SummarizationGenerationMixin:
|
||||
meeting_id=request.meeting_id,
|
||||
include_provider_details=True,
|
||||
)
|
||||
meeting_id, op_context, style_instructions, meeting, existing, segments = (
|
||||
await self._prepare_summary_request(request, context)
|
||||
)
|
||||
(
|
||||
meeting_id,
|
||||
op_context,
|
||||
style_instructions,
|
||||
meeting,
|
||||
existing,
|
||||
segments,
|
||||
) = await self._prepare_summary_request(request, context)
|
||||
if existing and not request.force_regenerate:
|
||||
await self._mark_summary_step(request.meeting_id, ProcessingStepStatus.COMPLETED)
|
||||
return summary_to_proto(existing)
|
||||
|
||||
await self._ensure_cloud_provider()
|
||||
|
||||
async with cast(UnitOfWork, self.create_repository_provider()) as repo:
|
||||
style_prompt = await self._resolve_style_prompt(
|
||||
TemplateResolutionInputs(
|
||||
request=request,
|
||||
meeting=meeting,
|
||||
segments=segments,
|
||||
style_instructions=style_instructions,
|
||||
context=context,
|
||||
op_context=op_context,
|
||||
repo=repo,
|
||||
summarization_service=self.summarization_service,
|
||||
)
|
||||
)
|
||||
request_context = _SummaryRequestContext(
|
||||
request=request,
|
||||
meeting=meeting,
|
||||
segments=segments,
|
||||
style_instructions=style_instructions,
|
||||
context=context,
|
||||
op_context=op_context,
|
||||
)
|
||||
style_prompt = await self._resolve_style_prompt_for_request(request_context)
|
||||
saved, trigger_webhook = await self._generate_summary_with_status(
|
||||
_SummaryGenerationContext(
|
||||
meeting_id=meeting_id,
|
||||
@@ -121,6 +156,13 @@ class SummarizationGenerationMixin:
|
||||
),
|
||||
)
|
||||
|
||||
async def _resolve_style_prompt_for_request(
|
||||
self,
|
||||
request_context: _SummaryRequestContext,
|
||||
) -> str | None:
|
||||
async with cast(UnitOfWork, self.create_repository_provider()) as repo:
|
||||
return await request_context.resolve_style_prompt(repo, self.summarization_service)
|
||||
|
||||
async def _resolve_style_prompt(self, inputs: TemplateResolutionInputs) -> str | None:
|
||||
return await resolve_template_prompt(inputs)
|
||||
|
||||
@@ -244,7 +286,6 @@ class SummarizationGenerationMixin:
|
||||
meeting = await repo.meetings.get(meeting_id)
|
||||
if meeting is None:
|
||||
await abort_not_found(context, ENTITY_MEETING, request.meeting_id)
|
||||
raise # Unreachable but helps type checker
|
||||
|
||||
existing = await repo.summaries.get_by_meeting(meeting.id)
|
||||
segments = list(await repo.segments.get_by_meeting(meeting.id))
|
||||
|
||||
@@ -9,6 +9,8 @@ from typing import Generic, TypeVar, Never
|
||||
|
||||
import grpc
|
||||
|
||||
from noteflow.domain.constants.fields import UNKNOWN as UNKNOWN_VALUE
|
||||
|
||||
T = TypeVar("T", covariant=True)
|
||||
U = TypeVar("U")
|
||||
E = TypeVar("E", bound=BaseException)
|
||||
@@ -17,7 +19,7 @@ E = TypeVar("E", bound=BaseException)
|
||||
class ClientErrorCode(StrEnum):
|
||||
"""Client-facing error codes for gRPC operations."""
|
||||
|
||||
UNKNOWN = "unknown"
|
||||
UNKNOWN = UNKNOWN_VALUE
|
||||
NOT_CONNECTED = "not_connected"
|
||||
NOT_FOUND = "not_found"
|
||||
INVALID_ARGUMENT = "invalid_argument"
|
||||
|
||||
@@ -5,8 +5,8 @@ Tests identity context validation and per-RPC request logging.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Protocol, cast
|
||||
from collections.abc import AsyncIterator, Awaitable, Callable
|
||||
from typing import cast
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import grpc
|
||||
@@ -19,7 +19,11 @@ from noteflow.grpc.interceptors import (
|
||||
IdentityInterceptor,
|
||||
RequestLoggingInterceptor,
|
||||
)
|
||||
from noteflow.grpc.interceptors._types import ServicerContextProtocol
|
||||
from noteflow.grpc.interceptors._types import (
|
||||
HandlerCallDetailsProtocol,
|
||||
RpcMethodHandlerProtocol,
|
||||
ServicerContextProtocol,
|
||||
)
|
||||
from noteflow.infrastructure.logging import (
|
||||
get_request_id,
|
||||
get_user_id,
|
||||
@@ -27,14 +31,13 @@ from noteflow.infrastructure.logging import (
|
||||
request_id_var,
|
||||
)
|
||||
|
||||
# Type alias for callable unary-unary handlers
|
||||
UnaryUnaryCallable = Callable[[bytes, ServicerContextProtocol], Awaitable[bytes]]
|
||||
|
||||
class _DummyRequest:
|
||||
"""Placeholder request type for handler casts."""
|
||||
# Type alias for test handler protocol (using bytes for wire-level compatibility)
|
||||
TestHandlerProtocol = RpcMethodHandlerProtocol[bytes, bytes]
|
||||
|
||||
|
||||
class _DummyResponse:
|
||||
"""Placeholder response type for handler casts."""
|
||||
|
||||
# Test data
|
||||
TEST_REQUEST_ID = "test-request-123"
|
||||
TEST_USER_ID = "user-456"
|
||||
@@ -45,25 +48,28 @@ TEST_METHOD = "/noteflow.NoteFlowService/GetMeeting"
|
||||
pytestmark = pytest.mark.usefixtures("reset_context_vars")
|
||||
|
||||
|
||||
# Type alias for handler call details matching grpc module
|
||||
HandlerCallDetails = HandlerCallDetailsProtocol
|
||||
|
||||
|
||||
def create_handler_call_details(
|
||||
method: str = TEST_METHOD,
|
||||
metadata: list[tuple[str, str | bytes]] | None = None,
|
||||
) -> grpc.HandlerCallDetails:
|
||||
) -> HandlerCallDetails:
|
||||
"""Create mock HandlerCallDetails with metadata."""
|
||||
details = MagicMock(spec=grpc.HandlerCallDetails)
|
||||
details = MagicMock(spec=HandlerCallDetails)
|
||||
details.method = method
|
||||
details.invocation_metadata = metadata or []
|
||||
return details
|
||||
|
||||
|
||||
def create_mock_handler() -> _UnaryUnaryHandler:
|
||||
def create_mock_handler() -> TestHandlerProtocol:
|
||||
"""Create a mock RPC method handler."""
|
||||
return _MockHandler()
|
||||
|
||||
|
||||
def create_mock_continuation(
|
||||
handler: _UnaryUnaryHandler | None = None,
|
||||
handler: TestHandlerProtocol | None = None,
|
||||
) -> AsyncMock:
|
||||
"""Create a mock continuation function."""
|
||||
if handler is None:
|
||||
@@ -71,46 +77,75 @@ def create_mock_continuation(
|
||||
return AsyncMock(return_value=handler)
|
||||
|
||||
|
||||
class _UnaryUnaryHandler(Protocol):
|
||||
"""Protocol for unary-unary RPC method handlers."""
|
||||
|
||||
unary_unary: Callable[
|
||||
[_DummyRequest, ServicerContextProtocol],
|
||||
Awaitable[_DummyResponse],
|
||||
] | None
|
||||
unary_stream: object | None
|
||||
stream_unary: object | None
|
||||
stream_stream: object | None
|
||||
request_deserializer: object | None
|
||||
response_serializer: object | None
|
||||
|
||||
|
||||
class _MockHandler:
|
||||
"""Concrete handler for tests with typed unary_unary."""
|
||||
"""Concrete handler for tests with typed unary_unary.
|
||||
|
||||
unary_unary: Callable[
|
||||
[_DummyRequest, ServicerContextProtocol],
|
||||
Awaitable[_DummyResponse],
|
||||
] | None
|
||||
unary_stream: object | None
|
||||
stream_unary: object | None
|
||||
stream_stream: object | None
|
||||
request_deserializer: object | None
|
||||
response_serializer: object | None
|
||||
Implements RpcMethodHandlerProtocol for interceptor testing.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.unary_unary = cast(
|
||||
Callable[
|
||||
[_DummyRequest, ServicerContextProtocol],
|
||||
Awaitable[_DummyResponse],
|
||||
],
|
||||
AsyncMock(return_value="response"),
|
||||
self._unary_unary: Callable[[bytes, ServicerContextProtocol], Awaitable[bytes]] | None = (
|
||||
cast(
|
||||
Callable[[bytes, ServicerContextProtocol], Awaitable[bytes]],
|
||||
AsyncMock(return_value=b"response"),
|
||||
)
|
||||
)
|
||||
self.unary_stream = None
|
||||
self.stream_unary = None
|
||||
self.stream_stream = None
|
||||
self.request_deserializer = None
|
||||
self.response_serializer = None
|
||||
self._unary_stream: (
|
||||
Callable[[bytes, ServicerContextProtocol], AsyncIterator[bytes]] | None
|
||||
) = None
|
||||
self._stream_unary: (
|
||||
Callable[[AsyncIterator[bytes], ServicerContextProtocol], Awaitable[bytes]] | None
|
||||
) = None
|
||||
self._stream_stream: (
|
||||
Callable[[AsyncIterator[bytes], ServicerContextProtocol], AsyncIterator[bytes]] | None
|
||||
) = None
|
||||
self._request_deserializer: Callable[[bytes], bytes] | None = None
|
||||
self._response_serializer: Callable[[bytes], bytes] | None = None
|
||||
|
||||
@property
|
||||
def request_deserializer(self) -> Callable[[bytes], bytes] | None:
|
||||
"""Request deserializer function."""
|
||||
return self._request_deserializer
|
||||
|
||||
@property
|
||||
def response_serializer(self) -> Callable[[bytes], bytes] | None:
|
||||
"""Response serializer function."""
|
||||
return self._response_serializer
|
||||
|
||||
@property
|
||||
def unary_unary(
|
||||
self,
|
||||
) -> Callable[[bytes, ServicerContextProtocol], Awaitable[bytes]] | None:
|
||||
"""Unary-unary handler behavior."""
|
||||
return self._unary_unary
|
||||
|
||||
@unary_unary.setter
|
||||
def unary_unary(
|
||||
self, value: Callable[[bytes, ServicerContextProtocol], Awaitable[bytes]] | None
|
||||
) -> None:
|
||||
"""Set unary-unary handler behavior."""
|
||||
self._unary_unary = value
|
||||
|
||||
@property
|
||||
def unary_stream(
|
||||
self,
|
||||
) -> Callable[[bytes, ServicerContextProtocol], AsyncIterator[bytes]] | None:
|
||||
"""Unary-stream handler behavior."""
|
||||
return self._unary_stream
|
||||
|
||||
@property
|
||||
def stream_unary(
|
||||
self,
|
||||
) -> Callable[[AsyncIterator[bytes], ServicerContextProtocol], Awaitable[bytes]] | None:
|
||||
"""Stream-unary handler behavior."""
|
||||
return self._stream_unary
|
||||
|
||||
@property
|
||||
def stream_stream(
|
||||
self,
|
||||
) -> Callable[[AsyncIterator[bytes], ServicerContextProtocol], AsyncIterator[bytes]] | None:
|
||||
"""Stream-stream handler behavior."""
|
||||
return self._stream_stream
|
||||
|
||||
|
||||
class TestIdentityInterceptor:
|
||||
@@ -157,8 +192,9 @@ class TestIdentityInterceptor:
|
||||
original_handler = create_mock_handler()
|
||||
continuation = create_mock_continuation(original_handler)
|
||||
|
||||
handler = await interceptor.intercept_service(continuation, details)
|
||||
typed_handler = cast(_UnaryUnaryHandler, handler)
|
||||
typed_handler: TestHandlerProtocol = await interceptor.intercept_service(
|
||||
continuation, details
|
||||
)
|
||||
|
||||
# Handler should be a rejection handler wrapping the original
|
||||
assert typed_handler.unary_unary is not None, "handler should have unary_unary"
|
||||
@@ -166,7 +202,9 @@ class TestIdentityInterceptor:
|
||||
# is a rejection wrapper that will abort with UNAUTHENTICATED
|
||||
continuation.assert_called_once()
|
||||
# The returned handler should NOT be the original handler
|
||||
assert typed_handler.unary_unary is not original_handler.unary_unary, "should return rejection handler, not original"
|
||||
assert typed_handler.unary_unary is not original_handler.unary_unary, (
|
||||
"should return rejection handler, not original"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reject_handler_aborts_with_unauthenticated(self) -> None:
|
||||
@@ -175,8 +213,9 @@ class TestIdentityInterceptor:
|
||||
details = create_handler_call_details(metadata=[])
|
||||
continuation = create_mock_continuation()
|
||||
|
||||
handler = await interceptor.intercept_service(continuation, details)
|
||||
typed_handler = cast(_UnaryUnaryHandler, handler)
|
||||
typed_handler: TestHandlerProtocol = await interceptor.intercept_service(
|
||||
continuation, details
|
||||
)
|
||||
|
||||
# Create mock context to verify abort behavior
|
||||
context: ServicerContextProtocol = AsyncMock(spec=ServicerContextProtocol)
|
||||
@@ -184,11 +223,14 @@ class TestIdentityInterceptor:
|
||||
|
||||
with pytest.raises(grpc.RpcError, match="x-request-id"):
|
||||
assert typed_handler.unary_unary is not None
|
||||
await typed_handler.unary_unary(MagicMock(), context)
|
||||
unary_fn = typed_handler.unary_unary
|
||||
await unary_fn(MagicMock(), context)
|
||||
|
||||
context.abort.assert_called_once()
|
||||
call_args = context.abort.call_args
|
||||
assert call_args[0][0] == grpc.StatusCode.UNAUTHENTICATED, "should abort with UNAUTHENTICATED"
|
||||
assert call_args[0][0] == grpc.StatusCode.UNAUTHENTICATED, (
|
||||
"should abort with UNAUTHENTICATED"
|
||||
)
|
||||
assert "x-request-id" in call_args[0][1], "error message should mention x-request-id"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -224,16 +266,16 @@ class TestRequestLoggingInterceptor:
|
||||
request_id_var.set(TEST_REQUEST_ID)
|
||||
|
||||
with patch("noteflow.grpc.interceptors.logging._logging_ops.logger") as mock_logger:
|
||||
wrapped_handler = await interceptor.intercept_service(continuation, details)
|
||||
typed_handler = cast(_UnaryUnaryHandler, wrapped_handler)
|
||||
wrapped_handler: TestHandlerProtocol = await interceptor.intercept_service(
|
||||
continuation, details
|
||||
)
|
||||
|
||||
# Execute the wrapped handler
|
||||
context: ServicerContextProtocol = AsyncMock(spec=ServicerContextProtocol)
|
||||
context.peer = MagicMock(return_value="ipv4:127.0.0.1:12345")
|
||||
assert typed_handler.unary_unary is not None
|
||||
await typed_handler.unary_unary(MagicMock(), context)
|
||||
assert wrapped_handler.unary_unary is not None
|
||||
unary_fn = wrapped_handler.unary_unary
|
||||
await unary_fn(MagicMock(), context)
|
||||
|
||||
# Verify logging
|
||||
mock_logger.info.assert_called_once()
|
||||
call_kwargs = mock_logger.info.call_args[1]
|
||||
assert call_kwargs["method"] == TEST_METHOD, "should log method"
|
||||
@@ -246,28 +288,31 @@ class TestRequestLoggingInterceptor:
|
||||
"""Interceptor logs error status when handler raises."""
|
||||
interceptor = RequestLoggingInterceptor()
|
||||
|
||||
# Create handler that raises
|
||||
handler = create_mock_handler()
|
||||
handler.unary_unary = AsyncMock(side_effect=Exception("Test error"))
|
||||
mock_handler = cast(_MockHandler, handler)
|
||||
mock_handler.unary_unary = AsyncMock(side_effect=Exception("Test error"))
|
||||
|
||||
details = create_handler_call_details(method=TEST_METHOD)
|
||||
continuation = create_mock_continuation(handler)
|
||||
|
||||
with patch("noteflow.grpc.interceptors.logging._logging_ops.logger") as mock_logger:
|
||||
wrapped_handler = await interceptor.intercept_service(continuation, details)
|
||||
typed_handler = cast(_UnaryUnaryHandler, wrapped_handler)
|
||||
wrapped_handler: TestHandlerProtocol = await interceptor.intercept_service(
|
||||
continuation, details
|
||||
)
|
||||
|
||||
context: ServicerContextProtocol = AsyncMock(spec=ServicerContextProtocol)
|
||||
context.peer = MagicMock(return_value="ipv4:127.0.0.1:12345")
|
||||
|
||||
with pytest.raises(Exception, match="Test error"):
|
||||
assert typed_handler.unary_unary is not None
|
||||
await typed_handler.unary_unary(MagicMock(), context)
|
||||
assert wrapped_handler.unary_unary is not None
|
||||
unary_fn = wrapped_handler.unary_unary
|
||||
await unary_fn(MagicMock(), context)
|
||||
|
||||
# Should still log with INTERNAL status
|
||||
mock_logger.info.assert_called_once()
|
||||
call_kwargs = mock_logger.info.call_args[1]
|
||||
assert call_kwargs["status"] == "INTERNAL", "should log INTERNAL for unhandled exception"
|
||||
assert call_kwargs["status"] == "INTERNAL", (
|
||||
"should log INTERNAL for unhandled exception"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passes_through_to_continuation(self) -> None:
|
||||
@@ -277,7 +322,7 @@ class TestRequestLoggingInterceptor:
|
||||
details = create_handler_call_details()
|
||||
continuation = create_mock_continuation(handler)
|
||||
|
||||
await interceptor.intercept_service(continuation, details)
|
||||
_: TestHandlerProtocol = await interceptor.intercept_service(continuation, details)
|
||||
|
||||
continuation.assert_called_once_with(details)
|
||||
|
||||
@@ -290,17 +335,17 @@ class TestRequestLoggingInterceptor:
|
||||
continuation = create_mock_continuation(handler)
|
||||
|
||||
with patch("noteflow.grpc.interceptors.logging._logging_ops.logger") as mock_logger:
|
||||
wrapped_handler = await interceptor.intercept_service(continuation, details)
|
||||
typed_handler = cast(_UnaryUnaryHandler, wrapped_handler)
|
||||
wrapped_handler: TestHandlerProtocol = await interceptor.intercept_service(
|
||||
continuation, details
|
||||
)
|
||||
|
||||
# Context without peer method
|
||||
context: ServicerContextProtocol = AsyncMock(spec=ServicerContextProtocol)
|
||||
context.peer = MagicMock(side_effect=RuntimeError("No peer"))
|
||||
|
||||
assert typed_handler.unary_unary is not None
|
||||
await typed_handler.unary_unary(MagicMock(), context)
|
||||
assert wrapped_handler.unary_unary is not None
|
||||
unary_fn = wrapped_handler.unary_unary
|
||||
await unary_fn(MagicMock(), context)
|
||||
|
||||
# Should still log with None peer
|
||||
mock_logger.info.assert_called_once()
|
||||
call_kwargs = mock_logger.info.call_args[1]
|
||||
assert call_kwargs["peer"] is None, "should handle missing peer"
|
||||
|
||||
Reference in New Issue
Block a user