Files
noteflow/tests/grpc/test_oidc_mixin.py
Travis Vasceannie d8090a98e8
Some checks failed
CI / test-typescript (push) Has been cancelled
CI / test-rust (push) Has been cancelled
CI / test-python (push) Has been cancelled
ci/cd fixes
2026-01-26 00:28:15 +00:00

702 lines
27 KiB
Python

"""Tests for OidcMixin gRPC endpoints.
Tests cover:
- RegisterOidcProvider: provider registration with validation
- ListOidcProviders: listing with optional filtering
- GetOidcProvider: single provider retrieval
- UpdateOidcProvider: provider configuration updates
- DeleteOidcProvider: provider removal
- RefreshOidcDiscovery: discovery refresh operations
- ListOidcPresets: available provider presets
"""
from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
import pytest
from noteflow.domain.auth.oidc import (
ClaimMapping,
OidcDiscoveryConfig,
OidcProviderConfig,
OidcProviderPreset,
)
from noteflow.grpc.mixins._types import GrpcContext
from noteflow.grpc.mixins.oidc import OidcMixin
from noteflow.grpc.proto import noteflow_pb2
from noteflow.infrastructure.auth.oidc_discovery import OidcDiscoveryError
if TYPE_CHECKING:
from datetime import datetime
class MockServicerHost(OidcMixin):
"""Mock servicer host implementing required protocol for OidcMixin."""
def __init__(self) -> None:
"""Initialize mock servicer with no OIDC service (created lazily)."""
# Type stubs for mixin methods to fix type inference
if TYPE_CHECKING:
async def RegisterOidcProvider(
self,
request: noteflow_pb2.RegisterOidcProviderRequest,
context: GrpcContext,
) -> noteflow_pb2.OidcProviderProto: ...
async def ListOidcProviders(
self,
request: noteflow_pb2.ListOidcProvidersRequest,
context: GrpcContext,
) -> noteflow_pb2.ListOidcProvidersResponse: ...
async def GetOidcProvider(
self,
request: noteflow_pb2.GetOidcProviderRequest,
context: GrpcContext,
) -> noteflow_pb2.OidcProviderProto: ...
async def UpdateOidcProvider(
self,
request: noteflow_pb2.UpdateOidcProviderRequest,
context: GrpcContext,
) -> noteflow_pb2.OidcProviderProto: ...
async def DeleteOidcProvider(
self,
request: noteflow_pb2.DeleteOidcProviderRequest,
context: GrpcContext,
) -> noteflow_pb2.DeleteOidcProviderResponse: ...
async def RefreshOidcDiscovery(
self,
request: noteflow_pb2.RefreshOidcDiscoveryRequest,
context: GrpcContext,
) -> noteflow_pb2.RefreshOidcDiscoveryResponse: ...
async def ListOidcPresets(
self,
request: noteflow_pb2.ListOidcPresetsRequest,
context: GrpcContext,
) -> noteflow_pb2.ListOidcPresetsResponse: ...
@pytest.fixture
def oidc_servicer() -> MockServicerHost:
"""Create servicer for OIDC mixin testing."""
return MockServicerHost()
@pytest.fixture
def sample_provider(sample_datetime: datetime) -> OidcProviderConfig:
"""Create sample OIDC provider config for testing."""
return OidcProviderConfig(
id=uuid4(),
workspace_id=uuid4(),
name="Test Authentik",
preset=OidcProviderPreset.AUTHENTIK,
issuer_url="https://auth.example.com",
client_id="test-client-id",
enabled=True,
discovery=OidcDiscoveryConfig(
issuer="https://auth.example.com",
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
userinfo_endpoint="https://auth.example.com/userinfo",
jwks_uri="https://auth.example.com/.well-known/jwks.json",
scopes_supported=("openid", "profile", "email"),
claims_supported=("sub", "email", "name"),
),
claim_mapping=ClaimMapping(),
scopes=("openid", "profile", "email"),
require_email_verified=True,
allowed_groups=(),
created_at=sample_datetime,
updated_at=sample_datetime,
)
@pytest.fixture
def disabled_provider(
sample_provider: OidcProviderConfig, sample_datetime: datetime
) -> OidcProviderConfig:
"""Create a disabled OIDC provider config for testing."""
return OidcProviderConfig(
id=sample_provider.id,
workspace_id=sample_provider.workspace_id,
name=sample_provider.name,
preset=sample_provider.preset,
issuer_url=sample_provider.issuer_url,
client_id=sample_provider.client_id,
enabled=False,
discovery=sample_provider.discovery,
claim_mapping=sample_provider.claim_mapping,
scopes=sample_provider.scopes,
require_email_verified=sample_provider.require_email_verified,
allowed_groups=sample_provider.allowed_groups,
created_at=sample_datetime,
updated_at=sample_datetime,
)
class TestRegisterOidcProvider:
"""Tests for RegisterOidcProvider RPC."""
async def test_registers_provider_successfully(
self,
oidc_servicer: MockServicerHost,
mock_grpc_context: MagicMock,
sample_provider: OidcProviderConfig,
) -> None:
"""RegisterOidcProvider creates provider with valid input."""
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
mock_service = MagicMock()
mock_service.register_provider = AsyncMock(return_value=(sample_provider, []))
mock_service.registry = MagicMock()
mock_get_service.return_value = mock_service
request = noteflow_pb2.RegisterOidcProviderRequest(
workspace_id=str(sample_provider.workspace_id),
name="Test Authentik",
issuer_url="https://auth.example.com",
client_id="test-client-id",
preset="authentik",
)
response = await oidc_servicer.RegisterOidcProvider(request, mock_grpc_context)
expected_provider_id = str(sample_provider.id)
assert response.id == expected_provider_id, "should return provider id"
assert response.name == "Test Authentik", "should return provider name"
assert response.preset == "authentik", "should return preset"
mock_service.register_provider.assert_called_once()
async def test_returns_warnings_from_validation(
self,
oidc_servicer: MockServicerHost,
mock_grpc_context: MagicMock,
sample_provider: OidcProviderConfig,
) -> None:
"""RegisterOidcProvider returns validation warnings."""
warnings = ["Scope 'groups' not supported by provider"]
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
mock_service = MagicMock()
mock_service.register_provider = AsyncMock(return_value=(sample_provider, warnings))
mock_service.registry = MagicMock()
mock_get_service.return_value = mock_service
request = noteflow_pb2.RegisterOidcProviderRequest(
name="Test Provider",
issuer_url="https://auth.example.com",
client_id="test-client-id",
)
response = await oidc_servicer.RegisterOidcProvider(request, mock_grpc_context)
assert len(response.warnings) == 1, "should include warnings"
assert "groups" in response.warnings[0], "should include warning message"
async def test_rejects_missing_name(
self,
oidc_servicer: MockServicerHost,
mock_grpc_context: MagicMock,
) -> None:
"""RegisterOidcProvider aborts when name is missing."""
request = noteflow_pb2.RegisterOidcProviderRequest(
issuer_url="https://auth.example.com",
client_id="test-client-id",
)
with pytest.raises(AssertionError, match="Unreachable"):
await oidc_servicer.RegisterOidcProvider(request, mock_grpc_context)
mock_grpc_context.abort.assert_called_once()
async def test_rejects_missing_issuer_url(
self,
oidc_servicer: MockServicerHost,
mock_grpc_context: MagicMock,
) -> None:
"""RegisterOidcProvider aborts when issuer_url is missing."""
request = noteflow_pb2.RegisterOidcProviderRequest(
name="Test Provider",
client_id="test-client-id",
)
with pytest.raises(AssertionError, match="Unreachable"):
await oidc_servicer.RegisterOidcProvider(request, mock_grpc_context)
mock_grpc_context.abort.assert_called_once()
async def test_rejects_invalid_issuer_url_scheme(
self,
oidc_servicer: MockServicerHost,
mock_grpc_context: MagicMock,
) -> None:
"""RegisterOidcProvider aborts when issuer_url has invalid scheme."""
request = noteflow_pb2.RegisterOidcProviderRequest(
name="Test Provider",
issuer_url="ftp://auth.example.com",
client_id="test-client-id",
)
with pytest.raises(AssertionError, match="Unreachable"):
await oidc_servicer.RegisterOidcProvider(request, mock_grpc_context)
mock_grpc_context.abort.assert_called_once()
async def test_rejects_missing_client_id(
self,
oidc_servicer: MockServicerHost,
mock_grpc_context: MagicMock,
) -> None:
"""RegisterOidcProvider aborts when client_id is missing."""
request = noteflow_pb2.RegisterOidcProviderRequest(
name="Test Provider",
issuer_url="https://auth.example.com",
)
with pytest.raises(AssertionError, match="Unreachable"):
await oidc_servicer.RegisterOidcProvider(request, mock_grpc_context)
mock_grpc_context.abort.assert_called_once()
async def test_handles_discovery_error(
self,
oidc_servicer: MockServicerHost,
mock_grpc_context: MagicMock,
) -> None:
"""RegisterOidcProvider aborts on discovery failure."""
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
mock_service = MagicMock()
mock_service.register_provider = AsyncMock(
side_effect=OidcDiscoveryError("Connection failed")
)
mock_get_service.return_value = mock_service
request = noteflow_pb2.RegisterOidcProviderRequest(
name="Test Provider",
issuer_url="https://auth.example.com",
client_id="test-client-id",
)
with pytest.raises(AssertionError, match="Unreachable"):
await oidc_servicer.RegisterOidcProvider(request, mock_grpc_context)
mock_grpc_context.abort.assert_called_once()
class TestListOidcProviders:
"""Tests for ListOidcProviders RPC."""
async def test_lists_all_providers(
self,
oidc_servicer: MockServicerHost,
mock_grpc_context: MagicMock,
sample_provider: OidcProviderConfig,
) -> None:
"""ListOidcProviders returns all providers."""
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
mock_service = MagicMock()
mock_service.registry.list_providers = MagicMock(return_value=[sample_provider])
mock_get_service.return_value = mock_service
request = noteflow_pb2.ListOidcProvidersRequest()
response = await oidc_servicer.ListOidcProviders(request, mock_grpc_context)
assert response.total_count == 1, "should return total count"
assert len(response.providers) == 1, "should return providers list"
assert response.providers[0].name == "Test Authentik", (
"should return correct provider name"
)
async def test_filters_by_workspace_id(
self,
oidc_servicer: MockServicerHost,
mock_grpc_context: MagicMock,
) -> None:
"""ListOidcProviders filters by workspace_id."""
workspace_id = uuid4()
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
mock_service = MagicMock()
mock_service.registry.list_providers = MagicMock(return_value=[])
mock_get_service.return_value = mock_service
request = noteflow_pb2.ListOidcProvidersRequest(workspace_id=str(workspace_id))
await oidc_servicer.ListOidcProviders(request, mock_grpc_context)
mock_service.registry.list_providers.assert_called_once_with(
workspace_id=workspace_id,
enabled_only=False,
)
async def test_filters_enabled_only(
self,
oidc_servicer: MockServicerHost,
mock_grpc_context: MagicMock,
) -> None:
"""ListOidcProviders filters to enabled providers only."""
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
mock_service = MagicMock()
mock_service.registry.list_providers = MagicMock(return_value=[])
mock_get_service.return_value = mock_service
request = noteflow_pb2.ListOidcProvidersRequest(enabled_only=True)
await oidc_servicer.ListOidcProviders(request, mock_grpc_context)
mock_service.registry.list_providers.assert_called_once_with(
workspace_id=None,
enabled_only=True,
)
async def test_returns_empty_list_when_no_providers(
self,
oidc_servicer: MockServicerHost,
mock_grpc_context: MagicMock,
) -> None:
"""ListOidcProviders returns empty list when no providers exist."""
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
mock_service = MagicMock()
mock_service.registry.list_providers = MagicMock(return_value=[])
mock_get_service.return_value = mock_service
request = noteflow_pb2.ListOidcProvidersRequest()
response = await oidc_servicer.ListOidcProviders(request, mock_grpc_context)
assert response.total_count == 0, "total count should be zero when no providers"
assert len(response.providers) == 0, "providers list should be empty"
class TestGetOidcProvider:
"""Tests for GetOidcProvider RPC."""
async def test_returns_provider_by_id(
self,
oidc_servicer: MockServicerHost,
mock_grpc_context: MagicMock,
sample_provider: OidcProviderConfig,
) -> None:
"""GetOidcProvider returns provider when found."""
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
mock_service = MagicMock()
mock_service.registry.get_provider = MagicMock(return_value=sample_provider)
mock_get_service.return_value = mock_service
request = noteflow_pb2.GetOidcProviderRequest(provider_id=str(sample_provider.id))
response = await oidc_servicer.GetOidcProvider(request, mock_grpc_context)
expected_provider_id = str(sample_provider.id)
assert response.id == expected_provider_id, "should return correct provider ID"
assert response.name == "Test Authentik", "should return correct provider name"
async def test_aborts_on_invalid_provider_id_format(
self,
oidc_servicer: MockServicerHost,
mock_grpc_context: MagicMock,
) -> None:
"""GetOidcProvider aborts with INVALID_ARGUMENT for invalid UUID."""
request = noteflow_pb2.GetOidcProviderRequest(provider_id="not-a-uuid")
with pytest.raises(AssertionError, match="Unreachable"):
await oidc_servicer.GetOidcProvider(request, mock_grpc_context)
mock_grpc_context.abort.assert_called_once()
class TestUpdateOidcProvider:
"""Tests for UpdateOidcProvider RPC."""
async def test_updates_provider_name(
self,
oidc_servicer: MockServicerHost,
mock_grpc_context: MagicMock,
sample_provider: OidcProviderConfig,
) -> None:
"""UpdateOidcProvider updates provider name."""
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
mock_service = MagicMock()
mock_service.registry.get_provider = MagicMock(return_value=sample_provider)
mock_get_service.return_value = mock_service
request = noteflow_pb2.UpdateOidcProviderRequest(
provider_id=str(sample_provider.id),
name="Updated Name",
)
response = await oidc_servicer.UpdateOidcProvider(request, mock_grpc_context)
assert response.name == "Updated Name", "should return updated name"
async def test_updates_provider_scopes(
self,
oidc_servicer: MockServicerHost,
mock_grpc_context: MagicMock,
sample_provider: OidcProviderConfig,
) -> None:
"""UpdateOidcProvider updates provider scopes."""
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
mock_service = MagicMock()
mock_service.registry.get_provider = MagicMock(return_value=sample_provider)
mock_get_service.return_value = mock_service
request = noteflow_pb2.UpdateOidcProviderRequest(
provider_id=str(sample_provider.id),
scopes=["openid", "profile", "email", "groups"],
)
response = await oidc_servicer.UpdateOidcProvider(request, mock_grpc_context)
assert "groups" in response.scopes, "should include added scope"
async def test_enables_provider(
self,
oidc_servicer: MockServicerHost,
mock_grpc_context: MagicMock,
disabled_provider: OidcProviderConfig,
) -> None:
"""UpdateOidcProvider can enable a disabled provider."""
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
mock_service = MagicMock()
mock_service.registry.get_provider = MagicMock(return_value=disabled_provider)
mock_get_service.return_value = mock_service
request = noteflow_pb2.UpdateOidcProviderRequest(
provider_id=str(disabled_provider.id),
enabled=True,
)
response = await oidc_servicer.UpdateOidcProvider(request, mock_grpc_context)
assert response.enabled is True, "provider should be enabled"
class TestOidcProviderNotFound:
"""Tests for OIDC provider not found cases."""
@staticmethod
def _configure_get_provider_none(mock_service: MagicMock) -> None:
mock_service.registry.get_provider = MagicMock(return_value=None)
@staticmethod
def _configure_remove_provider_false(mock_service: MagicMock) -> None:
mock_service.registry.remove_provider = MagicMock(return_value=False)
@pytest.mark.parametrize(
("method_name", "proto_request", "configure"),
[
pytest.param(
"GetOidcProvider",
noteflow_pb2.GetOidcProviderRequest(provider_id=str(uuid4())),
_configure_get_provider_none,
id="get",
),
pytest.param(
"UpdateOidcProvider",
noteflow_pb2.UpdateOidcProviderRequest(
provider_id=str(uuid4()),
name="New Name",
),
_configure_get_provider_none,
id="update",
),
pytest.param(
"DeleteOidcProvider",
noteflow_pb2.DeleteOidcProviderRequest(provider_id=str(uuid4())),
_configure_remove_provider_false,
id="delete",
),
],
)
async def test_provider_not_found_aborts(
self,
oidc_servicer: MockServicerHost,
mock_grpc_context: MagicMock,
method_name: str,
proto_request: object,
configure: Callable[[MagicMock], None],
) -> None:
"""OIDC provider operations abort when provider is missing."""
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
mock_service = MagicMock()
configure(mock_service)
mock_get_service.return_value = mock_service
method = getattr(oidc_servicer, method_name)
with pytest.raises(AssertionError, match="Unreachable"):
await method(proto_request, mock_grpc_context)
mock_grpc_context.abort.assert_called_once()
class TestDeleteOidcProvider:
"""Tests for DeleteOidcProvider RPC."""
async def test_deletes_provider_successfully(
self,
oidc_servicer: MockServicerHost,
mock_grpc_context: MagicMock,
) -> None:
"""DeleteOidcProvider removes provider when found."""
provider_id = uuid4()
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
mock_service = MagicMock()
mock_service.registry.remove_provider = MagicMock(return_value=True)
mock_get_service.return_value = mock_service
request = noteflow_pb2.DeleteOidcProviderRequest(provider_id=str(provider_id))
response = await oidc_servicer.DeleteOidcProvider(request, mock_grpc_context)
assert response.success is True, "delete should return success=True"
mock_service.registry.remove_provider.assert_called_once_with(provider_id)
async def test_aborts_on_invalid_provider_id(
self,
oidc_servicer: MockServicerHost,
mock_grpc_context: MagicMock,
) -> None:
"""DeleteOidcProvider aborts on invalid UUID format."""
request = noteflow_pb2.DeleteOidcProviderRequest(provider_id="invalid-uuid")
with pytest.raises(AssertionError, match="Unreachable"):
await oidc_servicer.DeleteOidcProvider(request, mock_grpc_context)
mock_grpc_context.abort.assert_called_once()
class TestRefreshOidcDiscovery:
"""Tests for RefreshOidcDiscovery RPC."""
async def test_refreshes_single_provider(
self,
oidc_servicer: MockServicerHost,
mock_grpc_context: MagicMock,
sample_provider: OidcProviderConfig,
) -> None:
"""RefreshOidcDiscovery refreshes a single provider."""
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
mock_service = MagicMock()
mock_service.registry.get_provider = MagicMock(return_value=sample_provider)
mock_service.registry.refresh_discovery = AsyncMock()
mock_get_service.return_value = mock_service
request = noteflow_pb2.RefreshOidcDiscoveryRequest(provider_id=str(sample_provider.id))
response = await oidc_servicer.RefreshOidcDiscovery(request, mock_grpc_context)
expected_provider_id = str(sample_provider.id)
assert response.success_count == 1, "should report one success"
assert response.failure_count == 0, "should report no failures"
assert expected_provider_id in response.results, "provider ID should be in results"
async def test_reports_single_provider_failure(
self,
oidc_servicer: MockServicerHost,
mock_grpc_context: MagicMock,
sample_provider: OidcProviderConfig,
) -> None:
"""RefreshOidcDiscovery reports failure for single provider."""
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
mock_service = MagicMock()
mock_service.registry.get_provider = MagicMock(return_value=sample_provider)
mock_service.registry.refresh_discovery = AsyncMock(
side_effect=OidcDiscoveryError("Network error")
)
mock_get_service.return_value = mock_service
request = noteflow_pb2.RefreshOidcDiscoveryRequest(provider_id=str(sample_provider.id))
response = await oidc_servicer.RefreshOidcDiscovery(request, mock_grpc_context)
assert response.success_count == 0, "should report no success"
assert response.failure_count == 1, "should report one failure"
assert "Network error" in response.results[str(sample_provider.id)], (
"error message should be in results"
)
async def test_refreshes_all_providers(
self,
oidc_servicer: MockServicerHost,
mock_grpc_context: MagicMock,
) -> None:
"""RefreshOidcDiscovery refreshes all providers when no ID specified."""
provider1_id = uuid4()
provider2_id = uuid4()
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
mock_service = MagicMock()
mock_service.refresh_all_discovery = AsyncMock(
return_value={
provider1_id: None, # Success
provider2_id: "Connection refused", # Failure
}
)
mock_get_service.return_value = mock_service
request = noteflow_pb2.RefreshOidcDiscoveryRequest()
response = await oidc_servicer.RefreshOidcDiscovery(request, mock_grpc_context)
assert response.success_count == 1, "should count successes"
assert response.failure_count == 1, "should count failures"
assert response.results[str(provider1_id)] == "", (
"successful provider should have empty error"
)
assert "Connection refused" in response.results[str(provider2_id)], (
"failed provider should have error message"
)
async def test_aborts_when_single_provider_not_found(
self,
oidc_servicer: MockServicerHost,
mock_grpc_context: MagicMock,
) -> None:
"""RefreshOidcDiscovery aborts when specified provider not found."""
with patch.object(oidc_servicer, "get_oidc_service") as mock_get_service:
mock_service = MagicMock()
mock_service.registry.get_provider = MagicMock(return_value=None)
mock_get_service.return_value = mock_service
request = noteflow_pb2.RefreshOidcDiscoveryRequest(provider_id=str(uuid4()))
with pytest.raises(AssertionError, match="Unreachable"):
await oidc_servicer.RefreshOidcDiscovery(request, mock_grpc_context)
mock_grpc_context.abort.assert_called_once()
class TestListOidcPresets:
"""Tests for ListOidcPresets RPC."""
async def test_returns_all_presets(
self,
oidc_servicer: MockServicerHost,
mock_grpc_context: MagicMock,
) -> None:
"""ListOidcPresets returns all available provider presets."""
request = noteflow_pb2.ListOidcPresetsRequest()
response = await oidc_servicer.ListOidcPresets(request, mock_grpc_context)
preset_names = [p.preset for p in response.presets]
assert "authentik" in preset_names, "should include Authentik preset"
assert "keycloak" in preset_names, "should include Keycloak preset"
assert "auth0" in preset_names, "should include Auth0 preset"
assert len(response.presets) >= 6, "should have at least 6 presets"
async def test_presets_include_required_fields(
self,
oidc_servicer: MockServicerHost,
mock_grpc_context: MagicMock,
) -> None:
"""ListOidcPresets returns presets with all required fields."""
request = noteflow_pb2.ListOidcPresetsRequest()
response = await oidc_servicer.ListOidcPresets(request, mock_grpc_context)
authentik_preset = next((p for p in response.presets if p.preset == "authentik"), None)
assert authentik_preset is not None, "should find Authentik preset"
assert authentik_preset.display_name, "preset should have display name"
assert authentik_preset.description, "preset should have description"
assert len(authentik_preset.default_scopes) > 0, "preset should have scopes"