diff --git a/.cupcake/policies/opencode/warn_new_file_search.rego b/.cupcake/policies/.inactive/warn_new_file_search.rego similarity index 100% rename from .cupcake/policies/opencode/warn_new_file_search.rego rename to .cupcake/policies/.inactive/warn_new_file_search.rego diff --git a/.cupcake/policies/opencode/ban_stdlib_logger.rego b/.cupcake/policies/opencode/ban_stdlib_logger.rego index 2d287bf..4803df0 100644 --- a/.cupcake/policies/opencode/ban_stdlib_logger.rego +++ b/.cupcake/policies/opencode/ban_stdlib_logger.rego @@ -114,7 +114,7 @@ file_path_pattern := `\.py$` deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Write", "Edit", "NotebookEdit"} + tool_name in {"write", "edit", "notebookedit"} file_path := resolved_file_path regex.match(file_path_pattern, file_path) @@ -132,7 +132,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name == "MultiEdit" + tool_name == "multiedit" some edit in tool_input.edits file_path := edit_path(edit) @@ -151,7 +151,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Patch", "ApplyPatch"} + tool_name in {"patch", "apply_patch"} patch := patch_content patch != null diff --git a/.cupcake/policies/opencode/block_assertion_roulette.rego b/.cupcake/policies/opencode/block_assertion_roulette.rego index 6e54c5e..a5b59c3 100644 --- a/.cupcake/policies/opencode/block_assertion_roulette.rego +++ b/.cupcake/policies/opencode/block_assertion_roulette.rego @@ -126,7 +126,7 @@ assertion_pattern := `^\s*assert\s+[^,\n]+\n\s*assert\s+[^,\n]+$` deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Write", "Edit", "NotebookEdit"} + tool_name in {"write", "edit", "notebookedit"} file_path := resolved_file_path regex.match(file_path_pattern, file_path) @@ -144,7 +144,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name == "MultiEdit" + tool_name == "multiedit" some edit in tool_input.edits file_path := edit_path(edit) @@ -163,7 +163,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Patch", "ApplyPatch"} + tool_name in {"patch", "apply_patch"} patch := patch_content patch != null diff --git a/.cupcake/policies/opencode/block_biome_ignore.rego b/.cupcake/policies/opencode/block_biome_ignore.rego index d851d54..04811ee 100644 --- a/.cupcake/policies/opencode/block_biome_ignore.rego +++ b/.cupcake/policies/opencode/block_biome_ignore.rego @@ -126,7 +126,7 @@ ignore_pattern := `//\s*biome-ignore|//\s*@ts-ignore|//\s*@ts-expect-error|//\s* deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Write", "Edit", "NotebookEdit"} + tool_name in {"write", "edit", "notebookedit"} file_path := resolved_file_path regex.match(file_path_pattern, file_path) @@ -144,7 +144,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name == "MultiEdit" + tool_name == "multiedit" some edit in tool_input.edits file_path := edit_path(edit) @@ -163,7 +163,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Patch", "ApplyPatch"} + tool_name in {"patch", "apply_patch"} patch := patch_content patch != null diff --git a/.cupcake/policies/opencode/block_broad_exception_handler.rego b/.cupcake/policies/opencode/block_broad_exception_handler.rego index 21670f2..4405af9 100644 --- a/.cupcake/policies/opencode/block_broad_exception_handler.rego +++ b/.cupcake/policies/opencode/block_broad_exception_handler.rego @@ -104,7 +104,7 @@ handler_pattern := `except\s+Exception\s*(?:as\s+\w+)?:\s*\n\s+(?:logger\.|loggi deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Write", "Edit", "NotebookEdit"} + tool_name in {"write", "edit", "notebookedit"} content := new_content content != null @@ -119,7 +119,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name == "MultiEdit" + tool_name == "multiedit" some edit in tool_input.edits content := edit_new_content(edit) @@ -135,7 +135,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Patch", "ApplyPatch"} + tool_name in {"patch", "apply_patch"} patch := patch_content patch != null diff --git a/.cupcake/policies/opencode/block_code_quality_test_edits.rego b/.cupcake/policies/opencode/block_code_quality_test_edits.rego index 8e1c000..9333d2a 100644 --- a/.cupcake/policies/opencode/block_code_quality_test_edits.rego +++ b/.cupcake/policies/opencode/block_code_quality_test_edits.rego @@ -104,7 +104,7 @@ file_path_pattern := `src/test/code-quality\.test\.ts$` deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Write", "Edit", "NotebookEdit"} + tool_name in {"write", "edit", "notebookedit"} file_path := resolved_file_path regex.match(file_path_pattern, file_path) @@ -118,7 +118,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name == "MultiEdit" + tool_name == "multiedit" some edit in tool_input.edits file_path := edit_path(edit) diff --git a/.cupcake/policies/opencode/block_datetime_now_fallback.rego b/.cupcake/policies/opencode/block_datetime_now_fallback.rego index dac543f..626a5b5 100644 --- a/.cupcake/policies/opencode/block_datetime_now_fallback.rego +++ b/.cupcake/policies/opencode/block_datetime_now_fallback.rego @@ -104,7 +104,7 @@ pattern := `return\s+datetime\.now\s*\(` deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Write", "Edit", "NotebookEdit"} + tool_name in {"write", "edit", "notebookedit"} content := new_content content != null @@ -119,7 +119,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name == "MultiEdit" + tool_name == "multiedit" some edit in tool_input.edits content := edit_new_content(edit) @@ -135,7 +135,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Patch", "ApplyPatch"} + tool_name in {"patch", "apply_patch"} patch := patch_content patch != null diff --git a/.cupcake/policies/opencode/block_default_value_swallow.rego b/.cupcake/policies/opencode/block_default_value_swallow.rego index aeee7fa..32d1e02 100644 --- a/.cupcake/policies/opencode/block_default_value_swallow.rego +++ b/.cupcake/policies/opencode/block_default_value_swallow.rego @@ -104,7 +104,7 @@ pattern := `except\s+\w*(?:Error|Exception).*?:\s*\n\s+.*?(?:logger\.|logging\.) deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Write", "Edit", "NotebookEdit"} + tool_name in {"write", "edit", "notebookedit"} content := new_content content != null @@ -119,7 +119,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name == "MultiEdit" + tool_name == "multiedit" some edit in tool_input.edits content := edit_new_content(edit) @@ -135,7 +135,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Patch", "ApplyPatch"} + tool_name in {"patch", "apply_patch"} patch := patch_content patch != null diff --git a/.cupcake/policies/opencode/block_duplicate_fixtures.rego b/.cupcake/policies/opencode/block_duplicate_fixtures.rego index 8f1bba8..b793240 100644 --- a/.cupcake/policies/opencode/block_duplicate_fixtures.rego +++ b/.cupcake/policies/opencode/block_duplicate_fixtures.rego @@ -127,7 +127,7 @@ fixture_pattern := `@pytest\.fixture[^@]*\ndef\s+(mock_uow|crypto|meetings_dir|w deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Write", "Edit", "NotebookEdit"} + tool_name in {"write", "edit", "notebookedit"} file_path := resolved_file_path regex.match(file_path_pattern, file_path) @@ -146,7 +146,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name == "MultiEdit" + tool_name == "multiedit" some edit in tool_input.edits file_path := edit_path(edit) @@ -166,7 +166,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Patch", "ApplyPatch"} + tool_name in {"patch", "apply_patch"} patch := patch_content patch != null diff --git a/.cupcake/policies/opencode/block_linter_config_frontend.rego b/.cupcake/policies/opencode/block_linter_config_frontend.rego index 7f22594..00d20bc 100644 --- a/.cupcake/policies/opencode/block_linter_config_frontend.rego +++ b/.cupcake/policies/opencode/block_linter_config_frontend.rego @@ -104,7 +104,7 @@ file_path_pattern := `(^|/)client/.*(?:\.?eslint(?:rc|\.config).*|\.?prettier(?: deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Write", "Edit", "NotebookEdit"} + tool_name in {"write", "edit", "notebookedit"} file_path := resolved_file_path regex.match(file_path_pattern, file_path) @@ -119,7 +119,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name == "MultiEdit" + tool_name == "multiedit" some edit in tool_input.edits file_path := edit_path(edit) diff --git a/.cupcake/policies/opencode/block_linter_config_python.rego b/.cupcake/policies/opencode/block_linter_config_python.rego index 9964c94..fdf85c9 100644 --- a/.cupcake/policies/opencode/block_linter_config_python.rego +++ b/.cupcake/policies/opencode/block_linter_config_python.rego @@ -104,7 +104,7 @@ file_path_pattern := `(?:pyproject\.toml|\.?ruff\.toml|\.?pyrightconfig\.json|\. deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Write", "Edit", "NotebookEdit"} + tool_name in {"write", "edit", "notebookedit"} file_path := resolved_file_path regex.match(file_path_pattern, file_path) @@ -119,7 +119,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name == "MultiEdit" + tool_name == "multiedit" some edit in tool_input.edits file_path := edit_path(edit) diff --git a/.cupcake/policies/opencode/block_magic_numbers.rego b/.cupcake/policies/opencode/block_magic_numbers.rego index 49f4b9e..62ee131 100644 --- a/.cupcake/policies/opencode/block_magic_numbers.rego +++ b/.cupcake/policies/opencode/block_magic_numbers.rego @@ -126,7 +126,7 @@ number_pattern := `(?:timeout|delay|interval|duration|limit|max|min|size|count|t deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Write", "Edit", "NotebookEdit"} + tool_name in {"write", "edit", "notebookedit"} file_path := resolved_file_path regex.match(file_path_pattern, file_path) @@ -145,7 +145,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name == "MultiEdit" + tool_name == "multiedit" some edit in tool_input.edits file_path := edit_path(edit) @@ -165,7 +165,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Patch", "ApplyPatch"} + tool_name in {"patch", "apply_patch"} patch := patch_content patch != null diff --git a/.cupcake/policies/opencode/block_makefile_edit.rego b/.cupcake/policies/opencode/block_makefile_edit.rego index cfcb5cc..d60887d 100644 --- a/.cupcake/policies/opencode/block_makefile_edit.rego +++ b/.cupcake/policies/opencode/block_makefile_edit.rego @@ -104,7 +104,7 @@ file_path_pattern := `(?:^|/)Makefile$` deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Write", "Edit", "NotebookEdit"} + tool_name in {"write", "edit", "notebookedit"} file_path := resolved_file_path regex.match(file_path_pattern, file_path) @@ -118,7 +118,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name == "MultiEdit" + tool_name == "multiedit" some edit in tool_input.edits file_path := edit_path(edit) diff --git a/.cupcake/policies/opencode/block_silent_none_return.rego b/.cupcake/policies/opencode/block_silent_none_return.rego index d585436..6b976ab 100644 --- a/.cupcake/policies/opencode/block_silent_none_return.rego +++ b/.cupcake/policies/opencode/block_silent_none_return.rego @@ -104,7 +104,7 @@ pattern := `except\s+\w*Error.*?:\s*\n\s+.*?(?:logger\.|logging\.).*?\n\s+return deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Write", "Edit", "NotebookEdit"} + tool_name in {"write", "edit", "notebookedit"} content := new_content content != null @@ -119,7 +119,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name == "MultiEdit" + tool_name == "multiedit" some edit in tool_input.edits content := edit_new_content(edit) @@ -135,7 +135,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Patch", "ApplyPatch"} + tool_name in {"patch", "apply_patch"} patch := patch_content patch != null diff --git a/.cupcake/policies/opencode/block_test_loops_conditionals.rego b/.cupcake/policies/opencode/block_test_loops_conditionals.rego index cdbd015..4c38e9a 100644 --- a/.cupcake/policies/opencode/block_test_loops_conditionals.rego +++ b/.cupcake/policies/opencode/block_test_loops_conditionals.rego @@ -126,7 +126,7 @@ pattern := `def test_[^(]+\([^)]*\)[^:]*:[\s\S]*?\b(for|while|if)\s+[^:]+:[\s\S] deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Write", "Edit", "NotebookEdit"} + tool_name in {"write", "edit", "notebookedit"} file_path := resolved_file_path regex.match(file_path_pattern, file_path) @@ -144,7 +144,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name == "MultiEdit" + tool_name == "multiedit" some edit in tool_input.edits file_path := edit_path(edit) @@ -163,7 +163,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Patch", "ApplyPatch"} + tool_name in {"patch", "apply_patch"} patch := patch_content patch != null diff --git a/.cupcake/policies/opencode/block_tests_quality.rego b/.cupcake/policies/opencode/block_tests_quality.rego index c86e51c..6652ad4 100644 --- a/.cupcake/policies/opencode/block_tests_quality.rego +++ b/.cupcake/policies/opencode/block_tests_quality.rego @@ -105,7 +105,7 @@ exclude_pattern := `baselines\.json$` deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Write", "Edit", "NotebookEdit"} + tool_name in {"write", "edit", "notebookedit"} file_path := resolved_file_path regex.match(file_path_pattern, file_path) @@ -120,7 +120,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name == "MultiEdit" + tool_name == "multiedit" some edit in tool_input.edits file_path := edit_path(edit) diff --git a/.cupcake/policies/opencode/prevent_any_type.rego b/.cupcake/policies/opencode/prevent_any_type.rego index d6cd27c..0b9be0e 100644 --- a/.cupcake/policies/opencode/prevent_any_type.rego +++ b/.cupcake/policies/opencode/prevent_any_type.rego @@ -128,7 +128,7 @@ any_type_patterns := [ # Block Write/Edit operations that introduce Any in Python files deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Write", "Edit", "NotebookEdit"} + tool_name in {"write", "edit", "notebookedit"} # Only enforce for Python files file_path := lower(resolved_file_path) @@ -149,7 +149,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Patch", "ApplyPatch"} + tool_name in {"patch", "apply_patch"} content := patch_content content != null @@ -166,7 +166,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name == "MultiEdit" + tool_name == "multiedit" some edit in tool_input.edits file_path := lower(edit_path(edit)) diff --git a/.cupcake/policies/opencode/prevent_type_suppression.rego b/.cupcake/policies/opencode/prevent_type_suppression.rego index 9129014..6410e39 100644 --- a/.cupcake/policies/opencode/prevent_type_suppression.rego +++ b/.cupcake/policies/opencode/prevent_type_suppression.rego @@ -123,7 +123,7 @@ type_suppression_patterns := [ # Block Write/Edit operations that introduce type suppression in Python files deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Write", "Edit", "NotebookEdit"} + tool_name in {"write", "edit", "notebookedit"} # Only enforce for Python files file_path := lower(resolved_file_path) @@ -144,7 +144,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Patch", "ApplyPatch"} + tool_name in {"patch", "apply_patch"} content := patch_content content != null @@ -161,7 +161,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name == "MultiEdit" + tool_name == "multiedit" some edit in tool_input.edits file_path := lower(edit_path(edit)) diff --git a/.cupcake/policies/opencode/warn_baselines_edit.rego b/.cupcake/policies/opencode/warn_baselines_edit.rego index 76193a4..463137f 100644 --- a/.cupcake/policies/opencode/warn_baselines_edit.rego +++ b/.cupcake/policies/opencode/warn_baselines_edit.rego @@ -104,7 +104,7 @@ file_path_pattern := `tests/quality/baselines\.json$` deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Write", "Edit", "NotebookEdit"} + tool_name in {"write", "edit", "notebookedit"} file_path := resolved_file_path regex.match(file_path_pattern, file_path) @@ -118,7 +118,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name == "MultiEdit" + tool_name == "multiedit" some edit in tool_input.edits file_path := edit_path(edit) diff --git a/.cupcake/policies/opencode/warn_large_file.rego b/.cupcake/policies/opencode/warn_large_file.rego index 09f8aee..c49d9b7 100644 --- a/.cupcake/policies/opencode/warn_large_file.rego +++ b/.cupcake/policies/opencode/warn_large_file.rego @@ -126,7 +126,7 @@ pattern := `(?:.*\n){500,}` deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Write", "Edit", "NotebookEdit"} + tool_name in {"write", "edit", "notebookedit"} file_path := resolved_file_path regex.match(file_path_pattern, file_path) @@ -144,7 +144,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name == "MultiEdit" + tool_name == "multiedit" some edit in tool_input.edits file_path := edit_path(edit) @@ -163,7 +163,7 @@ deny contains decision if { deny contains decision if { input.hook_event_name == "PreToolUse" - tool_name in {"Patch", "ApplyPatch"} + tool_name in {"patch", "apply_patch"} patch := patch_content patch != null diff --git a/src/noteflow/domain/entities/integration.py b/src/noteflow/domain/entities/integration.py index b577a74..d398952 100644 --- a/src/noteflow/domain/entities/integration.py +++ b/src/noteflow/domain/entities/integration.py @@ -7,7 +7,7 @@ from datetime import datetime from enum import StrEnum from uuid import UUID, uuid4 -from noteflow.domain.constants.fields import CALENDAR, EMAIL +from noteflow.domain.constants.fields import CALENDAR, EMAIL, UNKNOWN as UNKNOWN_VALUE from noteflow.domain.utils.time import utc_now from noteflow.infrastructure.logging import log_state_transition @@ -150,7 +150,7 @@ class SyncErrorCode(StrEnum): AUTH_REQUIRED = "auth_required" PROVIDER_ERROR = "provider_error" INTERNAL_ERROR = "internal_error" - UNKNOWN = "unknown" + UNKNOWN = UNKNOWN_VALUE @dataclass diff --git a/src/noteflow/grpc/interceptors/_types.py b/src/noteflow/grpc/interceptors/_types.py index f512af0..0fdd301 100644 --- a/src/noteflow/grpc/interceptors/_types.py +++ b/src/noteflow/grpc/interceptors/_types.py @@ -16,9 +16,14 @@ from typing import NoReturn, Protocol, TypeVar, cast import grpc + TRequest = TypeVar("TRequest") TResponse = TypeVar("TResponse") +# Type variables for RpcMethodHandlerProtocol variance +TRequest_co = TypeVar("TRequest_co", covariant=True) +TResponse_contra = TypeVar("TResponse_contra", contravariant=True) + # Metadata type alias MetadataLike = Sequence[tuple[str, str | bytes]] @@ -55,42 +60,74 @@ class ServicerContextProtocol(Protocol): ... -class RpcMethodHandlerProtocol(Protocol): +class HandlerCallDetailsProtocol(Protocol): + """Protocol for gRPC HandlerCallDetails. + + This matches the grpc.HandlerCallDetails interface at runtime. + """ + + method: str | None + invocation_metadata: MetadataLike | None + + +class ServerInterceptorProtocol(Protocol): + """Protocol for grpc.aio.ServerInterceptor.""" + + async def intercept_service( + self, + continuation: Callable[ + [HandlerCallDetailsProtocol], + Awaitable[RpcMethodHandlerProtocol[TRequest, TResponse]], + ], + handler_call_details: HandlerCallDetailsProtocol, + ) -> RpcMethodHandlerProtocol[TRequest, TResponse]: + ... + + +class RpcMethodHandlerProtocol[TRequest_co, TResponse_contra](Protocol): """Protocol for RPC method handlers. This matches the grpc.RpcMethodHandler interface at runtime. - Note: The handler properties return objects that are callable with - specific request/response types, but we use object here for Protocol - compatibility. Use casts at call sites for proper typing. """ @property - def request_deserializer(self) -> Callable[[bytes], object] | None: + def request_deserializer(self) -> Callable[[bytes], TRequest_co] | None: """Request deserializer function.""" ... @property - def response_serializer(self) -> Callable[[object], bytes] | None: + def response_serializer(self) -> Callable[[TResponse_contra], bytes] | None: """Response serializer function.""" ... @property - def unary_unary(self) -> object | None: + def unary_unary( + self, + ) -> Callable[[TRequest_co, ServicerContextProtocol], Awaitable[TResponse_contra]] | None: """Unary-unary handler behavior.""" ... @property - def unary_stream(self) -> object | None: + def unary_stream( + self, + ) -> Callable[[TRequest_co, ServicerContextProtocol], AsyncIterator[TResponse_contra]] | None: """Unary-stream handler behavior.""" ... @property - def stream_unary(self) -> object | None: + def stream_unary( + self, + ) -> Callable[[AsyncIterator[TRequest_co], ServicerContextProtocol], Awaitable[TResponse_contra]] | None: """Stream-unary handler behavior.""" ... @property - def stream_stream(self) -> object | None: + def stream_stream( + self, + ) -> Callable[ + [AsyncIterator[TRequest_co], ServicerContextProtocol], + AsyncIterator[TResponse_contra], + ] | None: """Stream-stream handler behavior.""" ... @@ -112,7 +149,7 @@ class GrpcFactoriesProtocol(Protocol): *, request_deserializer: Callable[[bytes], TRequest] | None = None, response_serializer: Callable[[TResponse], bytes] | None = None, - ) -> RpcMethodHandlerProtocol: + ) -> RpcMethodHandlerProtocol[TRequest, TResponse]: """Create a unary-unary RPC method handler.""" ... @@ -125,7 +162,7 @@ class GrpcFactoriesProtocol(Protocol): *, request_deserializer: Callable[[bytes], TRequest] | None = None, response_serializer: Callable[[TResponse], bytes] | None = None, - ) -> RpcMethodHandlerProtocol: + ) -> RpcMethodHandlerProtocol[TRequest, TResponse]: """Create a unary-stream RPC method handler.""" ... @@ -138,7 +175,7 @@ class GrpcFactoriesProtocol(Protocol): *, request_deserializer: Callable[[bytes], TRequest] | None = None, response_serializer: Callable[[TResponse], bytes] | None = None, - ) -> RpcMethodHandlerProtocol: + ) -> RpcMethodHandlerProtocol[TRequest, TResponse]: """Create a stream-unary RPC method handler.""" ... @@ -151,13 +188,11 @@ class GrpcFactoriesProtocol(Protocol): *, request_deserializer: Callable[[bytes], TRequest] | None = None, response_serializer: Callable[[TResponse], bytes] | None = None, - ) -> RpcMethodHandlerProtocol: + ) -> RpcMethodHandlerProtocol[TRequest, TResponse]: """Create a stream-stream RPC method handler.""" ... # Typed grpc module with proper factory function signatures # Use this instead of importing grpc directly when you need typed factories -typed_grpc: GrpcFactoriesProtocol = cast( - GrpcFactoriesProtocol, importlib.import_module("grpc") -) +typed_grpc: GrpcFactoriesProtocol = cast(GrpcFactoriesProtocol, importlib.import_module("grpc")) diff --git a/src/noteflow/grpc/interceptors/identity.py b/src/noteflow/grpc/interceptors/identity.py index b28c639..56aa656 100644 --- a/src/noteflow/grpc/interceptors/identity.py +++ b/src/noteflow/grpc/interceptors/identity.py @@ -11,11 +11,9 @@ from __future__ import annotations from collections.abc import AsyncIterator, Awaitable, Callable from functools import partial -from typing import NoReturn, cast +from typing import TYPE_CHECKING, NoReturn, TypeVar, cast import grpc -from grpc import HandlerCallDetails, RpcMethodHandler -from grpc.aio import ServerInterceptor from noteflow.infrastructure.logging import ( get_logger, @@ -25,13 +23,22 @@ from noteflow.infrastructure.logging import ( ) from ._types import ( + HandlerCallDetailsProtocol, RpcMethodHandlerProtocol, + ServerInterceptorProtocol, ServicerContextProtocol, - TRequest, - TResponse, typed_grpc, ) +_TRequest = TypeVar("_TRequest") +_TResponse = TypeVar("_TResponse") + + +if TYPE_CHECKING: + ServerInterceptor = ServerInterceptorProtocol +else: + from grpc.aio import ServerInterceptor + logger = get_logger(__name__) # Metadata keys for identity context @@ -126,21 +133,18 @@ class IdentityInterceptor(ServerInterceptor): async def intercept_service( self, continuation: Callable[ - [HandlerCallDetails], - Awaitable[RpcMethodHandler[TRequest, TResponse]], + [HandlerCallDetailsProtocol], + Awaitable[RpcMethodHandlerProtocol[_TRequest, _TResponse]], ], - handler_call_details: HandlerCallDetails, - ) -> RpcMethodHandler[TRequest, TResponse]: + handler_call_details: HandlerCallDetailsProtocol, + ) -> RpcMethodHandlerProtocol[_TRequest, _TResponse]: """Intercept incoming RPC calls to validate and set identity context.""" metadata = dict(handler_call_details.invocation_metadata or []) request_id = _get_request_id(metadata, handler_call_details.method) if request_id is None: handler = await continuation(handler_call_details) - return cast( - RpcMethodHandler[TRequest, TResponse], - _create_unauthenticated_handler(handler, _ERR_MISSING_REQUEST_ID), - ) + return _create_unauthenticated_handler(handler, _ERR_MISSING_REQUEST_ID) _apply_identity_context(metadata, request_id) @@ -156,35 +160,50 @@ class IdentityInterceptor(ServerInterceptor): def _create_unauthenticated_handler( - handler: RpcMethodHandlerProtocol, + handler: RpcMethodHandlerProtocol[_TRequest, _TResponse], message: str, -) -> RpcMethodHandlerProtocol: +) -> RpcMethodHandlerProtocol[_TRequest, _TResponse]: """Create a handler that rejects with UNAUTHENTICATED status.""" request_deserializer = handler.request_deserializer response_serializer = handler.response_serializer + if response_serializer is not None: + assert callable(response_serializer) + response_serializer_any = cast(Callable[[object], bytes] | None, response_serializer) if handler.unary_unary is not None: - return typed_grpc.unary_unary_rpc_method_handler( - partial(_reject_unary_unary, message), - request_deserializer=request_deserializer, - response_serializer=response_serializer, + return cast( + RpcMethodHandlerProtocol[_TRequest, _TResponse], + typed_grpc.unary_unary_rpc_method_handler( + partial(_reject_unary_unary, message), + request_deserializer=request_deserializer, + response_serializer=response_serializer_any, + ), ) if handler.unary_stream is not None: - return typed_grpc.unary_stream_rpc_method_handler( - partial(_reject_unary_stream, message), - request_deserializer=request_deserializer, - response_serializer=response_serializer, + return cast( + RpcMethodHandlerProtocol[_TRequest, _TResponse], + typed_grpc.unary_stream_rpc_method_handler( + partial(_reject_unary_stream, message), + request_deserializer=request_deserializer, + response_serializer=response_serializer_any, + ), ) if handler.stream_unary is not None: - return typed_grpc.stream_unary_rpc_method_handler( - partial(_reject_stream_unary, message), - request_deserializer=request_deserializer, - response_serializer=response_serializer, + return cast( + RpcMethodHandlerProtocol[_TRequest, _TResponse], + typed_grpc.stream_unary_rpc_method_handler( + partial(_reject_stream_unary, message), + request_deserializer=request_deserializer, + response_serializer=response_serializer_any, + ), ) if handler.stream_stream is not None: - return typed_grpc.stream_stream_rpc_method_handler( - partial(_reject_stream_stream, message), - request_deserializer=request_deserializer, - response_serializer=response_serializer, + return cast( + RpcMethodHandlerProtocol[_TRequest, _TResponse], + typed_grpc.stream_stream_rpc_method_handler( + partial(_reject_stream_stream, message), + request_deserializer=request_deserializer, + response_serializer=response_serializer_any, + ), ) return handler diff --git a/src/noteflow/grpc/interceptors/logging/_handler_factory.py b/src/noteflow/grpc/interceptors/logging/_handler_factory.py index 8c71f26..7518421 100644 --- a/src/noteflow/grpc/interceptors/logging/_handler_factory.py +++ b/src/noteflow/grpc/interceptors/logging/_handler_factory.py @@ -3,10 +3,13 @@ from __future__ import annotations from collections.abc import AsyncIterator, Awaitable, Callable +from typing import cast from .._types import ( RpcMethodHandlerProtocol, ServicerContextProtocol, + TRequest, + TResponse, typed_grpc, ) from ._wrappers import ( @@ -17,90 +20,118 @@ from ._wrappers import ( ) -def wrap_unary_unary_handler( - handler: RpcMethodHandlerProtocol, +def _wrap_unary_unary_handler( + handler: RpcMethodHandlerProtocol[TRequest, TResponse], method: str, -) -> RpcMethodHandlerProtocol: - """Create wrapped unary-unary handler with logging.""" +) -> RpcMethodHandlerProtocol[TRequest, TResponse]: if handler.unary_unary is None: raise TypeError("Unary-unary handler is missing") wrapped: Callable[ [object, ServicerContextProtocol], Awaitable[object], - ] = wrap_unary_unary(handler.unary_unary, method) - request_deserializer = handler.request_deserializer - response_serializer = handler.response_serializer - return typed_grpc.unary_unary_rpc_method_handler( - wrapped, - request_deserializer=request_deserializer, - response_serializer=response_serializer, + ] = wrap_unary_unary( + cast(Callable[[object, ServicerContextProtocol], Awaitable[object]], handler.unary_unary), + method, + ) + return cast( + RpcMethodHandlerProtocol[TRequest, TResponse], + typed_grpc.unary_unary_rpc_method_handler( + wrapped, + request_deserializer=cast( + Callable[[bytes], object] | None, handler.request_deserializer + ), + response_serializer=cast(Callable[[object], bytes] | None, handler.response_serializer), + ), ) -def wrap_unary_stream_handler( - handler: RpcMethodHandlerProtocol, +def _wrap_unary_stream_handler( + handler: RpcMethodHandlerProtocol[TRequest, TResponse], method: str, -) -> RpcMethodHandlerProtocol: - """Create wrapped unary-stream handler with logging.""" +) -> RpcMethodHandlerProtocol[TRequest, TResponse]: if handler.unary_stream is None: raise TypeError("Unary-stream handler is missing") wrapped: Callable[ [object, ServicerContextProtocol], AsyncIterator[object], - ] = wrap_unary_stream(handler.unary_stream, method) - request_deserializer = handler.request_deserializer - response_serializer = handler.response_serializer - return typed_grpc.unary_stream_rpc_method_handler( - wrapped, - request_deserializer=request_deserializer, - response_serializer=response_serializer, + ] = wrap_unary_stream( + cast( + Callable[[object, ServicerContextProtocol], AsyncIterator[object]], handler.unary_stream + ), + method, + ) + return cast( + RpcMethodHandlerProtocol[TRequest, TResponse], + typed_grpc.unary_stream_rpc_method_handler( + wrapped, + request_deserializer=cast( + Callable[[bytes], object] | None, handler.request_deserializer + ), + response_serializer=cast(Callable[[object], bytes] | None, handler.response_serializer), + ), ) -def wrap_stream_unary_handler( - handler: RpcMethodHandlerProtocol, +def _wrap_stream_unary_handler( + handler: RpcMethodHandlerProtocol[TRequest, TResponse], method: str, -) -> RpcMethodHandlerProtocol: - """Create wrapped stream-unary handler with logging.""" +) -> RpcMethodHandlerProtocol[TRequest, TResponse]: if handler.stream_unary is None: raise TypeError("Stream-unary handler is missing") wrapped: Callable[ [AsyncIterator[object], ServicerContextProtocol], Awaitable[object], - ] = wrap_stream_unary(handler.stream_unary, method) - request_deserializer = handler.request_deserializer - response_serializer = handler.response_serializer - return typed_grpc.stream_unary_rpc_method_handler( - wrapped, - request_deserializer=request_deserializer, - response_serializer=response_serializer, + ] = wrap_stream_unary( + cast( + Callable[[AsyncIterator[object], ServicerContextProtocol], Awaitable[object]], + handler.stream_unary, + ), + method, + ) + return cast( + RpcMethodHandlerProtocol[TRequest, TResponse], + typed_grpc.stream_unary_rpc_method_handler( + wrapped, + request_deserializer=cast( + Callable[[bytes], object] | None, handler.request_deserializer + ), + response_serializer=cast(Callable[[object], bytes] | None, handler.response_serializer), + ), ) -def wrap_stream_stream_handler( - handler: RpcMethodHandlerProtocol, +def _wrap_stream_stream_handler( + handler: RpcMethodHandlerProtocol[TRequest, TResponse], method: str, -) -> RpcMethodHandlerProtocol: - """Create wrapped stream-stream handler with logging.""" +) -> RpcMethodHandlerProtocol[TRequest, TResponse]: if handler.stream_stream is None: raise TypeError("Stream-stream handler is missing") wrapped: Callable[ [AsyncIterator[object], ServicerContextProtocol], AsyncIterator[object], - ] = wrap_stream_stream(handler.stream_stream, method) - request_deserializer = handler.request_deserializer - response_serializer = handler.response_serializer - return typed_grpc.stream_stream_rpc_method_handler( - wrapped, - request_deserializer=request_deserializer, - response_serializer=response_serializer, + ] = wrap_stream_stream( + cast( + Callable[[AsyncIterator[object], ServicerContextProtocol], AsyncIterator[object]], + handler.stream_stream, + ), + method, + ) + return cast( + RpcMethodHandlerProtocol[TRequest, TResponse], + typed_grpc.stream_stream_rpc_method_handler( + wrapped, + request_deserializer=cast( + Callable[[bytes], object] | None, handler.request_deserializer + ), + response_serializer=cast(Callable[[object], bytes] | None, handler.response_serializer), + ), ) def create_logging_handler( - handler: RpcMethodHandlerProtocol, + handler: RpcMethodHandlerProtocol[TRequest, TResponse], method: str, -) -> RpcMethodHandlerProtocol: +) -> RpcMethodHandlerProtocol[TRequest, TResponse]: """Wrap an RPC handler to add request logging. Args: @@ -111,12 +142,11 @@ def create_logging_handler( Wrapped handler with logging. """ if handler.unary_unary is not None: - return wrap_unary_unary_handler(handler, method) + return _wrap_unary_unary_handler(handler, method) if handler.unary_stream is not None: - return wrap_unary_stream_handler(handler, method) + return _wrap_unary_stream_handler(handler, method) if handler.stream_unary is not None: - return wrap_stream_unary_handler(handler, method) + return _wrap_stream_unary_handler(handler, method) if handler.stream_stream is not None: - return wrap_stream_stream_handler(handler, method) - # Fallback: return original handler if type unknown + return _wrap_stream_stream_handler(handler, method) return handler diff --git a/src/noteflow/grpc/interceptors/logging/logging.py b/src/noteflow/grpc/interceptors/logging/logging.py index ef8c7e6..6caadef 100644 --- a/src/noteflow/grpc/interceptors/logging/logging.py +++ b/src/noteflow/grpc/interceptors/logging/logging.py @@ -1,19 +1,29 @@ """Request logging interceptor for gRPC calls. -Log every RPC call with method, status, duration, peer, and request context -at INFO level for production observability and traceability. +Log every RPC call with method, status, duration, peer, and request context at +INFO level for production observability and traceability. """ from __future__ import annotations from collections.abc import Awaitable, Callable -from typing import cast - -from grpc import HandlerCallDetails, RpcMethodHandler -from grpc.aio import ServerInterceptor +from typing import TYPE_CHECKING, TypeVar +from .._types import ( + HandlerCallDetailsProtocol, + RpcMethodHandlerProtocol, + ServerInterceptorProtocol, +) from ._handler_factory import create_logging_handler -from .._types import TRequest, TResponse + +_TRequest = TypeVar("_TRequest") +_TResponse = TypeVar("_TResponse") + + +if TYPE_CHECKING: + ServerInterceptor = ServerInterceptorProtocol +else: + from grpc.aio import ServerInterceptor class RequestLoggingInterceptor(ServerInterceptor): @@ -30,25 +40,12 @@ class RequestLoggingInterceptor(ServerInterceptor): async def intercept_service( self, continuation: Callable[ - [HandlerCallDetails], - Awaitable[RpcMethodHandler[TRequest, TResponse]], + [HandlerCallDetailsProtocol], + Awaitable[RpcMethodHandlerProtocol[_TRequest, _TResponse]], ], - handler_call_details: HandlerCallDetails, - ) -> RpcMethodHandler[TRequest, TResponse]: - """Intercept incoming RPC calls to log request timing and status. - - Args: - continuation: The next interceptor or handler. - handler_call_details: Details about the RPC call. - - Returns: - Wrapped RPC handler that logs on completion. - """ + handler_call_details: HandlerCallDetailsProtocol, + ) -> RpcMethodHandlerProtocol[_TRequest, _TResponse]: + """Intercept incoming RPC calls to log request timing and status.""" handler = await continuation(handler_call_details) - method = handler_call_details.method - - # Return wrapped handler that logs on completion - return cast( - RpcMethodHandler[TRequest, TResponse], - create_logging_handler(handler, method), - ) + method = handler_call_details.method or "" + return create_logging_handler(handler, method) diff --git a/src/noteflow/grpc/mixins/calendar.py b/src/noteflow/grpc/mixins/calendar.py index 7bc5445..9bd9041 100644 --- a/src/noteflow/grpc/mixins/calendar.py +++ b/src/noteflow/grpc/mixins/calendar.py @@ -77,7 +77,6 @@ async def require_calendar_service( return host.calendar_service logger.warning(f"{operation}_unavailable", reason="service_not_enabled") await abort_unavailable(context, _ERR_CALENDAR_NOT_ENABLED) - raise # Unreachable but helps type checker class CalendarMixin: @@ -117,7 +116,6 @@ class CalendarMixin: except CalendarServiceError as e: logger.error("calendar_list_events_failed", error=str(e), provider=provider) await abort_internal(context, str(e)) - raise # Unreachable but helps type checker proto_events = [_calendar_event_to_proto(event) for event in events] @@ -166,9 +164,7 @@ class CalendarMixin: status=status.status, ) - authenticated_count = sum( - bool(provider.is_authenticated) for provider in providers - ) + authenticated_count = sum(bool(provider.is_authenticated) for provider in providers) logger.info( "calendar_get_providers_success", total_providers=len(providers), @@ -205,7 +201,6 @@ class CalendarMixin: error=str(e), ) await abort_invalid_argument(context, str(e)) - raise # Unreachable but helps type checker logger.info( "oauth_initiate_success", diff --git a/src/noteflow/grpc/mixins/diarization_job.py b/src/noteflow/grpc/mixins/diarization_job.py index 655def0..a84d589 100644 --- a/src/noteflow/grpc/mixins/diarization_job.py +++ b/src/noteflow/grpc/mixins/diarization_job.py @@ -92,20 +92,18 @@ class DiarizationJobMixin: Prunes both in-memory task references and database records. """ # Clean up in-memory task references for completed tasks - completed_tasks = [ - job_id for job_id, task in self.diarization_tasks.items() if task.done() - ] + completed_tasks = [job_id for job_id, task in self.diarization_tasks.items() if task.done()] for job_id in completed_tasks: self.diarization_tasks.pop(job_id, None) # Prune old completed jobs from database - async with cast(DiarizationJobRepositoryProvider, self.create_repository_provider()) as repo: + async with cast( + DiarizationJobRepositoryProvider, self.create_repository_provider() + ) as repo: if not repo.supports_diarization_jobs: logger.debug("Job pruning skipped: database required") return - pruned = await repo.diarization_jobs.prune_completed( - self.diarization_job_ttl_seconds - ) + pruned = await repo.diarization_jobs.prune_completed(self.diarization_job_ttl_seconds) await repo.commit() if pruned > 0: logger.debug("Pruned %d completed diarization jobs", pruned) @@ -121,15 +119,15 @@ class DiarizationJobMixin: """ await self.prune_diarization_jobs() - async with cast(DiarizationJobRepositoryProvider, self.create_repository_provider()) as repo: + async with cast( + DiarizationJobRepositoryProvider, self.create_repository_provider() + ) as repo: if not repo.supports_diarization_jobs: await abort_not_found(context, "Diarization jobs (database required)", "") - raise # Unreachable but helps type checker job = await repo.diarization_jobs.get(request.job_id) if job is None: await abort_not_found(context, "Diarization job", request.job_id) - raise # Unreachable but helps type checker return _build_job_status(job) @@ -145,7 +143,9 @@ class DiarizationJobMixin: job_id = request.job_id await _cancel_running_task(self.diarization_tasks, job_id) - async with cast(DiarizationJobRepositoryProvider, self.create_repository_provider()) as repo: + async with cast( + DiarizationJobRepositoryProvider, self.create_repository_provider() + ) as repo: if not repo.supports_diarization_jobs: await abort_database_required(context, "Diarization job cancellation") raise AssertionError(UNREACHABLE_ERROR) # abort is NoReturn @@ -185,7 +185,9 @@ class DiarizationJobMixin: """ response = noteflow_pb2.GetActiveDiarizationJobsResponse() - async with cast(DiarizationJobRepositoryProvider, self.create_repository_provider()) as repo: + async with cast( + DiarizationJobRepositoryProvider, self.create_repository_provider() + ) as repo: if not repo.supports_diarization_jobs: # Return empty list if DB not available return response diff --git a/src/noteflow/grpc/mixins/entities.py b/src/noteflow/grpc/mixins/entities.py index 129d448..1a34525 100644 --- a/src/noteflow/grpc/mixins/entities.py +++ b/src/noteflow/grpc/mixins/entities.py @@ -175,7 +175,6 @@ class EntitiesMixin: if updated is None: await abort_not_found(context, ENTITY_ENTITY, request.entity_id) - raise # Unreachable but helps type checker await uow.commit() diff --git a/src/noteflow/grpc/mixins/identity.py b/src/noteflow/grpc/mixins/identity.py index 954f910..16862b3 100644 --- a/src/noteflow/grpc/mixins/identity.py +++ b/src/noteflow/grpc/mixins/identity.py @@ -75,9 +75,7 @@ def _merge_workspace_settings( trigger_rules=updates.trigger_rules if updates.trigger_rules is not None else current.trigger_rules, - rag_enabled=updates.rag_enabled - if updates.rag_enabled is not None - else current.rag_enabled, + rag_enabled=updates.rag_enabled if updates.rag_enabled is not None else current.rag_enabled, default_summarization_template=updates.default_summarization_template if updates.default_summarization_template is not None else current.default_summarization_template, @@ -192,14 +190,12 @@ class IdentityMixin: if not request.workspace_id: await abort_invalid_argument(context, ERROR_WORKSPACE_ID_REQUIRED) - raise # Unreachable but helps type checker workspace_id = await parse_workspace_id(request.workspace_id, context) async with cast(UnitOfWork, self.create_repository_provider()) as uow: if not uow.supports_workspaces: await abort_database_required(context, WORKSPACES_LABEL) - raise # Unreachable but helps type checker user_ctx = await self.identity_service.get_or_create_default_user(uow) workspace, membership = await self._verify_workspace_access( @@ -231,14 +227,12 @@ class IdentityMixin: """Get workspace settings.""" if not request.workspace_id: await abort_invalid_argument(context, ERROR_WORKSPACE_ID_REQUIRED) - raise # Unreachable but helps type checker workspace_id = await parse_workspace_id(request.workspace_id, context) async with cast(UnitOfWork, self.create_repository_provider()) as uow: if not uow.supports_workspaces: await abort_database_required(context, WORKSPACES_LABEL) - raise # Unreachable but helps type checker user_ctx = await self.identity_service.get_or_create_default_user(uow) workspace, _ = await self._verify_workspace_access( @@ -259,14 +253,12 @@ class IdentityMixin: """Update workspace settings.""" if not request.workspace_id: await abort_invalid_argument(context, ERROR_WORKSPACE_ID_REQUIRED) - raise # Unreachable but helps type checker workspace_id = await parse_workspace_id(request.workspace_id, context) async with cast(UnitOfWork, self.create_repository_provider()) as uow: if not uow.supports_workspaces: await abort_database_required(context, WORKSPACES_LABEL) - raise # Unreachable but helps type checker user_ctx = await self.identity_service.get_or_create_default_user(uow) workspace, membership = await self._verify_workspace_access( @@ -278,7 +270,6 @@ class IdentityMixin: if not membership.role.can_admin(): await abort_permission_denied(context, ERROR_WORKSPACE_ADMIN_REQUIRED) - raise # Unreachable but helps type checker updates = proto_to_workspace_settings(request.settings) if updates is None: @@ -302,11 +293,9 @@ class IdentityMixin: workspace = await uow.workspaces.get(workspace_id) if not workspace: await abort_not_found(context, ENTITY_WORKSPACE, str(workspace_id)) - raise # Unreachable but helps type checker membership = await uow.workspaces.get_membership(workspace_id, user_id) if not membership: await abort_not_found(context, "Workspace membership", str(workspace_id)) - raise # Unreachable but helps type checker return workspace, membership diff --git a/src/noteflow/grpc/mixins/meeting/meeting_mixin.py b/src/noteflow/grpc/mixins/meeting/meeting_mixin.py index e6c5941..cb55fb9 100644 --- a/src/noteflow/grpc/mixins/meeting/meeting_mixin.py +++ b/src/noteflow/grpc/mixins/meeting/meeting_mixin.py @@ -42,7 +42,6 @@ if TYPE_CHECKING: from .._types import GrpcContext from ..protocols import ServicerHost -from ..errors._constants import UNREACHABLE_ERROR logger = get_logger(__name__) @@ -66,7 +65,6 @@ async def _load_meeting_for_stop( if meeting is None: logger.warning("StopMeeting: meeting not found", meeting_id=meeting_id_str) await abort_not_found(context, ENTITY_MEETING, meeting_id_str) - raise AssertionError(UNREACHABLE_ERROR) return meeting @@ -85,7 +83,7 @@ async def _stop_meeting_and_persist(context: _StopMeetingContext) -> Meeting: ) return context.meeting - previous_state = context.meeting.state.value + previous_state = context.meeting.state.name await transition_to_stopped( context.meeting, context.meeting_id, @@ -143,7 +141,9 @@ class MeetingMixin: op_context = self.get_operation_context(context) async with cast(MeetingRepositoryProvider, self.create_repository_provider()) as repo: - project_id = await resolve_create_project_id(self, repo, op_context, project_id) + project_id = await resolve_create_project_id( + cast("ServicerHost", self), repo, op_context, project_id + ) meeting = Meeting.create( title=request.title, @@ -212,7 +212,7 @@ class MeetingMixin: async with cast(MeetingRepositoryProvider, self.create_repository_provider()) as repo: if project_id is None and not project_ids: - project_id = await resolve_active_project_id(self, repo) + project_id = await resolve_active_project_id(cast("ServicerHost", self), repo) meetings, total = await repo.meetings.list_all( states=states, @@ -254,7 +254,6 @@ class MeetingMixin: if meeting is None: logger.warning("GetMeeting: meeting not found", meeting_id=request.meeting_id) await abort_not_found(context, ENTITY_MEETING, request.meeting_id) - raise # Unreachable but helps type checker # Load segments if requested if request.include_segments: segments = await repo.segments.get_by_meeting(meeting.id) @@ -283,7 +282,6 @@ class MeetingMixin: if not success: logger.warning("DeleteMeeting: meeting not found", meeting_id=request.meeting_id) await abort_not_found(context, ENTITY_MEETING, request.meeting_id) - raise # Unreachable but helps type checker await repo.commit() logger.info("Meeting deleted", meeting_id=request.meeting_id) diff --git a/src/noteflow/grpc/mixins/project/_membership.py b/src/noteflow/grpc/mixins/project/_membership.py index 9037a48..ac96258 100644 --- a/src/noteflow/grpc/mixins/project/_membership.py +++ b/src/noteflow/grpc/mixins/project/_membership.py @@ -35,8 +35,6 @@ if TYPE_CHECKING: from ..protocols import ProjectRepositoryProvider - - async def _parse_project_and_user_ids( request_project_id: str, request_user_id: str, @@ -45,21 +43,19 @@ async def _parse_project_and_user_ids( """Parse and validate project and user IDs from request.""" if not request_project_id: await abort_invalid_argument(context, ERROR_PROJECT_ID_REQUIRED) - raise # Unreachable but helps type checker if not request_user_id: await abort_invalid_argument(context, ERROR_USER_ID_REQUIRED) - raise # Unreachable but helps type checker try: project_id = UUID(request_project_id) user_id = UUID(request_user_id) except ValueError as e: await abort_invalid_argument(context, f"{ERROR_INVALID_UUID_PREFIX}{e}") - raise # Unreachable but helps type checker return project_id, user_id + class ProjectMembershipMixin: """Mixin providing project membership functionality. @@ -93,7 +89,6 @@ class ProjectMembershipMixin: ) if membership is None: await abort_not_found(context, ENTITY_PROJECT, request.project_id) - raise # Unreachable but helps type checker return membership_to_proto(membership) @@ -119,8 +114,9 @@ class ProjectMembershipMixin: role=role, ) if membership is None: - await abort_not_found(context, "Membership", f"{request.project_id}/{request.user_id}") - raise # Unreachable but helps type checker + await abort_not_found( + context, "Membership", f"{request.project_id}/{request.user_id}" + ) return membership_to_proto(membership) @@ -159,8 +155,9 @@ class ProjectMembershipMixin: try: project_id = UUID(request.project_id) except ValueError: - await abort_invalid_argument(context, f"{ERROR_INVALID_PROJECT_ID_PREFIX}{request.project_id}") - raise # Unreachable but helps type checker + await abort_invalid_argument( + context, f"{ERROR_INVALID_PROJECT_ID_PREFIX}{request.project_id}" + ) limit = request.limit if request.limit > 0 else 100 offset = max(request.offset, 0) diff --git a/src/noteflow/grpc/mixins/project/_mixin.py b/src/noteflow/grpc/mixins/project/_mixin.py index b81e17f..6ef3793 100644 --- a/src/noteflow/grpc/mixins/project/_mixin.py +++ b/src/noteflow/grpc/mixins/project/_mixin.py @@ -41,8 +41,6 @@ if TYPE_CHECKING: logger = get_logger(__name__) - - async def _require_and_parse_project_id( request_project_id: str, context: GrpcContext, @@ -50,7 +48,6 @@ async def _require_and_parse_project_id( """Require and parse a project_id from request.""" if not request_project_id: await abort_invalid_argument(context, ERROR_PROJECT_ID_REQUIRED) - raise # Unreachable but helps type checker return await parse_project_id(request_project_id, context) @@ -61,9 +58,9 @@ async def _require_and_parse_workspace_id( """Require and parse a workspace_id from request.""" if not request_workspace_id: await abort_invalid_argument(context, ERROR_WORKSPACE_ID_REQUIRED) - raise # Unreachable but helps type checker return await parse_workspace_id(request_workspace_id, context) + class ProjectMixin: """Mixin providing project management functionality. @@ -89,11 +86,12 @@ class ProjectMixin: if not request.name: await abort_invalid_argument(context, "name is required") - raise # Unreachable but helps type checker slug = request.slug if request.HasField("slug") else None description = request.description if request.HasField("description") else None - settings = proto_to_project_settings(request.settings) if request.HasField("settings") else None + settings = ( + proto_to_project_settings(request.settings) if request.HasField("settings") else None + ) async with cast(ProjectRepositoryProvider, self.create_repository_provider()) as uow: await require_feature_projects(uow, context) @@ -123,7 +121,6 @@ class ProjectMixin: project = await project_service.get_project(uow, project_id) if project is None: await abort_not_found(context, ENTITY_PROJECT, request.project_id) - raise # Unreachable but helps type checker return project_to_proto(project) @@ -138,7 +135,6 @@ class ProjectMixin: if not request.slug: await abort_invalid_argument(context, "slug is required") - raise # Unreachable but helps type checker async with cast(ProjectRepositoryProvider, self.create_repository_provider()) as uow: await require_feature_projects(uow, context) @@ -146,7 +142,6 @@ class ProjectMixin: project = await project_service.get_project_by_slug(uow, workspace_id, request.slug) if project is None: await abort_not_found(context, ENTITY_PROJECT, request.slug) - raise # Unreachable but helps type checker return project_to_proto(project) @@ -159,6 +154,7 @@ class ProjectMixin: project_service = await require_project_service(self.project_service, context) workspace_id = await _require_and_parse_workspace_id(request.workspace_id, context) + include_archived = request.include_archived limit = request.limit if request.limit > 0 else 50 offset = max(request.offset, 0) @@ -166,21 +162,20 @@ class ProjectMixin: await require_feature_projects(uow, context) projects = await project_service.list_projects( - uow=uow, - workspace_id=workspace_id, - include_archived=request.include_archived, + uow, + workspace_id, + include_archived=include_archived, limit=limit, offset=offset, ) - total_count = await project_service.count_projects( - uow=uow, - workspace_id=workspace_id, - include_archived=request.include_archived, + uow, + workspace_id, + include_archived=include_archived, ) return noteflow_pb2.ListProjectsResponse( - projects=[project_to_proto(p) for p in projects], + projects=[project_to_proto(project) for project in projects], total_count=total_count, ) @@ -196,14 +191,16 @@ class ProjectMixin: name = request.name if request.HasField("name") else None slug = request.slug if request.HasField("slug") else None description = request.description if request.HasField("description") else None - settings = proto_to_project_settings(request.settings) if request.HasField("settings") else None + settings = ( + proto_to_project_settings(request.settings) if request.HasField("settings") else None + ) async with cast(ProjectRepositoryProvider, self.create_repository_provider()) as uow: await require_feature_projects(uow, context) project = await project_service.update_project( - uow=uow, - project_id=project_id, + uow, + project_id, name=name, slug=slug, description=description, @@ -211,7 +208,6 @@ class ProjectMixin: ) if project is None: await abort_not_found(context, ENTITY_PROJECT, request.project_id) - raise # Unreachable but helps type checker return project_to_proto(project) @@ -231,11 +227,9 @@ class ProjectMixin: project = await project_service.archive_project(uow, project_id) except CannotArchiveDefaultProjectError: await abort_failed_precondition(context, "Cannot archive the default project") - raise # Unreachable but helps type checker if project is None: await abort_not_found(context, ENTITY_PROJECT, request.project_id) - raise # Unreachable but helps type checker return project_to_proto(project) @@ -254,7 +248,6 @@ class ProjectMixin: project = await project_service.restore_project(uow, project_id) if project is None: await abort_not_found(context, ENTITY_PROJECT, request.project_id) - raise # Unreachable but helps type checker return project_to_proto(project) @@ -302,7 +295,6 @@ class ProjectMixin: ) except ValueError as exc: await abort_invalid_argument(context, str(exc)) - raise # Unreachable but helps type checker await uow.commit() return noteflow_pb2.SetActiveProjectResponse() @@ -327,11 +319,9 @@ class ProjectMixin: ) except ValueError as exc: await abort_invalid_argument(context, str(exc)) - raise # Unreachable but helps type checker if project is None: await abort_not_found(context, ENTITY_PROJECT, "default") - raise # Unreachable but helps type checker response = noteflow_pb2.GetActiveProjectResponse( project=project_to_proto(project), diff --git a/src/noteflow/grpc/mixins/streaming/_processing/__init__.py b/src/noteflow/grpc/mixins/streaming/_processing/__init__.py index 90f5a41..aae5d72 100644 --- a/src/noteflow/grpc/mixins/streaming/_processing/__init__.py +++ b/src/noteflow/grpc/mixins/streaming/_processing/__init__.py @@ -59,7 +59,6 @@ async def _decode_chunk_audio( ) except ValueError as e: await abort_invalid_argument(context, str(e)) - raise # Unreachable but helps type checker conversion = AudioConversionContext( source_sample_rate=sample_rate, diff --git a/src/noteflow/grpc/mixins/summarization/_generation_mixin.py b/src/noteflow/grpc/mixins/summarization/_generation_mixin.py index dc08362..05b94e4 100644 --- a/src/noteflow/grpc/mixins/summarization/_generation_mixin.py +++ b/src/noteflow/grpc/mixins/summarization/_generation_mixin.py @@ -30,10 +30,13 @@ from ._template_resolution import ( if TYPE_CHECKING: from collections.abc import Callable + from noteflow.application.services.ner import NerService from noteflow.application.services.summarization import SummarizationService from noteflow.application.services.webhooks import WebhookService from noteflow.domain.entities import Meeting from noteflow.domain.identity import OperationContext + from noteflow.infrastructure.asr import FasterWhisperEngine + from noteflow.infrastructure.diarization.engine import DiarizationEngine logger = get_logger(__name__) @@ -48,9 +51,41 @@ class _SummaryGenerationContext: force_regenerate: bool +@dataclass(frozen=True, slots=True) +class _SummaryRequestContext: + request: noteflow_pb2.GenerateSummaryRequest + meeting: Meeting + segments: list[Segment] + style_instructions: str | None + context: GrpcContext + op_context: OperationContext + + async def resolve_style_prompt( + self, + repo: UnitOfWork, + summarization_service: SummarizationService | None, + ) -> str | None: + """Resolve style prompt for this request context.""" + return await resolve_template_prompt( + TemplateResolutionInputs( + request=self.request, + meeting=self.meeting, + segments=self.segments, + style_instructions=self.style_instructions, + context=self.context, + op_context=self.op_context, + repo=repo, + summarization_service=summarization_service, + ) + ) + + class SummarizationGenerationMixin: """Generate summaries and handle summary webhooks.""" + asr_engine: FasterWhisperEngine | None + diarization_engine: DiarizationEngine | None + ner_service: NerService | None summarization_service: SummarizationService | None webhook_service: WebhookService | None create_repository_provider: Callable[..., object] @@ -68,28 +103,28 @@ class SummarizationGenerationMixin: meeting_id=request.meeting_id, include_provider_details=True, ) - meeting_id, op_context, style_instructions, meeting, existing, segments = ( - await self._prepare_summary_request(request, context) - ) + ( + meeting_id, + op_context, + style_instructions, + meeting, + existing, + segments, + ) = await self._prepare_summary_request(request, context) if existing and not request.force_regenerate: await self._mark_summary_step(request.meeting_id, ProcessingStepStatus.COMPLETED) return summary_to_proto(existing) await self._ensure_cloud_provider() - - async with cast(UnitOfWork, self.create_repository_provider()) as repo: - style_prompt = await self._resolve_style_prompt( - TemplateResolutionInputs( - request=request, - meeting=meeting, - segments=segments, - style_instructions=style_instructions, - context=context, - op_context=op_context, - repo=repo, - summarization_service=self.summarization_service, - ) - ) + request_context = _SummaryRequestContext( + request=request, + meeting=meeting, + segments=segments, + style_instructions=style_instructions, + context=context, + op_context=op_context, + ) + style_prompt = await self._resolve_style_prompt_for_request(request_context) saved, trigger_webhook = await self._generate_summary_with_status( _SummaryGenerationContext( meeting_id=meeting_id, @@ -121,6 +156,13 @@ class SummarizationGenerationMixin: ), ) + async def _resolve_style_prompt_for_request( + self, + request_context: _SummaryRequestContext, + ) -> str | None: + async with cast(UnitOfWork, self.create_repository_provider()) as repo: + return await request_context.resolve_style_prompt(repo, self.summarization_service) + async def _resolve_style_prompt(self, inputs: TemplateResolutionInputs) -> str | None: return await resolve_template_prompt(inputs) @@ -244,7 +286,6 @@ class SummarizationGenerationMixin: meeting = await repo.meetings.get(meeting_id) if meeting is None: await abort_not_found(context, ENTITY_MEETING, request.meeting_id) - raise # Unreachable but helps type checker existing = await repo.summaries.get_by_meeting(meeting.id) segments = list(await repo.segments.get_by_meeting(meeting.id)) diff --git a/src/noteflow/grpc/types/__init__.py b/src/noteflow/grpc/types/__init__.py index c826f07..e78c34e 100644 --- a/src/noteflow/grpc/types/__init__.py +++ b/src/noteflow/grpc/types/__init__.py @@ -9,6 +9,8 @@ from typing import Generic, TypeVar, Never import grpc +from noteflow.domain.constants.fields import UNKNOWN as UNKNOWN_VALUE + T = TypeVar("T", covariant=True) U = TypeVar("U") E = TypeVar("E", bound=BaseException) @@ -17,7 +19,7 @@ E = TypeVar("E", bound=BaseException) class ClientErrorCode(StrEnum): """Client-facing error codes for gRPC operations.""" - UNKNOWN = "unknown" + UNKNOWN = UNKNOWN_VALUE NOT_CONNECTED = "not_connected" NOT_FOUND = "not_found" INVALID_ARGUMENT = "invalid_argument" diff --git a/tests/grpc/test_interceptors.py b/tests/grpc/test_interceptors.py index d2c0968..cf82137 100644 --- a/tests/grpc/test_interceptors.py +++ b/tests/grpc/test_interceptors.py @@ -5,8 +5,8 @@ Tests identity context validation and per-RPC request logging. from __future__ import annotations -from collections.abc import Awaitable, Callable -from typing import Protocol, cast +from collections.abc import AsyncIterator, Awaitable, Callable +from typing import cast from unittest.mock import AsyncMock, MagicMock, patch import grpc @@ -19,7 +19,11 @@ from noteflow.grpc.interceptors import ( IdentityInterceptor, RequestLoggingInterceptor, ) -from noteflow.grpc.interceptors._types import ServicerContextProtocol +from noteflow.grpc.interceptors._types import ( + HandlerCallDetailsProtocol, + RpcMethodHandlerProtocol, + ServicerContextProtocol, +) from noteflow.infrastructure.logging import ( get_request_id, get_user_id, @@ -27,14 +31,13 @@ from noteflow.infrastructure.logging import ( request_id_var, ) +# Type alias for callable unary-unary handlers +UnaryUnaryCallable = Callable[[bytes, ServicerContextProtocol], Awaitable[bytes]] -class _DummyRequest: - """Placeholder request type for handler casts.""" +# Type alias for test handler protocol (using bytes for wire-level compatibility) +TestHandlerProtocol = RpcMethodHandlerProtocol[bytes, bytes] -class _DummyResponse: - """Placeholder response type for handler casts.""" - # Test data TEST_REQUEST_ID = "test-request-123" TEST_USER_ID = "user-456" @@ -45,25 +48,28 @@ TEST_METHOD = "/noteflow.NoteFlowService/GetMeeting" pytestmark = pytest.mark.usefixtures("reset_context_vars") +# Type alias for handler call details matching grpc module +HandlerCallDetails = HandlerCallDetailsProtocol + def create_handler_call_details( method: str = TEST_METHOD, metadata: list[tuple[str, str | bytes]] | None = None, -) -> grpc.HandlerCallDetails: +) -> HandlerCallDetails: """Create mock HandlerCallDetails with metadata.""" - details = MagicMock(spec=grpc.HandlerCallDetails) + details = MagicMock(spec=HandlerCallDetails) details.method = method details.invocation_metadata = metadata or [] return details -def create_mock_handler() -> _UnaryUnaryHandler: +def create_mock_handler() -> TestHandlerProtocol: """Create a mock RPC method handler.""" return _MockHandler() def create_mock_continuation( - handler: _UnaryUnaryHandler | None = None, + handler: TestHandlerProtocol | None = None, ) -> AsyncMock: """Create a mock continuation function.""" if handler is None: @@ -71,46 +77,75 @@ def create_mock_continuation( return AsyncMock(return_value=handler) -class _UnaryUnaryHandler(Protocol): - """Protocol for unary-unary RPC method handlers.""" - - unary_unary: Callable[ - [_DummyRequest, ServicerContextProtocol], - Awaitable[_DummyResponse], - ] | None - unary_stream: object | None - stream_unary: object | None - stream_stream: object | None - request_deserializer: object | None - response_serializer: object | None - - class _MockHandler: - """Concrete handler for tests with typed unary_unary.""" + """Concrete handler for tests with typed unary_unary. - unary_unary: Callable[ - [_DummyRequest, ServicerContextProtocol], - Awaitable[_DummyResponse], - ] | None - unary_stream: object | None - stream_unary: object | None - stream_stream: object | None - request_deserializer: object | None - response_serializer: object | None + Implements RpcMethodHandlerProtocol for interceptor testing. + """ def __init__(self) -> None: - self.unary_unary = cast( - Callable[ - [_DummyRequest, ServicerContextProtocol], - Awaitable[_DummyResponse], - ], - AsyncMock(return_value="response"), + self._unary_unary: Callable[[bytes, ServicerContextProtocol], Awaitable[bytes]] | None = ( + cast( + Callable[[bytes, ServicerContextProtocol], Awaitable[bytes]], + AsyncMock(return_value=b"response"), + ) ) - self.unary_stream = None - self.stream_unary = None - self.stream_stream = None - self.request_deserializer = None - self.response_serializer = None + self._unary_stream: ( + Callable[[bytes, ServicerContextProtocol], AsyncIterator[bytes]] | None + ) = None + self._stream_unary: ( + Callable[[AsyncIterator[bytes], ServicerContextProtocol], Awaitable[bytes]] | None + ) = None + self._stream_stream: ( + Callable[[AsyncIterator[bytes], ServicerContextProtocol], AsyncIterator[bytes]] | None + ) = None + self._request_deserializer: Callable[[bytes], bytes] | None = None + self._response_serializer: Callable[[bytes], bytes] | None = None + + @property + def request_deserializer(self) -> Callable[[bytes], bytes] | None: + """Request deserializer function.""" + return self._request_deserializer + + @property + def response_serializer(self) -> Callable[[bytes], bytes] | None: + """Response serializer function.""" + return self._response_serializer + + @property + def unary_unary( + self, + ) -> Callable[[bytes, ServicerContextProtocol], Awaitable[bytes]] | None: + """Unary-unary handler behavior.""" + return self._unary_unary + + @unary_unary.setter + def unary_unary( + self, value: Callable[[bytes, ServicerContextProtocol], Awaitable[bytes]] | None + ) -> None: + """Set unary-unary handler behavior.""" + self._unary_unary = value + + @property + def unary_stream( + self, + ) -> Callable[[bytes, ServicerContextProtocol], AsyncIterator[bytes]] | None: + """Unary-stream handler behavior.""" + return self._unary_stream + + @property + def stream_unary( + self, + ) -> Callable[[AsyncIterator[bytes], ServicerContextProtocol], Awaitable[bytes]] | None: + """Stream-unary handler behavior.""" + return self._stream_unary + + @property + def stream_stream( + self, + ) -> Callable[[AsyncIterator[bytes], ServicerContextProtocol], AsyncIterator[bytes]] | None: + """Stream-stream handler behavior.""" + return self._stream_stream class TestIdentityInterceptor: @@ -157,8 +192,9 @@ class TestIdentityInterceptor: original_handler = create_mock_handler() continuation = create_mock_continuation(original_handler) - handler = await interceptor.intercept_service(continuation, details) - typed_handler = cast(_UnaryUnaryHandler, handler) + typed_handler: TestHandlerProtocol = await interceptor.intercept_service( + continuation, details + ) # Handler should be a rejection handler wrapping the original assert typed_handler.unary_unary is not None, "handler should have unary_unary" @@ -166,7 +202,9 @@ class TestIdentityInterceptor: # is a rejection wrapper that will abort with UNAUTHENTICATED continuation.assert_called_once() # The returned handler should NOT be the original handler - assert typed_handler.unary_unary is not original_handler.unary_unary, "should return rejection handler, not original" + assert typed_handler.unary_unary is not original_handler.unary_unary, ( + "should return rejection handler, not original" + ) @pytest.mark.asyncio async def test_reject_handler_aborts_with_unauthenticated(self) -> None: @@ -175,8 +213,9 @@ class TestIdentityInterceptor: details = create_handler_call_details(metadata=[]) continuation = create_mock_continuation() - handler = await interceptor.intercept_service(continuation, details) - typed_handler = cast(_UnaryUnaryHandler, handler) + typed_handler: TestHandlerProtocol = await interceptor.intercept_service( + continuation, details + ) # Create mock context to verify abort behavior context: ServicerContextProtocol = AsyncMock(spec=ServicerContextProtocol) @@ -184,11 +223,14 @@ class TestIdentityInterceptor: with pytest.raises(grpc.RpcError, match="x-request-id"): assert typed_handler.unary_unary is not None - await typed_handler.unary_unary(MagicMock(), context) + unary_fn = typed_handler.unary_unary + await unary_fn(MagicMock(), context) context.abort.assert_called_once() call_args = context.abort.call_args - assert call_args[0][0] == grpc.StatusCode.UNAUTHENTICATED, "should abort with UNAUTHENTICATED" + assert call_args[0][0] == grpc.StatusCode.UNAUTHENTICATED, ( + "should abort with UNAUTHENTICATED" + ) assert "x-request-id" in call_args[0][1], "error message should mention x-request-id" @pytest.mark.asyncio @@ -224,16 +266,16 @@ class TestRequestLoggingInterceptor: request_id_var.set(TEST_REQUEST_ID) with patch("noteflow.grpc.interceptors.logging._logging_ops.logger") as mock_logger: - wrapped_handler = await interceptor.intercept_service(continuation, details) - typed_handler = cast(_UnaryUnaryHandler, wrapped_handler) + wrapped_handler: TestHandlerProtocol = await interceptor.intercept_service( + continuation, details + ) - # Execute the wrapped handler context: ServicerContextProtocol = AsyncMock(spec=ServicerContextProtocol) context.peer = MagicMock(return_value="ipv4:127.0.0.1:12345") - assert typed_handler.unary_unary is not None - await typed_handler.unary_unary(MagicMock(), context) + assert wrapped_handler.unary_unary is not None + unary_fn = wrapped_handler.unary_unary + await unary_fn(MagicMock(), context) - # Verify logging mock_logger.info.assert_called_once() call_kwargs = mock_logger.info.call_args[1] assert call_kwargs["method"] == TEST_METHOD, "should log method" @@ -246,28 +288,31 @@ class TestRequestLoggingInterceptor: """Interceptor logs error status when handler raises.""" interceptor = RequestLoggingInterceptor() - # Create handler that raises handler = create_mock_handler() - handler.unary_unary = AsyncMock(side_effect=Exception("Test error")) + mock_handler = cast(_MockHandler, handler) + mock_handler.unary_unary = AsyncMock(side_effect=Exception("Test error")) details = create_handler_call_details(method=TEST_METHOD) continuation = create_mock_continuation(handler) with patch("noteflow.grpc.interceptors.logging._logging_ops.logger") as mock_logger: - wrapped_handler = await interceptor.intercept_service(continuation, details) - typed_handler = cast(_UnaryUnaryHandler, wrapped_handler) + wrapped_handler: TestHandlerProtocol = await interceptor.intercept_service( + continuation, details + ) context: ServicerContextProtocol = AsyncMock(spec=ServicerContextProtocol) context.peer = MagicMock(return_value="ipv4:127.0.0.1:12345") with pytest.raises(Exception, match="Test error"): - assert typed_handler.unary_unary is not None - await typed_handler.unary_unary(MagicMock(), context) + assert wrapped_handler.unary_unary is not None + unary_fn = wrapped_handler.unary_unary + await unary_fn(MagicMock(), context) - # Should still log with INTERNAL status mock_logger.info.assert_called_once() call_kwargs = mock_logger.info.call_args[1] - assert call_kwargs["status"] == "INTERNAL", "should log INTERNAL for unhandled exception" + assert call_kwargs["status"] == "INTERNAL", ( + "should log INTERNAL for unhandled exception" + ) @pytest.mark.asyncio async def test_passes_through_to_continuation(self) -> None: @@ -277,7 +322,7 @@ class TestRequestLoggingInterceptor: details = create_handler_call_details() continuation = create_mock_continuation(handler) - await interceptor.intercept_service(continuation, details) + _: TestHandlerProtocol = await interceptor.intercept_service(continuation, details) continuation.assert_called_once_with(details) @@ -290,17 +335,17 @@ class TestRequestLoggingInterceptor: continuation = create_mock_continuation(handler) with patch("noteflow.grpc.interceptors.logging._logging_ops.logger") as mock_logger: - wrapped_handler = await interceptor.intercept_service(continuation, details) - typed_handler = cast(_UnaryUnaryHandler, wrapped_handler) + wrapped_handler: TestHandlerProtocol = await interceptor.intercept_service( + continuation, details + ) - # Context without peer method context: ServicerContextProtocol = AsyncMock(spec=ServicerContextProtocol) context.peer = MagicMock(side_effect=RuntimeError("No peer")) - assert typed_handler.unary_unary is not None - await typed_handler.unary_unary(MagicMock(), context) + assert wrapped_handler.unary_unary is not None + unary_fn = wrapped_handler.unary_unary + await unary_fn(MagicMock(), context) - # Should still log with None peer mock_logger.info.assert_called_once() call_kwargs = mock_logger.info.call_args[1] assert call_kwargs["peer"] is None, "should handle missing peer"